mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-20 19:42:59 +08:00
Diffusion model part of Qwen Image Layered. (#11408)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.9) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.9) (push) Has been cancelled
Only thing missing after this is some nodes to make using it easier.
This commit is contained in:
parent
6a2678ac65
commit
28eaab608b
@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
|
|||||||
|
|
||||||
|
|
||||||
class QwenTimestepProjEmbeddings(nn.Module):
|
class QwenTimestepProjEmbeddings(nn.Module):
|
||||||
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
|
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
||||||
self.timestep_embedder = TimestepEmbedding(
|
self.timestep_embedder = TimestepEmbedding(
|
||||||
@ -72,9 +72,19 @@ class QwenTimestepProjEmbeddings(nn.Module):
|
|||||||
operations=operations
|
operations=operations
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, timestep, hidden_states):
|
self.use_additional_t_cond = use_additional_t_cond
|
||||||
|
if self.use_additional_t_cond:
|
||||||
|
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(self, timestep, hidden_states, addition_t_cond=None):
|
||||||
timesteps_proj = self.time_proj(timestep)
|
timesteps_proj = self.time_proj(timestep)
|
||||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
|
||||||
|
|
||||||
|
if self.use_additional_t_cond:
|
||||||
|
if addition_t_cond is None:
|
||||||
|
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
|
||||||
|
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
|
||||||
|
|
||||||
return timesteps_emb
|
return timesteps_emb
|
||||||
|
|
||||||
|
|
||||||
@ -320,11 +330,11 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
num_attention_heads: int = 24,
|
num_attention_heads: int = 24,
|
||||||
joint_attention_dim: int = 3584,
|
joint_attention_dim: int = 3584,
|
||||||
pooled_projection_dim: int = 768,
|
pooled_projection_dim: int = 768,
|
||||||
guidance_embeds: bool = False,
|
|
||||||
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
||||||
default_ref_method="index",
|
default_ref_method="index",
|
||||||
image_model=None,
|
image_model=None,
|
||||||
final_layer=True,
|
final_layer=True,
|
||||||
|
use_additional_t_cond=False,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None,
|
operations=None,
|
||||||
@ -342,6 +352,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
self.time_text_embed = QwenTimestepProjEmbeddings(
|
self.time_text_embed = QwenTimestepProjEmbeddings(
|
||||||
embedding_dim=self.inner_dim,
|
embedding_dim=self.inner_dim,
|
||||||
pooled_projection_dim=pooled_projection_dim,
|
pooled_projection_dim=pooled_projection_dim,
|
||||||
|
use_additional_t_cond=use_additional_t_cond,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
operations=operations
|
operations=operations
|
||||||
@ -375,27 +386,33 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
|
||||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
|
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||||
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
|
||||||
|
t_len = t
|
||||||
h_len = ((h + (patch_size // 2)) // patch_size)
|
h_len = ((h + (patch_size // 2)) // patch_size)
|
||||||
w_len = ((w + (patch_size // 2)) // patch_size)
|
w_len = ((w + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
|
||||||
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
|
||||||
|
|
||||||
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
|
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
|
||||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
|
||||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
|
|
||||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
|
|
||||||
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
|
|
||||||
|
|
||||||
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
|
if t_len > 1:
|
||||||
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
|
||||||
|
else:
|
||||||
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
|
||||||
|
|
||||||
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
|
||||||
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
|
||||||
|
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
|
||||||
|
|
||||||
|
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
|
||||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||||
self._forward,
|
self._forward,
|
||||||
self,
|
self,
|
||||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||||
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
|
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
|
||||||
|
|
||||||
def _forward(
|
def _forward(
|
||||||
self,
|
self,
|
||||||
@ -403,8 +420,8 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
timesteps,
|
timesteps,
|
||||||
context,
|
context,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
guidance: torch.Tensor = None,
|
|
||||||
ref_latents=None,
|
ref_latents=None,
|
||||||
|
additional_t_cond=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
control=None,
|
control=None,
|
||||||
**kwargs
|
**kwargs
|
||||||
@ -423,12 +440,17 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
index = 0
|
index = 0
|
||||||
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
|
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
|
||||||
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
|
||||||
|
negative_ref_method = ref_method == "negative_index"
|
||||||
timestep_zero = ref_method == "index_timestep_zero"
|
timestep_zero = ref_method == "index_timestep_zero"
|
||||||
for ref in ref_latents:
|
for ref in ref_latents:
|
||||||
if index_ref_method:
|
if index_ref_method:
|
||||||
index += 1
|
index += 1
|
||||||
h_offset = 0
|
h_offset = 0
|
||||||
w_offset = 0
|
w_offset = 0
|
||||||
|
elif negative_ref_method:
|
||||||
|
index -= 1
|
||||||
|
h_offset = 0
|
||||||
|
w_offset = 0
|
||||||
else:
|
else:
|
||||||
index = 1
|
index = 1
|
||||||
h_offset = 0
|
h_offset = 0
|
||||||
@ -458,14 +480,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||||
|
|
||||||
if guidance is not None:
|
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
|
||||||
guidance = guidance * 1000
|
|
||||||
|
|
||||||
temb = (
|
|
||||||
self.time_text_embed(timestep, hidden_states)
|
|
||||||
if guidance is None
|
|
||||||
else self.time_text_embed(timestep, guidance, hidden_states)
|
|
||||||
)
|
|
||||||
|
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
@ -513,6 +528,6 @@ class QwenImageTransformer2DModel(nn.Module):
|
|||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
|
||||||
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
|
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
|
||||||
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]
|
||||||
|
|||||||
@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||||
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
|
||||||
dit_config["default_ref_method"] = "index_timestep_zero"
|
dit_config["default_ref_method"] = "index_timestep_zero"
|
||||||
|
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
|
||||||
|
dit_config["use_additional_t_cond"] = True
|
||||||
|
dit_config["default_ref_method"] = "negative_index"
|
||||||
return dit_config
|
return dit_config
|
||||||
|
|
||||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user