From 092ee8a5008c8d558b0a72cc7961a31d9cc5400b Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Dec 2025 15:25:31 -0800 Subject: [PATCH 1/9] Fix some custom nodes. (#11134) --- comfy/supported_models_base.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 9fd84d329..0e7a829ba 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -17,6 +17,7 @@ """ import torch +import logging from . import model_base from . import utils from . import latent_formats @@ -117,3 +118,7 @@ class BASE: def set_inference_dtype(self, dtype, manual_cast_dtype): self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype + + def __getattr__(self, name): + logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name)) + return None From bed12674a1d2c4bfdfbdd098686390f807996c90 Mon Sep 17 00:00:00 2001 From: "Dr.Lt.Data" <128333288+ltdrdata@users.noreply.github.com> Date: Sat, 6 Dec 2025 08:45:38 +0900 Subject: [PATCH 2/9] docs: add ComfyUI-Manager documentation and update to v4.0.3b4 (#11133) - Add manager setup instructions and command line options to README - Document --enable-manager, --enable-manager-legacy-ui, and --disable-manager-ui flags - Bump comfyui_manager version from 4.0.3b3 to 4.0.3b4 --- README.md | 26 ++++++++++++++++++++++++++ manager_requirements.txt | 2 +- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ed857df9f..bae955b1b 100644 --- a/README.md +++ b/README.md @@ -320,6 +320,32 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step 1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536) 2. Launch ComfyUI by running `python main.py` + +## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4) + +**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI. + +### Setup + +1. Install the manager dependencies: + ```bash + pip install -r manager_requirements.txt + ``` + +2. Enable the manager with the `--enable-manager` flag when running ComfyUI: + ```bash + python main.py --enable-manager + ``` + +### Command Line Options + +| Flag | Description | +|------|-------------| +| `--enable-manager` | Enable ComfyUI-Manager | +| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) | +| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) | + + # Running ```python main.py``` diff --git a/manager_requirements.txt b/manager_requirements.txt index 52cc5389c..b95cefb74 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.0.3b3 +comfyui_manager==4.0.3b4 From fd109325db7126f92c2dfb7e6b25310eded8c1f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 6 Dec 2025 05:20:22 +0200 Subject: [PATCH 3/9] Kandinsky5 model support (#10988) * Add Kandinsky5 model support lite and pro T2V tested to work * Update kandinsky5.py * Fix fp8 * Fix fp8_scaled text encoder * Add transformer_options for attention * Code cleanup, optimizations, use fp32 for all layers originally at fp32 * ImageToVideo -node * Fix I2V, add necessary latent post process nodes * Support text to image model * Support block replace patches (SLG mostly) * Support official LoRAs * Don't scale RoPE for lite model as that just doesn't work... * Update supported_models.py * Rever RoPE scaling to simpler one * Fix typo * Handle latent dim difference for image model in the VAE instead * Add node to use different prompts for clip_l and qwen25_7b * Reduce peak VRAM usage a bit * Further reduce peak VRAM consumption by chunking ffn * Update chunking * Update memory_usage_factor * Code cleanup, don't force the fp32 layers as it has minimal effect * Allow for stronger changes with first frames normalization Default values are too weak for any meaningful changes, these should probably be exposed as advanced node options when that's available. * Add image model's own chat template, remove unused image2video template * Remove hard error in ReplaceVideoLatentFrames -node * Update kandinsky5.py * Update supported_models.py * Fix typos in prompt template They were now fixed in the original repository as well * Update ReplaceVideoLatentFrames Add tooltips Make source optional Better handle negative index * Rename NormalizeVideoLatentFrames -node For bit better clarity what it does * Fix NormalizeVideoLatentStart node out on non-op --- comfy/ldm/kandinsky5/model.py | 407 ++++++++++++++++++++++++++++++ comfy/lora.py | 7 + comfy/model_base.py | 47 ++++ comfy/model_detection.py | 18 ++ comfy/sd.py | 11 + comfy/supported_models.py | 56 +++- comfy/text_encoders/kandinsky5.py | 68 +++++ comfy_api/latest/_io.py | 2 + comfy_extras/nodes_kandinsky5.py | 136 ++++++++++ comfy_extras/nodes_latent.py | 39 ++- nodes.py | 3 +- 11 files changed, 791 insertions(+), 3 deletions(-) create mode 100644 comfy/ldm/kandinsky5/model.py create mode 100644 comfy/text_encoders/kandinsky5.py create mode 100644 comfy_extras/nodes_kandinsky5.py diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py new file mode 100644 index 000000000..a653e02fc --- /dev/null +++ b/comfy/ldm/kandinsky5/model.py @@ -0,0 +1,407 @@ +import torch +from torch import nn +import math + +import comfy.ldm.common_dit +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.flux.layers import EmbedND + +def attention(q, k, v, heads, transformer_options={}): + return optimized_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + heads=heads, + skip_reshape=True, + transformer_options=transformer_options + ) + +def apply_scale_shift_norm(norm, x, scale, shift): + return torch.addcmul(shift, norm(x), scale + 1.0) + +def apply_gate_sum(x, out, gate): + return torch.addcmul(x, gate, out) + +def get_shift_scale_gate(params): + shift, scale, gate = torch.chunk(params, 3, dim=-1) + return tuple(x.unsqueeze(1) for x in (shift, scale, gate)) + +def get_freqs(dim, max_period=10000.0): + return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim) + + +class TimeEmbeddings(nn.Module): + def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None): + super().__init__() + assert model_dim % 2 == 0 + self.model_dim = model_dim + self.max_period = max_period + self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False) + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.activation = nn.SiLU() + self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, timestep, dtype): + args = torch.outer(timestep, self.freqs.to(device=timestep.device)) + time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype) + time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) + return time_embed + + +class TextEmbeddings(nn.Module): + def __init__(self, text_dim, model_dim, operation_settings=None): + super().__init__() + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, text_embed): + text_embed = self.in_layer(text_embed) + return self.norm(text_embed).type_as(text_embed) + + +class VisualEmbeddings(nn.Module): + def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None): + super().__init__() + self.patch_size = patch_size + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x): + x = x.movedim(1, -1) # B C T H W -> B T H W C + B, T, H, W, dim = x.shape + pt, ph, pw = self.patch_size + + x = x.view( + B, + T // pt, pt, + H // ph, ph, + W // pw, pw, + dim, + ).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7) + + return self.in_layer(x) + + +class Modulation(nn.Module): + def __init__(self, time_dim, model_dim, num_params, operation_settings=None): + super().__init__() + self.activation = nn.SiLU() + self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x): + return self.out_layer(self.activation(x)) + + +class SelfAttention(nn.Module): + def __init__(self, num_channels, head_dim, operation_settings=None): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + self.head_dim = head_dim + + operations = operation_settings.get("operations") + self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.num_chunks = 2 + + def _compute_qk(self, x, freqs, proj_fn, norm_fn): + result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1) + return apply_rope1(norm_fn(result), freqs) + + def _forward(self, x, freqs, transformer_options={}): + q = self._compute_qk(x, freqs, self.to_query, self.query_norm) + k = self._compute_qk(x, freqs, self.to_key, self.key_norm) + v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + def _forward_chunked(self, x, freqs, transformer_options={}): + def process_chunks(proj_fn, norm_fn): + x_chunks = torch.chunk(x, self.num_chunks, dim=1) + freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1) + chunks = [] + for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks): + chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn)) + return torch.cat(chunks, dim=1) + + q = process_chunks(self.to_query, self.query_norm) + k = process_chunks(self.to_key, self.key_norm) + v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + def forward(self, x, freqs, transformer_options={}): + if x.shape[1] > 8192: + return self._forward_chunked(x, freqs, transformer_options=transformer_options) + else: + return self._forward(x, freqs, transformer_options=transformer_options) + + +class CrossAttention(SelfAttention): + def get_qkv(self, x, context): + q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1) + k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1) + v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1) + return q, k, v + + def forward(self, x, context, transformer_options={}): + q, k, v = self.get_qkv(x, context) + out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + +class FeedForward(nn.Module): + def __init__(self, dim, ff_dim, operation_settings=None): + super().__init__() + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.activation = nn.GELU() + self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.num_chunks = 4 + + def _forward(self, x): + return self.out_layer(self.activation(self.in_layer(x))) + + def _forward_chunked(self, x): + chunks = torch.chunk(x, self.num_chunks, dim=1) + output_chunks = [] + for chunk in chunks: + output_chunks.append(self._forward(chunk)) + return torch.cat(output_chunks, dim=1) + + def forward(self, x): + if x.shape[1] > 8192: + return self._forward_chunked(x) + else: + return self._forward(x) + + +class OutLayer(nn.Module): + def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None): + super().__init__() + self.patch_size = patch_size + self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings) + operations = operation_settings.get("operations") + self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, visual_embed, time_embed): + B, T, H, W, _ = visual_embed.shape + shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1) + scale = scale[:, None, None, None, :] + shift = shift[:, None, None, None, :] + visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift) + x = self.out_layer(visual_embed) + + out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2]) + x = x.view( + B, T, H, W, + out_dim, + self.patch_size[0], self.patch_size[1], self.patch_size[2] + ) + return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5) + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None): + super().__init__() + self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings) + operations = operation_settings.get("operations") + + self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) + + def forward(self, x, time_embed, freqs, transformer_options={}): + self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1) + shift, scale, gate = get_shift_scale_gate(self_attn_params) + out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) + out = self.self_attention(out, freqs, transformer_options=transformer_options) + x = apply_gate_sum(x, out, gate) + + shift, scale, gate = get_shift_scale_gate(ff_params) + out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift) + out = self.feed_forward(out) + x = apply_gate_sum(x, out, gate) + return x + + +class TransformerDecoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None): + super().__init__() + self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings) + + operations = operation_settings.get("operations") + self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) + + def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}): + self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1) + # self attention + shift, scale, gate = get_shift_scale_gate(self_attn_params) + visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) + visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + # cross attention + shift, scale, gate = get_shift_scale_gate(cross_attn_params) + visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift) + visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + # feed forward + shift, scale, gate = get_shift_scale_gate(ff_params) + visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift) + visual_out = self.feed_forward(visual_out) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + return visual_embed + + +class Kandinsky5(nn.Module): + def __init__( + self, + in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512, + model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32, + axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0), + dtype=None, device=None, operations=None, **kwargs + ): + super().__init__() + head_dim = sum(axes_dims) + self.rope_scale_factor = rope_scale_factor + self.in_visual_dim = in_visual_dim + self.model_dim = model_dim + self.patch_size = patch_size + self.visual_embed_dim = visual_embed_dim + self.dtype = dtype + self.device = device + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings) + self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings) + self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings) + self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings) + + self.text_transformer_blocks = nn.ModuleList( + [TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)] + ) + + self.visual_transformer_blocks = nn.ModuleList( + [TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)] + ) + + self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings) + + self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims) + self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim]) + + def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}): + steps = seq_len if steps is None else steps + seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype) + seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1) + freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2) + return freqs + + def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}): + + patch_size = self.patch_size + t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) + h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) + w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + + if steps_t is None: + steps_t = t_len + if steps_h is None: + steps_h = h_len + if steps_w is None: + steps_w = w_len + + h_start = 0 + w_start = 0 + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0 + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + t_start += rope_options.get("shift_t", 0.0) + h_start += rope_options.get("shift_y", 0.0) + w_start += rope_options.get("shift_x", 0.0) + else: + rope_scale_factor = self.rope_scale_factor + if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions + if h * w >= 14080: + rope_scale_factor = (1.0, 3.16, 3.16) + + t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0 + h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0 + w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0 + + img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) + img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) + + freqs = self.rope_embedder_3d(img_ids).movedim(1, 2) + return freqs + + def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs): + patches_replace = transformer_options.get("patches_replace", {}) + context = self.text_embeddings(context) + time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y) + + for block in self.text_transformer_blocks: + context = block(context, time_embed, freqs_text, transformer_options=transformer_options) + + visual_embed = self.visual_embeddings(x) + visual_shape = visual_embed.shape[:-1] + visual_embed = visual_embed.flatten(1, -2) + + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.visual_transformer_blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.visual_transformer_blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) + visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"] + else: + visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options) + + visual_embed = visual_embed.reshape(*visual_shape, -1) + return self.out_layer(visual_embed, time_embed) + + def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): + bs, c, t_len, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + if time_dim_replace is not None: + time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size) + x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace + + freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) + freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options) + + return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs) + + def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs) diff --git a/comfy/lora.py b/comfy/lora.py index 3a9077869..e7202ce97 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -322,6 +322,13 @@ def model_lora_keys_unet(model, key_map={}): key_map["diffusion_model.{}".format(key_lora)] = to key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + if isinstance(model, comfy.model_base.Kandinsky5): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["{}".format(key_lora)] = k + key_map["transformer.{}".format(key_lora)] = k + return key_map diff --git a/comfy/model_base.py b/comfy/model_base.py index 3cedd4f31..0be006cc2 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -47,6 +47,7 @@ import comfy.ldm.chroma_radiance.model import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model +import comfy.ldm.kandinsky5.model import comfy.model_management import comfy.patcher_extension @@ -1630,3 +1631,49 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo15): out = super().extra_conds(**kwargs) out['disable_time_r'] = comfy.conds.CONDConstant(False) return out + +class Kandinsky5(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5) + + def encode_adm(self, **kwargs): + return kwargs["pooled_output"] + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + device = kwargs["device"] + image = torch.zeros_like(noise) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :1] + else: + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + + return torch.cat((image, mask), dim=1) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + + time_dim_replace = kwargs.get("time_dim_replace", None) + if time_dim_replace is not None: + out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace)) + + return out + +class Kandinsky5Image(Kandinsky5): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + return None diff --git a/comfy/model_detection.py b/comfy/model_detection.py index fd1907627..30b33a486 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -611,6 +611,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') return dit_config + if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 + dit_config = {} + model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + dit_config["model_dim"] = model_dim + if model_dim in [4096, 2560]: # pro video and lite image + dit_config["axes_dims"] = (32, 48, 48) + if model_dim == 2560: # lite image + dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0) + elif model_dim == 1792: # lite video + dit_config["axes_dims"] = (16, 24, 24) + dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + dit_config["image_model"] = "kandinsky5" + dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0] + dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1] + dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.') + dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.') + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None diff --git a/comfy/sd.py b/comfy/sd.py index c350322f8..754b1703d 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -54,6 +54,7 @@ import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image import comfy.text_encoders.ovis +import comfy.text_encoders.kandinsky5 import comfy.model_patcher import comfy.lora @@ -766,6 +767,8 @@ class VAE: self.throw_exception_if_invalid() pixel_samples = None do_tile = False + if self.latent_dim == 2 and samples_in.ndim == 5: + samples_in = samples_in[:, :, 0] try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -983,6 +986,8 @@ class CLIPType(Enum): HUNYUAN_IMAGE = 19 HUNYUAN_VIDEO_15 = 20 OVIS = 21 + KANDINSKY5 = 22 + KANDINSKY5_IMAGE = 23 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1231,6 +1236,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.HUNYUAN_VIDEO_15: clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer + elif clip_type == CLIPType.KANDINSKY5: + clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer + elif clip_type == CLIPType.KANDINSKY5_IMAGE: + clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index afd97160b..91cc4ef08 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -21,6 +21,7 @@ import comfy.text_encoders.ace import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image +import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image from . import supported_models_base @@ -1474,7 +1475,60 @@ class HunyuanVideo15_SR_Distilled(HunyuanVideo): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2] +class Kandinsky5(supported_models_base.BASE): + unet_config = { + "image_model": "kandinsky5", + } + + sampling_settings = { + "shift": 10.0, + } + + unet_extra_config = {} + latent_format = latent_formats.HunyuanVideo + + memory_usage_factor = 1.1 #TODO + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Kandinsky5(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) + + +class Kandinsky5Image(Kandinsky5): + unet_config = { + "image_model": "kandinsky5", + "model_dim": 2560, + "visual_embed_dim": 64, + } + + sampling_settings = { + "shift": 3.0, + } + + latent_format = latent_formats.Flux + memory_usage_factor = 1.1 #TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Kandinsky5Image(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Kandinsky5Image, Kandinsky5] models += [SVD_img2vid] diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py new file mode 100644 index 000000000..22f991c36 --- /dev/null +++ b/comfy/text_encoders/kandinsky5.py @@ -0,0 +1,68 @@ +from comfy import sd1_clip +from .qwen_image import QwenImageTokenizer, QwenImageTEModel +from .llama import Qwen25_7BVLI + + +class Kandinsky5Tokenizer(QwenImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = super().tokenize_with_weights(text, return_word_ids, **kwargs) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) + + return out + + +class Kandinsky5TokenizerImage(Kandinsky5Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + + +class Qwen25_7BVLIModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}): + llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) + if llama_scaled_fp8 is not None: + model_options = model_options.copy() + model_options["scaled_fp8"] = llama_scaled_fp8 + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class Kandinsky5TEModel(QwenImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1) + l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"]) + + return cond, l_pooled, extra + + def set_clip_options(self, options): + super().set_clip_options(options) + self.clip_l.set_clip_options(options) + + def reset_clip_options(self): + super().reset_clip_options() + self.clip_l.reset_clip_options() + + def load_sd(self, sd): + if "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + return self.clip_l.load_sd(sd) + else: + return super().load_sd(sd) + +def te(dtype_llama=None, llama_scaled_fp8=None): + class Kandinsky5TEModel_(Kandinsky5TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Kandinsky5TEModel_ diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 866c3e0eb..d7cbe68cf 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -568,6 +568,8 @@ class Conditioning(ComfyTypeIO): '''Used by WAN Camera.''' time_dim_concat: NotRequired[torch.Tensor] '''Used by WAN Phantom Subject.''' + time_dim_replace: NotRequired[torch.Tensor] + '''Used by Kandinsky5 I2V.''' CondList = list[tuple[torch.Tensor, PooledDict]] Type = CondList diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py new file mode 100644 index 000000000..9cb234be1 --- /dev/null +++ b/comfy_extras/nodes_kandinsky5.py @@ -0,0 +1,136 @@ +import nodes +import node_helpers +import torch +import comfy.model_management +import comfy.utils + +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class Kandinsky5ImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Kandinsky5ImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty video latent"), + io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + cond_latent_out = {} + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + encoded = vae.encode(start_image[:, :, :, :3]) + cond_latent_out["samples"] = encoded + + mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent, cond_latent_out) + + +def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5): + source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W + source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W + + reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high) + reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high) + + # normalization + normalized = (source - source_mean) / (source_std + 1e-8) + normalized = normalized * reference_std + reference_mean + + return normalized + + +class NormalizeVideoLatentStart(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="NormalizeVideoLatentStart", + category="conditioning/video_models", + description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.", + inputs=[ + io.Latent.Input("latent"), + io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"), + io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"), + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput: + if latent["samples"].shape[2] <= 1: + return io.NodeOutput(latent) + s = latent.copy() + samples = latent["samples"].clone() + + first_frames = samples[:, :, :start_frame_count] + reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)] + normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data) + + samples[:, :, :start_frame_count] = normalized_first_frames + s["samples"] = samples + return io.NodeOutput(s) + + +class CLIPTextEncodeKandinsky5(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeKandinsky5", + category="advanced/conditioning/kandinsky5", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput: + tokens = clip.tokenize(clip_l) + tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"] + + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + +class Kandinsky5Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Kandinsky5ImageToVideo, + NormalizeVideoLatentStart, + CLIPTextEncodeKandinsky5, + ] + +async def comfy_entrypoint() -> Kandinsky5Extension: + return Kandinsky5Extension() diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index d2df07ff9..e439b18ef 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -4,7 +4,7 @@ import torch import nodes from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +import logging def reshape_latent_to(target_shape, latent, repeat_batch=True): if latent.shape[1:] != target_shape[1:]: @@ -388,6 +388,42 @@ class LatentOperationSharpen(io.ComfyNode): return luminance * sharpened return io.NodeOutput(sharpen) +class ReplaceVideoLatentFrames(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ReplaceVideoLatentFrames", + category="latent/batch", + inputs=[ + io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."), + io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."), + io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, destination, index, source=None) -> io.NodeOutput: + if source is None: + return io.NodeOutput(destination) + dest_frames = destination["samples"].shape[2] + source_frames = source["samples"].shape[2] + if index < 0: + index = dest_frames + index + if index > dest_frames: + logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.") + return io.NodeOutput(destination) + if index + source_frames > dest_frames: + logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.") + return io.NodeOutput(destination) + s = source.copy() + s_source = source["samples"] + s_destination = destination["samples"].clone() + s_destination[:, :, index:index + s_source.shape[2]] = s_source + s["samples"] = s_destination + return io.NodeOutput(s) class LatentExtension(ComfyExtension): @override @@ -405,6 +441,7 @@ class LatentExtension(ComfyExtension): LatentApplyOperationCFG, LatentOperationTonemapReinhard, LatentOperationSharpen, + ReplaceVideoLatentFrames ] diff --git a/nodes.py b/nodes.py index 356aa63df..8d28a725d 100644 --- a/nodes.py +++ b/nodes.py @@ -970,7 +970,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -2357,6 +2357,7 @@ async def init_builtin_extra_nodes(): "nodes_rope.py", "nodes_logic.py", "nodes_nop.py", + "nodes_kandinsky5.py", ] import_failed = [] From ae676ed105663bb225153c8bca406f00edf738b4 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Dec 2025 20:01:19 -0800 Subject: [PATCH 4/9] Fix regression. (#11137) --- comfy/supported_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 91cc4ef08..383c82c3e 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1529,6 +1529,6 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Kandinsky5Image, Kandinsky5] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] models += [SVD_img2vid] From 117bf3f2bd9235cb5942a1de10a534c9869c7444 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 6 Dec 2025 06:22:02 +0200 Subject: [PATCH 5/9] convert nodes_freelunch.py to the V3 schema (#10904) --- comfy_extras/nodes_freelunch.py | 89 +++++++++++++++++---------- comfy_extras/nodes_model_downscale.py | 5 -- 2 files changed, 57 insertions(+), 37 deletions(-) diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index e3ac58447..3429b731e 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -2,6 +2,8 @@ import torch import logging +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO def Fourier_filter(x, threshold, scale): # FFT @@ -22,21 +24,26 @@ def Fourier_filter(x, threshold, scale): return x_filtered.to(x.dtype) -class FreeU: +class FreeU(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}), - "b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}), - "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), - "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return IO.Schema( + node_id="FreeU", + category="model_patches/unet", + inputs=[ + IO.Model.Input("model"), + IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01), + IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + IO.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, b1, b2, s1, s2): + @classmethod + def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput: model_channels = model.model.model_config.unet_config["model_channels"] scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} on_cpu_devices = {} @@ -59,23 +66,31 @@ class FreeU: m = model.clone() m.set_model_output_block_patch(output_block_patch) - return (m, ) + return IO.NodeOutput(m) -class FreeU_V2: + patch = execute # TODO: remove + + +class FreeU_V2(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}), - "b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}), - "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), - "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return IO.Schema( + node_id="FreeU_V2", + category="model_patches/unet", + inputs=[ + IO.Model.Input("model"), + IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01), + IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + IO.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, b1, b2, s1, s2): + @classmethod + def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput: model_channels = model.model.model_config.unet_config["model_channels"] scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} on_cpu_devices = {} @@ -105,9 +120,19 @@ class FreeU_V2: m = model.clone() m.set_model_output_block_patch(output_block_patch) - return (m, ) + return IO.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "FreeU": FreeU, - "FreeU_V2": FreeU_V2, -} + patch = execute # TODO: remove + + +class FreelunchExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + FreeU, + FreeU_V2, + ] + + +async def comfy_entrypoint() -> FreelunchExtension: + return FreelunchExtension() diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index f7ca9699d..dec2ae841 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -53,11 +53,6 @@ class PatchModelAddDownscale(io.ComfyNode): return io.NodeOutput(m) -NODE_DISPLAY_NAME_MAPPINGS = { - # Sampling - "PatchModelAddDownscale": "", -} - class ModelDownscaleExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: From 913f86b72740f84f759786a698108840a09b6498 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 6 Dec 2025 06:24:10 +0200 Subject: [PATCH 6/9] [V3] convert nodes_mask.py to V3 schema (#10669) * convert nodes_mask.py to V3 schema * set "Preview Mask" as display name for MaskPreview --- comfy_extras/nodes_mask.py | 508 +++++++++++++++++++------------------ 1 file changed, 263 insertions(+), 245 deletions(-) diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index a5e405008..290e6f55e 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -3,11 +3,10 @@ import scipy.ndimage import torch import comfy.utils import node_helpers -import folder_paths -import random +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, UI import nodes -from nodes import MAX_RESOLUTION def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): source = source.to(destination.device) @@ -46,202 +45,213 @@ def composite(destination, source, x, y, mask = None, multiplier = 8, resize_sou destination[..., top:bottom, left:right] = source_portion + destination_portion return destination -class LatentCompositeMasked: +class LatentCompositeMasked(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "destination": ("LATENT",), - "source": ("LATENT",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "resize_source": ("BOOLEAN", {"default": False}), - }, - "optional": { - "mask": ("MASK",), - } - } - RETURN_TYPES = ("LATENT",) - FUNCTION = "composite" + def define_schema(cls): + return IO.Schema( + node_id="LatentCompositeMasked", + category="latent", + inputs=[ + IO.Latent.Input("destination"), + IO.Latent.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), + IO.Boolean.Input("resize_source", default=False), + IO.Mask.Input("mask", optional=True), + ], + outputs=[IO.Latent.Output()], + ) - CATEGORY = "latent" - - def composite(self, destination, source, x, y, resize_source, mask = None): + @classmethod + def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput: output = destination.copy() destination = destination["samples"].clone() source = source["samples"] output["samples"] = composite(destination, source, x, y, mask, 8, resize_source) - return (output,) + return IO.NodeOutput(output) -class ImageCompositeMasked: + composite = execute # TODO: remove + + +class ImageCompositeMasked(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "destination": ("IMAGE",), - "source": ("IMAGE",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "resize_source": ("BOOLEAN", {"default": False}), - }, - "optional": { - "mask": ("MASK",), - } - } - RETURN_TYPES = ("IMAGE",) - FUNCTION = "composite" + def define_schema(cls): + return IO.Schema( + node_id="ImageCompositeMasked", + category="image", + inputs=[ + IO.Image.Input("destination"), + IO.Image.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Boolean.Input("resize_source", default=False), + IO.Mask.Input("mask", optional=True), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "image" - - def composite(self, destination, source, x, y, resize_source, mask = None): + @classmethod + def execute(cls, destination, source, x, y, resize_source, mask = None) -> IO.NodeOutput: destination, source = node_helpers.image_alpha_fix(destination, source) destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) - return (output,) + return IO.NodeOutput(output) -class MaskToImage: + composite = execute # TODO: remove + + +class MaskToImage(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="MaskToImage", + display_name="Convert Mask to Image", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "mask_to_image" - - def mask_to_image(self, mask): + @classmethod + def execute(cls, mask) -> IO.NodeOutput: result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) - return (result,) + return IO.NodeOutput(result) -class ImageToMask: + mask_to_image = execute # TODO: remove + + +class ImageToMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "channel": (["red", "green", "blue", "alpha"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ImageToMask", + display_name="Convert Image to Mask", + category="mask", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("channel", options=["red", "green", "blue", "alpha"]), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, image, channel): + @classmethod + def execute(cls, image, channel) -> IO.NodeOutput: channels = ["red", "green", "blue", "alpha"] mask = image[:, :, :, channels.index(channel)] - return (mask,) + return IO.NodeOutput(mask) -class ImageColorToMask: + image_to_mask = execute # TODO: remove + + +class ImageColorToMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ImageColorToMask", + category="mask", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("color", default=0, min=0, max=0xFFFFFF, step=1, display_mode=IO.NumberDisplay.number), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, image, color): + @classmethod + def execute(cls, image, color) -> IO.NodeOutput: temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int) temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2] mask = torch.where(temp == color, 1.0, 0).float() - return (mask,) + return IO.NodeOutput(mask) -class SolidMask: + image_to_mask = execute # TODO: remove + + +class SolidMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="SolidMask", + category="mask", + inputs=[ + IO.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "solid" - - def solid(self, value, width, height): + @classmethod + def execute(cls, value, width, height) -> IO.NodeOutput: out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu") - return (out,) + return IO.NodeOutput(out) -class InvertMask: + solid = execute # TODO: remove + + +class InvertMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="InvertMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "invert" - - def invert(self, mask): + @classmethod + def execute(cls, mask) -> IO.NodeOutput: out = 1.0 - mask - return (out,) + return IO.NodeOutput(out) -class CropMask: + invert = execute # TODO: remove + + +class CropMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="CropMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "crop" - - def crop(self, mask, x, y, width, height): + @classmethod + def execute(cls, mask, x, y, width, height) -> IO.NodeOutput: mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) out = mask[:, y:y + height, x:x + width] - return (out,) + return IO.NodeOutput(out) -class MaskComposite: + crop = execute # TODO: remove + + +class MaskComposite(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "destination": ("MASK",), - "source": ("MASK",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="MaskComposite", + category="mask", + inputs=[ + IO.Mask.Input("destination"), + IO.Mask.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "combine" - - def combine(self, destination, source, x, y, operation): + @classmethod + def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput: output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() source = source.reshape((-1, source.shape[-2], source.shape[-1])) @@ -267,28 +277,29 @@ class MaskComposite: output = torch.clamp(output, 0.0, 1.0) - return (output,) + return IO.NodeOutput(output) -class FeatherMask: + combine = execute # TODO: remove + + +class FeatherMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="FeatherMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "feather" - - def feather(self, mask, left, top, right, bottom): + @classmethod + def execute(cls, mask, left, top, right, bottom) -> IO.NodeOutput: output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() left = min(left, output.shape[-1]) @@ -312,26 +323,28 @@ class FeatherMask: feather_rate = (y + 1) / bottom output[:, -y, :] *= feather_rate - return (output,) + return IO.NodeOutput(output) -class GrowMask: + feather = execute # TODO: remove + + +class GrowMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), - "tapered_corners": ("BOOLEAN", {"default": True}), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="GrowMask", + display_name="Grow Mask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1), + IO.Boolean.Input("tapered_corners", default=True), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "expand_mask" - - def expand_mask(self, mask, expand, tapered_corners): + @classmethod + def execute(cls, mask, expand, tapered_corners) -> IO.NodeOutput: c = 0 if tapered_corners else 1 kernel = np.array([[c, 1, c], [1, 1, 1], @@ -347,69 +360,74 @@ class GrowMask: output = scipy.ndimage.grey_dilation(output, footprint=kernel) output = torch.from_numpy(output) out.append(output) - return (torch.stack(out, dim=0),) + return IO.NodeOutput(torch.stack(out, dim=0)) -class ThresholdMask: + expand_mask = execute # TODO: remove + + +class ThresholdMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ThresholdMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, mask, value): + @classmethod + def execute(cls, mask, value) -> IO.NodeOutput: mask = (mask > value).float() - return (mask,) + return IO.NodeOutput(mask) + + image_to_mask = execute # TODO: remove + # Mask Preview - original implement from # https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81 # upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes -class MaskPreview(nodes.SaveImage): - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() - self.type = "temp" - self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) - self.compress_level = 4 +class MaskPreview(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MaskPreview", + display_name="Preview Mask", + category="mask", + description="Saves the input images to your ComfyUI output directory.", + inputs=[ + IO.Mask.Input("mask"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": {"mask": ("MASK",), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - FUNCTION = "execute" - CATEGORY = "mask" - - def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) - return self.save_images(preview, filename_prefix, prompt, extra_pnginfo) + def execute(cls, mask, filename_prefix="ComfyUI") -> IO.NodeOutput: + return IO.NodeOutput(ui=UI.PreviewMask(mask)) -NODE_CLASS_MAPPINGS = { - "LatentCompositeMasked": LatentCompositeMasked, - "ImageCompositeMasked": ImageCompositeMasked, - "MaskToImage": MaskToImage, - "ImageToMask": ImageToMask, - "ImageColorToMask": ImageColorToMask, - "SolidMask": SolidMask, - "InvertMask": InvertMask, - "CropMask": CropMask, - "MaskComposite": MaskComposite, - "FeatherMask": FeatherMask, - "GrowMask": GrowMask, - "ThresholdMask": ThresholdMask, - "MaskPreview": MaskPreview -} +class MaskExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + LatentCompositeMasked, + ImageCompositeMasked, + MaskToImage, + ImageToMask, + ImageColorToMask, + SolidMask, + InvertMask, + CropMask, + MaskComposite, + FeatherMask, + GrowMask, + ThresholdMask, + MaskPreview, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "ImageToMask": "Convert Image to Mask", - "MaskToImage": "Convert Mask to Image", -} + +async def comfy_entrypoint() -> MaskExtension: + return MaskExtension() From d7a0aef65033bf0fe56e521577a44fac1830b8b3 Mon Sep 17 00:00:00 2001 From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com> Date: Fri, 5 Dec 2025 21:15:21 -0800 Subject: [PATCH 7/9] Set OCL_SET_SVM_SIZE on AMD. (#11139) --- cuda_malloc.py | 27 +++++++++++++++++---------- main.py | 3 +++ 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/cuda_malloc.py b/cuda_malloc.py index 6520d5123..ee2bc4b69 100644 --- a/cuda_malloc.py +++ b/cuda_malloc.py @@ -63,18 +63,22 @@ def cuda_malloc_supported(): return True +version = "" + +try: + torch_spec = importlib.util.find_spec("torch") + for folder in torch_spec.submodule_search_locations: + ver_file = os.path.join(folder, "version.py") + if os.path.isfile(ver_file): + spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + version = module.__version__ +except: + pass + if not args.cuda_malloc: try: - version = "" - torch_spec = importlib.util.find_spec("torch") - for folder in torch_spec.submodule_search_locations: - ver_file = os.path.join(folder, "version.py") - if os.path.isfile(ver_file): - spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - version = module.__version__ - if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc args.cuda_malloc = cuda_malloc_supported() @@ -90,3 +94,6 @@ if args.cuda_malloc and not args.disable_cuda_malloc: env_var += ",backend:cudaMallocAsync" os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var + +def get_torch_version_noimport(): + return str(version) diff --git a/main.py b/main.py index 0cd815d9e..0d02a087b 100644 --- a/main.py +++ b/main.py @@ -167,6 +167,9 @@ if __name__ == "__main__": os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" import cuda_malloc + if "rocm" in cuda_malloc.get_torch_version_noimport(): + os.environ['OCL_SET_SVM_SIZE'] = '262144' # set at the request of AMD + if 'torch' in sys.modules: logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.") From 76f18e955dcbc88ed13d6802194fd897927f93e5 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Sat, 6 Dec 2025 13:28:08 +0200 Subject: [PATCH 8/9] marked all Pika API nodes a deprecated (#11146) --- comfy_api_nodes/nodes_pika.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index 51148211b..acd88c391 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -92,6 +92,7 @@ class PikaImageToVideo(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -152,6 +153,7 @@ class PikaTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -239,6 +241,7 @@ class PikaScenes(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -323,6 +326,7 @@ class PikAdditionsNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -399,6 +403,7 @@ class PikaSwapsNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -466,6 +471,7 @@ class PikaffectsNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -515,6 +521,7 @@ class PikaStartEndFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod From 7ac7d69d948e75c3a230d1262daab84d75aff895 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?= <40791699+kijai@users.noreply.github.com> Date: Sat, 6 Dec 2025 20:09:44 +0200 Subject: [PATCH 9/9] Fix EmptyAudio node input types (#11149) --- comfy_extras/nodes_audio.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 812301fb7..c7916443c 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -573,12 +573,14 @@ class EmptyAudio(IO.ComfyNode): step=0.01, tooltip="Duration of the empty audio clip in seconds", ), - IO.Float.Input( + IO.Int.Input( "sample_rate", default=44100, tooltip="Sample rate of the empty audio clip.", + min=1, + max=192000, ), - IO.Float.Input( + IO.Int.Input( "channels", default=2, min=1,