This commit is contained in:
Mihail Karaev 2026-02-02 14:28:56 +00:00 committed by GitHub
commit 987d843e5a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 349 additions and 29 deletions

View File

@ -6,6 +6,12 @@ import comfy.ldm.common_dit
from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.flux.layers import EmbedND from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.kandinsky5.utils_nabla import (
fractal_flatten,
fractal_unflatten,
fast_sta_nabla,
nabla,
)
def attention(q, k, v, heads, transformer_options={}): def attention(q, k, v, heads, transformer_options={}):
return optimized_attention( return optimized_attention(
@ -116,14 +122,17 @@ class SelfAttention(nn.Module):
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1) result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
return apply_rope1(norm_fn(result), freqs) return apply_rope1(norm_fn(result), freqs)
def _forward(self, x, freqs, transformer_options={}): def _forward(self, x, freqs, sparse_params=None, transformer_options={}):
q = self._compute_qk(x, freqs, self.to_query, self.query_norm) q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
k = self._compute_qk(x, freqs, self.to_key, self.key_norm) k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) if sparse_params is None:
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
else:
out = nabla(q, k, v, sparse_params)
return self.out_layer(out) return self.out_layer(out)
def _forward_chunked(self, x, freqs, transformer_options={}): def _forward_chunked(self, x, freqs, sparse_params=None, transformer_options={}):
def process_chunks(proj_fn, norm_fn): def process_chunks(proj_fn, norm_fn):
x_chunks = torch.chunk(x, self.num_chunks, dim=1) x_chunks = torch.chunk(x, self.num_chunks, dim=1)
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1) freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
@ -135,14 +144,17 @@ class SelfAttention(nn.Module):
q = process_chunks(self.to_query, self.query_norm) q = process_chunks(self.to_query, self.query_norm)
k = process_chunks(self.to_key, self.key_norm) k = process_chunks(self.to_key, self.key_norm)
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) if sparse_params is None:
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
else:
out = nabla(q, k, v, sparse_params)
return self.out_layer(out) return self.out_layer(out)
def forward(self, x, freqs, transformer_options={}): def forward(self, x, freqs, sparse_params=None, transformer_options={}):
if x.shape[1] > 8192: if x.shape[1] > 8192:
return self._forward_chunked(x, freqs, transformer_options=transformer_options) return self._forward_chunked(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
else: else:
return self._forward(x, freqs, transformer_options=transformer_options) return self._forward(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
class CrossAttention(SelfAttention): class CrossAttention(SelfAttention):
@ -251,12 +263,12 @@ class TransformerDecoderBlock(nn.Module):
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}): def forward(self, visual_embed, text_embed, time_embed, freqs, sparse_params=None, transformer_options={}):
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1) self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
# self attention # self attention
shift, scale, gate = get_shift_scale_gate(self_attn_params) shift, scale, gate = get_shift_scale_gate(self_attn_params)
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options) visual_out = self.self_attention(visual_out, freqs, sparse_params=sparse_params, transformer_options=transformer_options)
visual_embed = apply_gate_sum(visual_embed, visual_out, gate) visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
# cross attention # cross attention
shift, scale, gate = get_shift_scale_gate(cross_attn_params) shift, scale, gate = get_shift_scale_gate(cross_attn_params)
@ -369,21 +381,82 @@ class Kandinsky5(nn.Module):
visual_embed = self.visual_embeddings(x) visual_embed = self.visual_embeddings(x)
visual_shape = visual_embed.shape[:-1] visual_shape = visual_embed.shape[:-1]
visual_embed = visual_embed.flatten(1, -2)
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.visual_transformer_blocks) transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
transformer_options["block_type"] = "double" transformer_options["block_type"] = "double"
B, _, T, H, W = x.shape
NABLA_THR = 31 # long (10 sec) generation
if T > NABLA_THR:
assert self.patch_size[0] == 1
# pro video model uses lower P at higher resolutions
P = 0.7 if self.model_dim == 4096 and H * W >= 14080 else 0.9
freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])
visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:])
pt, ph, pw = self.patch_size
T, H, W = T // pt, H // ph, W // pw
wT, wW, wH = 11, 3, 3
sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW, device=x.device)
sparse_params = dict(
sta_mask=sta_mask.unsqueeze_(0).unsqueeze_(0),
attention_type="nabla",
to_fractal=True,
P=P,
wT=wT, wW=wW, wH=wH,
add_sta=True,
visual_shape=(T, H, W),
method="topcdf",
)
else:
sparse_params = None
visual_embed = visual_embed.flatten(1, -2)
for i, block in enumerate(self.visual_transformer_blocks): for i, block in enumerate(self.visual_transformer_blocks):
transformer_options["block_index"] = i transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace: if ("double_block", i) in blocks_replace:
def block_wrap(args): def block_wrap(args):
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) return block(
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"] x=args["x"],
context=args["context"],
time_embed=args["time_embed"],
freqs=args["freqs"],
sparse_params=args.get("sparse_params"),
transformer_options=args.get("transformer_options"),
)
visual_embed = blocks_replace[("double_block", i)](
{
"x": visual_embed,
"context": context,
"time_embed": time_embed,
"freqs": freqs,
"sparse_params": sparse_params,
"transformer_options": transformer_options,
},
{"original_block": block_wrap},
)["x"]
else: else:
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options) visual_embed = block(
visual_embed,
context,
time_embed,
freqs=freqs,
sparse_params=sparse_params,
transformer_options=transformer_options,
)
if T > NABLA_THR:
visual_embed = fractal_unflatten(
visual_embed,
visual_shape[1:],
)
else:
visual_embed = visual_embed.reshape(*visual_shape, -1)
visual_embed = visual_embed.reshape(*visual_shape, -1)
return self.out_layer(visual_embed, time_embed) return self.out_layer(visual_embed, time_embed)
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):

