From 117e2143543dd649d47345e183748a82d48d12d3 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 12 Feb 2026 16:51:50 -0800 Subject: [PATCH 1/9] ModelPatcherDynamic: force load non leaf weights (#12433) The current behaviour of the default ModelPatcher is to .to a model only if its fully loaded, which is how random non-leaf weights get loaded in non-LowVRAM conditions. The however means they never get loaded in dynamic_vram. In the dynamic_vram case, force load them to the GPU. --- comfy/model_patcher.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index f278fccac..b1d907ba4 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -679,18 +679,19 @@ class ModelPatcher: for key in list(self.pinned): self.unpin_weight(key) - def _load_list(self, prio_comfy_cast_weights=False): + def _load_list(self, prio_comfy_cast_weights=False, default_device=None): loading = [] for n, m in self.model.named_modules(): - params = [] - skip = False - for name, param in m.named_parameters(recurse=False): - params.append(name) + default = False + params = { name: param for name, param in m.named_parameters(recurse=False) } for name, param in m.named_parameters(recurse=True): if name not in params: - skip = True # skip random weights in non leaf modules + default = True # default random weights in non leaf modules break - if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): + if default and default_device is not None: + for param in params.values(): + param.data = param.data.to(device=default_device) + if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0): module_mem = comfy.model_management.module_size(m) module_offload_mem = module_mem if hasattr(m, "comfy_cast_weights"): @@ -1495,7 +1496,7 @@ class ModelPatcherDynamic(ModelPatcher): #with pin and unpin syncrhonization which can be expensive for small weights #with a high layer rate (e.g. autoregressive LLMs). #prioritize the non-comfy weights (note the order reverse). - loading = self._load_list(prio_comfy_cast_weights=True) + loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to) loading.sort(reverse=True) for x in loading: @@ -1579,7 +1580,7 @@ class ModelPatcherDynamic(ModelPatcher): return 0 if vbar is None else vbar.free_memory(memory_to_free) def partially_unload_ram(self, ram_to_unload): - loading = self._load_list(prio_comfy_cast_weights=True) + loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device) for x in loading: _, _, _, _, m, _ = x ram_to_unload -= comfy.pinned_memory.unpin_memory(m) From ae79e33345dc893f8e0632c380c0e91dc09ac6e8 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Thu, 12 Feb 2026 16:56:42 -0800 Subject: [PATCH 2/9] llama: use a more efficient rope implementation (#12434) Get rid of the cat and unary negation and inplace add-cmul the two halves of the rope. Precompute -sin once at the start of the model rather than every transformer block. This is slightly faster on both GPU and CPU bound setups. --- comfy/text_encoders/llama.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index b6735d210..54f3d5595 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -355,13 +355,6 @@ class RMSNorm(nn.Module): -def rotate_half(x): - """Rotates half the hidden dims of the input.""" - x1 = x[..., : x.shape[-1] // 2] - x2 = x[..., x.shape[-1] // 2 :] - return torch.cat((-x2, x1), dim=-1) - - def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None): if not isinstance(theta, list): theta = [theta] @@ -390,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di else: cos = cos.unsqueeze(1) sin = sin.unsqueeze(1) - out.append((cos, sin)) + sin_split = sin.shape[-1] // 2 + out.append((cos, sin[..., : sin_split], -sin[..., sin_split :])) if len(out) == 1: return out[0] return out - def apply_rope(xq, xk, freqs_cis): org_dtype = xq.dtype cos = freqs_cis[0] sin = freqs_cis[1] - q_embed = (xq * cos) + (rotate_half(xq) * sin) - k_embed = (xk * cos) + (rotate_half(xk) * sin) + nsin = freqs_cis[2] + + q_embed = (xq * cos) + q_split = q_embed.shape[-1] // 2 + q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin) + q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin) + + k_embed = (xk * cos) + k_split = k_embed.shape[-1] // 2 + k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin) + k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin) + return q_embed.to(org_dtype), k_embed.to(org_dtype) From e03fe8b5919a23a473cea6e53f916f7403c082a5 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Thu, 12 Feb 2026 20:29:12 -0800 Subject: [PATCH 3/9] Update command to install AMD stable linux pytorch. (#12437) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 96dc2904b..3ccdc9c19 100644 --- a/README.md +++ b/README.md @@ -227,7 +227,7 @@ Put your VAE in: models/vae AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version: -```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4``` +```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1``` This is the command to install the nightly with ROCm 7.1 which might have some performance improvements: From 8902907d7ab949ce42dd9b658b4a4582ed9fb630 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Fri, 13 Feb 2026 12:29:37 -0800 Subject: [PATCH 4/9] dynamic_vram: Training fixes (#12442) --- comfy/model_patcher.py | 4 ++++ comfy_extras/nodes_train.py | 11 ++++++++++- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index b1d907ba4..67dce088e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1561,6 +1561,8 @@ class ModelPatcherDynamic(ModelPatcher): allocated_size += weight_size vbar.set_watermark_limit(allocated_size) + move_weight_functions(m, device_to) + logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.") self.model.device = device_to @@ -1601,6 +1603,8 @@ class ModelPatcherDynamic(ModelPatcher): if unpatch_weights: self.partially_unload_ram(1e32) self.partially_unload(None, 1e32) + for m in self.model.modules(): + move_weight_functions(m, device_to) def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): assert not force_patch_weights #See above diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index 630eedc9f..aa2d88673 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -1035,7 +1035,7 @@ class TrainLoraNode(io.ComfyNode): io.Boolean.Input( "offloading", default=False, - tooltip="Depth level for gradient checkpointing.", + tooltip="Offload the Model to RAM. Requires Bypass Mode.", ), io.Combo.Input( "existing_lora", @@ -1124,6 +1124,15 @@ class TrainLoraNode(io.ComfyNode): lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) + if mp.is_dynamic(): + if not bypass_mode: + logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode") + bypass_mode = True + offloading = True + elif offloading: + if not bypass_mode: + logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message") + # Prepare latents and compute counts latents, num_images, multi_res = _prepare_latents_and_count( latents, dtype, bucket_mode From e1add563f9e89026e8c4e8825a2b279fbd67d23a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 13 Feb 2026 12:35:13 -0800 Subject: [PATCH 5/9] Use torch RMSNorm for flux models and refactor hunyuan video code. (#12432) --- comfy/controlnet.py | 1 + comfy/ldm/chroma/layers.py | 3 +- comfy/ldm/chroma_radiance/layers.py | 8 ++--- comfy/ldm/flux/layers.py | 53 +++++++---------------------- comfy/ldm/flux/model.py | 3 +- comfy/ldm/hunyuan_video/model.py | 11 +++--- comfy/lora_convert.py | 2 +- comfy/model_detection.py | 18 +++++++--- comfy/supported_models.py | 32 +++++++++++++++-- comfy/utils.py | 12 +++---- 10 files changed, 74 insertions(+), 69 deletions(-) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 9e1e704e0..8336412f2 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -560,6 +560,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}): def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}): model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options) control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) + sd = model_config.process_unet_state_dict(sd) control_model = controlnet_load_state_dict(control_model, sd) extra_conds = ['y', 'guidance'] control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 2d5684348..df348a8ed 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -3,7 +3,6 @@ from torch import Tensor, nn from comfy.ldm.flux.layers import ( MLPEmbedder, - RMSNorm, ModulationOut, ) @@ -29,7 +28,7 @@ class Approximator(nn.Module): super().__init__() self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) - self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) + self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)]) self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) @property diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py index 3c7bc9b6b..08d31e0ba 100644 --- a/comfy/ldm/chroma_radiance/layers.py +++ b/comfy/ldm/chroma_radiance/layers.py @@ -4,8 +4,6 @@ from functools import lru_cache import torch from torch import nn -from comfy.ldm.flux.layers import RMSNorm - class NerfEmbedder(nn.Module): """ @@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module): # We now need to generate parameters for 3 matrices. total_params = 3 * hidden_size_x**2 * mlp_ratio self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device) - self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations) + self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device) self.mlp_ratio = mlp_ratio @@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module): class NerfFinalLayer(nn.Module): def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None): super().__init__() - self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device) self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module): class NerfFinalLayerConv(nn.Module): def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None): super().__init__() - self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations) + self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device) self.conv = operations.Conv2d( in_channels=hidden_size, out_channels=out_channels, diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 60f2bdae2..1f2975fb1 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -5,8 +5,6 @@ import torch from torch import Tensor, nn from .math import attention, rope -import comfy.ops -import comfy.ldm.common_dit class EmbedND(nn.Module): @@ -87,20 +85,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), ) -class RMSNorm(torch.nn.Module): - def __init__(self, dim: int, dtype=None, device=None, operations=None): - super().__init__() - self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device)) - - def forward(self, x: Tensor): - return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6) - class QKNorm(torch.nn.Module): def __init__(self, dim: int, dtype=None, device=None, operations=None): super().__init__() - self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) - self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations) + self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device) + self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device) def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple: q = self.query_norm(q) @@ -169,7 +159,7 @@ class SiLUActivation(nn.Module): class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) @@ -197,8 +187,6 @@ class DoubleStreamBlock(nn.Module): self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) - self.flipped_img_txt = flipped_img_txt - def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): if self.modulation: img_mod1, img_mod2 = self.img_mod(vec) @@ -224,32 +212,17 @@ class DoubleStreamBlock(nn.Module): del txt_qkv txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) - if self.flipped_img_txt: - q = torch.cat((img_q, txt_q), dim=2) - del img_q, txt_q - k = torch.cat((img_k, txt_k), dim=2) - del img_k, txt_k - v = torch.cat((img_v, txt_v), dim=2) - del img_v, txt_v - # run actual attention - attn = attention(q, k, v, - pe=pe, mask=attn_mask, transformer_options=transformer_options) - del q, k, v + q = torch.cat((txt_q, img_q), dim=2) + del txt_q, img_q + k = torch.cat((txt_k, img_k), dim=2) + del txt_k, img_k + v = torch.cat((txt_v, img_v), dim=2) + del txt_v, img_v + # run actual attention + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v - img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] - else: - q = torch.cat((txt_q, img_q), dim=2) - del txt_q, img_q - k = torch.cat((txt_k, img_k), dim=2) - del txt_k, img_k - v = torch.cat((txt_v, img_v), dim=2) - del txt_v, img_v - # run actual attention - attn = attention(q, k, v, - pe=pe, mask=attn_mask, transformer_options=transformer_options) - del q, k, v - - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] + txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] # calculate the img bloks img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index f40c2a7a9..260ccad7e 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -16,7 +16,6 @@ from .layers import ( SingleStreamBlock, timestep_embedding, Modulation, - RMSNorm ) @dataclass @@ -81,7 +80,7 @@ class Flux(nn.Module): self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) if params.txt_norm: - self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations) + self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device) else: self.txt_norm = None diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index 55ab550f8..563f28f6b 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, - flipped_img_txt=True, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -378,14 +377,14 @@ class HunyuanVideo(nn.Module): extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1) - ids = torch.cat((img_ids, txt_ids), dim=1) + ids = torch.cat((txt_ids, img_ids), dim=1) pe = self.pe_embedder(ids) img_len = img.shape[1] if txt_mask is not None: attn_mask_len = img_len + txt.shape[1] attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device) - attn_mask[:, 0, img_len:] = txt_mask + attn_mask[:, 0, :txt.shape[1]] = txt_mask else: attn_mask = None @@ -413,7 +412,7 @@ class HunyuanVideo(nn.Module): if add is not None: img += add - img = torch.cat((img, txt), 1) + img = torch.cat((txt, img), 1) transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["block_type"] = "single" @@ -435,9 +434,9 @@ class HunyuanVideo(nn.Module): if i < len(control_o): add = control_o[i] if add is not None: - img[:, : img_len] += add + img[:, txt.shape[1]: img_len + txt.shape[1]] += add - img = img[:, : img_len] + img = img[:, txt.shape[1]: img_len + txt.shape[1]] if ref_latent is not None: img = img[:, ref_latent.shape[1]:] diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py index 9d8d21efe..749e81df3 100644 --- a/comfy/lora_convert.py +++ b/comfy/lora_convert.py @@ -5,7 +5,7 @@ import comfy.utils def convert_lora_bfl_control(sd): #BFL loras for Flux sd_out = {} for k in sd: - k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight")) + k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight")) sd_out[k_to] = sd[k] sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]]) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index e8ad725df..30ea03e8e 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string): count += 1 return count +def any_suffix_in(keys, prefix, main, suffix_list=[]): + for x in suffix_list: + if "{}{}{}".format(prefix, main, x) in keys: + return True + return False + def calculate_transformer_depth(prefix, state_dict_keys, state_dict): context_dim = None use_linear_in_transformer = False @@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["meanflow_sum"] = False return dit_config - if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight) + if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight) dit_config = {} if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys: dit_config["image_model"] = "flux2" @@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma + + if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma dit_config["image_model"] = "chroma" dit_config["in_channels"] = 64 dit_config["out_channels"] = 64 @@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["out_dim"] = 3072 dit_config["hidden_dim"] = 5120 dit_config["n_layers"] = 5 - if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance + + if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance dit_config["image_model"] = "chroma_radiance" dit_config["in_channels"] = 3 dit_config["out_channels"] = 3 @@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_depth"] = 4 dit_config["nerf_max_freqs"] = 8 dit_config["nerf_tile_size"] = 512 - dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear" + dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear" dit_config["nerf_embedder_dtype"] = torch.float32 if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred dit_config["use_x0"] = True @@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys - dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys + dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"]) if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model dit_config["txt_ids_dims"] = [1, 2] diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d33db7507..c28be1716 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE): supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + def process_unet_state_dict(self, state_dict): + out_sd = {} + for k in list(state_dict.keys()): + key_out = k + if key_out.endswith("_norm.scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) + out_sd[key_out] = state_dict[k] + return out_sd + vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] @@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE): key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.") key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.") key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.") - key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale") - key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale") + key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight") + key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight") key_out = key_out.replace("_attn_proj.", "_attn.proj.") key_out = key_out.replace(".modulation.linear.", ".modulation.lin.") key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.") + if key_out.endswith(".scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) out_sd[key_out] = state_dict[k] return out_sd @@ -1264,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE): latent_format = latent_formats.Hunyuan3Dv2 + def process_unet_state_dict(self, state_dict): + out_sd = {} + for k in list(state_dict.keys()): + key_out = k + if key_out.endswith(".scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) + out_sd[key_out] = state_dict[k] + return out_sd + def process_unet_state_dict_for_saving(self, state_dict): replace_prefix = {"": "model."} return utils.state_dict_prefix_replace(state_dict, replace_prefix) @@ -1341,6 +1361,14 @@ class Chroma(supported_models_base.BASE): supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + def process_unet_state_dict(self, state_dict): + out_sd = {} + for k in list(state_dict.keys()): + key_out = k + if key_out.endswith(".scale"): + key_out = "{}.weight".format(key_out[:-len(".scale")]) + out_sd[key_out] = state_dict[k] + return out_sd def get_model(self, state_dict, prefix="", device=None): out = model_base.Chroma(self, device=device) diff --git a/comfy/utils.py b/comfy/utils.py index e0a94e2e1..d553a7c1b 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -675,10 +675,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "ff_context.linear_in.bias": "txt_mlp.0.bias", "ff_context.linear_out.weight": "txt_mlp.2.weight", "ff_context.linear_out.bias": "txt_mlp.2.bias", - "attn.norm_q.weight": "img_attn.norm.query_norm.scale", - "attn.norm_k.weight": "img_attn.norm.key_norm.scale", - "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", - "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", + "attn.norm_q.weight": "img_attn.norm.query_norm.weight", + "attn.norm_k.weight": "img_attn.norm.key_norm.weight", + "attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight", + "attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight", } for k in block_map: @@ -701,8 +701,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): "norm.linear.bias": "modulation.lin.bias", "proj_out.weight": "linear2.weight", "proj_out.bias": "linear2.bias", - "attn.norm_q.weight": "norm.query_norm.scale", - "attn.norm_k.weight": "norm.key_norm.scale", + "attn.norm_q.weight": "norm.query_norm.weight", + "attn.norm_k.weight": "norm.key_norm.weight", "attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2 "attn.to_out.weight": "linear2.weight", # Flux 2 } From 831351a29e91ea758437227c2f3c915a6be6d1a6 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 13 Feb 2026 17:15:23 -0800 Subject: [PATCH 6/9] Support generating attention masks for left padded text encoders. (#12454) --- comfy/sd1_clip.py | 15 +++++++++++---- comfy/text_encoders/ace15.py | 8 +------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4c817d468..b564d1529 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -171,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def process_tokens(self, tokens, device): end_token = self.special_tokens.get("end", None) + pad_token = self.special_tokens.get("pad", -1) if end_token is None: - cmp_token = self.special_tokens.get("pad", -1) + cmp_token = pad_token else: cmp_token = end_token @@ -186,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): other_embeds = [] eos = False index = 0 + left_pad = False for y in x: if isinstance(y, numbers.Integral): - if eos: + token = int(y) + if index == 0 and token == pad_token: + left_pad = True + + if eos or (left_pad and token == pad_token): attention_mask.append(0) else: attention_mask.append(1) - token = int(y) + left_pad = False + tokens_temp += [token] - if not eos and token == cmp_token: + if not eos and token == cmp_token and not left_pad: if end_token is None: attention_mask[-1] = 0 eos = True diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py index 0fdd4669f..f135d74c1 100644 --- a/comfy/text_encoders/ace15.py +++ b/comfy/text_encoders/ace15.py @@ -10,7 +10,6 @@ import comfy.utils def sample_manual_loop_no_classes( model, ids=None, - paddings=[], execution_dtype=None, cfg_scale: float = 2.0, temperature: float = 0.85, @@ -36,9 +35,6 @@ def sample_manual_loop_no_classes( embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device) embeds_batch = embeds.shape[0] - for i, t in enumerate(paddings): - attention_mask[i, :t] = 0 - attention_mask[i, t:] = 1 output_audio_codes = [] past_key_values = [] @@ -135,13 +131,11 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102 pos_pad = (len(negative) - len(positive)) positive = [model.special_tokens["pad"]] * pos_pad + positive - paddings = [pos_pad, neg_pad] ids = [positive, negative] else: - paddings = [] ids = [positive] - return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) + return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens) class ACE15Tokenizer(sd1_clip.SD1Tokenizer): From 726af73867c18c5ca8b980a2c28401d77e5b365a Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 13 Feb 2026 17:21:10 -0800 Subject: [PATCH 7/9] Fix some custom nodes. (#12455) --- comfy/ldm/flux/layers.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 1f2975fb1..3518a1922 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -6,6 +6,8 @@ from torch import Tensor, nn from .math import attention, rope +# Fix import for some custom nodes, TODO: delete eventually. +RMSNorm = None class EmbedND(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: list): From 712efb466b9379e6761802c44027783d37d96a87 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 13 Feb 2026 18:56:54 -0800 Subject: [PATCH 8/9] Add left padding to LTXAV text encoder. (#12456) --- comfy/text_encoders/lt.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py index 3f87dfd6a..9cf87c0b2 100644 --- a/comfy/text_encoders/lt.py +++ b/comfy/text_encoders/lt.py @@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs): class Gemma3_12BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} @@ -97,6 +97,7 @@ class LTXAVTEModel(torch.nn.Module): token_weight_pairs = token_weight_pairs["gemma3_12b"] out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs) + out = out[:, :, -torch.sum(extra["attention_mask"]).item():] out_device = out.device if comfy.model_management.should_use_bf16(self.execution_device): out = out.to(device=self.execution_device, dtype=torch.bfloat16) @@ -138,6 +139,7 @@ class LTXAVTEModel(torch.nn.Module): token_weight_pairs = token_weight_pairs.get("gemma3_12b", []) num_tokens = sum(map(lambda a: len(a), token_weight_pairs)) + num_tokens = max(num_tokens, 64) return num_tokens * constant * 1024 * 1024 def ltxav_te(dtype_llama=None, llama_quantization_metadata=None): From dc9822b7df4785e77690e93b0e09feaff01e2e12 Mon Sep 17 00:00:00 2001 From: krigeta <75309361+krigeta@users.noreply.github.com> Date: Sat, 14 Feb 2026 08:53:52 +0530 Subject: [PATCH 9/9] Add working Qwen 2512 ControlNet (Fun ControlNet) support (#12359) --- comfy/controlnet.py | 73 +++++++++++ comfy/ldm/qwen_image/controlnet.py | 190 +++++++++++++++++++++++++++++ 2 files changed, 263 insertions(+) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 8336412f2..ba670b16d 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -297,6 +297,30 @@ class ControlNet(ControlBase): self.model_sampling_current = None super().cleanup() + +class QwenFunControlNet(ControlNet): + def get_control(self, x_noisy, t, cond, batched_number, transformer_options): + # Fun checkpoints are more sensitive to high strengths in the generic + # ControlNet merge path. Use a soft response curve so strength=1.0 stays + # unchanged while >1 grows more gently. + original_strength = self.strength + self.strength = math.sqrt(max(self.strength, 0.0)) + try: + return super().get_control(x_noisy, t, cond, batched_number, transformer_options) + finally: + self.strength = original_strength + + def pre_run(self, model, percent_to_timestep_function): + super().pre_run(model, percent_to_timestep_function) + self.set_extra_arg("base_model", model.diffusion_model) + + def copy(self): + c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype) + c.control_model = self.control_model + c.control_model_wrapped = self.control_model_wrapped + self.copy_to(c) + return c + class ControlLoraOps: class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp): def __init__(self, in_features: int, out_features: int, bias: bool = True, @@ -606,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}): control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) return control + +def load_controlnet_qwen_fun(sd, model_options={}): + load_device = comfy.model_management.get_torch_device() + weight_dtype = comfy.utils.weight_dtype(sd) + unet_dtype = model_options.get("dtype", weight_dtype) + manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device) + + operations = model_options.get("custom_operations", None) + if operations is None: + operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True) + + in_features = sd["control_img_in.weight"].shape[1] + inner_dim = sd["control_img_in.weight"].shape[0] + + block_weight = sd["control_blocks.0.attn.to_q.weight"] + attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0] + num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim)) + + model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel( + control_in_features=in_features, + inner_dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + num_control_blocks=5, + main_model_double=60, + injection_layers=(0, 12, 24, 36, 48), + operations=operations, + device=comfy.model_management.unet_offload_device(), + dtype=unet_dtype, + ) + model = controlnet_load_state_dict(model, sd) + + latent_format = comfy.latent_formats.Wan21() + control = QwenFunControlNet( + model, + compression_ratio=1, + latent_format=latent_format, + # Fun checkpoints already expect their own 33-channel context handling. + # Enabling generic concat_mask injects an extra mask channel at apply-time + # and breaks the intended fallback packing path. + concat_mask=False, + load_device=load_device, + manual_cast_dtype=manual_cast_dtype, + extra_conds=[], + ) + return control + def convert_mistoline(sd): return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) @@ -683,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}): return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options) elif "controlnet_x_embedder.weight" in controlnet_data: return load_controlnet_flux_instantx(controlnet_data, model_options=model_options) + elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data: + return load_controlnet_qwen_fun(controlnet_data, model_options=model_options) elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options) diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py index a6d408104..c0aae9240 100644 --- a/comfy/ldm/qwen_image/controlnet.py +++ b/comfy/ldm/qwen_image/controlnet.py @@ -2,6 +2,196 @@ import torch import math from .model import QwenImageTransformer2DModel +from .model import QwenImageTransformerBlock + + +class QwenImageFunControlBlock(QwenImageTransformerBlock): + def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None): + super().__init__( + dim=dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + dtype=dtype, + device=device, + operations=operations, + ) + self.has_before_proj = has_before_proj + if has_before_proj: + self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype) + self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype) + + +class QwenImageFunControlNetModel(torch.nn.Module): + def __init__( + self, + control_in_features=132, + inner_dim=3072, + num_attention_heads=24, + attention_head_dim=128, + num_control_blocks=5, + main_model_double=60, + injection_layers=(0, 12, 24, 36, 48), + dtype=None, + device=None, + operations=None, + ): + super().__init__() + self.dtype = dtype + self.main_model_double = main_model_double + self.injection_layers = tuple(injection_layers) + # Keep base hint scaling at 1.0 so user-facing strength behaves similarly + # to the reference Gen2/VideoX implementation around strength=1. + self.hint_scale = 1.0 + self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype) + + self.control_blocks = torch.nn.ModuleList([]) + for i in range(num_control_blocks): + self.control_blocks.append( + QwenImageFunControlBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + has_before_proj=(i == 0), + dtype=dtype, + device=device, + operations=operations, + ) + ) + + def _process_hint_tokens(self, hint): + if hint is None: + return None + if hint.ndim == 4: + hint = hint.unsqueeze(2) + + # Fun checkpoints are trained with 33 latent channels before 2x2 packing: + # [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features. + # Default behavior (no inpaint input in stock Apply ControlNet) should use + # zeros for mask/inpaint branches, matching VideoX fallback semantics. + expected_c = self.control_img_in.weight.shape[1] // 4 + if hint.shape[1] == 16 and expected_c == 33: + zeros_mask = torch.zeros_like(hint[:, :1]) + zeros_inpaint = torch.zeros_like(hint) + hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1) + + bs, c, t, h, w = hint.shape + hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2)) + orig_shape = hidden_states.shape + hidden_states = hidden_states.view( + orig_shape[0], + orig_shape[1], + orig_shape[-3], + orig_shape[-2] // 2, + 2, + orig_shape[-1] // 2, + 2, + ) + hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6) + hidden_states = hidden_states.reshape( + bs, + t * ((h + 1) // 2) * ((w + 1) // 2), + c * 4, + ) + + expected_in = self.control_img_in.weight.shape[1] + cur_in = hidden_states.shape[-1] + if cur_in < expected_in: + pad = torch.zeros( + (hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + hidden_states = torch.cat([hidden_states, pad], dim=-1) + elif cur_in > expected_in: + hidden_states = hidden_states[:, :, :expected_in] + + return hidden_states + + def forward( + self, + x, + timesteps, + context, + attention_mask=None, + guidance: torch.Tensor = None, + hint=None, + transformer_options={}, + base_model=None, + **kwargs, + ): + if base_model is None: + raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.") + + encoder_hidden_states_mask = attention_mask + # Keep attention mask disabled inside Fun control blocks to mirror + # VideoX behavior (they rely on seq lengths for RoPE, not masked attention). + encoder_hidden_states_mask = None + + hidden_states, img_ids, _ = base_model.process_img(x) + hint_tokens = self._process_hint_tokens(hint) + if hint_tokens is None: + raise RuntimeError("Qwen Fun ControlNet requires a control hint image.") + + if hint_tokens.shape[1] != hidden_states.shape[1]: + max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1]) + hint_tokens = hint_tokens[:, :max_tokens] + hidden_states = hidden_states[:, :max_tokens] + img_ids = img_ids[:, :max_tokens] + + txt_start = round( + max( + ((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2, + ((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2, + ) + ) + txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) + ids = torch.cat((txt_ids, img_ids), dim=1) + image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous() + + hidden_states = base_model.img_in(hidden_states) + encoder_hidden_states = base_model.txt_norm(context) + encoder_hidden_states = base_model.txt_in(encoder_hidden_states) + + if guidance is not None: + guidance = guidance * 1000 + + temb = ( + base_model.time_text_embed(timesteps, hidden_states) + if guidance is None + else base_model.time_text_embed(timesteps, guidance, hidden_states) + ) + + c = self.control_img_in(hint_tokens) + + for i, block in enumerate(self.control_blocks): + if i == 0: + c_in = block.before_proj(c) + hidden_states + all_c = [] + else: + all_c = list(torch.unbind(c, dim=0)) + c_in = all_c.pop(-1) + + encoder_hidden_states, c_out = block( + hidden_states=c_in, + encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states_mask=encoder_hidden_states_mask, + temb=temb, + image_rotary_emb=image_rotary_emb, + transformer_options=transformer_options, + ) + + c_skip = block.after_proj(c_out) * self.hint_scale + all_c += [c_skip, c_out] + c = torch.stack(all_c, dim=0) + + hints = torch.unbind(c, dim=0)[:-1] + + controlnet_block_samples = [None] * self.main_model_double + for local_idx, base_idx in enumerate(self.injection_layers): + if local_idx < len(hints) and base_idx < len(controlnet_block_samples): + controlnet_block_samples[base_idx] = hints[local_idx] + + return {"input": controlnet_block_samples} class QwenImageControlNetModel(QwenImageTransformer2DModel):