import torch from torch import nn from .model import JointTransformerBlock class ZImageControlTransformerBlock(JointTransformerBlock): def __init__( self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int, multiple_of: int, ffn_dim_multiplier: float, norm_eps: float, qk_norm: bool, modulation=True, block_id=0, operation_settings=None, ): super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings) self.block_id = block_id if block_id == 0: self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) def forward(self, c, x, **kwargs): if self.block_id == 0: c = self.before_proj(c) + x c = super().forward(c, **kwargs) c_skip = self.after_proj(c) return c_skip, c class ZImage_Control(torch.nn.Module): def __init__( self, dim: int = 3840, n_heads: int = 30, n_kv_heads: int = 30, multiple_of: int = 256, ffn_dim_multiplier: float = (8.0 / 3.0), norm_eps: float = 1e-5, qk_norm: bool = True, n_control_layers=6, control_in_dim=16, additional_in_dim=0, broken=False, refiner_control=False, dtype=None, device=None, operations=None, **kwargs ): super().__init__() operation_settings = {"operations": operations, "device": device, "dtype": dtype} self.broken = broken self.additional_in_dim = additional_in_dim self.control_in_dim = control_in_dim n_refiner_layers = 2 self.n_control_layers = n_control_layers self.control_layers = nn.ModuleList( [ ZImageControlTransformerBlock( i, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, block_id=i, operation_settings=operation_settings, ) for i in range(self.n_control_layers) ] ) all_x_embedder = {} patch_size = 2 f_patch_size = 1 x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype) all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder self.refiner_control = refiner_control self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) if self.refiner_control: self.control_noise_refiner = nn.ModuleList( [ ZImageControlTransformerBlock( layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, block_id=layer_id, operation_settings=operation_settings, ) for layer_id in range(n_refiner_layers) ] ) else: self.control_noise_refiner = nn.ModuleList( [ JointTransformerBlock( layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation=True, z_image_modulation=True, operation_settings=operation_settings, ) for layer_id in range(n_refiner_layers) ] ) def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input): patch_size = 2 f_patch_size = 1 pH = pW = patch_size B, C, H, W = control_context.shape control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) x_attn_mask = None if not self.refiner_control: for layer in self.control_noise_refiner: control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) return control_context def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): if self.refiner_control: if self.broken: if layer_id == 0: return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) if layer_id > 0: out = None for i in range(1, len(self.control_layers)): o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) if out is None: out = o return (out, control_context) else: return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) else: return (None, control_context) def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)