Compare commits

...

20 Commits

Author SHA1 Message Date
Deep Mehta
bc1a1f8833
Merge 5225f109a6 into 025e6792ee 2026-05-03 08:55:16 -07:00
Jukka Seppänen
025e6792ee
Batch broadcasting in JoinImageWithAlpha node (#13686)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Generate Pydantic Stubs from api.comfy.org / generate-models (push) Has been cancelled
* Batch broadcasting in JoinImageWithAlpha node
2026-05-03 16:30:00 +03:00
Luke Mino-Altherr
867b8d2408
fix: gracefully handle port-in-use error on server startup (#13001)
Catch EADDRINUSE OSError when binding the TCP site and exit with a clear error message instead of an unhandled traceback.
2026-05-03 20:44:20 +08:00
Alexis Rolland
d0f0b15cf5
Update ComfyUI screenshot in README (#13683)
Update ComfyUI screenshot to showcase a more modern workflow
2026-05-03 18:48:58 +08:00
Alexis Rolland
b5bb83c964
Fix issue blend images with alpha (#13615)
Make ImageBlend and ImageCompositeMasked nodes handle images with different channel counts
2026-05-03 18:17:08 +08:00
Alexis Rolland
f6d5068ac0
Update README (#13679)
Some checks failed
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Updated the README to include a new screenshot, improved description and add Ernie Image to supported models.
2026-05-03 12:20:17 +08:00
Jukka Seppänen
be95871adc
feat: Gemma4 text generation support (CORE-30) (#13376)
* initial gemma4 support

* parity with reference implementation

outputs can 100% match transformers with same sdpa flags, checkpoint this and then optimize

* Cleanup, video fixes

* cleanup, enable fused rms norm by default

* update comment

* Cleanup

* Update sd.py

* Various fixes

* Add fp8 scaled embedding support

* small fixes

* Translate think tokens

* Fix image encoder attention mask type

So it works with basic attention

* Handle thinking tokens different only for Gemma4

* Code cleanup

* Update nodes_textgen.py

* Use embed scale class instead of buffer

Slight difference to HF, but technically more accurate and simpler code

* Default to fused rms_norm

* Update gemma4.py
2026-05-02 22:46:15 -04:00
Alexander Piskun
f756d801a1
[Partner Nodes] Topaz Astra 2 model (#13672)
* feat(api-nodes): add Topaz Astra 2 model

Signed-off-by: bigcat88 <bigcat88@icloud.com>

* feat(api-nodes): make Astra 2 the default Topaz upscaler model

Reorder UPSCALER_MODELS_MAP and the upscaler_model dynamic combo so
"Astra 2" appears first, surfacing it as the default selection.

---------

Signed-off-by: bigcat88 <bigcat88@icloud.com>
Co-authored-by: Marwan Mostafa <marawan206@gmail.com>
2026-05-02 19:29:00 -07:00
Daxiong (Lin)
1d23a875ed
chore: update workflow templates to v0.9.68 (#13678) 2026-05-03 10:06:55 +08:00
comfyanonymous
ef6722f6be
Some cleanups to the load image node. (#13677) 2026-05-02 20:34:27 -04:00
Jedrzej Kosinski
5225f109a6
Merge branch 'master' into deepme987/auto-register-node-replacements-json
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-04-28 02:53:37 -07:00
Deep Mehta
e35fe5bc09
Merge branch 'master' into deepme987/auto-register-node-replacements-json
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-04-21 05:00:15 +05:30
Deep Mehta
77054cd49e
Merge branch 'master' into deepme987/auto-register-node-replacements-json
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-04-14 19:34:21 -07:00
Deep Mehta
1cd2730b25
Merge branch 'master' into deepme987/auto-register-node-replacements-json
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-04-06 13:13:42 -07:00
Deep Mehta
d4351f77f8
Merge branch 'master' into deepme987/auto-register-node-replacements-json
Some checks failed
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Python Linting / Run Ruff (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
2026-03-25 22:50:44 -07:00
Deep Mehta
9837dd368a refactor: move load_from_json into NodeReplaceManager
Address review feedback from Kosinkadink:
1. Move JSON loading logic from nodes.py into NodeReplaceManager as
   load_from_json() method for better encapsulation and testability
2. Tests now exercise the real NodeReplaceManager (no duplicated logic)
3. Defer `import nodes` in apply_replacements to avoid torch at import
4. nodes.py call site simplified to one line:
   PromptServer.instance.node_replace_manager.load_from_json(...)
2026-03-25 22:12:23 -07:00
Deep Mehta
62ec9a3238 fix: skip single-file nodes and validate new_node_id
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Two fixes from code review:
1. Only load node_replacements.json from directory-based custom nodes.
   Single-file .py nodes share a parent dir (custom_nodes/), so checking
   there would incorrectly pick up a stray file.
2. Skip entries with missing or empty new_node_id instead of registering
   a replacement pointing to nothing.
2026-03-23 14:47:11 -07:00
Deep Mehta
b20cb7892e
Merge branch 'master' into deepme987/auto-register-node-replacements-json
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
2026-03-18 17:14:08 -07:00
Deep Mehta
b9b24d425b
Merge branch 'master' into deepme987/auto-register-node-replacements-json
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
2026-03-17 20:58:06 -07:00
Deep Mehta
d731cb6ae1 feat: auto-register node replacements from custom node JSON files
Custom node authors can now ship a `node_replacements.json` in their
repo root to define replacements declaratively. During node loading,
ComfyUI reads these files and registers entries via the existing
NodeReplaceManager — no Python registration code needed.

This enables two use cases:
1. Authors deprecate/rename nodes with a migration path for old workflows
2. Authors offer their nodes as drop-in replacements for other packs
2026-03-17 20:57:32 -07:00
21 changed files with 2135 additions and 81 deletions

View File

@ -1,7 +1,7 @@
<div align="center">
# ComfyUI
**The most powerful and modular visual AI engine and application.**
**The most powerful and modular AI engine for content creation.**
[![Website][website-shield]][website-url]
@ -31,10 +31,16 @@
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe)
<img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/36e065e0-bfae-4456-8c7f-8369d5ea48a2" />
<br>
</div>
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
- ComfyUI natively supports the latest open-source state of the art models.
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
- It integrates seamlessly into production pipelines with our API endpoints.
## Get Started
@ -77,6 +83,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
- Ernie Image
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)

View File

@ -1,5 +1,9 @@
from __future__ import annotations
import json
import logging
import os
from aiohttp import web
from typing import TYPE_CHECKING, TypedDict
@ -7,7 +11,6 @@ if TYPE_CHECKING:
from comfy_api.latest._io_public import NodeReplace
from comfy_execution.graph_utils import is_link
import nodes
class NodeStruct(TypedDict):
inputs: dict[str, str | int | float | bool | tuple[str, int]]
@ -43,6 +46,7 @@ class NodeReplaceManager:
return old_node_id in self._replacements
def apply_replacements(self, prompt: dict[str, NodeStruct]):
import nodes
connections: dict[str, list[tuple[str, str, int]]] = {}
need_replacement: set[str] = set()
for node_number, node_struct in prompt.items():
@ -94,6 +98,60 @@ class NodeReplaceManager:
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
previous_input[1] = new_output_idx
def load_from_json(self, module_dir: str, module_name: str, _node_replace_class=None):
"""Load node_replacements.json from a custom node directory and register replacements.
Custom node authors can ship a node_replacements.json file in their repo root
to define node replacements declaratively. The file format matches the output
of NodeReplace.as_dict(), keyed by old_node_id.
Fail-open: all errors are logged and skipped so a malformed file never
prevents the custom node from loading.
"""
replacements_path = os.path.join(module_dir, "node_replacements.json")
if not os.path.isfile(replacements_path):
return
try:
with open(replacements_path, "r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
logging.warning(f"node_replacements.json in {module_name} must be a JSON object, skipping.")
return
if _node_replace_class is None:
from comfy_api.latest._io import NodeReplace
_node_replace_class = NodeReplace
count = 0
for old_node_id, replacements in data.items():
if not isinstance(replacements, list):
logging.warning(f"node_replacements.json in {module_name}: value for '{old_node_id}' must be a list, skipping.")
continue
for entry in replacements:
if not isinstance(entry, dict):
continue
new_node_id = entry.get("new_node_id", "")
if not new_node_id:
logging.warning(f"node_replacements.json in {module_name}: entry for '{old_node_id}' missing 'new_node_id', skipping.")
continue
self.register(_node_replace_class(
new_node_id=new_node_id,
old_node_id=entry.get("old_node_id", old_node_id),
old_widget_ids=entry.get("old_widget_ids"),
input_mapping=entry.get("input_mapping"),
output_mapping=entry.get("output_mapping"),
))
count += 1
if count > 0:
logging.info(f"Loaded {count} node replacement(s) from {module_name}/node_replacements.json")
except json.JSONDecodeError as e:
logging.warning(f"Failed to parse node_replacements.json in {module_name}: {e}")
except Exception as e:
logging.warning(f"Failed to load node_replacements.json from {module_name}: {e}")
def as_dict(self):
"""Serialize all replacements to dict."""
return {

View File

@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
if model_management.xformers_enabled():
import xformers
import xformers.ops
@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
n_rep = q.shape[-3] // k.shape[-3]
k = k.repeat_interleave(n_rep, dim=-3)
v = v.repeat_interleave(n_rep, dim=-3)
scale = kwargs.get("scale", dim_head ** -0.5)
h = heads
if skip_reshape:
@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape
dim_head //= heads
if "scale" in kwargs:
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
query = query * (kwargs["scale"] * dim_head ** 0.5)
if skip_reshape:
query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head)
@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape
dim_head //= heads
scale = dim_head ** -0.5
scale = kwargs.get("scale", dim_head ** -0.5)
if skip_reshape:
q, k, v = map(
@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3:
mask = mask.unsqueeze(1)
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
if SDP_BATCH_LIMIT >= b:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
if not skip_output_reshape:
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT],
attn_mask=m,
dropout_p=0.0, is_causal=False
dropout_p=0.0, is_causal=False, **sdpa_extra
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out

View File

@ -1246,6 +1246,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self._buffers[key] = fn(buf)
return self
class Embedding(manual_cast.Embedding):
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight"
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
# Only fp8 makes sense for embeddings (per-row dequant via index select).
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key]
scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(scale_key, None)
if scale is not None:
scale = scale.float()
manually_loaded_keys.append(scale_key)
params = layout_cls.Params(
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.num_embeddings, self.embedding_dim),
)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
requires_grad=False)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
else:
if layer_conf is not None:
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if not hasattr(self, 'weight') or self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def forward_comfy_cast_weights(self, input, out_dtype=None):
weight = self.weight
# Optimized path: lookup in fp8, dequantize only the selected rows.
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
if isinstance(qdata, QuantizedTensor):
scale = qdata._params.scale
qdata = qdata._qdata
else:
scale = None
x = torch.nn.functional.embedding(
input, qdata, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
uncast_bias_weight(self, qdata, None, offload_stream)
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
x = x.to(dtype=target_dtype)
if scale is not None and scale != 1.0:
x = x * scale.to(dtype=target_dtype)
return x
# Fallback for non-quantized or weight_function (LoRA) case
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):

View File

@ -3,6 +3,7 @@ import comfy.model_management
RMSNorm = torch.nn.RMSNorm
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
def rms_norm(x, weight=None, eps=1e-6):
if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)

View File

@ -65,6 +65,7 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35
import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.model_patcher
import comfy.lora
@ -1271,6 +1272,9 @@ class TEModel(Enum):
QWEN35_9B = 26
QWEN35_27B = 27
MINISTRAL_3_3B = 28
GEMMA_4_E4B = 29
GEMMA_4_E2B = 30
GEMMA_4_31B = 31
def detect_te_model(sd):
@ -1296,6 +1300,12 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.59.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_4_31B
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E4B
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E2B
if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd:
@ -1435,6 +1445,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer

File diff suppressed because it is too large Load Diff

View File

@ -521,7 +521,7 @@ class Attention(nn.Module):
else:
present_key_value = (xk, xv, index + num_tokens)
if sliding_window is not None and xk.shape[2] > sliding_window:
if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
xk = xk[:, :, -sliding_window:]
xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
@ -533,12 +533,12 @@ class Attention(nn.Module):
return self.o_proj(output), present_key_value
class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
super().__init__()
ops = ops or nn
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
if config.mlp_activation == "silu":
self.activation = torch.nn.functional.silu
elif config.mlp_activation == "gelu_pytorch_tanh":
@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module):
return x, present_key_value
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
class ScaledEmbedding(ops.Embedding):
def forward(self, input_ids, out_dtype=None):
return super().forward(input_ids, out_dtype=out_dtype) * scale
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size
self.embed_tokens = ops.Embedding(
config.vocab_size,
config.hidden_size,
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2
self.normalize_in = True
self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
else:
transformer = TransformerBlock
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
@ -690,15 +691,12 @@ class Llama2_(nn.Module):
self.config.rope_dims,
device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None):
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
if embeds is not None:
x = embeds
else:
x = self.embed_tokens(x, out_dtype=dtype)
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
seq_len = x.shape[1]
past_len = 0
if past_key_values is not None and len(past_key_values) > 0:
@ -850,7 +848,7 @@ class BaseGenerate:
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0):
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
device = embeds.device
if stop_tokens is None:
@ -875,14 +873,16 @@ class BaseGenerate:
pbar = comfy.utils.ProgressBar(max_length)
# Generation loop
current_input_ids = initial_input_ids
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
token_id = next_token[0].item()
generated_token_ids.append(token_id)
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
current_input_ids = next_token if initial_input_ids is not None else None
pbar.update(1)
if token_id in stop_tokens:

View File

@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module):

View File

@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def process_tokens(self, tokens, device):
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
embeds, _, _, _ = super().process_tokens(tokens, device)
return embeds
class LuminaModel(sd1_clip.SD1ClipModel):

View File

@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_):
nn.Module.__init__(self)
self.config = config
self.vocab_size = config.vocab_size
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)

