From 62315fbb15861e64b917d0a072dad5dc9a15173c Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Mon, 9 Feb 2026 13:16:08 -0800
Subject: [PATCH 01/85] Dynamic VRAM fixes - Ace 1.5 performance + a VRAM leak
(#12368)
* revert threaded model loader change
This change was only needed to get around the pytorch 2.7 mempool bugs,
and should have been reverted along with #12260. This fixes a different
memory leak where pytorch gets confused about cache emptying.
* load non comfy weights
* MPDynamic: Pre-generate the tensors for vbars
Apparently this is an expensive operation that slows down things.
* bump to aimdo 1.8
New features:
watermark limit feature
logging enhancements
-O2 build on linux
---
comfy/model_management.py | 37 ++++++-------------------------------
comfy/model_patcher.py | 7 ++++++-
comfy/ops.py | 2 +-
execution.py | 7 ++++++-
requirements.txt | 2 +-
5 files changed, 20 insertions(+), 35 deletions(-)
diff --git a/comfy/model_management.py b/comfy/model_management.py
index b6291f340..6018c1ab6 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -19,7 +19,7 @@
import psutil
import logging
from enum import Enum
-from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
+from comfy.cli_args import args, PerformanceFeature
import threading
import torch
import sys
@@ -651,7 +651,7 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
soft_empty_cache()
return unloaded_models
-def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
+def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
cleanup_models_gc()
global vram_state
@@ -747,26 +747,6 @@ def load_models_gpu_orig(models, memory_required=0, force_patch_weights=False, m
current_loaded_models.insert(0, loaded_model)
return
-def load_models_gpu_thread(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load):
- with torch.inference_mode():
- load_models_gpu_orig(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
- soft_empty_cache()
-
-def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False):
- #Deliberately load models outside of the Aimdo mempool so they can be retained accross
- #nodes. Use a dummy thread to do it as pytorch documents that mempool contexts are
- #thread local. So exploit that to escape context
- if enables_dynamic_vram():
- t = threading.Thread(
- target=load_models_gpu_thread,
- args=(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
- )
- t.start()
- t.join()
- else:
- load_models_gpu_orig(models, memory_required=memory_required, force_patch_weights=force_patch_weights,
- minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
-
def load_model_gpu(model):
return load_models_gpu([model])
@@ -1226,21 +1206,16 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
if dtype is None:
dtype = weight._model_dtype
- r = torch.empty_like(weight, dtype=dtype, device=device)
-
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
- raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
- v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
+ v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0]
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
- #always take a deep copy even if _v is good, as we have no reasonable point to unpin
- #a non comfy weight
- r.copy_(v_tensor)
- comfy_aimdo.model_vbar.vbar_unpin(weight._v)
- return r
+ return v_tensor.to(dtype=dtype)
+
+ r = torch.empty_like(weight, dtype=dtype, device=device)
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index d888dbcfb..b9a117a7c 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -1492,7 +1492,9 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None:
vbar.prioritize()
- #We have way more tools for acceleration on comfy weight offloading, so always
+ #We force reserve VRAM for the non comfy-weight so we dont have to deal
+ #with pin and unpin syncrhonization which can be expensive for small weights
+ #with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True)
loading.sort(reverse=True)
@@ -1541,6 +1543,7 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
+ m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to)
allocated_size += v_weight_size
else:
@@ -1555,8 +1558,10 @@ class ModelPatcherDynamic(ModelPatcher):
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
+ weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to)
weight._model_dtype = model_dtype
allocated_size += weight_size
+ vbar.set_watermark_limit(allocated_size)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
diff --git a/comfy/ops.py b/comfy/ops.py
index 0f4eca7c7..ea0d70702 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -87,7 +87,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
if signature is not None:
- xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
+ xfer_dest = s._v_tensor
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if not resident:
diff --git a/execution.py b/execution.py
index 3dbab82e6..896862c6b 100644
--- a/execution.py
+++ b/execution.py
@@ -13,8 +13,11 @@ from contextlib import nullcontext
import torch
+from comfy.cli_args import args
import comfy.memory_management
import comfy.model_management
+import comfy_aimdo.model_vbar
+
from latent_preview import set_preview_method
import nodes
from comfy_execution.caching import (
@@ -527,8 +530,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
finally:
if allocator is not None:
+ if args.verbose == "DEBUG":
+ comfy_aimdo.model_vbar.vbars_analyze()
comfy.model_management.reset_cast_buffers()
- torch.cuda.synchronize()
+ comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
if has_pending_tasks:
pending_async_nodes[unique_id] = output_data
diff --git a/requirements.txt b/requirements.txt
index 5e34a2a49..4fda07fde 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -22,7 +22,7 @@ alembic
SQLAlchemy
av>=14.2.0
comfy-kitchen>=0.2.7
-comfy-aimdo>=0.1.7
+comfy-aimdo>=0.1.8
requests
#non essential dependencies:
From baf8c874557f1522a99d47d94faad12b0257c8f1 Mon Sep 17 00:00:00 2001
From: blepping <157360029+blepping@users.noreply.github.com>
Date: Mon, 9 Feb 2026 17:41:49 -0700
Subject: [PATCH 02/85] Iimprovements to ACE-Steps 1.5 text encoding (part 2)
(#12350)
---
comfy/text_encoders/ace15.py | 114 +++++++++++++++++++++++++----------
1 file changed, 81 insertions(+), 33 deletions(-)
diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py
index 00dd5ba90..5dac644c2 100644
--- a/comfy/text_encoders/ace15.py
+++ b/comfy/text_encoders/ace15.py
@@ -3,6 +3,7 @@ import comfy.text_encoders.llama
from comfy import sd1_clip
import torch
import math
+from tqdm.auto import trange
import yaml
import comfy.utils
@@ -23,6 +24,8 @@ def sample_manual_loop_no_classes(
audio_end_id: int = 215669,
eos_token_id: int = 151645,
):
+ if ids is None:
+ return []
device = model.execution_device
if execution_dtype is None:
@@ -32,6 +35,7 @@ def sample_manual_loop_no_classes(
execution_dtype = torch.float32
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
+ embeds_batch = embeds.shape[0]
for i, t in enumerate(paddings):
attention_mask[i, :t] = 0
attention_mask[i, t:] = 1
@@ -41,22 +45,27 @@ def sample_manual_loop_no_classes(
generator = torch.Generator(device=device)
generator.manual_seed(seed)
model_config = model.transformer.model.config
+ past_kv_shape = [embeds_batch, model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim]
for x in range(model_config.num_hidden_layers):
- past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), torch.empty([embeds.shape[0], model_config.num_key_value_heads, embeds.shape[1] + min_tokens, model_config.head_dim], device=device, dtype=execution_dtype), 0))
+ past_key_values.append((torch.empty(past_kv_shape, device=device, dtype=execution_dtype), torch.empty(past_kv_shape, device=device, dtype=execution_dtype), 0))
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
- for step in range(max_new_tokens):
+ for step in trange(max_new_tokens, desc="LM sampling"):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
past_key_values = outputs[2]
- cond_logits = next_token_logits[0:1]
- uncond_logits = next_token_logits[1:2]
- cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
+ if cfg_scale != 1.0:
+ cond_logits = next_token_logits[0:1]
+ uncond_logits = next_token_logits[1:2]
+ cfg_logits = uncond_logits + cfg_scale * (cond_logits - uncond_logits)
+ else:
+ cfg_logits = next_token_logits[0:1]
- if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
+ use_eos_score = eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step
+ if use_eos_score:
eos_score = cfg_logits[:, eos_token_id].clone()
remove_logit_value = torch.finfo(cfg_logits.dtype).min
@@ -64,7 +73,7 @@ def sample_manual_loop_no_classes(
cfg_logits[:, :audio_start_id] = remove_logit_value
cfg_logits[:, audio_end_id:] = remove_logit_value
- if eos_token_id is not None and eos_token_id < audio_start_id and min_tokens < step:
+ if use_eos_score:
cfg_logits[:, eos_token_id] = eos_score
if top_k is not None and top_k > 0:
@@ -93,8 +102,8 @@ def sample_manual_loop_no_classes(
break
embed, _, _, _ = model.process_tokens([[token]], device)
- embeds = embed.repeat(2, 1, 1)
- attention_mask = torch.cat([attention_mask, torch.ones((2, 1), device=device, dtype=attention_mask.dtype)], dim=1)
+ embeds = embed.repeat(embeds_batch, 1, 1)
+ attention_mask = torch.cat([attention_mask, torch.ones((embeds_batch, 1), device=device, dtype=attention_mask.dtype)], dim=1)
output_audio_codes.append(token - audio_start_id)
progress_bar.update_absolute(step)
@@ -104,22 +113,29 @@ def sample_manual_loop_no_classes(
def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
positive = [[token for token, _ in inner_list] for inner_list in positive]
- negative = [[token for token, _ in inner_list] for inner_list in negative]
positive = positive[0]
- negative = negative[0]
- neg_pad = 0
- if len(negative) < len(positive):
- neg_pad = (len(positive) - len(negative))
- negative = [model.special_tokens["pad"]] * neg_pad + negative
+ if cfg_scale != 1.0:
+ negative = [[token for token, _ in inner_list] for inner_list in negative]
+ negative = negative[0]
- pos_pad = 0
- if len(negative) > len(positive):
- pos_pad = (len(negative) - len(positive))
- positive = [model.special_tokens["pad"]] * pos_pad + positive
+ neg_pad = 0
+ if len(negative) < len(positive):
+ neg_pad = (len(positive) - len(negative))
+ negative = [model.special_tokens["pad"]] * neg_pad + negative
- paddings = [pos_pad, neg_pad]
- return sample_manual_loop_no_classes(model, [positive, negative], paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
+ pos_pad = 0
+ if len(negative) > len(positive):
+ pos_pad = (len(negative) - len(positive))
+ positive = [model.special_tokens["pad"]] * pos_pad + positive
+
+ paddings = [pos_pad, neg_pad]
+ ids = [positive, negative]
+ else:
+ paddings = []
+ ids = [positive]
+
+ return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
@@ -129,12 +145,12 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
user_metas = {
k: kwargs.pop(k)
- for k in ("bpm", "duration", "keyscale", "timesignature", "language", "caption")
+ for k in ("bpm", "duration", "keyscale", "timesignature", "language")
if k in kwargs
}
timesignature = user_metas.get("timesignature")
if isinstance(timesignature, str) and timesignature.endswith("/4"):
- user_metas["timesignature"] = timesignature.rsplit("/", 1)[0]
+ user_metas["timesignature"] = timesignature[:-2]
user_metas = {
k: v if not isinstance(v, str) or not v.isdigit() else int(v)
for k, v in user_metas.items()
@@ -147,8 +163,11 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return f"\n{meta_yaml}\n" if not return_yaml else meta_yaml
def _metas_to_cap(self, **kwargs) -> str:
- use_keys = ("bpm", "duration", "keyscale", "timesignature")
+ use_keys = ("bpm", "timesignature", "keyscale", "duration")
user_metas = { k: kwargs.pop(k, "N/A") for k in use_keys }
+ timesignature = user_metas.get("timesignature")
+ if isinstance(timesignature, str) and timesignature.endswith("/4"):
+ user_metas["timesignature"] = timesignature[:-2]
duration = user_metas["duration"]
if duration == "N/A":
user_metas["duration"] = "30 seconds"
@@ -159,9 +178,13 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
return "\n".join(f"- {k}: {user_metas[k]}" for k in use_keys)
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs):
- out = {}
+ text = text.strip()
+ text_negative = kwargs.get("caption_negative", text).strip()
lyrics = kwargs.get("lyrics", "")
+ lyrics_negative = kwargs.get("lyrics_negative", lyrics)
duration = kwargs.get("duration", 120)
+ if isinstance(duration, str):
+ duration = float(duration.split(None, 1)[0])
language = kwargs.get("language")
seed = kwargs.get("seed", 0)
@@ -171,21 +194,46 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
top_p = kwargs.get("top_p", 0.9)
top_k = kwargs.get("top_k", 0.0)
-
duration = math.ceil(duration)
kwargs["duration"] = duration
+ tokens_duration = duration * 5
+ min_tokens = int(kwargs.get("min_tokens", tokens_duration))
+ max_tokens = int(kwargs.get("max_tokens", tokens_duration))
+
+ metas_negative = {
+ k.rsplit("_", 1)[0]: kwargs.pop(k)
+ for k in ("bpm_negative", "duration_negative", "keyscale_negative", "timesignature_negative", "language_negative", "caption_negative")
+ if k in kwargs
+ }
+ if not kwargs.get("use_negative_caption"):
+ _ = metas_negative.pop("caption", None)
cot_text = self._metas_to_cot(caption = text, **kwargs)
+ cot_text_negative = "\n" if not metas_negative else self._metas_to_cot(**metas_negative)
meta_cap = self._metas_to_cap(**kwargs)
- lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n<|im_end|>\n"
+ lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
+ lyrics_template = "# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>"
+ qwen3_06b_template = "# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>"
- out["lm_prompt"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, cot_text), disable_weights=True)
- out["lm_prompt_negative"] = self.qwen3_06b.tokenize_with_weights(lm_template.format(text, lyrics, "\n"), disable_weights=True)
+ llm_prompts = {
+ "lm_prompt": lm_template.format(text, lyrics.strip(), cot_text),
+ "lm_prompt_negative": lm_template.format(text_negative, lyrics_negative.strip(), cot_text_negative),
+ "lyrics": lyrics_template.format(language if language is not None else "", lyrics),
+ "qwen3_06b": qwen3_06b_template.format(text, meta_cap),
+ }
- out["lyrics"] = self.qwen3_06b.tokenize_with_weights("# Languages\n{}\n\n# Lyric\n{}<|endoftext|><|endoftext|>".format(language if language is not None else "", lyrics), return_word_ids, disable_weights=True, **kwargs)
- out["qwen3_06b"] = self.qwen3_06b.tokenize_with_weights("# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n# Caption\n{}\n# Metas\n{}\n<|endoftext|>\n<|endoftext|>".format(text, meta_cap), return_word_ids, **kwargs)
- out["lm_metadata"] = {"min_tokens": duration * 5,
+ out = {
+ prompt_key: self.qwen3_06b.tokenize_with_weights(
+ prompt,
+ prompt_key == "qwen3_06b" and return_word_ids,
+ disable_weights = True,
+ **kwargs,
+ )
+ for prompt_key, prompt in llm_prompts.items()
+ }
+ out["lm_metadata"] = {"min_tokens": min_tokens,
+ "max_tokens": max_tokens,
"seed": seed,
"generate_audio_codes": generate_audio_codes,
"cfg_scale": cfg_scale,
@@ -252,7 +300,7 @@ class ACE15TEModel(torch.nn.Module):
lm_metadata = token_weight_pairs["lm_metadata"]
if lm_metadata["generate_audio_codes"]:
- audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
+ audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["max_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
out["audio_codes"] = [audio_codes]
return base_out, None, out
From a4be04c5d750cc5d62256f7f86bb5a7c0a78e28d Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Mon, 9 Feb 2026 16:45:56 -0800
Subject: [PATCH 03/85] Ace step prompts match now. (#12376)
---
comfy/text_encoders/ace15.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py
index 5dac644c2..73697b3c1 100644
--- a/comfy/text_encoders/ace15.py
+++ b/comfy/text_encoders/ace15.py
@@ -145,7 +145,7 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
def _metas_to_cot(self, *, return_yaml: bool = False, **kwargs) -> str:
user_metas = {
k: kwargs.pop(k)
- for k in ("bpm", "duration", "keyscale", "timesignature", "language")
+ for k in ("bpm", "duration", "keyscale", "timesignature")
if k in kwargs
}
timesignature = user_metas.get("timesignature")
@@ -208,8 +208,8 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
if not kwargs.get("use_negative_caption"):
_ = metas_negative.pop("caption", None)
- cot_text = self._metas_to_cot(caption = text, **kwargs)
- cot_text_negative = "\n" if not metas_negative else self._metas_to_cot(**metas_negative)
+ cot_text = self._metas_to_cot(caption=text, **kwargs)
+ cot_text_negative = "\n\n" if not metas_negative else self._metas_to_cot(**metas_negative)
meta_cap = self._metas_to_cap(**kwargs)
lm_template = "<|im_start|>system\n# Instruction\nGenerate audio semantic tokens based on the given conditions:\n\n<|im_end|>\n<|im_start|>user\n# Caption\n{}\n\n# Lyric\n{}\n<|im_end|>\n<|im_start|>assistant\n{}\n\n<|im_end|>\n"
From 349a636a2b0f15aba2930b9af905bb805d2fe30b Mon Sep 17 00:00:00 2001
From: ComfyUI Wiki
Date: Tue, 10 Feb 2026 10:25:34 +0800
Subject: [PATCH 04/85] chore: update workflow templates to v0.8.37 (#12377)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 4fda07fde..4e2773f5d 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
comfyui-frontend-package==1.38.13
-comfyui-workflow-templates==0.8.31
+comfyui-workflow-templates==0.8.37
comfyui-embedded-docs==0.4.1
torch
torchsde
From c1b63a7e78b606bc14cd49a02e9338274db28a60 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Tue, 10 Feb 2026 04:58:27 +0200
Subject: [PATCH 05/85] fix(Moonvalley-API-Nodes): adjust "steps" parameter to
not raise exception (#12370)
---
comfy_api_nodes/nodes_moonvalley.py | 14 +++++++-------
1 file changed, 7 insertions(+), 7 deletions(-)
diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py
index 08315fa2b..78a230529 100644
--- a/comfy_api_nodes/nodes_moonvalley.py
+++ b/comfy_api_nodes/nodes_moonvalley.py
@@ -219,8 +219,8 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
- default=33,
- min=1,
+ default=80,
+ min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
max=100,
step=1,
tooltip="Number of denoising steps",
@@ -340,8 +340,8 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
- default=33,
- min=1,
+ default=60,
+ min=60, # steps should be greater or equal to cooldown_steps(36) + warmup_steps(24)
max=100,
step=1,
display_mode=IO.NumberDisplay.number,
@@ -370,7 +370,7 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
video: Input.Video | None = None,
control_type: str = "Motion Transfer",
motion_intensity: int | None = 100,
- steps=33,
+ steps=60,
prompt_adherence=4.5,
) -> IO.NodeOutput:
validated_video = validate_video_to_video_input(video)
@@ -465,8 +465,8 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode):
),
IO.Int.Input(
"steps",
- default=33,
- min=1,
+ default=80,
+ min=75, # steps should be greater or equal to cooldown_steps(75) + warmup_steps(0)
max=100,
step=1,
tooltip="Inference steps",
From 8ca842a8edb26006e730e631ec1153cd42f46d3b Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Tue, 10 Feb 2026 19:34:54 +0200
Subject: [PATCH 06/85] feat(api-nodes-Kling): add new models (V3, O3) (#12389)
* feat(api-nodes-Kling): add new models (V3, O3)
* remove storyboard from VideoToVideo node
* added check for total duration of storyboards
* fixed other small things
* updated display name for nodes
* added "fake" seed
---
comfy_api_nodes/apis/__init__.py | 8 +-
comfy_api_nodes/apis/kling.py | 46 +-
comfy_api_nodes/nodes_kling.py | 764 ++++++++++++++++++++++++++++---
3 files changed, 750 insertions(+), 68 deletions(-)
diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py
index ee2aa1ce6..46a583b5e 100644
--- a/comfy_api_nodes/apis/__init__.py
+++ b/comfy_api_nodes/apis/__init__.py
@@ -1197,12 +1197,6 @@ class KlingImageGenImageReferenceType(str, Enum):
face = 'face'
-class KlingImageGenModelName(str, Enum):
- kling_v1 = 'kling-v1'
- kling_v1_5 = 'kling-v1-5'
- kling_v2 = 'kling-v2'
-
-
class KlingImageGenerationsRequest(BaseModel):
aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9'
callback_url: Optional[AnyUrl] = Field(
@@ -1218,7 +1212,7 @@ class KlingImageGenerationsRequest(BaseModel):
0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0
)
image_reference: Optional[KlingImageGenImageReferenceType] = None
- model_name: Optional[KlingImageGenModelName] = 'kling-v1'
+ model_name: str = Field(...)
n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9)
negative_prompt: Optional[str] = Field(
None, description='Negative text prompt', max_length=200
diff --git a/comfy_api_nodes/apis/kling.py b/comfy_api_nodes/apis/kling.py
index bf54ede3e..9c0446075 100644
--- a/comfy_api_nodes/apis/kling.py
+++ b/comfy_api_nodes/apis/kling.py
@@ -1,12 +1,22 @@
from pydantic import BaseModel, Field
+class MultiPromptEntry(BaseModel):
+ index: int = Field(...)
+ prompt: str = Field(...)
+ duration: str = Field(...)
+
+
class OmniProText2VideoRequest(BaseModel):
model_name: str = Field(..., description="kling-video-o1")
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
+ sound: str = Field(..., description="'on' or 'off'")
class OmniParamImage(BaseModel):
@@ -26,6 +36,10 @@ class OmniProFirstLastFrameRequest(BaseModel):
duration: str = Field(..., description="'5' or '10'")
prompt: str = Field(...)
mode: str = Field("pro")
+ sound: str | None = Field(None, description="'on' or 'off'")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
class OmniProReferences2VideoRequest(BaseModel):
@@ -38,6 +52,10 @@ class OmniProReferences2VideoRequest(BaseModel):
duration: str | None = Field(..., description="From 3 to 10.")
prompt: str = Field(...)
mode: str = Field("pro")
+ sound: str | None = Field(None, description="'on' or 'off'")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
class TaskStatusVideoResult(BaseModel):
@@ -54,6 +72,7 @@ class TaskStatusImageResult(BaseModel):
class TaskStatusResults(BaseModel):
videos: list[TaskStatusVideoResult] | None = Field(None)
images: list[TaskStatusImageResult] | None = Field(None)
+ series_images: list[TaskStatusImageResult] | None = Field(None)
class TaskStatusResponseData(BaseModel):
@@ -77,31 +96,42 @@ class OmniImageParamImage(BaseModel):
class OmniProImageRequest(BaseModel):
- model_name: str = Field(..., description="kling-image-o1")
- resolution: str = Field(..., description="'1k' or '2k'")
+ model_name: str = Field(...)
+ resolution: str = Field(...)
aspect_ratio: str | None = Field(...)
prompt: str = Field(...)
mode: str = Field("pro")
n: int | None = Field(1, le=9)
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
+ result_type: str | None = Field(None, description="Set to 'series' for series generation")
+ series_amount: int | None = Field(None, ge=2, le=9, description="Number of images in a series")
class TextToVideoWithAudioRequest(BaseModel):
- model_name: str = Field(..., description="kling-v2-6")
+ model_name: str = Field(...)
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
- duration: str = Field(..., description="'5' or '10'")
- prompt: str = Field(...)
+ duration: str = Field(...)
+ prompt: str | None = Field(...)
+ negative_prompt: str | None = Field(None)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
class ImageToVideoWithAudioRequest(BaseModel):
- model_name: str = Field(..., description="kling-v2-6")
+ model_name: str = Field(...)
image: str = Field(...)
- duration: str = Field(..., description="'5' or '10'")
- prompt: str = Field(...)
+ image_tail: str | None = Field(None)
+ duration: str = Field(...)
+ prompt: str | None = Field(...)
+ negative_prompt: str | None = Field(None)
mode: str = Field("pro")
sound: str = Field(..., description="'on' or 'off'")
+ multi_shot: bool | None = Field(None)
+ multi_prompt: list[MultiPromptEntry] | None = Field(None)
+ shot_type: str | None = Field(None)
class MotionControlRequest(BaseModel):
diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py
index 739fe1855..b89c85561 100644
--- a/comfy_api_nodes/nodes_kling.py
+++ b/comfy_api_nodes/nodes_kling.py
@@ -38,7 +38,6 @@ from comfy_api_nodes.apis import (
KlingImageGenerationsRequest,
KlingImageGenerationsResponse,
KlingImageGenImageReferenceType,
- KlingImageGenModelName,
KlingImageGenAspectRatio,
KlingVideoEffectsRequest,
KlingVideoEffectsResponse,
@@ -52,6 +51,7 @@ from comfy_api_nodes.apis import (
from comfy_api_nodes.apis.kling import (
ImageToVideoWithAudioRequest,
MotionControlRequest,
+ MultiPromptEntry,
OmniImageParamImage,
OmniParamImage,
OmniParamVideo,
@@ -71,6 +71,7 @@ from comfy_api_nodes.util import (
sync_op,
tensor_to_base64_string,
upload_audio_to_comfyapi,
+ upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_image_aspect_ratio,
@@ -80,6 +81,31 @@ from comfy_api_nodes.util import (
validate_video_duration,
)
+
+def _generate_storyboard_inputs(count: int) -> list:
+ inputs = []
+ for i in range(1, count + 1):
+ inputs.extend(
+ [
+ IO.String.Input(
+ f"storyboard_{i}_prompt",
+ multiline=True,
+ default="",
+ tooltip=f"Prompt for storyboard segment {i}. Max 512 characters.",
+ ),
+ IO.Int.Input(
+ f"storyboard_{i}_duration",
+ default=4,
+ min=1,
+ max=15,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip=f"Duration for storyboard segment {i} in seconds.",
+ ),
+ ]
+ )
+ return inputs
+
+
KLING_API_VERSION = "v1"
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
PATH_IMAGE_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/image2video"
@@ -820,20 +846,48 @@ class OmniProTextToVideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProTextToVideoNode",
- display_name="Kling Omni Text to Video (Pro)",
+ display_name="Kling 3.0 Omni Text to Video",
category="api node/video/Kling",
description="Use text prompts to generate videos with the latest Kling model.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
- "This can include both positive and negative descriptions.",
+ "This can include both positive and negative descriptions. "
+ "Ignored when storyboards are enabled.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
- IO.Combo.Input("duration", options=[5, 10]),
+ IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.DynamicCombo.Input(
+ "storyboards",
+ options=[
+ IO.DynamicCombo.Option("disabled", []),
+ IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)),
+ IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)),
+ IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)),
+ IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)),
+ IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)),
+ IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)),
+ ],
+ tooltip="Generate a series of video segments with individual prompts and durations. "
+ "Ignored for o1 model.",
+ optional=True,
+ ),
+ IO.Boolean.Input("generate_audio", default=False, optional=True),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -845,11 +899,15 @@ class OmniProTextToVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
- $rates := {"std": 0.084, "pro": 0.112};
+ $isV3 := $contains(widgets.model_name, "v3");
+ $audio := $isV3 and widgets.generate_audio;
+ $rates := $audio
+ ? {"std": 0.112, "pro": 0.14}
+ : {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
@@ -864,8 +922,45 @@ class OmniProTextToVideoNode(IO.ComfyNode):
aspect_ratio: str,
duration: int,
resolution: str = "1080p",
+ storyboards: dict | None = None,
+ generate_audio: bool = False,
+ seed: int = 0,
) -> IO.NodeOutput:
- validate_string(prompt, min_length=1, max_length=2500)
+ _ = seed
+ if model_name == "kling-video-o1":
+ if duration not in (5, 10):
+ raise ValueError("kling-video-o1 only supports durations of 5 or 10 seconds.")
+ if generate_audio:
+ raise ValueError("kling-video-o1 does not support audio generation.")
+ stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
+ if stories_enabled and model_name == "kling-video-o1":
+ raise ValueError("kling-video-o1 does not support storyboards.")
+ validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500)
+
+ multi_shot = None
+ multi_prompt_list = None
+ if stories_enabled:
+ count = int(storyboards["storyboards"].split()[0])
+ multi_shot = True
+ multi_prompt_list = []
+ for i in range(1, count + 1):
+ sb_prompt = storyboards[f"storyboard_{i}_prompt"]
+ sb_duration = storyboards[f"storyboard_{i}_duration"]
+ validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512)
+ multi_prompt_list.append(
+ MultiPromptEntry(
+ index=i,
+ prompt=sb_prompt,
+ duration=str(sb_duration),
+ )
+ )
+ total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list)
+ if total_storyboard_duration != duration:
+ raise ValueError(
+ f"Total storyboard duration ({total_storyboard_duration}s) "
+ f"must equal the global duration ({duration}s)."
+ )
+
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
@@ -876,6 +971,10 @@ class OmniProTextToVideoNode(IO.ComfyNode):
aspect_ratio=aspect_ratio,
duration=str(duration),
mode="pro" if resolution == "1080p" else "std",
+ multi_shot=multi_shot,
+ multi_prompt=multi_prompt_list,
+ shot_type="customize" if multi_shot else None,
+ sound="on" if generate_audio else "off",
),
)
return await finish_omni_video_task(cls, response)
@@ -887,24 +986,26 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProFirstLastFrameNode",
- display_name="Kling Omni First-Last-Frame to Video (Pro)",
+ display_name="Kling 3.0 Omni First-Last-Frame to Video",
category="api node/video/Kling",
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
- "This can include both positive and negative descriptions.",
+ "This can include both positive and negative descriptions. "
+ "Ignored when storyboards are enabled.",
),
- IO.Int.Input("duration", default=5, min=3, max=10, display_mode=IO.NumberDisplay.slider),
+ IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
IO.Image.Input("first_frame"),
IO.Image.Input(
"end_frame",
optional=True,
tooltip="An optional end frame for the video. "
- "This cannot be used simultaneously with 'reference_images'.",
+ "This cannot be used simultaneously with 'reference_images'. "
+ "Does not work with storyboards.",
),
IO.Image.Input(
"reference_images",
@@ -912,6 +1013,38 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
tooltip="Up to 6 additional reference images.",
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.DynamicCombo.Input(
+ "storyboards",
+ options=[
+ IO.DynamicCombo.Option("disabled", []),
+ IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)),
+ IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)),
+ IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)),
+ IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)),
+ IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)),
+ IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)),
+ ],
+ tooltip="Generate a series of video segments with individual prompts and durations. "
+ "Only supported for kling-v3-omni.",
+ optional=True,
+ ),
+ IO.Boolean.Input(
+ "generate_audio",
+ default=False,
+ optional=True,
+ tooltip="Generate audio for the video. Only supported for kling-v3-omni.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -923,11 +1056,15 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
- $rates := {"std": 0.084, "pro": 0.112};
+ $isV3 := $contains(widgets.model_name, "v3");
+ $audio := $isV3 and widgets.generate_audio;
+ $rates := $audio
+ ? {"std": 0.112, "pro": 0.14}
+ : {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
@@ -944,15 +1081,59 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
end_frame: Input.Image | None = None,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
+ storyboards: dict | None = None,
+ generate_audio: bool = False,
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
+ if model_name == "kling-video-o1":
+ if duration > 10:
+ raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
+ if generate_audio:
+ raise ValueError("kling-video-o1 does not support audio generation.")
+ stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
+ if stories_enabled and model_name == "kling-video-o1":
+ raise ValueError("kling-video-o1 does not support storyboards.")
prompt = normalize_omni_prompt_references(prompt)
- validate_string(prompt, min_length=1, max_length=2500)
+ validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500)
if end_frame is not None and reference_images is not None:
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
- if duration not in (5, 10) and end_frame is None and reference_images is None:
+ if end_frame is not None and stories_enabled:
+ raise ValueError("The 'end_frame' input cannot be used simultaneously with storyboards.")
+ if (
+ model_name == "kling-video-o1"
+ and duration not in (5, 10)
+ and end_frame is None
+ and reference_images is None
+ ):
raise ValueError(
"Duration is only supported for 5 or 10 seconds if there is no end frame or reference images."
)
+
+ multi_shot = None
+ multi_prompt_list = None
+ if stories_enabled:
+ count = int(storyboards["storyboards"].split()[0])
+ multi_shot = True
+ multi_prompt_list = []
+ for i in range(1, count + 1):
+ sb_prompt = storyboards[f"storyboard_{i}_prompt"]
+ sb_duration = storyboards[f"storyboard_{i}_duration"]
+ validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512)
+ multi_prompt_list.append(
+ MultiPromptEntry(
+ index=i,
+ prompt=sb_prompt,
+ duration=str(sb_duration),
+ )
+ )
+ total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list)
+ if total_storyboard_duration != duration:
+ raise ValueError(
+ f"Total storyboard duration ({total_storyboard_duration}s) "
+ f"must equal the global duration ({duration}s)."
+ )
+
validate_image_dimensions(first_frame, min_width=300, min_height=300)
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
image_list: list[OmniParamImage] = [
@@ -988,6 +1169,10 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
duration=str(duration),
image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
+ sound="on" if generate_audio else "off",
+ multi_shot=multi_shot,
+ multi_prompt=multi_prompt_list,
+ shot_type="customize" if multi_shot else None,
),
)
return await finish_omni_video_task(cls, response)
@@ -999,24 +1184,57 @@ class OmniProImageToVideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProImageToVideoNode",
- display_name="Kling Omni Image to Video (Pro)",
+ display_name="Kling 3.0 Omni Image to Video",
category="api node/video/Kling",
description="Use up to 7 reference images to generate a video with the latest Kling model.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the video content. "
- "This can include both positive and negative descriptions.",
+ "This can include both positive and negative descriptions. "
+ "Ignored when storyboards are enabled.",
),
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
- IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
+ IO.Int.Input("duration", default=5, min=3, max=15, display_mode=IO.NumberDisplay.slider),
IO.Image.Input(
"reference_images",
tooltip="Up to 7 reference images.",
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.DynamicCombo.Input(
+ "storyboards",
+ options=[
+ IO.DynamicCombo.Option("disabled", []),
+ IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)),
+ IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)),
+ IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)),
+ IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)),
+ IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)),
+ IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)),
+ ],
+ tooltip="Generate a series of video segments with individual prompts and durations. "
+ "Only supported for kling-v3-omni.",
+ optional=True,
+ ),
+ IO.Boolean.Input(
+ "generate_audio",
+ default=False,
+ optional=True,
+ tooltip="Generate audio for the video. Only supported for kling-v3-omni.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -1028,11 +1246,15 @@ class OmniProImageToVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution", "model_name", "generate_audio"]),
expr="""
(
$mode := (widgets.resolution = "720p") ? "std" : "pro";
- $rates := {"std": 0.084, "pro": 0.112};
+ $isV3 := $contains(widgets.model_name, "v3");
+ $audio := $isV3 and widgets.generate_audio;
+ $rates := $audio
+ ? {"std": 0.112, "pro": 0.14}
+ : {"std": 0.084, "pro": 0.112};
{"type":"usd","usd": $lookup($rates, $mode) * widgets.duration}
)
""",
@@ -1048,9 +1270,46 @@ class OmniProImageToVideoNode(IO.ComfyNode):
duration: int,
reference_images: Input.Image,
resolution: str = "1080p",
+ storyboards: dict | None = None,
+ generate_audio: bool = False,
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
+ if model_name == "kling-video-o1":
+ if duration > 10:
+ raise ValueError("kling-video-o1 does not support durations greater than 10 seconds.")
+ if generate_audio:
+ raise ValueError("kling-video-o1 does not support audio generation.")
+ stories_enabled = storyboards is not None and storyboards["storyboards"] != "disabled"
+ if stories_enabled and model_name == "kling-video-o1":
+ raise ValueError("kling-video-o1 does not support storyboards.")
prompt = normalize_omni_prompt_references(prompt)
- validate_string(prompt, min_length=1, max_length=2500)
+ validate_string(prompt, strip_whitespace=True, min_length=0 if stories_enabled else 1, max_length=2500)
+
+ multi_shot = None
+ multi_prompt_list = None
+ if stories_enabled:
+ count = int(storyboards["storyboards"].split()[0])
+ multi_shot = True
+ multi_prompt_list = []
+ for i in range(1, count + 1):
+ sb_prompt = storyboards[f"storyboard_{i}_prompt"]
+ sb_duration = storyboards[f"storyboard_{i}_duration"]
+ validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512)
+ multi_prompt_list.append(
+ MultiPromptEntry(
+ index=i,
+ prompt=sb_prompt,
+ duration=str(sb_duration),
+ )
+ )
+ total_storyboard_duration = sum(int(e.duration) for e in multi_prompt_list)
+ if total_storyboard_duration != duration:
+ raise ValueError(
+ f"Total storyboard duration ({total_storyboard_duration}s) "
+ f"must equal the global duration ({duration}s)."
+ )
+
if get_number_of_images(reference_images) > 7:
raise ValueError("The maximum number of reference images is 7.")
for i in reference_images:
@@ -1070,6 +1329,10 @@ class OmniProImageToVideoNode(IO.ComfyNode):
duration=str(duration),
image_list=image_list,
mode="pro" if resolution == "1080p" else "std",
+ sound="on" if generate_audio else "off",
+ multi_shot=multi_shot,
+ multi_prompt=multi_prompt_list,
+ shot_type="customize" if multi_shot else None,
),
)
return await finish_omni_video_task(cls, response)
@@ -1081,11 +1344,11 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProVideoToVideoNode",
- display_name="Kling Omni Video to Video (Pro)",
+ display_name="Kling 3.0 Omni Video to Video",
category="api node/video/Kling",
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -1102,6 +1365,17 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
optional=True,
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -1135,7 +1409,9 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
keep_original_sound: bool,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
@@ -1179,11 +1455,11 @@ class OmniProEditVideoNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProEditVideoNode",
- display_name="Kling Omni Edit Video (Pro)",
+ display_name="Kling 3.0 Omni Edit Video",
category="api node/video/Kling",
description="Edit an existing video with the latest model from Kling.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-video-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
IO.String.Input(
"prompt",
multiline=True,
@@ -1198,6 +1474,17 @@ class OmniProEditVideoNode(IO.ComfyNode):
optional=True,
),
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Video.Output(),
@@ -1229,7 +1516,9 @@ class OmniProEditVideoNode(IO.ComfyNode):
keep_original_sound: bool,
reference_images: Input.Image | None = None,
resolution: str = "1080p",
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
@@ -1273,27 +1562,43 @@ class OmniProImageNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingOmniProImageNode",
- display_name="Kling Omni Image (Pro)",
+ display_name="Kling 3.0 Omni Image",
category="api node/image/Kling",
description="Create or edit images with the latest model from Kling.",
inputs=[
- IO.Combo.Input("model_name", options=["kling-image-o1"]),
+ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]),
IO.String.Input(
"prompt",
multiline=True,
tooltip="A text prompt describing the image content. "
"This can include both positive and negative descriptions.",
),
- IO.Combo.Input("resolution", options=["1K", "2K"]),
+ IO.Combo.Input("resolution", options=["1K", "2K", "4K"]),
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
),
+ IO.Combo.Input(
+ "series_amount",
+ options=["disabled", "2", "3", "4", "5", "6", "7", "8", "9"],
+ tooltip="Generate a series of images. Not supported for kling-image-o1.",
+ ),
IO.Image.Input(
"reference_images",
tooltip="Up to 10 additional reference images.",
optional=True,
),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Image.Output(),
@@ -1305,7 +1610,16 @@ class OmniProImageNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- expr="""{"type":"usd","usd":0.028}""",
+ depends_on=IO.PriceBadgeDepends(widgets=["resolution", "series_amount", "model_name"]),
+ expr="""
+ (
+ $prices := {"1k": 0.028, "2k": 0.028, "4k": 0.056};
+ $base := $lookup($prices, widgets.resolution);
+ $isO1 := widgets.model_name = "kling-image-o1";
+ $mult := ($isO1 or widgets.series_amount = "disabled") ? 1 : $number(widgets.series_amount);
+ {"type":"usd","usd": $base * $mult}
+ )
+ """,
),
)
@@ -1316,8 +1630,13 @@ class OmniProImageNode(IO.ComfyNode):
prompt: str,
resolution: str,
aspect_ratio: str,
+ series_amount: str = "disabled",
reference_images: Input.Image | None = None,
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
+ if model_name == "kling-image-o1" and resolution == "4K":
+ raise ValueError("4K resolution is not supported for kling-image-o1 model.")
prompt = normalize_omni_prompt_references(prompt)
validate_string(prompt, min_length=1, max_length=2500)
image_list: list[OmniImageParamImage] = []
@@ -1329,6 +1648,9 @@ class OmniProImageNode(IO.ComfyNode):
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
image_list.append(OmniImageParamImage(image=i))
+ use_series = series_amount != "disabled"
+ if use_series and model_name == "kling-image-o1":
+ raise ValueError("kling-image-o1 does not support series generation.")
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
@@ -1339,6 +1661,8 @@ class OmniProImageNode(IO.ComfyNode):
resolution=resolution.lower(),
aspect_ratio=aspect_ratio,
image_list=image_list if image_list else None,
+ result_type="series" if use_series else None,
+ series_amount=int(series_amount) if use_series else None,
),
)
if response.code:
@@ -1351,7 +1675,9 @@ class OmniProImageNode(IO.ComfyNode):
response_model=TaskStatusResponse,
status_extractor=lambda r: (r.data.task_status if r.data else None),
)
- return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
+ images = final_response.data.task_result.series_images or final_response.data.task_result.images
+ tensors = [await download_url_to_image_tensor(img.url) for img in images]
+ return IO.NodeOutput(torch.cat(tensors, dim=0))
class KlingCameraControlT2VNode(IO.ComfyNode):
@@ -2119,7 +2445,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImageGenerationNode",
- display_name="Kling Image Generation",
+ display_name="Kling 3.0 Image",
category="api node/image/Kling",
description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.",
inputs=[
@@ -2147,11 +2473,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
display_mode=IO.NumberDisplay.slider,
tooltip="Subject reference similarity",
),
- IO.Combo.Input(
- "model_name",
- options=[i.value for i in KlingImageGenModelName],
- default="kling-v2",
- ),
+ IO.Combo.Input("model_name", options=["kling-v3", "kling-v2", "kling-v1-5"]),
IO.Combo.Input(
"aspect_ratio",
options=[i.value for i in KlingImageGenAspectRatio],
@@ -2165,6 +2487,17 @@ class KlingImageGenerationNode(IO.ComfyNode):
tooltip="Number of generated images",
),
IO.Image.Input("image", optional=True),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ optional=True,
+ ),
],
outputs=[
IO.Image.Output(),
@@ -2183,7 +2516,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
$base :=
$contains($m,"kling-v1-5")
? (inputs.image.connected ? 0.028 : 0.014)
- : ($contains($m,"kling-v1") ? 0.0035 : 0.014);
+ : $contains($m,"kling-v3") ? 0.028 : 0.014;
{"type":"usd","usd": $base * widgets.n}
)
""",
@@ -2193,7 +2526,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
@classmethod
async def execute(
cls,
- model_name: KlingImageGenModelName,
+ model_name: str,
prompt: str,
negative_prompt: str,
image_type: KlingImageGenImageReferenceType,
@@ -2202,17 +2535,11 @@ class KlingImageGenerationNode(IO.ComfyNode):
n: int,
aspect_ratio: KlingImageGenAspectRatio,
image: torch.Tensor | None = None,
+ seed: int = 0,
) -> IO.NodeOutput:
+ _ = seed
validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
-
- if image is None:
- image_type = None
- elif model_name == KlingImageGenModelName.kling_v1:
- raise ValueError(f"The model {KlingImageGenModelName.kling_v1.value} does not support reference images.")
- else:
- image = tensor_to_base64_string(image)
-
task_creation_response = await sync_op(
cls,
ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"),
@@ -2221,8 +2548,8 @@ class KlingImageGenerationNode(IO.ComfyNode):
model_name=model_name,
prompt=prompt,
negative_prompt=negative_prompt,
- image=image,
- image_reference=image_type,
+ image=tensor_to_base64_string(image) if image is not None else None,
+ image_reference=image_type if image is not None else None,
image_fidelity=image_fidelity,
human_fidelity=human_fidelity,
n=n,
@@ -2252,7 +2579,7 @@ class TextToVideoWithAudio(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingTextToVideoWithAudio",
- display_name="Kling Text to Video with Audio",
+ display_name="Kling 2.6 Text to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
@@ -2320,7 +2647,7 @@ class ImageToVideoWithAudio(IO.ComfyNode):
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="KlingImageToVideoWithAudio",
- display_name="Kling Image(First Frame) to Video with Audio",
+ display_name="Kling 2.6 Image(First Frame) to Video with Audio",
category="api node/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
@@ -2478,6 +2805,335 @@ class MotionControl(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+class KlingVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingVideoNode",
+ display_name="Kling 3.0 Video",
+ category="api node/video/Kling",
+ description="Generate videos with Kling V3. "
+ "Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.",
+ inputs=[
+ IO.DynamicCombo.Input(
+ "multi_shot",
+ options=[
+ IO.DynamicCombo.Option(
+ "disabled",
+ [
+ IO.String.Input("prompt", multiline=True, default=""),
+ IO.String.Input("negative_prompt", multiline=True, default=""),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=3,
+ max=15,
+ display_mode=IO.NumberDisplay.slider,
+ ),
+ ],
+ ),
+ IO.DynamicCombo.Option("1 storyboard", _generate_storyboard_inputs(1)),
+ IO.DynamicCombo.Option("2 storyboards", _generate_storyboard_inputs(2)),
+ IO.DynamicCombo.Option("3 storyboards", _generate_storyboard_inputs(3)),
+ IO.DynamicCombo.Option("4 storyboards", _generate_storyboard_inputs(4)),
+ IO.DynamicCombo.Option("5 storyboards", _generate_storyboard_inputs(5)),
+ IO.DynamicCombo.Option("6 storyboards", _generate_storyboard_inputs(6)),
+ ],
+ tooltip="Generate a series of video segments with individual prompts and durations.",
+ ),
+ IO.Boolean.Input("generate_audio", default=True),
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "kling-v3",
+ [
+ IO.Combo.Input("resolution", options=["1080p", "720p"]),
+ IO.Combo.Input(
+ "aspect_ratio",
+ options=["16:9", "9:16", "1:1"],
+ tooltip="Ignored in image-to-video mode.",
+ ),
+ ],
+ ),
+ ],
+ tooltip="Model and generation settings.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ IO.Image.Input(
+ "start_frame",
+ optional=True,
+ tooltip="Optional start frame image. When connected, switches to image-to-video mode.",
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(
+ widgets=[
+ "model.resolution",
+ "generate_audio",
+ "multi_shot",
+ "multi_shot.duration",
+ "multi_shot.storyboard_1_duration",
+ "multi_shot.storyboard_2_duration",
+ "multi_shot.storyboard_3_duration",
+ "multi_shot.storyboard_4_duration",
+ "multi_shot.storyboard_5_duration",
+ "multi_shot.storyboard_6_duration",
+ ],
+ ),
+ expr="""
+ (
+ $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}};
+ $res := $lookup(widgets, "model.resolution");
+ $audio := widgets.generate_audio ? "on" : "off";
+ $rate := $lookup($lookup($rates, $res), $audio);
+ $ms := widgets.multi_shot;
+ $isSb := $ms != "disabled";
+ $n := $isSb ? $number($substring($ms, 0, 1)) : 0;
+ $d1 := $lookup(widgets, "multi_shot.storyboard_1_duration");
+ $d2 := $n >= 2 ? $lookup(widgets, "multi_shot.storyboard_2_duration") : 0;
+ $d3 := $n >= 3 ? $lookup(widgets, "multi_shot.storyboard_3_duration") : 0;
+ $d4 := $n >= 4 ? $lookup(widgets, "multi_shot.storyboard_4_duration") : 0;
+ $d5 := $n >= 5 ? $lookup(widgets, "multi_shot.storyboard_5_duration") : 0;
+ $d6 := $n >= 6 ? $lookup(widgets, "multi_shot.storyboard_6_duration") : 0;
+ $dur := $isSb ? $d1 + $d2 + $d3 + $d4 + $d5 + $d6 : $lookup(widgets, "multi_shot.duration");
+ {"type":"usd","usd": $rate * $dur}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ multi_shot: dict,
+ generate_audio: bool,
+ model: dict,
+ seed: int,
+ start_frame: Input.Image | None = None,
+ ) -> IO.NodeOutput:
+ _ = seed
+ mode = "pro" if model["resolution"] == "1080p" else "std"
+ custom_multi_shot = False
+ if multi_shot["multi_shot"] == "disabled":
+ shot_type = None
+ else:
+ shot_type = "customize"
+ custom_multi_shot = True
+
+ multi_prompt_list = None
+ if shot_type == "customize":
+ count = int(multi_shot["multi_shot"].split()[0])
+ multi_prompt_list = []
+ for i in range(1, count + 1):
+ sb_prompt = multi_shot[f"storyboard_{i}_prompt"]
+ sb_duration = multi_shot[f"storyboard_{i}_duration"]
+ validate_string(sb_prompt, field_name=f"storyboard_{i}_prompt", min_length=1, max_length=512)
+ multi_prompt_list.append(
+ MultiPromptEntry(
+ index=i,
+ prompt=sb_prompt,
+ duration=str(sb_duration),
+ )
+ )
+ duration = sum(int(e.duration) for e in multi_prompt_list)
+ if duration < 3 or duration > 15:
+ raise ValueError(
+ f"Total storyboard duration ({duration}s) must be between 3 and 15 seconds."
+ )
+ else:
+ duration = multi_shot["duration"]
+ validate_string(multi_shot["prompt"], min_length=1, max_length=2500)
+
+ if start_frame is not None:
+ validate_image_dimensions(start_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
+ image_url = await upload_image_to_comfyapi(cls, start_frame, wait_label="Uploading start frame")
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
+ response_model=TaskStatusResponse,
+ data=ImageToVideoWithAudioRequest(
+ model_name=model["model"],
+ image=image_url,
+ prompt=None if custom_multi_shot else multi_shot["prompt"],
+ negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"],
+ mode=mode,
+ duration=str(duration),
+ sound="on" if generate_audio else "off",
+ multi_shot=True if shot_type else None,
+ multi_prompt=multi_prompt_list,
+ shot_type=shot_type,
+ ),
+ )
+ poll_path = f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"
+ else:
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
+ response_model=TaskStatusResponse,
+ data=TextToVideoWithAudioRequest(
+ model_name=model["model"],
+ aspect_ratio=model["aspect_ratio"],
+ prompt=None if custom_multi_shot else multi_shot["prompt"],
+ negative_prompt=None if custom_multi_shot else multi_shot["negative_prompt"],
+ mode=mode,
+ duration=str(duration),
+ sound="on" if generate_audio else "off",
+ multi_shot=True if shot_type else None,
+ multi_prompt=multi_prompt_list,
+ shot_type=shot_type,
+ ),
+ )
+ poll_path = f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"
+
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=poll_path),
+ response_model=TaskStatusResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ )
+ return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+
+
+class KlingFirstLastFrameNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls) -> IO.Schema:
+ return IO.Schema(
+ node_id="KlingFirstLastFrameNode",
+ display_name="Kling 3.0 First-Last-Frame to Video",
+ category="api node/video/Kling",
+ description="Generate videos with Kling V3 using first and last frames.",
+ inputs=[
+ IO.String.Input("prompt", multiline=True, default=""),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=3,
+ max=15,
+ display_mode=IO.NumberDisplay.slider,
+ ),
+ IO.Image.Input("first_frame"),
+ IO.Image.Input("end_frame"),
+ IO.Boolean.Input("generate_audio", default=True),
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "kling-v3",
+ [
+ IO.Combo.Input("resolution", options=["1080p", "720p"]),
+ ],
+ ),
+ ],
+ tooltip="Model and generation settings.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(
+ widgets=["model.resolution", "generate_audio", "duration"],
+ ),
+ expr="""
+ (
+ $rates := {"1080p": {"off": 0.112, "on": 0.168}, "720p": {"off": 0.084, "on": 0.126}};
+ $res := $lookup(widgets, "model.resolution");
+ $audio := widgets.generate_audio ? "on" : "off";
+ $rate := $lookup($lookup($rates, $res), $audio);
+ {"type":"usd","usd": $rate * widgets.duration}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ prompt: str,
+ duration: int,
+ first_frame: Input.Image,
+ end_frame: Input.Image,
+ generate_audio: bool,
+ model: dict,
+ seed: int,
+ ) -> IO.NodeOutput:
+ _ = seed
+ validate_string(prompt, min_length=1, max_length=2500)
+ validate_image_dimensions(first_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
+ validate_image_dimensions(end_frame, min_width=300, min_height=300)
+ validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
+ image_url = await upload_image_to_comfyapi(cls, first_frame, wait_label="Uploading first frame")
+ image_tail_url = await upload_image_to_comfyapi(cls, end_frame, wait_label="Uploading end frame")
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
+ response_model=TaskStatusResponse,
+ data=ImageToVideoWithAudioRequest(
+ model_name=model["model"],
+ image=image_url,
+ image_tail=image_tail_url,
+ prompt=prompt,
+ mode="pro" if model["resolution"] == "1080p" else "std",
+ duration=str(duration),
+ sound="on" if generate_audio else "off",
+ ),
+ )
+ if response.code:
+ raise RuntimeError(
+ f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ )
+ final_response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
+ response_model=TaskStatusResponse,
+ status_extractor=lambda r: (r.data.task_status if r.data else None),
+ )
+ return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
+
+
class KlingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -2504,6 +3160,8 @@ class KlingExtension(ComfyExtension):
TextToVideoWithAudio,
ImageToVideoWithAudio,
MotionControl,
+ KlingVideoNode,
+ KlingFirstLastFrameNode,
]
From 6615db925c9f84843e29db118852e14b643a1a03 Mon Sep 17 00:00:00 2001
From: ComfyUI Wiki
Date: Wed, 11 Feb 2026 02:24:56 +0800
Subject: [PATCH 07/85] chore: update workflow templates to v0.8.38 (#12394)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 4e2773f5d..7de6a413c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
comfyui-frontend-package==1.38.13
-comfyui-workflow-templates==0.8.37
+comfyui-workflow-templates==0.8.38
comfyui-embedded-docs==0.4.1
torch
torchsde
From 6648ab68bc934a185c90a2a872c87dc64d093751 Mon Sep 17 00:00:00 2001
From: comfyanonymous
Date: Tue, 10 Feb 2026 13:26:29 -0500
Subject: [PATCH 08/85] ComfyUI v0.13.0
---
comfyui_version.py | 2 +-
pyproject.toml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfyui_version.py b/comfyui_version.py
index 706b37763..cf4e89816 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.12.3"
+__version__ = "0.13.0"
diff --git a/pyproject.toml b/pyproject.toml
index f7925b92a..9dab9a50c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.12.3"
+version = "0.13.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
From fe053ba5eb34c8abcc5d17a25c114340af1833aa Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Tue, 10 Feb 2026 10:37:17 -0800
Subject: [PATCH 09/85] mp: dont deep-clone objects from model_options (#12382)
If there are non-trivial python objects nested in the model_options, this
causes all sorts of issues. Traverse lists and dicts so clones can safely
overide settings and BYO objects but stop there on the deepclone.
---
comfy/model_patcher.py | 3 +--
comfy/utils.py | 18 ++++++++++++++++++
2 files changed, 19 insertions(+), 2 deletions(-)
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index b9a117a7c..19c9031ea 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -19,7 +19,6 @@
from __future__ import annotations
import collections
-import copy
import inspect
import logging
import math
@@ -317,7 +316,7 @@ class ModelPatcher:
n.object_patches = self.object_patches.copy()
n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
- n.model_options = copy.deepcopy(self.model_options)
+ n.model_options = comfy.utils.deepcopy_list_dict(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
diff --git a/comfy/utils.py b/comfy/utils.py
index 1337e2205..edd80cebe 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -1376,3 +1376,21 @@ def string_to_seed(data):
else:
crc >>= 1
return crc ^ 0xFFFFFFFF
+
+def deepcopy_list_dict(obj, memo=None):
+ if memo is None:
+ memo = {}
+
+ obj_id = id(obj)
+ if obj_id in memo:
+ return memo[obj_id]
+
+ if isinstance(obj, dict):
+ res = {deepcopy_list_dict(k, memo): deepcopy_list_dict(v, memo) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ res = [deepcopy_list_dict(i, memo) for i in obj]
+ else:
+ res = obj
+
+ memo[obj_id] = res
+ return res
From f719f9c06266e7944683009b403e995d4c61d5f0 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Tue, 10 Feb 2026 10:37:46 -0800
Subject: [PATCH 10/85] sd: delay VAE dtype archive until after override
(#12388)
VAEs have host specific dtype logic that should override the dynamic
_model_dtype. Defer the archiving of model dtypes until after.
---
comfy/sd.py | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/comfy/sd.py b/comfy/sd.py
index bc9407405..f65e7cadd 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -793,8 +793,6 @@ class VAE:
self.first_stage_model = AutoencoderKL(**(config['params']))
self.first_stage_model = self.first_stage_model.eval()
- model_management.archive_model_dtypes(self.first_stage_model)
-
if device is None:
device = model_management.vae_device()
self.device = device
@@ -803,6 +801,7 @@ class VAE:
dtype = model_management.vae_dtype(self.device, self.working_dtypes)
self.vae_dtype = dtype
self.first_stage_model.to(self.vae_dtype)
+ model_management.archive_model_dtypes(self.first_stage_model)
self.output_device = model_management.intermediate_device()
mp = comfy.model_patcher.CoreModelPatcher
From 123a7874a97c4a8b8f06d4b7c2b1a566b8f0d057 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Tue, 10 Feb 2026 10:38:28 -0800
Subject: [PATCH 11/85] ops: Fix vanilla-fp8 loaded lora quality (#12390)
This was missing the stochastic rounding required for fp8 downcast
to be consistent with model_patcher.patch_weight_to_device.
Missed in testing as I spend too much time with quantized tensors
and overlooked the simpler ones.
---
comfy/ops.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfy/ops.py b/comfy/ops.py
index ea0d70702..33803b223 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -169,8 +169,8 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if orig.dtype == dtype and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
- else:
- y = x
+ elif update_weight:
+ y = comfy.float.stochastic_rounding(x, orig.dtype, seed = comfy.utils.string_to_seed(s.seed_key))
if update_weight:
orig.copy_(y)
for f in fns:
From 00fff6019ecf0f4306005579e93cef0cd51a3a1c Mon Sep 17 00:00:00 2001
From: guill
Date: Tue, 10 Feb 2026 14:37:14 -0800
Subject: [PATCH 12/85] feat(jobs): add 3d to PREVIEWABLE_MEDIA_TYPES for
first-class 3D output support (#12381)
Co-authored-by: Jedrzej Kosinski
---
comfy_execution/jobs.py | 79 +++++++++++--
tests/execution/test_jobs.py | 208 ++++++++++++++++++++++++++++++++++-
2 files changed, 271 insertions(+), 16 deletions(-)
diff --git a/comfy_execution/jobs.py b/comfy_execution/jobs.py
index bf091a448..370014fb6 100644
--- a/comfy_execution/jobs.py
+++ b/comfy_execution/jobs.py
@@ -20,10 +20,60 @@ class JobStatus:
# Media types that can be previewed in the frontend
-PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio'})
+PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d'})
# 3D file extensions for preview fallback (no dedicated media_type exists)
-THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb'})
+THREE_D_EXTENSIONS = frozenset({'.obj', '.fbx', '.gltf', '.glb', '.usdz'})
+
+
+def has_3d_extension(filename: str) -> bool:
+ lower = filename.lower()
+ return any(lower.endswith(ext) for ext in THREE_D_EXTENSIONS)
+
+
+def normalize_output_item(item):
+ """Normalize a single output list item for the jobs API.
+
+ Returns the normalized item, or None to exclude it.
+ String items with 3D extensions become {filename, type, subfolder} dicts.
+ """
+ if item is None:
+ return None
+ if isinstance(item, str):
+ if has_3d_extension(item):
+ return {'filename': item, 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
+ return None
+ if isinstance(item, dict):
+ return item
+ return None
+
+
+def normalize_outputs(outputs: dict) -> dict:
+ """Normalize raw node outputs for the jobs API.
+
+ Transforms string 3D filenames into file output dicts and removes
+ None items. All other items (non-3D strings, dicts, etc.) are
+ preserved as-is.
+ """
+ normalized = {}
+ for node_id, node_outputs in outputs.items():
+ if not isinstance(node_outputs, dict):
+ normalized[node_id] = node_outputs
+ continue
+ normalized_node = {}
+ for media_type, items in node_outputs.items():
+ if media_type == 'animated' or not isinstance(items, list):
+ normalized_node[media_type] = items
+ continue
+ normalized_items = []
+ for item in items:
+ if item is None:
+ continue
+ norm = normalize_output_item(item)
+ normalized_items.append(norm if norm is not None else item)
+ normalized_node[media_type] = normalized_items
+ normalized[node_id] = normalized_node
+ return normalized
def _extract_job_metadata(extra_data: dict) -> tuple[Optional[int], Optional[str]]:
@@ -45,9 +95,9 @@ def is_previewable(media_type: str, item: dict) -> bool:
Maintains backwards compatibility with existing logic.
Priority:
- 1. media_type is 'images', 'video', or 'audio'
+ 1. media_type is 'images', 'video', 'audio', or '3d'
2. format field starts with 'video/' or 'audio/'
- 3. filename has a 3D extension (.obj, .fbx, .gltf, .glb)
+ 3. filename has a 3D extension (.obj, .fbx, .gltf, .glb, .usdz)
"""
if media_type in PREVIEWABLE_MEDIA_TYPES:
return True
@@ -139,7 +189,7 @@ def normalize_history_item(prompt_id: str, history_item: dict, include_outputs:
})
if include_outputs:
- job['outputs'] = outputs
+ job['outputs'] = normalize_outputs(outputs)
job['execution_status'] = status_info
job['workflow'] = {
'prompt': prompt,
@@ -171,18 +221,23 @@ def get_outputs_summary(outputs: dict) -> tuple[int, Optional[dict]]:
continue
for item in items:
- count += 1
-
- if not isinstance(item, dict):
+ normalized = normalize_output_item(item)
+ if normalized is None:
continue
- if preview_output is None and is_previewable(media_type, item):
+ count += 1
+
+ if preview_output is not None:
+ continue
+
+ if isinstance(normalized, dict) and is_previewable(media_type, normalized):
enriched = {
- **item,
+ **normalized,
'nodeId': node_id,
- 'mediaType': media_type
}
- if item.get('type') == 'output':
+ if 'mediaType' not in normalized:
+ enriched['mediaType'] = media_type
+ if normalized.get('type') == 'output':
preview_output = enriched
elif fallback_preview is None:
fallback_preview = enriched
diff --git a/tests/execution/test_jobs.py b/tests/execution/test_jobs.py
index 4d2f9ed36..83c36fe48 100644
--- a/tests/execution/test_jobs.py
+++ b/tests/execution/test_jobs.py
@@ -5,8 +5,11 @@ from comfy_execution.jobs import (
is_previewable,
normalize_queue_item,
normalize_history_item,
+ normalize_output_item,
+ normalize_outputs,
get_outputs_summary,
apply_sorting,
+ has_3d_extension,
)
@@ -35,8 +38,8 @@ class TestIsPreviewable:
"""Unit tests for is_previewable()"""
def test_previewable_media_types(self):
- """Images, video, audio media types should be previewable."""
- for media_type in ['images', 'video', 'audio']:
+ """Images, video, audio, 3d media types should be previewable."""
+ for media_type in ['images', 'video', 'audio', '3d']:
assert is_previewable(media_type, {}) is True
def test_non_previewable_media_types(self):
@@ -46,7 +49,7 @@ class TestIsPreviewable:
def test_3d_extensions_previewable(self):
"""3D file extensions should be previewable regardless of media_type."""
- for ext in ['.obj', '.fbx', '.gltf', '.glb']:
+ for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
item = {'filename': f'model{ext}'}
assert is_previewable('files', item) is True
@@ -160,7 +163,7 @@ class TestGetOutputsSummary:
def test_3d_files_previewable(self):
"""3D file extensions should be previewable."""
- for ext in ['.obj', '.fbx', '.gltf', '.glb']:
+ for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
outputs = {
'node1': {
'files': [{'filename': f'model{ext}', 'type': 'output'}]
@@ -192,6 +195,64 @@ class TestGetOutputsSummary:
assert preview['mediaType'] == 'images'
assert preview['subfolder'] == 'outputs'
+ def test_string_3d_filename_creates_preview(self):
+ """String items with 3D extensions should synthesize a preview (Preview3D node output).
+ Only the .glb counts — nulls and non-file strings are excluded."""
+ outputs = {
+ 'node1': {
+ 'result': ['preview3d_abc123.glb', None, None]
+ }
+ }
+ count, preview = get_outputs_summary(outputs)
+ assert count == 1
+ assert preview is not None
+ assert preview['filename'] == 'preview3d_abc123.glb'
+ assert preview['mediaType'] == '3d'
+ assert preview['nodeId'] == 'node1'
+ assert preview['type'] == 'output'
+
+ def test_string_non_3d_filename_no_preview(self):
+ """String items without 3D extensions should not create a preview."""
+ outputs = {
+ 'node1': {
+ 'result': ['data.json', None]
+ }
+ }
+ count, preview = get_outputs_summary(outputs)
+ assert count == 0
+ assert preview is None
+
+ def test_string_3d_filename_used_as_fallback(self):
+ """String 3D preview should be used when no dict items are previewable."""
+ outputs = {
+ 'node1': {
+ 'latents': [{'filename': 'latent.safetensors'}],
+ },
+ 'node2': {
+ 'result': ['model.glb', None]
+ }
+ }
+ count, preview = get_outputs_summary(outputs)
+ assert preview is not None
+ assert preview['filename'] == 'model.glb'
+ assert preview['mediaType'] == '3d'
+
+
+class TestHas3DExtension:
+ """Unit tests for has_3d_extension()"""
+
+ def test_recognized_extensions(self):
+ for ext in ['.obj', '.fbx', '.gltf', '.glb', '.usdz']:
+ assert has_3d_extension(f'model{ext}') is True
+
+ def test_case_insensitive(self):
+ assert has_3d_extension('MODEL.GLB') is True
+ assert has_3d_extension('Scene.GLTF') is True
+
+ def test_non_3d_extensions(self):
+ for name in ['photo.png', 'video.mp4', 'data.json', 'model']:
+ assert has_3d_extension(name) is False
+
class TestApplySorting:
"""Unit tests for apply_sorting()"""
@@ -395,3 +456,142 @@ class TestNormalizeHistoryItem:
'prompt': {'nodes': {'1': {}}},
'extra_data': {'create_time': 1234567890, 'client_id': 'abc'},
}
+
+ def test_include_outputs_normalizes_3d_strings(self):
+ """Detail view should transform string 3D filenames into file output dicts."""
+ history_item = {
+ 'prompt': (
+ 5,
+ 'prompt-3d',
+ {'nodes': {}},
+ {'create_time': 1234567890},
+ ['node1'],
+ ),
+ 'status': {'status_str': 'success', 'completed': True, 'messages': []},
+ 'outputs': {
+ 'node1': {
+ 'result': ['preview3d_abc123.glb', None, None]
+ }
+ },
+ }
+ job = normalize_history_item('prompt-3d', history_item, include_outputs=True)
+
+ assert job['outputs_count'] == 1
+ result_items = job['outputs']['node1']['result']
+ assert len(result_items) == 1
+ assert result_items[0] == {
+ 'filename': 'preview3d_abc123.glb',
+ 'type': 'output',
+ 'subfolder': '',
+ 'mediaType': '3d',
+ }
+
+ def test_include_outputs_preserves_dict_items(self):
+ """Detail view normalization should pass dict items through unchanged."""
+ history_item = {
+ 'prompt': (
+ 5,
+ 'prompt-img',
+ {'nodes': {}},
+ {'create_time': 1234567890},
+ ['node1'],
+ ),
+ 'status': {'status_str': 'success', 'completed': True, 'messages': []},
+ 'outputs': {
+ 'node1': {
+ 'images': [
+ {'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
+ ]
+ }
+ },
+ }
+ job = normalize_history_item('prompt-img', history_item, include_outputs=True)
+
+ assert job['outputs_count'] == 1
+ assert job['outputs']['node1']['images'] == [
+ {'filename': 'photo.png', 'type': 'output', 'subfolder': ''},
+ ]
+
+
+class TestNormalizeOutputItem:
+ """Unit tests for normalize_output_item()"""
+
+ def test_none_returns_none(self):
+ assert normalize_output_item(None) is None
+
+ def test_string_3d_extension_synthesizes_dict(self):
+ result = normalize_output_item('model.glb')
+ assert result == {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'}
+
+ def test_string_non_3d_extension_returns_none(self):
+ assert normalize_output_item('data.json') is None
+
+ def test_string_no_extension_returns_none(self):
+ assert normalize_output_item('camera_info_string') is None
+
+ def test_dict_passes_through(self):
+ item = {'filename': 'test.png', 'type': 'output'}
+ assert normalize_output_item(item) is item
+
+ def test_other_types_return_none(self):
+ assert normalize_output_item(42) is None
+ assert normalize_output_item(True) is None
+
+
+class TestNormalizeOutputs:
+ """Unit tests for normalize_outputs()"""
+
+ def test_empty_outputs(self):
+ assert normalize_outputs({}) == {}
+
+ def test_dict_items_pass_through(self):
+ outputs = {
+ 'node1': {
+ 'images': [{'filename': 'a.png', 'type': 'output'}],
+ }
+ }
+ result = normalize_outputs(outputs)
+ assert result == outputs
+
+ def test_3d_string_synthesized(self):
+ outputs = {
+ 'node1': {
+ 'result': ['model.glb', None, None],
+ }
+ }
+ result = normalize_outputs(outputs)
+ assert result == {
+ 'node1': {
+ 'result': [
+ {'filename': 'model.glb', 'type': 'output', 'subfolder': '', 'mediaType': '3d'},
+ ],
+ }
+ }
+
+ def test_animated_key_preserved(self):
+ outputs = {
+ 'node1': {
+ 'images': [{'filename': 'a.png', 'type': 'output'}],
+ 'animated': [True],
+ }
+ }
+ result = normalize_outputs(outputs)
+ assert result['node1']['animated'] == [True]
+
+ def test_non_dict_node_outputs_preserved(self):
+ outputs = {'node1': 'unexpected_value'}
+ result = normalize_outputs(outputs)
+ assert result == {'node1': 'unexpected_value'}
+
+ def test_none_items_filtered_but_other_types_preserved(self):
+ outputs = {
+ 'node1': {
+ 'result': ['data.json', None, [1, 2, 3]],
+ }
+ }
+ result = normalize_outputs(outputs)
+ assert result == {
+ 'node1': {
+ 'result': ['data.json', [1, 2, 3]],
+ }
+ }
From dbe70b6821994ce92d9cf211cc685862d0b6c0ca Mon Sep 17 00:00:00 2001
From: AustinMroz
Date: Tue, 10 Feb 2026 14:42:21 -0800
Subject: [PATCH 13/85] Add a VideoSlice node (#12107)
* Base TrimVideo implementation
* Raise error if as_trimmed call fails
* Bigger max start_time, tooltips, and formatting
* Count packets unless codec has subframes
* Remove incorrect nested decode
* Add null check for audio streams
* Support non-strict duration
* Added strict_duration bool to node definition
* Empty commit for approval
* Fix duration
* Support 5.1 audio layout on save
---------
Co-authored-by: Jedrzej Kosinski
---
comfy_api/latest/_input/video_types.py | 15 ++
comfy_api/latest/_input_impl/video_types.py | 201 ++++++++++++++------
comfy_extras/nodes_video.py | 51 +++++
3 files changed, 207 insertions(+), 60 deletions(-)
diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py
index e634a0311..451e9526e 100644
--- a/comfy_api/latest/_input/video_types.py
+++ b/comfy_api/latest/_input/video_types.py
@@ -34,6 +34,21 @@ class VideoInput(ABC):
"""
pass
+ @abstractmethod
+ def as_trimmed(
+ self,
+ start_time: float | None = None,
+ duration: float | None = None,
+ strict_duration: bool = False,
+ ) -> VideoInput | None:
+ """
+ Create a new VideoInput which is trimmed to have the corresponding start_time and duration
+
+ Returns:
+ A new VideoInput, or None if the result would have negative duration
+ """
+ pass
+
def get_stream_source(self) -> Union[str, io.BytesIO]:
"""
Get a streamable source for the video. This allows processing without
diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py
index 1405d0b81..3463ed1c9 100644
--- a/comfy_api/latest/_input_impl/video_types.py
+++ b/comfy_api/latest/_input_impl/video_types.py
@@ -6,6 +6,7 @@ from typing import Optional
from .._input import AudioInput, VideoInput
import av
import io
+import itertools
import json
import numpy as np
import math
@@ -29,7 +30,6 @@ def container_to_output_format(container_format: str | None) -> str | None:
formats = container_format.split(",")
return formats[0]
-
def get_open_write_kwargs(
dest: str | io.BytesIO, container_format: str, to_format: str | None
) -> dict:
@@ -57,12 +57,14 @@ class VideoFromFile(VideoInput):
Class representing video input from a file.
"""
- def __init__(self, file: str | io.BytesIO):
+ def __init__(self, file: str | io.BytesIO, *, start_time: float=0, duration: float=0):
"""
Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
containing the file contents.
"""
self.__file = file
+ self.__start_time = start_time
+ self.__duration = duration
def get_stream_source(self) -> str | io.BytesIO:
"""
@@ -96,6 +98,16 @@ class VideoFromFile(VideoInput):
Returns:
Duration in seconds
"""
+ raw_duration = self._get_raw_duration()
+ if self.__start_time < 0:
+ duration_from_start = min(raw_duration, -self.__start_time)
+ else:
+ duration_from_start = raw_duration - self.__start_time
+ if self.__duration:
+ return min(self.__duration, duration_from_start)
+ return duration_from_start
+
+ def _get_raw_duration(self) -> float:
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0)
with av.open(self.__file, mode="r") as container:
@@ -113,9 +125,13 @@ class VideoFromFile(VideoInput):
if video_stream and video_stream.average_rate:
frame_count = 0
container.seek(0)
- for packet in container.demux(video_stream):
- for _ in packet.decode():
- frame_count += 1
+ frame_iterator = (
+ container.decode(video_stream)
+ if video_stream.codec.capabilities & 0x100
+ else container.demux(video_stream)
+ )
+ for packet in frame_iterator:
+ frame_count += 1
if frame_count > 0:
return float(frame_count / video_stream.average_rate)
@@ -131,36 +147,54 @@ class VideoFromFile(VideoInput):
with av.open(self.__file, mode="r") as container:
video_stream = self._get_first_video_stream(container)
- # 1. Prefer the frames field if available
- if video_stream.frames and video_stream.frames > 0:
+ # 1. Prefer the frames field if available and usable
+ if (
+ video_stream.frames
+ and video_stream.frames > 0
+ and not self.__start_time
+ and not self.__duration
+ ):
return int(video_stream.frames)
# 2. Try to estimate from duration and average_rate using only metadata
- if container.duration is not None and video_stream.average_rate:
- duration_seconds = float(container.duration / av.time_base)
- estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
- if estimated_frames > 0:
- return estimated_frames
-
if (
getattr(video_stream, "duration", None) is not None
and getattr(video_stream, "time_base", None) is not None
and video_stream.average_rate
):
- duration_seconds = float(video_stream.duration * video_stream.time_base)
+ raw_duration = float(video_stream.duration * video_stream.time_base)
+ if self.__start_time < 0:
+ duration_from_start = min(raw_duration, -self.__start_time)
+ else:
+ duration_from_start = raw_duration - self.__start_time
+ duration_seconds = min(self.__duration, duration_from_start)
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
if estimated_frames > 0:
return estimated_frames
# 3. Last resort: decode frames and count them (streaming)
- frame_count = 0
- container.seek(0)
- for packet in container.demux(video_stream):
- for _ in packet.decode():
- frame_count += 1
-
- if frame_count == 0:
- raise ValueError(f"Could not determine frame count for file '{self.__file}'")
+ if self.__start_time < 0:
+ start_time = max(self._get_raw_duration() + self.__start_time, 0)
+ else:
+ start_time = self.__start_time
+ frame_count = 1
+ start_pts = int(start_time / video_stream.time_base)
+ end_pts = int((start_time + self.__duration) / video_stream.time_base)
+ container.seek(start_pts, stream=video_stream)
+ frame_iterator = (
+ container.decode(video_stream)
+ if video_stream.codec.capabilities & 0x100
+ else container.demux(video_stream)
+ )
+ for frame in frame_iterator:
+ if frame.pts >= start_pts:
+ break
+ else:
+ raise ValueError(f"Could not determine frame count for file '{self.__file}'\nNo frames exist for start_time {self.__start_time}")
+ for frame in frame_iterator:
+ if frame.pts >= end_pts:
+ break
+ frame_count += 1
return frame_count
def get_frame_rate(self) -> Fraction:
@@ -199,9 +233,21 @@ class VideoFromFile(VideoInput):
return container.format.name
def get_components_internal(self, container: InputContainer) -> VideoComponents:
+ video_stream = self._get_first_video_stream(container)
+ if self.__start_time < 0:
+ start_time = max(self._get_raw_duration() + self.__start_time, 0)
+ else:
+ start_time = self.__start_time
# Get video frames
frames = []
- for frame in container.decode(video=0):
+ start_pts = int(start_time / video_stream.time_base)
+ end_pts = int((start_time + self.__duration) / video_stream.time_base)
+ container.seek(start_pts, stream=video_stream)
+ for frame in container.decode(video_stream):
+ if frame.pts < start_pts:
+ continue
+ if self.__duration and frame.pts >= end_pts:
+ break
img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
frames.append(img)
@@ -209,31 +255,44 @@ class VideoFromFile(VideoInput):
images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
# Get frame rate
- video_stream = next(s for s in container.streams if s.type == 'video')
- frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
+ frame_rate = Fraction(video_stream.average_rate) if video_stream.average_rate else Fraction(1)
# Get audio if available
audio = None
- try:
- container.seek(0) # Reset the container to the beginning
- for stream in container.streams:
- if stream.type != 'audio':
- continue
- assert isinstance(stream, av.AudioStream)
- audio_frames = []
- for packet in container.demux(stream):
- for frame in packet.decode():
- assert isinstance(frame, av.AudioFrame)
- audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
- if len(audio_frames) > 0:
- audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
- audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
- audio = AudioInput({
- "waveform": audio_tensor,
- "sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
- })
- except StopIteration:
- pass # No audio stream
+ container.seek(start_pts, stream=video_stream)
+ # Use last stream for consistency
+ if len(container.streams.audio):
+ audio_stream = container.streams.audio[-1]
+ audio_frames = []
+ resample = av.audio.resampler.AudioResampler(format='fltp').resample
+ frames = itertools.chain.from_iterable(
+ map(resample, container.decode(audio_stream))
+ )
+
+ has_first_frame = False
+ for frame in frames:
+ offset_seconds = start_time - frame.pts * audio_stream.time_base
+ to_skip = int(offset_seconds * audio_stream.sample_rate)
+ if to_skip < frame.samples:
+ has_first_frame = True
+ break
+ if has_first_frame:
+ audio_frames.append(frame.to_ndarray()[..., to_skip:])
+
+ for frame in frames:
+ if frame.time > start_time + self.__duration:
+ break
+ audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
+ if len(audio_frames) > 0:
+ audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
+ if self.__duration:
+ audio_data = audio_data[..., :int(self.__duration * audio_stream.sample_rate)]
+
+ audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
+ audio = AudioInput({
+ "waveform": audio_tensor,
+ "sample_rate": int(audio_stream.sample_rate) if audio_stream.sample_rate else 1,
+ })
metadata = container.metadata
return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
@@ -250,7 +309,7 @@ class VideoFromFile(VideoInput):
path: str | io.BytesIO,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
- metadata: Optional[dict] = None
+ metadata: Optional[dict] = None,
):
if isinstance(self.__file, io.BytesIO):
self.__file.seek(0) # Reset the BytesIO object to the beginning
@@ -262,15 +321,14 @@ class VideoFromFile(VideoInput):
reuse_streams = False
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
reuse_streams = False
+ if self.__start_time or self.__duration:
+ reuse_streams = False
if not reuse_streams:
components = self.get_components_internal(container)
video = VideoFromComponents(components)
return video.save_to(
- path,
- format=format,
- codec=codec,
- metadata=metadata
+ path, format=format, codec=codec, metadata=metadata
)
streams = container.streams
@@ -304,10 +362,21 @@ class VideoFromFile(VideoInput):
output_container.mux(packet)
def _get_first_video_stream(self, container: InputContainer):
- video_stream = next((s for s in container.streams if s.type == "video"), None)
- if video_stream is None:
- raise ValueError(f"No video stream found in file '{self.__file}'")
- return video_stream
+ if len(container.streams.video):
+ return container.streams.video[0]
+ raise ValueError(f"No video stream found in file '{self.__file}'")
+
+ def as_trimmed(
+ self, start_time: float = 0, duration: float = 0, strict_duration: bool = True
+ ) -> VideoInput | None:
+ trimmed = VideoFromFile(
+ self.get_stream_source(),
+ start_time=start_time + self.__start_time,
+ duration=duration,
+ )
+ if trimmed.get_duration() < duration and strict_duration:
+ return None
+ return trimmed
class VideoFromComponents(VideoInput):
@@ -322,7 +391,7 @@ class VideoFromComponents(VideoInput):
return VideoComponents(
images=self.__components.images,
audio=self.__components.audio,
- frame_rate=self.__components.frame_rate
+ frame_rate=self.__components.frame_rate,
)
def save_to(
@@ -330,7 +399,7 @@ class VideoFromComponents(VideoInput):
path: str,
format: VideoContainer = VideoContainer.AUTO,
codec: VideoCodec = VideoCodec.AUTO,
- metadata: Optional[dict] = None
+ metadata: Optional[dict] = None,
):
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
@@ -357,7 +426,10 @@ class VideoFromComponents(VideoInput):
audio_stream: Optional[av.AudioStream] = None
if self.__components.audio:
audio_sample_rate = int(self.__components.audio['sample_rate'])
- audio_stream = output.add_stream('aac', rate=audio_sample_rate)
+ waveform = self.__components.audio['waveform']
+ waveform = waveform[0, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
+ layout = {1: 'mono', 2: 'stereo', 6: '5.1'}.get(waveform.shape[0], 'stereo')
+ audio_stream = output.add_stream('aac', rate=audio_sample_rate, layout=layout)
# Encode video
for i, frame in enumerate(self.__components.images):
@@ -372,12 +444,21 @@ class VideoFromComponents(VideoInput):
output.mux(packet)
if audio_stream and self.__components.audio:
- waveform = self.__components.audio['waveform']
- waveform = waveform[:, :, :math.ceil((audio_sample_rate / frame_rate) * self.__components.images.shape[0])]
- frame = av.AudioFrame.from_ndarray(waveform.movedim(2, 1).reshape(1, -1).float().cpu().numpy(), format='flt', layout='mono' if waveform.shape[1] == 1 else 'stereo')
+ frame = av.AudioFrame.from_ndarray(waveform.float().cpu().numpy(), format='fltp', layout=layout)
frame.sample_rate = audio_sample_rate
frame.pts = 0
output.mux(audio_stream.encode(frame))
# Flush encoder
output.mux(audio_stream.encode(None))
+
+ def as_trimmed(
+ self,
+ start_time: float | None = None,
+ duration: float | None = None,
+ strict_duration: bool = True,
+ ) -> VideoInput | None:
+ if self.get_duration() < start_time + duration:
+ return None
+ #TODO Consider tracking duration and trimming at time of save?
+ return VideoFromFile(self.get_stream_source(), start_time=start_time, duration=duration)
diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py
index ccf7b63d3..cd765a7c1 100644
--- a/comfy_extras/nodes_video.py
+++ b/comfy_extras/nodes_video.py
@@ -202,6 +202,56 @@ class LoadVideo(io.ComfyNode):
return True
+class VideoSlice(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return io.Schema(
+ node_id="Video Slice",
+ display_name="Video Slice",
+ search_aliases=[
+ "trim video duration",
+ "skip first frames",
+ "frame load cap",
+ "start time",
+ ],
+ category="image/video",
+ inputs=[
+ io.Video.Input("video"),
+ io.Float.Input(
+ "start_time",
+ default=0.0,
+ max=1e5,
+ min=-1e5,
+ step=0.001,
+ tooltip="Start time in seconds",
+ ),
+ io.Float.Input(
+ "duration",
+ default=0.0,
+ min=0.0,
+ step=0.001,
+ tooltip="Duration in seconds, or 0 for unlimited duration",
+ ),
+ io.Boolean.Input(
+ "strict_duration",
+ default=False,
+ tooltip="If True, when the specified duration is not possible, an error will be raised.",
+ ),
+ ],
+ outputs=[
+ io.Video.Output(),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, video: io.Video.Type, start_time: float, duration: float, strict_duration: bool) -> io.NodeOutput:
+ trimmed = video.as_trimmed(start_time, duration, strict_duration=strict_duration)
+ if trimmed is not None:
+ return io.NodeOutput(trimmed)
+ raise ValueError(
+ f"Failed to slice video:\nSource duration: {video.get_duration()}\nStart time: {start_time}\nTarget duration: {duration}"
+ )
+
class VideoExtension(ComfyExtension):
@override
@@ -212,6 +262,7 @@ class VideoExtension(ComfyExtension):
CreateVideo,
GetVideoComponents,
LoadVideo,
+ VideoSlice,
]
async def comfy_entrypoint() -> VideoExtension:
From cdcf4119b3e826bd69fa986772485fb5b44a54cd Mon Sep 17 00:00:00 2001
From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>
Date: Wed, 11 Feb 2026 10:45:19 +0800
Subject: [PATCH 14/85] [Trainer] training with proper offloading (#12189)
* Fix bypass dtype/device moving
* Force offloading mode for training
* training context var
* offloading implementation in training node
* fix wrong input type
* Support bypass load lora model, correct adapter/offloading handling
---
comfy/ldm/flux/math.py | 39 +++++---
comfy/model_management.py | 5 +
comfy/sampler_helpers.py | 16 +++-
comfy/weight_adapter/bypass.py | 20 ++--
comfy_extras/nodes_train.py | 162 ++++++++++++++++++++++++++++-----
5 files changed, 196 insertions(+), 46 deletions(-)
diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py
index f9597de5b..5e764bb46 100644
--- a/comfy/ldm/flux/math.py
+++ b/comfy/ldm/flux/math.py
@@ -29,19 +29,34 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
return out.to(dtype=torch.float32, device=pos.device)
+def _apply_rope1(x: Tensor, freqs_cis: Tensor):
+ x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
+
+ x_out = freqs_cis[..., 0] * x_[..., 0]
+ x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
+
+ return x_out.reshape(*x.shape).type_as(x)
+
+
+def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
+ return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
+
+
try:
import comfy.quant_ops
- apply_rope = comfy.quant_ops.ck.apply_rope
- apply_rope1 = comfy.quant_ops.ck.apply_rope1
+ q_apply_rope = comfy.quant_ops.ck.apply_rope
+ q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
+ def apply_rope(xq, xk, freqs_cis):
+ if comfy.model_management.in_training:
+ return _apply_rope(xq, xk, freqs_cis)
+ else:
+ return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
+ def apply_rope1(x, freqs_cis):
+ if comfy.model_management.in_training:
+ return _apply_rope1(x, freqs_cis)
+ else:
+ return q_apply_rope1(x, freqs_cis)
except:
logging.warning("No comfy kitchen, using old apply_rope functions.")
- def apply_rope1(x: Tensor, freqs_cis: Tensor):
- x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
-
- x_out = freqs_cis[..., 0] * x_[..., 0]
- x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
-
- return x_out.reshape(*x.shape).type_as(x)
-
- def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
- return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
+ apply_rope = _apply_rope
+ apply_rope1 = _apply_rope1
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 6018c1ab6..304931eb0 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -55,6 +55,11 @@ cpu_state = CPUState.GPU
total_vram = 0
+
+# Training Related State
+in_training = False
+
+
def get_supported_float8_types():
float8_types = []
try:
diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py
index 9134e6d71..1f75f2ba7 100644
--- a/comfy/sampler_helpers.py
+++ b/comfy/sampler_helpers.py
@@ -122,20 +122,26 @@ def estimate_memory(model, noise_shape, conds):
minimum_memory_required = model.model.memory_required([noise_shape[0]] + list(noise_shape[1:]), cond_shapes=cond_shapes_min)
return memory_required, minimum_memory_required
-def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
+def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
executor = comfy.patcher_extension.WrapperExecutor.new_executor(
_prepare_sampling,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
)
- return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load)
+ return executor.execute(model, noise_shape, conds, model_options=model_options, force_full_load=force_full_load, force_offload=force_offload)
-def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False):
+def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None, force_full_load=False, force_offload=False):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
- memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
- comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required + inference_memory, minimum_memory_required=minimum_memory_required + inference_memory, force_full_load=force_full_load)
+ if force_offload: # In training + offload enabled, we want to force prepare sampling to trigger partial load
+ memory_required = 1e20
+ minimum_memory_required = None
+ else:
+ memory_required, minimum_memory_required = estimate_memory(model, noise_shape, conds)
+ memory_required += inference_memory
+ minimum_memory_required += inference_memory
+ comfy.model_management.load_models_gpu([model] + models, memory_required=memory_required, minimum_memory_required=minimum_memory_required, force_full_load=force_full_load)
real_model = model.model
return real_model, conds, models
diff --git a/comfy/weight_adapter/bypass.py b/comfy/weight_adapter/bypass.py
index d4aaf98ca..b9d5ec7d9 100644
--- a/comfy/weight_adapter/bypass.py
+++ b/comfy/weight_adapter/bypass.py
@@ -21,6 +21,7 @@ from typing import Optional, Union
import torch
import torch.nn as nn
+import comfy.model_management
from .base import WeightAdapterBase, WeightAdapterTrainBase
from comfy.patcher_extension import PatcherInjection
@@ -181,18 +182,21 @@ class BypassForwardHook:
)
return # Already injected
- # Move adapter weights to module's device to avoid CPU-GPU transfer on every forward
- device = None
+ # Move adapter weights to compute device (GPU)
+ # Use get_torch_device() instead of module.weight.device because
+ # with offloading, module weights may be on CPU while compute happens on GPU
+ device = comfy.model_management.get_torch_device()
+
+ # Get dtype from module weight if available
dtype = None
if hasattr(self.module, "weight") and self.module.weight is not None:
- device = self.module.weight.device
dtype = self.module.weight.dtype
- elif hasattr(self.module, "W_q"): # Quantized layers might use different attr
- device = self.module.W_q.device
- dtype = self.module.W_q.dtype
- if device is not None:
- self._move_adapter_weights_to_device(device, dtype)
+ # Only use dtype if it's a standard float type, not quantized
+ if dtype is not None and dtype not in (torch.float32, torch.float16, torch.bfloat16):
+ dtype = None
+
+ self._move_adapter_weights_to_device(device, dtype)
self.original_forward = self.module.forward
self.module.forward = self._bypass_forward
diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py
index 024a89391..630eedc9f 100644
--- a/comfy_extras/nodes_train.py
+++ b/comfy_extras/nodes_train.py
@@ -4,6 +4,7 @@ import os
import numpy as np
import safetensors
import torch
+import torch.nn as nn
import torch.utils.checkpoint
from tqdm.auto import trange
from PIL import Image, ImageDraw, ImageFont
@@ -27,6 +28,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
"""
CFGGuider with modifications for training specific logic
"""
+
+ def __init__(self, *args, offloading=False, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.offloading = offloading
+
def outer_sample(
self,
noise,
@@ -45,9 +51,11 @@ class TrainGuider(comfy_extras.nodes_custom_sampler.Guider_Basic):
noise.shape,
self.conds,
self.model_options,
- force_full_load=True, # mirror behavior in TrainLoraNode.execute() to keep model loaded
+ force_full_load=not self.offloading,
+ force_offload=self.offloading,
)
)
+ torch.cuda.empty_cache()
device = self.model_patcher.load_device
if denoise_mask is not None:
@@ -404,16 +412,97 @@ def find_all_highest_child_module_with_forward(
return result
-def patch(m):
+def find_modules_at_depth(
+ model: nn.Module, depth: int = 1, result=None, current_depth=0, name=None
+) -> list[nn.Module]:
+ """
+ Find modules at a specific depth level for gradient checkpointing.
+
+ Args:
+ model: The model to search
+ depth: Target depth level (1 = top-level blocks, 2 = their children, etc.)
+ result: Accumulator for results
+ current_depth: Current recursion depth
+ name: Current module name for logging
+
+ Returns:
+ List of modules at the target depth
+ """
+ if result is None:
+ result = []
+ name = name or "root"
+
+ # Skip container modules (they don't have meaningful forward)
+ is_container = isinstance(model, (nn.ModuleList, nn.Sequential, nn.ModuleDict))
+ has_forward = hasattr(model, "forward") and not is_container
+
+ if has_forward:
+ current_depth += 1
+ if current_depth == depth:
+ result.append(model)
+ logging.debug(f"Found module at depth {depth}: {name} ({model.__class__.__name__})")
+ return result
+
+ # Recurse into children
+ for next_name, child in model.named_children():
+ find_modules_at_depth(child, depth, result, current_depth, f"{name}.{next_name}")
+
+ return result
+
+
+class OffloadCheckpointFunction(torch.autograd.Function):
+ """
+ Gradient checkpointing that works with weight offloading.
+
+ Forward: no_grad -> compute -> weights can be freed
+ Backward: enable_grad -> recompute -> backward -> weights can be freed
+
+ For single input, single output modules (Linear, Conv*).
+ """
+
+ @staticmethod
+ def forward(ctx, x: torch.Tensor, forward_fn):
+ ctx.save_for_backward(x)
+ ctx.forward_fn = forward_fn
+ with torch.no_grad():
+ return forward_fn(x)
+
+ @staticmethod
+ def backward(ctx, grad_out: torch.Tensor):
+ x, = ctx.saved_tensors
+ forward_fn = ctx.forward_fn
+
+ # Clear context early
+ ctx.forward_fn = None
+
+ with torch.enable_grad():
+ x_detached = x.detach().requires_grad_(True)
+ y = forward_fn(x_detached)
+ y.backward(grad_out)
+ grad_x = x_detached.grad
+
+ # Explicit cleanup
+ del y, x_detached, forward_fn
+
+ return grad_x, None
+
+
+def patch(m, offloading=False):
if not hasattr(m, "forward"):
return
org_forward = m.forward
- def fwd(args, kwargs):
- return org_forward(*args, **kwargs)
+ # Branch 1: Linear/Conv* -> offload-compatible checkpoint (single input/output)
+ if offloading and isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ def checkpointing_fwd(x):
+ return OffloadCheckpointFunction.apply(x, org_forward)
+ # Branch 2: Others -> standard checkpoint
+ else:
+ def fwd(args, kwargs):
+ return org_forward(*args, **kwargs)
- def checkpointing_fwd(*args, **kwargs):
- return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
+ def checkpointing_fwd(*args, **kwargs):
+ return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False)
m.org_forward = org_forward
m.forward = checkpointing_fwd
@@ -936,6 +1025,18 @@ class TrainLoraNode(io.ComfyNode):
default=True,
tooltip="Use gradient checkpointing for training.",
),
+ io.Int.Input(
+ "checkpoint_depth",
+ default=1,
+ min=1,
+ max=5,
+ tooltip="Depth level for gradient checkpointing.",
+ ),
+ io.Boolean.Input(
+ "offloading",
+ default=False,
+ tooltip="Depth level for gradient checkpointing.",
+ ),
io.Combo.Input(
"existing_lora",
options=folder_paths.get_filename_list("loras") + ["[None]"],
@@ -982,6 +1083,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype,
algorithm,
gradient_checkpointing,
+ checkpoint_depth,
+ offloading,
existing_lora,
bucket_mode,
bypass_mode,
@@ -1000,6 +1103,8 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = lora_dtype[0]
algorithm = algorithm[0]
gradient_checkpointing = gradient_checkpointing[0]
+ offloading = offloading[0]
+ checkpoint_depth = checkpoint_depth[0]
existing_lora = existing_lora[0]
bucket_mode = bucket_mode[0]
bypass_mode = bypass_mode[0]
@@ -1054,16 +1159,18 @@ class TrainLoraNode(io.ComfyNode):
# Setup gradient checkpointing
if gradient_checkpointing:
- for m in find_all_highest_child_module_with_forward(
- mp.model.diffusion_model
- ):
- patch(m)
+ modules_to_patch = find_modules_at_depth(
+ mp.model.diffusion_model, depth=checkpoint_depth
+ )
+ logging.info(f"Gradient checkpointing: patching {len(modules_to_patch)} modules at depth {checkpoint_depth}")
+ for m in modules_to_patch:
+ patch(m, offloading=offloading)
torch.cuda.empty_cache()
# With force_full_load=False we should be able to have offloading
# But for offloading in training we need custom AutoGrad hooks for fwd/bwd
comfy.model_management.load_models_gpu(
- [mp], memory_required=1e20, force_full_load=True
+ [mp], memory_required=1e20, force_full_load=not offloading
)
torch.cuda.empty_cache()
@@ -1100,7 +1207,7 @@ class TrainLoraNode(io.ComfyNode):
)
# Setup guider
- guider = TrainGuider(mp)
+ guider = TrainGuider(mp, offloading=offloading)
guider.set_conds(positive)
# Inject bypass hooks if bypass mode is enabled
@@ -1113,6 +1220,7 @@ class TrainLoraNode(io.ComfyNode):
# Run training loop
try:
+ comfy.model_management.in_training = True
_run_training_loop(
guider,
train_sampler,
@@ -1123,6 +1231,7 @@ class TrainLoraNode(io.ComfyNode):
multi_res,
)
finally:
+ comfy.model_management.in_training = False
# Eject bypass hooks if they were injected
if bypass_injections is not None:
for injection in bypass_injections:
@@ -1132,19 +1241,20 @@ class TrainLoraNode(io.ComfyNode):
unpatch(m)
del train_sampler, optimizer
- # Finalize adapters
+ for param in lora_sd:
+ lora_sd[param] = lora_sd[param].to(lora_dtype).detach()
+
for adapter in all_weight_adapters:
adapter.requires_grad_(False)
-
- for param in lora_sd:
- lora_sd[param] = lora_sd[param].to(lora_dtype)
+ del adapter
+ del all_weight_adapters
# mp in train node is highly specialized for training
# use it in inference will result in bad behavior so we don't return it
return io.NodeOutput(lora_sd, loss_map, steps + existing_steps)
-class LoraModelLoader(io.ComfyNode):#
+class LoraModelLoader(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
@@ -1166,6 +1276,11 @@ class LoraModelLoader(io.ComfyNode):#
max=100.0,
tooltip="How strongly to modify the diffusion model. This value can be negative.",
),
+ io.Boolean.Input(
+ "bypass",
+ default=False,
+ tooltip="When enabled, applies LoRA in bypass mode without modifying base model weights. Useful for training and when model weights are offloaded.",
+ ),
],
outputs=[
io.Model.Output(
@@ -1175,13 +1290,18 @@ class LoraModelLoader(io.ComfyNode):#
)
@classmethod
- def execute(cls, model, lora, strength_model):
+ def execute(cls, model, lora, strength_model, bypass=False):
if strength_model == 0:
return io.NodeOutput(model)
- model_lora, _ = comfy.sd.load_lora_for_models(
- model, None, lora, strength_model, 0
- )
+ if bypass:
+ model_lora, _ = comfy.sd.load_bypass_lora_for_models(
+ model, None, lora, strength_model, 0
+ )
+ else:
+ model_lora, _ = comfy.sd.load_lora_for_models(
+ model, None, lora, strength_model, 0
+ )
return io.NodeOutput(model_lora)
From 76a7fa96dbdc2eda89218601fe3aed5997df055f Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Tue, 10 Feb 2026 19:04:32 -0800
Subject: [PATCH 15/85] Make built in lora training work on anima. (#12402)
---
comfy/ldm/anima/model.py | 16 ++++++++++++++--
comfy/model_base.py | 12 ++++++++----
2 files changed, 22 insertions(+), 6 deletions(-)
diff --git a/comfy/ldm/anima/model.py b/comfy/ldm/anima/model.py
index 2e6ed58fa..6fb51c4a4 100644
--- a/comfy/ldm/anima/model.py
+++ b/comfy/ldm/anima/model.py
@@ -195,8 +195,20 @@ class Anima(MiniTrainDIT):
super().__init__(*args, **kwargs)
self.llm_adapter = LLMAdapter(device=kwargs.get("device"), dtype=kwargs.get("dtype"), operations=kwargs.get("operations"))
- def preprocess_text_embeds(self, text_embeds, text_ids):
+ def preprocess_text_embeds(self, text_embeds, text_ids, t5xxl_weights=None):
if text_ids is not None:
- return self.llm_adapter(text_embeds, text_ids)
+ out = self.llm_adapter(text_embeds, text_ids)
+ if t5xxl_weights is not None:
+ out = out * t5xxl_weights
+
+ if out.shape[1] < 512:
+ out = torch.nn.functional.pad(out, (0, 0, 0, 512 - out.shape[1]))
+ return out
else:
return text_embeds
+
+ def forward(self, x, timesteps, context, **kwargs):
+ t5xxl_ids = kwargs.pop("t5xxl_ids", None)
+ if t5xxl_ids is not None:
+ context = self.preprocess_text_embeds(context, t5xxl_ids, t5xxl_weights=kwargs.pop("t5xxl_weights", None))
+ return super().forward(x, timesteps, context, **kwargs)
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 858789b30..4a74cb1ce 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -1160,12 +1160,16 @@ class Anima(BaseModel):
device = kwargs["device"]
if cross_attn is not None:
if t5xxl_ids is not None:
- cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.unsqueeze(0).to(device=device))
if t5xxl_weights is not None:
- cross_attn *= t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
+ t5xxl_weights = t5xxl_weights.unsqueeze(0).unsqueeze(-1).to(cross_attn)
+ t5xxl_ids = t5xxl_ids.unsqueeze(0)
+
+ if torch.is_inference_mode_enabled(): # if not we are training
+ cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
+ else:
+ out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
+ out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
- if cross_attn.shape[1] < 512:
- cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, 0, 512 - cross_attn.shape[1]))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
From 2c7cef4a23c08e3f02a33c693d927158a15a11f6 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Wed, 11 Feb 2026 20:51:49 +0200
Subject: [PATCH 16/85] fix(api-nodes): retry on connection errors during
polling instead of aborting (#12393)
---
comfy_api_nodes/util/client.py | 8 ++++----
1 file changed, 4 insertions(+), 4 deletions(-)
diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py
index 8a1259506..391748e7a 100644
--- a/comfy_api_nodes/util/client.py
+++ b/comfy_api_nodes/util/client.py
@@ -143,9 +143,9 @@ async def poll_op(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
- max_retries_per_poll: int = 3,
+ max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0,
- retry_backoff_per_poll: float = 2.0,
+ retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
@@ -240,9 +240,9 @@ async def poll_op_raw(
poll_interval: float = 5.0,
max_poll_attempts: int = 160,
timeout_per_poll: float = 120.0,
- max_retries_per_poll: int = 3,
+ max_retries_per_poll: int = 10,
retry_delay_per_poll: float = 1.0,
- retry_backoff_per_poll: float = 2.0,
+ retry_backoff_per_poll: float = 1.4,
estimated_duration: int | None = None,
cancel_endpoint: ApiEndpoint | None = None,
cancel_timeout: float = 10.0,
From 4993411fd9a43d642971925272c3748d9e058131 Mon Sep 17 00:00:00 2001
From: Benjamin Lu
Date: Wed, 11 Feb 2026 11:15:13 -0800
Subject: [PATCH 17/85] Dispatch desktop auto-bump when a ComfyUI release is
published (#12398)
* Dispatch desktop auto-bump on ComfyUI release publish
* Fix release webhook secret checks in step conditions
* Require desktop dispatch token in release webhook
* Apply suggestion from @Copilot
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
---------
Co-authored-by: Luke Mino-Altherr
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Jedrzej Kosinski
---
.github/workflows/release-webhook.yml | 36 +++++++++++++++++++++++++++
1 file changed, 36 insertions(+)
diff --git a/.github/workflows/release-webhook.yml b/.github/workflows/release-webhook.yml
index 6fceb7560..737e4c488 100644
--- a/.github/workflows/release-webhook.yml
+++ b/.github/workflows/release-webhook.yml
@@ -7,6 +7,8 @@ on:
jobs:
send-webhook:
runs-on: ubuntu-latest
+ env:
+ DESKTOP_REPO_DISPATCH_TOKEN: ${{ secrets.DESKTOP_REPO_DISPATCH_TOKEN }}
steps:
- name: Send release webhook
env:
@@ -106,3 +108,37 @@ jobs:
--fail --silent --show-error
echo "✅ Release webhook sent successfully"
+
+ - name: Send repository dispatch to desktop
+ env:
+ DISPATCH_TOKEN: ${{ env.DESKTOP_REPO_DISPATCH_TOKEN }}
+ RELEASE_TAG: ${{ github.event.release.tag_name }}
+ RELEASE_URL: ${{ github.event.release.html_url }}
+ run: |
+ set -euo pipefail
+
+ if [ -z "${DISPATCH_TOKEN:-}" ]; then
+ echo "::error::DESKTOP_REPO_DISPATCH_TOKEN is required but not set."
+ exit 1
+ fi
+
+ PAYLOAD="$(jq -n \
+ --arg release_tag "$RELEASE_TAG" \
+ --arg release_url "$RELEASE_URL" \
+ '{
+ event_type: "comfyui_release_published",
+ client_payload: {
+ release_tag: $release_tag,
+ release_url: $release_url
+ }
+ }')"
+
+ curl -fsSL \
+ -X POST \
+ -H "Accept: application/vnd.github+json" \
+ -H "Content-Type: application/json" \
+ -H "Authorization: Bearer ${DISPATCH_TOKEN}" \
+ https://api.github.com/repos/Comfy-Org/desktop/dispatches \
+ -d "$PAYLOAD"
+
+ echo "✅ Dispatched ComfyUI release ${RELEASE_TAG} to Comfy-Org/desktop"
From 2b7cc7e3b69127a81b9232d4e8305eb678fa3d0c Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Wed, 11 Feb 2026 21:30:19 +0200
Subject: [PATCH 18/85] [API Nodes] enable Magnific Upscalers (#12179)
* feat(api-nodes): enable Magnific Upscalers
* update price badges
---------
Co-authored-by: Jedrzej Kosinski
---
comfy_api_nodes/nodes_magnific.py | 62 +++++++++++++++++++++++++++----
1 file changed, 55 insertions(+), 7 deletions(-)
diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py
index 013e71cc8..83a581c5d 100644
--- a/comfy_api_nodes/nodes_magnific.py
+++ b/comfy_api_nodes/nodes_magnific.py
@@ -30,6 +30,30 @@ from comfy_api_nodes.util import (
validate_image_dimensions,
)
+_EUR_TO_USD = 1.19
+
+
+def _tier_price_eur(megapixels: float) -> float:
+ """Price in EUR for a single Magnific upscaling step based on input megapixels."""
+ if megapixels <= 1.3:
+ return 0.143
+ if megapixels <= 3.0:
+ return 0.286
+ if megapixels <= 6.4:
+ return 0.429
+ return 1.716
+
+
+def _calculate_magnific_upscale_price_usd(width: int, height: int, scale: int) -> float:
+ """Calculate total Magnific upscale price in USD for given input dimensions and scale factor."""
+ num_steps = int(math.log2(scale))
+ total_eur = 0.0
+ pixels = width * height
+ for _ in range(num_steps):
+ total_eur += _tier_price_eur(pixels / 1_000_000)
+ pixels *= 4
+ return round(total_eur * _EUR_TO_USD, 2)
+
class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
@classmethod
@@ -103,11 +127,20 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["scale_factor", "auto_downscale"]),
expr="""
(
- $max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
- {"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
+ $ad := widgets.auto_downscale;
+ $mins := $ad
+ ? {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.515}
+ : {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
+ $maxs := {"2x": 0.515, "4x": 0.844, "8x": 1.015, "16x": 1.187};
+ {
+ "type": "range_usd",
+ "min_usd": $lookup($mins, widgets.scale_factor),
+ "max_usd": $lookup($maxs, widgets.scale_factor),
+ "format": { "approximate": true }
+ }
)
""",
),
@@ -168,6 +201,10 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
f"Use a smaller input image or lower scale factor."
)
+ final_height, final_width = get_image_dimensions(image)
+ actual_scale = int(scale_factor.rstrip("x"))
+ price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, actual_scale)
+
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler", method="POST"),
@@ -189,6 +226,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler/{initial_res.task_id}"),
response_model=TaskResponse,
status_extractor=lambda x: x.status,
+ price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
@@ -257,8 +295,14 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
depends_on=IO.PriceBadgeDepends(widgets=["scale_factor"]),
expr="""
(
- $max := widgets.scale_factor = "2x" ? 1.326 : 1.657;
- {"type": "range_usd", "min_usd": 0.11, "max_usd": $max}
+ $mins := {"2x": 0.172, "4x": 0.343, "8x": 0.515, "16x": 0.844};
+ $maxs := {"2x": 2.045, "4x": 2.545, "8x": 2.889, "16x": 3.06};
+ {
+ "type": "range_usd",
+ "min_usd": $lookup($mins, widgets.scale_factor),
+ "max_usd": $lookup($maxs, widgets.scale_factor),
+ "format": { "approximate": true }
+ }
)
""",
),
@@ -321,6 +365,9 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
f"Use a smaller input image or lower scale factor."
)
+ final_height, final_width = get_image_dimensions(image)
+ price_usd = _calculate_magnific_upscale_price_usd(final_width, final_height, requested_scale)
+
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/freepik/v1/ai/image-upscaler-precision-v2", method="POST"),
@@ -339,6 +386,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/freepik/v1/ai/image-upscaler-precision-v2/{initial_res.task_id}"),
response_model=TaskResponse,
status_extractor=lambda x: x.status,
+ price_extractor=lambda _: price_usd,
poll_interval=10.0,
max_poll_attempts=480,
)
@@ -877,8 +925,8 @@ class MagnificExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
- # MagnificImageUpscalerCreativeNode,
- # MagnificImageUpscalerPreciseV2Node,
+ MagnificImageUpscalerCreativeNode,
+ MagnificImageUpscalerPreciseV2Node,
MagnificImageStyleTransferNode,
MagnificImageRelightNode,
MagnificImageSkinEnhancerNode,
From d297a749a2fa3a34ebff898797feef161bcd64c6 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Wed, 11 Feb 2026 11:50:16 -0800
Subject: [PATCH 19/85] dynamic_vram: Fix windows Aimdo crash + Fix LLM
performance (#12408)
* model_management: lazy-cache aimdo_tensor
These tensors cosntructed from aimdo-allocations are CPU expensive to
make on the pytorch side. Add a cache version that will be valid with
signature match to fast path past whatever torch is doing.
* dynamic_vram: Minimize fast path CPU work
Move as much as possible inside the not resident if block and cache
the formed weight and bias rather than the flat intermediates. In
extreme layer weight rates this adds up.
---
comfy/model_management.py | 8 ++++++--
comfy/model_patcher.py | 2 --
comfy/ops.py | 21 ++++++++++++++-------
3 files changed, 20 insertions(+), 11 deletions(-)
diff --git a/comfy/model_management.py b/comfy/model_management.py
index 304931eb0..38c3e482b 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -1213,8 +1213,12 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
- v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, weight._v_tensor)[0]
- if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
+ if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
+ v_tensor = weight._v_tensor
+ else:
+ raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
+ v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
+ weight._v_tensor = v_tensor
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 19c9031ea..224e218e3 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -1542,7 +1542,6 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size)
- m._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(m._v, device_to)
allocated_size += v_weight_size
else:
@@ -1557,7 +1556,6 @@ class ModelPatcherDynamic(ModelPatcher):
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
- weight._v_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device_to)
weight._model_dtype = model_dtype
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
diff --git a/comfy/ops.py b/comfy/ops.py
index 33803b223..688937e43 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -83,14 +83,18 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
offload_stream = None
xfer_dest = None
- cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
- if signature is not None:
- xfer_dest = s._v_tensor
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
+ if signature is not None:
+ if resident:
+ weight = s._v_weight
+ bias = s._v_bias
+ else:
+ xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
if not resident:
+ cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
xfer_source = [ s.weight, s.bias ]
@@ -140,9 +144,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
- params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
- weight = params[0]
- bias = params[1]
+ params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
+ weight = params[0]
+ bias = params[1]
+ if signature is not None:
+ s._v_weight = weight
+ s._v_bias = bias
+ s._v_signature=signature
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
@@ -182,7 +190,6 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
- s._v_signature=signature
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)
From 2a4328d639810858aa625c7bfedb974a13a57abe Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Wed, 11 Feb 2026 11:53:42 -0800
Subject: [PATCH 20/85] ace15: Use dynamic_vram friendly trange (#12409)
Factor out the ksampler trange and use it in ACE LLM to prevent the
silent stall at 0 and rate distortion due to first-step model load.
---
comfy/k_diffusion/sampling.py | 32 ++------------------------------
comfy/text_encoders/ace15.py | 3 +--
comfy/utils.py | 27 +++++++++++++++++++++++++++
3 files changed, 30 insertions(+), 32 deletions(-)
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index c0c51d51a..6978eb717 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -1,12 +1,11 @@
import math
-import time
from functools import partial
from scipy import integrate
import torch
from torch import nn
import torchsde
-from tqdm.auto import trange as trange_, tqdm
+from tqdm.auto import tqdm
from . import utils
from . import deis
@@ -15,34 +14,7 @@ import comfy.model_patcher
import comfy.model_sampling
import comfy.memory_management
-
-
-def trange(*args, **kwargs):
- if comfy.memory_management.aimdo_allocator is None:
- return trange_(*args, **kwargs)
-
- pbar = trange_(*args, **kwargs, smoothing=1.0)
- pbar._i = 0
- pbar.set_postfix_str(" Model Initializing ... ")
-
- _update = pbar.update
-
- def warmup_update(n=1):
- pbar._i += 1
- if pbar._i == 1:
- pbar.i1_time = time.time()
- pbar.set_postfix_str(" Model Initialization complete! ")
- elif pbar._i == 2:
- #bring forward the effective start time based the the diff between first and second iteration
- #to attempt to remove load overhead from the final step rate estimate.
- pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
- pbar.set_postfix_str("")
-
- _update(n)
-
- pbar.update = warmup_update
- return pbar
-
+from comfy.utils import model_trange as trange
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py
index 73697b3c1..b8198a820 100644
--- a/comfy/text_encoders/ace15.py
+++ b/comfy/text_encoders/ace15.py
@@ -3,7 +3,6 @@ import comfy.text_encoders.llama
from comfy import sd1_clip
import torch
import math
-from tqdm.auto import trange
import yaml
import comfy.utils
@@ -52,7 +51,7 @@ def sample_manual_loop_no_classes(
progress_bar = comfy.utils.ProgressBar(max_new_tokens)
- for step in trange(max_new_tokens, desc="LM sampling"):
+ for step in comfy.utils.model_trange(max_new_tokens, desc="LM sampling"):
outputs = model.transformer(None, attention_mask, embeds=embeds.to(execution_dtype), num_tokens=num_tokens, intermediate_output=None, dtype=execution_dtype, embeds_info=embeds_info, past_key_values=past_key_values)
next_token_logits = model.transformer.logits(outputs[0])[:, -1]
past_key_values = outputs[2]
diff --git a/comfy/utils.py b/comfy/utils.py
index edd80cebe..e0a94e2e1 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -27,6 +27,7 @@ from PIL import Image
import logging
import itertools
from torch.nn.functional import interpolate
+from tqdm.auto import trange
from einops import rearrange
from comfy.cli_args import args, enables_dynamic_vram
import json
@@ -1155,6 +1156,32 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
+def model_trange(*args, **kwargs):
+ if comfy.memory_management.aimdo_allocator is None:
+ return trange(*args, **kwargs)
+
+ pbar = trange(*args, **kwargs, smoothing=1.0)
+ pbar._i = 0
+ pbar.set_postfix_str(" Model Initializing ... ")
+
+ _update = pbar.update
+
+ def warmup_update(n=1):
+ pbar._i += 1
+ if pbar._i == 1:
+ pbar.i1_time = time.time()
+ pbar.set_postfix_str(" Model Initialization complete! ")
+ elif pbar._i == 2:
+ #bring forward the effective start time based the the diff between first and second iteration
+ #to attempt to remove load overhead from the final step rate estimate.
+ pbar.start_t = pbar.i1_time - (time.time() - pbar.i1_time)
+ pbar.set_postfix_str("")
+
+ _update(n)
+
+ pbar.update = warmup_update
+ return pbar
+
PROGRESS_BAR_ENABLED = True
def set_progress_bar_enabled(enabled):
global PROGRESS_BAR_ENABLED
From 3fe61cedda090c744dcf6f579ed48744fa66ef5f Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Wed, 11 Feb 2026 11:54:02 -0800
Subject: [PATCH 21/85] model_patcher: guard against none model_dtype (#12410)
Handle the case where the _model_dtype exists but is none with the
intended fallback.
---
comfy/model_patcher.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 224e218e3..f278fccac 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -1525,7 +1525,7 @@ class ModelPatcherDynamic(ModelPatcher):
setattr(m, param_key + "_function", weight_function)
geometry = weight
if not isinstance(weight, QuantizedTensor):
- model_dtype = getattr(m, param_key + "_comfy_model_dtype", weight.dtype)
+ model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry)
@@ -1551,7 +1551,7 @@ class ModelPatcherDynamic(ModelPatcher):
weight.seed_key = key
set_dirty(weight, dirty)
geometry = weight
- model_dtype = getattr(m, param + "_comfy_model_dtype", weight.dtype)
+ model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
From e5ae670a4016d3698a806e7f840fcecc50639848 Mon Sep 17 00:00:00 2001
From: askmyteapot <62238146+askmyteapot@users.noreply.github.com>
Date: Thu, 12 Feb 2026 11:28:48 +1000
Subject: [PATCH 22/85] Update ace15.py to allow min_p sampling (#12373)
---
comfy/text_encoders/ace15.py | 15 ++++++++++++---
comfy_extras/nodes_ace.py | 5 +++--
2 files changed, 15 insertions(+), 5 deletions(-)
diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py
index b8198a820..0fdd4669f 100644
--- a/comfy/text_encoders/ace15.py
+++ b/comfy/text_encoders/ace15.py
@@ -16,6 +16,7 @@ def sample_manual_loop_no_classes(
temperature: float = 0.85,
top_p: float = 0.9,
top_k: int = None,
+ min_p: float = 0.000,
seed: int = 1,
min_tokens: int = 1,
max_new_tokens: int = 2048,
@@ -80,6 +81,12 @@ def sample_manual_loop_no_classes(
min_val = top_k_vals[..., -1, None]
cfg_logits[cfg_logits < min_val] = remove_logit_value
+ if min_p is not None and min_p > 0:
+ probs = torch.softmax(cfg_logits, dim=-1)
+ p_max = probs.max(dim=-1, keepdim=True).values
+ indices_to_remove = probs < (min_p * p_max)
+ cfg_logits[indices_to_remove] = remove_logit_value
+
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(cfg_logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
@@ -110,7 +117,7 @@ def sample_manual_loop_no_classes(
return output_audio_codes
-def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0):
+def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=1024, seed=0, cfg_scale=2.0, temperature=0.85, top_p=0.9, top_k=0, min_p=0.000):
positive = [[token for token, _ in inner_list] for inner_list in positive]
positive = positive[0]
@@ -134,7 +141,7 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
paddings = []
ids = [positive]
- return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
+ return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
@@ -192,6 +199,7 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
temperature = kwargs.get("temperature", 0.85)
top_p = kwargs.get("top_p", 0.9)
top_k = kwargs.get("top_k", 0.0)
+ min_p = kwargs.get("min_p", 0.000)
duration = math.ceil(duration)
kwargs["duration"] = duration
@@ -239,6 +247,7 @@ class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
+ "min_p": min_p,
}
return out
@@ -299,7 +308,7 @@ class ACE15TEModel(torch.nn.Module):
lm_metadata = token_weight_pairs["lm_metadata"]
if lm_metadata["generate_audio_codes"]:
- audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["max_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"])
+ audio_codes = generate_audio_codes(getattr(self, self.lm_model, self.qwen3_06b), token_weight_pairs["lm_prompt"], token_weight_pairs["lm_prompt_negative"], min_tokens=lm_metadata["min_tokens"], max_tokens=lm_metadata["min_tokens"], seed=lm_metadata["seed"], cfg_scale=lm_metadata["cfg_scale"], temperature=lm_metadata["temperature"], top_p=lm_metadata["top_p"], top_k=lm_metadata["top_k"], min_p=lm_metadata["min_p"])
out["audio_codes"] = [audio_codes]
return base_out, None, out
diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py
index dde5bbd2a..9cf84ab4d 100644
--- a/comfy_extras/nodes_ace.py
+++ b/comfy_extras/nodes_ace.py
@@ -49,13 +49,14 @@ class TextEncodeAceStepAudio15(io.ComfyNode):
io.Float.Input("temperature", default=0.85, min=0.0, max=2.0, step=0.01, advanced=True),
io.Float.Input("top_p", default=0.9, min=0.0, max=2000.0, step=0.01, advanced=True),
io.Int.Input("top_k", default=0, min=0, max=100, advanced=True),
+ io.Float.Input("min_p", default=0.000, min=0.0, max=1.0, step=0.001, advanced=True),
],
outputs=[io.Conditioning.Output()],
)
@classmethod
- def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k) -> io.NodeOutput:
- tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k)
+ def execute(cls, clip, tags, lyrics, seed, bpm, duration, timesignature, language, keyscale, generate_audio_codes, cfg_scale, temperature, top_p, top_k, min_p) -> io.NodeOutput:
+ tokens = clip.tokenize(tags, lyrics=lyrics, bpm=bpm, duration=duration, timesignature=int(timesignature), language=language, keyscale=keyscale, seed=seed, generate_audio_codes=generate_audio_codes, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p)
conditioning = clip.encode_from_tokens_scheduled(tokens)
return io.NodeOutput(conditioning)
From 66c18522fbcde5b62731e3fb080a84b14e3dacfc Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Wed, 11 Feb 2026 19:12:16 -0800
Subject: [PATCH 23/85] Add a tip for common error. (#12414)
---
execution.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/execution.py b/execution.py
index 896862c6b..f549a2f0f 100644
--- a/execution.py
+++ b/execution.py
@@ -623,6 +623,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.")
comfy.model_management.unload_all_models()
+ elif isinstance(ex, RuntimeError) and ("mat1 and mat2 shapes" in str(ex)) and "Sampler" in class_type:
+ tips = "\n\nTIPS: If you have any \"Load CLIP\" or \"*CLIP Loader\" nodes in your workflow connected to this sampler node make sure the correct file(s) and type is selected."
error_details = {
"node_id": real_node_id,
From 4a93a62371b64f9d11a140a09faf985c48902d2e Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Thu, 12 Feb 2026 11:38:51 +0200
Subject: [PATCH 24/85] fix(api-nodes): add separate retry budget for 429 rate
limit responses (#12421)
---
comfy_api_nodes/util/client.py | 207 ++++++++++++-----------
comfy_api_nodes/util/download_helpers.py | 30 ++--
comfy_api_nodes/util/request_logger.py | 66 ++++----
comfy_api_nodes/util/upload_helpers.py | 55 +++---
4 files changed, 177 insertions(+), 181 deletions(-)
diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py
index 391748e7a..94886af7b 100644
--- a/comfy_api_nodes/util/client.py
+++ b/comfy_api_nodes/util/client.py
@@ -57,6 +57,7 @@ class _RequestConfig:
files: dict[str, Any] | list[tuple[str, Any]] | None
multipart_parser: Callable | None
max_retries: int
+ max_retries_on_rate_limit: int
retry_delay: float
retry_backoff: float
wait_label: str = "Waiting"
@@ -65,6 +66,7 @@ class _RequestConfig:
final_label_on_success: str | None = "Completed"
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
+ is_rate_limited: Callable[[int, Any], bool] | None = None
@dataclass
@@ -78,7 +80,7 @@ class _PollUIState:
active_since: float | None = None # start time of current active interval (None if queued)
-_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
+_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
@@ -103,6 +105,8 @@ async def sync_op(
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
+ max_retries_on_rate_limit: int = 16,
+ is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> M:
raw = await sync_op_raw(
cls,
@@ -122,6 +126,8 @@ async def sync_op(
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
monitor_progress=monitor_progress,
+ max_retries_on_rate_limit=max_retries_on_rate_limit,
+ is_rate_limited=is_rate_limited,
)
if not isinstance(raw, dict):
raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).")
@@ -194,6 +200,8 @@ async def sync_op_raw(
final_label_on_success: str | None = "Completed",
progress_origin_ts: float | None = None,
monitor_progress: bool = True,
+ max_retries_on_rate_limit: int = 16,
+ is_rate_limited: Callable[[int, Any], bool] | None = None,
) -> dict[str, Any] | bytes:
"""
Make a single network request.
@@ -222,6 +230,8 @@ async def sync_op_raw(
final_label_on_success=final_label_on_success,
progress_origin_ts=progress_origin_ts,
price_extractor=price_extractor,
+ max_retries_on_rate_limit=max_retries_on_rate_limit,
+ is_rate_limited=is_rate_limited,
)
return await _request_base(cfg, expect_binary=as_binary)
@@ -506,7 +516,7 @@ def _friendly_http_message(status: int, body: Any) -> str:
if status == 409:
return "There is a problem with your account. Please contact support@comfy.org."
if status == 429:
- return "Rate Limit Exceeded: Please try again later."
+ return "Rate Limit Exceeded: The server returned 429 after all retry attempts. Please wait and try again."
try:
if isinstance(body, dict):
err = body.get("error")
@@ -586,6 +596,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic()
attempt = 0
delay = cfg.retry_delay
+ rate_limit_attempts = 0
+ rate_limit_delay = cfg.retry_delay
operation_succeeded: bool = False
final_elapsed_seconds: int | None = None
extracted_price: float | None = None
@@ -653,17 +665,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
payload_headers["Content-Type"] = "application/json"
payload_kw["json"] = cfg.data or {}
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- request_headers=dict(payload_headers) if payload_headers else None,
- request_params=dict(params) if params else None,
- request_data=request_body_log,
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] request logging failed: %s", _log_e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ request_headers=dict(payload_headers) if payload_headers else None,
+ request_params=dict(params) if params else None,
+ request_data=request_body_log,
+ )
req_coro = sess.request(method, url, params=params, **payload_kw)
req_task = asyncio.create_task(req_coro)
@@ -688,41 +697,33 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
body = await resp.json()
except (ContentTypeError, json.JSONDecodeError):
body = await resp.text()
- if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries:
+ should_retry = False
+ wait_time = 0.0
+ retry_label = ""
+ is_rl = resp.status == 429 or (
+ cfg.is_rate_limited is not None and cfg.is_rate_limited(resp.status, body)
+ )
+ if is_rl and rate_limit_attempts < cfg.max_retries_on_rate_limit:
+ rate_limit_attempts += 1
+ wait_time = min(rate_limit_delay, 30.0)
+ rate_limit_delay *= cfg.retry_backoff
+ retry_label = f"rate-limit retry {rate_limit_attempts} of {cfg.max_retries_on_rate_limit}"
+ should_retry = True
+ elif resp.status in _RETRY_STATUS and (attempt - rate_limit_attempts) <= cfg.max_retries:
+ wait_time = delay
+ delay *= cfg.retry_backoff
+ retry_label = f"retry {attempt - rate_limit_attempts} of {cfg.max_retries}"
+ should_retry = True
+
+ if should_retry:
logging.warning(
- "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).",
+ "HTTP %s %s -> %s. Waiting %.2fs (%s).",
method,
url,
resp.status,
- delay,
- attempt,
- cfg.max_retries,
+ wait_time,
+ retry_label,
)
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content=body,
- error_message=_friendly_http_message(resp.status, body),
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] response logging failed: %s", _log_e)
-
- await sleep_with_interrupt(
- delay,
- cfg.node_cls,
- cfg.wait_label if cfg.monitor_progress else None,
- start_time if cfg.monitor_progress else None,
- cfg.estimated_total,
- display_callback=_display_time_progress if cfg.monitor_progress else None,
- )
- delay *= cfg.retry_backoff
- continue
- msg = _friendly_http_message(resp.status, body)
- try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
@@ -730,10 +731,27 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_content=body,
- error_message=msg,
+ error_message=f"HTTP {resp.status} ({retry_label}, will retry in {wait_time:.1f}s)",
)
- except Exception as _log_e:
- logging.debug("[DEBUG] response logging failed: %s", _log_e)
+ await sleep_with_interrupt(
+ wait_time,
+ cfg.node_cls,
+ cfg.wait_label if cfg.monitor_progress else None,
+ start_time if cfg.monitor_progress else None,
+ cfg.estimated_total,
+ display_callback=_display_time_progress if cfg.monitor_progress else None,
+ )
+ continue
+ msg = _friendly_http_message(resp.status, body)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content=body,
+ error_message=msg,
+ )
raise Exception(msg)
if expect_binary:
@@ -753,17 +771,14 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
bytes_payload = bytes(buff)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content=bytes_payload,
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] response logging failed: %s", _log_e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content=bytes_payload,
+ )
return bytes_payload
else:
try:
@@ -780,45 +795,39 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content=response_content_to_log,
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] response logging failed: %s", _log_e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content=response_content_to_log,
+ )
return payload
except ProcessingInterrupted:
logging.debug("Polling was interrupted by user")
raise
except (ClientError, OSError) as e:
- if attempt <= cfg.max_retries:
+ if (attempt - rate_limit_attempts) <= cfg.max_retries:
logging.warning(
"Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s",
method,
url,
delay,
- attempt,
+ attempt - rate_limit_attempts,
cfg.max_retries,
str(e),
)
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- request_headers=dict(payload_headers) if payload_headers else None,
- request_params=dict(params) if params else None,
- request_data=request_body_log,
- error_message=f"{type(e).__name__}: {str(e)} (will retry)",
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] request error logging failed: %s", _log_e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ request_headers=dict(payload_headers) if payload_headers else None,
+ request_params=dict(params) if params else None,
+ request_data=request_body_log,
+ error_message=f"{type(e).__name__}: {str(e)} (will retry)",
+ )
await sleep_with_interrupt(
delay,
cfg.node_cls,
@@ -831,23 +840,6 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
continue
diag = await _diagnose_connectivity()
if not diag["internet_accessible"]:
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method=method,
- request_url=url,
- request_headers=dict(payload_headers) if payload_headers else None,
- request_params=dict(params) if params else None,
- request_data=request_body_log,
- error_message=f"LocalNetworkError: {str(e)}",
- )
- except Exception as _log_e:
- logging.debug("[DEBUG] final error logging failed: %s", _log_e)
- raise LocalNetworkError(
- "Unable to connect to the API server due to local network issues. "
- "Please check your internet connection and try again."
- ) from e
- try:
request_logger.log_request_response(
operation_id=operation_id,
request_method=method,
@@ -855,10 +847,21 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_headers=dict(payload_headers) if payload_headers else None,
request_params=dict(params) if params else None,
request_data=request_body_log,
- error_message=f"ApiServerError: {str(e)}",
+ error_message=f"LocalNetworkError: {str(e)}",
)
- except Exception as _log_e:
- logging.debug("[DEBUG] final error logging failed: %s", _log_e)
+ raise LocalNetworkError(
+ "Unable to connect to the API server due to local network issues. "
+ "Please check your internet connection and try again."
+ ) from e
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ request_headers=dict(payload_headers) if payload_headers else None,
+ request_params=dict(params) if params else None,
+ request_data=request_body_log,
+ error_message=f"ApiServerError: {str(e)}",
+ )
raise ApiServerError(
f"The API server at {default_base_url()} is currently unreachable. "
f"The service may be experiencing issues."
diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py
index 78bcf1fa1..aa588d038 100644
--- a/comfy_api_nodes/util/download_helpers.py
+++ b/comfy_api_nodes/util/download_helpers.py
@@ -167,27 +167,25 @@ async def download_url_to_bytesio(
with contextlib.suppress(Exception):
dest.seek(0)
- with contextlib.suppress(Exception):
- request_logger.log_request_response(
- operation_id=op_id,
- request_method="GET",
- request_url=url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content=f"[streamed {written} bytes to dest]",
- )
+ request_logger.log_request_response(
+ operation_id=op_id,
+ request_method="GET",
+ request_url=url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content=f"[streamed {written} bytes to dest]",
+ )
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (ClientError, OSError) as e:
if attempt <= max_retries:
- with contextlib.suppress(Exception):
- request_logger.log_request_response(
- operation_id=op_id,
- request_method="GET",
- request_url=url,
- error_message=f"{type(e).__name__}: {str(e)} (will retry)",
- )
+ request_logger.log_request_response(
+ operation_id=op_id,
+ request_method="GET",
+ request_url=url,
+ error_message=f"{type(e).__name__}: {str(e)} (will retry)",
+ )
await sleep_with_interrupt(delay, cls, None, None, None)
delay *= retry_backoff
continue
diff --git a/comfy_api_nodes/util/request_logger.py b/comfy_api_nodes/util/request_logger.py
index e0cb4428d..fe0543d9b 100644
--- a/comfy_api_nodes/util/request_logger.py
+++ b/comfy_api_nodes/util/request_logger.py
@@ -8,7 +8,6 @@ from typing import Any
import folder_paths
-# Get the logger instance
logger = logging.getLogger(__name__)
@@ -91,38 +90,41 @@ def log_request_response(
Filenames are sanitized and length-limited for cross-platform safety.
If we still fail to write, we fall back to appending into api.log.
"""
- log_dir = get_log_directory()
- filepath = _build_log_filepath(log_dir, operation_id, request_url)
-
- log_content: list[str] = []
- log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
- log_content.append(f"Operation ID: {operation_id}")
- log_content.append("-" * 30 + " REQUEST " + "-" * 30)
- log_content.append(f"Method: {request_method}")
- log_content.append(f"URL: {request_url}")
- if request_headers:
- log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
- if request_params:
- log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
- if request_data is not None:
- log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
-
- log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
- if response_status_code is not None:
- log_content.append(f"Status Code: {response_status_code}")
- if response_headers:
- log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
- if response_content is not None:
- log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
- if error_message:
- log_content.append(f"Error:\n{error_message}")
-
try:
- with open(filepath, "w", encoding="utf-8") as f:
- f.write("\n".join(log_content))
- logger.debug("API log saved to: %s", filepath)
- except Exception as e:
- logger.error("Error writing API log to %s: %s", filepath, str(e))
+ log_dir = get_log_directory()
+ filepath = _build_log_filepath(log_dir, operation_id, request_url)
+
+ log_content: list[str] = []
+ log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
+ log_content.append(f"Operation ID: {operation_id}")
+ log_content.append("-" * 30 + " REQUEST " + "-" * 30)
+ log_content.append(f"Method: {request_method}")
+ log_content.append(f"URL: {request_url}")
+ if request_headers:
+ log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
+ if request_params:
+ log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
+ if request_data is not None:
+ log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
+
+ log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
+ if response_status_code is not None:
+ log_content.append(f"Status Code: {response_status_code}")
+ if response_headers:
+ log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
+ if response_content is not None:
+ log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
+ if error_message:
+ log_content.append(f"Error:\n{error_message}")
+
+ try:
+ with open(filepath, "w", encoding="utf-8") as f:
+ f.write("\n".join(log_content))
+ logger.debug("API log saved to: %s", filepath)
+ except Exception as e:
+ logger.error("Error writing API log to %s: %s", filepath, str(e))
+ except Exception as _log_e:
+ logging.debug("[DEBUG] log_request_response failed: %s", _log_e)
if __name__ == '__main__':
diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py
index 83d936ce1..7cc565263 100644
--- a/comfy_api_nodes/util/upload_helpers.py
+++ b/comfy_api_nodes/util/upload_helpers.py
@@ -255,17 +255,14 @@ async def upload_file(
monitor_task = asyncio.create_task(_monitor())
sess: aiohttp.ClientSession | None = None
try:
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method="PUT",
- request_url=upload_url,
- request_headers=headers or None,
- request_params=None,
- request_data=f"[File data {len(data)} bytes]",
- )
- except Exception as e:
- logging.debug("[DEBUG] upload request logging failed: %s", e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method="PUT",
+ request_url=upload_url,
+ request_headers=headers or None,
+ request_params=None,
+ request_data=f"[File data {len(data)} bytes]",
+ )
sess = aiohttp.ClientSession(timeout=timeout)
req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers)
@@ -311,31 +308,27 @@ async def upload_file(
delay *= retry_backoff
continue
raise Exception(f"Failed to upload (HTTP {resp.status}).")
- try:
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method="PUT",
- request_url=upload_url,
- response_status_code=resp.status,
- response_headers=dict(resp.headers),
- response_content="File uploaded successfully.",
- )
- except Exception as e:
- logging.debug("[DEBUG] upload response logging failed: %s", e)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method="PUT",
+ request_url=upload_url,
+ response_status_code=resp.status,
+ response_headers=dict(resp.headers),
+ response_content="File uploaded successfully.",
+ )
return
except asyncio.CancelledError:
raise ProcessingInterrupted("Task cancelled") from None
except (aiohttp.ClientError, OSError) as e:
if attempt <= max_retries:
- with contextlib.suppress(Exception):
- request_logger.log_request_response(
- operation_id=operation_id,
- request_method="PUT",
- request_url=upload_url,
- request_headers=headers or None,
- request_data=f"[File data {len(data)} bytes]",
- error_message=f"{type(e).__name__}: {str(e)} (will retry)",
- )
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method="PUT",
+ request_url=upload_url,
+ request_headers=headers or None,
+ request_data=f"[File data {len(data)} bytes]",
+ error_message=f"{type(e).__name__}: {str(e)} (will retry)",
+ )
await sleep_with_interrupt(
delay,
cls,
From 117e2143543dd649d47345e183748a82d48d12d3 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Thu, 12 Feb 2026 16:51:50 -0800
Subject: [PATCH 25/85] ModelPatcherDynamic: force load non leaf weights
(#12433)
The current behaviour of the default ModelPatcher is to .to a model
only if its fully loaded, which is how random non-leaf weights get
loaded in non-LowVRAM conditions.
The however means they never get loaded in dynamic_vram. In the
dynamic_vram case, force load them to the GPU.
---
comfy/model_patcher.py | 19 ++++++++++---------
1 file changed, 10 insertions(+), 9 deletions(-)
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index f278fccac..b1d907ba4 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -679,18 +679,19 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
- def _load_list(self, prio_comfy_cast_weights=False):
+ def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
loading = []
for n, m in self.model.named_modules():
- params = []
- skip = False
- for name, param in m.named_parameters(recurse=False):
- params.append(name)
+ default = False
+ params = { name: param for name, param in m.named_parameters(recurse=False) }
for name, param in m.named_parameters(recurse=True):
if name not in params:
- skip = True # skip random weights in non leaf modules
+ default = True # default random weights in non leaf modules
break
- if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
+ if default and default_device is not None:
+ for param in params.values():
+ param.data = param.data.to(device=default_device)
+ if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
@@ -1495,7 +1496,7 @@ class ModelPatcherDynamic(ModelPatcher):
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
- loading = self._load_list(prio_comfy_cast_weights=True)
+ loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True)
for x in loading:
@@ -1579,7 +1580,7 @@ class ModelPatcherDynamic(ModelPatcher):
return 0 if vbar is None else vbar.free_memory(memory_to_free)
def partially_unload_ram(self, ram_to_unload):
- loading = self._load_list(prio_comfy_cast_weights=True)
+ loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
for x in loading:
_, _, _, _, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
From ae79e33345dc893f8e0632c380c0e91dc09ac6e8 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Thu, 12 Feb 2026 16:56:42 -0800
Subject: [PATCH 26/85] llama: use a more efficient rope implementation
(#12434)
Get rid of the cat and unary negation and inplace add-cmul the two
halves of the rope. Precompute -sin once at the start of the model
rather than every transformer block.
This is slightly faster on both GPU and CPU bound setups.
---
comfy/text_encoders/llama.py | 25 ++++++++++++++-----------
1 file changed, 14 insertions(+), 11 deletions(-)
diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py
index b6735d210..54f3d5595 100644
--- a/comfy/text_encoders/llama.py
+++ b/comfy/text_encoders/llama.py
@@ -355,13 +355,6 @@ class RMSNorm(nn.Module):
-def rotate_half(x):
- """Rotates half the hidden dims of the input."""
- x1 = x[..., : x.shape[-1] // 2]
- x2 = x[..., x.shape[-1] // 2 :]
- return torch.cat((-x2, x1), dim=-1)
-
-
def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_dims=None, device=None):
if not isinstance(theta, list):
theta = [theta]
@@ -390,20 +383,30 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_scale=None, rope_di
else:
cos = cos.unsqueeze(1)
sin = sin.unsqueeze(1)
- out.append((cos, sin))
+ sin_split = sin.shape[-1] // 2
+ out.append((cos, sin[..., : sin_split], -sin[..., sin_split :]))
if len(out) == 1:
return out[0]
return out
-
def apply_rope(xq, xk, freqs_cis):
org_dtype = xq.dtype
cos = freqs_cis[0]
sin = freqs_cis[1]
- q_embed = (xq * cos) + (rotate_half(xq) * sin)
- k_embed = (xk * cos) + (rotate_half(xk) * sin)
+ nsin = freqs_cis[2]
+
+ q_embed = (xq * cos)
+ q_split = q_embed.shape[-1] // 2
+ q_embed[..., : q_split].addcmul_(xq[..., q_split :], nsin)
+ q_embed[..., q_split :].addcmul_(xq[..., : q_split], sin)
+
+ k_embed = (xk * cos)
+ k_split = k_embed.shape[-1] // 2
+ k_embed[..., : k_split].addcmul_(xk[..., k_split :], nsin)
+ k_embed[..., k_split :].addcmul_(xk[..., : k_split], sin)
+
return q_embed.to(org_dtype), k_embed.to(org_dtype)
From e03fe8b5919a23a473cea6e53f916f7403c082a5 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Thu, 12 Feb 2026 20:29:12 -0800
Subject: [PATCH 27/85] Update command to install AMD stable linux pytorch.
(#12437)
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 96dc2904b..3ccdc9c19 100644
--- a/README.md
+++ b/README.md
@@ -227,7 +227,7 @@ Put your VAE in: models/vae
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
-```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.4```
+```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm7.1```
This is the command to install the nightly with ROCm 7.1 which might have some performance improvements:
From 8902907d7ab949ce42dd9b658b4a4582ed9fb630 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Fri, 13 Feb 2026 12:29:37 -0800
Subject: [PATCH 28/85] dynamic_vram: Training fixes (#12442)
---
comfy/model_patcher.py | 4 ++++
comfy_extras/nodes_train.py | 11 ++++++++++-
2 files changed, 14 insertions(+), 1 deletion(-)
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index b1d907ba4..67dce088e 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -1561,6 +1561,8 @@ class ModelPatcherDynamic(ModelPatcher):
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
+ move_weight_functions(m, device_to)
+
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
self.model.device = device_to
@@ -1601,6 +1603,8 @@ class ModelPatcherDynamic(ModelPatcher):
if unpatch_weights:
self.partially_unload_ram(1e32)
self.partially_unload(None, 1e32)
+ for m in self.model.modules():
+ move_weight_functions(m, device_to)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py
index 630eedc9f..aa2d88673 100644
--- a/comfy_extras/nodes_train.py
+++ b/comfy_extras/nodes_train.py
@@ -1035,7 +1035,7 @@ class TrainLoraNode(io.ComfyNode):
io.Boolean.Input(
"offloading",
default=False,
- tooltip="Depth level for gradient checkpointing.",
+ tooltip="Offload the Model to RAM. Requires Bypass Mode.",
),
io.Combo.Input(
"existing_lora",
@@ -1124,6 +1124,15 @@ class TrainLoraNode(io.ComfyNode):
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
+ if mp.is_dynamic():
+ if not bypass_mode:
+ logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
+ bypass_mode = True
+ offloading = True
+ elif offloading:
+ if not bypass_mode:
+ logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
+
# Prepare latents and compute counts
latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode
From e1add563f9e89026e8c4e8825a2b279fbd67d23a Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 13 Feb 2026 12:35:13 -0800
Subject: [PATCH 29/85] Use torch RMSNorm for flux models and refactor hunyuan
video code. (#12432)
---
comfy/controlnet.py | 1 +
comfy/ldm/chroma/layers.py | 3 +-
comfy/ldm/chroma_radiance/layers.py | 8 ++---
comfy/ldm/flux/layers.py | 53 +++++++----------------------
comfy/ldm/flux/model.py | 3 +-
comfy/ldm/hunyuan_video/model.py | 11 +++---
comfy/lora_convert.py | 2 +-
comfy/model_detection.py | 18 +++++++---
comfy/supported_models.py | 32 +++++++++++++++--
comfy/utils.py | 12 +++----
10 files changed, 74 insertions(+), 69 deletions(-)
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index 9e1e704e0..8336412f2 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -560,6 +560,7 @@ def load_controlnet_hunyuandit(controlnet_data, model_options={}):
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options={}):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
+ sd = model_config.process_unet_state_dict(sd)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py
index 2d5684348..df348a8ed 100644
--- a/comfy/ldm/chroma/layers.py
+++ b/comfy/ldm/chroma/layers.py
@@ -3,7 +3,6 @@ from torch import Tensor, nn
from comfy.ldm.flux.layers import (
MLPEmbedder,
- RMSNorm,
ModulationOut,
)
@@ -29,7 +28,7 @@ class Approximator(nn.Module):
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
- self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
+ self.norms = nn.ModuleList([operations.RMSNorm(hidden_dim, dtype=dtype, device=device) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
@property
diff --git a/comfy/ldm/chroma_radiance/layers.py b/comfy/ldm/chroma_radiance/layers.py
index 3c7bc9b6b..08d31e0ba 100644
--- a/comfy/ldm/chroma_radiance/layers.py
+++ b/comfy/ldm/chroma_radiance/layers.py
@@ -4,8 +4,6 @@ from functools import lru_cache
import torch
from torch import nn
-from comfy.ldm.flux.layers import RMSNorm
-
class NerfEmbedder(nn.Module):
"""
@@ -145,7 +143,7 @@ class NerfGLUBlock(nn.Module):
# We now need to generate parameters for 3 matrices.
total_params = 3 * hidden_size_x**2 * mlp_ratio
self.param_generator = operations.Linear(hidden_size_s, total_params, dtype=dtype, device=device)
- self.norm = RMSNorm(hidden_size_x, dtype=dtype, device=device, operations=operations)
+ self.norm = operations.RMSNorm(hidden_size_x, dtype=dtype, device=device)
self.mlp_ratio = mlp_ratio
@@ -178,7 +176,7 @@ class NerfGLUBlock(nn.Module):
class NerfFinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, dtype=None, device=None, operations=None):
super().__init__()
- self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
+ self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -190,7 +188,7 @@ class NerfFinalLayer(nn.Module):
class NerfFinalLayerConv(nn.Module):
def __init__(self, hidden_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
- self.norm = RMSNorm(hidden_size, dtype=dtype, device=device, operations=operations)
+ self.norm = operations.RMSNorm(hidden_size, dtype=dtype, device=device)
self.conv = operations.Conv2d(
in_channels=hidden_size,
out_channels=out_channels,
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index 60f2bdae2..1f2975fb1 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -5,8 +5,6 @@ import torch
from torch import Tensor, nn
from .math import attention, rope
-import comfy.ops
-import comfy.ldm.common_dit
class EmbedND(nn.Module):
@@ -87,20 +85,12 @@ def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dt
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
-class RMSNorm(torch.nn.Module):
- def __init__(self, dim: int, dtype=None, device=None, operations=None):
- super().__init__()
- self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
-
- def forward(self, x: Tensor):
- return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
-
class QKNorm(torch.nn.Module):
def __init__(self, dim: int, dtype=None, device=None, operations=None):
super().__init__()
- self.query_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
- self.key_norm = RMSNorm(dim, dtype=dtype, device=device, operations=operations)
+ self.query_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
+ self.key_norm = operations.RMSNorm(dim, dtype=dtype, device=device)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple:
q = self.query_norm(q)
@@ -169,7 +159,7 @@ class SiLUActivation(nn.Module):
class DoubleStreamBlock(nn.Module):
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
@@ -197,8 +187,6 @@ class DoubleStreamBlock(nn.Module):
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
- self.flipped_img_txt = flipped_img_txt
-
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
if self.modulation:
img_mod1, img_mod2 = self.img_mod(vec)
@@ -224,32 +212,17 @@ class DoubleStreamBlock(nn.Module):
del txt_qkv
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
- if self.flipped_img_txt:
- q = torch.cat((img_q, txt_q), dim=2)
- del img_q, txt_q
- k = torch.cat((img_k, txt_k), dim=2)
- del img_k, txt_k
- v = torch.cat((img_v, txt_v), dim=2)
- del img_v, txt_v
- # run actual attention
- attn = attention(q, k, v,
- pe=pe, mask=attn_mask, transformer_options=transformer_options)
- del q, k, v
+ q = torch.cat((txt_q, img_q), dim=2)
+ del txt_q, img_q
+ k = torch.cat((txt_k, img_k), dim=2)
+ del txt_k, img_k
+ v = torch.cat((txt_v, img_v), dim=2)
+ del txt_v, img_v
+ # run actual attention
+ attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
+ del q, k, v
- img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
- else:
- q = torch.cat((txt_q, img_q), dim=2)
- del txt_q, img_q
- k = torch.cat((txt_k, img_k), dim=2)
- del txt_k, img_k
- v = torch.cat((txt_v, img_v), dim=2)
- del txt_v, img_v
- # run actual attention
- attn = attention(q, k, v,
- pe=pe, mask=attn_mask, transformer_options=transformer_options)
- del q, k, v
-
- txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index f40c2a7a9..260ccad7e 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -16,7 +16,6 @@ from .layers import (
SingleStreamBlock,
timestep_embedding,
Modulation,
- RMSNorm
)
@dataclass
@@ -81,7 +80,7 @@ class Flux(nn.Module):
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
if params.txt_norm:
- self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
+ self.txt_norm = operations.RMSNorm(params.context_in_dim, dtype=dtype, device=device)
else:
self.txt_norm = None
diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py
index 55ab550f8..563f28f6b 100644
--- a/comfy/ldm/hunyuan_video/model.py
+++ b/comfy/ldm/hunyuan_video/model.py
@@ -241,7 +241,6 @@ class HunyuanVideo(nn.Module):
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
- flipped_img_txt=True,
dtype=dtype, device=device, operations=operations
)
for _ in range(params.depth)
@@ -378,14 +377,14 @@ class HunyuanVideo(nn.Module):
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
- ids = torch.cat((img_ids, txt_ids), dim=1)
+ ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
img_len = img.shape[1]
if txt_mask is not None:
attn_mask_len = img_len + txt.shape[1]
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
- attn_mask[:, 0, img_len:] = txt_mask
+ attn_mask[:, 0, :txt.shape[1]] = txt_mask
else:
attn_mask = None
@@ -413,7 +412,7 @@ class HunyuanVideo(nn.Module):
if add is not None:
img += add
- img = torch.cat((img, txt), 1)
+ img = torch.cat((txt, img), 1)
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
@@ -435,9 +434,9 @@ class HunyuanVideo(nn.Module):
if i < len(control_o):
add = control_o[i]
if add is not None:
- img[:, : img_len] += add
+ img[:, txt.shape[1]: img_len + txt.shape[1]] += add
- img = img[:, : img_len]
+ img = img[:, txt.shape[1]: img_len + txt.shape[1]]
if ref_latent is not None:
img = img[:, ref_latent.shape[1]:]
diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py
index 9d8d21efe..749e81df3 100644
--- a/comfy/lora_convert.py
+++ b/comfy/lora_convert.py
@@ -5,7 +5,7 @@ import comfy.utils
def convert_lora_bfl_control(sd): #BFL loras for Flux
sd_out = {}
for k in sd:
- k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.scale.set_weight"))
+ k_to = "diffusion_model.{}".format(k.replace(".lora_B.bias", ".diff_b").replace("_norm.scale", "_norm.set_weight"))
sd_out[k_to] = sd[k]
sd_out["diffusion_model.img_in.reshape_weight"] = torch.tensor([sd["img_in.lora_B.weight"].shape[0], sd["img_in.lora_A.weight"].shape[1]])
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index e8ad725df..30ea03e8e 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -19,6 +19,12 @@ def count_blocks(state_dict_keys, prefix_string):
count += 1
return count
+def any_suffix_in(keys, prefix, main, suffix_list=[]):
+ for x in suffix_list:
+ if "{}{}{}".format(prefix, main, x) in keys:
+ return True
+ return False
+
def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
context_dim = None
use_linear_in_transformer = False
@@ -186,7 +192,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["meanflow_sum"] = False
return dit_config
- if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
+ if any_suffix_in(state_dict_keys, key_prefix, 'double_blocks.0.img_attn.norm.key_norm.', ["weight", "scale"]) and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"])): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
dit_config = {}
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
dit_config["image_model"] = "flux2"
@@ -241,7 +247,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
- if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
+
+ if any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.0.norms.0.', ["weight", "scale"]) or any_suffix_in(state_dict_keys, key_prefix, 'distilled_guidance_layer.norms.0.', ["weight", "scale"]): #Chroma
dit_config["image_model"] = "chroma"
dit_config["in_channels"] = 64
dit_config["out_channels"] = 64
@@ -249,7 +256,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["out_dim"] = 3072
dit_config["hidden_dim"] = 5120
dit_config["n_layers"] = 5
- if f"{key_prefix}nerf_blocks.0.norm.scale" in state_dict_keys: #Chroma Radiance
+
+ if any_suffix_in(state_dict_keys, key_prefix, 'nerf_blocks.0.norm.', ["weight", "scale"]): #Chroma Radiance
dit_config["image_model"] = "chroma_radiance"
dit_config["in_channels"] = 3
dit_config["out_channels"] = 3
@@ -259,7 +267,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["nerf_depth"] = 4
dit_config["nerf_max_freqs"] = 8
dit_config["nerf_tile_size"] = 512
- dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
+ dit_config["nerf_final_head_type"] = "conv" if any_suffix_in(state_dict_keys, key_prefix, 'nerf_final_layer_conv.norm.', ["weight", "scale"]) else "linear"
dit_config["nerf_embedder_dtype"] = torch.float32
if "{}__x0__".format(key_prefix) in state_dict_keys: # x0 pred
dit_config["use_x0"] = True
@@ -268,7 +276,7 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
- dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
+ dit_config["txt_norm"] = any_suffix_in(state_dict_keys, key_prefix, 'txt_norm.', ["weight", "scale"])
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
dit_config["txt_ids_dims"] = [1, 2]
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index d33db7507..c28be1716 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -710,6 +710,15 @@ class Flux(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+ def process_unet_state_dict(self, state_dict):
+ out_sd = {}
+ for k in list(state_dict.keys()):
+ key_out = k
+ if key_out.endswith("_norm.scale"):
+ key_out = "{}.weight".format(key_out[:-len(".scale")])
+ out_sd[key_out] = state_dict[k]
+ return out_sd
+
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
@@ -898,11 +907,13 @@ class HunyuanVideo(supported_models_base.BASE):
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
- key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
- key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
+ key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.weight").replace("_attn_k_norm.weight", "_attn.norm.key_norm.weight")
+ key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.weight").replace(".k_norm.weight", ".norm.key_norm.weight")
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
+ if key_out.endswith(".scale"):
+ key_out = "{}.weight".format(key_out[:-len(".scale")])
out_sd[key_out] = state_dict[k]
return out_sd
@@ -1264,6 +1275,15 @@ class Hunyuan3Dv2(supported_models_base.BASE):
latent_format = latent_formats.Hunyuan3Dv2
+ def process_unet_state_dict(self, state_dict):
+ out_sd = {}
+ for k in list(state_dict.keys()):
+ key_out = k
+ if key_out.endswith(".scale"):
+ key_out = "{}.weight".format(key_out[:-len(".scale")])
+ out_sd[key_out] = state_dict[k]
+ return out_sd
+
def process_unet_state_dict_for_saving(self, state_dict):
replace_prefix = {"": "model."}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
@@ -1341,6 +1361,14 @@ class Chroma(supported_models_base.BASE):
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+ def process_unet_state_dict(self, state_dict):
+ out_sd = {}
+ for k in list(state_dict.keys()):
+ key_out = k
+ if key_out.endswith(".scale"):
+ key_out = "{}.weight".format(key_out[:-len(".scale")])
+ out_sd[key_out] = state_dict[k]
+ return out_sd
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Chroma(self, device=device)
diff --git a/comfy/utils.py b/comfy/utils.py
index e0a94e2e1..d553a7c1b 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -675,10 +675,10 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"ff_context.linear_in.bias": "txt_mlp.0.bias",
"ff_context.linear_out.weight": "txt_mlp.2.weight",
"ff_context.linear_out.bias": "txt_mlp.2.bias",
- "attn.norm_q.weight": "img_attn.norm.query_norm.scale",
- "attn.norm_k.weight": "img_attn.norm.key_norm.scale",
- "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
- "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
+ "attn.norm_q.weight": "img_attn.norm.query_norm.weight",
+ "attn.norm_k.weight": "img_attn.norm.key_norm.weight",
+ "attn.norm_added_q.weight": "txt_attn.norm.query_norm.weight",
+ "attn.norm_added_k.weight": "txt_attn.norm.key_norm.weight",
}
for k in block_map:
@@ -701,8 +701,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
- "attn.norm_q.weight": "norm.query_norm.scale",
- "attn.norm_k.weight": "norm.key_norm.scale",
+ "attn.norm_q.weight": "norm.query_norm.weight",
+ "attn.norm_k.weight": "norm.key_norm.weight",
"attn.to_qkv_mlp_proj.weight": "linear1.weight", # Flux 2
"attn.to_out.weight": "linear2.weight", # Flux 2
}
From 831351a29e91ea758437227c2f3c915a6be6d1a6 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 13 Feb 2026 17:15:23 -0800
Subject: [PATCH 30/85] Support generating attention masks for left padded text
encoders. (#12454)
---
comfy/sd1_clip.py | 15 +++++++++++----
comfy/text_encoders/ace15.py | 8 +-------
2 files changed, 12 insertions(+), 11 deletions(-)
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 4c817d468..b564d1529 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -171,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def process_tokens(self, tokens, device):
end_token = self.special_tokens.get("end", None)
+ pad_token = self.special_tokens.get("pad", -1)
if end_token is None:
- cmp_token = self.special_tokens.get("pad", -1)
+ cmp_token = pad_token
else:
cmp_token = end_token
@@ -186,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
other_embeds = []
eos = False
index = 0
+ left_pad = False
for y in x:
if isinstance(y, numbers.Integral):
- if eos:
+ token = int(y)
+ if index == 0 and token == pad_token:
+ left_pad = True
+
+ if eos or (left_pad and token == pad_token):
attention_mask.append(0)
else:
attention_mask.append(1)
- token = int(y)
+ left_pad = False
+
tokens_temp += [token]
- if not eos and token == cmp_token:
+ if not eos and token == cmp_token and not left_pad:
if end_token is None:
attention_mask[-1] = 0
eos = True
diff --git a/comfy/text_encoders/ace15.py b/comfy/text_encoders/ace15.py
index 0fdd4669f..f135d74c1 100644
--- a/comfy/text_encoders/ace15.py
+++ b/comfy/text_encoders/ace15.py
@@ -10,7 +10,6 @@ import comfy.utils
def sample_manual_loop_no_classes(
model,
ids=None,
- paddings=[],
execution_dtype=None,
cfg_scale: float = 2.0,
temperature: float = 0.85,
@@ -36,9 +35,6 @@ def sample_manual_loop_no_classes(
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
embeds_batch = embeds.shape[0]
- for i, t in enumerate(paddings):
- attention_mask[i, :t] = 0
- attention_mask[i, t:] = 1
output_audio_codes = []
past_key_values = []
@@ -135,13 +131,11 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
pos_pad = (len(negative) - len(positive))
positive = [model.special_tokens["pad"]] * pos_pad + positive
- paddings = [pos_pad, neg_pad]
ids = [positive, negative]
else:
- paddings = []
ids = [positive]
- return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
+ return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
From 726af73867c18c5ca8b980a2c28401d77e5b365a Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 13 Feb 2026 17:21:10 -0800
Subject: [PATCH 31/85] Fix some custom nodes. (#12455)
---
comfy/ldm/flux/layers.py | 2 ++
1 file changed, 2 insertions(+)
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index 1f2975fb1..3518a1922 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -6,6 +6,8 @@ from torch import Tensor, nn
from .math import attention, rope
+# Fix import for some custom nodes, TODO: delete eventually.
+RMSNorm = None
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: list):
From 712efb466b9379e6761802c44027783d37d96a87 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Fri, 13 Feb 2026 18:56:54 -0800
Subject: [PATCH 32/85] Add left padding to LTXAV text encoder. (#12456)
---
comfy/text_encoders/lt.py | 4 +++-
1 file changed, 3 insertions(+), 1 deletion(-)
diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py
index 3f87dfd6a..9cf87c0b2 100644
--- a/comfy/text_encoders/lt.py
+++ b/comfy/text_encoders/lt.py
@@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
- super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
@@ -97,6 +97,7 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs["gemma3_12b"]
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
+ out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
@@ -138,6 +139,7 @@ class LTXAVTEModel(torch.nn.Module):
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
+ num_tokens = max(num_tokens, 64)
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
From dc9822b7df4785e77690e93b0e09feaff01e2e12 Mon Sep 17 00:00:00 2001
From: krigeta <75309361+krigeta@users.noreply.github.com>
Date: Sat, 14 Feb 2026 08:53:52 +0530
Subject: [PATCH 33/85] Add working Qwen 2512 ControlNet (Fun ControlNet)
support (#12359)
---
comfy/controlnet.py | 73 +++++++++++
comfy/ldm/qwen_image/controlnet.py | 190 +++++++++++++++++++++++++++++
2 files changed, 263 insertions(+)
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index 8336412f2..ba670b16d 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -297,6 +297,30 @@ class ControlNet(ControlBase):
self.model_sampling_current = None
super().cleanup()
+
+class QwenFunControlNet(ControlNet):
+ def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
+ # Fun checkpoints are more sensitive to high strengths in the generic
+ # ControlNet merge path. Use a soft response curve so strength=1.0 stays
+ # unchanged while >1 grows more gently.
+ original_strength = self.strength
+ self.strength = math.sqrt(max(self.strength, 0.0))
+ try:
+ return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
+ finally:
+ self.strength = original_strength
+
+ def pre_run(self, model, percent_to_timestep_function):
+ super().pre_run(model, percent_to_timestep_function)
+ self.set_extra_arg("base_model", model.diffusion_model)
+
+ def copy(self):
+ c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
+ c.control_model = self.control_model
+ c.control_model_wrapped = self.control_model_wrapped
+ self.copy_to(c)
+ return c
+
class ControlLoraOps:
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True,
@@ -606,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
+
+def load_controlnet_qwen_fun(sd, model_options={}):
+ load_device = comfy.model_management.get_torch_device()
+ weight_dtype = comfy.utils.weight_dtype(sd)
+ unet_dtype = model_options.get("dtype", weight_dtype)
+ manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
+
+ operations = model_options.get("custom_operations", None)
+ if operations is None:
+ operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
+
+ in_features = sd["control_img_in.weight"].shape[1]
+ inner_dim = sd["control_img_in.weight"].shape[0]
+
+ block_weight = sd["control_blocks.0.attn.to_q.weight"]
+ attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
+ num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
+
+ model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
+ control_in_features=in_features,
+ inner_dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ num_control_blocks=5,
+ main_model_double=60,
+ injection_layers=(0, 12, 24, 36, 48),
+ operations=operations,
+ device=comfy.model_management.unet_offload_device(),
+ dtype=unet_dtype,
+ )
+ model = controlnet_load_state_dict(model, sd)
+
+ latent_format = comfy.latent_formats.Wan21()
+ control = QwenFunControlNet(
+ model,
+ compression_ratio=1,
+ latent_format=latent_format,
+ # Fun checkpoints already expect their own 33-channel context handling.
+ # Enabling generic concat_mask injects an extra mask channel at apply-time
+ # and breaks the intended fallback packing path.
+ concat_mask=False,
+ load_device=load_device,
+ manual_cast_dtype=manual_cast_dtype,
+ extra_conds=[],
+ )
+ return control
+
def convert_mistoline(sd):
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
@@ -683,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
+ elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
+ return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py
index a6d408104..c0aae9240 100644
--- a/comfy/ldm/qwen_image/controlnet.py
+++ b/comfy/ldm/qwen_image/controlnet.py
@@ -2,6 +2,196 @@ import torch
import math
from .model import QwenImageTransformer2DModel
+from .model import QwenImageTransformerBlock
+
+
+class QwenImageFunControlBlock(QwenImageTransformerBlock):
+ def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
+ super().__init__(
+ dim=dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ self.has_before_proj = has_before_proj
+ if has_before_proj:
+ self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
+ self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
+
+
+class QwenImageFunControlNetModel(torch.nn.Module):
+ def __init__(
+ self,
+ control_in_features=132,
+ inner_dim=3072,
+ num_attention_heads=24,
+ attention_head_dim=128,
+ num_control_blocks=5,
+ main_model_double=60,
+ injection_layers=(0, 12, 24, 36, 48),
+ dtype=None,
+ device=None,
+ operations=None,
+ ):
+ super().__init__()
+ self.dtype = dtype
+ self.main_model_double = main_model_double
+ self.injection_layers = tuple(injection_layers)
+ # Keep base hint scaling at 1.0 so user-facing strength behaves similarly
+ # to the reference Gen2/VideoX implementation around strength=1.
+ self.hint_scale = 1.0
+ self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
+
+ self.control_blocks = torch.nn.ModuleList([])
+ for i in range(num_control_blocks):
+ self.control_blocks.append(
+ QwenImageFunControlBlock(
+ dim=inner_dim,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ has_before_proj=(i == 0),
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ )
+
+ def _process_hint_tokens(self, hint):
+ if hint is None:
+ return None
+ if hint.ndim == 4:
+ hint = hint.unsqueeze(2)
+
+ # Fun checkpoints are trained with 33 latent channels before 2x2 packing:
+ # [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
+ # Default behavior (no inpaint input in stock Apply ControlNet) should use
+ # zeros for mask/inpaint branches, matching VideoX fallback semantics.
+ expected_c = self.control_img_in.weight.shape[1] // 4
+ if hint.shape[1] == 16 and expected_c == 33:
+ zeros_mask = torch.zeros_like(hint[:, :1])
+ zeros_inpaint = torch.zeros_like(hint)
+ hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
+
+ bs, c, t, h, w = hint.shape
+ hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
+ orig_shape = hidden_states.shape
+ hidden_states = hidden_states.view(
+ orig_shape[0],
+ orig_shape[1],
+ orig_shape[-3],
+ orig_shape[-2] // 2,
+ 2,
+ orig_shape[-1] // 2,
+ 2,
+ )
+ hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
+ hidden_states = hidden_states.reshape(
+ bs,
+ t * ((h + 1) // 2) * ((w + 1) // 2),
+ c * 4,
+ )
+
+ expected_in = self.control_img_in.weight.shape[1]
+ cur_in = hidden_states.shape[-1]
+ if cur_in < expected_in:
+ pad = torch.zeros(
+ (hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
+ device=hidden_states.device,
+ dtype=hidden_states.dtype,
+ )
+ hidden_states = torch.cat([hidden_states, pad], dim=-1)
+ elif cur_in > expected_in:
+ hidden_states = hidden_states[:, :, :expected_in]
+
+ return hidden_states
+
+ def forward(
+ self,
+ x,
+ timesteps,
+ context,
+ attention_mask=None,
+ guidance: torch.Tensor = None,
+ hint=None,
+ transformer_options={},
+ base_model=None,
+ **kwargs,
+ ):
+ if base_model is None:
+ raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
+
+ encoder_hidden_states_mask = attention_mask
+ # Keep attention mask disabled inside Fun control blocks to mirror
+ # VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
+ encoder_hidden_states_mask = None
+
+ hidden_states, img_ids, _ = base_model.process_img(x)
+ hint_tokens = self._process_hint_tokens(hint)
+ if hint_tokens is None:
+ raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
+
+ if hint_tokens.shape[1] != hidden_states.shape[1]:
+ max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
+ hint_tokens = hint_tokens[:, :max_tokens]
+ hidden_states = hidden_states[:, :max_tokens]
+ img_ids = img_ids[:, :max_tokens]
+
+ txt_start = round(
+ max(
+ ((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
+ ((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
+ )
+ )
+ txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+ image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
+
+ hidden_states = base_model.img_in(hidden_states)
+ encoder_hidden_states = base_model.txt_norm(context)
+ encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
+
+ if guidance is not None:
+ guidance = guidance * 1000
+
+ temb = (
+ base_model.time_text_embed(timesteps, hidden_states)
+ if guidance is None
+ else base_model.time_text_embed(timesteps, guidance, hidden_states)
+ )
+
+ c = self.control_img_in(hint_tokens)
+
+ for i, block in enumerate(self.control_blocks):
+ if i == 0:
+ c_in = block.before_proj(c) + hidden_states
+ all_c = []
+ else:
+ all_c = list(torch.unbind(c, dim=0))
+ c_in = all_c.pop(-1)
+
+ encoder_hidden_states, c_out = block(
+ hidden_states=c_in,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_states_mask=encoder_hidden_states_mask,
+ temb=temb,
+ image_rotary_emb=image_rotary_emb,
+ transformer_options=transformer_options,
+ )
+
+ c_skip = block.after_proj(c_out) * self.hint_scale
+ all_c += [c_skip, c_out]
+ c = torch.stack(all_c, dim=0)
+
+ hints = torch.unbind(c, dim=0)[:-1]
+
+ controlnet_block_samples = [None] * self.main_model_double
+ for local_idx, base_idx in enumerate(self.injection_layers):
+ if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
+ controlnet_block_samples[base_idx] = hints[local_idx]
+
+ return {"input": controlnet_block_samples}
class QwenImageControlNetModel(QwenImageTransformer2DModel):
From df1e5e85142746a745a56572b705406b273a594c Mon Sep 17 00:00:00 2001
From: Christian Byrne
Date: Sat, 14 Feb 2026 11:01:10 -0800
Subject: [PATCH 34/85] Update frontend package to 1.38.14 (#12469)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 7de6a413c..e939e486a 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-comfyui-frontend-package==1.38.13
+comfyui-frontend-package==1.38.14
comfyui-workflow-templates==0.8.38
comfyui-embedded-docs==0.4.1
torch
From e1ede29d827d573262caede8aeb6cbc98c323c81 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Sat, 14 Feb 2026 19:53:52 -0800
Subject: [PATCH 35/85] Remove unsafe pickle loading code that was used on
pytorch older than 2.4 (#12473)
ComfyUI hasn't started on pytorch 2.4 since last month.
---
comfy/checkpoint_pickle.py | 13 -------------
comfy/utils.py | 25 +++++++++++--------------
2 files changed, 11 insertions(+), 27 deletions(-)
delete mode 100644 comfy/checkpoint_pickle.py
diff --git a/comfy/checkpoint_pickle.py b/comfy/checkpoint_pickle.py
deleted file mode 100644
index 206551d3c..000000000
--- a/comfy/checkpoint_pickle.py
+++ /dev/null
@@ -1,13 +0,0 @@
-import pickle
-
-load = pickle.load
-
-class Empty:
- pass
-
-class Unpickler(pickle.Unpickler):
- def find_class(self, module, name):
- #TODO: safe unpickle
- if module.startswith("pytorch_lightning"):
- return Empty
- return super().find_class(module, name)
diff --git a/comfy/utils.py b/comfy/utils.py
index d553a7c1b..c1ce540b5 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -20,7 +20,7 @@
import torch
import math
import struct
-import comfy.checkpoint_pickle
+import comfy.memory_management
import safetensors.torch
import numpy as np
from PIL import Image
@@ -38,26 +38,26 @@ import warnings
MMAP_TORCH_FILES = args.mmap_torch_files
DISABLE_MMAP = args.disable_mmap
-ALWAYS_SAFE_LOAD = False
-if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
+
+if True: # ckpt/pt file whitelist for safe loading of old sd files
class ModelCheckpoint:
pass
ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
def scalar(*args, **kwargs):
- from numpy.core.multiarray import scalar as sc
- return sc(*args, **kwargs)
+ return None
scalar.__module__ = "numpy.core.multiarray"
from numpy import dtype
from numpy.dtypes import Float64DType
- from _codecs import encode
+
+ def encode(*args, **kwargs): # no longer necessary on newer torch
+ return None
+ encode.__module__ = "_codecs"
torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
- ALWAYS_SAFE_LOAD = True
logging.info("Checkpoint files will always be loaded safely.")
-else:
- logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
+
# Current as of safetensors 0.7.0
_TYPES = {
@@ -140,11 +140,8 @@ def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if MMAP_TORCH_FILES:
torch_args["mmap"] = True
- if safe_load or ALWAYS_SAFE_LOAD:
- pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
- else:
- logging.warning("WARNING: loading {} unsafely, upgrade your pytorch to 2.4 or newer to load this file safely.".format(ckpt))
- pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
+ pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
+
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
From ce4a1ab48d9f723eeaac37f88dde55086b1f233f Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Sun, 15 Feb 2026 11:31:59 +0200
Subject: [PATCH 36/85] chore(api-nodes): remove "gpt-4o" model (#12467)
---
comfy_api_nodes/nodes_openai.py | 6 ------
1 file changed, 6 deletions(-)
diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py
index f05aaab7b..332107a82 100644
--- a/comfy_api_nodes/nodes_openai.py
+++ b/comfy_api_nodes/nodes_openai.py
@@ -43,7 +43,6 @@ class SupportedOpenAIModel(str, Enum):
o1 = "o1"
o3 = "o3"
o1_pro = "o1-pro"
- gpt_4o = "gpt-4o"
gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano"
@@ -649,11 +648,6 @@ class OpenAIChatNode(IO.ComfyNode):
"usd": [0.01, 0.04],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
- : $contains($m, "gpt-4o") ? {
- "type": "list_usd",
- "usd": [0.0025, 0.01],
- "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
- }
: $contains($m, "gpt-4.1-nano") ? {
"type": "list_usd",
"usd": [0.0001, 0.0004],
From 596ed686919f11f75be3cf9a79977a07d64002c5 Mon Sep 17 00:00:00 2001
From: Jedrzej Kosinski
Date: Sun, 15 Feb 2026 02:12:30 -0800
Subject: [PATCH 37/85] Node Replacement API (#12014)
---
app/node_replace_manager.py | 105 ++++++++++++++++++++++++++
comfy_api/feature_flags.py | 1 +
comfy_api/latest/__init__.py | 13 +++-
comfy_api/latest/_io.py | 63 ++++++++++++++++
comfy_extras/nodes_post_processing.py | 1 +
comfy_extras/nodes_replacements.py | 103 +++++++++++++++++++++++++
nodes.py | 2 +
server.py | 5 ++
8 files changed, 291 insertions(+), 2 deletions(-)
create mode 100644 app/node_replace_manager.py
create mode 100644 comfy_extras/nodes_replacements.py
diff --git a/app/node_replace_manager.py b/app/node_replace_manager.py
new file mode 100644
index 000000000..03b603c70
--- /dev/null
+++ b/app/node_replace_manager.py
@@ -0,0 +1,105 @@
+from __future__ import annotations
+
+from aiohttp import web
+
+from typing import TYPE_CHECKING, TypedDict
+if TYPE_CHECKING:
+ from comfy_api.latest._io_public import NodeReplace
+
+from comfy_execution.graph_utils import is_link
+import nodes
+
+class NodeStruct(TypedDict):
+ inputs: dict[str, str | int | float | bool | tuple[str, int]]
+ class_type: str
+ _meta: dict[str, str]
+
+def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
+ new_node_struct = node_struct.copy()
+ if empty_inputs:
+ new_node_struct["inputs"] = {}
+ else:
+ new_node_struct["inputs"] = node_struct["inputs"].copy()
+ new_node_struct["_meta"] = node_struct["_meta"].copy()
+ return new_node_struct
+
+
+class NodeReplaceManager:
+ """Manages node replacement registrations."""
+
+ def __init__(self):
+ self._replacements: dict[str, list[NodeReplace]] = {}
+
+ def register(self, node_replace: NodeReplace):
+ """Register a node replacement mapping."""
+ self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
+
+ def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
+ """Get replacements for an old node ID."""
+ return self._replacements.get(old_node_id)
+
+ def has_replacement(self, old_node_id: str) -> bool:
+ """Check if a replacement exists for an old node ID."""
+ return old_node_id in self._replacements
+
+ def apply_replacements(self, prompt: dict[str, NodeStruct]):
+ connections: dict[str, list[tuple[str, str, int]]] = {}
+ need_replacement: set[str] = set()
+ for node_number, node_struct in prompt.items():
+ class_type = node_struct["class_type"]
+ # need replacement if not in NODE_CLASS_MAPPINGS and has replacement
+ if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
+ need_replacement.add(node_number)
+ # keep track of connections
+ for input_id, input_value in node_struct["inputs"].items():
+ if is_link(input_value):
+ conn_number = input_value[0]
+ connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
+ for node_number in need_replacement:
+ node_struct = prompt[node_number]
+ class_type = node_struct["class_type"]
+ replacements = self.get_replacement(class_type)
+ if replacements is None:
+ continue
+ # just use the first replacement
+ replacement = replacements[0]
+ new_node_id = replacement.new_node_id
+ # if replacement is not a valid node, skip trying to replace it as will only cause confusion
+ if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
+ continue
+ # first, replace node id (class_type)
+ new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
+ new_node_struct["class_type"] = new_node_id
+ # TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
+ # second, replace inputs
+ if replacement.input_mapping is not None:
+ for input_map in replacement.input_mapping:
+ if "set_value" in input_map:
+ new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
+ elif "old_id" in input_map:
+ new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
+ # finalize input replacement
+ prompt[node_number] = new_node_struct
+ # third, replace outputs
+ if replacement.output_mapping is not None:
+ # re-mapping outputs requires changing the input values of nodes that receive connections from this one
+ if node_number in connections:
+ for conns in connections[node_number]:
+ conn_node_number, conn_input_id, old_output_idx = conns
+ for output_map in replacement.output_mapping:
+ if output_map["old_idx"] == old_output_idx:
+ new_output_idx = output_map["new_idx"]
+ previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
+ previous_input[1] = new_output_idx
+
+ def as_dict(self):
+ """Serialize all replacements to dict."""
+ return {
+ k: [v.as_dict() for v in v_list]
+ for k, v_list in self._replacements.items()
+ }
+
+ def add_routes(self, routes):
+ @routes.get("/node_replacements")
+ async def get_node_replacements(request):
+ return web.json_response(self.as_dict())
diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py
index de167f037..a90a5ca40 100644
--- a/comfy_api/feature_flags.py
+++ b/comfy_api/feature_flags.py
@@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
+ "node_replacements": True,
}
diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py
index 8542a1dbc..f2399422b 100644
--- a/comfy_api/latest/__init__.py
+++ b/comfy_api/latest/__init__.py
@@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest"
STABLE = False
+ def __init__(self):
+ super().__init__()
+ self.node_replacement = self.NodeReplacement()
+ self.execution = self.Execution()
+
+ class NodeReplacement(ProxiedSingleton):
+ async def register(self, node_replace: io.NodeReplace) -> None:
+ """Register a node replacement mapping."""
+ from server import PromptServer
+ PromptServer.instance.node_replace_manager.register(node_replace)
+
class Execution(ProxiedSingleton):
async def set_progress(
self,
@@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display,
)
- execution: Execution
-
class ComfyExtension(ABC):
async def on_load(self) -> None:
"""
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 93cf482ca..95d79c035 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -2030,6 +2030,68 @@ class _UIOutput(ABC):
...
+class InputMapOldId(TypedDict):
+ """Map an old node input to a new node input by ID."""
+ new_id: str
+ old_id: str
+
+class InputMapSetValue(TypedDict):
+ """Set a specific value for a new node input."""
+ new_id: str
+ set_value: Any
+
+InputMap = InputMapOldId | InputMapSetValue
+"""
+Input mapping for node replacement. Type is inferred by dictionary keys:
+- {"new_id": str, "old_id": str} - maps old input to new input
+- {"new_id": str, "set_value": Any} - sets a specific value for new input
+"""
+
+class OutputMap(TypedDict):
+ """Map outputs of node replacement via indexes."""
+ new_idx: int
+ old_idx: int
+
+class NodeReplace:
+ """
+ Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
+
+ Also supports assigning specific values to the input widgets of the new node.
+
+ Args:
+ new_node_id: The class name of the new replacement node.
+ old_node_id: The class name of the deprecated node.
+ old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
+ connected. The workflow JSON stores widget values by their relative position index,
+ not by ID. This list maps those positional indexes to input IDs, enabling the
+ replacement system to correctly identify widget values during node migration.
+ input_mapping: List of input mappings from old node to new node.
+ output_mapping: List of output mappings from old node to new node.
+ """
+ def __init__(self,
+ new_node_id: str,
+ old_node_id: str,
+ old_widget_ids: list[str] | None=None,
+ input_mapping: list[InputMap] | None=None,
+ output_mapping: list[OutputMap] | None=None,
+ ):
+ self.new_node_id = new_node_id
+ self.old_node_id = old_node_id
+ self.old_widget_ids = old_widget_ids
+ self.input_mapping = input_mapping
+ self.output_mapping = output_mapping
+
+ def as_dict(self):
+ """Create serializable representation of the node replacement."""
+ return {
+ "new_node_id": self.new_node_id,
+ "old_node_id": self.old_node_id,
+ "old_widget_ids": self.old_widget_ids,
+ "input_mapping": list(self.input_mapping) if self.input_mapping else None,
+ "output_mapping": list(self.output_mapping) if self.output_mapping else None,
+ }
+
+
__all__ = [
"FolderType",
"UploadType",
@@ -2121,4 +2183,5 @@ __all__ = [
"ImageCompare",
"PriceBadgeDepends",
"PriceBadge",
+ "NodeReplace",
]
diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py
index a52a90e2c..66dac10b1 100644
--- a/comfy_extras/nodes_post_processing.py
+++ b/comfy_extras/nodes_post_processing.py
@@ -655,6 +655,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
batched = batch_masks(values)
return io.NodeOutput(batched)
+
class PostProcessingExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
diff --git a/comfy_extras/nodes_replacements.py b/comfy_extras/nodes_replacements.py
new file mode 100644
index 000000000..7684e854c
--- /dev/null
+++ b/comfy_extras/nodes_replacements.py
@@ -0,0 +1,103 @@
+from comfy_api.latest import ComfyExtension, io, ComfyAPI
+
+api = ComfyAPI()
+
+
+async def register_replacements():
+ """Register all built-in node replacements."""
+ await register_replacements_longeredge()
+ await register_replacements_batchimages()
+ await register_replacements_upscaleimage()
+ await register_replacements_controlnet()
+ await register_replacements_load3d()
+ await register_replacements_preview3d()
+ await register_replacements_svdimg2vid()
+ await register_replacements_conditioningavg()
+
+async def register_replacements_longeredge():
+ # No dynamic inputs here
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="ImageScaleToMaxDimension",
+ old_node_id="ResizeImagesByLongerEdge",
+ old_widget_ids=["longer_edge"],
+ input_mapping=[
+ {"new_id": "image", "old_id": "images"},
+ {"new_id": "largest_size", "old_id": "longer_edge"},
+ {"new_id": "upscale_method", "set_value": "lanczos"},
+ ],
+ # just to test the frontend output_mapping code, does nothing really here
+ output_mapping=[{"new_idx": 0, "old_idx": 0}],
+ ))
+
+async def register_replacements_batchimages():
+ # BatchImages node uses Autogrow
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="BatchImagesNode",
+ old_node_id="ImageBatch",
+ input_mapping=[
+ {"new_id": "images.image0", "old_id": "image1"},
+ {"new_id": "images.image1", "old_id": "image2"},
+ ],
+ ))
+
+async def register_replacements_upscaleimage():
+ # ResizeImageMaskNode uses DynamicCombo
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="ResizeImageMaskNode",
+ old_node_id="ImageScaleBy",
+ old_widget_ids=["upscale_method", "scale_by"],
+ input_mapping=[
+ {"new_id": "input", "old_id": "image"},
+ {"new_id": "resize_type", "set_value": "scale by multiplier"},
+ {"new_id": "resize_type.multiplier", "old_id": "scale_by"},
+ {"new_id": "scale_method", "old_id": "upscale_method"},
+ ],
+ ))
+
+async def register_replacements_controlnet():
+ # T2IAdapterLoader → ControlNetLoader
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="ControlNetLoader",
+ old_node_id="T2IAdapterLoader",
+ input_mapping=[
+ {"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
+ ],
+ ))
+
+async def register_replacements_load3d():
+ # Load3DAnimation merged into Load3D
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="Load3D",
+ old_node_id="Load3DAnimation",
+ ))
+
+async def register_replacements_preview3d():
+ # Preview3DAnimation merged into Preview3D
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="Preview3D",
+ old_node_id="Preview3DAnimation",
+ ))
+
+async def register_replacements_svdimg2vid():
+ # Typo fix: SDV → SVD
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="SVD_img2vid_Conditioning",
+ old_node_id="SDV_img2vid_Conditioning",
+ ))
+
+async def register_replacements_conditioningavg():
+ # Typo fix: trailing space in node name
+ await api.node_replacement.register(io.NodeReplace(
+ new_node_id="ConditioningAverage",
+ old_node_id="ConditioningAverage ",
+ ))
+
+class NodeReplacementsExtension(ComfyExtension):
+ async def on_load(self) -> None:
+ await register_replacements()
+
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return []
+
+async def comfy_entrypoint() -> NodeReplacementsExtension:
+ return NodeReplacementsExtension()
diff --git a/nodes.py b/nodes.py
index 91de7a9d7..db5f98408 100644
--- a/nodes.py
+++ b/nodes.py
@@ -2264,6 +2264,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
if not isinstance(extension, ComfyExtension):
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
return False
+ await extension.on_load()
node_list = await extension.get_node_list()
if not isinstance(node_list, list):
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
@@ -2435,6 +2436,7 @@ async def init_builtin_extra_nodes():
"nodes_lora_debug.py",
"nodes_color.py",
"nodes_toolkit.py",
+ "nodes_replacements.py",
]
import_failed = []
diff --git a/server.py b/server.py
index 2300393b2..8882e43c4 100644
--- a/server.py
+++ b/server.py
@@ -40,6 +40,7 @@ from app.user_manager import UserManager
from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager
+from app.node_replace_manager import NodeReplaceManager
from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes
@@ -204,6 +205,7 @@ class PromptServer():
self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager()
self.subgraph_manager = SubgraphManager()
+ self.node_replace_manager = NodeReplaceManager()
self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"]
self.prompt_queue = execution.PromptQueue(self)
@@ -887,6 +889,8 @@ class PromptServer():
if "partial_execution_targets" in json_data:
partial_execution_targets = json_data["partial_execution_targets"]
+ self.node_replace_manager.apply_replacements(prompt)
+
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
extra_data = {}
if "extra_data" in json_data:
@@ -995,6 +999,7 @@ class PromptServer():
self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
+ self.node_replace_manager.add_routes(self.routes)
self.app.add_subapp('/internal', self.internal_routes.get_app())
# Prefix every route with /api for easier matching for delegation.
From e2c71ceb0004da0d8a33dc9e79b31c2324241173 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Sun, 15 Feb 2026 15:33:18 +0200
Subject: [PATCH 38/85] feat(api-nodes-Tencent): add ModelTo3DUV,
3DTextureEdit, 3DParts nodes (#12428)
* feat(api-nodes-Tencent): add ModelTo3DUV, 3DTextureEdit, 3DParts nodes
* add image output to TencentModelTo3DUV node
* commented out two nodes
* added rate_limit check to other hunyuan3d nodes
---
comfy_api_nodes/apis/hunyuan3d.py | 20 ++
comfy_api_nodes/nodes_hunyuan3d.py | 295 ++++++++++++++++++++++++-
comfy_api_nodes/util/__init__.py | 2 +
comfy_api_nodes/util/upload_helpers.py | 21 ++
4 files changed, 326 insertions(+), 12 deletions(-)
diff --git a/comfy_api_nodes/apis/hunyuan3d.py b/comfy_api_nodes/apis/hunyuan3d.py
index 6421c9bd5..e84eba31e 100644
--- a/comfy_api_nodes/apis/hunyuan3d.py
+++ b/comfy_api_nodes/apis/hunyuan3d.py
@@ -64,3 +64,23 @@ class To3DProTaskResultResponse(BaseModel):
class To3DProTaskQueryRequest(BaseModel):
JobId: str = Field(...)
+
+
+class To3DUVFileInput(BaseModel):
+ Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
+ Url: str = Field(...)
+
+
+class To3DUVTaskRequest(BaseModel):
+ File: To3DUVFileInput = Field(...)
+
+
+class TextureEditImageInfo(BaseModel):
+ Url: str = Field(...)
+
+
+class TextureEditTaskRequest(BaseModel):
+ File3D: To3DUVFileInput = Field(...)
+ Image: TextureEditImageInfo | None = Field(None)
+ Prompt: str | None = Field(None)
+ EnablePBR: bool | None = Field(None)
diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py
index 813a7c809..ca002cc60 100644
--- a/comfy_api_nodes/nodes_hunyuan3d.py
+++ b/comfy_api_nodes/nodes_hunyuan3d.py
@@ -1,31 +1,48 @@
from typing_extensions import override
-from comfy_api.latest import IO, ComfyExtension, Input
+from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.hunyuan3d import (
Hunyuan3DViewImage,
InputGenerateType,
ResultFile3D,
+ TextureEditTaskRequest,
To3DProTaskCreateResponse,
To3DProTaskQueryRequest,
To3DProTaskRequest,
To3DProTaskResultResponse,
+ To3DUVFileInput,
+ To3DUVTaskRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_file_3d,
+ download_url_to_image_tensor,
downscale_image_tensor_by_max_side,
poll_op,
sync_op,
+ upload_3d_model_to_comfyapi,
upload_image_to_comfyapi,
validate_image_dimensions,
validate_string,
)
-def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None:
+def _is_tencent_rate_limited(status: int, body: object) -> bool:
+ return (
+ status == 400
+ and isinstance(body, dict)
+ and "RequestLimitExceeded" in str(body.get("Response", {}).get("Error", {}).get("Code", ""))
+ )
+
+
+def get_file_from_response(
+ response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
+) -> ResultFile3D | None:
for i in response_objs:
if i.Type.lower() == file_type.lower():
return i
+ if raise_if_not_found:
+ raise ValueError(f"'{file_type}' file type is not found in the response.")
return None
@@ -35,7 +52,7 @@ class TencentTextToModelNode(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TencentTextToModelNode",
- display_name="Hunyuan3D: Text to Model (Pro)",
+ display_name="Hunyuan3D: Text to Model",
category="api node/3d/Tencent",
inputs=[
IO.Combo.Input(
@@ -120,6 +137,7 @@ class TencentTextToModelNode(IO.ComfyNode):
EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None),
),
+ is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
@@ -131,11 +149,14 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
- glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
- obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
- file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
return IO.NodeOutput(
- file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
+ f"{task_id}.glb",
+ await download_url_to_file_3d(
+ get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
+ ),
+ await download_url_to_file_3d(
+ get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
+ ),
)
@@ -145,7 +166,7 @@ class TencentImageToModelNode(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TencentImageToModelNode",
- display_name="Hunyuan3D: Image(s) to Model (Pro)",
+ display_name="Hunyuan3D: Image(s) to Model",
category="api node/3d/Tencent",
inputs=[
IO.Combo.Input(
@@ -268,6 +289,7 @@ class TencentImageToModelNode(IO.ComfyNode):
EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None),
),
+ is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
@@ -279,11 +301,257 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
- glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
- obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
- file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
return IO.NodeOutput(
- file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None
+ f"{task_id}.glb",
+ await download_url_to_file_3d(
+ get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
+ ),
+ await download_url_to_file_3d(
+ get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
+ ),
+ )
+
+
+class TencentModelTo3DUVNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="TencentModelTo3DUVNode",
+ display_name="Hunyuan3D: Model to UV",
+ category="api node/3d/Tencent",
+ description="Perform UV unfolding on a 3D model to generate UV texture. "
+ "Input model must have less than 30000 faces.",
+ inputs=[
+ IO.MultiType.Input(
+ "model_3d",
+ types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DFBX, IO.File3DAny],
+ tooltip="Input 3D model (GLB, OBJ, or FBX)",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=1,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.File3DOBJ.Output(display_name="OBJ"),
+ IO.File3DFBX.Output(display_name="FBX"),
+ IO.Image.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.2}'),
+ )
+
+ SUPPORTED_FORMATS = {"glb", "obj", "fbx"}
+
+ @classmethod
+ async def execute(
+ cls,
+ model_3d: Types.File3D,
+ seed: int,
+ ) -> IO.NodeOutput:
+ _ = seed
+ file_format = model_3d.format.lower()
+ if file_format not in cls.SUPPORTED_FORMATS:
+ raise ValueError(
+ f"Unsupported file format: '{file_format}'. "
+ f"Supported formats: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
+ )
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
+ response_model=To3DProTaskCreateResponse,
+ data=To3DUVTaskRequest(
+ File=To3DUVFileInput(
+ Type=file_format.upper(),
+ Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
+ )
+ ),
+ is_rate_limited=_is_tencent_rate_limited,
+ )
+ if response.Error:
+ raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
+ result = await poll_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv/query", method="POST"),
+ data=To3DProTaskQueryRequest(JobId=response.JobId),
+ response_model=To3DProTaskResultResponse,
+ status_extractor=lambda r: r.Status,
+ )
+ return IO.NodeOutput(
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
+ await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
+ )
+
+
+class Tencent3DTextureEditNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="Tencent3DTextureEditNode",
+ display_name="Hunyuan3D: 3D Texture Edit",
+ category="api node/3d/Tencent",
+ description="After inputting the 3D model, perform 3D model texture redrawing.",
+ inputs=[
+ IO.MultiType.Input(
+ "model_3d",
+ types=[IO.File3DFBX, IO.File3DAny],
+ tooltip="3D model in FBX format. Model should have less than 100000 faces.",
+ ),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ default="",
+ tooltip="Describes texture editing. Supports up to 1024 UTF-8 characters.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.File3DGLB.Output(display_name="GLB"),
+ IO.File3DFBX.Output(display_name="FBX"),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd": 0.6}""",
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_3d: Types.File3D,
+ prompt: str,
+ seed: int,
+ ) -> IO.NodeOutput:
+ _ = seed
+ file_format = model_3d.format.lower()
+ if file_format != "fbx":
+ raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
+ validate_string(prompt, field_name="prompt", min_length=1, max_length=1024)
+ model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
+ response_model=To3DProTaskCreateResponse,
+ data=TextureEditTaskRequest(
+ File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
+ Prompt=prompt,
+ EnablePBR=True,
+ ),
+ is_rate_limited=_is_tencent_rate_limited,
+ )
+ if response.Error:
+ raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
+
+ result = await poll_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit/query", method="POST"),
+ data=To3DProTaskQueryRequest(JobId=response.JobId),
+ response_model=To3DProTaskResultResponse,
+ status_extractor=lambda r: r.Status,
+ )
+ return IO.NodeOutput(
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
+ )
+
+
+class Tencent3DPartNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="Tencent3DPartNode",
+ display_name="Hunyuan3D: 3D Part",
+ category="api node/3d/Tencent",
+ description="Automatically perform component identification and generation based on the model structure.",
+ inputs=[
+ IO.MultiType.Input(
+ "model_3d",
+ types=[IO.File3DFBX, IO.File3DAny],
+ tooltip="3D model in FBX format. Model should have less than 30000 faces.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[
+ IO.File3DFBX.Output(display_name="FBX"),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.6}'),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model_3d: Types.File3D,
+ seed: int,
+ ) -> IO.NodeOutput:
+ _ = seed
+ file_format = model_3d.format.lower()
+ if file_format != "fbx":
+ raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
+ model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
+ response_model=To3DProTaskCreateResponse,
+ data=To3DUVTaskRequest(
+ File=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
+ ),
+ is_rate_limited=_is_tencent_rate_limited,
+ )
+ if response.Error:
+ raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
+ result = await poll_op(
+ cls,
+ ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part/query", method="POST"),
+ data=To3DProTaskQueryRequest(JobId=response.JobId),
+ response_model=To3DProTaskResultResponse,
+ status_extractor=lambda r: r.Status,
+ )
+ return IO.NodeOutput(
+ await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
)
@@ -293,6 +561,9 @@ class TencentHunyuan3DExtension(ComfyExtension):
return [
TencentTextToModelNode,
TencentImageToModelNode,
+ # TencentModelTo3DUVNode,
+ # Tencent3DTextureEditNode,
+ Tencent3DPartNode,
]
diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py
index 18b020eef..f8a0ba8af 100644
--- a/comfy_api_nodes/util/__init__.py
+++ b/comfy_api_nodes/util/__init__.py
@@ -33,6 +33,7 @@ from .download_helpers import (
download_url_to_video_output,
)
from .upload_helpers import (
+ upload_3d_model_to_comfyapi,
upload_audio_to_comfyapi,
upload_file_to_comfyapi,
upload_image_to_comfyapi,
@@ -62,6 +63,7 @@ __all__ = [
"sync_op",
"sync_op_raw",
# Upload helpers
+ "upload_3d_model_to_comfyapi",
"upload_audio_to_comfyapi",
"upload_file_to_comfyapi",
"upload_image_to_comfyapi",
diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py
index 7cc565263..6d1d107a1 100644
--- a/comfy_api_nodes/util/upload_helpers.py
+++ b/comfy_api_nodes/util/upload_helpers.py
@@ -164,6 +164,27 @@ async def upload_video_to_comfyapi(
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
+_3D_MIME_TYPES = {
+ "glb": "model/gltf-binary",
+ "obj": "model/obj",
+ "fbx": "application/octet-stream",
+}
+
+
+async def upload_3d_model_to_comfyapi(
+ cls: type[IO.ComfyNode],
+ model_3d: Types.File3D,
+ file_format: str,
+) -> str:
+ """Uploads a 3D model file to ComfyUI API and returns its download URL."""
+ return await upload_file_to_comfyapi(
+ cls,
+ model_3d.get_data(),
+ f"{uuid.uuid4()}.{file_format}",
+ _3D_MIME_TYPES.get(file_format, "application/octet-stream"),
+ )
+
+
async def upload_file_to_comfyapi(
cls: type[IO.ComfyNode],
file_bytes_io: BytesIO,
From 2c1d06a4e32900c260bcb3d0888f20edc1e3e5ab Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Mon, 16 Feb 2026 03:22:30 +0200
Subject: [PATCH 39/85] feat(api-nodes): add Bria RMBG nodes (#12465)
Co-authored-by: Jedrzej Kosinski
---
comfy_api_nodes/apis/bria.py | 44 +++++-
comfy_api_nodes/nodes_bria.py | 220 ++++++++++++++++++++++------
comfy_api_nodes/util/conversions.py | 2 +-
3 files changed, 218 insertions(+), 48 deletions(-)
diff --git a/comfy_api_nodes/apis/bria.py b/comfy_api_nodes/apis/bria.py
index 9119cacc6..8c496b56c 100644
--- a/comfy_api_nodes/apis/bria.py
+++ b/comfy_api_nodes/apis/bria.py
@@ -45,17 +45,55 @@ class BriaEditImageRequest(BaseModel):
)
+class BriaRemoveBackgroundRequest(BaseModel):
+ image: str = Field(...)
+ sync: bool = Field(False)
+ visual_input_content_moderation: bool = Field(
+ False, description="If true, returns 422 on input image moderation failure."
+ )
+ visual_output_content_moderation: bool = Field(
+ False, description="If true, returns 422 on visual output moderation failure."
+ )
+ seed: int = Field(...)
+
+
class BriaStatusResponse(BaseModel):
request_id: str = Field(...)
status_url: str = Field(...)
warning: str | None = Field(None)
-class BriaResult(BaseModel):
+class BriaRemoveBackgroundResult(BaseModel):
+ image_url: str = Field(...)
+
+
+class BriaRemoveBackgroundResponse(BaseModel):
+ status: str = Field(...)
+ result: BriaRemoveBackgroundResult | None = Field(None)
+
+
+class BriaImageEditResult(BaseModel):
structured_prompt: str = Field(...)
image_url: str = Field(...)
-class BriaResponse(BaseModel):
+class BriaImageEditResponse(BaseModel):
status: str = Field(...)
- result: BriaResult | None = Field(None)
+ result: BriaImageEditResult | None = Field(None)
+
+
+class BriaRemoveVideoBackgroundRequest(BaseModel):
+ video: str = Field(...)
+ background_color: str = Field(default="transparent", description="Background color for the output video.")
+ output_container_and_codec: str = Field(...)
+ preserve_audio: bool = Field(True)
+ seed: int = Field(...)
+
+
+class BriaRemoveVideoBackgroundResult(BaseModel):
+ video_url: str = Field(...)
+
+
+class BriaRemoveVideoBackgroundResponse(BaseModel):
+ status: str = Field(...)
+ result: BriaRemoveVideoBackgroundResult | None = Field(None)
diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py
index d3a52bc1b..4044ee3ea 100644
--- a/comfy_api_nodes/nodes_bria.py
+++ b/comfy_api_nodes/nodes_bria.py
@@ -3,7 +3,11 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bria import (
BriaEditImageRequest,
- BriaResponse,
+ BriaRemoveBackgroundRequest,
+ BriaRemoveBackgroundResponse,
+ BriaRemoveVideoBackgroundRequest,
+ BriaRemoveVideoBackgroundResponse,
+ BriaImageEditResponse,
BriaStatusResponse,
InputModerationSettings,
)
@@ -11,10 +15,12 @@ from comfy_api_nodes.util import (
ApiEndpoint,
convert_mask_to_image,
download_url_to_image_tensor,
- get_number_of_images,
+ download_url_to_video_output,
poll_op,
sync_op,
- upload_images_to_comfyapi,
+ upload_image_to_comfyapi,
+ upload_video_to_comfyapi,
+ validate_video_duration,
)
@@ -73,21 +79,15 @@ class BriaImageEditNode(IO.ComfyNode):
IO.DynamicCombo.Input(
"moderation",
options=[
+ IO.DynamicCombo.Option("false", []),
IO.DynamicCombo.Option(
"true",
[
- IO.Boolean.Input(
- "prompt_content_moderation", default=False
- ),
- IO.Boolean.Input(
- "visual_input_moderation", default=False
- ),
- IO.Boolean.Input(
- "visual_output_moderation", default=True
- ),
+ IO.Boolean.Input("prompt_content_moderation", default=False),
+ IO.Boolean.Input("visual_input_moderation", default=False),
+ IO.Boolean.Input("visual_output_moderation", default=True),
],
),
- IO.DynamicCombo.Option("false", []),
],
tooltip="Moderation settings",
),
@@ -127,50 +127,26 @@ class BriaImageEditNode(IO.ComfyNode):
mask: Input.Image | None = None,
) -> IO.NodeOutput:
if not prompt and not structured_prompt:
- raise ValueError(
- "One of prompt or structured_prompt is required to be non-empty."
- )
- if get_number_of_images(image) != 1:
- raise ValueError("Exactly one input image is required.")
+ raise ValueError("One of prompt or structured_prompt is required to be non-empty.")
mask_url = None
if mask is not None:
- mask_url = (
- await upload_images_to_comfyapi(
- cls,
- convert_mask_to_image(mask),
- max_images=1,
- mime_type="image/png",
- wait_label="Uploading mask",
- )
- )[0]
+ mask_url = await upload_image_to_comfyapi(cls, convert_mask_to_image(mask), wait_label="Uploading mask")
response = await sync_op(
cls,
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
data=BriaEditImageRequest(
instruction=prompt if prompt else None,
structured_instruction=structured_prompt if structured_prompt else None,
- images=await upload_images_to_comfyapi(
- cls,
- image,
- max_images=1,
- mime_type="image/png",
- wait_label="Uploading image",
- ),
+ images=[await upload_image_to_comfyapi(cls, image, wait_label="Uploading image")],
mask=mask_url,
negative_prompt=negative_prompt if negative_prompt else None,
guidance_scale=guidance_scale,
seed=seed,
model_version=model,
steps_num=steps,
- prompt_content_moderation=moderation.get(
- "prompt_content_moderation", False
- ),
- visual_input_content_moderation=moderation.get(
- "visual_input_moderation", False
- ),
- visual_output_content_moderation=moderation.get(
- "visual_output_moderation", False
- ),
+ prompt_content_moderation=moderation.get("prompt_content_moderation", False),
+ visual_input_content_moderation=moderation.get("visual_input_moderation", False),
+ visual_output_content_moderation=moderation.get("visual_output_moderation", False),
),
response_model=BriaStatusResponse,
)
@@ -178,7 +154,7 @@ class BriaImageEditNode(IO.ComfyNode):
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
- response_model=BriaResponse,
+ response_model=BriaImageEditResponse,
)
return IO.NodeOutput(
await download_url_to_image_tensor(response.result.image_url),
@@ -186,11 +162,167 @@ class BriaImageEditNode(IO.ComfyNode):
)
+class BriaRemoveImageBackground(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="BriaRemoveImageBackground",
+ display_name="Bria Remove Image Background",
+ category="api node/image/Bria",
+ description="Remove the background from an image using Bria RMBG 2.0.",
+ inputs=[
+ IO.Image.Input("image"),
+ IO.DynamicCombo.Input(
+ "moderation",
+ options=[
+ IO.DynamicCombo.Option("false", []),
+ IO.DynamicCombo.Option(
+ "true",
+ [
+ IO.Boolean.Input("visual_input_moderation", default=False),
+ IO.Boolean.Input("visual_output_moderation", default=True),
+ ],
+ ),
+ ],
+ tooltip="Moderation settings",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[IO.Image.Output()],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd":0.018}""",
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ image: Input.Image,
+ moderation: dict,
+ seed: int,
+ ) -> IO.NodeOutput:
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/bria/v2/image/edit/remove_background", method="POST"),
+ data=BriaRemoveBackgroundRequest(
+ image=await upload_image_to_comfyapi(cls, image, wait_label="Uploading image"),
+ sync=False,
+ visual_input_content_moderation=moderation.get("visual_input_moderation", False),
+ visual_output_content_moderation=moderation.get("visual_output_moderation", False),
+ seed=seed,
+ ),
+ response_model=BriaStatusResponse,
+ )
+ response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
+ status_extractor=lambda r: r.status,
+ response_model=BriaRemoveBackgroundResponse,
+ )
+ return IO.NodeOutput(await download_url_to_image_tensor(response.result.image_url))
+
+
+class BriaRemoveVideoBackground(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="BriaRemoveVideoBackground",
+ display_name="Bria Remove Video Background",
+ category="api node/video/Bria",
+ description="Remove the background from a video using Bria. ",
+ inputs=[
+ IO.Video.Input("video"),
+ IO.Combo.Input(
+ "background_color",
+ options=[
+ "Black",
+ "White",
+ "Gray",
+ "Red",
+ "Green",
+ "Blue",
+ "Yellow",
+ "Cyan",
+ "Magenta",
+ "Orange",
+ ],
+ tooltip="Background color for the output video.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=2147483647,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ tooltip="Seed controls whether the node should re-run; "
+ "results are non-deterministic regardless of seed.",
+ ),
+ ],
+ outputs=[IO.Video.Output()],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ video: Input.Video,
+ background_color: str,
+ seed: int,
+ ) -> IO.NodeOutput:
+ validate_video_duration(video, max_duration=60.0)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
+ data=BriaRemoveVideoBackgroundRequest(
+ video=await upload_video_to_comfyapi(cls, video),
+ background_color=background_color,
+ output_container_and_codec="mp4_h264",
+ seed=seed,
+ ),
+ response_model=BriaStatusResponse,
+ )
+ response = await poll_op(
+ cls,
+ ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
+ status_extractor=lambda r: r.status,
+ response_model=BriaRemoveVideoBackgroundResponse,
+ )
+ return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
+
+
class BriaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
BriaImageEditNode,
+ BriaRemoveImageBackground,
+ BriaRemoveVideoBackground,
]
diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py
index 3e37e8a8c..82b6d22a5 100644
--- a/comfy_api_nodes/util/conversions.py
+++ b/comfy_api_nodes/util/conversions.py
@@ -57,7 +57,7 @@ def tensor_to_bytesio(
image: torch.Tensor,
*,
total_pixels: int | None = 2048 * 2048,
- mime_type: str = "image/png",
+ mime_type: str | None = "image/png",
) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object.
From ecd2a19661ecccd96e26f111af21781f3e613f59 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Sun, 15 Feb 2026 17:28:51 -0800
Subject: [PATCH 40/85] Fix lora Extraction in offload conditions (+
dynamic_vram mode) (#12479)
* lora_extract: Add a trange
If you bite off more than your GPU can chew, this kinda just hangs.
Give a rough indication of progress counting the weights in a trange.
* lora_extract: Support on-the-fly patching
Use the on-the-fly approach from the regular model saving logic for
lora extraction too. Switch off force_cast_weights accordingly.
This gets extraction working in dynamic vram while also supporting
extraction on GPU offloaded.
---
comfy_extras/nodes_lora_extract.py | 21 ++++++++++++++++-----
1 file changed, 16 insertions(+), 5 deletions(-)
diff --git a/comfy_extras/nodes_lora_extract.py b/comfy_extras/nodes_lora_extract.py
index fb89e03f4..1542d0a88 100644
--- a/comfy_extras/nodes_lora_extract.py
+++ b/comfy_extras/nodes_lora_extract.py
@@ -7,6 +7,7 @@ import logging
from enum import Enum
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
+from tqdm.auto import trange
CLAMP_QUANTILE = 0.99
@@ -49,12 +50,22 @@ LORA_TYPES = {"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
- comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True)
+ comfy.model_management.load_models_gpu([model_diff])
sd = model_diff.model_state_dict(filter_prefix=prefix_model)
- for k in sd:
- if k.endswith(".weight"):
+ sd_keys = list(sd.keys())
+ for index in trange(len(sd_keys), unit="weight"):
+ k = sd_keys[index]
+ op_keys = sd_keys[index].rsplit('.', 1)
+ if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff):
+ continue
+ op = comfy.utils.get_attr(model_diff.model, op_keys[0])
+ if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False):
+ weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True)
+ else:
weight_diff = sd[k]
+
+ if op_keys[1] == "weight":
if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2:
if bias_diff:
@@ -69,8 +80,8 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
- elif bias_diff and k.endswith(".bias"):
- output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu()
+ elif bias_diff and op_keys[1] == "bias":
+ output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu()
return output_sd
class LoraSave(io.ComfyNode):
From c0370044cd467b92f4db63b88029ebc700388d36 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Sun, 15 Feb 2026 17:30:09 -0800
Subject: [PATCH 41/85] MPDynamic: force load flux img_in weight (Fixes flux1
canny+depth lora crash) (#12446)
* lora: add weight shape calculations.
This lets the loader know if a lora will change the shape of a weight
so it can take appropriate action.
* MPDynamic: force load flux img_in weight
This weight is a bit special, in that the lora changes its geometry.
This is rather unique, not handled by existing estimate and doesn't
work for either offloading or dynamic_vram.
Fix for dynamic_vram as a special case. Ideally we can fully precalculate
these lora geometry changes at load time, but just get these models
working first.
---
comfy/lora.py | 25 +++++++++++++++++++++++++
comfy/model_patcher.py | 35 +++++++++++++++++++++++++++--------
comfy/weight_adapter/base.py | 6 ++++++
comfy/weight_adapter/lora.py | 7 +++++++
4 files changed, 65 insertions(+), 8 deletions(-)
diff --git a/comfy/lora.py b/comfy/lora.py
index 44030bcab..279cf38bb 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor
+def calculate_shape(patches, weight, key, original_weights=None):
+ current_shape = weight.shape
+
+ for p in patches:
+ v = p[1]
+ offset = p[3]
+
+ # Offsets restore the old shape; lists force a diff without metadata
+ if offset is not None or isinstance(v, list):
+ continue
+
+ if isinstance(v, weight_adapter.WeightAdapterBase):
+ adapter_shape = v.calculate_shape(key)
+ if adapter_shape is not None:
+ current_shape = adapter_shape
+ continue
+
+ # Standard diff logic with padding
+ if len(v) == 2:
+ patch_type, patch_data = v[0], v[1]
+ if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
+ current_shape = patch_data[0].shape
+
+ return current_shape
+
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches:
strength = p[0]
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 67dce088e..f01818f50 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -1514,8 +1514,10 @@ class ModelPatcherDynamic(ModelPatcher):
weight, _, _ = get_key_weight(self.model, key)
if weight is None:
- return 0
+ return (False, 0)
if key in self.patches:
+ if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
+ return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1
else:
@@ -1529,7 +1531,13 @@ class ModelPatcherDynamic(ModelPatcher):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
- return comfy.memory_management.vram_aligned_size(geometry)
+ return (False, comfy.memory_management.vram_aligned_size(geometry))
+
+ def force_load_param(self, param_key, device_to):
+ key = key_param_name_to_key(n, param_key)
+ if key in self.backup:
+ comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
+ self.patch_weight_to_device(key, device_to=device_to)
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
@@ -1537,13 +1545,19 @@ class ModelPatcherDynamic(ModelPatcher):
m.seed_key = n
set_dirty(m, dirty)
- v_weight_size = 0
- v_weight_size += setup_param(self, m, n, "weight")
- v_weight_size += setup_param(self, m, n, "bias")
+ force_load, v_weight_size = setup_param(self, m, n, "weight")
+ force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
+ force_load = force_load or force_load_bias
+ v_weight_size += v_weight_bias
- if vbar is not None and not hasattr(m, "_v"):
- m._v = vbar.alloc(v_weight_size)
- allocated_size += v_weight_size
+ if force_load:
+ logging.info(f"Module {n} has resizing Lora - force loading")
+ force_load_param(self, "weight", device_to)
+ force_load_param(self, "bias", device_to)
+ else:
+ if vbar is not None and not hasattr(m, "_v"):
+ m._v = vbar.alloc(v_weight_size)
+ allocated_size += v_weight_size
else:
for param in params:
@@ -1606,6 +1620,11 @@ class ModelPatcherDynamic(ModelPatcher):
for m in self.model.modules():
move_weight_functions(m, device_to)
+ keys = list(self.backup.keys())
+ for k in keys:
+ bk = self.backup[k]
+ comfy.utils.set_attr_param(self.model, k, bk.weight)
+
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
with self.use_ejected(skip_and_inject_on_exit_only=True):
diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py
index bce89a0e2..d352e066b 100644
--- a/comfy/weight_adapter/base.py
+++ b/comfy/weight_adapter/base.py
@@ -49,6 +49,12 @@ class WeightAdapterBase:
"""
raise NotImplementedError
+ def calculate_shape(
+ self,
+ key
+ ):
+ return None
+
def calculate_weight(
self,
weight,
diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py
index bc4260a8f..8e1261a12 100644
--- a/comfy/weight_adapter/lora.py
+++ b/comfy/weight_adapter/lora.py
@@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase):
else:
return None
+ def calculate_shape(
+ self,
+ key
+ ):
+ reshape = self.weights[5]
+ return tuple(reshape) if reshape is not None else None
+
def calculate_weight(
self,
weight,
From 88e6370527dbd602851de07d957a8f17b3ca9447 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Sun, 15 Feb 2026 17:43:53 -0800
Subject: [PATCH 42/85] Remove workaround for old pytorch. (#12480)
---
comfy/ldm/modules/diffusionmodules/model.py | 14 +-------------
1 file changed, 1 insertion(+), 13 deletions(-)
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index 5a22ef030..805592aa5 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -102,19 +102,7 @@ class VideoConv3d(nn.Module):
return self.conv(x)
def interpolate_up(x, scale_factor):
- try:
- return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
- except: #operation not implemented for bf16
- orig_shape = list(x.shape)
- out_shape = orig_shape[:2]
- for i in range(len(orig_shape) - 2):
- out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
- out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
- split = 8
- l = out.shape[1] // split
- for i in range(0, out.shape[1], l):
- out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
- return out
+ return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
From 1978f59ffdf242389ded3eec76274a4cbed9cc3d Mon Sep 17 00:00:00 2001
From: ComfyUI Wiki
Date: Tue, 17 Feb 2026 06:33:43 +0800
Subject: [PATCH 43/85] chore: update workflow templates to v0.8.42 (#12491)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index e939e486a..0930bbbb8 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
comfyui-frontend-package==1.38.14
-comfyui-workflow-templates==0.8.38
+comfyui-workflow-templates==0.8.42
comfyui-embedded-docs==0.4.1
torch
torchsde
From 4454fab7f003c655e07f059c315e2aae0e5fb087 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Mon, 16 Feb 2026 17:09:24 -0800
Subject: [PATCH 44/85] Remove code to support RMSNorm on old pytorch. (#12499)
---
comfy/ops.py | 6 ++----
comfy/rmsnorm.py | 55 ++++--------------------------------------------
2 files changed, 6 insertions(+), 55 deletions(-)
diff --git a/comfy/ops.py b/comfy/ops.py
index 688937e43..026062f56 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -21,7 +21,6 @@ import logging
import comfy.model_management
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import comfy.float
-import comfy.rmsnorm
import json
import comfy.memory_management
import comfy.pinned_memory
@@ -463,7 +462,7 @@ class disable_weight_init:
else:
return super().forward(*args, **kwargs)
- class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
+ class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
def reset_parameters(self):
self.bias = None
return None
@@ -475,8 +474,7 @@ class disable_weight_init:
weight = None
bias = None
offload_stream = None
- x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
- # x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
+ x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py
index 555542a46..ab7cf14fa 100644
--- a/comfy/rmsnorm.py
+++ b/comfy/rmsnorm.py
@@ -1,57 +1,10 @@
import torch
import comfy.model_management
-import numbers
-import logging
-
-RMSNorm = None
-
-try:
- rms_norm_torch = torch.nn.functional.rms_norm
- RMSNorm = torch.nn.RMSNorm
-except:
- rms_norm_torch = None
- logging.warning("Please update pytorch to use native RMSNorm")
+RMSNorm = torch.nn.RMSNorm
def rms_norm(x, weight=None, eps=1e-6):
- if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
- if weight is None:
- return rms_norm_torch(x, (x.shape[-1],), eps=eps)
- else:
- return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
+ if weight is None:
+ return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
else:
- r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
- if weight is None:
- return r
- else:
- return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
-
-
-if RMSNorm is None:
- class RMSNorm(torch.nn.Module):
- def __init__(
- self,
- normalized_shape,
- eps=1e-6,
- elementwise_affine=True,
- device=None,
- dtype=None,
- ):
- factory_kwargs = {"device": device, "dtype": dtype}
- super().__init__()
- if isinstance(normalized_shape, numbers.Integral):
- # mypy error: incompatible types in assignment
- normalized_shape = (normalized_shape,) # type: ignore[assignment]
- self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
- self.eps = eps
- self.elementwise_affine = elementwise_affine
- if self.elementwise_affine:
- self.weight = torch.nn.Parameter(
- torch.empty(self.normalized_shape, **factory_kwargs)
- )
- else:
- self.register_parameter("weight", None)
- self.bias = None
-
- def forward(self, x):
- return rms_norm(x, self.weight, self.eps)
+ return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
From b44fc4c589c66e39686239d6eff7d6088668c9a8 Mon Sep 17 00:00:00 2001
From: Alex Butler
Date: Tue, 17 Feb 2026 03:16:19 +0000
Subject: [PATCH 45/85] add venv* to gitignore (#12431)
---
.gitignore | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/.gitignore b/.gitignore
index 4e8cea71e..2700ad5c2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,7 +11,7 @@ extra_model_paths.yaml
/.vs
.vscode/
.idea/
-venv/
+venv*/
.venv/
/web/extensions/*
!/web/extensions/logging.js.example
From 8a6fbc2dc29d0b15c1e9655c24e7501829249995 Mon Sep 17 00:00:00 2001
From: Jedrzej Kosinski
Date: Mon, 16 Feb 2026 19:20:21 -0800
Subject: [PATCH 46/85] Allow control_after_generate to be type
ControlAfterGenerate in v3 schema (#12187)
---
comfy_api/latest/_io.py | 13 ++++++++++---
1 file changed, 10 insertions(+), 3 deletions(-)
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index 95d79c035..d18330d0b 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -75,6 +75,12 @@ class NumberDisplay(str, Enum):
slider = "slider"
+class ControlAfterGenerate(str, Enum):
+ fixed = "fixed"
+ increment = "increment"
+ decrement = "decrement"
+ randomize = "randomize"
+
class _ComfyType(ABC):
Type = Any
io_type: str = None
@@ -263,7 +269,7 @@ class Int(ComfyTypeIO):
class Input(WidgetInput):
'''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
- default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
+ default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min
@@ -345,7 +351,7 @@ class Combo(ComfyTypeIO):
tooltip: str=None,
lazy: bool=None,
default: str | int | Enum = None,
- control_after_generate: bool=None,
+ control_after_generate: bool | ControlAfterGenerate=None,
upload: UploadType=None,
image_folder: FolderType=None,
remote: RemoteOptions=None,
@@ -389,7 +395,7 @@ class MultiCombo(ComfyTypeI):
Type = list[str]
class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
- default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
+ default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
self.multiselect = True
@@ -2097,6 +2103,7 @@ __all__ = [
"UploadType",
"RemoteOptions",
"NumberDisplay",
+ "ControlAfterGenerate",
"comfytype",
"Custom",
From 18927538a15d44c734653513e9fdbbe1e79a9f0c Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Mon, 16 Feb 2026 20:30:34 -0800
Subject: [PATCH 47/85] Implement NAG on all the models based on the Flux code.
(#12500)
Use the Normalized Attention Guidance node.
Flux, Flux2, Klein, Chroma, Chroma radiance, Hunyuan Video, etc..
---
comfy/ldm/chroma/model.py | 2 +
comfy/ldm/flux/layers.py | 18 ++++++
comfy/ldm/flux/model.py | 2 +
comfy/ldm/hunyuan_video/model.py | 2 +
comfy/model_patcher.py | 5 +-
comfy_extras/nodes_nag.py | 99 ++++++++++++++++++++++++++++++++
nodes.py | 1 +
7 files changed, 128 insertions(+), 1 deletion(-)
create mode 100644 comfy_extras/nodes_nag.py
diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py
index 2e8ef0687..9fd865f20 100644
--- a/comfy/ldm/chroma/model.py
+++ b/comfy/ldm/chroma/model.py
@@ -152,6 +152,7 @@ class Chroma(nn.Module):
transformer_options={},
attn_mask: Tensor = None,
) -> Tensor:
+ transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
# running on sequences img
@@ -228,6 +229,7 @@ class Chroma(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
+ transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if i not in self.skip_dit:
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index 3518a1922..8b3f500d7 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -196,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
+ transformer_patches = transformer_options.get("patches", {})
+ extra_options = transformer_options.copy()
+
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
@@ -224,6 +227,12 @@ class DoubleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
+ if "attn1_output_patch" in transformer_patches:
+ extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
+ patch = transformer_patches["attn1_output_patch"]
+ for p in patch:
+ attn = p(attn, extra_options)
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
@@ -303,6 +312,9 @@ class SingleStreamBlock(nn.Module):
else:
mod = vec
+ transformer_patches = transformer_options.get("patches", {})
+ extra_options = transformer_options.copy()
+
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@@ -312,6 +324,12 @@ class SingleStreamBlock(nn.Module):
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
+
+ if "attn1_output_patch" in transformer_patches:
+ patch = transformer_patches["attn1_output_patch"]
+ for p in patch:
+ attn = p(attn, extra_options)
+
# compute activation in mlp stream, cat again and run second linear layer
if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index 260ccad7e..ef4dcf7c5 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -142,6 +142,7 @@ class Flux(nn.Module):
attn_mask: Tensor = None,
) -> Tensor:
+ transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3:
@@ -231,6 +232,7 @@ class Flux(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
+ transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py
index 563f28f6b..b94cdfa87 100644
--- a/comfy/ldm/hunyuan_video/model.py
+++ b/comfy/ldm/hunyuan_video/model.py
@@ -304,6 +304,7 @@ class HunyuanVideo(nn.Module):
control=None,
transformer_options={},
) -> Tensor:
+ transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape)
@@ -416,6 +417,7 @@ class HunyuanVideo(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
+ transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index f01818f50..21b4ce53e 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -406,13 +406,16 @@ class ModelPatcher:
def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape)
+ def disable_model_cfg1_optimization(self):
+ self.model_options["disable_cfg1_optimization"] = True
+
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization:
- self.model_options["disable_cfg1_optimization"] = True
+ self.disable_model_cfg1_optimization()
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
diff --git a/comfy_extras/nodes_nag.py b/comfy_extras/nodes_nag.py
new file mode 100644
index 000000000..033e40eb9
--- /dev/null
+++ b/comfy_extras/nodes_nag.py
@@ -0,0 +1,99 @@
+import torch
+from comfy_api.latest import ComfyExtension, io
+from typing_extensions import override
+
+
+class NAGuidance(io.ComfyNode):
+ @classmethod
+ def define_schema(cls) -> io.Schema:
+ return io.Schema(
+ node_id="NAGuidance",
+ display_name="Normalized Attention Guidance",
+ description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
+ category="",
+ is_experimental=True,
+ inputs=[
+ io.Model.Input("model", tooltip="The model to apply NAG to."),
+ io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
+ io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
+ io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
+ # io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
+ # io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
+ ],
+ outputs=[
+ io.Model.Output(tooltip="The patched model with NAG enabled."),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
+ m = model.clone()
+
+ # sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
+ # sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
+
+ def nag_attention_output_patch(out, extra_options):
+ cond_or_uncond = extra_options.get("cond_or_uncond", None)
+ if cond_or_uncond is None:
+ return out
+
+ if not (1 in cond_or_uncond and 0 in cond_or_uncond):
+ return out
+
+ # sigma = extra_options.get("sigmas", None)
+ # if sigma is not None and len(sigma) > 0:
+ # sigma = sigma[0].item()
+ # if sigma > sigma_start or sigma < sigma_end:
+ # return out
+
+ img_slice = extra_options.get("img_slice", None)
+
+ if img_slice is not None:
+ orig_out = out
+ out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
+
+ batch_size = out.shape[0]
+ half_size = batch_size // len(cond_or_uncond)
+
+ ind_neg = cond_or_uncond.index(1)
+ ind_pos = cond_or_uncond.index(0)
+ z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
+ z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
+
+ guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
+
+ eps = 1e-6
+ norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
+ norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
+
+ ratio = norm_guided / norm_pos
+ scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
+
+ guided_normalized = guided * scale_factor
+
+ z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
+
+ if img_slice is not None:
+ orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
+ orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
+ return orig_out
+ else:
+ out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
+ return out
+
+ m.set_model_attn1_output_patch(nag_attention_output_patch)
+ m.disable_model_cfg1_optimization()
+
+ return io.NodeOutput(m)
+
+
+class NagExtension(ComfyExtension):
+ @override
+ async def get_node_list(self) -> list[type[io.ComfyNode]]:
+ return [
+ NAGuidance,
+ ]
+
+
+async def comfy_entrypoint() -> NagExtension:
+ return NagExtension()
diff --git a/nodes.py b/nodes.py
index db5f98408..dff56b79c 100644
--- a/nodes.py
+++ b/nodes.py
@@ -2437,6 +2437,7 @@ async def init_builtin_extra_nodes():
"nodes_color.py",
"nodes_toolkit.py",
"nodes_replacements.py",
+ "nodes_nag.py",
]
import_failed = []
From c39653163d77161b2df2d57419129a4d6d081aa1 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Mon, 16 Feb 2026 21:29:20 -0800
Subject: [PATCH 48/85] Fix anima preprocess text embeds not using right
inference dtype. (#12501)
---
comfy/model_base.py | 18 ++++++++++--------
1 file changed, 10 insertions(+), 8 deletions(-)
diff --git a/comfy/model_base.py b/comfy/model_base.py
index 4a74cb1ce..9dcef8741 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module):
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
context = c_crossattn
- dtype = self.get_dtype()
-
- if self.manual_cast_dtype is not None:
- dtype = self.manual_cast_dtype
+ dtype = self.get_dtype_inference()
xc = xc.to(dtype)
device = xc.device
@@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module):
def get_dtype(self):
return self.diffusion_model.dtype
+ def get_dtype_inference(self):
+ dtype = self.get_dtype()
+
+ if self.manual_cast_dtype is not None:
+ dtype = self.manual_cast_dtype
+ return dtype
+
def encode_adm(self, **kwargs):
return None
@@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module):
input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
- dtype = self.get_dtype()
- if self.manual_cast_dtype is not None:
- dtype = self.manual_cast_dtype
+ dtype = self.get_dtype_inference()
#TODO: this needs to be tweaked
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
@@ -1165,7 +1167,7 @@ class Anima(BaseModel):
t5xxl_ids = t5xxl_ids.unsqueeze(0)
if torch.is_inference_mode_enabled(): # if not we are training
- cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
+ cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
From fe52843fe55b92dedaabff684294dd7a115d2204 Mon Sep 17 00:00:00 2001
From: comfyanonymous
Date: Tue, 17 Feb 2026 00:39:54 -0500
Subject: [PATCH 49/85] ComfyUI v0.14.0
---
comfyui_version.py | 2 +-
pyproject.toml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfyui_version.py b/comfyui_version.py
index cf4e89816..8f7f3228e 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.13.0"
+__version__ = "0.14.0"
diff --git a/pyproject.toml b/pyproject.toml
index 9dab9a50c..b132bb9c4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.13.0"
+version = "0.14.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
From 44f85985215b4d819665e4cec84c00ef87aa9a7a Mon Sep 17 00:00:00 2001
From: chaObserv <154517000+chaObserv@users.noreply.github.com>
Date: Tue, 17 Feb 2026 23:56:44 +0800
Subject: [PATCH 50/85] Fix anima LLM adapter forward when manual cast (#12504)
---
comfy/ldm/anima/model.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/comfy/ldm/anima/model.py b/comfy/ldm/anima/model.py
index 6fb51c4a4..6fcf8df90 100644
--- a/comfy/ldm/anima/model.py
+++ b/comfy/ldm/anima/model.py
@@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
- x = self.in_proj(self.embed(target_input_ids))
context = source_hidden_states
+ x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids)
From 5284e6bf69b6e2e856c672595fd413fd505377ee Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Tue, 17 Feb 2026 20:07:14 +0200
Subject: [PATCH 51/85] feat(api-nodes): add "viduq3-turbo" model and
Vidu3StartEnd node; fix the price badges (#12482)
---
comfy_api_nodes/nodes_vidu.py | 226 ++++++++++++++++++++++++++++++++--
1 file changed, 218 insertions(+), 8 deletions(-)
diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py
index 80de14dfe..bbe7ebba2 100644
--- a/comfy_api_nodes/nodes_vidu.py
+++ b/comfy_api_nodes/nodes_vidu.py
@@ -54,6 +54,7 @@ async def execute_task(
response_model=TaskStatusResponse,
status_extractor=lambda r: r.state,
progress_extractor=lambda r: r.progress,
+ price_extractor=lambda r: r.credits * 0.005 if r.credits is not None else None,
max_poll_attempts=max_poll_attempts,
)
if not response.creations:
@@ -1306,6 +1307,36 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
),
],
),
+ IO.DynamicCombo.Option(
+ "viduq3-turbo",
+ [
+ IO.Combo.Input(
+ "aspect_ratio",
+ options=["16:9", "9:16", "3:4", "4:3", "1:1"],
+ tooltip="The aspect ratio of the output video.",
+ ),
+ IO.Combo.Input(
+ "resolution",
+ options=["720p", "1080p"],
+ tooltip="Resolution of the output video.",
+ ),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=1,
+ max=16,
+ step=1,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Duration of the output video in seconds.",
+ ),
+ IO.Boolean.Input(
+ "audio",
+ default=False,
+ tooltip="When enabled, outputs video with sound "
+ "(including dialogue and sound effects).",
+ ),
+ ],
+ ),
],
tooltip="Model to use for video generation.",
),
@@ -1334,13 +1365,20 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
expr="""
(
$res := $lookup(widgets, "model.resolution");
- $base := $lookup({"720p": 0.075, "1080p": 0.1}, $res);
- $perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res);
- {"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
+ $d := $lookup(widgets, "model.duration");
+ $contains(widgets.model, "turbo")
+ ? (
+ $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
+ {"type":"usd","usd": $rate * $d}
+ )
+ : (
+ $rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
+ {"type":"usd","usd": $rate * $d}
+ )
)
""",
),
@@ -1409,6 +1447,31 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
),
],
),
+ IO.DynamicCombo.Option(
+ "viduq3-turbo",
+ [
+ IO.Combo.Input(
+ "resolution",
+ options=["720p", "1080p"],
+ tooltip="Resolution of the output video.",
+ ),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=1,
+ max=16,
+ step=1,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Duration of the output video in seconds.",
+ ),
+ IO.Boolean.Input(
+ "audio",
+ default=False,
+ tooltip="When enabled, outputs video with sound "
+ "(including dialogue and sound effects).",
+ ),
+ ],
+ ),
],
tooltip="Model to use for video generation.",
),
@@ -1442,13 +1505,20 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
- depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
+ depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
expr="""
(
$res := $lookup(widgets, "model.resolution");
- $base := $lookup({"720p": 0.075, "1080p": 0.275, "2k": 0.35}, $res);
- $perSec := $lookup({"720p": 0.05, "1080p": 0.075, "2k": 0.075}, $res);
- {"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
+ $d := $lookup(widgets, "model.duration");
+ $contains(widgets.model, "turbo")
+ ? (
+ $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
+ {"type":"usd","usd": $rate * $d}
+ )
+ : (
+ $rate := $lookup({"720p": 0.15, "1080p": 0.16, "2k": 0.2}, $res);
+ {"type":"usd","usd": $rate * $d}
+ )
)
""",
),
@@ -1481,6 +1551,145 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
+class Vidu3StartEndToVideoNode(IO.ComfyNode):
+
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="Vidu3StartEndToVideoNode",
+ display_name="Vidu Q3 Start/End Frame-to-Video Generation",
+ category="api node/video/Vidu",
+ description="Generate a video from a start frame, an end frame, and a prompt.",
+ inputs=[
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "viduq3-pro",
+ [
+ IO.Combo.Input(
+ "resolution",
+ options=["720p", "1080p"],
+ tooltip="Resolution of the output video.",
+ ),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=1,
+ max=16,
+ step=1,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Duration of the output video in seconds.",
+ ),
+ IO.Boolean.Input(
+ "audio",
+ default=False,
+ tooltip="When enabled, outputs video with sound "
+ "(including dialogue and sound effects).",
+ ),
+ ],
+ ),
+ IO.DynamicCombo.Option(
+ "viduq3-turbo",
+ [
+ IO.Combo.Input(
+ "resolution",
+ options=["720p", "1080p"],
+ tooltip="Resolution of the output video.",
+ ),
+ IO.Int.Input(
+ "duration",
+ default=5,
+ min=1,
+ max=16,
+ step=1,
+ display_mode=IO.NumberDisplay.slider,
+ tooltip="Duration of the output video in seconds.",
+ ),
+ IO.Boolean.Input(
+ "audio",
+ default=False,
+ tooltip="When enabled, outputs video with sound "
+ "(including dialogue and sound effects).",
+ ),
+ ],
+ ),
+ ],
+ tooltip="Model to use for video generation.",
+ ),
+ IO.Image.Input("first_frame"),
+ IO.Image.Input("end_frame"),
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="Prompt description (max 2000 characters).",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=1,
+ min=0,
+ max=2147483647,
+ step=1,
+ display_mode=IO.NumberDisplay.number,
+ control_after_generate=True,
+ ),
+ ],
+ outputs=[
+ IO.Video.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
+ expr="""
+ (
+ $res := $lookup(widgets, "model.resolution");
+ $d := $lookup(widgets, "model.duration");
+ $contains(widgets.model, "turbo")
+ ? (
+ $rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
+ {"type":"usd","usd": $rate * $d}
+ )
+ : (
+ $rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
+ {"type":"usd","usd": $rate * $d}
+ )
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ model: dict,
+ first_frame: Input.Image,
+ end_frame: Input.Image,
+ prompt: str,
+ seed: int,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, max_length=2000)
+ validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
+ payload = TaskCreationRequest(
+ model=model["model"],
+ prompt=prompt,
+ duration=model["duration"],
+ seed=seed,
+ resolution=model["resolution"],
+ audio=model["audio"],
+ images=[
+ (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
+ for frame in (first_frame, end_frame)
+ ],
+ )
+ results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
+ return IO.NodeOutput(await download_url_to_video_output(results[0].url))
+
+
class ViduExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -1497,6 +1706,7 @@ class ViduExtension(ComfyExtension):
ViduMultiFrameVideoNode,
Vidu3TextToVideoNode,
Vidu3ImageToVideoNode,
+ Vidu3StartEndToVideoNode,
]
From 262abf437b0666f3d00d1f335a526073503e59e4 Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Tue, 17 Feb 2026 20:25:44 +0200
Subject: [PATCH 52/85] feat(api-nodes): add Recraft V4 nodes (#12502)
---
comfy_api_nodes/apis/recraft.py | 45 ++++-
comfy_api_nodes/nodes_recraft.py | 272 +++++++++++++++++++++++++++++--
2 files changed, 297 insertions(+), 20 deletions(-)
diff --git a/comfy_api_nodes/apis/recraft.py b/comfy_api_nodes/apis/recraft.py
index 0bd7d23b3..78ededd94 100644
--- a/comfy_api_nodes/apis/recraft.py
+++ b/comfy_api_nodes/apis/recraft.py
@@ -198,11 +198,6 @@ dict_recraft_substyles_v3 = {
}
-class RecraftModel(str, Enum):
- recraftv3 = 'recraftv3'
- recraftv2 = 'recraftv2'
-
-
class RecraftImageSize(str, Enum):
res_1024x1024 = '1024x1024'
res_1365x1024 = '1365x1024'
@@ -221,6 +216,41 @@ class RecraftImageSize(str, Enum):
res_1707x1024 = '1707x1024'
+RECRAFT_V4_SIZES = [
+ "1024x1024",
+ "1536x768",
+ "768x1536",
+ "1280x832",
+ "832x1280",
+ "1216x896",
+ "896x1216",
+ "1152x896",
+ "896x1152",
+ "832x1344",
+ "1280x896",
+ "896x1280",
+ "1344x768",
+ "768x1344",
+]
+
+RECRAFT_V4_PRO_SIZES = [
+ "2048x2048",
+ "3072x1536",
+ "1536x3072",
+ "2560x1664",
+ "1664x2560",
+ "2432x1792",
+ "1792x2432",
+ "2304x1792",
+ "1792x2304",
+ "1664x2688",
+ "1434x1024",
+ "1024x1434",
+ "2560x1792",
+ "1792x2560",
+]
+
+
class RecraftColorObject(BaseModel):
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
@@ -234,17 +264,16 @@ class RecraftControlsObject(BaseModel):
class RecraftImageGenerationRequest(BaseModel):
prompt: str = Field(..., description='The text prompt describing the image to generate')
- size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
+ size: str | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
n: int = Field(..., description='The number of images to generate')
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
- model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
+ model: str = Field(...)
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
random_seed: int | None = Field(None, description="Seed for video generation")
- # text_layout
class RecraftReturnedObject(BaseModel):
diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py
index 3a1f32263..773cb7dbe 100644
--- a/comfy_api_nodes/nodes_recraft.py
+++ b/comfy_api_nodes/nodes_recraft.py
@@ -1,5 +1,4 @@
from io import BytesIO
-from typing import Optional, Union
import aiohttp
import torch
@@ -9,6 +8,8 @@ from typing_extensions import override
from comfy.utils import ProgressBar
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.recraft import (
+ RECRAFT_V4_PRO_SIZES,
+ RECRAFT_V4_SIZES,
RecraftColor,
RecraftColorChain,
RecraftControls,
@@ -18,7 +19,6 @@ from comfy_api_nodes.apis.recraft import (
RecraftImageGenerationResponse,
RecraftImageSize,
RecraftIO,
- RecraftModel,
RecraftStyle,
RecraftStyleV3,
get_v3_substyles,
@@ -39,7 +39,7 @@ async def handle_recraft_file_request(
cls: type[IO.ComfyNode],
image: torch.Tensor,
path: str,
- mask: Optional[torch.Tensor] = None,
+ mask: torch.Tensor | None = None,
total_pixels: int = 4096 * 4096,
timeout: int = 1024,
request=None,
@@ -73,11 +73,11 @@ async def handle_recraft_file_request(
def recraft_multipart_parser(
data,
parent_key=None,
- formatter: Optional[type[callable]] = None,
- converted_to_check: Optional[list[list]] = None,
+ formatter: type[callable] | None = None,
+ converted_to_check: list[list] | None = None,
is_list: bool = False,
return_mode: str = "formdata", # "dict" | "formdata"
-) -> Union[dict, aiohttp.FormData]:
+) -> dict | aiohttp.FormData:
"""
Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
@@ -309,7 +309,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
node_id="RecraftStyleV3InfiniteStyleLibrary",
display_name="Recraft Style - Infinite Style Library",
category="api node/image/Recraft",
- description="Select style based on preexisting UUID from Recraft's Infinite Style Library.",
+ description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
inputs=[
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
],
@@ -485,7 +485,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
data=RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
- model=RecraftModel.recraftv3,
+ model="recraftv3",
size=size,
n=n,
style=recraft_style.style,
@@ -598,7 +598,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
request = RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
- model=RecraftModel.recraftv3,
+ model="recraftv3",
n=n,
strength=round(strength, 2),
style=recraft_style.style,
@@ -698,7 +698,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
request = RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
- model=RecraftModel.recraftv3,
+ model="recraftv3",
n=n,
style=recraft_style.style,
substyle=recraft_style.substyle,
@@ -810,7 +810,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
data=RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
- model=RecraftModel.recraftv3,
+ model="recraftv3",
size=size,
n=n,
style=recraft_style.style,
@@ -933,7 +933,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
request = RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt,
- model=RecraftModel.recraftv3,
+ model="recraftv3",
n=n,
style=recraft_style.style,
substyle=recraft_style.substyle,
@@ -1078,6 +1078,252 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
)
+class RecraftV4TextToImageNode(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="RecraftV4TextToImageNode",
+ display_name="Recraft V4 Text to Image",
+ category="api node/image/Recraft",
+ description="Generates images using Recraft V4 or V4 Pro models.",
+ inputs=[
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="Prompt for the image generation. Maximum 10,000 characters.",
+ ),
+ IO.String.Input(
+ "negative_prompt",
+ multiline=True,
+ tooltip="An optional text description of undesired elements on an image.",
+ ),
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "recraftv4",
+ [
+ IO.Combo.Input(
+ "size",
+ options=RECRAFT_V4_SIZES,
+ default="1024x1024",
+ tooltip="The size of the generated image.",
+ ),
+ ],
+ ),
+ IO.DynamicCombo.Option(
+ "recraftv4_pro",
+ [
+ IO.Combo.Input(
+ "size",
+ options=RECRAFT_V4_PRO_SIZES,
+ default="2048x2048",
+ tooltip="The size of the generated image.",
+ ),
+ ],
+ ),
+ ],
+ tooltip="The model to use for generation.",
+ ),
+ IO.Int.Input(
+ "n",
+ default=1,
+ min=1,
+ max=6,
+ tooltip="The number of images to generate.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=0xFFFFFFFFFFFFFFFF,
+ control_after_generate=True,
+ tooltip="Seed to determine if node should re-run; "
+ "actual results are nondeterministic regardless of seed.",
+ ),
+ IO.Custom(RecraftIO.CONTROLS).Input(
+ "recraft_controls",
+ tooltip="Optional additional controls over the generation via the Recraft Controls node.",
+ optional=True,
+ ),
+ ],
+ outputs=[
+ IO.Image.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
+ expr="""
+ (
+ $prices := {"recraftv4": 0.04, "recraftv4_pro": 0.25};
+ {"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ prompt: str,
+ negative_prompt: str,
+ model: dict,
+ n: int,
+ seed: int,
+ recraft_controls: RecraftControls | None = None,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
+ response_model=RecraftImageGenerationResponse,
+ data=RecraftImageGenerationRequest(
+ prompt=prompt,
+ negative_prompt=negative_prompt if negative_prompt else None,
+ model=model["model"],
+ size=model["size"],
+ n=n,
+ controls=recraft_controls.create_api_model() if recraft_controls else None,
+ ),
+ max_retries=1,
+ )
+ images = []
+ for data in response.data:
+ with handle_recraft_image_output():
+ image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024))
+ if len(image.shape) < 4:
+ image = image.unsqueeze(0)
+ images.append(image)
+ return IO.NodeOutput(torch.cat(images, dim=0))
+
+
+class RecraftV4TextToVectorNode(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="RecraftV4TextToVectorNode",
+ display_name="Recraft V4 Text to Vector",
+ category="api node/image/Recraft",
+ description="Generates SVG using Recraft V4 or V4 Pro models.",
+ inputs=[
+ IO.String.Input(
+ "prompt",
+ multiline=True,
+ tooltip="Prompt for the image generation. Maximum 10,000 characters.",
+ ),
+ IO.String.Input(
+ "negative_prompt",
+ multiline=True,
+ tooltip="An optional text description of undesired elements on an image.",
+ ),
+ IO.DynamicCombo.Input(
+ "model",
+ options=[
+ IO.DynamicCombo.Option(
+ "recraftv4",
+ [
+ IO.Combo.Input(
+ "size",
+ options=RECRAFT_V4_SIZES,
+ default="1024x1024",
+ tooltip="The size of the generated image.",
+ ),
+ ],
+ ),
+ IO.DynamicCombo.Option(
+ "recraftv4_pro",
+ [
+ IO.Combo.Input(
+ "size",
+ options=RECRAFT_V4_PRO_SIZES,
+ default="2048x2048",
+ tooltip="The size of the generated image.",
+ ),
+ ],
+ ),
+ ],
+ tooltip="The model to use for generation.",
+ ),
+ IO.Int.Input(
+ "n",
+ default=1,
+ min=1,
+ max=6,
+ tooltip="The number of images to generate.",
+ ),
+ IO.Int.Input(
+ "seed",
+ default=0,
+ min=0,
+ max=0xFFFFFFFFFFFFFFFF,
+ control_after_generate=True,
+ tooltip="Seed to determine if node should re-run; "
+ "actual results are nondeterministic regardless of seed.",
+ ),
+ IO.Custom(RecraftIO.CONTROLS).Input(
+ "recraft_controls",
+ tooltip="Optional additional controls over the generation via the Recraft Controls node.",
+ optional=True,
+ ),
+ ],
+ outputs=[
+ IO.SVG.Output(),
+ ],
+ hidden=[
+ IO.Hidden.auth_token_comfy_org,
+ IO.Hidden.api_key_comfy_org,
+ IO.Hidden.unique_id,
+ ],
+ is_api_node=True,
+ price_badge=IO.PriceBadge(
+ depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
+ expr="""
+ (
+ $prices := {"recraftv4": 0.08, "recraftv4_pro": 0.30};
+ {"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
+ )
+ """,
+ ),
+ )
+
+ @classmethod
+ async def execute(
+ cls,
+ prompt: str,
+ negative_prompt: str,
+ model: dict,
+ n: int,
+ seed: int,
+ recraft_controls: RecraftControls | None = None,
+ ) -> IO.NodeOutput:
+ validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
+ response = await sync_op(
+ cls,
+ ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
+ response_model=RecraftImageGenerationResponse,
+ data=RecraftImageGenerationRequest(
+ prompt=prompt,
+ negative_prompt=negative_prompt if negative_prompt else None,
+ model=model["model"],
+ size=model["size"],
+ n=n,
+ style="vector_illustration",
+ substyle=None,
+ controls=recraft_controls.create_api_model() if recraft_controls else None,
+ ),
+ max_retries=1,
+ )
+ svg_data = []
+ for data in response.data:
+ svg_data.append(await download_url_as_bytesio(data.url, timeout=1024))
+ return IO.NodeOutput(SVG(svg_data))
+
+
class RecraftExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -1098,6 +1344,8 @@ class RecraftExtension(ComfyExtension):
RecraftCreateStyleNode,
RecraftColorRGBNode,
RecraftControlsNode,
+ RecraftV4TextToImageNode,
+ RecraftV4TextToVectorNode,
]
From 73c3f869737bbb1035f6b72b2e1068a1a5642764 Mon Sep 17 00:00:00 2001
From: ComfyUI Wiki
Date: Wed, 18 Feb 2026 02:25:55 +0800
Subject: [PATCH 53/85] chore: update workflow templates to v0.8.43 (#12507)
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 0930bbbb8..881d6bd58 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,5 @@
comfyui-frontend-package==1.38.14
-comfyui-workflow-templates==0.8.42
+comfyui-workflow-templates==0.8.43
comfyui-embedded-docs==0.4.1
torch
torchsde
From 19236edfa4d2f66070d66a6b3aee592c9c2ad574 Mon Sep 17 00:00:00 2001
From: comfyanonymous
Date: Tue, 17 Feb 2026 13:28:06 -0500
Subject: [PATCH 54/85] ComfyUI v0.14.1
---
comfyui_version.py | 2 +-
pyproject.toml | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/comfyui_version.py b/comfyui_version.py
index 8f7f3228e..f24c15cc5 100644
--- a/comfyui_version.py
+++ b/comfyui_version.py
@@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
-__version__ = "0.14.0"
+__version__ = "0.14.1"
diff --git a/pyproject.toml b/pyproject.toml
index b132bb9c4..51c3d224d 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
-version = "0.14.0"
+version = "0.14.1"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"
From 58dcc97dcfadc548ac8d8d5e80741ddfb807d213 Mon Sep 17 00:00:00 2001
From: rattus <46076784+rattus128@users.noreply.github.com>
Date: Tue, 17 Feb 2026 12:32:27 -0800
Subject: [PATCH 55/85] ops: limit return of requants (#12506)
This check was far too broad and the dtype is not a reliable indicator
of wanting the requant (as QT returns the compute dtype as the dtype).
So explictly plumb whether fp8mm wants the requant or not.
---
comfy/ops.py | 17 ++++++++---------
1 file changed, 8 insertions(+), 9 deletions(-)
diff --git a/comfy/ops.py b/comfy/ops.py
index 026062f56..a6c642795 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -79,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
-def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
+def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
offload_stream = None
xfer_dest = None
@@ -170,10 +170,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and
- (orig.dtype == dtype and len(fns) == 0 or update_weight)):
+ (want_requant and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
- if orig.dtype == dtype and len(fns) == 0:
+ if want_requant and len(fns) == 0:
#The layer actually wants our freshly saved QT
x = y
elif update_weight:
@@ -194,7 +194,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
return weight, bias, (offload_stream, device if signature is not None else None, None)
-def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
+def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance.
@@ -212,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
- return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
+ return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
@@ -850,8 +850,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias)
- def forward_comfy_cast_weights(self, input, compute_dtype=None):
- weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
+ def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
+ weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
@@ -881,8 +881,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
-
- output = self.forward_comfy_cast_weights(input, compute_dtype)
+ output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
# Reshape output back to 3D if input was 3D
if reshaped_3d:
From 6c14f129af4fd94c4197644e6950bddbba0c9e51 Mon Sep 17 00:00:00 2001
From: Comfy Org PR Bot
Date: Wed, 18 Feb 2026 06:41:34 +0900
Subject: [PATCH 56/85] Bump comfyui-frontend-package to 1.39.14 (#12494)
* Bump comfyui-frontend-package to 1.39.13
* Update requirements.txt
---------
Co-authored-by: Christian Byrne
---
requirements.txt | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/requirements.txt b/requirements.txt
index 881d6bd58..807fea5e0 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,4 +1,4 @@
-comfyui-frontend-package==1.38.14
+comfyui-frontend-package==1.39.14
comfyui-workflow-templates==0.8.43
comfyui-embedded-docs==0.4.1
torch
From 8ad38d2073b019a204f730182dcf5456fb260858 Mon Sep 17 00:00:00 2001
From: Terry Jia
Date: Tue, 17 Feb 2026 20:13:39 -0500
Subject: [PATCH 57/85] BBox widget (#11594)
* Boundingbox widget
* code improve
---------
Co-authored-by: Jedrzej Kosinski
Co-authored-by: Christian Byrne
---
comfy_api/latest/_io.py | 25 ++++++++++++++++
comfy_extras/nodes_images.py | 56 +++++++++++++++++++++++++++++++++++-
2 files changed, 80 insertions(+), 1 deletion(-)
diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py
index d18330d0b..312681249 100644
--- a/comfy_api/latest/_io.py
+++ b/comfy_api/latest/_io.py
@@ -1209,6 +1209,30 @@ class Color(ComfyTypeIO):
def as_dict(self):
return super().as_dict()
+@comfytype(io_type="BOUNDING_BOX")
+class BoundingBox(ComfyTypeIO):
+ class BoundingBoxDict(TypedDict):
+ x: int
+ y: int
+ width: int
+ height: int
+ Type = BoundingBoxDict
+
+ class Input(WidgetInput):
+ def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
+ socketless: bool=True, default: dict=None, component: str=None):
+ super().__init__(id, display_name, optional, tooltip, None, default, socketless)
+ self.component = component
+ if default is None:
+ self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
+
+ def as_dict(self):
+ d = super().as_dict()
+ if self.component:
+ d["component"] = self.component
+ return d
+
+
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
@@ -2190,5 +2214,6 @@ __all__ = [
"ImageCompare",
"PriceBadgeDepends",
"PriceBadge",
+ "BoundingBox",
"NodeReplace",
]
diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py
index cb4fb24a1..23419a65d 100644
--- a/comfy_extras/nodes_images.py
+++ b/comfy_extras/nodes_images.py
@@ -23,8 +23,9 @@ class ImageCrop(IO.ComfyNode):
return IO.Schema(
node_id="ImageCrop",
search_aliases=["trim"],
- display_name="Image Crop",
+ display_name="Image Crop (Deprecated)",
category="image/transform",
+ is_deprecated=True,
inputs=[
IO.Image.Input("image"),
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
@@ -47,6 +48,57 @@ class ImageCrop(IO.ComfyNode):
crop = execute # TODO: remove
+class ImageCropV2(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="ImageCropV2",
+ search_aliases=["trim"],
+ display_name="Image Crop",
+ category="image/transform",
+ inputs=[
+ IO.Image.Input("image"),
+ IO.BoundingBox.Input("crop_region", component="ImageCrop"),
+ ],
+ outputs=[IO.Image.Output()],
+ )
+
+ @classmethod
+ def execute(cls, image, crop_region) -> IO.NodeOutput:
+ x = crop_region.get("x", 0)
+ y = crop_region.get("y", 0)
+ width = crop_region.get("width", 512)
+ height = crop_region.get("height", 512)
+
+ x = min(x, image.shape[2] - 1)
+ y = min(y, image.shape[1] - 1)
+ to_x = width + x
+ to_y = height + y
+ img = image[:,y:to_y, x:to_x, :]
+ return IO.NodeOutput(img, ui=UI.PreviewImage(img))
+
+
+class BoundingBox(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="PrimitiveBoundingBox",
+ display_name="Bounding Box",
+ category="utils/primitive",
+ inputs=[
+ IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION),
+ IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION),
+ IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION),
+ IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION),
+ ],
+ outputs=[IO.BoundingBox.Output()],
+ )
+
+ @classmethod
+ def execute(cls, x, y, width, height) -> IO.NodeOutput:
+ return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height})
+
+
class RepeatImageBatch(IO.ComfyNode):
@classmethod
def define_schema(cls):
@@ -632,6 +684,8 @@ class ImagesExtension(ComfyExtension):
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ImageCrop,
+ ImageCropV2,
+ BoundingBox,
RepeatImageBatch,
ImageFromBatch,
ImageAddNoise,
From 83dd65f23ae78186df5be7f579af5c0cdb61f0f9 Mon Sep 17 00:00:00 2001
From: Hunter
Date: Wed, 18 Feb 2026 00:03:54 -0500
Subject: [PATCH 58/85] fix: use glob matching for Gemini image MIME types
(#12511)
gemini-3-pro-image-preview nondeterministically returns image/jpeg
instead of image/png. get_image_from_response() hardcoded
get_parts_by_type(response, "image/png"), silently dropping JPEG
responses and falling back to torch.zeros (all-black output).
Add _mime_matches() helper using fnmatch for glob-style MIME matching.
Change get_image_from_response() to request "image/*" so any image
format returned by the API is correctly captured.
---
comfy_api_nodes/nodes_gemini.py | 14 +++++++++++---
1 file changed, 11 insertions(+), 3 deletions(-)
diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py
index 3b31caa7b..5287a777a 100644
--- a/comfy_api_nodes/nodes_gemini.py
+++ b/comfy_api_nodes/nodes_gemini.py
@@ -6,6 +6,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
import base64
import os
from enum import Enum
+from fnmatch import fnmatch
from io import BytesIO
from typing import Literal
@@ -119,6 +120,13 @@ async def create_image_parts(
return image_parts
+def _mime_matches(mime: GeminiMimeType | None, pattern: str) -> bool:
+ """Check if a MIME type matches a pattern. Supports fnmatch globs (e.g. 'image/*')."""
+ if mime is None:
+ return False
+ return fnmatch(mime.value, pattern)
+
+
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
"""
Filter response parts by their type.
@@ -151,9 +159,9 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
for part in candidate.content.parts:
if part_type == "text" and part.text:
parts.append(part)
- elif part.inlineData and part.inlineData.mimeType == part_type:
+ elif part.inlineData and _mime_matches(part.inlineData.mimeType, part_type):
parts.append(part)
- elif part.fileData and part.fileData.mimeType == part_type:
+ elif part.fileData and _mime_matches(part.fileData.mimeType, part_type):
parts.append(part)
if not parts and blocked_reasons:
@@ -178,7 +186,7 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
- parts = get_parts_by_type(response, "image/png")
+ parts = get_parts_by_type(response, "image/*")
for part in parts:
if part.inlineData:
image_data = base64.b64decode(part.inlineData.data)
From 239ddd332724c63934bf517cfc6d0026214d8aee Mon Sep 17 00:00:00 2001
From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com>
Date: Wed, 18 Feb 2026 09:15:23 +0200
Subject: [PATCH 59/85] fix(api-nodes): add price badge for Rodin Gen-2 node
(#12512)
---
comfy_api_nodes/nodes_rodin.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py
index f9cff121f..9c1adaa51 100644
--- a/comfy_api_nodes/nodes_rodin.py
+++ b/comfy_api_nodes/nodes_rodin.py
@@ -505,6 +505,9 @@ class Rodin3D_Gen2(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
+ price_badge=IO.PriceBadge(
+ expr="""{"type":"usd","usd":0.4}""",
+ ),
)
@classmethod
From f262444dd4818b6acdbc1350856679dd6245f7f5 Mon Sep 17 00:00:00 2001
From: comfyanonymous <121283862+comfyanonymous@users.noreply.github.com>
Date: Wed, 18 Feb 2026 15:36:35 -0800
Subject: [PATCH 60/85] Add simple 3 band equalizer node for audio. (#12519)
---
comfy_extras/nodes_audio.py | 62 +++++++++++++++++++++++++++++++++++++
1 file changed, 62 insertions(+)
diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py
index b63dd8e97..7e74169f2 100644
--- a/comfy_extras/nodes_audio.py
+++ b/comfy_extras/nodes_audio.py
@@ -698,6 +698,67 @@ class EmptyAudio(IO.ComfyNode):
create_empty_audio = execute # TODO: remove
+class AudioEqualizer3Band(IO.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ return IO.Schema(
+ node_id="AudioEqualizer3Band",
+ search_aliases=["eq", "bass boost", "treble boost", "equalizer"],
+ display_name="Audio Equalizer (3-Band)",
+ category="audio",
+ is_experimental=True,
+ inputs=[
+ IO.Audio.Input("audio"),
+ IO.Float.Input("low_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Low frequencies (Bass)"),
+ IO.Int.Input("low_freq", default=100, min=20, max=500, tooltip="Cutoff frequency for Low shelf"),
+ IO.Float.Input("mid_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Mid frequencies"),
+ IO.Int.Input("mid_freq", default=1000, min=200, max=4000, tooltip="Center frequency for Mids"),
+ IO.Float.Input("mid_q", default=0.707, min=0.1, max=10.0, step=0.1, tooltip="Q factor (bandwidth) for Mids"),
+ IO.Float.Input("high_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for High frequencies (Treble)"),
+ IO.Int.Input("high_freq", default=5000, min=1000, max=15000, tooltip="Cutoff frequency for High shelf"),
+ ],
+ outputs=[IO.Audio.Output()],
+ )
+
+ @classmethod
+ def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput:
+ waveform = audio["waveform"]
+ sample_rate = audio["sample_rate"]
+ eq_waveform = waveform.clone()
+
+ # 1. Apply Low Shelf (Bass)
+ if low_gain_dB != 0:
+ eq_waveform = torchaudio.functional.bass_biquad(
+ eq_waveform,
+ sample_rate,
+ gain=low_gain_dB,
+ central_freq=float(low_freq),
+ Q=0.707
+ )
+
+ # 2. Apply Peaking EQ (Mids)
+ if mid_gain_dB != 0:
+ eq_waveform = torchaudio.functional.equalizer_biquad(
+ eq_waveform,
+ sample_rate,
+ center_freq=float(mid_freq),
+ gain=mid_gain_dB,
+ Q=mid_q
+ )
+
+ # 3. Apply High Shelf (Treble)
+ if high_gain_dB != 0:
+ eq_waveform = torchaudio.functional.treble_biquad(
+ eq_waveform,
+ sample_rate,
+ gain=high_gain_dB,
+ central_freq=float(high_freq),
+ Q=0.707
+ )
+
+ return IO.NodeOutput({"waveform": eq_waveform, "sample_rate": sample_rate})
+
+
class AudioExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@@ -720,6 +781,7 @@ class AudioExtension(ComfyExtension):
AudioMerge,
AudioAdjustVolume,
EmptyAudio,
+ AudioEqualizer3Band,
]
async def comfy_entrypoint() -> AudioExtension:
From 6d11cc73549e14a0a31e9ff8c90bfd71b380fe2c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jukka=20Sepp=C3=A4nen?=
<40791699+kijai@users.noreply.github.com>
Date: Thu, 19 Feb 2026 03:49:43 +0200
Subject: [PATCH 61/85] feat: Add basic text generation support with native
models, initially supporting Gemma3 (#12392)
---
comfy/sd.py | 29 +++-
comfy/sd1_clip.py | 18 +++
comfy/text_encoders/llama.py | 148 +++++++++++++++++++-
comfy/text_encoders/lt.py | 92 ++++++++++---
comfy/text_encoders/lumina2.py | 36 ++++-
comfy/text_encoders/spiece_tokenizer.py | 27 +++-
comfy/utils.py | 8 ++
comfy_extras/nodes_textgen.py | 176 ++++++++++++++++++++++++
nodes.py | 1 +
9 files changed, 502 insertions(+), 33 deletions(-)
create mode 100644 comfy_extras/nodes_textgen.py
diff --git a/comfy/sd.py b/comfy/sd.py
index f65e7cadd..164f30803 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -423,6 +423,19 @@ class CLIP:
def get_key_patches(self):
return self.patcher.get_key_patches()
+ def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
+ self.cond_stage_model.reset_clip_options()
+
+ if self.layer_idx is not None:
+ self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
+
+ self.load_model()
+ self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
+ return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
+
+ def decode(self, token_ids, skip_special_tokens=True):
+ return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
+
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
@@ -1182,6 +1195,7 @@ class TEModel(Enum):
JINA_CLIP_2 = 19
QWEN3_8B = 20
QWEN3_06B = 21
+ GEMMA_3_4B_VISION = 22
def detect_te_model(sd):
@@ -1210,7 +1224,10 @@ def detect_te_model(sd):
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
- return TEModel.GEMMA_3_4B
+ if 'vision_model.embeddings.patch_embedding.weight' in sd:
+ return TEModel.GEMMA_3_4B_VISION
+ else:
+ return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd:
weight = sd['model.layers.0.self_attn.k_proj.bias']
@@ -1270,6 +1287,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
+ if "lm_head.weight" in clip_data[i]:
+ clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models
tokenizer_data = {}
clip_target = EmptyClass()
@@ -1335,6 +1354,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
+ elif te_model == TEModel.GEMMA_3_4B_VISION:
+ clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision")
+ clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
+ tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
+ elif te_model == TEModel.GEMMA_3_12B:
+ clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer
+ tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index b564d1529..d9d014055 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -308,6 +308,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
+ def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[]):
+ if isinstance(tokens, dict):
+ tokens_only = next(iter(tokens.values())) # todo: get this better?
+ else:
+ tokens_only = tokens
+ tokens_only = [[t[0] for t in b] for b in tokens_only]
+ embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
+ return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens)
+
def parse_parentheses(string):
result = []
current_item = ""
@@ -663,6 +672,9 @@ class SDTokenizer:
def state_dict(self):
return {}
+ def decode(self, token_ids, skip_special_tokens=True):
+ return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
+
class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
if name is not None:
@@ -686,6 +698,9 @@ class SD1Tokenizer:
def state_dict(self):
return getattr(self, self.clip).state_dict()
+ def decode(self, token_ids, skip_special_tokens=True):
+ return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens)
+
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
@@ -722,3 +737,6 @@ class SD1ClipModel(torch.nn.Module):
def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd)
+
+ def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
+ return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py
index 54f3d5595..e5d21fa74 100644
--- a/comfy/text_encoders/llama.py
+++ b/comfy/text_encoders/llama.py
@@ -3,6 +3,8 @@ import torch.nn as nn
from dataclasses import dataclass
from typing import Optional, Any, Tuple
import math
+from tqdm import tqdm
+import comfy.utils
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
@@ -313,6 +315,13 @@ class Gemma3_4B_Config:
final_norm: bool = True
lm_head: bool = False
+GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
+
+@dataclass
+class Gemma3_4B_Vision_Config(Gemma3_4B_Config):
+ vision_config = GEMMA3_VISION_CONFIG
+ mm_tokens_per_image = 256
+
@dataclass
class Gemma3_12B_Config:
vocab_size: int = 262208
@@ -336,7 +345,7 @@ class Gemma3_12B_Config:
rope_scale = [8.0, 1.0]
final_norm: bool = True
lm_head: bool = False
- vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
+ vision_config = GEMMA3_VISION_CONFIG
mm_tokens_per_image = 256
class RMSNorm(nn.Module):
@@ -441,8 +450,10 @@ class Attention(nn.Module):
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ sliding_window: Optional[int] = None,
):
batch_size, seq_length, _ = hidden_states.shape
+
xq = self.q_proj(hidden_states)
xk = self.k_proj(hidden_states)
xv = self.v_proj(hidden_states)
@@ -477,6 +488,11 @@ class Attention(nn.Module):
else:
present_key_value = (xk, xv, index + num_tokens)
+ if sliding_window is not None and xk.shape[2] > sliding_window:
+ xk = xk[:, :, -sliding_window:]
+ xv = xv[:, :, -sliding_window:]
+ attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
+
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
@@ -559,10 +575,12 @@ class TransformerBlockGemma2(nn.Module):
optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
):
+ sliding_window = None
if self.transformer_type == 'gemma3':
if self.sliding_attention:
+ sliding_window = self.sliding_attention
if x.shape[1] > self.sliding_attention:
- sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype)
+ sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None:
attention_mask = attention_mask + sliding_mask
@@ -581,6 +599,7 @@ class TransformerBlockGemma2(nn.Module):
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
past_key_value=past_key_value,
+ sliding_window=sliding_window,
)
x = self.post_attention_layernorm(x)
@@ -765,6 +784,104 @@ class BaseLlama:
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)
+class BaseGenerate:
+ def logits(self, x):
+ input = x[:, -1:]
+ if hasattr(self.model, "lm_head"):
+ module = self.model.lm_head
+ else:
+ module = self.model.embed_tokens
+
+ offload_stream = None
+ if module.comfy_cast_weights:
+ weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
+ else:
+ weight = self.model.embed_tokens.weight.to(x)
+
+ x = torch.nn.functional.linear(input, weight, None)
+
+ comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
+ return x
+
+ def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0):
+ device = embeds.device
+ model_config = self.model.config
+
+ if execution_dtype is None:
+ if comfy.model_management.should_use_bf16(device):
+ execution_dtype = torch.bfloat16
+ else:
+ execution_dtype = torch.float32
+ embeds = embeds.to(execution_dtype)
+
+ if embeds.ndim == 2:
+ embeds = embeds.unsqueeze(0)
+
+ past_key_values = [] #kv_cache init
+ max_cache_len = embeds.shape[1] + max_length
+ for x in range(model_config.num_hidden_layers):
+ past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
+ torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
+
+ generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
+
+ generated_token_ids = []
+ pbar = comfy.utils.ProgressBar(max_length)
+
+ # Generation loop
+ for step in tqdm(range(max_length), desc="Generating tokens"):
+ x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
+ logits = self.logits(x)[:, -1]
+ next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
+ token_id = next_token[0].item()
+ generated_token_ids.append(token_id)
+
+ embeds = self.model.embed_tokens(next_token).to(execution_dtype)
+ pbar.update(1)
+
+ if token_id in stop_tokens:
+ break
+
+ return generated_token_ids
+
+ def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
+
+ if not do_sample or temperature == 0.0:
+ return torch.argmax(logits, dim=-1, keepdim=True)
+
+ # Sampling mode
+ if repetition_penalty != 1.0:
+ for i in range(logits.shape[0]):
+ for token_id in set(token_history):
+ logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
+
+ if temperature != 1.0:
+ logits = logits / temperature
+
+ if top_k > 0:
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
+ logits[indices_to_remove] = torch.finfo(logits.dtype).min
+
+ if min_p > 0.0:
+ probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
+ top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
+ min_threshold = min_p * top_probs
+ indices_to_remove = probs_before_filter < min_threshold
+ logits[indices_to_remove] = torch.finfo(logits.dtype).min
+
+ if top_p < 1.0:
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
+ cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
+ sorted_indices_to_remove = cumulative_probs > top_p
+ sorted_indices_to_remove[..., 0] = False
+ indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
+ indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
+ logits[indices_to_remove] = torch.finfo(logits.dtype).min
+
+ probs = torch.nn.functional.softmax(logits, dim=-1)
+
+ return torch.multinomial(probs, num_samples=1, generator=generator)
+
class BaseQwen3:
def logits(self, x):
input = x[:, -1:]
@@ -871,7 +988,7 @@ class Ovis25_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
-class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
+class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Qwen25_7BVLI_Config(**config_dict)
@@ -881,6 +998,9 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
+ # todo: should this be tied or not?
+ #self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
+
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
@@ -923,7 +1043,7 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
-class Gemma3_4B(BaseLlama, torch.nn.Module):
+class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Config(**config_dict)
@@ -932,7 +1052,25 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
-class Gemma3_12B(BaseLlama, torch.nn.Module):
+class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ config = Gemma3_4B_Vision_Config(**config_dict)
+ self.num_layers = config.num_hidden_layers
+
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+ self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
+ self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
+ self.image_size = config.vision_config["image_size"]
+
+ def preprocess_embed(self, embed, device):
+ if embed["type"] == "image":
+ image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
+ return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
+ return None, None
+
+class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_12B_Config(**config_dict)
diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py
index 9cf87c0b2..82fbacf59 100644
--- a/comfy/text_encoders/lt.py
+++ b/comfy/text_encoders/lt.py
@@ -6,6 +6,7 @@ import comfy.text_encoders.genmo
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import torch
import comfy.utils
+import math
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -22,40 +23,79 @@ def ltxv_te(*args, **kwargs):
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
-class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
- def __init__(self, embedding_directory=None, tokenizer_data={}):
- tokenizer = tokenizer_data.get("spiece_model", None)
- super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
-
+class Gemma3_Tokenizer():
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
+ def tokenize_with_weights(self, text, return_word_ids=False, image=None, llama_template=None, skip_template=True, **kwargs):
+ self.llama_template = "system\nYou are a helpful assistant.\nuser\n{}\nmodel\n"
+ self.llama_template_images = "system\nYou are a helpful assistant.\nuser\n\n{}\n\nmodel\n"
+
+ if image is None:
+ images = []
+ else:
+ samples = image.movedim(-1, 1)
+ total = int(896 * 896)
+
+ scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
+ width = round(samples.shape[3] * scale_by)
+ height = round(samples.shape[2] * scale_by)
+
+ s = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
+ images = [s[:, :, :, :3]]
+
+ if text.startswith(''):
+ skip_template = True
+
+ if skip_template:
+ llama_text = text
+ else:
+ if llama_template is None:
+ if len(images) > 0:
+ llama_text = self.llama_template_images.format(text)
+ else:
+ llama_text = self.llama_template.format(text)
+ else:
+ llama_text = llama_template.format(text)
+
+ text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
+
+ if len(images) > 0:
+ embed_count = 0
+ for r in text_tokens:
+ for i, token in enumerate(r):
+ if token[0] == 262144 and embed_count < len(images):
+ r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:]
+ embed_count += 1
+ return text_tokens
+
+class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ tokenizer = tokenizer_data.get("spiece_model", None)
+ special_tokens = {"": 262144, "": 106}
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
+
+
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
+
class Gemma3_12BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
-
+ self.dtypes = set()
+ self.dtypes.add(dtype)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
- def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs):
- text = llama_template.format(text)
- text_tokens = super().tokenize_with_weights(text, return_word_ids)
- embed_count = 0
- for k in text_tokens:
- tt = text_tokens[k]
- for r in tt:
- for i in range(len(r)):
- if r[i][0] == 262144:
- if image_embeds is not None and embed_count < image_embeds.shape[0]:
- r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
- embed_count += 1
- return text_tokens
+ def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
+ tokens_only = [[t[0] for t in b] for b in tokens]
+ embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
+ comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
+ return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is
class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
@@ -112,6 +152,9 @@ class LTXAVTEModel(torch.nn.Module):
return out.to(out_device), pooled
+ def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
+ return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
+
def load_sd(self, sd):
if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd)
@@ -152,3 +195,14 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return LTXAVTEModel_
+
+def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
+ class Gemma3_12BModel_(Gemma3_12BModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["llama_quantization_metadata"] = llama_quantization_metadata
+ if dtype_llama is not None:
+ dtype = dtype_llama
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
+ return Gemma3_12BModel_
diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py
index b29a7cc87..1b731e094 100644
--- a/comfy/text_encoders/lumina2.py
+++ b/comfy/text_encoders/lumina2.py
@@ -1,23 +1,23 @@
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.llama
-
+from comfy.text_encoders.lt import Gemma3_Tokenizer
+import comfy.utils
class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
- super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
+ special_tokens = {"": 107}
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
-class Gemma3_4BTokenizer(sd1_clip.SDTokenizer):
+class Gemma3_4BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
- super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data)
-
- def state_dict(self):
- return {"spiece_model": self.tokenizer.serialize_model()}
+ special_tokens = {"": 262144, "": 106}
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, disable_weights=True, tokenizer_data=tokenizer_data)
class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -31,6 +31,9 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+ def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
+ return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[107])
+
class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
@@ -40,6 +43,23 @@ class Gemma3_4BModel(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+ def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
+ return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106])
+
+class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
+ llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
+ if llama_quantization_metadata is not None:
+ model_options = model_options.copy()
+ model_options["quantization_metadata"] = llama_quantization_metadata
+
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+
+ def process_tokens(self, tokens, device):
+ embeds, _, _, embeds_info = super().process_tokens(tokens, device)
+ comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
+ return embeds
+
class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
@@ -50,6 +70,8 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b
model = Gemma2_2BModel
elif model_type == "gemma3_4b":
model = Gemma3_4BModel
+ elif model_type == "gemma3_4b_vision":
+ model = Gemma3_4B_Vision_Model
class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
diff --git a/comfy/text_encoders/spiece_tokenizer.py b/comfy/text_encoders/spiece_tokenizer.py
index caccb3ca2..099d8d2d9 100644
--- a/comfy/text_encoders/spiece_tokenizer.py
+++ b/comfy/text_encoders/spiece_tokenizer.py
@@ -6,9 +6,10 @@ class SPieceTokenizer:
def from_pretrained(path, **kwargs):
return SPieceTokenizer(path, **kwargs)
- def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
+ def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None):
self.add_bos = add_bos
self.add_eos = add_eos
+ self.special_tokens = special_tokens
import sentencepiece
if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes()
@@ -27,8 +28,32 @@ class SPieceTokenizer:
return out
def __call__(self, string):
+ if self.special_tokens is not None:
+ import re
+ special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys())
+ if special_tokens_pattern and re.search(special_tokens_pattern, string):
+ parts = re.split(f'({special_tokens_pattern})', string)
+ result = []
+ for part in parts:
+ if not part:
+ continue
+ if part in self.special_tokens:
+ result.append(self.special_tokens[part])
+ else:
+ encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False)
+ result.extend(encoded)
+ return {"input_ids": result}
+
out = self.tokenizer.encode(string)
return {"input_ids": out}
+ def decode(self, token_ids, skip_special_tokens=False):
+
+ if skip_special_tokens and self.special_tokens:
+ special_token_ids = set(self.special_tokens.values())
+ token_ids = [tid for tid in token_ids if tid not in special_token_ids]
+
+ return self.tokenizer.decode(token_ids)
+
def serialize_model(self):
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))
diff --git a/comfy/utils.py b/comfy/utils.py
index c1ce540b5..17443b4cc 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -1418,3 +1418,11 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res
return res
+
+def normalize_image_embeddings(embeds, embeds_info, scale_factor):
+ """Normalize image embeddings to match text embedding scale"""
+ for info in embeds_info:
+ if info.get("type") == "image":
+ start_idx = info["index"]
+ end_idx = start_idx + info["size"]
+ embeds[:, start_idx:end_idx, :] /= scale_factor
diff --git a/comfy_extras/nodes_textgen.py b/comfy_extras/nodes_textgen.py
new file mode 100644
index 000000000..dd4f6b0d3
--- /dev/null
+++ b/comfy_extras/nodes_textgen.py
@@ -0,0 +1,176 @@
+from comfy_api.latest import ComfyExtension, io
+from typing_extensions import override
+
+class TextGenerate(io.ComfyNode):
+ @classmethod
+ def define_schema(cls):
+ # Define dynamic combo options for sampling mode
+ sampling_options = [
+ io.DynamicCombo.Option(
+ key="on",
+ inputs=[
+ io.Float.Input("temperature", default=0.7, min=0.01, max=2.0, step=0.000001),
+ io.Int.Input("top_k", default=64, min=0, max=1000),
+ io.Float.Input("top_p", default=0.95, min=0.0, max=1.0, step=0.01),
+ io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
+ io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
+ io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
+ ]
+ ),
+ io.DynamicCombo.Option(
+ key="off",
+ inputs=[]
+ ),
+ ]
+
+ return io.Schema(
+ node_id="TextGenerate",
+ category="textgen/",
+ search_aliases=["LLM", "gemma"],
+ inputs=[
+ io.Clip.Input("clip"),
+ io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
+ io.Image.Input("image", optional=True),
+ io.Int.Input("max_length", default=256, min=1, max=2048),
+ io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
+ ],
+ outputs=[
+ io.String.Output(display_name="generated_text"),
+ ],
+ )
+
+ @classmethod
+ def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
+
+ tokens = clip.tokenize(prompt, image=image, skip_template=False)
+
+ # Get sampling parameters from dynamic combo
+ do_sample = sampling_mode.get("sampling_mode") == "on"
+ temperature = sampling_mode.get("temperature", 1.0)
+ top_k = sampling_mode.get("top_k", 50)
+ top_p = sampling_mode.get("top_p", 1.0)
+ min_p = sampling_mode.get("min_p", 0.0)
+ seed = sampling_mode.get("seed", None)
+ repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
+
+ generated_ids = clip.generate(
+ tokens,
+ do_sample=do_sample,
+ max_length=max_length,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ min_p=min_p,
+ repetition_penalty=repetition_penalty,
+ seed=seed
+ )
+
+ generated_text = clip.decode(generated_ids, skip_special_tokens=True)
+ return io.NodeOutput(generated_text)
+
+
+LTX2_T2V_SYSTEM_PROMPT = """You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
+#### Guidelines
+- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio).
+ - If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
+ - For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
+- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
+- Maintain chronological flow: use temporal connectors ("as," "then," "while").
+- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present").
+- Speech (only when requested):
+ - For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
+ - Specify language if not English and accent if relevant.
+- Style: Include visual style at the beginning: "Style: