Diffusion model part of Qwen Image Layered. (#11408)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.9) (push) Has been cancelled

Only thing missing after this is some nodes to make using it easier.
This commit is contained in:
comfyanonymous 2025-12-18 17:21:14 -08:00 committed by GitHub
parent 6a2678ac65
commit 28eaab608b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 24 deletions

View File

@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
self.timestep_embedder = TimestepEmbedding(
@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module):
operations=operations
)
def forward(self, timestep, hidden_states):
self.use_additional_t_cond = use_additional_t_cond
if self.use_additional_t_cond:
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
def forward(self, timestep, hidden_states, addition_t_cond=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
if self.use_additional_t_cond:
if addition_t_cond is None:
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
return timesteps_emb
@ -320,11 +330,11 @@ class QwenImageTransformer2DModel(nn.Module):
num_attention_heads: int = 24,
joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
default_ref_method="index",
image_model=None,
final_layer=True,
use_additional_t_cond=False,
dtype=None,
device=None,
operations=None,
@ -342,6 +352,7 @@ class QwenImageTransformer2DModel(nn.Module):
self.time_text_embed = QwenTimestepProjEmbeddings(
embedding_dim=self.inner_dim,
pooled_projection_dim=pooled_projection_dim,
use_additional_t_cond=use_additional_t_cond,
dtype=dtype,
device=device,
operations=operations
@ -375,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module):
patch_size = self.patch_size
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
orig_shape = hidden_states.shape
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
t_len = t
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
img_ids[:, :, 0] = img_ids[:, :, 1] + index
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
if t_len > 1:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
else:
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
def _forward(
self,
@ -403,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module):
timesteps,
context,
attention_mask=None,
guidance: torch.Tensor = None,
ref_latents=None,
additional_t_cond=None,
transformer_options={},
control=None,
**kwargs
@ -423,12 +440,17 @@ class QwenImageTransformer2DModel(nn.Module):
index = 0
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
negative_ref_method = ref_method == "negative_index"
timestep_zero = ref_method == "index_timestep_zero"
for ref in ref_latents:
if index_ref_method:
index += 1
h_offset = 0
w_offset = 0
elif negative_ref_method:
index -= 1
h_offset = 0
w_offset = 0
else:
index = 1
h_offset = 0
@ -458,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module):
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
if guidance is not None:
guidance = guidance * 1000
temb = (
self.time_text_embed(timestep, hidden_states)
if guidance is None
else self.time_text_embed(timestep, guidance, hidden_states)
)
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
patches_replace = transformer_options.get("patches_replace", {})
patches = transformer_options.get("patches", {})
@ -513,6 +528,6 @@ class QwenImageTransformer2DModel(nn.Module):
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

View File

@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
dit_config["default_ref_method"] = "index_timestep_zero"
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
dit_config["use_additional_t_cond"] = True
dit_config["default_ref_method"] = "negative_index"
return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5