View File

@ -0,0 +1,146 @@
import math
import torch
from torch import Tensor
from torch.nn.attention.flex_attention import BlockMask, flex_attention
def fractal_flatten(x, rope, shape):
pixel_size = 8
x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1)
rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1)
x = x.flatten(1, 2)
rope = rope.flatten(1, 2)
return x, rope
def fractal_unflatten(x, shape):
pixel_size = 8
x = x.reshape(x.shape[0], -1, pixel_size**2, x.shape[-1])
x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1)
return x
def local_patching(x, shape, group_size, dim=0):
duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
duration // g1,
g1,
height // g2,
g2,
width // g3,
g3,
*x.shape[dim + 3 :]
)
x = x.permute(
*range(len(x.shape[:dim])),
dim,
dim + 2,
dim + 4,
dim + 1,
dim + 3,
dim + 5,
*range(dim + 6, len(x.shape))
)
x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3)
return x
def local_merge(x, shape, group_size, dim=0):
duration, height, width = shape
g1, g2, g3 = group_size
x = x.reshape(
*x.shape[:dim],
duration // g1,
height // g2,
width // g3,
g1,
g2,
g3,
*x.shape[dim + 2 :]
)
x = x.permute(
*range(len(x.shape[:dim])),
dim,
dim + 3,
dim + 1,
dim + 4,
dim + 2,
dim + 5,
*range(dim + 6, len(x.shape))
)
x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3)
return x
def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> Tensor:
l = torch.Tensor([T, H, W]).amax()
r = torch.arange(0, l, 1, dtype=torch.int16, device=device)
mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs()
sta_t, sta_h, sta_w = (
mat[:T, :T].flatten(),
mat[:H, :H].flatten(),
mat[:W, :W].flatten(),
)
sta_t = sta_t <= wT // 2
sta_h = sta_h <= wH // 2
sta_w = sta_w <= wW // 2
sta_hw = (
(sta_h.unsqueeze(1) * sta_w.unsqueeze(0))
.reshape(H, H, W, W)
.transpose(1, 2)
.flatten()
)
sta = (
(sta_t.unsqueeze(1) * sta_hw.unsqueeze(0))
.reshape(T, T, H * W, H * W)
.transpose(1, 2)
)
return sta.reshape(T * H * W, T * H * W)
def nablaT_v2(q: Tensor, k: Tensor, sta: Tensor, thr: float = 0.9) -> BlockMask:
# Map estimation
B, h, S, D = q.shape
s1 = S // 64
qa = q.reshape(B, h, s1, 64, D).mean(-2)
ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1)
map = qa @ ka
map = torch.softmax(map / math.sqrt(D), dim=-1)
# Map binarization
vals, inds = map.sort(-1)
cvals = vals.cumsum_(-1)
mask = (cvals >= 1 - thr).int()
mask = mask.gather(-1, inds.argsort(-1))
mask = torch.logical_or(mask, sta)
# BlockMask creation
kv_nb = mask.sum(-1).to(torch.int32)
kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32)
return BlockMask.from_kv_blocks(
torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None
)
@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True)
def nabla(query, key, value, sparse_params=None):
query = query.transpose(1, 2).contiguous()
key = key.transpose(1, 2).contiguous()
value = value.transpose(1, 2).contiguous()
block_mask = nablaT_v2(
query,
key,
sparse_params["sta_mask"],
thr=sparse_params["P"],
)
out = (
flex_attention(
query,
key,
value,
block_mask=block_mask
)
.transpose(1, 2)
.contiguous()
)
out = out.flatten(-2, -1)
return out

