From 28eaab608bc34c4e3b1886b1bddbb429453249d8 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 18 Dec 2025 17:21:14 -0800 Subject: [PATCH] Diffusion model part of Qwen Image Layered. (#11408) Only thing missing after this is some nodes to make using it easier. --- comfy/ldm/qwen_image/model.py | 63 ++++++++++++++++++++++------------- comfy/model_detection.py | 3 ++ 2 files changed, 42 insertions(+), 24 deletions(-) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 902af30ed..00c597535 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis): class QwenTimestepProjEmbeddings(nn.Module): - def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None): + def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.timestep_embedder = TimestepEmbedding( @@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module): operations=operations ) - def forward(self, timestep, hidden_states): + self.use_additional_t_cond = use_additional_t_cond + if self.use_additional_t_cond: + self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype) + + def forward(self, timestep, hidden_states, addition_t_cond=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) + + if self.use_additional_t_cond: + if addition_t_cond is None: + addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long) + timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype) + return timesteps_emb @@ -320,11 +330,11 @@ class QwenImageTransformer2DModel(nn.Module): num_attention_heads: int = 24, joint_attention_dim: int = 3584, pooled_projection_dim: int = 768, - guidance_embeds: bool = False, axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), default_ref_method="index", image_model=None, final_layer=True, + use_additional_t_cond=False, dtype=None, device=None, operations=None, @@ -342,6 +352,7 @@ class QwenImageTransformer2DModel(nn.Module): self.time_text_embed = QwenTimestepProjEmbeddings( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim, + use_additional_t_cond=use_additional_t_cond, dtype=dtype, device=device, operations=operations @@ -375,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module): patch_size = self.patch_size hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size)) orig_shape = hidden_states.shape - hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) - hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5) - hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2) + hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6) + hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4) + t_len = t h_len = ((h + (patch_size // 2)) // patch_size) w_len = ((w + (patch_size // 2)) // patch_size) h_offset = ((h_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device) - img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2) - return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape + img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device) - def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs): + if t_len > 1: + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1) + else: + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index + + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2) + return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape + + def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs): return comfy.patcher_extension.WrapperExecutor.new_class_executor( self._forward, self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs) + ).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs) def _forward( self, @@ -403,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module): timesteps, context, attention_mask=None, - guidance: torch.Tensor = None, ref_latents=None, + additional_t_cond=None, transformer_options={}, control=None, **kwargs @@ -423,12 +440,17 @@ class QwenImageTransformer2DModel(nn.Module): index = 0 ref_method = kwargs.get("ref_latents_method", self.default_ref_method) index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero") + negative_ref_method = ref_method == "negative_index" timestep_zero = ref_method == "index_timestep_zero" for ref in ref_latents: if index_ref_method: index += 1 h_offset = 0 w_offset = 0 + elif negative_ref_method: + index -= 1 + h_offset = 0 + w_offset = 0 else: index = 1 h_offset = 0 @@ -458,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module): encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_in(encoder_hidden_states) - if guidance is not None: - guidance = guidance * 1000 - - temb = ( - self.time_text_embed(timestep, hidden_states) - if guidance is None - else self.time_text_embed(timestep, guidance, hidden_states) - ) + temb = self.time_text_embed(timestep, hidden_states, additional_t_cond) patches_replace = transformer_options.get("patches_replace", {}) patches = transformer_options.get("patches", {}) @@ -513,6 +528,6 @@ class QwenImageTransformer2DModel(nn.Module): hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) - hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5) + hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2) + hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6) return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]] diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7148c77fd..84fd409fd 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511 dit_config["default_ref_method"] = "index_timestep_zero" + if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered + dit_config["use_additional_t_cond"] = True + dit_config["default_ref_method"] = "negative_index" return dit_config if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5