import torch import math from .model import QwenImageTransformer2DModel from .model import QwenImageTransformerBlock class QwenImageFunControlBlock(QwenImageTransformerBlock): def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None): super().__init__( dim=dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, dtype=dtype, device=device, operations=operations, ) self.has_before_proj = has_before_proj if has_before_proj: self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype) self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype) class QwenImageFunControlNetModel(torch.nn.Module): def __init__( self, control_in_features=132, inner_dim=3072, num_attention_heads=24, attention_head_dim=128, num_control_blocks=5, main_model_double=60, injection_layers=(0, 12, 24, 36, 48), dtype=None, device=None, operations=None, ): super().__init__() self.dtype = dtype self.main_model_double = main_model_double self.injection_layers = tuple(injection_layers) # Keep base hint scaling at 1.0 so user-facing strength behaves similarly # to the reference Gen2/VideoX implementation around strength=1. self.hint_scale = 1.0 self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype) self.control_blocks = torch.nn.ModuleList([]) for i in range(num_control_blocks): self.control_blocks.append( QwenImageFunControlBlock( dim=inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, has_before_proj=(i == 0), dtype=dtype, device=device, operations=operations, ) ) def _process_hint_tokens(self, hint): if hint is None: return None if hint.ndim == 4: hint = hint.unsqueeze(2) # Fun checkpoints are trained with 33 latent channels before 2x2 packing: # [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features. # Default behavior (no inpaint input in stock Apply ControlNet) should use # zeros for mask/inpaint branches, matching VideoX fallback semantics. expected_c = self.control_img_in.weight.shape[1] // 4 if hint.shape[1] == 16 and expected_c == 33: zeros_mask = torch.zeros_like(hint[:, :1]) zeros_inpaint = torch.zeros_like(hint) hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1) bs, c, t, h, w = hint.shape hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2)) orig_shape = hidden_states.shape hidden_states = hidden_states.view( orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2, ) hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6) hidden_states = hidden_states.reshape( bs, t * ((h + 1) // 2) * ((w + 1) // 2), c * 4, ) expected_in = self.control_img_in.weight.shape[1] cur_in = hidden_states.shape[-1] if cur_in < expected_in: pad = torch.zeros( (hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in), device=hidden_states.device, dtype=hidden_states.dtype, ) hidden_states = torch.cat([hidden_states, pad], dim=-1) elif cur_in > expected_in: hidden_states = hidden_states[:, :, :expected_in] return hidden_states def forward( self, x, timesteps, context, attention_mask=None, guidance: torch.Tensor = None, hint=None, transformer_options={}, base_model=None, **kwargs, ): if base_model is None: raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.") encoder_hidden_states_mask = attention_mask # Keep attention mask disabled inside Fun control blocks to mirror # VideoX behavior (they rely on seq lengths for RoPE, not masked attention). encoder_hidden_states_mask = None hidden_states, img_ids, _ = base_model.process_img(x) hint_tokens = self._process_hint_tokens(hint) if hint_tokens is None: raise RuntimeError("Qwen Fun ControlNet requires a control hint image.") if hint_tokens.shape[1] != hidden_states.shape[1]: max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1]) hint_tokens = hint_tokens[:, :max_tokens] hidden_states = hidden_states[:, :max_tokens] img_ids = img_ids[:, :max_tokens] txt_start = round( max( ((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2, ((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2, ) ) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous() hidden_states = base_model.img_in(hidden_states) encoder_hidden_states = base_model.txt_norm(context) encoder_hidden_states = base_model.txt_in(encoder_hidden_states) if guidance is not None: guidance = guidance * 1000 temb = ( base_model.time_text_embed(timesteps, hidden_states) if guidance is None else base_model.time_text_embed(timesteps, guidance, hidden_states) ) c = self.control_img_in(hint_tokens) for i, block in enumerate(self.control_blocks): if i == 0: c_in = block.before_proj(c) + hidden_states all_c = [] else: all_c = list(torch.unbind(c, dim=0)) c_in = all_c.pop(-1) encoder_hidden_states, c_out = block( hidden_states=c_in, encoder_hidden_states=encoder_hidden_states, encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, transformer_options=transformer_options, ) c_skip = block.after_proj(c_out) * self.hint_scale all_c += [c_skip, c_out] c = torch.stack(all_c, dim=0) hints = torch.unbind(c, dim=0)[:-1] controlnet_block_samples = [None] * self.main_model_double for local_idx, base_idx in enumerate(self.injection_layers): if local_idx < len(hints) and base_idx < len(controlnet_block_samples): controlnet_block_samples[base_idx] = hints[local_idx] return {"input": controlnet_block_samples} class QwenImageControlNetModel(QwenImageTransformer2DModel): def __init__( self, extra_condition_channels=0, dtype=None, device=None, operations=None, **kwargs ): super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs) self.main_model_double = 60 # controlnet_blocks self.controlnet_blocks = torch.nn.ModuleList([]) for _ in range(len(self.transformer_blocks)): self.controlnet_blocks.append(operations.Linear(self.inner_dim, self.inner_dim, device=device, dtype=dtype)) self.controlnet_x_embedder = operations.Linear(self.in_channels + extra_condition_channels, self.inner_dim, device=device, dtype=dtype) def forward( self, x, timesteps, context, attention_mask=None, guidance: torch.Tensor = None, ref_latents=None, hint=None, transformer_options={}, **kwargs ): timestep = timesteps encoder_hidden_states = context encoder_hidden_states_mask = attention_mask hidden_states, img_ids, orig_shape = self.process_img(x) hint, _, _ = self.process_img(hint) txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint) encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) if guidance is not None: guidance = guidance * 1000 temb = ( self.time_text_embed(timestep, hidden_states) if guidance is None else self.time_text_embed(timestep, guidance, hidden_states) ) repeat = math.ceil(self.main_model_double / len(self.controlnet_blocks)) controlnet_block_samples = () for i, block in enumerate(self.transformer_blocks): encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, encoder_hidden_states_mask=encoder_hidden_states_mask, temb=temb, image_rotary_emb=image_rotary_emb, ) controlnet_block_samples = controlnet_block_samples + (self.controlnet_blocks[i](hidden_states),) * repeat return {"input": controlnet_block_samples[:self.main_model_double]}