View File

@ -1790,3 +1790,25 @@ class Kandinsky5Image(Kandinsky5):
def concat_cond(self, **kwargs): def concat_cond(self, **kwargs):
return None return None
class Kandinsky5ImageToImage(Kandinsky5):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
noise = kwargs["noise"]
device = kwargs["device"]
image = kwargs.get("latent_image", None)
image = utils.resize_to_batch_size(image, noise.shape[0])
mask_ones = torch.ones_like(noise)[:, :1].to(device=device)
return torch.cat((image, mask_ones), dim=1)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
out["attention_mask"] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out["c_crossattn"] = comfy.conds.CONDRegular(cross_attn)
return out

View File

@ -1133,6 +1133,7 @@ class CLIPType(Enum):
KANDINSKY5_IMAGE = 23 KANDINSKY5_IMAGE = 23
NEWBIE = 24 NEWBIE = 24
FLUX2 = 25 FLUX2 = 25
KANDINSKY5_I2I = 26
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@ -1427,6 +1428,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_data_jina = clip_data[0] clip_data_jina = clip_data[0]
tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None) tokenizer_data["gemma_spiece_model"] = clip_data_gemma.get("spiece_model", None)
tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None) tokenizer_data["jina_spiece_model"] = clip_data_jina.get("spiece_model", None)
elif clip_type == CLIPType.KANDINSKY5_I2I:
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerI2I
else: else:
clip_target.clip = sdxl_clip.SDXLClipModel clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer clip_target.tokenizer = sdxl_clip.SDXLTokenizer

View File

@ -1595,7 +1595,29 @@ class Kandinsky5Image(Kandinsky5):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
class Kandinsky5ImageToImage(Kandinsky5):
unet_config = {
"image_model": "kandinsky5",
"model_dim": 2560,
"visual_embed_dim": 132,
}
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima] sampling_settings = {
"shift": 3.0,
}
latent_format = latent_formats.Flux
memory_usage_factor = 1.25 #TODO
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Kandinsky5ImageToImage(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerI2I, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5ImageToImage, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid] models += [SVD_img2vid]

View File

