From b94d394a64dd0af06bca44b96c66549bb463331d Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 2 Dec 2025 18:38:31 -0800 Subject: [PATCH 1/6] Support Z Image alibaba pai fun controlnets. (#11062) These are not actual controlnets so put it in the models/model_patches folder and use the ModelPatchLoader + QwenImageDiffsynthControlnet node to use it. --- comfy/ldm/lumina/controlnet.py | 113 ++++++++++++++++++++++++++++++ comfy/ldm/lumina/model.py | 24 ++++--- comfy_extras/nodes_model_patch.py | 101 +++++++++++++++++++++++++- 3 files changed, 229 insertions(+), 9 deletions(-) create mode 100644 comfy/ldm/lumina/controlnet.py diff --git a/comfy/ldm/lumina/controlnet.py b/comfy/ldm/lumina/controlnet.py new file mode 100644 index 000000000..fd7ce3b5c --- /dev/null +++ b/comfy/ldm/lumina/controlnet.py @@ -0,0 +1,113 @@ +import torch +from torch import nn + +from .model import JointTransformerBlock + +class ZImageControlTransformerBlock(JointTransformerBlock): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + modulation=True, + block_id=0, + operation_settings=None, + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings) + self.block_id = block_id + if block_id == 0: + self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + return c_skip, c + +class ZImage_Control(torch.nn.Module): + def __init__( + self, + dim: int = 3840, + n_heads: int = 30, + n_kv_heads: int = 30, + multiple_of: int = 256, + ffn_dim_multiplier: float = (8.0 / 3.0), + norm_eps: float = 1e-5, + qk_norm: bool = True, + dtype=None, + device=None, + operations=None, + **kwargs + ): + super().__init__() + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.additional_in_dim = 0 + self.control_in_dim = 16 + n_refiner_layers = 2 + self.n_control_layers = 6 + self.control_layers = nn.ModuleList( + [ + ZImageControlTransformerBlock( + i, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + block_id=i, + operation_settings=operation_settings, + ) + for i in range(self.n_control_layers) + ] + ) + + all_x_embedder = {} + patch_size = 2 + f_patch_size = 1 + x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) + self.control_noise_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=True, + z_image_modulation=True, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input): + patch_size = 2 + f_patch_size = 1 + pH = pW = patch_size + B, C, H, W = control_context.shape + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) + + x_attn_mask = None + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) + return control_context + + def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): + return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index 070b5da09..f1c1a0ec3 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -568,7 +568,7 @@ class NextDiT(nn.Module): ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) # def forward(self, x, t, cap_feats, cap_mask): - def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs): t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask @@ -585,16 +585,24 @@ class NextDiT(nn.Module): cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute + patches = transformer_options.get("patches", {}) transformer_options = kwargs.get("transformer_options", {}) x_is_tensor = isinstance(x, torch.Tensor) - x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) - freqs_cis = freqs_cis.to(x.device) + img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) + freqs_cis = freqs_cis.to(img.device) - for layer in self.layers: - x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + for i, layer in enumerate(self.layers): + img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + if "img" in out: + img[:, cap_size[0]:] = out["img"] + if "txt" in out: + img[:, :cap_size[0]] = out["txt"] - x = self.final_layer(x, adaln_input) - x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] + img = self.final_layer(img, adaln_input) + img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] - return -x + return -img diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 783c59b6b..c61810dbf 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -6,6 +6,7 @@ import comfy.ops import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats +import comfy.ldm.lumina.controlnet class BlockWiseControlBlock(torch.nn.Module): @@ -189,6 +190,35 @@ class SigLIPMultiFeatProjModel(torch.nn.Module): return embedding +def z_image_convert(sd): + replace_keys = {".attention.to_out.0.bias": ".attention.out.bias", + ".attention.norm_k.weight": ".attention.k_norm.weight", + ".attention.norm_q.weight": ".attention.q_norm.weight", + ".attention.to_out.0.weight": ".attention.out.weight" + } + + out_sd = {} + for k in sorted(sd.keys()): + w = sd[k] + + k_out = k + if k_out.endswith(".attention.to_k.weight"): + cc = [w] + continue + if k_out.endswith(".attention.to_q.weight"): + cc = [w] + cc + continue + if k_out.endswith(".attention.to_v.weight"): + cc = cc + [w] + w = torch.cat(cc, dim=0) + k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight") + + for r, rr in replace_keys.items(): + k_out = k_out.replace(r, rr) + out_sd[k_out] = w + + return out_sd + class ModelPatchLoader: @classmethod def INPUT_TYPES(s): @@ -211,6 +241,9 @@ class ModelPatchLoader: elif 'feature_embedder.mid_layer_norm.bias' in sd: sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True) model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet + sd = z_image_convert(sd) + model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) model.load_state_dict(sd) model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) @@ -263,6 +296,69 @@ class DiffSynthCnetPatch: def models(self): return [self.model_patch] +class ZImageControlPatch: + def __init__(self, model_patch, vae, image, strength): + self.model_patch = model_patch + self.vae = vae + self.image = image + self.strength = strength + self.encoded_image = self.encode_latent_cond(image) + self.encoded_image_size = (image.shape[1], image.shape[2]) + self.temp_data = None + + def encode_latent_cond(self, image): + latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image)) + return latent_image + + def __call__(self, kwargs): + x = kwargs.get("x") + img = kwargs.get("img") + txt = kwargs.get("txt") + pe = kwargs.get("pe") + vec = kwargs.get("vec") + block_index = kwargs.get("block_index") + spacial_compression = self.vae.spacial_compression_encode() + if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression): + image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1)) + self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1]) + comfy.model_management.load_models_gpu(loaded_models) + + cnet_index = (block_index // 5) + cnet_index_float = (block_index / 5) + + kwargs.pop("img") # we do ops in place + kwargs.pop("txt") + + cnet_blocks = self.model_patch.model.n_control_layers + if cnet_index_float > (cnet_blocks - 1): + self.temp_data = None + return kwargs + + if self.temp_data is None or self.temp_data[0] > cnet_index: + self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec))) + + while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks: + next_layer = self.temp_data[0] + 1 + self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec)) + + if cnet_index_float == self.temp_data[0]: + img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength) + if cnet_blocks == self.temp_data[0] + 1: + self.temp_data = None + + return kwargs + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.encoded_image = self.encoded_image.to(device_or_dtype) + self.temp_data = None + return self + + def models(self): + return [self.model_patch] + class QwenImageDiffsynthControlnet: @classmethod def INPUT_TYPES(s): @@ -289,7 +385,10 @@ class QwenImageDiffsynthControlnet: mask = mask.unsqueeze(2) mask = 1.0 - mask - model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) + if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control): + model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength)) + else: + model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) return (model_patched,) From 3f512f5659cfbb3c53999cde6ff557591740252b Mon Sep 17 00:00:00 2001 From: Jim Heising Date: Tue, 2 Dec 2025 19:29:27 -0800 Subject: [PATCH 2/6] Added PATCH method to CORS headers (#11066) Added PATCH http method to access-control-allow-header-methods header because there are now PATCH endpoints exposed in the API. See https://github.com/comfyanonymous/ComfyUI/blob/277237ccc1499bac7fcd221a666dfe7a32ac4206/api_server/routes/internal/internal_routes.py#L34 for an example of an API endpoint that uses the PATCH method. --- server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server.py b/server.py index e3bd056d9..ac4f42222 100644 --- a/server.py +++ b/server.py @@ -98,7 +98,7 @@ def create_cors_middleware(allowed_origin: str): response = await handler(request) response.headers['Access-Control-Allow-Origin'] = allowed_origin - response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' + response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS, PATCH' response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization' response.headers['Access-Control-Allow-Credentials'] = 'true' return response From 73f5649196f472d3719e2e7513e0a9d029cc3e38 Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 3 Dec 2025 13:49:29 +1000 Subject: [PATCH 3/6] Implement temporal rolling VAE (Major VRAM reductions in Hunyuan and Kandinsky) (#10995) * hunyuan upsampler: rework imports Remove the transitive import of VideoConv3d and Resnet and takes these from actual implementation source. * model: remove unused give_pre_end According to git grep, this is not used now, and was not used in the initial commit that introduced it (see below). This semantic is difficult to implement temporal roll VAE for (and would defeat the purpose). Rather than implement the complex if, just delete the unused feature. (venv) rattus@rattus-box2:~/ComfyUI$ git log --oneline 220afe33 (HEAD) Initial commit. (venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end: (venv) rattus@rattus-box2:~/ComfyUI$ git co origin/master Previous HEAD position was 220afe33 Initial commit. HEAD is now at 9d8a8179 Enable async offloading by default on Nvidia. (#10953) (venv) rattus@rattus-box2:~/ComfyUI$ git grep give_pre comfy/ldm/modules/diffusionmodules/model.py: resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, comfy/ldm/modules/diffusionmodules/model.py: self.give_pre_end = give_pre_end comfy/ldm/modules/diffusionmodules/model.py: if self.give_pre_end: * move refiner VAE temporal roller to core Move the carrying conv op to the common VAE code and give it a better name. Roll the carry implementation logic for Resnet into the base class and scrap the Hunyuan specific subclass. * model: Add temporal roll to main VAE decoder If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolloing VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings). * model: Add temporal roll to main VAE encoder If there are no attention layers, its a standard resnet and VideoConv3d is asked for, substitute in the temporal rolling VAE algorithm. This reduces VAE usage by the temporal dimension (can be huge VRAM savings). --- comfy/ldm/hunyuan_video/upsampler.py | 3 +- comfy/ldm/hunyuan_video/vae_refiner.py | 94 +++------ comfy/ldm/modules/diffusionmodules/model.py | 207 ++++++++++++++------ 3 files changed, 174 insertions(+), 130 deletions(-) diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py index 9f5e91a59..85f515f67 100644 --- a/comfy/ldm/hunyuan_video/upsampler.py +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -1,7 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm, ResnetBlock, VideoConv3d +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d +from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm import model_management, model_patcher class SRResidualCausalBlock3D(nn.Module): diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index 9f750dcc4..ddf77cd0e 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -1,42 +1,12 @@ import torch import torch.nn as nn import torch.nn.functional as F -from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed import comfy.ops import comfy.ldm.models.autoencoder import comfy.model_management ops = comfy.ops.disable_weight_init -class NoPadConv3d(nn.Module): - def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): - super().__init__() - self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) - - def forward(self, x): - return self.conv(x) - - -def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): - - x = xl[0] - xl.clear() - - if conv_carry_out is not None: - to_push = x[:, :, -2:, :, :].clone() - conv_carry_out.append(to_push) - - if isinstance(op, NoPadConv3d): - if conv_carry_in is None: - x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') - else: - carry_len = conv_carry_in[0].shape[2] - x = torch.cat([conv_carry_in.pop(0), x], dim=2) - x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate') - - out = op(x) - - return out - class RMS_norm(nn.Module): def __init__(self, dim): @@ -49,7 +19,7 @@ class RMS_norm(nn.Module): return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) class DnSmpl(nn.Module): - def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): + def __init__(self, ic, oc, tds, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tds else 1 * 2 * 2 assert oc % fct == 0 @@ -109,7 +79,7 @@ class DnSmpl(nn.Module): class UpSmpl(nn.Module): - def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): + def __init__(self, ic, oc, tus, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tus else 1 * 2 * 2 self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) @@ -163,23 +133,6 @@ class UpSmpl(nn.Module): return h + x -class HunyuanRefinerResnetBlock(ResnetBlock): - def __init__(self, in_channels, out_channels, conv_op=NoPadConv3d, norm_op=RMS_norm): - super().__init__(in_channels=in_channels, out_channels=out_channels, temb_channels=0, conv_op=conv_op, norm_op=norm_op) - - def forward(self, x, conv_carry_in=None, conv_carry_out=None): - h = x - h = [ self.swish(self.norm1(x)) ] - h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) - - h = [ self.dropout(self.swish(self.norm2(h))) ] - h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) - - if self.in_channels != self.out_channels: - x = self.nin_shortcut(x) - - return x+h - class Encoder(nn.Module): def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks, ffactor_spatial, ffactor_temporal, downsample_match_channel=True, refiner_vae=True, **_): @@ -191,7 +144,7 @@ class Encoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = NoPadConv3d + conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -206,9 +159,10 @@ class Encoder(nn.Module): for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks)]) ch = tgt if i < depth: @@ -218,9 +172,9 @@ class Encoder(nn.Module): self.down.append(stage) self.mid = nn.Module() - self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.norm_out = norm_op(ch) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) @@ -246,22 +200,20 @@ class Encoder(nn.Module): conv_carry_out = [] if i == len(x) - 1: conv_carry_out = None + x1 = [ x1 ] x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) for stage in self.down: for blk in stage.block: - x1 = blk(x1, conv_carry_in, conv_carry_out) + x1 = blk(x1, None, conv_carry_in, conv_carry_out) if hasattr(stage, 'downsample'): x1 = stage.downsample(x1, conv_carry_in, conv_carry_out) out.append(x1) conv_carry_in = conv_carry_out - if len(out) > 1: - out = torch.cat(out, dim=2) - else: - out = out[0] + out = torch_cat_if_needed(out, dim=2) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out))) del out @@ -288,7 +240,7 @@ class Decoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = NoPadConv3d + conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -298,9 +250,9 @@ class Decoder(nn.Module): self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() - self.mid.block_1 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = HunyuanRefinerResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() @@ -308,9 +260,10 @@ class Decoder(nn.Module): for i, tgt in enumerate(block_out_channels): stage = nn.Module() - stage.block = nn.ModuleList([HunyuanRefinerResnetBlock(in_channels=ch if j == 0 else tgt, - out_channels=tgt, - conv_op=conv_op, norm_op=norm_op) + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_op=conv_op, norm_op=norm_op) for j in range(num_res_blocks + 1)]) ch = tgt if i < depth: @@ -340,7 +293,7 @@ class Decoder(nn.Module): conv_carry_out = None for stage in self.up: for blk in stage.block: - x1 = blk(x1, conv_carry_in, conv_carry_out) + x1 = blk(x1, None, conv_carry_in, conv_carry_out) if hasattr(stage, 'upsample'): x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) @@ -350,10 +303,7 @@ class Decoder(nn.Module): conv_carry_in = conv_carry_out del x - if len(out) > 1: - out = torch.cat(out, dim=2) - else: - out = out[0] + out = torch_cat_if_needed(out, dim=2) if not self.refiner_vae: if z.shape[-3] == 1: diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index de1e01cc8..681a55db5 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae(): import xformers import xformers.ops +def torch_cat_if_needed(xl, dim): + if len(xl) > 1: + return torch.cat(xl, dim) + else: + return xl[0] + def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: @@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32): return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) +class CarriedConv3d(nn.Module): + def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): + super().__init__() + self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + return self.conv(x) + + +def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): + + x = xl[0] + xl.clear() + + if isinstance(op, CarriedConv3d): + if conv_carry_in is None: + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate') + else: + carry_len = conv_carry_in[0].shape[2] + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate') + x = torch.cat([conv_carry_in.pop(0), x], dim=2) + + if conv_carry_out is not None: + to_push = x[:, :, -2:, :, :].clone() + conv_carry_out.append(to_push) + + out = op(x) + + return out + + class VideoConv3d(nn.Module): def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs): super().__init__() @@ -89,29 +126,24 @@ class Upsample(nn.Module): stride=1, padding=1) - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): scale_factor = self.scale_factor if isinstance(scale_factor, (int, float)): scale_factor = (scale_factor,) * (x.ndim - 2) if x.ndim == 5 and scale_factor[0] > 1.0: - t = x.shape[2] - if t > 1: - a, b = x.split((1, t - 1), dim=2) - del x - b = interpolate_up(b, scale_factor) - else: - a = x - - a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2) - if t > 1: - x = torch.cat((a, b), dim=2) - else: - x = a + results = [] + if conv_carry_in is None: + first = x[:, :, :1, :, :] + results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)) + x = x[:, :, 1:, :, :] + if x.shape[2] > 0: + results.append(interpolate_up(x, scale_factor)) + x = torch_cat_if_needed(results, dim=2) else: x = interpolate_up(x, scale_factor) if self.with_conv: - x = self.conv(x) + x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) return x @@ -127,17 +159,20 @@ class Downsample(nn.Module): stride=stride, padding=0) - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): if self.with_conv: - if x.ndim == 4: + if isinstance(self.conv, CarriedConv3d): + x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) + elif x.ndim == 4: pad = (0, 1, 0, 1) mode = "constant" x = torch.nn.functional.pad(x, pad, mode=mode, value=0) + x = self.conv(x) elif x.ndim == 5: pad = (1, 1, 1, 1, 2, 0) mode = "replicate" x = torch.nn.functional.pad(x, pad, mode=mode) - x = self.conv(x) + x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x @@ -183,23 +218,23 @@ class ResnetBlock(nn.Module): stride=1, padding=0) - def forward(self, x, temb=None): + def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None): h = x h = self.norm1(h) - h = self.swish(h) - h = self.conv1(h) + h = [ self.swish(h) ] + h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if temb is not None: h = h + self.temb_proj(self.swish(temb))[:,:,None,None] h = self.norm2(h) h = self.swish(h) - h = self.dropout(h) - h = self.conv2(h) + h = [ self.dropout(h) ] + h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - x = self.conv_shortcut(x) + x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) else: x = self.nin_shortcut(x) @@ -520,9 +555,14 @@ class Encoder(nn.Module): self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels + self.carried = False if conv3d: - conv_op = VideoConv3d + if not attn_resolutions: + conv_op = CarriedConv3d + self.carried = True + else: + conv_op = VideoConv3d mid_attn_conv_op = ops.Conv3d else: conv_op = ops.Conv2d @@ -535,6 +575,7 @@ class Encoder(nn.Module): stride=1, padding=1) + self.time_compress = 1 curr_res = resolution in_ch_mult = (1,)+tuple(ch_mult) self.in_ch_mult = in_ch_mult @@ -561,10 +602,15 @@ class Encoder(nn.Module): if time_compress is not None: if (self.num_resolutions - 1 - i_level) > math.log2(time_compress): stride = (1, 2, 2) + else: + self.time_compress *= 2 down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op) curr_res = curr_res // 2 self.down.append(down) + if time_compress is not None: + self.time_compress = time_compress + # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, @@ -590,15 +636,42 @@ class Encoder(nn.Module): def forward(self, x): # timestep embedding temb = None - # downsampling - h = self.conv_in(x) - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](h, temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - if i_level != self.num_resolutions-1: - h = self.down[i_level].downsample(h) + + if self.carried: + xl = [x[:, :, :1, :, :]] + if x.shape[2] > self.time_compress: + tc = self.time_compress + xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2) + x = xl + else: + x = [x] + out = [] + + conv_carry_in = None + + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + + # downsampling + x1 = [ x1 ] + h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) + + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out) + if len(self.down[i_level].attn) > 0: + assert i == 0 #carried should not happen if attn exists + h1 = self.down[i_level].attn[i_block](h1) + if i_level != self.num_resolutions-1: + h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out) + + out.append(h1) + conv_carry_in = conv_carry_out + + h = torch_cat_if_needed(out, dim=2) + del out # middle h = self.mid.block_1(h, temb) @@ -607,15 +680,15 @@ class Encoder(nn.Module): # end h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h = [ nonlinearity(h) ] + h = conv_carry_causal_3d(h, self.conv_out) return h class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + resolution, z_channels, tanh_out=False, use_linear_attn=False, conv_out_op=ops.Conv2d, resnet_op=ResnetBlock, attn_op=AttnBlock, @@ -629,12 +702,18 @@ class Decoder(nn.Module): self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels - self.give_pre_end = give_pre_end self.tanh_out = tanh_out + self.carried = False if conv3d: - conv_op = VideoConv3d - conv_out_op = VideoConv3d + if not attn_resolutions and resnet_op == ResnetBlock: + conv_op = CarriedConv3d + conv_out_op = CarriedConv3d + self.carried = True + else: + conv_op = VideoConv3d + conv_out_op = VideoConv3d + mid_attn_conv_op = ops.Conv3d else: conv_op = ops.Conv2d @@ -709,29 +788,43 @@ class Decoder(nn.Module): temb = None # z to block_in - h = self.conv_in(z) + h = conv_carry_causal_3d([z], self.conv_in) # middle h = self.mid.block_1(h, temb, **kwargs) h = self.mid.attn_1(h, **kwargs) h = self.mid.block_2(h, temb, **kwargs) + if self.carried: + h = torch.split(h, 2, dim=2) + else: + h = [ h ] + out = [] + + conv_carry_in = None + # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks+1): - h = self.up[i_level].block[i_block](h, temb, **kwargs) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h, **kwargs) - if i_level != 0: - h = self.up[i_level].upsample(h) + for i, h1 in enumerate(h): + conv_carry_out = [] + if i == len(h) - 1: + conv_carry_out = None + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs) + if len(self.up[i_level].attn) > 0: + assert i == 0 #carried should not happen if attn exists + h1 = self.up[i_level].attn[i_block](h1, **kwargs) + if i_level != 0: + h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out) - # end - if self.give_pre_end: - return h + h1 = self.norm_out(h1) + h1 = [ nonlinearity(h1) ] + h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out) + if self.tanh_out: + h1 = torch.tanh(h1) + out.append(h1) + conv_carry_in = conv_carry_out - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h, **kwargs) - if self.tanh_out: - h = torch.tanh(h) - return h + out = torch_cat_if_needed(out, dim=2) + + return out From c120eee5bacca643062657d2a7efad83c7d4d828 Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Tue, 2 Dec 2025 21:17:13 -0800 Subject: [PATCH 4/6] Add MatchType, DynamicCombo, and Autogrow support to V3 Schema (#10832) * Added output_matchtypes to generated json for v3, initial backend support for MatchType, created nodes_logic.py and added SwitchNode * Fixed providing list of allowed_types * Add workaround in validation.py for V3 Combo outputs not working as Combo inputs * Make match type receive_type pass validation * Also add MatchType check to input_type in validation - will likely trigger when connecting to non-lazy stuff * Make sure this PR only has MatchType stuff * Initial work on DynamicCombo * Add get_dynamic function, not yet filled out correctly * Mark Switch node as Beta * Make sure other unfinished dynamic types are not accidentally used * Send DynamicCombo.Option inputs in the same format as normal v1 inputs * add dynamic combo test node * Support validation of inputs and outputs * Add missing input params to DynamicCombo.Input * Add get_all function to inputs for id validation purposes * Fix imports for v3 returning everything when doing io/ui/IO/UI instead of what is in __all__ of _io.py and _ui.py * Modifying behavior of get_dynamic in V3 + serialization so can be used in execution code * Fix v3 schema validation code after changes * Refactor hidden_values for v3 in execution.py to be more general v3_data, add helper functions for dynamic behavior, preparing for restructuring dynamic type into object (not finished yet) * Add nesting of inputs on DynamicCombo during execution * Work with latest frontend commits * Fix cringe arrows * frontend will no longer namespace dynamic inputs widgets so reflect that in code, refactor build_nested_inputs * Prepare Autogrow support for the love of the game * satisfy ruff * Create test nodes for Autogrow to collab with frontend development * Add nested combo to DCTestNode * Remove array support from build_nested_inputs, properly handle missing expected values * Make execution.validate_inputs properly validate required dynamic inputs, renamed dynamic_data to dynamic_paths for clarity * MatchType does not need any DynamicInput/Output features on backend; will increase compatibility with dynamic types * Probably need this for ruff check * Change MatchType to have template be the first and only required param; output id's do nothing right now, so no need * Fix merge regression with LatentUpscaleModel type not being put in __all__ for _io.py, fix invalid type hint for validate_inputs * Make Switch node inputs optional, disallow both inputs from being missing, and still work properly with lazy; when one input is missing, use the other no matter what the switch is set to * Satisfy ruff * Move MatchType code above the types that inherit from DynamicInput * Add DynamicSlot type, awaiting frontend support * Make curr_prefix creation happen in Autogrow, move curr_prefix in DynamicCombo to only be created if input exists in live_inputs * I was confused, fixing accidentally redundant curr_prefix addition in Autogrow * Make sure Autogrow inputs are force_input = True when WidgetInput, fix runtime validation by removing original input from expected inputs, fix min/max bounds, change test nodes slightly * Remove unnecessary id usage in Autogrow test node outputs * Commented out Switch node + test nodes * Remove commented out code from Autogrow * Make TemplatePrefix max more clear, allow max == 1 * Replace all dict[str] with dict[str, Any] * Renamed add_to_dict_live_inputs to expand_schema_for_dynamic * Fixed typo in DynamicSlot input code * note about live_inputs not being present soon in get_v1_info (internal function anyway) * For now, hide DynamicCombo and Autogrow from public interface * Removed comment --- comfy_api/latest/__init__.py | 4 +- comfy_api/latest/_io.py | 416 ++++++++++++++++++++++++++------- comfy_api/latest/_io_public.py | 1 + comfy_api/latest/_ui_public.py | 1 + comfy_api/v0_0_2/__init__.py | 6 +- comfy_execution/validation.py | 6 + comfy_extras/nodes_logic.py | 155 ++++++++++++ execution.py | 40 ++-- nodes.py | 1 + 9 files changed, 525 insertions(+), 105 deletions(-) create mode 100644 comfy_api/latest/_io_public.py create mode 100644 comfy_api/latest/_ui_public.py create mode 100644 comfy_extras/nodes_logic.py diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index 176ae36e0..0fa01d1e7 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -8,8 +8,8 @@ from comfy_api.internal.async_to_sync import create_sync_class from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL -from . import _io as io -from . import _ui as ui +from . import _io_public as io +from . import _ui_public as ui # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 from comfy_execution.utils import get_executing_context from comfy_execution.progress import get_progress_state, PreviewImageTuple diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 79c0722a9..257f07c42 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -4,6 +4,7 @@ import copy import inspect from abc import ABC, abstractmethod from collections import Counter +from collections.abc import Iterable from dataclasses import asdict, dataclass from enum import Enum from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING @@ -150,6 +151,9 @@ class _IO_V3: def __init__(self): pass + def validate(self): + pass + @property def io_type(self): return self.Parent.io_type @@ -182,6 +186,9 @@ class Input(_IO_V3): def get_io_type(self): return _StringIOType(self.io_type) + def get_all(self) -> list[Input]: + return [self] + class WidgetInput(Input): ''' Base class for a V3 Input with widget. @@ -814,13 +821,61 @@ class MultiType: else: return super().as_dict() +@comfytype(io_type="COMFY_MATCHTYPE_V3") +class MatchType(ComfyTypeIO): + class Template: + def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType): + self.template_id = template_id + # account for syntactic sugar + if not isinstance(allowed_types, Iterable): + allowed_types = [allowed_types] + for t in allowed_types: + if not isinstance(t, type): + if not isinstance(t, _ComfyType): + raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}") + else: + if not issubclass(t, _ComfyType): + raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}") + self.allowed_types = allowed_types + + def as_dict(self): + return { + "template_id": self.template_id, + "allowed_types": ",".join([t.io_type for t in self.allowed_types]), + } + + class Input(Input): + def __init__(self, id: str, template: MatchType.Template, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + + class Output(Output): + def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None, + is_output_list=False): + super().__init__(id, display_name, tooltip, is_output_list) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + class DynamicInput(Input, ABC): ''' Abstract class for dynamic input registration. ''' - @abstractmethod def get_dynamic(self) -> list[Input]: - ... + return [] + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + pass + class DynamicOutput(Output, ABC): ''' @@ -830,99 +885,223 @@ class DynamicOutput(Output, ABC): is_output_list=False): super().__init__(id, display_name, tooltip, is_output_list) - @abstractmethod def get_dynamic(self) -> list[Output]: - ... + return [] @comfytype(io_type="COMFY_AUTOGROW_V3") -class AutogrowDynamic(ComfyTypeI): - Type = list[Any] - class Input(DynamicInput): - def __init__(self, id: str, template_input: Input, min: int=1, max: int=None, - display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) - self.template_input = template_input - if min is not None: - assert(min >= 1) - if max is not None: - assert(max >= 1) +class Autogrow(ComfyTypeI): + Type = dict[str, Any] + _MaxNames = 100 # NOTE: max 100 names for sanity + + class _AutogrowTemplate: + def __init__(self, input: Input): + # dynamic inputs are not allowed as the template input + assert(not isinstance(input, DynamicInput)) + self.input = copy.copy(input) + if isinstance(self.input, WidgetInput): + self.input.force_input = True + self.names: list[str] = [] + self.cached_inputs = {} + + def _create_input(self, input: Input, name: str): + new_input = copy.copy(self.input) + new_input.id = name + return new_input + + def _create_cached_inputs(self): + for name in self.names: + self.cached_inputs[name] = self._create_input(self.input, name) + + def get_all(self) -> list[Input]: + return list(self.cached_inputs.values()) + + def as_dict(self): + return prune_dict({ + "input": create_input_dict_v1([self.input]), + }) + + def validate(self): + self.input.validate() + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + real_inputs = [] + for name, input in self.cached_inputs.items(): + if name in live_inputs: + real_inputs.append(input) + add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, real_inputs, curr_prefix) + + class TemplatePrefix(_AutogrowTemplate): + def __init__(self, input: Input, prefix: str, min: int=1, max: int=10): + super().__init__(input) + self.prefix = prefix + assert(min >= 0) + assert(max >= 1) + assert(max <= Autogrow._MaxNames) self.min = min self.max = max + self.names = [f"{self.prefix}{i}" for i in range(self.max)] + self._create_cached_inputs() + + def as_dict(self): + return super().as_dict() | prune_dict({ + "prefix": self.prefix, + "min": self.min, + "max": self.max, + }) + + class TemplateNames(_AutogrowTemplate): + def __init__(self, input: Input, names: list[str], min: int=1): + super().__init__(input) + self.names = names[:Autogrow._MaxNames] + assert(min >= 0) + self.min = min + self._create_cached_inputs() + + def as_dict(self): + return super().as_dict() | prune_dict({ + "names": self.names, + "min": self.min, + }) + + class Input(DynamicInput): + def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) def get_dynamic(self) -> list[Input]: - curr_count = 1 - new_inputs = [] - for i in range(self.min): - new_input = copy.copy(self.template_input) - new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$" - if new_input.display_name is not None: - new_input.display_name = f"{new_input.display_name}{curr_count}" - new_input.optional = self.optional or new_input.optional - if isinstance(self.template_input, WidgetInput): - new_input.force_input = True - new_inputs.append(new_input) - curr_count += 1 - # pretend to expand up to max - for i in range(curr_count-1, self.max): - new_input = copy.copy(self.template_input) - new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$" - if new_input.display_name is not None: - new_input.display_name = f"{new_input.display_name}{curr_count}" - new_input.optional = True - if isinstance(self.template_input, WidgetInput): - new_input.force_input = True - new_inputs.append(new_input) - curr_count += 1 - return new_inputs + return self.template.get_all() -@comfytype(io_type="COMFY_COMBODYNAMIC_V3") -class ComboDynamic(ComfyTypeI): - class Input(DynamicInput): - def __init__(self, id: str): - pass + def get_all(self) -> list[Input]: + return [self] + self.template.get_all() -@comfytype(io_type="COMFY_MATCHTYPE_V3") -class MatchType(ComfyTypeIO): - class Template: - def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]): - self.template_id = template_id - self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types + def validate(self): + self.template.validate() + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + curr_prefix = f"{curr_prefix}{self.id}." + # need to remove self from expected inputs dictionary; replaced by template inputs in frontend + for inner_dict in d.values(): + if self.id in inner_dict: + del inner_dict[self.id] + self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + +@comfytype(io_type="COMFY_DYNAMICCOMBO_V3") +class DynamicCombo(ComfyTypeI): + Type = dict[str, Any] + + class Option: + def __init__(self, key: str, inputs: list[Input]): + self.key = key + self.inputs = inputs def as_dict(self): return { - "template_id": self.template_id, - "allowed_types": "".join(t.io_type for t in self.allowed_types), + "key": self.key, + "inputs": create_input_dict_v1(self.inputs), } class Input(DynamicInput): - def __init__(self, id: str, template: MatchType.Template, + def __init__(self, id: str, options: list[DynamicCombo.Option], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) - self.template = template + self.options = options + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + # check if dynamic input's id is in live_inputs + if self.id in live_inputs: + curr_prefix = f"{curr_prefix}{self.id}." + key = live_inputs[self.id] + selected_option = None + for option in self.options: + if option.key == key: + selected_option = option + break + if selected_option is not None: + add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self) def get_dynamic(self) -> list[Input]: - return [self] + return [input for option in self.options for input in option.inputs] + + def get_all(self) -> list[Input]: + return [self] + [input for option in self.options for input in option.inputs] def as_dict(self): return super().as_dict() | prune_dict({ - "template": self.template.as_dict(), + "options": [o.as_dict() for o in self.options], }) - class Output(DynamicOutput): - def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None, - is_output_list=False): - super().__init__(id, display_name, tooltip, is_output_list) - self.template = template + def validate(self): + # make sure all nested inputs are validated + for option in self.options: + for input in option.inputs: + input.validate() - def get_dynamic(self) -> list[Output]: - return [self] +@comfytype(io_type="COMFY_DYNAMICSLOT_V3") +class DynamicSlot(ComfyTypeI): + Type = dict[str, Any] + + class Input(DynamicInput): + def __init__(self, slot: Input, inputs: list[Input], + display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None): + assert(not isinstance(slot, DynamicInput)) + self.slot = copy.copy(slot) + self.slot.display_name = slot.display_name if slot.display_name is not None else display_name + optional = True + self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip + self.slot.lazy = slot.lazy if slot.lazy is not None else lazy + self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict + super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict) + self.inputs = inputs + self.force_input = None + # force widget inputs to have no widgets, otherwise this would be awkward + if isinstance(self.slot, WidgetInput): + self.force_input = True + self.slot.force_input = True + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + if self.id in live_inputs: + curr_prefix = f"{curr_prefix}{self.id}." + add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix) + + def get_dynamic(self) -> list[Input]: + return [self.slot] + self.inputs + + def get_all(self) -> list[Input]: + return [self] + [self.slot] + self.inputs def as_dict(self): return super().as_dict() | prune_dict({ - "template": self.template.as_dict(), + "slotType": str(self.slot.get_io_type()), + "inputs": create_input_dict_v1(self.inputs), + "forceInput": self.force_input, }) + def validate(self): + self.slot.validate() + for input in self.inputs: + input.validate() + +def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None): + dynamic = d.setdefault("dynamic_paths", {}) + if self is not None: + dynamic[self.id] = f"{curr_prefix}{self.id}" + for i in inputs: + if not isinstance(i, DynamicInput): + dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}" + +class V3Data(TypedDict): + hidden_inputs: dict[str, Any] + dynamic_paths: dict[str, Any] class HiddenHolder: def __init__(self, unique_id: str, prompt: Any, @@ -984,6 +1163,7 @@ class NodeInfoV1: output_is_list: list[bool]=None output_name: list[str]=None output_tooltips: list[str]=None + output_matchtypes: list[str]=None name: str=None display_name: str=None description: str=None @@ -1061,7 +1241,11 @@ class Schema: '''Validate the schema: - verify ids on inputs and outputs are unique - both internally and in relation to each other ''' - input_ids = [i.id for i in self.inputs] if self.inputs is not None else [] + nested_inputs: list[Input] = [] + if self.inputs is not None: + for input in self.inputs: + nested_inputs.extend(input.get_all()) + input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else [] output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] input_set = set(input_ids) output_set = set(output_ids) @@ -1077,6 +1261,13 @@ class Schema: issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") if len(issues) > 0: raise ValueError("\n".join(issues)) + # validate inputs and outputs + if self.inputs is not None: + for input in self.inputs: + input.validate() + if self.outputs is not None: + for output in self.outputs: + output.validate() def finalize(self): """Add hidden based on selected schema options, and give outputs without ids default ids.""" @@ -1102,19 +1293,10 @@ class Schema: if output.id is None: output.id = f"_{i}_{output.io_type}_" - def get_v1_info(self, cls) -> NodeInfoV1: + def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1: + # NOTE: live_inputs will not be used anymore very soon and this will be done another way # get V1 inputs - input = { - "required": {} - } - if self.inputs: - for i in self.inputs: - if isinstance(i, DynamicInput): - dynamic_inputs = i.get_dynamic() - for d in dynamic_inputs: - add_to_dict_v1(d, input) - else: - add_to_dict_v1(i, input) + input = create_input_dict_v1(self.inputs, live_inputs) if self.hidden: for hidden in self.hidden: input.setdefault("hidden", {})[hidden.name] = (hidden.value,) @@ -1123,12 +1305,24 @@ class Schema: output_is_list = [] output_name = [] output_tooltips = [] + output_matchtypes = [] + any_matchtypes = False if self.outputs: for o in self.outputs: output.append(o.io_type) output_is_list.append(o.is_output_list) output_name.append(o.display_name if o.display_name else o.io_type) output_tooltips.append(o.tooltip if o.tooltip else None) + # special handling for MatchType + if isinstance(o, MatchType.Output): + output_matchtypes.append(o.template.template_id) + any_matchtypes = True + else: + output_matchtypes.append(None) + + # clear out lists that are all None + if not any_matchtypes: + output_matchtypes = None info = NodeInfoV1( input=input, @@ -1137,6 +1331,7 @@ class Schema: output_is_list=output_is_list, output_name=output_name, output_tooltips=output_tooltips, + output_matchtypes=output_matchtypes, name=self.node_id, display_name=self.display_name, category=self.category, @@ -1182,16 +1377,57 @@ class Schema: return info -def add_to_dict_v1(i: Input, input: dict): +def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict: + input = { + "required": {} + } + add_to_input_dict_v1(input, inputs, live_inputs) + return input + +def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''): + for i in inputs: + if isinstance(i, DynamicInput): + add_to_dict_v1(i, d) + if live_inputs is not None: + i.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + else: + add_to_dict_v1(i, d) + +def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None): key = "optional" if i.optional else "required" as_dict = i.as_dict() # for v1, we don't want to include the optional key as_dict.pop("optional", None) - input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) + if dynamic_dict is None: + value = (i.get_io_type(), as_dict) + else: + value = (i.get_io_type(), as_dict, dynamic_dict) + d.setdefault(key, {})[i.id] = value def add_to_dict_v3(io: Input | Output, d: dict): d[io.id] = (io.get_io_type(), io.as_dict()) +def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): + paths = v3_data.get("dynamic_paths", None) + if paths is None: + return values + values = values.copy() + result = {} + + for key, path in paths.items(): + parts = path.split(".") + current = result + + for i, p in enumerate(parts): + is_last = (i == len(parts) - 1) + + if is_last: + current[p] = values.pop(key, None) + else: + current = current.setdefault(p, {}) + + values.update(result) + return values class _ComfyNodeBaseInternal(_ComfyNodeInternal): @@ -1311,12 +1547,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]: + def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]: """Creates clone of real node class to prevent monkey-patching.""" c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNode] = shallow_clone_class(c_type) # set hidden - type_clone.hidden = HiddenHolder.from_dict(hidden_inputs) + type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"]) return type_clone @final @@ -1433,14 +1669,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]: + def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: schema = cls.FINALIZE_SCHEMA() - info = schema.get_v1_info(cls) + info = schema.get_v1_info(cls, live_inputs) input = info.input if not include_hidden: input.pop("hidden", None) if return_schema: - return input, schema + v3_data: V3Data = {} + dynamic = input.pop("dynamic_paths", None) + if dynamic is not None: + v3_data["dynamic_paths"] = dynamic + return input, schema, v3_data return input @final @@ -1513,7 +1753,7 @@ class ComfyNode(_ComfyNodeBaseInternal): raise NotImplementedError @classmethod - def validate_inputs(cls, **kwargs) -> bool: + def validate_inputs(cls, **kwargs) -> bool | str: """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" raise NotImplementedError @@ -1628,6 +1868,7 @@ __all__ = [ "StyleModel", "Gligen", "UpscaleModel", + "LatentUpscaleModel", "Audio", "Video", "SVG", @@ -1651,6 +1892,10 @@ __all__ = [ "SEGS", "AnyType", "MultiType", + # Dynamic Types + "MatchType", + # "DynamicCombo", + # "Autogrow", # Other classes "HiddenHolder", "Hidden", @@ -1661,4 +1906,5 @@ __all__ = [ "NodeOutput", "add_to_dict_v1", "add_to_dict_v3", + "V3Data", ] diff --git a/comfy_api/latest/_io_public.py b/comfy_api/latest/_io_public.py new file mode 100644 index 000000000..43c7680f3 --- /dev/null +++ b/comfy_api/latest/_io_public.py @@ -0,0 +1 @@ +from ._io import * # noqa: F403 diff --git a/comfy_api/latest/_ui_public.py b/comfy_api/latest/_ui_public.py new file mode 100644 index 000000000..85b11d78b --- /dev/null +++ b/comfy_api/latest/_ui_public.py @@ -0,0 +1 @@ +from ._ui import * # noqa: F403 diff --git a/comfy_api/v0_0_2/__init__.py b/comfy_api/v0_0_2/__init__.py index de0f95001..c4fa1d971 100644 --- a/comfy_api/v0_0_2/__init__.py +++ b/comfy_api/v0_0_2/__init__.py @@ -6,7 +6,7 @@ from comfy_api.latest import ( ) from typing import Type, TYPE_CHECKING from comfy_api.internal.async_to_sync import create_sync_class -from comfy_api.latest import io, ui, ComfyExtension #noqa: F401 +from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401 class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest): @@ -42,4 +42,8 @@ __all__ = [ "InputImpl", "Types", "ComfyExtension", + "io", + "IO", + "ui", + "UI", ] diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py index cec105fc9..24c0b4ed7 100644 --- a/comfy_execution/validation.py +++ b/comfy_execution/validation.py @@ -1,4 +1,5 @@ from __future__ import annotations +from comfy_api.latest import IO def validate_node_input( @@ -23,6 +24,11 @@ def validate_node_input( if not received_type != input_type: return True + # If the received type or input_type is a MatchType, we can return True immediately; + # validation for this is handled by the frontend + if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type: + return True + # Not equal, and not strings if not isinstance(received_type, str) or not isinstance(input_type, str): return False diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py new file mode 100644 index 000000000..95a6ba788 --- /dev/null +++ b/comfy_extras/nodes_logic.py @@ -0,0 +1,155 @@ +from typing import TypedDict +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_api.latest import _io + + + +class SwitchNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = io.MatchType.Template("switch") + return io.Schema( + node_id="ComfySwitchNode", + display_name="Switch", + category="logic", + is_experimental=True, + inputs=[ + io.Boolean.Input("switch"), + io.MatchType.Input("on_false", template=template, lazy=True, optional=True), + io.MatchType.Input("on_true", template=template, lazy=True, optional=True), + ], + outputs=[ + io.MatchType.Output(template=template, display_name="output"), + ], + ) + + @classmethod + def check_lazy_status(cls, switch, on_false=..., on_true=...): + # We use ... instead of None, as None is passed for connected-but-unevaluated inputs. + # This trick allows us to ignore the value of the switch and still be able to run execute(). + + # One of the inputs may be missing, in which case we need to evaluate the other input + if on_false is ...: + return ["on_true"] + if on_true is ...: + return ["on_false"] + # Normal lazy switch operation + if switch and on_true is None: + return ["on_true"] + if not switch and on_false is None: + return ["on_false"] + + @classmethod + def validate_inputs(cls, switch, on_false=..., on_true=...): + # This check happens before check_lazy_status(), so we can eliminate the case where + # both inputs are missing. + if on_false is ... and on_true is ...: + return "At least one of on_false or on_true must be connected to Switch node" + return True + + @classmethod + def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput: + if on_true is ...: + return io.NodeOutput(on_false) + if on_false is ...: + return io.NodeOutput(on_true) + return io.NodeOutput(on_true if switch else on_false) + + +class DCTestNode(io.ComfyNode): + class DCValues(TypedDict): + combo: str + string: str + integer: int + image: io.Image.Type + subcombo: dict[str] + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DCTestNode", + display_name="DCTest", + category="logic", + is_output_node=True, + inputs=[_io.DynamicCombo.Input("combo", options=[ + _io.DynamicCombo.Option("option1", [io.String.Input("string")]), + _io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), + _io.DynamicCombo.Option("option3", [io.Image.Input("image")]), + _io.DynamicCombo.Option("option4", [ + _io.DynamicCombo.Input("subcombo", options=[ + _io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), + _io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), + ]) + ])] + )], + outputs=[io.AnyType.Output()], + ) + + @classmethod + def execute(cls, combo: DCValues) -> io.NodeOutput: + combo_val = combo["combo"] + if combo_val == "option1": + return io.NodeOutput(combo["string"]) + elif combo_val == "option2": + return io.NodeOutput(combo["integer"]) + elif combo_val == "option3": + return io.NodeOutput(combo["image"]) + elif combo_val == "option4": + return io.NodeOutput(f"{combo['subcombo']}") + else: + raise ValueError(f"Invalid combo: {combo_val}") + + +class AutogrowNamesTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"]) + return io.Schema( + node_id="AutogrowNamesTestNode", + display_name="AutogrowNamesTest", + category="logic", + inputs=[ + _io.Autogrow.Input("autogrow", template=template) + ], + outputs=[io.String.Output()], + ) + + @classmethod + def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput: + vals = list(autogrow.values()) + combined = ",".join([str(x) for x in vals]) + return io.NodeOutput(combined) + +class AutogrowPrefixTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10) + return io.Schema( + node_id="AutogrowPrefixTestNode", + display_name="AutogrowPrefixTest", + category="logic", + inputs=[ + _io.Autogrow.Input("autogrow", template=template) + ], + outputs=[io.String.Output()], + ) + + @classmethod + def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput: + vals = list(autogrow.values()) + combined = ",".join([str(x) for x in vals]) + return io.NodeOutput(combined) + +class LogicExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + # SwitchNode, + # DCTestNode, + # AutogrowNamesTestNode, + # AutogrowPrefixTestNode, + ] + +async def comfy_entrypoint() -> LogicExtension: + return LogicExtension() diff --git a/execution.py b/execution.py index 17c77beab..c2186ac98 100644 --- a/execution.py +++ b/execution.py @@ -34,7 +34,7 @@ from comfy_execution.validation import validate_node_input from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler from comfy_execution.utils import CurrentNodeContext from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func -from comfy_api.latest import io +from comfy_api.latest import io, _io class ExecutionResult(Enum): @@ -76,7 +76,7 @@ class IsChangedCache: return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) try: is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) is_changed = await resolve_map_node_over_list_results(is_changed) @@ -146,8 +146,9 @@ SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data={}): is_v3 = issubclass(class_def, _ComfyNodeInternal) + v3_data: io.V3Data = {} if is_v3: - valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) + valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) else: valid_inputs = class_def.INPUT_TYPES() input_data_all = {} @@ -207,7 +208,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt= input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] if h[x] == "API_KEY_COMFY_ORG": input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] - return input_data_all, missing_keys, hidden_inputs_v3 + v3_data["hidden_inputs"] = hidden_inputs_v3 + return input_data_all, missing_keys, v3_data map_node_over_list = None #Don't hook this please @@ -223,7 +225,7 @@ async def resolve_map_node_over_list_results(results): raise exc return [x.result() if isinstance(x, asyncio.Task) else x for x in results] -async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): +async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None): # check if node wants the lists input_is_list = getattr(obj, "INPUT_IS_LIST", False) @@ -259,13 +261,16 @@ async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, f if is_class(obj): type_obj = obj obj.VALIDATE_CLASS() - class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs) + class_clone = obj.PREPARE_CLASS_CLONE(v3_data) # otherwise, use class instance to populate/reuse some fields else: type_obj = type(obj) type_obj.VALIDATE_CLASS() - class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs) + class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) f = make_locked_method_func(type_obj, func, class_clone) + # in case of dynamic inputs, restructure inputs to expected nested dict + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) # V1 else: f = getattr(obj, func) @@ -320,8 +325,8 @@ def merge_result_data(results, obj): output.append([o[i] for o in results]) return output -async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): - return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) +async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None): + return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) if has_pending_task: return return_values, {}, False, has_pending_task @@ -460,7 +465,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) + input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) @@ -475,7 +480,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, else: lazy_status_present = getattr(obj, "check_lazy_status", None) is not None if lazy_status_present: - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs) + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, v3_data=v3_data) required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r,list)], [])) required_inputs = [x for x in required_inputs if isinstance(x,str) and ( @@ -507,7 +512,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed, def pre_execute_cb(call_index): # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs) + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) @@ -745,18 +750,17 @@ async def validate_inputs(prompt_id, prompt, item, validated): class_type = prompt[unique_id]['class_type'] obj_class = nodes.NODE_CLASS_MAPPINGS[class_type] - class_inputs = obj_class.INPUT_TYPES() - valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) - errors = [] valid = True validate_function_inputs = [] validate_has_kwargs = False if issubclass(obj_class, _ComfyNodeInternal): + class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) validate_function_name = "validate_inputs" validate_function = first_real_override(obj_class, validate_function_name) else: + class_inputs = obj_class.INPUT_TYPES() validate_function_name = "VALIDATE_INPUTS" validate_function = getattr(obj_class, validate_function_name, None) if validate_function is not None: @@ -765,6 +769,8 @@ async def validate_inputs(prompt_id, prompt, item, validated): validate_has_kwargs = argspec.varkw is not None received_types = {} + valid_inputs = set(class_inputs.get('required',{})).union(set(class_inputs.get('optional',{}))) + for x in valid_inputs: input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) assert extra_info is not None @@ -935,7 +941,7 @@ async def validate_inputs(prompt_id, prompt, item, validated): continue if len(validate_function_inputs) > 0 or validate_has_kwargs: - input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id) + input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: @@ -943,7 +949,7 @@ async def validate_inputs(prompt_id, prompt, item, validated): if 'input_types' in validate_function_inputs: input_filtered['input_types'] = [received_types] - ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs) + ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data) ret = await resolve_map_node_over_list_results(ret) for x in input_filtered: for i, r in enumerate(ret): diff --git a/nodes.py b/nodes.py index 4c910a34b..356aa63df 100644 --- a/nodes.py +++ b/nodes.py @@ -2355,6 +2355,7 @@ async def init_builtin_extra_nodes(): "nodes_easycache.py", "nodes_audio_encoder.py", "nodes_rope.py", + "nodes_logic.py", "nodes_nop.py", ] From 861817d22d2659099811b56005c9eaea18d64c73 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Tue, 2 Dec 2025 21:47:51 -0800 Subject: [PATCH 5/6] Fix issue with portable updater. (#11070) This should fix the problem with the portable updater not working with portables created from a separate branch on the repo. This does not affect any current portables who were all created on the master branch. --- .ci/update_windows/update.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index 51a263203..59ece5130 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -66,8 +66,10 @@ if branch is None: try: ref = repo.lookup_reference('refs/remotes/origin/master') except: - print("pulling.") # noqa: T201 - pull(repo) + print("fetching.") # noqa: T201 + for remote in repo.remotes: + if remote.name == "origin": + remote.fetch() ref = repo.lookup_reference('refs/remotes/origin/master') repo.checkout(ref) branch = repo.lookup_branch('master') @@ -149,3 +151,4 @@ try: shutil.copy(stable_update_script, stable_update_script_to) except: pass + From 519c9411653df99761053c30e101816e0ca3c24b Mon Sep 17 00:00:00 2001 From: rattus <46076784+rattus128@users.noreply.github.com> Date: Wed, 3 Dec 2025 17:28:45 +1000 Subject: [PATCH 6/6] Prs/lora reservations (reduce massive Lora reservations especially on Flux2) (#11069) * mp: only count the offload cost of math once This was previously bundling the combined weight storage and computation cost * ops: put all post async transfer compute on the main stream Some models have massive weights that need either complex dequantization or lora patching. Don't do these patchings on the offload stream, instead do them on the main stream to syncrhonize the potentially large vram spikes for these compute processes. This avoids having to assume a worst case scenario of multiple offload streams all spiking VRAM is parallel with whatever the main stream is doing. --- comfy/model_patcher.py | 4 ++-- comfy/ops.py | 39 ++++++++++++++++++++++----------------- 2 files changed, 24 insertions(+), 19 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3eac77275..df2d8e827 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -704,7 +704,7 @@ class ModelPatcher: lowvram_weight = False - potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1)) + potential_offload = max(offload_buffer, module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem)) lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory weight_key = "{}.weight".format(n) @@ -883,7 +883,7 @@ class ModelPatcher: break module_offload_mem, module_mem, n, m, params = unload - potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem + potential_offload = module_offload_mem + (comfy.model_management.NUM_STREAMS * module_mem) lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: diff --git a/comfy/ops.py b/comfy/ops.py index 61a2f0754..eae434e68 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -111,22 +111,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if s.bias is not None: bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) - if bias_has_function: - with wf_context: - for f in s.bias_function: - bias = f(bias) + comfy.model_management.sync_stream(device, offload_stream) + + bias_a = bias + weight_a = weight + + if s.bias is not None: + for f in s.bias_function: + bias = f(bias) if weight_has_function or weight.dtype != dtype: - with wf_context: - weight = weight.to(dtype=dtype) - if isinstance(weight, QuantizedTensor): - weight = weight.dequantize() - for f in s.weight_function: - weight = f(weight) + weight = weight.to(dtype=dtype) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + for f in s.weight_function: + weight = f(weight) - comfy.model_management.sync_stream(device, offload_stream) if offloadable: - return weight, bias, offload_stream + return weight, bias, (offload_stream, weight_a, bias_a) else: #Legacy function signature return weight, bias @@ -135,13 +137,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of def uncast_bias_weight(s, weight, bias, offload_stream): if offload_stream is None: return - if weight is not None: - device = weight.device + os, weight_a, bias_a = offload_stream + if os is None: + return + if weight_a is not None: + device = weight_a.device else: - if bias is None: + if bias_a is None: return - device = bias.device - offload_stream.wait_stream(comfy.model_management.current_stream(device)) + device = bias_a.device + os.wait_stream(comfy.model_management.current_stream(device)) class CastWeightBiasOp: