diff --git a/comfy/ldm/krea2/model.py b/comfy/ldm/krea2/model.py new file mode 100644 index 000000000..ecb16254f --- /dev/null +++ b/comfy/ldm/krea2/model.py @@ -0,0 +1,290 @@ +"""Krea 2 (K2) — single-stream MMDiT. + +Text tokens produced by a Qwen3-VL-4B 12-layer ``txtfusion`` adapter and patchified image tokens are +concatenated into one sequence and run through ``layers`` shared transformer blocks with +AdaLN-single modulation, GQA + per-head QK-norm + sigmoid-gated attention, SwiGLU MLP, and 3-axis RoPE. +""" + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +import comfy.model_management +import comfy.patcher_extension +import comfy.ldm.common_dit +from comfy.ldm.flux.layers import EmbedND, timestep_embedding +from comfy.ldm.flux.math import apply_rope +from comfy.ldm.modules.attention import optimized_attention_masked + + +class RMSNorm(nn.Module): + """RMSNorm with the reference ``(1 + scale)`` weight convention (scale stored zero-centered).""" + + def __init__(self, features: int, eps: float = 1e-5, device=None, dtype=None, operations=None): + super().__init__() + self.eps = eps + self.scale = nn.Parameter(torch.empty(features, device=device, dtype=dtype)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = x.dtype + weight = comfy.model_management.cast_to(self.scale, dtype=torch.float32, device=x.device) + 1.0 + return F.rms_norm(x.float(), (x.shape[-1],), weight=weight, eps=self.eps).to(dtype) + + +class QKNorm(nn.Module): + def __init__(self, dim: int, device=None, dtype=None, operations=None): + super().__init__() + self.qnorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations) + self.knorm = RMSNorm(dim, device=device, dtype=dtype, operations=operations) + + def forward(self, q, k): + return self.qnorm(q), self.knorm(k) + + +class SwiGLU(nn.Module): + def __init__(self, features: int, multiplier: int, bias: bool = False, multiple: int = 128, + device=None, dtype=None, operations=None): + super().__init__() + mlpdim = int(2 * features / 3) * multiplier + mlpdim = multiple * ((mlpdim + multiple - 1) // multiple) + self.gate = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype) + self.up = operations.Linear(features, mlpdim, bias=bias, device=device, dtype=dtype) + self.down = operations.Linear(mlpdim, features, bias=bias, device=device, dtype=dtype) + + def forward(self, x): + return self.down(F.silu(self.gate(x)).mul_(self.up(x))) + + +class Attention(nn.Module): + def __init__(self, dim: int, heads: int, kvheads: Optional[int] = None, bias: bool = False, + device=None, dtype=None, operations=None): + super().__init__() + self.heads = heads + self.kvheads = kvheads if kvheads is not None else heads + self.headdim = dim // self.heads + self.wq = operations.Linear(dim, self.headdim * self.heads, bias=bias, device=device, dtype=dtype) + self.wk = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype) + self.wv = operations.Linear(dim, self.headdim * self.kvheads, bias=bias, device=device, dtype=dtype) + self.gate = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype) + self.qknorm = QKNorm(self.headdim, device=device, dtype=dtype, operations=operations) + self.wo = operations.Linear(dim, dim, bias=bias, device=device, dtype=dtype) + + def forward(self, x, freqs=None, mask=None, transformer_options={}): + q, k, v, gate = self.wq(x), self.wk(x), self.wv(x), self.gate(x) + q = rearrange(q, "B L (H D) -> B H L D", H=self.heads) + k = rearrange(k, "B L (H D) -> B H L D", H=self.kvheads) + v = rearrange(v, "B L (H D) -> B H L D", H=self.kvheads) + q, k = self.qknorm(q, k) + if freqs is not None: + q, k = apply_rope(q, k, freqs) + if self.kvheads != self.heads: + rep = self.heads // self.kvheads + k = k.repeat_interleave(rep, dim=1) + v = v.repeat_interleave(rep, dim=1) + out = optimized_attention_masked(q, k, v, self.heads, mask=mask, skip_reshape=True, + transformer_options=transformer_options) + return self.wo(out * F.sigmoid(gate)) + + +class SimpleModulation(nn.Module): + def __init__(self, dim: int, device=None, dtype=None, operations=None): + super().__init__() + self.lin = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype)) + + def forward(self, vec): + out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device).unsqueeze(0) + scale, shift = out.chunk(2, dim=1) + return scale, shift + + +class DoubleSharedModulation(nn.Module): + def __init__(self, dim: int, device=None, dtype=None, operations=None): + super().__init__() + self.lin = nn.Parameter(torch.empty(6 * dim, device=device, dtype=dtype)) + + def forward(self, vec): + out = vec + comfy.model_management.cast_to(self.lin, dtype=vec.dtype, device=vec.device) + return out.chunk(6, dim=-1) + + +class TextFusionBlock(nn.Module): + def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None): + super().__init__() + self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations) + self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations) + self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations) + self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations) + + def forward(self, x, mask=None, transformer_options={}): + x = x + self.attn(self.prenorm(x), mask=mask, transformer_options=transformer_options) + x = x + self.mlp(self.postnorm(x)) + return x + + +class TextFusionTransformer(nn.Module): + def __init__(self, num_txt_layers, txt_dim, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None): + super().__init__() + self.layerwise_blocks = nn.ModuleList([ + TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations) + for _ in range(2) + ]) + self.projector = operations.Linear(num_txt_layers, 1, bias=False, device=device, dtype=dtype) + self.refiner_blocks = nn.ModuleList([ + TextFusionBlock(txt_dim, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations) + for _ in range(2) + ]) + + def forward(self, x, mask=None, transformer_options={}): + b, l, n, d = x.shape + x = x.reshape(b * l, n, d) + for block in self.layerwise_blocks: + x = block(x.contiguous(), mask=None, transformer_options=transformer_options) + x = rearrange(x, "(b l) n d -> b l d n", b=b, l=l) + x = self.projector(x).squeeze(-1) + for block in self.refiner_blocks: + x = block(x, mask=mask, transformer_options=transformer_options) + return x + + +class SingleStreamBlock(nn.Module): + def __init__(self, features, heads, multiplier, bias=False, kvheads=None, device=None, dtype=None, operations=None): + super().__init__() + self.mod = DoubleSharedModulation(features, device=device, dtype=dtype, operations=operations) + self.prenorm = RMSNorm(features, device=device, dtype=dtype, operations=operations) + self.postnorm = RMSNorm(features, device=device, dtype=dtype, operations=operations) + self.attn = Attention(features, heads, kvheads=kvheads, bias=bias, device=device, dtype=dtype, operations=operations) + self.mlp = SwiGLU(features, multiplier, bias, device=device, dtype=dtype, operations=operations) + + def forward(self, x, vec, freqs, mask=None, transformer_options={}): + prescale, preshift, pregate, postscale, postshift, postgate = self.mod(vec) + x = x + pregate * self.attn((1 + prescale) * self.prenorm(x) + preshift, freqs, mask, transformer_options=transformer_options) + x = x + postgate * self.mlp((1 + postscale) * self.postnorm(x) + postshift) + return x + + +class LastLayer(nn.Module): + def __init__(self, features, patch, channels, device=None, dtype=None, operations=None): + super().__init__() + self.norm = RMSNorm(features, device=device, dtype=dtype, operations=operations) + self.linear = operations.Linear(features, patch * patch * channels, bias=True, device=device, dtype=dtype) + self.modulation = SimpleModulation(features, device=device, dtype=dtype, operations=operations) + + def forward(self, x, tvec): + scale, shift = self.modulation(tvec) + x = (1 + scale) * self.norm(x) + shift + return self.linear(x) + + +class SingleStreamDiT(nn.Module): + def __init__(self, features=6144, tdim=256, txtdim=2560, heads=48, kvheads=12, multiplier=4, + layers=28, patch=2, channels=16, bias=False, theta=1e3, txtlayers=12, + txtheads=20, txtkvheads=20, image_model=None, + device=None, dtype=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + self.patch = patch + self.channels = channels + self.tdim = tdim + self.heads = heads + self.txtdim = txtdim + self.txtlayers = txtlayers + + headdim = features // heads + axes = [headdim - 12 * (headdim // 16), 6 * (headdim // 16), 6 * (headdim // 16)] + assert sum(axes) == headdim, f"axes {axes} sum != headdim {headdim}" + self.pe_embedder = EmbedND(dim=headdim, theta=int(theta), axes_dim=axes) + + self.first = operations.Linear(channels * patch ** 2, features, bias=True, device=device, dtype=dtype) + self.blocks = nn.ModuleList([ + SingleStreamBlock(features, heads, multiplier, bias, kvheads, device=device, dtype=dtype, operations=operations) + for _ in range(layers) + ]) + self.tmlp = nn.Sequential( + operations.Linear(tdim, features, device=device, dtype=dtype), + nn.GELU(approximate="tanh"), + operations.Linear(features, features, device=device, dtype=dtype), + ) + self.txtfusion = TextFusionTransformer(txtlayers, txtdim, txtheads, multiplier, bias, txtkvheads, + device=device, dtype=dtype, operations=operations) + self.txtmlp = nn.Sequential( + RMSNorm(txtdim, device=device, dtype=dtype, operations=operations), + operations.Linear(txtdim, features, device=device, dtype=dtype), + nn.GELU(approximate="tanh"), + operations.Linear(features, features, device=device, dtype=dtype), + ) + self.last = LastLayer(features, patch, channels, device=device, dtype=dtype, operations=operations) + self.tproj = nn.Sequential( + nn.GELU(approximate="tanh"), + operations.Linear(features, features * 6, device=device, dtype=dtype), + ) + + def forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options), + ).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs) + + def _forward(self, x, timesteps, context, attention_mask=None, transformer_options={}, **kwargs): + temporal = x.ndim == 5 + if temporal: + b5, c5, t5, h5, w5 = x.shape + x = x.reshape(b5 * t5, c5, h5, w5) + bs, c, H_orig, W_orig = x.shape + patch = self.patch + # Pad the latent up to a multiple of patch (as Flux/Lumina/QwenImage do); crop back at the end. + x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch, patch)) + H, W = x.shape[-2], x.shape[-1] + h_, w_ = H // patch, W // patch + + # context arrives as (B, seq, txtlayers*txtdim); reshape to (B, txtlayers, seq, txtdim). + context = self._unpack_context(context) + + img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch, pw=patch) + img = self.first(img) + + t = self.tmlp(timestep_embedding(timesteps, self.tdim).unsqueeze(1).to(img.dtype)) + tvec = self.tproj(t) + + context = self.txtfusion(context, mask=None, transformer_options=transformer_options) + context = self.txtmlp(context) + + txtlen, imglen = context.shape[1], img.shape[1] + combined = torch.cat((context, img), dim=1) + + # Position ids: text at 0, image at (0, h_idx, w_idx). + device = combined.device + txtpos = torch.zeros(bs, txtlen, 3, device=device, dtype=torch.float32) + imgids = torch.zeros(h_, w_, 3, device=device, dtype=torch.float32) + imgids[..., 1] = torch.arange(h_, device=device, dtype=torch.float32)[:, None] + imgids[..., 2] = torch.arange(w_, device=device, dtype=torch.float32)[None, :] + imgpos = imgids.reshape(1, h_ * w_, 3).repeat(bs, 1, 1) + pos = torch.cat((txtpos, imgpos), dim=1) + + freqs = self.pe_embedder(pos) + + for block in self.blocks: + combined = block(combined, tvec, freqs, None, transformer_options=transformer_options) + + final = self.last(combined, t) + out = final[:, txtlen:txtlen + imglen, :] + out = rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", + h=h_, w=w_, ph=patch, pw=patch, c=self.channels) + out = out[:, :, :H_orig, :W_orig] # crop padding back off + if temporal: + out = out.reshape(b5, t5, self.channels, H_orig, W_orig).movedim(1, 2) + return out + + def _unpack_context(self, context): + # context: (B, seq, txtlayers*txtdim) -> (B, seq, txtlayers, txtdim). + b, seq, fused = context.shape + if fused != self.txtlayers * self.txtdim: + raise ValueError( + f"Krea2 expects conditioning with {self.txtlayers}x{self.txtdim}={self.txtlayers * self.txtdim} " + f"features (a {self.txtlayers}-layer Qwen3-VL stack) but got {fused}. " + f"Load the text encoder with CLIPLoader type 'krea2'." + ) + return context.reshape(b, seq, self.txtlayers, self.txtdim) diff --git a/comfy/lora.py b/comfy/lora.py index 2c8d0f0bf..427cf98aa 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -326,6 +326,17 @@ def model_lora_keys_unet(model, key_map={}): key_map["transformer.{}".format(key_lora)] = k key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format + if isinstance(model, comfy.model_base.Krea2): + diffusers_keys = comfy.utils.krea2_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = k[:-len(".weight")] + key_map["diffusion_model.{}".format(key_lora)] = to + key_map["transformer.{}".format(key_lora)] = to + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + key_map[key_lora] = to + if isinstance(model, comfy.model_base.Lumina2): diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: diff --git a/comfy/model_base.py b/comfy/model_base.py index 264dbb9b3..dcfa555dc 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -58,6 +58,7 @@ import comfy.ldm.omnigen.omnigen2 import comfy.ldm.boogu.model import comfy.ldm.qwen_image.model import comfy.ldm.ideogram4.model +import comfy.ldm.krea2.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 @@ -2278,6 +2279,17 @@ class Ideogram4(BaseModel): out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) return out +class Krea2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLUX, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.krea2.model.SingleStreamDiT) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out + class HunyuanImage21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index b773f0393..e53d848c9 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -834,6 +834,21 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') return dit_config + if '{}txtfusion.projector.weight'.format(key_prefix) in state_dict_keys: # Krea 2 (K2) + dit_config = {} + dit_config["image_model"] = "krea2" + head_dim = 128 + first_w = state_dict['{}first.weight'.format(key_prefix)] # (features, channels*patch^2) + dit_config["features"] = first_w.shape[0] + dit_config["channels"] = first_w.shape[1] // (2 * 2) # patch=2 + dit_config["patch"] = 2 + dit_config["layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') + dit_config["heads"] = state_dict['{}blocks.0.attn.wq.weight'.format(key_prefix)].shape[0] // head_dim + dit_config["kvheads"] = state_dict['{}blocks.0.attn.wk.weight'.format(key_prefix)].shape[0] // head_dim + dit_config["txtlayers"] = state_dict['{}txtfusion.projector.weight'.format(key_prefix)].shape[1] + dit_config["txtdim"] = state_dict['{}txtfusion.layerwise_blocks.0.prenorm.scale'.format(key_prefix)].shape[0] + return dit_config + if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 dit_config = {} model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0] diff --git a/comfy/sd.py b/comfy/sd.py index d9b1c0553..610c4e2b8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -58,6 +58,7 @@ import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image +import comfy.text_encoders.krea2 import comfy.text_encoders.ideogram4 import comfy.text_encoders.ovis import comfy.text_encoders.kandinsky5 @@ -1303,6 +1304,7 @@ class CLIPType(Enum): PIXELDIT = 29 IDEOGRAM4 = 30 BOOGU = 31 + KREA2 = 32 @@ -1628,6 +1630,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) clip_target.clip = comfy.text_encoders.boogu.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.boogu.BooguTokenizer + elif clip_type == CLIPType.KREA2 and te_model == TEModel.QWEN3VL_4B: # Krea2: full Qwen3-VL-4B (12-layer tap for conditioning + multimodal generate). + clip_data[0] = comfy.utils.state_dict_prefix_replace(clip_data[0], {"model.language_model.": "model.", "model.visual.": "visual.", "lm_head.": "model.lm_head."}) + clip_target.clip = comfy.text_encoders.krea2.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.krea2.Krea2Tokenizer elif clip_type in (CLIPType.FLUX, CLIPType.FLUX2): # Flux2 Klein reuses the Qwen3-VL LM (3-layer tap -> 12288); visual unused. klein_model_type = "qwen3_8b" if te_model == TEModel.QWEN3VL_8B else "qwen3_4b" clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type=klein_model_type) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index cc05908ee..afb66e6f3 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -26,6 +26,7 @@ import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image import comfy.text_encoders.ideogram4 import comfy.text_encoders.boogu +import comfy.text_encoders.krea2 import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image @@ -1818,6 +1819,35 @@ class Ideogram4(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.ideogram4.Ideogram4Tokenizer, comfy.text_encoders.ideogram4.te(**hunyuan_detect)) + +class Krea2(supported_models_base.BASE): + unet_config = { + "image_model": "krea2", + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 1.15, + } + + memory_usage_factor = 2.2 + + latent_format = latent_formats.Wan21 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Krea2(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_4b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.krea2.Krea2Tokenizer, comfy.text_encoders.krea2.te(**hunyuan_detect)) + class QwenImage(supported_models_base.BASE): unet_config = { "image_model": "qwen_image", @@ -2325,6 +2355,7 @@ models = [ Boogu, QwenImage, Ideogram4, + Krea2, Flux2, Lens, Kandinsky5Image, diff --git a/comfy/text_encoders/krea2.py b/comfy/text_encoders/krea2.py new file mode 100644 index 000000000..408a03566 --- /dev/null +++ b/comfy/text_encoders/krea2.py @@ -0,0 +1,84 @@ +"""Krea 2 (K2) text encoder: Qwen3-VL-4B, 12-layer tap. + +K2 conditions on a stack of hidden states from 12 layers of Qwen3-VL-4B +(reference taps ``hidden_states[2,5,8,...,35]``), kept as a ``(B, 12, seq, 2560)`` tensor and +consumed by the DiT's internal ``txtfusion`` adapter. Comfy carries conditioning as a 3D tensor, +so the 12-layer stack is flattened to ``(B, seq, 12*2560)`` here and unpacked inside the model. +""" + +import numbers + +import torch + +import comfy.text_encoders.qwen3vl +from comfy import sd1_clip + +# tap k == hidden_states[k] (no offset). +KREA2_TAP_LAYERS = [2, 5, 8, 11, 14, 17, 20, 23, 26, 29, 32, 35] + +# Identical system template to Qwen-Image; Krea2 strips the system+user-opening prefix. +KREA2_TEMPLATE = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + +class Krea2Tokenizer(comfy.text_encoders.qwen3vl.Qwen3VLTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, model_type="qwen3vl_4b") + self.llama_template = KREA2_TEMPLATE # conditioning template; image text-gen uses qwen3vl's default image template. + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, thinking=True, **kwargs): + # Krea2 conditions on the no-think template; thinking=True drops the empty block qwen3vl adds. + return super().tokenize_with_weights(text, return_word_ids=return_word_ids, llama_template=llama_template, images=images, prevent_empty_text=prevent_empty_text, thinking=thinking, **kwargs) + + +class Krea2Qwen3VLClipModel(comfy.text_encoders.qwen3vl.Qwen3VLClipModel): + def __init__(self, device="cpu", dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=KREA2_TAP_LAYERS, layer_idx=None, dtype=dtype, + attention_mask=attention_mask, model_options=model_options, model_type="qwen3vl_4b") + + +class Krea2TEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3vl_4b", clip_model=Krea2Qwen3VLClipModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs, template_end=-1): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) # out: (B, 12, seq, 2560) + tok_pairs = token_weight_pairs["qwen3vl_4b"][0] + + # Strip the system + user-opening prefix + count_im_start = 0 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem) and isinstance(elem, numbers.Integral): + if elem == 151644 and count_im_start < 2: + template_end = i + count_im_start += 1 + if out.shape[2] > (template_end + 3): + if tok_pairs[template_end + 1][0] == 872: # "user" + if tok_pairs[template_end + 2][0] == 198: # "\n" + template_end += 3 + + out = out[:, :, template_end:] + + b, n, seq, h = out.shape + # Flatten the 12-layer axis into the feature dim: (B, seq, 12*2560). Unpacked in the model. + out = out.permute(0, 2, 1, 3).reshape(b, seq, n * h) + + if "attention_mask" in extra: + extra["attention_mask"] = extra["attention_mask"][:, template_end:] + if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): + extra.pop("attention_mask") + + return out, pooled, extra + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class Krea2TEModel_(Krea2TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Krea2TEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index 09d783fff..61c2a22dd 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -818,6 +818,44 @@ def z_image_to_diffusers(mmdit_config, output_prefix=""): return key_map +def krea2_to_diffusers(mmdit_config, output_prefix=""): + n_layers = mmdit_config.get("layers", 0) + n_txt_layerwise = 2 # TextFusionTransformer hardcodes 2 layerwise + 2 refiner blocks + n_txt_refiner = 2 + key_map = {} + + def add_block(prefix_to, prefix_from): + block_map = { + "attn.to_q": "attn.wq", "attn.to_k": "attn.wk", "attn.to_v": "attn.wv", + "attn.to_gate": "attn.gate", "attn.to_out.0": "attn.wo", + "attn.to_out": "attn.wo", # some tools drop the ".0" on to_out + "ff.gate": "mlp.gate", "ff.up": "mlp.up", "ff.down": "mlp.down", + } + for d, c in block_map.items(): + key_map["{}.{}.weight".format(prefix_to, d)] = "{}{}.{}.weight".format(output_prefix, prefix_from, c) + + for i in range(n_layers): + add_block("transformer_blocks.{}".format(i), "blocks.{}".format(i)) + for i in range(n_txt_layerwise): + add_block("text_fusion.layerwise_blocks.{}".format(i), "txtfusion.layerwise_blocks.{}".format(i)) + for i in range(n_txt_refiner): + add_block("text_fusion.refiner_blocks.{}".format(i), "txtfusion.refiner_blocks.{}".format(i)) + + MAP_BASIC = [ + ("img_in", "first"), + ("time_embed.linear_1", "tmlp.0"), + ("time_embed.linear_2", "tmlp.2"), + ("time_mod_proj", "tproj.1"), + ("txt_in.linear_1", "txtmlp.1"), + ("txt_in.linear_2", "txtmlp.3"), + ("text_fusion.projector", "txtfusion.projector"), + ("final_layer.linear", "last.linear"), + ] + for d, c in MAP_BASIC: + key_map["{}.weight".format(d)] = "{}{}.weight".format(output_prefix, c) + + return key_map + def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: return tensor.narrow(dim, 0, batch_size) diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index b7b97d70f..1782739fd 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -48,10 +48,13 @@ from comfy_api_nodes.util import ( upload_image_to_comfyapi, upload_video_to_comfyapi, validate_audio_duration, + validate_image_aspect_ratio, + validate_image_dimensions, validate_string, validate_video_duration, ) + RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)") @@ -1657,6 +1660,44 @@ class HappyHorseTextToVideoApi(IO.ComfyNode): IO.DynamicCombo.Input( "model", options=[ + IO.DynamicCombo.Option( + "happyhorse-1.1-t2v", + [ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. " + "Supports English and Chinese.", + ), + IO.Combo.Input( + "resolution", + options=["720P", "1080P"], + ), + IO.Combo.Input( + "ratio", + options=[ + "16:9", + "9:16", + "1:1", + "4:3", + "3:4", + "21:9", + "9:21", + "5:4", + "4:5", + ], + ), + IO.Int.Input( + "duration", + default=5, + min=3, + max=15, + step=1, + display_mode=IO.NumberDisplay.number, + ), + ], + ), IO.DynamicCombo.Option( "happyhorse-1.0-t2v", [ @@ -1719,7 +1760,9 @@ class HappyHorseTextToVideoApi(IO.ComfyNode): ( $res := $lookup(widgets, "model.resolution"); $dur := $lookup(widgets, "model.duration"); - $ppsTable := { "720p": 0.14, "1080p": 0.24 }; + $ppsTable := $contains(widgets.model, "1.1") + ? { "720p": 0.2002, "1080p": 0.2574 } + : { "720p": 0.14, "1080p": 0.24 }; $pps := $lookup($ppsTable, $res); { "type": "usd", "usd": $pps * $dur } ) @@ -1781,6 +1824,30 @@ class HappyHorseImageToVideoApi(IO.ComfyNode): IO.DynamicCombo.Input( "model", options=[ + IO.DynamicCombo.Option( + "happyhorse-1.1-i2v", + [ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the elements and visual features. " + "Supports English and Chinese.", + ), + IO.Combo.Input( + "resolution", + options=["720P", "1080P"], + ), + IO.Int.Input( + "duration", + default=5, + min=3, + max=15, + step=1, + display_mode=IO.NumberDisplay.number, + ), + ], + ), IO.DynamicCombo.Option( "happyhorse-1.0-i2v", [ @@ -1843,7 +1910,9 @@ class HappyHorseImageToVideoApi(IO.ComfyNode): ( $res := $lookup(widgets, "model.resolution"); $dur := $lookup(widgets, "model.duration"); - $ppsTable := { "720p": 0.14, "1080p": 0.24 }; + $ppsTable := $contains(widgets.model, "1.1") + ? { "720p": 0.2002, "1080p": 0.2574 } + : { "720p": 0.14, "1080p": 0.24 }; $pps := $lookup($ppsTable, $res); { "type": "usd", "usd": $pps * $dur } ) @@ -1859,6 +1928,8 @@ class HappyHorseImageToVideoApi(IO.ComfyNode): seed: int, watermark: bool, ): + validate_image_dimensions(first_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1), strict=False) media = [ Wan27MediaItem( type="first_frame", @@ -2053,6 +2124,62 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode): IO.DynamicCombo.Input( "model", options=[ + IO.DynamicCombo.Option( + "happyhorse-1.1-r2v", + [ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Prompt describing the video. Use identifiers such as 'character1' and " + "'character2' to refer to the reference characters.", + ), + IO.Combo.Input( + "resolution", + options=["720P", "1080P"], + ), + IO.Combo.Input( + "ratio", + options=[ + "16:9", + "9:16", + "1:1", + "4:3", + "3:4", + "21:9", + "9:21", + "5:4", + "4:5", + ], + ), + IO.Int.Input( + "duration", + default=5, + min=3, + max=15, + step=1, + display_mode=IO.NumberDisplay.number, + ), + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("reference_image"), + names=[ + "image1", + "image2", + "image3", + "image4", + "image5", + "image6", + "image7", + "image8", + "image9", + ], + min=1, + ), + ), + ], + ), IO.DynamicCombo.Option( "happyhorse-1.0-r2v", [ @@ -2133,7 +2260,9 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode): ( $res := $lookup(widgets, "model.resolution"); $dur := $lookup(widgets, "model.duration"); - $ppsTable := { "720p": 0.14, "1080p": 0.24 }; + $ppsTable := $contains(widgets.model, "1.1") + ? { "720p": 0.2002, "1080p": 0.2574 } + : { "720p": 0.14, "1080p": 0.24 }; $pps := $lookup($ppsTable, $res); { "type": "usd", "usd": $pps * $dur } ) @@ -2149,8 +2278,11 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode): watermark: bool, ): validate_string(model["prompt"], strip_whitespace=False, min_length=1) - media = [] reference_images = model.get("reference_images", {}) + for key in reference_images: + validate_image_dimensions(reference_images[key], min_width=400, min_height=400) + validate_image_aspect_ratio(reference_images[key], (1, 2.5), (2.5, 1), strict=False) + media = [] for key in reference_images: media.append( Wan27MediaItem( @@ -2159,7 +2291,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode): ) ) if not media: - raise ValueError("At least one reference reference image must be provided.") + raise ValueError("At least one reference image must be provided.") initial_response = await sync_op( cls, diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 77f124e28..6adcc95fa 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -158,7 +158,7 @@ class SaveAudio(IO.ComfyNode): return IO.Schema( node_id="SaveAudio", search_aliases=["export flac"], - display_name="Save Audio (FLAC) (Deprecated)", + display_name="Save Audio (FLAC) (DEPRECATED)", category="audio", essentials_category="Audio", inputs=[ @@ -166,8 +166,9 @@ class SaveAudio(IO.ComfyNode): IO.String.Input("filename_prefix", default="audio/ComfyUI"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], - is_output_node=True, is_deprecated=True, + is_output_node=True, + outputs=[IO.Audio.Output("audio")] ) @classmethod @@ -175,11 +176,10 @@ class SaveAudio(IO.ComfyNode): if audio is None: raise ValueError("SaveAudio: input audio is None (source video may have no audio track).") return IO.NodeOutput( + audio, ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format) ) - save_flac = execute # TODO: remove - class SaveAudioMP3(IO.ComfyNode): @classmethod @@ -187,7 +187,7 @@ class SaveAudioMP3(IO.ComfyNode): return IO.Schema( node_id="SaveAudioMP3", search_aliases=["export mp3"], - display_name="Save Audio (MP3) (Deprecated)", + display_name="Save Audio (MP3) (DEPRECATED)", category="audio", essentials_category="Audio", inputs=[ @@ -196,8 +196,9 @@ class SaveAudioMP3(IO.ComfyNode): IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], - is_output_node=True, is_deprecated=True, + is_output_node=True, + outputs=[IO.Audio.Output("audio")] ) @classmethod @@ -205,13 +206,12 @@ class SaveAudioMP3(IO.ComfyNode): if audio is None: raise ValueError("SaveAudioMP3: input audio is None (source video may have no audio track).") return IO.NodeOutput( + audio, ui=UI.AudioSaveHelper.get_save_audio_ui( audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality ) ) - save_mp3 = execute # TODO: remove - class SaveAudioOpus(IO.ComfyNode): @classmethod @@ -219,7 +219,7 @@ class SaveAudioOpus(IO.ComfyNode): return IO.Schema( node_id="SaveAudioOpus", search_aliases=["export opus"], - display_name="Save Audio (Opus) (Deprecated)", + display_name="Save Audio (Opus) (DEPRECATED)", category="audio", inputs=[ IO.Audio.Input("audio"), @@ -227,8 +227,9 @@ class SaveAudioOpus(IO.ComfyNode): IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"), ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], - is_output_node=True, is_deprecated=True, + is_output_node=True, + outputs=[IO.Audio.Output("audio")] ) @classmethod @@ -236,13 +237,12 @@ class SaveAudioOpus(IO.ComfyNode): if audio is None: raise ValueError("SaveAudioOpus: input audio is None (source video may have no audio track).") return IO.NodeOutput( + audio, ui=UI.AudioSaveHelper.get_save_audio_ui( audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality ) ) - save_opus = execute # TODO: remove - class SaveAudioAdvanced(IO.ComfyNode): @classmethod @@ -258,10 +258,7 @@ class SaveAudioAdvanced(IO.ComfyNode): IO.String.Input( "filename_prefix", default="audio/ComfyUI", - tooltip=( - "The prefix for the file to save. May include formatting tokens " - "such as %date:yyyy-MM-dd%." - ), + tooltip=("The prefix for the file to save. May include formatting tokens such as %date:yyyy-MM-dd%."), ), IO.DynamicCombo.Input( "format", @@ -279,6 +276,7 @@ class SaveAudioAdvanced(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Audio.Output("audio")], ) @classmethod @@ -289,7 +287,7 @@ class SaveAudioAdvanced(IO.ComfyNode): ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format, quality=quality) else: ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format) - return IO.NodeOutput(ui=ui) + return IO.NodeOutput(audio, ui=ui) class PreviewAudio(IO.ComfyNode): @@ -305,13 +303,14 @@ class PreviewAudio(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Audio.Output("audio")] ) @classmethod def execute(cls, audio) -> IO.NodeOutput: if audio is None: raise ValueError("PreviewAudio: input audio is None (source video may have no audio track).") - return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls)) + return IO.NodeOutput(audio, ui=UI.PreviewAudio(audio, cls=cls)) save_flac = execute # TODO: remove diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 469a7be55..fe1937ba5 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -214,11 +214,13 @@ class SaveAnimatedWEBP(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Image.Output(display_name="images")] ) @classmethod def execute(cls, images, fps, filename_prefix, lossless, quality, method, num_frames=0) -> IO.NodeOutput: return IO.NodeOutput( + images, ui=UI.ImageSaveHelper.get_save_animated_webp_ui( images=images, filename_prefix=filename_prefix, @@ -230,8 +232,6 @@ class SaveAnimatedWEBP(IO.ComfyNode): ) ) - save_images = execute # TODO: remove - class SaveAnimatedPNG(IO.ComfyNode): @@ -249,11 +249,13 @@ class SaveAnimatedPNG(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Image.Output(display_name="images")] ) @classmethod def execute(cls, images, fps, compress_level, filename_prefix="ComfyUI") -> IO.NodeOutput: return IO.NodeOutput( + images, ui=UI.ImageSaveHelper.get_save_animated_png_ui( images=images, filename_prefix=filename_prefix, @@ -263,8 +265,6 @@ class SaveAnimatedPNG(IO.ComfyNode): ) ) - save_images = execute # TODO: remove - class ImageStitch(IO.ComfyNode): """Upstreamed from https://github.com/kijai/ComfyUI-KJNodes""" @@ -513,6 +513,7 @@ class SaveSVGNode(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.SVG.Output("svg")], ) @classmethod @@ -562,9 +563,7 @@ class SaveSVGNode(IO.ComfyNode): results.append(UI.SavedResult(filename=file, subfolder=subfolder, type=IO.FolderType.output)) counter += 1 - return IO.NodeOutput(ui={"images": results}) - - save_svg = execute # TODO: remove + return IO.NodeOutput(svg, ui={"images": results}) class GetImageSize(IO.ComfyNode): @@ -1157,40 +1156,27 @@ class SaveImageAdvanced(IO.ComfyNode): IO.String.Input( "filename_prefix", default="ComfyUI", - tooltip=( - "The prefix for the file to save. May include formatting tokens " - "such as %date:yyyy-MM-dd% or %Empty Latent Image.width%." - ), + tooltip=("The prefix for the file to save. May include formatting tokens such as %date:yyyy-MM-dd% or %Empty Latent Image.width%."), ), IO.DynamicCombo.Input( "format", options=[ IO.DynamicCombo.Option("png", [ - IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"], - default="8-bit", advanced=True), - IO.Combo.Input("input_color_space", options=["sRGB"], - default="sRGB", advanced=True), + IO.Combo.Input("bit_depth", options=["8-bit", "16-bit"], default="8-bit", advanced=True), + IO.Combo.Input("input_color_space", options=["sRGB"], default="sRGB", advanced=True), ]), IO.DynamicCombo.Option("exr", [ - IO.Combo.Input("bit_depth", options=["32-bit float"], - default="32-bit float", advanced=True), + IO.Combo.Input("bit_depth", options=["32-bit float"], default="32-bit float", advanced=True), IO.Combo.Input( "input_color_space", options=["sRGB", "HDR", "linear"], default="sRGB", advanced=True, tooltip=( - "Colorspace of the input tensor. The EXR is " - "always written as scene-linear in the matching " - "gamut.\n" - " 'sRGB' — input is sRGB-encoded Rec.709; " - "the inverse sRGB EOTF is applied.\n" - " 'HDR' — input is HLG-encoded Rec.2020 " - "(BT.2100); the inverse HLG OETF is applied " - "to get scene-linear light.\n" - " 'linear' — input is already scene-linear " - "(Rec.709 primaries); written through unchanged. " - "Use this for renderer/compositor output." + "Colorspace of the input tensor. The EXR is always written as scene-linear in the matching gamut.\n" + "sRGB — input is sRGB-encoded Rec.709; the inverse sRGB EOTF is applied.\n" + "HDR — input is HLG-encoded Rec.2020 (BT.2100); the inverse HLG OETF is applied to get scene-linear light.\n" + "linear — input is already scene-linear (Rec.709 primaries); written through unchanged. Use this for renderer/compositor output." ), ), ]), @@ -1200,6 +1186,7 @@ class SaveImageAdvanced(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + outputs=[IO.Image.Output(display_name="images")] ) @classmethod @@ -1237,7 +1224,7 @@ class SaveImageAdvanced(IO.ComfyNode): results.append({"filename": file, "subfolder": subfolder, "type": "output"}) counter += 1 - return IO.NodeOutput(ui={"images": results}) + return IO.NodeOutput(images, ui={"images": results}) class ImagesExtension(ComfyExtension): diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index 8d76af1c1..d3acc9ad0 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -27,6 +27,7 @@ class SaveWEBM(io.ComfyNode): ], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, + outputs=[io.Image.Output(display_name="images")] ) @classmethod @@ -69,7 +70,7 @@ class SaveWEBM(io.ComfyNode): container.mux(stream.encode()) container.close() - return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) + return io.NodeOutput(images, ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) class SaveVideo(io.ComfyNode): @classmethod @@ -89,6 +90,7 @@ class SaveVideo(io.ComfyNode): ], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, + outputs=[io.Video.Output("video")], ) @classmethod @@ -117,7 +119,7 @@ class SaveVideo(io.ComfyNode): metadata=saved_metadata ) - return io.NodeOutput(ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) + return io.NodeOutput(video, ui=ui.PreviewVideo([ui.SavedResult(file, subfolder, io.FolderType.output)])) class CreateVideo(io.ComfyNode): diff --git a/extra_model_paths.yaml.example b/extra_model_paths.yaml.example index 9c395c0b2..6a31d8a63 100644 --- a/extra_model_paths.yaml.example +++ b/extra_model_paths.yaml.example @@ -8,21 +8,37 @@ # # You can use is_default to mark that these folders should be listed first, and used as the default dirs for eg downloads # #is_default: true # checkpoints: models/checkpoints/ +# configs: models/configs/ +# loras: models/loras/ +# vae: models/vae/ # text_encoders: | # models/text_encoders/ -# models/clip/ # legacy location still supported -# clip_vision: models/clip_vision/ -# configs: models/configs/ -# controlnet: models/controlnet/ +# models/clip/ # diffusion_models: | -# models/diffusion_models -# models/unet +# models/unet/ +# models/diffusion_models/ +# clip_vision: models/clip_vision/ +# style_models: models/style_models/ # embeddings: models/embeddings/ -# loras: models/loras/ +# diffusers: models/diffusers/ +# vae_approx: models/vae_approx/ +# controlnet: | +# models/controlnet/ +# models/t2i_adapter/ +# gligen: models/gligen/ # upscale_models: models/upscale_models/ -# vae: models/vae/ -# audio_encoders: models/audio_encoders/ +# latent_upscale_models: models/latent_upscale_models/ +# custom_nodes: custom_nodes/ +# hypernetworks: models/hypernetworks/ +# photomaker: models/photomaker/ +# classifiers: models/classifiers/ # model_patches: models/model_patches/ +# audio_encoders: models/audio_encoders/ +# background_removal: models/background_removal/ +# frame_interpolation: models/frame_interpolation/ +# geometry_estimation: models/geometry_estimation/ +# optical_flow: models/optical_flow/ +# detection: models/detection/ #config for a1111 ui @@ -45,8 +61,7 @@ # controlnet: models/ControlNet -# For a full list of supported keys (style_models, vae_approx, hypernetworks, photomaker, -# model_patches, audio_encoders, classifiers, etc.) see folder_paths.py. +# For the canonical list of supported keys and extensions, see folder_paths.py. #other_ui: # base_path: path/to/ui diff --git a/main.py b/main.py index ad5c11e16..aa4ee2adb 100644 --- a/main.py +++ b/main.py @@ -557,8 +557,13 @@ if __name__ == "__main__": logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") if args.disable_dynamic_vram: - logging.warning("Dynamic vram disabled with argument. If you have any issues with dynamic vram enabled please give us a detailed reports as this argument will be removed soon.") - + logging.warning( + "Dynamic vram disabled with argument. If you have any issues with " + "dynamic vram enabled please give us a detailed reports as this " + "argument will be removed soon. If you use gguf we recommend keeping " + "dynamic vram enabled and using native ComfyUI model formats instead. " + "ComfyUI native formats like fp8 will be faster even if they are larger than your memory." + ) event_loop, _, start_all_func = start_comfyui() try: x = start_all_func() diff --git a/nodes.py b/nodes.py index c7fbd3475..51be2eef2 100644 --- a/nodes.py +++ b/nodes.py @@ -480,11 +480,13 @@ class SaveLatent: @classmethod def INPUT_TYPES(s): - return {"required": { "samples": ("LATENT", ), - "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - RETURN_TYPES = () + return { "required": { + "samples": ("LATENT",), + "filename_prefix": ("STRING", {"default": "latents/ComfyUI"})}, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + RETURN_TYPES = ("LATENT",) + RETURN_NAMES = ("samples",) FUNCTION = "save" OUTPUT_NODE = True @@ -522,7 +524,7 @@ class SaveLatent: output["latent_format_version_0"] = torch.tensor([]) comfy.utils.save_torch_file(output, file, metadata=metadata) - return { "ui": { "latents": results } } + return { "ui": { "latents": results }, "result": (samples,) } class LoadLatent: @@ -967,7 +969,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4", "boogu", "krea2"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -1627,14 +1629,18 @@ class SaveImage: return { "required": { "images": ("IMAGE", {"tooltip": "The images to save."}), - "filename_prefix": ("STRING", {"default": "ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."}) + "filename_prefix": ("STRING", { + "default": "ComfyUI", + "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes." + }) }, "hidden": { "prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO" }, } - RETURN_TYPES = () + RETURN_TYPES = ("IMAGE",) + RETURN_NAMES = ("images",) FUNCTION = "save_images" OUTPUT_NODE = True @@ -1670,7 +1676,7 @@ class SaveImage: }) counter += 1 - return { "ui": { "images": results } } + return { "ui": { "images": results }, "result" : (images,) } class PreviewImage(SaveImage): def __init__(self): diff --git a/requirements.txt b/requirements.txt index ad8b1c2ee..0c8b1888e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ comfyui-frontend-package==1.45.19 comfyui-workflow-templates==0.10.0 -comfyui-embedded-docs==0.5.4 +comfyui-embedded-docs==0.5.5 torch torchsde torchvision