diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index 1509de2f8..24a06da0a 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -6,6 +6,12 @@ import comfy.ldm.common_dit from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.kandinsky5.utils_nabla import ( + fractal_flatten, + fractal_unflatten, + fast_sta_nabla, + nabla, +) def attention(q, k, v, heads, transformer_options={}): return optimized_attention( @@ -116,14 +122,17 @@ class SelfAttention(nn.Module): result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1) return apply_rope1(norm_fn(result), freqs) - def _forward(self, x, freqs, transformer_options={}): + def _forward(self, x, freqs, sparse_params=None, transformer_options={}): q = self._compute_qk(x, freqs, self.to_query, self.query_norm) k = self._compute_qk(x, freqs, self.to_key, self.key_norm) v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) - out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + if sparse_params is None: + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + else: + out = nabla(q, k, v, sparse_params) return self.out_layer(out) - def _forward_chunked(self, x, freqs, transformer_options={}): + def _forward_chunked(self, x, freqs, sparse_params=None, transformer_options={}): def process_chunks(proj_fn, norm_fn): x_chunks = torch.chunk(x, self.num_chunks, dim=1) freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1) @@ -135,14 +144,17 @@ class SelfAttention(nn.Module): q = process_chunks(self.to_query, self.query_norm) k = process_chunks(self.to_key, self.key_norm) v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) - out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + if sparse_params is None: + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + else: + out = nabla(q, k, v, sparse_params) return self.out_layer(out) - def forward(self, x, freqs, transformer_options={}): + def forward(self, x, freqs, sparse_params=None, transformer_options={}): if x.shape[1] > 8192: - return self._forward_chunked(x, freqs, transformer_options=transformer_options) + return self._forward_chunked(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options) else: - return self._forward(x, freqs, transformer_options=transformer_options) + return self._forward(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options) class CrossAttention(SelfAttention): @@ -251,12 +263,12 @@ class TransformerDecoderBlock(nn.Module): self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) - def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}): + def forward(self, visual_embed, text_embed, time_embed, freqs, sparse_params=None, transformer_options={}): self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1) # self attention shift, scale, gate = get_shift_scale_gate(self_attn_params) visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) - visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options) + visual_out = self.self_attention(visual_out, freqs, sparse_params=sparse_params, transformer_options=transformer_options) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) # cross attention shift, scale, gate = get_shift_scale_gate(cross_attn_params) @@ -369,21 +381,82 @@ class Kandinsky5(nn.Module): visual_embed = self.visual_embeddings(x) visual_shape = visual_embed.shape[:-1] - visual_embed = visual_embed.flatten(1, -2) blocks_replace = patches_replace.get("dit", {}) transformer_options["total_blocks"] = len(self.visual_transformer_blocks) transformer_options["block_type"] = "double" + + B, _, T, H, W = x.shape + NABLA_THR = 31 # long (10 sec) generation + if T > NABLA_THR: + assert self.patch_size[0] == 1 + + # pro video model uses lower P at higher resolutions + P = 0.7 if self.model_dim == 4096 and H * W >= 14080 else 0.9 + + freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:]) + visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:]) + pt, ph, pw = self.patch_size + T, H, W = T // pt, H // ph, W // pw + + wT, wW, wH = 11, 3, 3 + sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW, device=x.device) + + sparse_params = dict( + sta_mask=sta_mask.unsqueeze_(0).unsqueeze_(0), + attention_type="nabla", + to_fractal=True, + P=P, + wT=wT, wW=wW, wH=wH, + add_sta=True, + visual_shape=(T, H, W), + method="topcdf", + ) + else: + sparse_params = None + visual_embed = visual_embed.flatten(1, -2) + for i, block in enumerate(self.visual_transformer_blocks): transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): - return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) - visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"] + return block( + x=args["x"], + context=args["context"], + time_embed=args["time_embed"], + freqs=args["freqs"], + sparse_params=args.get("sparse_params"), + transformer_options=args.get("transformer_options"), + ) + visual_embed = blocks_replace[("double_block", i)]( + { + "x": visual_embed, + "context": context, + "time_embed": time_embed, + "freqs": freqs, + "sparse_params": sparse_params, + "transformer_options": transformer_options, + }, + {"original_block": block_wrap}, + )["x"] else: - visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options) + visual_embed = block( + visual_embed, + context, + time_embed, + freqs=freqs, + sparse_params=sparse_params, + transformer_options=transformer_options, + ) + + if T > NABLA_THR: + visual_embed = fractal_unflatten( + visual_embed, + visual_shape[1:], + ) + else: + visual_embed = visual_embed.reshape(*visual_shape, -1) - visual_embed = visual_embed.reshape(*visual_shape, -1) return self.out_layer(visual_embed, time_embed) def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): diff --git a/comfy/ldm/kandinsky5/utils_nabla.py b/comfy/ldm/kandinsky5/utils_nabla.py new file mode 100644 index 000000000..a346736b2 --- /dev/null +++ b/comfy/ldm/kandinsky5/utils_nabla.py @@ -0,0 +1,146 @@ +import math + +import torch +from torch import Tensor +from torch.nn.attention.flex_attention import BlockMask, flex_attention + + +def fractal_flatten(x, rope, shape): + pixel_size = 8 + x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1) + rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1) + x = x.flatten(1, 2) + rope = rope.flatten(1, 2) + return x, rope + + +def fractal_unflatten(x, shape): + pixel_size = 8 + x = x.reshape(x.shape[0], -1, pixel_size**2, x.shape[-1]) + x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1) + return x + +def local_patching(x, shape, group_size, dim=0): + duration, height, width = shape + g1, g2, g3 = group_size + x = x.reshape( + *x.shape[:dim], + duration // g1, + g1, + height // g2, + g2, + width // g3, + g3, + *x.shape[dim + 3 :] + ) + x = x.permute( + *range(len(x.shape[:dim])), + dim, + dim + 2, + dim + 4, + dim + 1, + dim + 3, + dim + 5, + *range(dim + 6, len(x.shape)) + ) + x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3) + return x + + +def local_merge(x, shape, group_size, dim=0): + duration, height, width = shape + g1, g2, g3 = group_size + x = x.reshape( + *x.shape[:dim], + duration // g1, + height // g2, + width // g3, + g1, + g2, + g3, + *x.shape[dim + 2 :] + ) + x = x.permute( + *range(len(x.shape[:dim])), + dim, + dim + 3, + dim + 1, + dim + 4, + dim + 2, + dim + 5, + *range(dim + 6, len(x.shape)) + ) + x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3) + return x + +def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> Tensor: + l = torch.Tensor([T, H, W]).amax() + r = torch.arange(0, l, 1, dtype=torch.int16, device=device) + mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() + sta_t, sta_h, sta_w = ( + mat[:T, :T].flatten(), + mat[:H, :H].flatten(), + mat[:W, :W].flatten(), + ) + sta_t = sta_t <= wT // 2 + sta_h = sta_h <= wH // 2 + sta_w = sta_w <= wW // 2 + sta_hw = ( + (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)) + .reshape(H, H, W, W) + .transpose(1, 2) + .flatten() + ) + sta = ( + (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)) + .reshape(T, T, H * W, H * W) + .transpose(1, 2) + ) + return sta.reshape(T * H * W, T * H * W) + +def nablaT_v2(q: Tensor, k: Tensor, sta: Tensor, thr: float = 0.9) -> BlockMask: + # Map estimation + B, h, S, D = q.shape + s1 = S // 64 + qa = q.reshape(B, h, s1, 64, D).mean(-2) + ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1) + map = qa @ ka + + map = torch.softmax(map / math.sqrt(D), dim=-1) + # Map binarization + vals, inds = map.sort(-1) + cvals = vals.cumsum_(-1) + mask = (cvals >= 1 - thr).int() + mask = mask.gather(-1, inds.argsort(-1)) + mask = torch.logical_or(mask, sta) + + # BlockMask creation + kv_nb = mask.sum(-1).to(torch.int32) + kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32) + return BlockMask.from_kv_blocks( + torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None + ) + +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) +def nabla(query, key, value, sparse_params=None): + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + block_mask = nablaT_v2( + query, + key, + sparse_params["sta_mask"], + thr=sparse_params["P"], + ) + out = ( + flex_attention( + query, + key, + value, + block_mask=block_mask + ) + .transpose(1, 2) + .contiguous() + ) + out = out.flatten(-2, -1) + return out diff --git a/comfy/model_base.py b/comfy/model_base.py index 85acdb66a..94449b69b 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1790,3 +1790,25 @@ class Kandinsky5Image(Kandinsky5): def concat_cond(self, **kwargs): return None + +class Kandinsky5ImageToImage(Kandinsky5): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + noise = kwargs["noise"] + device = kwargs["device"] + image = kwargs.get("latent_image", None) + image = utils.resize_to_batch_size(image, noise.shape[0]) + mask_ones = torch.ones_like(noise)[:, :1].to(device=device) + return torch.cat((image, mask_ones), dim=1) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out["attention_mask"] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out["c_crossattn"] = comfy.conds.CONDRegular(cross_attn) + return out diff --git a/comfy/sd.py b/comfy/sd.py index fd0ac85e7..43e5eb0c9 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1133,6 +1133,7 @@ class CLIPType(Enum): KANDINSKY5_IMAGE = 23 NEWBIE = 24 FLUX2 = 25 + KANDINSKY5_I2I = 26 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): @@ -1427,6 +1428,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_data_jina = clip_data[0] tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) + elif clip_type == CLIPType.KANDINSKY5_I2I: + clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerI2I else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer diff --git a/comfy/supported_models.py b/comfy/supported_models.py index d25271d6e..ce515763c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1595,7 +1595,29 @@ class Kandinsky5Image(Kandinsky5): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) +class Kandinsky5ImageToImage(Kandinsky5): + unet_config = { + "image_model": "kandinsky5", + "model_dim": 2560, + "visual_embed_dim": 132, + } -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] + sampling_settings = { + "shift": 3.0, + } + latent_format = latent_formats.Flux + memory_usage_factor = 1.25 #TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Kandinsky5ImageToImage(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerI2I, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5ImageToImage, Kandinsky5Image, Kandinsky5, Anima] models += [SVD_img2vid] diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py index be086458c..9a9f6ea76 100644 --- a/comfy/text_encoders/kandinsky5.py +++ b/comfy/text_encoders/kandinsky5.py @@ -6,7 +6,7 @@ from .llama import Qwen25_7BVLI class Kandinsky5Tokenizer(QwenImageTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) - self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>" self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): @@ -21,6 +21,11 @@ class Kandinsky5TokenizerImage(Kandinsky5Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" +class Kandinsky5TokenizerI2I(Kandinsky5Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template_images = "<|im_start|>system\nYou are a promt engineer. Based on the provided source image (first image) and target image (second image), create an interesting text prompt that can be used together with the source image to create the target image:<|im_end|>\n<|im_start|>user\n{}<|vision_start|><|image_pad|><|vision_end|><|im_end|>" + class Qwen25_7BVLIModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}): diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index 346c50cde..89c752053 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -1,6 +1,7 @@ import nodes import node_helpers import torch +import torchvision.transforms.functional as F import comfy.model_management import comfy.utils @@ -34,6 +35,9 @@ class Kandinsky5ImageToVideo(io.ComfyNode): @classmethod def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: + if length > 121: # 10 sec generation, for nabla + height = 128 * round(height / 128) + width = 128 * round(width / 128) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) cond_latent_out = {} if start_image is not None: @@ -52,6 +56,48 @@ class Kandinsky5ImageToVideo(io.ComfyNode): return io.NodeOutput(positive, negative, out_latent, cond_latent_out) + +class Kandinsky5ImageToImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Kandinsky5ImageToImage", + category="advanced/conditioning/kandinsky5", + inputs=[ + io.Vae.Input("vae"), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image"), + ], + outputs=[ + io.Latent.Output(display_name="latent", tooltip="Latent of resized source image"), + io.Image.Output("resized_image", tooltip="Resized source image"), + ], + ) + + @classmethod + def execute(cls, vae, batch_size, start_image) -> io.NodeOutput: + height, width = start_image.shape[1:-1] + available_res = [(1024, 1024), (640, 1408), (1408, 640), (768, 1280), (1280, 768), (896, 1152), (1152, 896)] + nearest_index = torch.argmin(torch.Tensor([abs((h / w) - (height / width))for (h, w) in available_res])) + nh, nw = available_res[nearest_index] + scale_factor = min(height / nh, width / nw) + start_image = start_image.permute(0,3,1,2) + start_image = F.resize(start_image, (int(height / scale_factor), int(width / scale_factor))) + + height, width = start_image.shape[-2:] + start_image = F.crop( + start_image, + (height - nh) // 2, + (width - nw) // 2, + nh, + nw, + ) + start_image = start_image.permute(0,2,3,1) + encoded = vae.encode(start_image[:, :, :, :3]) + out_latent = {"samples": encoded.repeat(batch_size, 1, 1, 1)} + return io.NodeOutput(out_latent, start_image) + + def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5): source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W @@ -98,7 +144,6 @@ class NormalizeVideoLatentStart(io.ComfyNode): s["samples"] = samples return io.NodeOutput(s) - class CLIPTextEncodeKandinsky5(io.ComfyNode): @classmethod def define_schema(cls): @@ -108,27 +153,30 @@ class CLIPTextEncodeKandinsky5(io.ComfyNode): category="advanced/conditioning/kandinsky5", inputs=[ io.Clip.Input("clip"), - io.String.Input("clip_l", multiline=True, dynamic_prompts=True), - io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True), - ], - outputs=[ - io.Conditioning.Output(), + io.String.Input("prompt", multiline=True, dynamic_prompts=True), + io.Image.Input("image", optional=True), ], + outputs=[io.Conditioning.Output()], ) @classmethod - def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput: - tokens = clip.tokenize(clip_l) - tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"] - - return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) - + def execute(cls, clip, prompt, image=None) -> io.NodeOutput: + images = [] + if image is not None: + image = image.permute(0,3,1,2) + height, width = image.shape[-2:] + image = F.resize(image, (int(height / 2), int(width / 2))).permute(0,2,3,1) + images.append(image) + tokens = clip.tokenize(prompt, images=images) + conditioning = clip.encode_from_tokens_scheduled(tokens) + return io.NodeOutput(conditioning) class Kandinsky5Extension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ Kandinsky5ImageToVideo, + Kandinsky5ImageToImage, NormalizeVideoLatentStart, CLIPTextEncodeKandinsky5, ] diff --git a/nodes.py b/nodes.py index 1cb43d9e2..21bd182b6 100644 --- a/nodes.py +++ b/nodes.py @@ -1001,7 +1001,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "kandinsky5_i2i", "ltxv", "newbie"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}),