@ -6,7 +6,7 @@ from .llama import Qwen25_7BVLI
class Kandinsky5Tokenizer(QwenImageTokenizer): class Kandinsky5Tokenizer(QwenImageTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>" self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
@ -21,6 +21,11 @@ class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
class Kandinsky5TokenizerI2I(Kandinsky5Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.llama_template_images = "<|im_start|>system\nYou are a promt engineer. Based on the provided source image (first image) and target image (second image), create an interesting text prompt that can be used together with the source image to create the target image:<|im_end|>\n<|im_start|>user\n{}<|vision_start|><|image_pad|><|vision_end|><|im_end|>"
class Qwen25_7BVLIModel(sd1_clip.SDClipModel): class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):

View File

@ -1,6 +1,7 @@
import nodes import nodes
import node_helpers import node_helpers
import torch import torch
import torchvision.transforms.functional as F
import comfy.model_management import comfy.model_management
import comfy.utils import comfy.utils
@ -34,6 +35,9 @@ class Kandinsky5ImageToVideo(io.ComfyNode):
@classmethod @classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput:
if length > 121: # 10 sec generation, for nabla
height = 128 * round(height / 128)
width = 128 * round(width / 128)
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
cond_latent_out = {} cond_latent_out = {}
if start_image is not None: if start_image is not None:
@ -52,6 +56,48 @@ class Kandinsky5ImageToVideo(io.ComfyNode):
return io.NodeOutput(positive, negative, out_latent, cond_latent_out) return io.NodeOutput(positive, negative, out_latent, cond_latent_out)
class Kandinsky5ImageToImage(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Kandinsky5ImageToImage",
category="advanced/conditioning/kandinsky5",
inputs=[
io.Vae.Input("vae"),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("start_image"),
],
outputs=[
io.Latent.Output(display_name="latent", tooltip="Latent of resized source image"),
io.Image.Output("resized_image", tooltip="Resized source image"),
],
)
@classmethod
def execute(cls, vae, batch_size, start_image) -> io.NodeOutput:
height, width = start_image.shape[1:-1]
available_res = [(1024, 1024), (640, 1408), (1408, 640), (768, 1280), (1280, 768), (896, 1152), (1152, 896)]
nearest_index = torch.argmin(torch.Tensor([abs((h / w) - (height / width))for (h, w) in available_res]))
nh, nw = available_res[nearest_index]
scale_factor = min(height / nh, width / nw)
start_image = start_image.permute(0,3,1,2)
start_image = F.resize(start_image, (int(height / scale_factor), int(width / scale_factor)))
height, width = start_image.shape[-2:]
start_image = F.crop(
start_image,
(height - nh) // 2,
(width - nw) // 2,
nh,
nw,
)
start_image = start_image.permute(0,2,3,1)
encoded = vae.encode(start_image[:, :, :, :3])
out_latent = {"samples": encoded.repeat(batch_size, 1, 1, 1)}
return io.NodeOutput(out_latent, start_image)
def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5): def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5):
source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W
source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W
@ -98,7 +144,6 @@ class NormalizeVideoLatentStart(io.ComfyNode):
s["samples"] = samples s["samples"] = samples
return io.NodeOutput(s) return io.NodeOutput(s)
class CLIPTextEncodeKandinsky5(io.ComfyNode): class CLIPTextEncodeKandinsky5(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -108,27 +153,30 @@ class CLIPTextEncodeKandinsky5(io.ComfyNode):
category="advanced/conditioning/kandinsky5", category="advanced/conditioning/kandinsky5",
inputs=[ inputs=[
io.Clip.Input("clip"), io.Clip.Input("clip"),
io.String.Input("clip_l", multiline=True, dynamic_prompts=True), io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True), io.Image.Input("image", optional=True),
],
outputs=[
io.Conditioning.Output(),
], ],
outputs=[io.Conditioning.Output()],
) )
@classmethod @classmethod
def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput: def execute(cls, clip, prompt, image=None) -> io.NodeOutput:
tokens = clip.tokenize(clip_l) images = []
tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"] if image is not None:
image = image.permute(0,3,1,2)
return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) height, width = image.shape[-2:]
image = F.resize(image, (int(height / 2), int(width / 2))).permute(0,2,3,1)
images.append(image)
tokens = clip.tokenize(prompt, images=images)
conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning)
class Kandinsky5Extension(ComfyExtension): class Kandinsky5Extension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [ return [
Kandinsky5ImageToVideo, Kandinsky5ImageToVideo,
Kandinsky5ImageToImage,
NormalizeVideoLatentStart, NormalizeVideoLatentStart,
CLIPTextEncodeKandinsky5, CLIPTextEncodeKandinsky5,
] ]

View File

@ -1001,7 +1001,7 @@ class DualCLIPLoader:
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ), return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
"clip_name2": (folder_paths.get_filename_list("text_encoders"), ), "clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
"type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "ltxv", "newbie"], ), "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image", "kandinsky5_i2i", "ltxv", "newbie"], ),
}, },
"optional": { "optional": {
"device": (["default", "cpu"], {"advanced": True}), "device": (["default", "cpu"], {"advanced": True}),