From 2bff3c520fa2dd5dd0fdb311594ae11764b72f12 Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Wed, 10 Dec 2025 14:06:16 +0000 Subject: [PATCH 01/13] Add nabla support --- comfy/ldm/kandinsky5/model.py | 101 ++++++++++++++++--- comfy/ldm/kandinsky5/utils_nabla.py | 147 ++++++++++++++++++++++++++++ 2 files changed, 234 insertions(+), 14 deletions(-) create mode 100644 comfy/ldm/kandinsky5/utils_nabla.py diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index 1509de2f8..df10e4496 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -6,6 +6,12 @@ import comfy.ldm.common_dit from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.kandinsky5.utils_nabla import ( + fractal_flatten, + fractal_unflatten, + fast_sta_nabla, + nabla, +) def attention(q, k, v, heads, transformer_options={}): return optimized_attention( @@ -116,14 +122,17 @@ class SelfAttention(nn.Module): result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1) return apply_rope1(norm_fn(result), freqs) - def _forward(self, x, freqs, transformer_options={}): + def _forward(self, x, freqs, sparse_params=None, transformer_options={}): q = self._compute_qk(x, freqs, self.to_query, self.query_norm) k = self._compute_qk(x, freqs, self.to_key, self.key_norm) v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) - out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + if sparse_params is None: + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + else: + out = nabla(q, k, v, sparse_params) return self.out_layer(out) - def _forward_chunked(self, x, freqs, transformer_options={}): + def _forward_chunked(self, x, freqs, sparse_params=None, transformer_options={}): def process_chunks(proj_fn, norm_fn): x_chunks = torch.chunk(x, self.num_chunks, dim=1) freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1) @@ -135,14 +144,17 @@ class SelfAttention(nn.Module): q = process_chunks(self.to_query, self.query_norm) k = process_chunks(self.to_key, self.key_norm) v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) - out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + if sparse_params is None: + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + else: + out = nabla(q, k, v, sparse_params) return self.out_layer(out) - def forward(self, x, freqs, transformer_options={}): + def forward(self, x, freqs, sparse_params=None, transformer_options={}): if x.shape[1] > 8192: - return self._forward_chunked(x, freqs, transformer_options=transformer_options) + return self._forward_chunked(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options) else: - return self._forward(x, freqs, transformer_options=transformer_options) + return self._forward(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options) class CrossAttention(SelfAttention): @@ -251,12 +263,12 @@ class TransformerDecoderBlock(nn.Module): self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) - def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}): + def forward(self, visual_embed, text_embed, time_embed, freqs, sparse_params=None, transformer_options={}): self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1) # self attention shift, scale, gate = get_shift_scale_gate(self_attn_params) visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) - visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options) + visual_out = self.self_attention(visual_out, freqs, sparse_params=sparse_params, transformer_options=transformer_options) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) # cross attention shift, scale, gate = get_shift_scale_gate(cross_attn_params) @@ -369,21 +381,80 @@ class Kandinsky5(nn.Module): visual_embed = self.visual_embeddings(x) visual_shape = visual_embed.shape[:-1] - visual_embed = visual_embed.flatten(1, -2) blocks_replace = patches_replace.get("dit", {}) transformer_options["total_blocks"] = len(self.visual_transformer_blocks) transformer_options["block_type"] = "double" + + B, _, T, H, W = x.shape + if T > 30: # 10 sec generation + assert self.patch_size[0] == 1 + + freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])[0] + visual_embed_4d, freqs = fractal_flatten(visual_embed[0], freqs, visual_shape[1:]) + visual_embed, freqs = visual_embed_4d.unsqueeze(0), freqs.unsqueeze(0) + + pt, ph, pw = self.patch_size + T, H, W = T // pt, H // ph, W // pw + + wT, wW, wH = 11, 11, 3 + sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW, device=x.device) + + sparse_params = dict( + sta_mask=sta_mask.unsqueeze_(0).unsqueeze_(0), + attention_type="nabla", + to_fractal=True, + P=0.8, + wT=wT, wW=wW, wH=wH, + add_sta=True, + visual_shape=(T, H, W), + method="topcdf", + ) + else: + sparse_params = None + visual_embed = visual_embed.flatten(1, -2) + for i, block in enumerate(self.visual_transformer_blocks): transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): - return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) - visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"] + return block( + x=args["x"], + context=args["context"], + time_embed=args["time_embed"], + freqs=args["freqs"], + sparse_params=args.get("sparse_params"), + transformer_options=args.get("transformer_options"), + ) + visual_embed = blocks_replace[("double_block", i)]( + { + "x": visual_embed, + "context": context, + "time_embed": time_embed, + "freqs": freqs, + "sparse_params": sparse_params, + "transformer_options": transformer_options, + }, + {"original_block": block_wrap}, + )["x"] else: - visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options) + visual_embed = block( + visual_embed, + context, + time_embed, + freqs=freqs, + sparse_params=sparse_params, + transformer_options=transformer_options, + ) + + if T > 30: + visual_embed = fractal_unflatten( + visual_embed[0], + visual_shape[1:], + ).unsqueeze(0) + else: + visual_embed = visual_embed.reshape(*visual_shape, -1) - visual_embed = visual_embed.reshape(*visual_shape, -1) return self.out_layer(visual_embed, time_embed) def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): @@ -411,3 +482,5 @@ class Kandinsky5(nn.Module): self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs) + + diff --git a/comfy/ldm/kandinsky5/utils_nabla.py b/comfy/ldm/kandinsky5/utils_nabla.py new file mode 100644 index 000000000..705b1d75e --- /dev/null +++ b/comfy/ldm/kandinsky5/utils_nabla.py @@ -0,0 +1,147 @@ +import math + +import torch +from torch import Tensor +from torch.nn.attention.flex_attention import BlockMask, flex_attention + + +def fractal_flatten(x, rope, shape): + pixel_size = 8 + x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0) + rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0) + x = x.flatten(0, 1) + rope = rope.flatten(0, 1) + return x, rope + + +def fractal_unflatten(x, shape): + pixel_size = 8 + x = x.reshape(-1, pixel_size**2, x.shape[-1]) + x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0) + return x + + +def local_patching(x, shape, group_size, dim=0): + duration, height, width = shape + g1, g2, g3 = group_size + x = x.reshape( + *x.shape[:dim], + duration // g1, + g1, + height // g2, + g2, + width // g3, + g3, + *x.shape[dim + 3 :] + ) + x = x.permute( + *range(len(x.shape[:dim])), + dim, + dim + 2, + dim + 4, + dim + 1, + dim + 3, + dim + 5, + *range(dim + 6, len(x.shape)) + ) + x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3) + return x + + +def local_merge(x, shape, group_size, dim=0): + duration, height, width = shape + g1, g2, g3 = group_size + x = x.reshape( + *x.shape[:dim], + duration // g1, + height // g2, + width // g3, + g1, + g2, + g3, + *x.shape[dim + 2 :] + ) + x = x.permute( + *range(len(x.shape[:dim])), + dim, + dim + 3, + dim + 1, + dim + 4, + dim + 2, + dim + 5, + *range(dim + 6, len(x.shape)) + ) + x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3) + return x + +def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> Tensor: + l = torch.Tensor([T, H, W]).amax() + r = torch.arange(0, l, 1, dtype=torch.int16, device=device) + mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() + sta_t, sta_h, sta_w = ( + mat[:T, :T].flatten(), + mat[:H, :H].flatten(), + mat[:W, :W].flatten(), + ) + sta_t = sta_t <= wT // 2 + sta_h = sta_h <= wH // 2 + sta_w = sta_w <= wW // 2 + sta_hw = ( + (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)) + .reshape(H, H, W, W) + .transpose(1, 2) + .flatten() + ) + sta = ( + (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)) + .reshape(T, T, H * W, H * W) + .transpose(1, 2) + ) + return sta.reshape(T * H * W, T * H * W) + +def nablaT_v2(q: Tensor, k: Tensor, sta: Tensor, thr: float = 0.9) -> BlockMask: + # Map estimation + B, h, S, D = q.shape + s1 = S // 64 + qa = q.reshape(B, h, s1, 64, D).mean(-2) + ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1) + map = qa @ ka + + map = torch.softmax(map / math.sqrt(D), dim=-1) + # Map binarization + vals, inds = map.sort(-1) + cvals = vals.cumsum_(-1) + mask = (cvals >= 1 - thr).int() + mask = mask.gather(-1, inds.argsort(-1)) + mask = torch.logical_or(mask, sta) + + # BlockMask creation + kv_nb = mask.sum(-1).to(torch.int32) + kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32) + return BlockMask.from_kv_blocks( + torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None + ) + +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) +def nabla(query, key, value, sparse_params=None): + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + block_mask = nablaT_v2( + query, + key, + sparse_params["sta_mask"], + thr=sparse_params["P"], + ) + out = ( + flex_attention( + query, + key, + value, + block_mask=block_mask + ) + .transpose(1, 2) + .contiguous() + ) + out = out.flatten(-2, -1) + return out \ No newline at end of file From 0c84b7650fca765162a33028237603c3ada81664 Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Tue, 16 Dec 2025 11:15:59 +0000 Subject: [PATCH 02/13] Add batch support for nabla --- comfy/ldm/kandinsky5/model.py | 15 +++++++-------- comfy/ldm/kandinsky5/utils_nabla.py | 13 ++++++------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index df10e4496..ac9352867 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -387,13 +387,12 @@ class Kandinsky5(nn.Module): transformer_options["block_type"] = "double" B, _, T, H, W = x.shape - if T > 30: # 10 sec generation + NABLA_THR = 40 # long (10 sec) generation + if T > NABLA_THR: assert self.patch_size[0] == 1 - freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])[0] - visual_embed_4d, freqs = fractal_flatten(visual_embed[0], freqs, visual_shape[1:]) - visual_embed, freqs = visual_embed_4d.unsqueeze(0), freqs.unsqueeze(0) - + freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:]) + visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:]) pt, ph, pw = self.patch_size T, H, W = T // pt, H // ph, W // pw @@ -447,11 +446,11 @@ class Kandinsky5(nn.Module): transformer_options=transformer_options, ) - if T > 30: + if T > NABLA_THR: visual_embed = fractal_unflatten( - visual_embed[0], + visual_embed, visual_shape[1:], - ).unsqueeze(0) + ) else: visual_embed = visual_embed.reshape(*visual_shape, -1) diff --git a/comfy/ldm/kandinsky5/utils_nabla.py b/comfy/ldm/kandinsky5/utils_nabla.py index 705b1d75e..5e2bc4076 100644 --- a/comfy/ldm/kandinsky5/utils_nabla.py +++ b/comfy/ldm/kandinsky5/utils_nabla.py @@ -7,20 +7,19 @@ from torch.nn.attention.flex_attention import BlockMask, flex_attention def fractal_flatten(x, rope, shape): pixel_size = 8 - x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0) - rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0) - x = x.flatten(0, 1) - rope = rope.flatten(0, 1) + x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1) + rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1) + x = x.flatten(1, 2) + rope = rope.flatten(1, 2) return x, rope def fractal_unflatten(x, shape): pixel_size = 8 - x = x.reshape(-1, pixel_size**2, x.shape[-1]) - x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0) + x = x.reshape(x.shape[0], -1, pixel_size**2, x.shape[-1]) + x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1) return x - def local_patching(x, shape, group_size, dim=0): duration, height, width = shape g1, g2, g3 = group_size From a3f78be5c27b9b5a2c292b5a0edf3b65c3dd1121 Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Wed, 17 Dec 2025 07:37:46 +0000 Subject: [PATCH 03/13] Add 128 divisibility for nabla --- comfy/ldm/kandinsky5/model.py | 4 +--- comfy_extras/nodes_kandinsky5.py | 3 +++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index ac9352867..f4e02af70 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -387,7 +387,7 @@ class Kandinsky5(nn.Module): transformer_options["block_type"] = "double" B, _, T, H, W = x.shape - NABLA_THR = 40 # long (10 sec) generation + NABLA_THR = 31 # long (10 sec) generation if T > NABLA_THR: assert self.patch_size[0] == 1 @@ -481,5 +481,3 @@ class Kandinsky5(nn.Module): self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs) - - diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index 9cb234be1..aaaf83566 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -34,6 +34,9 @@ class Kandinsky5ImageToVideo(io.ComfyNode): @classmethod def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: + if length > 121: # 10 sec generation, for nabla + height = 128 * round(height / 128) + width = 128 * round(width / 128) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) cond_latent_out = {} if start_image is not None: From 296b7c7b6d7111c5ee07a09576e7e1e6255aeeff Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Wed, 17 Dec 2025 11:40:14 +0000 Subject: [PATCH 04/13] Small fixes --- comfy/ldm/kandinsky5/model.py | 7 +++++-- comfy/ldm/kandinsky5/utils_nabla.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index f4e02af70..24a06da0a 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -391,19 +391,22 @@ class Kandinsky5(nn.Module): if T > NABLA_THR: assert self.patch_size[0] == 1 + # pro video model uses lower P at higher resolutions + P = 0.7 if self.model_dim == 4096 and H * W >= 14080 else 0.9 + freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:]) visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:]) pt, ph, pw = self.patch_size T, H, W = T // pt, H // ph, W // pw - wT, wW, wH = 11, 11, 3 + wT, wW, wH = 11, 3, 3 sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW, device=x.device) sparse_params = dict( sta_mask=sta_mask.unsqueeze_(0).unsqueeze_(0), attention_type="nabla", to_fractal=True, - P=0.8, + P=P, wT=wT, wW=wW, wH=wH, add_sta=True, visual_shape=(T, H, W), diff --git a/comfy/ldm/kandinsky5/utils_nabla.py b/comfy/ldm/kandinsky5/utils_nabla.py index 5e2bc4076..a346736b2 100644 --- a/comfy/ldm/kandinsky5/utils_nabla.py +++ b/comfy/ldm/kandinsky5/utils_nabla.py @@ -143,4 +143,4 @@ def nabla(query, key, value, sparse_params=None): .contiguous() ) out = out.flatten(-2, -1) - return out \ No newline at end of file + return out From bfe4b31a3295cf93dea14c8a8aa6723a15a85af1 Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Tue, 23 Dec 2025 15:22:48 +0000 Subject: [PATCH 05/13] Add i2i pipeline --- comfy/model_base.py | 30 ++++++++++++++++++++ comfy/sd.py | 4 +++ comfy/supported_models.py | 25 +++++++++++++++- comfy/text_encoders/kandinsky5.py | 5 ++++ comfy_extras/nodes_kandinsky5.py | 47 +++++++++++++++++++++++++++++++ nodes.py | 2 +- 6 files changed, 111 insertions(+), 2 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 6b8a8454d..d69328deb 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1681,3 +1681,33 @@ class Kandinsky5Image(Kandinsky5): def concat_cond(self, **kwargs): return None + +class Kandinsky5ImageToImage(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__( + model_config, + model_type, + device=device, + unet_model=comfy.ldm.kandinsky5.model.Kandinsky5 + ) + + def encode_adm(self, **kwargs): + return kwargs["pooled_output"] + + def concat_cond(self, **kwargs): + noise = kwargs["noise"] + device = kwargs["device"] + image = kwargs.get("latent_image", None) + image = utils.resize_to_batch_size(image, noise.shape[0]) + mask_ones = torch.ones_like(noise)[:, :1].to(device=device) + return torch.cat((image, mask_ones), dim=1) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out["attention_mask"] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out["c_crossattn"] = comfy.conds.CONDRegular(cross_attn) + return out \ No newline at end of file diff --git a/comfy/sd.py b/comfy/sd.py index 1cad98aef..bf9180b21 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -992,6 +992,7 @@ class CLIPType(Enum): OVIS = 21 KANDINSKY5 = 22 KANDINSKY5_IMAGE = 23 + KANDINSKY5_I2I = 24 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1246,6 +1247,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.KANDINSKY5_IMAGE: clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage + elif clip_type == CLIPType.KANDINSKY5_I2I: + clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerI2I else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 1888f35ba..ad3c3b40b 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1535,7 +1535,30 @@ class Kandinsky5Image(Kandinsky5): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) +class Kandinsky5ImageToImage(supported_models_base.BASE): + unet_config = { + "image_model": "kandinsky5", + "model_dim": 2560, + "visual_embed_dim": 132, + } -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] + sampling_settings = { + "shift": 3.0, + } + + latent_format = latent_formats.Flux + memory_usage_factor = 1.25 #TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Kandinsky5ImageToImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerI2I, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5ImageToImage, Kandinsky5Image, Kandinsky5] models += [SVD_img2vid] diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py index be086458c..ec5a0d5f7 100644 --- a/comfy/text_encoders/kandinsky5.py +++ b/comfy/text_encoders/kandinsky5.py @@ -21,6 +21,11 @@ class Kandinsky5TokenizerImage(Kandinsky5Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" +class Kandinsky5TokenizerI2I(Kandinsky5Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a promt engineer. Based on the provided source image (first image) and target image (second image), create an interesting text prompt that can be used together with the source image to create the target image:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + class Qwen25_7BVLIModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}): diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index aaaf83566..5c46296c0 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -1,6 +1,7 @@ import nodes import node_helpers import torch +import torchvision.transforms.functional as F import comfy.model_management import comfy.utils @@ -55,6 +56,51 @@ class Kandinsky5ImageToVideo(io.ComfyNode): return io.NodeOutput(positive, negative, out_latent, cond_latent_out) + +class Kandinsky5ImageToImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Kandinsky5ImageToImage", + category="image", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image"), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty video latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, batch_size, start_image) -> io.NodeOutput: + height, width = start_image.shape[1:-1] + + available_res = [(1024, 1024), (640, 1408), (1408, 640), (768, 1280), (1280, 768), (896, 1152), (1152, 896)] + nearest_index = torch.argmin(torch.Tensor([abs((w / h) - (width / height))for (w, h) in available_res])) + nw, nh = available_res[nearest_index] + scale_factor = min(height / nh, width / nw) + start_image = start_image.permute(0,3,1,2) + start_image = F.resize(start_image, (int(height / scale_factor), int(width / scale_factor))) + start_image = F.crop( + start_image, + (height - nh) // 2, + (width - nw) // 2, + nh, + nw, + ) + print(start_image.shape) + start_image = start_image.permute(0,2,3,1) + encoded = vae.encode(start_image[:, :, :, :3]) + out_latent = {"samples": encoded.repeat(batch_size, 1, 1, 1)} + return io.NodeOutput(positive, negative, out_latent) + + def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5): source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W @@ -131,6 +177,7 @@ class Kandinsky5Extension(ComfyExtension): async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ Kandinsky5ImageToVideo, + Kandinsky5ImageToImage, NormalizeVideoLatentStart, CLIPTextEncodeKandinsky5, ] diff --git a/nodes.py b/nodes.py index 3fa543294..d422a3b00 100644 --- a/nodes.py +++ b/nodes.py @@ -970,7 +970,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "kandinsky5_i2i"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), From a78d870d499833dbbadceb556c1315e5f882e429 Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Mon, 29 Dec 2025 15:03:59 +0000 Subject: [PATCH 06/13] Add image to text encoders --- comfy/text_encoders/kandinsky5.py | 2 +- comfy_extras/nodes_kandinsky5.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py index ec5a0d5f7..d4351f5f6 100644 --- a/comfy/text_encoders/kandinsky5.py +++ b/comfy/text_encoders/kandinsky5.py @@ -24,7 +24,7 @@ class Kandinsky5TokenizerImage(Kandinsky5Tokenizer): class Kandinsky5TokenizerI2I(Kandinsky5Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.llama_template = "<|im_start|>system\nYou are a promt engineer. Based on the provided source image (first image) and target image (second image), create an interesting text prompt that can be used together with the source image to create the target image:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + self.llama_template_images = "<|im_start|>system\nYou are a promt engineer. Based on the provided source image (first image) and target image (second image), create an interesting text prompt that can be used together with the source image to create the target image:<|im_end|>\n<|im_start|>user\n{}<|im_end|><|im_start|>assistant\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>" class Qwen25_7BVLIModel(sd1_clip.SDClipModel): diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index 5c46296c0..3ac1238ab 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -171,6 +171,26 @@ class CLIPTextEncodeKandinsky5(io.ComfyNode): return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) +class TextEncodeQwenKandinskyI2I(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="TextEncodeQwenKandinskyI2I", + category="advanced/conditioning", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Image.Input("image", optional=True), + ], + outputs=[io.Conditioning.Output()], + ) + + @classmethod + def execute(cls, clip, prompt, image=None) -> io.NodeOutput: + images = [image,] if image is not None else [] + tokens = clip.tokenize(prompt, images=images) + conditioning = clip.encode_from_tokens_scheduled(tokens) + return io.NodeOutput(conditioning) class Kandinsky5Extension(ComfyExtension): @override @@ -180,6 +200,7 @@ class Kandinsky5Extension(ComfyExtension): Kandinsky5ImageToImage, NormalizeVideoLatentStart, CLIPTextEncodeKandinsky5, + TextEncodeQwenKandinskyI2I, ] async def comfy_entrypoint() -> Kandinsky5Extension: From b2593317bfedd7ad4f5bd2df680e1258053518d6 Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Mon, 12 Jan 2026 12:17:25 +0000 Subject: [PATCH 07/13] Change dit format to bf16 --- comfy/supported_models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index ad3c3b40b..1f928bb49 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1535,7 +1535,7 @@ class Kandinsky5Image(Kandinsky5): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) -class Kandinsky5ImageToImage(supported_models_base.BASE): +class Kandinsky5ImageToImage(Kandinsky5): unet_config = { "image_model": "kandinsky5", "model_dim": 2560, @@ -1545,7 +1545,6 @@ class Kandinsky5ImageToImage(supported_models_base.BASE): sampling_settings = { "shift": 3.0, } - latent_format = latent_formats.Flux memory_usage_factor = 1.25 #TODO From 646cf6600a6633f511deb871135ad058a297d257 Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Tue, 13 Jan 2026 12:53:27 +0000 Subject: [PATCH 08/13] Fix resize in I2I Node --- comfy_extras/nodes_kandinsky5.py | 38 +++++--------------------------- 1 file changed, 5 insertions(+), 33 deletions(-) diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index 3ac1238ab..22edeb3f7 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -64,29 +64,27 @@ class Kandinsky5ImageToImage(io.ComfyNode): node_id="Kandinsky5ImageToImage", category="image", inputs=[ - io.Conditioning.Input("positive"), - io.Conditioning.Input("negative"), io.Vae.Input("vae"), io.Int.Input("batch_size", default=1, min=1, max=4096), io.Image.Input("start_image"), ], outputs=[ - io.Conditioning.Output(display_name="positive"), - io.Conditioning.Output(display_name="negative"), io.Latent.Output(display_name="latent", tooltip="Empty video latent"), + io.Image.Output("resized_image"), ], ) @classmethod - def execute(cls, positive, negative, vae, batch_size, start_image) -> io.NodeOutput: + def execute(cls, vae, batch_size, start_image) -> io.NodeOutput: height, width = start_image.shape[1:-1] - available_res = [(1024, 1024), (640, 1408), (1408, 640), (768, 1280), (1280, 768), (896, 1152), (1152, 896)] nearest_index = torch.argmin(torch.Tensor([abs((w / h) - (width / height))for (w, h) in available_res])) nw, nh = available_res[nearest_index] scale_factor = min(height / nh, width / nw) start_image = start_image.permute(0,3,1,2) start_image = F.resize(start_image, (int(height / scale_factor), int(width / scale_factor))) + + height, width = start_image.shape[-2:] start_image = F.crop( start_image, (height - nh) // 2, @@ -94,11 +92,10 @@ class Kandinsky5ImageToImage(io.ComfyNode): nh, nw, ) - print(start_image.shape) start_image = start_image.permute(0,2,3,1) encoded = vae.encode(start_image[:, :, :, :3]) out_latent = {"samples": encoded.repeat(batch_size, 1, 1, 1)} - return io.NodeOutput(positive, negative, out_latent) + return io.NodeOutput(out_latent, start_image) def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5): @@ -147,35 +144,11 @@ class NormalizeVideoLatentStart(io.ComfyNode): s["samples"] = samples return io.NodeOutput(s) - class CLIPTextEncodeKandinsky5(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="CLIPTextEncodeKandinsky5", - category="advanced/conditioning/kandinsky5", - inputs=[ - io.Clip.Input("clip"), - io.String.Input("clip_l", multiline=True, dynamic_prompts=True), - io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True), - ], - outputs=[ - io.Conditioning.Output(), - ], - ) - - @classmethod - def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput: - tokens = clip.tokenize(clip_l) - tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"] - - return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) - -class TextEncodeQwenKandinskyI2I(io.ComfyNode): - @classmethod - def define_schema(cls): - return io.Schema( - node_id="TextEncodeQwenKandinskyI2I", category="advanced/conditioning", inputs=[ io.Clip.Input("clip"), @@ -200,7 +173,6 @@ class Kandinsky5Extension(ComfyExtension): Kandinsky5ImageToImage, NormalizeVideoLatentStart, CLIPTextEncodeKandinsky5, - TextEncodeQwenKandinskyI2I, ] async def comfy_entrypoint() -> Kandinsky5Extension: From 04f2e27c40c069dfcd8ad9bf42ba10462842a080 Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Thu, 15 Jan 2026 13:18:19 +0000 Subject: [PATCH 09/13] Add resize in CLIPTextEncodeKandinsky5 --- comfy_extras/nodes_kandinsky5.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index 22edeb3f7..79fb7eccc 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -160,7 +160,12 @@ class CLIPTextEncodeKandinsky5(io.ComfyNode): @classmethod def execute(cls, clip, prompt, image=None) -> io.NodeOutput: - images = [image,] if image is not None else [] + images = [] + if image is not None: + image = image.permute(0,3,1,2) + height, width = image.shape[-2:] + image = F.resize(image, (int(height / 2), int(width / 2))).permute(0,2,3,1) + images.append(image) tokens = clip.tokenize(prompt, images=images) conditioning = clip.encode_from_tokens_scheduled(tokens) return io.NodeOutput(conditioning) From 56841ba4a71323035d5c0f0e5557cea39f1ea16d Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Wed, 21 Jan 2026 09:22:36 +0000 Subject: [PATCH 10/13] Fix llama_template I2I --- comfy/text_encoders/kandinsky5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py index d4351f5f6..2ecb2f183 100644 --- a/comfy/text_encoders/kandinsky5.py +++ b/comfy/text_encoders/kandinsky5.py @@ -24,7 +24,7 @@ class Kandinsky5TokenizerImage(Kandinsky5Tokenizer): class Kandinsky5TokenizerI2I(Kandinsky5Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.llama_template_images = "<|im_start|>system\nYou are a promt engineer. Based on the provided source image (first image) and target image (second image), create an interesting text prompt that can be used together with the source image to create the target image:<|im_end|>\n<|im_start|>user\n{}<|im_end|><|im_start|>assistant\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>" + self.llama_template_images = "<|im_start|>system\nYou are a promt engineer. Based on the provided source image (first image) and target image (second image), create an interesting text prompt that can be used together with the source image to create the target image:<|im_end|>\n<|im_start|>user\n{}<|vision_start|><|image_pad|><|vision_end|><|im_end|>" class Qwen25_7BVLIModel(sd1_clip.SDClipModel): From c55a6fb2710a8ea46b8d19e16d08b86938f0e0e0 Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Mon, 2 Feb 2026 09:04:35 +0000 Subject: [PATCH 11/13] Return typo 'prompt' --- comfy/text_encoders/kandinsky5.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py index 2ecb2f183..9a9f6ea76 100644 --- a/comfy/text_encoders/kandinsky5.py +++ b/comfy/text_encoders/kandinsky5.py @@ -6,7 +6,7 @@ from .llama import Qwen25_7BVLI class Kandinsky5Tokenizer(QwenImageTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>" self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): From 94d1df4b830562fc81b5a72237d4133dad949098 Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Mon, 2 Feb 2026 09:26:09 +0000 Subject: [PATCH 12/13] Small fixes Kandinsky5 --- comfy/model_base.py | 12 ++---------- comfy_extras/nodes_kandinsky5.py | 12 ++++++------ 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index a78a7bbcc..4539f6495 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1791,17 +1791,9 @@ class Kandinsky5Image(Kandinsky5): def concat_cond(self, **kwargs): return None -class Kandinsky5ImageToImage(BaseModel): +class Kandinsky5ImageToImage(Kandinsky5): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): - super().__init__( - model_config, - model_type, - device=device, - unet_model=comfy.ldm.kandinsky5.model.Kandinsky5 - ) - - def encode_adm(self, **kwargs): - return kwargs["pooled_output"] + super().__init__(model_config, model_type, device=device) def concat_cond(self, **kwargs): noise = kwargs["noise"] diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index e3bfd8d99..111cfbb61 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -62,15 +62,15 @@ class Kandinsky5ImageToImage(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Kandinsky5ImageToImage", - category="image", + category="advanced/conditioning/kandinsky5", inputs=[ io.Vae.Input("vae"), io.Int.Input("batch_size", default=1, min=1, max=4096), io.Image.Input("start_image"), ], outputs=[ - io.Latent.Output(display_name="latent", tooltip="Empty video latent"), - io.Image.Output("resized_image"), + io.Latent.Output(display_name="latent", tooltip="Latent of resized source image"), + io.Image.Output("resized_image", tooltip="Resized source image"), ], ) @@ -78,8 +78,8 @@ class Kandinsky5ImageToImage(io.ComfyNode): def execute(cls, vae, batch_size, start_image) -> io.NodeOutput: height, width = start_image.shape[1:-1] available_res = [(1024, 1024), (640, 1408), (1408, 640), (768, 1280), (1280, 768), (896, 1152), (1152, 896)] - nearest_index = torch.argmin(torch.Tensor([abs((w / h) - (width / height))for (w, h) in available_res])) - nw, nh = available_res[nearest_index] + nearest_index = torch.argmin(torch.Tensor([abs((h / w) - (height / width))for (h, w) in available_res])) + nh, nw = available_res[nearest_index] scale_factor = min(height / nh, width / nw) start_image = start_image.permute(0,3,1,2) start_image = F.resize(start_image, (int(height / scale_factor), int(width / scale_factor))) @@ -150,7 +150,7 @@ class CLIPTextEncodeKandinsky5(io.ComfyNode): return io.Schema( node_id="CLIPTextEncodeKandinsky5", search_aliases=["kandinsky prompt"], - category="advanced/conditioning", + category="advanced/conditioning/kandinsky5", inputs=[ io.Clip.Input("clip"), io.String.Input("prompt", multiline=True, dynamic_prompts=True), From 18f773dcded8c34a5dad388018330ff50df4035e Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Mon, 2 Feb 2026 14:26:34 +0000 Subject: [PATCH 13/13] fix ruff errors --- comfy/model_base.py | 4 ++-- comfy_extras/nodes_kandinsky5.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/comfy/model_base.py b/comfy/model_base.py index 4539f6495..94449b69b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1790,7 +1790,7 @@ class Kandinsky5Image(Kandinsky5): def concat_cond(self, **kwargs): return None - + class Kandinsky5ImageToImage(Kandinsky5): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) @@ -1811,4 +1811,4 @@ class Kandinsky5ImageToImage(Kandinsky5): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out["c_crossattn"] = comfy.conds.CONDRegular(cross_attn) - return out \ No newline at end of file + return out diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index 111cfbb61..89c752053 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -96,7 +96,7 @@ class Kandinsky5ImageToImage(io.ComfyNode): encoded = vae.encode(start_image[:, :, :, :3]) out_latent = {"samples": encoded.repeat(batch_size, 1, 1, 1)} return io.NodeOutput(out_latent, start_image) - + def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5): source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W