View File

@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res
return res
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
"""Normalize image embeddings to match text embedding scale"""
for info in embeds_info:
if info.get("type") == "image":
start_idx = info["index"]
end_idx = start_idx + info["size"]
embeds[:, start_idx:end_idx, :] /= scale_factor

View File

@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Optional
from pydantic import BaseModel, Field
@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel):
grain: Optional[float] = Field(None, description="Grain after AI model processing")
grainSize: Optional[float] = Field(None, description="Size of generated grain")
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.")
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)")
sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness")
realism: float | None = Field(None, description="ast-2 realism control")
class OutputInformationVideo(BaseModel):
@ -90,7 +93,7 @@ class Overrides(BaseModel):
class CreateVideoRequest(BaseModel):
source: CreateVideoRequestSource = Field(...)
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...)
output: OutputInformationVideo = Field(...)
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))

View File

@ -36,11 +36,15 @@ from comfy_api_nodes.util import (
)
UPSCALER_MODELS_MAP = {
"Astra 2": "ast-2",
"Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1",
"Starlight Precise 2.5": "slp-2.5",
}
AST2_MAX_FRAMES = 9000
AST2_MAX_FRAMES_WITH_PROMPT = 450
class TopazImageEnhance(IO.ComfyNode):
@classmethod
@ -230,13 +234,20 @@ class TopazVideoEnhance(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TopazVideoEnhance",
display_name="Topaz Video Enhance",
display_name="Topaz Video Enhance (Legacy)",
category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
IO.Boolean.Input("upscaler_enabled", default=True),
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
IO.Combo.Input(
"upscaler_model",
options=[
"Starlight (Astra) Fast",
"Starlight (Astra) Creative",
"Starlight Precise 2.5",
],
),
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input(
"upscaler_creativity",
@ -304,6 +315,7 @@ class TopazVideoEnhance(IO.ComfyNode):
IO.Hidden.unique_id,
],
is_api_node=True,
is_deprecated=True,
)
@classmethod
@ -457,12 +469,357 @@ class TopazVideoEnhance(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
class TopazVideoEnhanceV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TopazVideoEnhanceV2",
display_name="Topaz Video Enhance",
category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
IO.DynamicCombo.Input(
"upscaler_model",
options=[
IO.DynamicCombo.Option(
"Astra 2",
[
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Float.Input(
"creativity",
default=0.5,
min=0.0,
max=1.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Creative strength of the upscale.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional descriptive (not instructive) scene prompt."
f"Capping input at {AST2_MAX_FRAMES_WITH_PROMPT} frames (~15s @ 30fps) when set.",
),
IO.Float.Input(
"sharp",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Pre-enhance sharpness: "
"0.0=Gaussian blur, 0.5=passthrough (default), 1.0=USM sharpening.",
advanced=True,
),
IO.Float.Input(
"realism",
default=0.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Pulls output toward photographic realism."
"Leave at 0 for the model default.",
advanced=True,
),
],
),
IO.DynamicCombo.Option(
"Starlight (Astra) Fast",
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),],
),
IO.DynamicCombo.Option(
"Starlight (Astra) Creative",
[
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input(
"creativity",
options=["low", "middle", "high"],
default="low",
tooltip="Creative strength of the upscale.",
),
],
),
IO.DynamicCombo.Option(
"Starlight Precise 2.5",
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"])],
),
IO.DynamicCombo.Option("Disabled", []),
],
),
IO.DynamicCombo.Input(
"interpolation_model",
options=[
IO.DynamicCombo.Option("Disabled", []),
IO.DynamicCombo.Option(
"apo-8",
[
IO.Int.Input(
"interpolation_frame_rate",
default=60,
min=15,
max=240,
display_mode=IO.NumberDisplay.number,
tooltip="Output frame rate.",
),
IO.Int.Input(
"interpolation_slowmo",
default=1,
min=1,
max=16,
display_mode=IO.NumberDisplay.number,
tooltip="Slow-motion factor applied to the input video. "
"For example, 2 makes the output twice as slow and doubles the duration.",
advanced=True,
),
IO.Boolean.Input(
"interpolation_duplicate",
default=False,
tooltip="Analyze the input for duplicate frames and remove them.",
advanced=True,
),
IO.Float.Input(
"interpolation_duplicate_threshold",
default=0.01,
min=0.001,
max=0.1,
step=0.001,
display_mode=IO.NumberDisplay.number,
tooltip="Detection sensitivity for duplicate frames.",
advanced=True,
),
],
),
],
),
IO.Combo.Input(
"dynamic_compression_level",
options=["Low", "Mid", "High"],
default="Low",
tooltip="CQP level.",
optional=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=[
"upscaler_model",
"upscaler_model.upscaler_resolution",
"interpolation_model",
]),
expr="""
(
$model := $lookup(widgets, "upscaler_model");
$res := $lookup(widgets, "upscaler_model.upscaler_resolution");
$interp := $lookup(widgets, "interpolation_model");
$is4k := $contains($res, "4k");
$hasInterp := $interp != "disabled";
$rates := {
"starlight (astra) fast": {"hd": 0.43, "uhd": 0.85},
"starlight precise 2.5": {"hd": 0.70, "uhd": 1.54},
"astra 2": {"hd": 1.72, "uhd": 2.85},
"starlight (astra) creative": {"hd": 2.25, "uhd": 3.99}
};
$surcharge := $is4k ? 0.28 : 0.14;
$entry := $lookup($rates, $model);
$base := $is4k ? $entry.uhd : $entry.hd;
$hi := $base + ($hasInterp ? $surcharge : 0);
$model = "disabled"
? {"type":"text","text":"Interpolation only"}
: ($hasInterp
? {"type":"text","text":"~" & $string($base) & "" & $string($hi) & " credits/src frame"}
: {"type":"text","text":"~" & $string($base) & " credits/src frame"})
)
""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
upscaler_model: dict,
interpolation_model: dict,
dynamic_compression_level: str = "Low",
) -> IO.NodeOutput:
upscaler_choice = upscaler_model["upscaler_model"]
interpolation_choice = interpolation_model["interpolation_model"]
if upscaler_choice == "Disabled" and interpolation_choice == "Disabled":
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
validate_container_format_is_mp4(video)
src_width, src_height = video.get_dimensions()
src_frame_rate = int(video.get_frame_rate())
duration_sec = video.get_duration()
src_video_stream = video.get_stream_source()
target_width = src_width
target_height = src_height
target_frame_rate = src_frame_rate
filters = []
if upscaler_choice != "Disabled":
if "1080p" in upscaler_model["upscaler_resolution"]:
target_pixel_p = 1080
max_long_side = 1920
else:
target_pixel_p = 2160
max_long_side = 3840
ar = src_width / src_height
if src_width >= src_height:
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
target_height = target_pixel_p
target_width = int(target_height * ar)
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
if target_width > max_long_side:
target_width = max_long_side
target_height = int(target_width / ar)
else:
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
target_width = target_pixel_p
target_height = int(target_width / ar)
# Check if height exceeds standard bounds
if target_height > max_long_side:
target_height = max_long_side
target_width = int(target_height * ar)
if target_width % 2 != 0:
target_width += 1
if target_height % 2 != 0:
target_height += 1
model_id = UPSCALER_MODELS_MAP[upscaler_choice]
if model_id == "slc-1":
filters.append(
VideoEnhancementFilter(
model=model_id,
creativity=upscaler_model["creativity"],
isOptimizedMode=True,
)
)
elif model_id == "ast-2":
n_frames = video.get_frame_count()
ast2_prompt = (upscaler_model["prompt"] or "").strip()
if ast2_prompt and n_frames > AST2_MAX_FRAMES_WITH_PROMPT:
raise ValueError(
f"Astra 2 with a prompt is limited to {AST2_MAX_FRAMES_WITH_PROMPT} input frames "
f"(~15s @ 30fps); video has {n_frames}. Clear the prompt or shorten the clip."
)
if n_frames > AST2_MAX_FRAMES:
raise ValueError(f"Astra 2 is limited to {AST2_MAX_FRAMES} input frames; video has {n_frames}.")
realism = upscaler_model["realism"]
filters.append(
VideoEnhancementFilter(
model=model_id,
creativity=upscaler_model["creativity"],
prompt=(ast2_prompt or None),
sharp=upscaler_model["sharp"],
realism=(realism if realism > 0 else None),
)
)
else:
filters.append(VideoEnhancementFilter(model=model_id))
if interpolation_choice != "Disabled":
target_frame_rate = interpolation_model["interpolation_frame_rate"]
filters.append(
VideoFrameInterpolationFilter(
model=interpolation_choice,
slowmo=interpolation_model["interpolation_slowmo"],
fps=interpolation_model["interpolation_frame_rate"],
duplicate=interpolation_model["interpolation_duplicate"],
duplicate_threshold=interpolation_model["interpolation_duplicate_threshold"],
),
)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
response_model=CreateVideoResponse,
data=CreateVideoRequest(
source=CreateVideoRequestSource(
container="mp4",
size=get_fs_object_size(src_video_stream),
duration=int(duration_sec),
frameCount=video.get_frame_count(),
frameRate=src_frame_rate,
resolution=Resolution(width=src_width, height=src_height),
),
filters=filters,
output=OutputInformationVideo(
resolution=Resolution(width=target_width, height=target_height),
frameRate=target_frame_rate,
audioCodec="AAC",
audioTransfer="Copy",
dynamicCompressionLevel=dynamic_compression_level,
),
),
wait_label="Creating task",
final_label_on_success="Task created",
)
upload_res = await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
method="PATCH",
),
response_model=VideoAcceptResponse,
wait_label="Preparing upload",
final_label_on_success="Upload started",
)
if len(upload_res.urls) > 1:
raise NotImplementedError(
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
)
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
if isinstance(src_video_stream, BytesIO):
src_video_stream.seek(0)
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
else:
with builtins.open(src_video_stream, "rb") as video_file:
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
method="PATCH",
),
response_model=VideoCompleteUploadResponse,
data=VideoCompleteUploadRequest(
uploadResults=[
VideoCompleteUploadRequestPart(
partNum=1,
eTag=upload_etag,
),
],
),
wait_label="Finalizing upload",
final_label_on_success="Upload completed",
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
response_model=VideoStatusResponse,
status_extractor=lambda x: x.status,
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
poll_interval=10.0,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
class TopazExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TopazImageEnhance,
TopazVideoEnhance,
TopazVideoEnhanceV2,
]

View File

@ -202,14 +202,11 @@ class JoinImageWithAlpha(io.ComfyNode):
@classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = min(len(image), len(alpha))
out_images = []
batch_size = max(len(image), len(alpha))
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size):
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
return io.NodeOutput(torch.stack(out_images))
alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
image = comfy.utils.repeat_to_batch_size(image, batch_size)
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
class CompositingExtension(ComfyExtension):

View File

@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode):
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("image", optional=True),
io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."),
io.Audio.Input("audio", optional=True),
io.Int.Input("max_length", default=256, min=1, max=2048),
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking)
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio)
# Get sampling parameters from dynamic combo
do_sample = sampling_mode.get("sampling_mode") == "on"
@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode):
seed=seed
)
generated_text = clip.decode(generated_ids, skip_special_tokens=True)
generated_text = clip.decode(generated_ids)
return io.NodeOutput(generated_text)
@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput:
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
if image is None:
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
else:
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template)
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio)
class TextgenExtension(ComfyExtension):

View File

@ -86,6 +86,6 @@ def image_alpha_fix(destination, source):
if destination.shape[-1] < source.shape[-1]:
source = source[...,:destination.shape[-1]]
elif destination.shape[-1] > source.shape[-1]:
destination = torch.nn.functional.pad(destination, (0, 1))
destination[..., -1] = 1.0
source = torch.nn.functional.pad(source, (0, 1))
source[..., -1] = 1.0
return destination, source

View File

@ -1694,26 +1694,27 @@ class LoadImage:
RETURN_TYPES = ("IMAGE", "MASK")
FUNCTION = "load_image"
def load_image(self, image):
image_path = folder_paths.get_annotated_filepath(image)
dtype = comfy.model_management.intermediate_dtype()
device = comfy.model_management.intermediate_device()
components = InputImpl.VideoFromFile(image_path).get_components()
if components.images.shape[0] > 0:
return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu"))
return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device))
# This code is left here to handle animated webp which pyav does not support loading
img = node_helpers.pillow(Image.open, image_path)
output_images = []
output_masks = []
w, h = None, None
dtype = comfy.model_management.intermediate_dtype()
for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")
if len(output_images) == 0:
@ -1728,25 +1729,15 @@ class LoadImage:
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
elif i.mode == 'P' and 'transparency' in i.info:
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
output_images.append(image.to(dtype=dtype))
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
if img.format == "MPO":
break # ignore all frames except the first one for MPO format
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
if len(output_images) > 1:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]
return (output_image, output_mask)
return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype))
@classmethod
def IS_CHANGED(s, image):
@ -2213,6 +2204,12 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
LOADED_MODULE_DIRS[module_name] = os.path.abspath(module_dir)
# Only load node_replacements.json from directory-based custom nodes (proper packs).
# Single-file .py nodes share a parent dir, so checking there would be incorrect.
if os.path.isdir(module_path):
from server import PromptServer
PromptServer.instance.node_replace_manager.load_from_json(module_dir, module_name)
try:
from comfy_config import config_parser

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.42.15
comfyui-workflow-templates==0.9.66
comfyui-workflow-templates==0.9.68
comfyui-embedded-docs==0.4.4
torch
torchsde

View File

@ -1,3 +1,4 @@
import errno
import os
import sys
import asyncio
@ -1245,7 +1246,13 @@ class PromptServer():
address = addr[0]
port = addr[1]
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
await site.start()
try:
await site.start()
except OSError as e:
if e.errno == errno.EADDRINUSE:
logging.error(f"Port {port} is already in use on address {address}. Please close the other application or use a different port with --port.")
raise SystemExit(1)
raise
if not hasattr(self, 'address'):
self.address = address #TODO: remove this

View File

@ -0,0 +1,217 @@
"""Tests for NodeReplaceManager.load_from_json — auto-registration of
node_replacements.json from custom node directories."""
import json
import os
import tempfile
import unittest
from app.node_replace_manager import NodeReplaceManager
class SimpleNodeReplace:
"""Lightweight stand-in for comfy_api.latest._io.NodeReplace (avoids torch import)."""
def __init__(self, new_node_id, old_node_id, old_widget_ids=None,
input_mapping=None, output_mapping=None):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def as_dict(self):
return {
"new_node_id": self.new_node_id,
"old_node_id": self.old_node_id,
"old_widget_ids": self.old_widget_ids,
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
}
class TestLoadFromJson(unittest.TestCase):
"""Test auto-registration of node_replacements.json from custom node directories."""
def setUp(self):
self.tmpdir = tempfile.mkdtemp()
self.manager = NodeReplaceManager()
def _write_json(self, data):
path = os.path.join(self.tmpdir, "node_replacements.json")
with open(path, "w") as f:
json.dump(data, f)
def _load(self):
self.manager.load_from_json(self.tmpdir, "test-node-pack", _node_replace_class=SimpleNodeReplace)
def test_no_file_does_nothing(self):
"""No node_replacements.json — should silently do nothing."""
self._load()
self.assertEqual(self.manager.as_dict(), {})
def test_empty_object(self):
"""Empty {} — should do nothing."""
self._write_json({})
self._load()
self.assertEqual(self.manager.as_dict(), {})
def test_single_replacement(self):
"""Single replacement entry registers correctly."""
self._write_json({
"OldNode": [{
"new_node_id": "NewNode",
"old_node_id": "OldNode",
"input_mapping": [{"new_id": "model", "old_id": "ckpt_name"}],
"output_mapping": [{"new_idx": 0, "old_idx": 0}],
}]
})
self._load()
result = self.manager.as_dict()
self.assertIn("OldNode", result)
self.assertEqual(len(result["OldNode"]), 1)
entry = result["OldNode"][0]
self.assertEqual(entry["new_node_id"], "NewNode")
self.assertEqual(entry["old_node_id"], "OldNode")
self.assertEqual(entry["input_mapping"], [{"new_id": "model", "old_id": "ckpt_name"}])
self.assertEqual(entry["output_mapping"], [{"new_idx": 0, "old_idx": 0}])
def test_multiple_replacements(self):
"""Multiple old_node_ids each with entries."""
self._write_json({
"NodeA": [{"new_node_id": "NodeB", "old_node_id": "NodeA"}],
"NodeC": [{"new_node_id": "NodeD", "old_node_id": "NodeC"}],
})
self._load()
result = self.manager.as_dict()
self.assertEqual(len(result), 2)
self.assertIn("NodeA", result)
self.assertIn("NodeC", result)
def test_multiple_alternatives_for_same_node(self):
"""Multiple replacement options for the same old node."""
self._write_json({
"OldNode": [
{"new_node_id": "AltA", "old_node_id": "OldNode"},
{"new_node_id": "AltB", "old_node_id": "OldNode"},
]
})
self._load()
result = self.manager.as_dict()
self.assertEqual(len(result["OldNode"]), 2)
def test_null_mappings(self):
"""Null input/output mappings (trivial replacement)."""
self._write_json({
"OldNode": [{
"new_node_id": "NewNode",
"old_node_id": "OldNode",
"input_mapping": None,
"output_mapping": None,
}]
})
self._load()
entry = self.manager.as_dict()["OldNode"][0]
self.assertIsNone(entry["input_mapping"])
self.assertIsNone(entry["output_mapping"])
def test_old_node_id_defaults_to_key(self):
"""If old_node_id is missing from entry, uses the dict key."""
self._write_json({
"OldNode": [{"new_node_id": "NewNode"}]
})
self._load()
entry = self.manager.as_dict()["OldNode"][0]
self.assertEqual(entry["old_node_id"], "OldNode")
def test_invalid_json_skips(self):
"""Invalid JSON file — should warn and skip, not crash."""
path = os.path.join(self.tmpdir, "node_replacements.json")
with open(path, "w") as f:
f.write("{invalid json")
self._load()
self.assertEqual(self.manager.as_dict(), {})
def test_non_object_json_skips(self):
"""JSON array instead of object — should warn and skip."""
self._write_json([1, 2, 3])
self._load()
self.assertEqual(self.manager.as_dict(), {})
def test_non_list_value_skips(self):
"""Value is not a list — should warn and skip that key."""
self._write_json({
"OldNode": "not a list",
"GoodNode": [{"new_node_id": "NewNode", "old_node_id": "GoodNode"}],
})
self._load()
result = self.manager.as_dict()
self.assertNotIn("OldNode", result)
self.assertIn("GoodNode", result)
def test_with_old_widget_ids(self):
"""old_widget_ids are passed through."""
self._write_json({
"OldNode": [{
"new_node_id": "NewNode",
"old_node_id": "OldNode",
"old_widget_ids": ["width", "height"],
}]
})
self._load()
entry = self.manager.as_dict()["OldNode"][0]
self.assertEqual(entry["old_widget_ids"], ["width", "height"])
def test_set_value_in_input_mapping(self):
"""input_mapping with set_value entries."""
self._write_json({
"OldNode": [{
"new_node_id": "NewNode",
"old_node_id": "OldNode",
"input_mapping": [
{"new_id": "method", "set_value": "lanczos"},
{"new_id": "size", "old_id": "dimension"},
],
}]
})
self._load()
entry = self.manager.as_dict()["OldNode"][0]
self.assertEqual(len(entry["input_mapping"]), 2)
def test_missing_new_node_id_skipped(self):
"""Entry without new_node_id is skipped."""
self._write_json({
"OldNode": [
{"old_node_id": "OldNode"},
{"new_node_id": "", "old_node_id": "OldNode"},
{"new_node_id": "ValidNew", "old_node_id": "OldNode"},
]
})
self._load()
result = self.manager.as_dict()
self.assertEqual(len(result["OldNode"]), 1)
self.assertEqual(result["OldNode"][0]["new_node_id"], "ValidNew")
def test_non_dict_entry_skipped(self):
"""Non-dict entries in the list are silently skipped."""
self._write_json({
"OldNode": [
"not a dict",
{"new_node_id": "NewNode", "old_node_id": "OldNode"},
]
})
self._load()
result = self.manager.as_dict()
self.assertEqual(len(result["OldNode"]), 1)
def test_has_replacement_after_load(self):
"""Manager reports has_replacement correctly after JSON load."""
self._write_json({
"OldNode": [{"new_node_id": "NewNode", "old_node_id": "OldNode"}],
})
self.assertFalse(self.manager.has_replacement("OldNode"))
self._load()
self.assertTrue(self.manager.has_replacement("OldNode"))
self.assertFalse(self.manager.has_replacement("UnknownNode"))
if __name__ == "__main__":
unittest.main()