Merge branch 'master' into feature/generic-feature-flag-cli

This commit is contained in:
Jedrzej Kosinski 2026-05-04 07:37:45 -07:00 committed by GitHub
commit 0992141135
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 1887 additions and 101 deletions

View File

@ -1,7 +1,7 @@
<div align="center"> <div align="center">
# ComfyUI # ComfyUI
**The most powerful and modular visual AI engine and application.** **The most powerful and modular AI engine for content creation.**
[![Website][website-shield]][website-url] [![Website][website-shield]][website-url]
@ -31,10 +31,16 @@
[github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest [github-downloads-latest-shield]: https://img.shields.io/github/downloads/comfyanonymous/ComfyUI/latest/total?style=flat&label=downloads%40latest
[github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases [github-downloads-link]: https://github.com/comfyanonymous/ComfyUI/releases
![ComfyUI Screenshot](https://github.com/user-attachments/assets/7ccaf2c1-9b72-41ae-9a89-5688c94b7abe) <img width="1590" height="795" alt="ComfyUI Screenshot" src="https://github.com/user-attachments/assets/36e065e0-bfae-4456-8c7f-8369d5ea48a2" />
<br>
</div> </div>
ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS. ComfyUI is the AI creation engine for visual professionals who demand control over every model, every parameter, and every output. Its powerful and modular node graph interface empowers creatives to generate images, videos, 3D models, audio, and more...
- ComfyUI natively supports the latest open-source state of the art models.
- API nodes provide access to the best closed source models such as Nano Banana, Seedance, Hunyuan3D, etc.
- It is available on Windows, Linux, and macOS, locally with our desktop application or on our cloud.
- The most sophisticated workflows can be exposed through a simple UI thanks to App Mode.
- It integrates seamlessly into production pipelines with our API endpoints.
## Get Started ## Get Started
@ -77,6 +83,7 @@ See what ComfyUI can do with the [newer template workflows](https://comfy.org/wo
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/) - [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/) - [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/) - [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
- Ernie Image
- Image Editing Models - Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/) - [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model) - [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)

View File

@ -91,6 +91,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.") parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.") parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
parser.add_argument("--enable-triton-backend", action="store_true", help="ComfyUI will enable the use of Triton backend in comfy-kitchen. Is disabled at launch by default.")
class LatentPreviewMethod(enum.Enum): class LatentPreviewMethod(enum.Enum):
NoPreviews = "none" NoPreviews = "none"

View File

@ -14,6 +14,8 @@ from .sub_quadratic_attention import efficient_dot_product_attention
from comfy import model_management from comfy import model_management
TORCH_HAS_GQA = model_management.torch_version_numeric >= (2, 5)
if model_management.xformers_enabled(): if model_management.xformers_enabled():
import xformers import xformers
import xformers.ops import xformers.ops
@ -150,7 +152,12 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5 if kwargs.get("enable_gqa", False) and q.shape[-3] != k.shape[-3]:
n_rep = q.shape[-3] // k.shape[-3]
k = k.repeat_interleave(n_rep, dim=-3)
v = v.repeat_interleave(n_rep, dim=-3)
scale = kwargs.get("scale", dim_head ** -0.5)
h = heads h = heads
if skip_reshape: if skip_reshape:
@ -219,6 +226,10 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape b, _, dim_head = query.shape
dim_head //= heads dim_head //= heads
if "scale" in kwargs:
# Pre-scale query to match requested scale (cancels internal 1/sqrt(dim_head))
query = query * (kwargs["scale"] * dim_head ** 0.5)
if skip_reshape: if skip_reshape:
query = query.reshape(b * heads, -1, dim_head) query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head) value = value.reshape(b * heads, -1, dim_head)
@ -290,7 +301,7 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
b, _, dim_head = q.shape b, _, dim_head = q.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5 scale = kwargs.get("scale", dim_head ** -0.5)
if skip_reshape: if skip_reshape:
q, k, v = map( q, k, v = map(
@ -500,8 +511,13 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
if mask.ndim == 3: if mask.ndim == 3:
mask = mask.unsqueeze(1) mask = mask.unsqueeze(1)
# Pass through extra SDPA kwargs (scale, enable_gqa) if provided
# enable_gqa requires PyTorch 2.5+; older versions use manual KV expansion above
sdpa_keys = ("scale", "enable_gqa") if TORCH_HAS_GQA else ("scale",)
sdpa_extra = {k: v for k, v in kwargs.items() if k in sdpa_keys}
if SDP_BATCH_LIMIT >= b: if SDP_BATCH_LIMIT >= b:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False, **sdpa_extra)
if not skip_output_reshape: if not skip_output_reshape:
out = ( out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head) out.transpose(1, 2).reshape(b, -1, heads * dim_head)
@ -519,7 +535,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
k[i : i + SDP_BATCH_LIMIT], k[i : i + SDP_BATCH_LIMIT],
v[i : i + SDP_BATCH_LIMIT], v[i : i + SDP_BATCH_LIMIT],
attn_mask=m, attn_mask=m,
dropout_p=0.0, is_causal=False dropout_p=0.0, is_causal=False, **sdpa_extra
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head) ).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
return out return out

View File

@ -1246,6 +1246,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
self._buffers[key] = fn(buf) self._buffers[key] = fn(buf)
return self return self
class Embedding(manual_cast.Embedding):
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
weight_key = f"{prefix}weight"
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
if layer_conf is not None:
layer_conf = json.loads(layer_conf.numpy().tobytes())
# Only fp8 makes sense for embeddings (per-row dequant via index select).
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
self.quant_format = quant_format
qconfig = QUANT_ALGOS[quant_format]
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
weight = state_dict.pop(weight_key)
manually_loaded_keys = [weight_key]
scale_key = f"{prefix}weight_scale"
scale = state_dict.pop(scale_key, None)
if scale is not None:
scale = scale.float()
manually_loaded_keys.append(scale_key)
params = layout_cls.Params(
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.num_embeddings, self.embedding_dim),
)
self.weight = torch.nn.Parameter(
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
requires_grad=False)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
for k in manually_loaded_keys:
if k in missing_keys:
missing_keys.remove(k)
else:
if layer_conf is not None:
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
if destination is not None:
sd = destination
else:
sd = {}
if not hasattr(self, 'weight') or self.weight is None:
return sd
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def forward_comfy_cast_weights(self, input, out_dtype=None):
weight = self.weight
# Optimized path: lookup in fp8, dequantize only the selected rows.
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
if isinstance(qdata, QuantizedTensor):
scale = qdata._params.scale
qdata = qdata._qdata
else:
scale = None
x = torch.nn.functional.embedding(
input, qdata, self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse)
uncast_bias_weight(self, qdata, None, offload_stream)
target_dtype = out_dtype if out_dtype is not None else weight._params.orig_dtype
x = x.to(dtype=target_dtype)
if scale is not None and scale != 1.0:
x = x * scale.to(dtype=target_dtype)
return x
# Fallback for non-quantized or weight_function (LoRA) case
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
return MixedPrecisionOps return MixedPrecisionOps
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):

View File

@ -1,6 +1,8 @@
import torch import torch
import logging import logging
from comfy.cli_args import args
try: try:
import comfy_kitchen as ck import comfy_kitchen as ck
from comfy_kitchen.tensor import ( from comfy_kitchen.tensor import (
@ -21,6 +23,14 @@ try:
ck.registry.disable("cuda") ck.registry.disable("cuda")
logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.") logging.warning("WARNING: You need pytorch with cu130 or higher to use optimized CUDA operations.")
if args.enable_triton_backend:
try:
import triton
logging.info("Found triton %s. Enabling comfy-kitchen triton backend.", triton.__version__)
except ImportError as e:
logging.error(f"Failed to import triton, Error: {e}, the comfy-kitchen triton backend will not be available.")
ck.registry.disable("triton")
else:
ck.registry.disable("triton") ck.registry.disable("triton")
for k, v in ck.list_backends().items(): for k, v in ck.list_backends().items():
logging.info(f"Found comfy_kitchen backend {k}: {v}") logging.info(f"Found comfy_kitchen backend {k}: {v}")

View File

@ -3,6 +3,7 @@ import comfy.model_management
RMSNorm = torch.nn.RMSNorm RMSNorm = torch.nn.RMSNorm
# Note: torch's fused F.rms_norm is faster but produces slightly different output than manual implementations (rsqrt/reduction rounding).
def rms_norm(x, weight=None, eps=1e-6): def rms_norm(x, weight=None, eps=1e-6):
if weight is None: if weight is None:
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps) return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)

View File

@ -65,6 +65,7 @@ import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image import comfy.text_encoders.longcat_image
import comfy.text_encoders.qwen35 import comfy.text_encoders.qwen35
import comfy.text_encoders.ernie import comfy.text_encoders.ernie
import comfy.text_encoders.gemma4
import comfy.model_patcher import comfy.model_patcher
import comfy.lora import comfy.lora
@ -1271,6 +1272,9 @@ class TEModel(Enum):
QWEN35_9B = 26 QWEN35_9B = 26
QWEN35_27B = 27 QWEN35_27B = 27
MINISTRAL_3_3B = 28 MINISTRAL_3_3B = 28
GEMMA_4_E4B = 29
GEMMA_4_E2B = 30
GEMMA_4_31B = 31
def detect_te_model(sd): def detect_te_model(sd):
@ -1296,6 +1300,12 @@ def detect_te_model(sd):
return TEModel.BYT5_SMALL_GLYPH return TEModel.BYT5_SMALL_GLYPH
return TEModel.T5_BASE return TEModel.T5_BASE
if 'model.layers.0.post_feedforward_layernorm.weight' in sd: if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
if 'model.layers.59.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_4_31B
if 'model.layers.41.self_attn.q_norm.weight' in sd and 'model.layers.47.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E4B
if 'model.layers.34.self_attn.q_norm.weight' in sd and 'model.layers.41.self_attn.q_norm.weight' not in sd:
return TEModel.GEMMA_4_E2B
if 'model.layers.47.self_attn.q_norm.weight' in sd: if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd: if 'model.layers.0.self_attn.q_norm.weight' in sd:
@ -1435,6 +1445,13 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else: else:
clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
elif te_model in (TEModel.GEMMA_4_E4B, TEModel.GEMMA_4_E2B, TEModel.GEMMA_4_31B):
variant = {TEModel.GEMMA_4_E4B: comfy.text_encoders.gemma4.Gemma4_E4B,
TEModel.GEMMA_4_E2B: comfy.text_encoders.gemma4.Gemma4_E2B,
TEModel.GEMMA_4_31B: comfy.text_encoders.gemma4.Gemma4_31B}[te_model]
clip_target.clip = comfy.text_encoders.gemma4.gemma4_te(**llama_detect(clip_data), model_class=variant)
clip_target.tokenizer = variant.tokenizer
tokenizer_data["tokenizer_json"] = clip_data[0].get("tokenizer_json", None)
elif te_model == TEModel.GEMMA_2_2B: elif te_model == TEModel.GEMMA_2_2B:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data)) clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer

File diff suppressed because it is too large Load Diff

View File

@ -521,7 +521,7 @@ class Attention(nn.Module):
else: else:
present_key_value = (xk, xv, index + num_tokens) present_key_value = (xk, xv, index + num_tokens)
if sliding_window is not None and xk.shape[2] > sliding_window: if sliding_window is not None and xk.shape[2] > sliding_window and seq_length == 1:
xk = xk[:, :, -sliding_window:] xk = xk[:, :, -sliding_window:]
xv = xv[:, :, -sliding_window:] xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
@ -533,12 +533,12 @@ class Attention(nn.Module):
return self.o_proj(output), present_key_value return self.o_proj(output), present_key_value
class MLP(nn.Module): class MLP(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None): def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None, intermediate_size=None):
super().__init__() super().__init__()
ops = ops or nn intermediate_size = intermediate_size or config.intermediate_size
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) self.gate_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype) self.up_proj = ops.Linear(config.hidden_size, intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype) self.down_proj = ops.Linear(intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
if config.mlp_activation == "silu": if config.mlp_activation == "silu":
self.activation = torch.nn.functional.silu self.activation = torch.nn.functional.silu
elif config.mlp_activation == "gelu_pytorch_tanh": elif config.mlp_activation == "gelu_pytorch_tanh":
@ -647,24 +647,25 @@ class TransformerBlockGemma2(nn.Module):
return x, present_key_value return x, present_key_value
def _make_scaled_embedding(ops, vocab_size, hidden_size, scale, device, dtype):
class ScaledEmbedding(ops.Embedding):
def forward(self, input_ids, out_dtype=None):
return super().forward(input_ids, out_dtype=out_dtype) * scale
return ScaledEmbedding(vocab_size, hidden_size, device=device, dtype=dtype)
class Llama2_(nn.Module): class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None): def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__() super().__init__()
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = ops.Embedding(
config.vocab_size,
config.hidden_size,
device=device,
dtype=dtype
)
if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3": if self.config.transformer_type == "gemma2" or self.config.transformer_type == "gemma3":
transformer = TransformerBlockGemma2 transformer = TransformerBlockGemma2
self.normalize_in = True self.embed_tokens = _make_scaled_embedding(ops, config.vocab_size, config.hidden_size, config.hidden_size ** 0.5, device, dtype)
else: else:
transformer = TransformerBlock transformer = TransformerBlock
self.normalize_in = False self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
transformer(config, index=i, device=device, dtype=dtype, ops=ops) transformer(config, index=i, device=device, dtype=dtype, ops=ops)
@ -690,15 +691,12 @@ class Llama2_(nn.Module):
self.config.rope_dims, self.config.rope_dims,
device=device) device=device)
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None): def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[], past_key_values=None, input_ids=None):
if embeds is not None: if embeds is not None:
x = embeds x = embeds
else: else:
x = self.embed_tokens(x, out_dtype=dtype) x = self.embed_tokens(x, out_dtype=dtype)
if self.normalize_in:
x *= self.config.hidden_size ** 0.5
seq_len = x.shape[1] seq_len = x.shape[1]
past_len = 0 past_len = 0
if past_key_values is not None and len(past_key_values) > 0: if past_key_values is not None and len(past_key_values) > 0:
@ -850,7 +848,7 @@ class BaseGenerate:
torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0)) torch.empty([batch, model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
return past_key_values return past_key_values
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0): def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=None, initial_tokens=[], execution_dtype=None, min_tokens=0, presence_penalty=0.0, initial_input_ids=None):
device = embeds.device device = embeds.device
if stop_tokens is None: if stop_tokens is None:
@ -875,14 +873,16 @@ class BaseGenerate:
pbar = comfy.utils.ProgressBar(max_length) pbar = comfy.utils.ProgressBar(max_length)
# Generation loop # Generation loop
current_input_ids = initial_input_ids
for step in tqdm(range(max_length), desc="Generating tokens"): for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values) x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values, input_ids=current_input_ids)
logits = self.logits(x)[:, -1] logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty) next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample, presence_penalty=presence_penalty)
token_id = next_token[0].item() token_id = next_token[0].item()
generated_token_ids.append(token_id) generated_token_ids.append(token_id)
embeds = self.model.embed_tokens(next_token).to(execution_dtype) embeds = self.model.embed_tokens(next_token).to(execution_dtype)
current_input_ids = next_token if initial_input_ids is not None else None
pbar.update(1) pbar.update(1)
if token_id in stop_tokens: if token_id in stop_tokens:

View File

@ -93,8 +93,7 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty): def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, presence_penalty):
tokens_only = [[t[0] for t in b] for b in tokens] tokens_only = [[t[0] for t in b] for b in tokens]
embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device) embeds, _, _, _ = self.process_tokens(tokens_only, self.execution_device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn> return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106], presence_penalty=presence_penalty) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module): class DualLinearProjection(torch.nn.Module):

View File

@ -50,8 +50,7 @@ class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def process_tokens(self, tokens, device): def process_tokens(self, tokens, device):
embeds, _, _, embeds_info = super().process_tokens(tokens, device) embeds, _, _, _ = super().process_tokens(tokens, device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return embeds return embeds
class LuminaModel(sd1_clip.SD1ClipModel): class LuminaModel(sd1_clip.SD1ClipModel):

View File

@ -408,8 +408,6 @@ class Qwen35Transformer(Llama2_):
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.normalize_in = False
self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype) self.embed_tokens = ops.Embedding(config.vocab_size, config.hidden_size, device=device, dtype=dtype)
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops) Qwen35TransformerBlock(config, index=i, device=device, dtype=dtype, ops=ops)

View File

@ -1446,10 +1446,3 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res memo[obj_id] = res
return res return res
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
"""Normalize image embeddings to match text embedding scale"""
for info in embeds_info:
if info.get("type") == "image":
start_idx = info["index"]
end_idx = start_idx + info["size"]
embeds[:, start_idx:end_idx, :] /= scale_factor

View File

@ -1,4 +1,4 @@
from typing import Optional, Union from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -72,8 +72,11 @@ class VideoEnhancementFilter(BaseModel):
grain: Optional[float] = Field(None, description="Grain after AI model processing") grain: Optional[float] = Field(None, description="Grain after AI model processing")
grainSize: Optional[float] = Field(None, description="Size of generated grain") grainSize: Optional[float] = Field(None, description="Size of generated grain")
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video") recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only") creativity: float | str | None = Field(None, description="slc-1/slp-2.5: enum (low/middle/high). ast-2: decimal 0.0-1.0.")
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only") isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
prompt: str | None = Field(None, description="Descriptive scene prompt (ast-2 only)")
sharp: float | None = Field(None, description="ast-2 pre-enhance sharpness")
realism: float | None = Field(None, description="ast-2 realism control")
class OutputInformationVideo(BaseModel): class OutputInformationVideo(BaseModel):
@ -90,7 +93,7 @@ class Overrides(BaseModel):
class CreateVideoRequest(BaseModel): class CreateVideoRequest(BaseModel):
source: CreateVideoRequestSource = Field(...) source: CreateVideoRequestSource = Field(...)
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) filters: list[VideoFrameInterpolationFilter | VideoEnhancementFilter] = Field(...)
output: OutputInformationVideo = Field(...) output: OutputInformationVideo = Field(...)
overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) overrides: Overrides = Field(Overrides(isPaidDiffusion=True))

View File

@ -36,11 +36,15 @@ from comfy_api_nodes.util import (
) )
UPSCALER_MODELS_MAP = { UPSCALER_MODELS_MAP = {
"Astra 2": "ast-2",
"Starlight (Astra) Fast": "slf-1", "Starlight (Astra) Fast": "slf-1",
"Starlight (Astra) Creative": "slc-1", "Starlight (Astra) Creative": "slc-1",
"Starlight Precise 2.5": "slp-2.5", "Starlight Precise 2.5": "slp-2.5",
} }
AST2_MAX_FRAMES = 9000
AST2_MAX_FRAMES_WITH_PROMPT = 450
class TopazImageEnhance(IO.ComfyNode): class TopazImageEnhance(IO.ComfyNode):
@classmethod @classmethod
@ -230,13 +234,20 @@ class TopazVideoEnhance(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="TopazVideoEnhance", node_id="TopazVideoEnhance",
display_name="Topaz Video Enhance", display_name="Topaz Video Enhance (Legacy)",
category="api node/video/Topaz", category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.", description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[ inputs=[
IO.Video.Input("video"), IO.Video.Input("video"),
IO.Boolean.Input("upscaler_enabled", default=True), IO.Boolean.Input("upscaler_enabled", default=True),
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), IO.Combo.Input(
"upscaler_model",
options=[
"Starlight (Astra) Fast",
"Starlight (Astra) Creative",
"Starlight Precise 2.5",
],
),
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]), IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input( IO.Combo.Input(
"upscaler_creativity", "upscaler_creativity",
@ -304,6 +315,7 @@ class TopazVideoEnhance(IO.ComfyNode):
IO.Hidden.unique_id, IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
is_deprecated=True,
) )
@classmethod @classmethod
@ -457,12 +469,357 @@ class TopazVideoEnhance(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
class TopazVideoEnhanceV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TopazVideoEnhanceV2",
display_name="Topaz Video Enhance",
category="api node/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
IO.DynamicCombo.Input(
"upscaler_model",
options=[
IO.DynamicCombo.Option(
"Astra 2",
[
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Float.Input(
"creativity",
default=0.5,
min=0.0,
max=1.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Creative strength of the upscale.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional descriptive (not instructive) scene prompt."
f"Capping input at {AST2_MAX_FRAMES_WITH_PROMPT} frames (~15s @ 30fps) when set.",
),
IO.Float.Input(
"sharp",
default=0.5,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Pre-enhance sharpness: "
"0.0=Gaussian blur, 0.5=passthrough (default), 1.0=USM sharpening.",
advanced=True,
),
IO.Float.Input(
"realism",
default=0.0,
min=0.0,
max=1.0,
step=0.01,
display_mode=IO.NumberDisplay.slider,
tooltip="Pulls output toward photographic realism."
"Leave at 0 for the model default.",
advanced=True,
),
],
),
IO.DynamicCombo.Option(
"Starlight (Astra) Fast",
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),],
),
IO.DynamicCombo.Option(
"Starlight (Astra) Creative",
[
IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"]),
IO.Combo.Input(
"creativity",
options=["low", "middle", "high"],
default="low",
tooltip="Creative strength of the upscale.",
),
],
),
IO.DynamicCombo.Option(
"Starlight Precise 2.5",
[IO.Combo.Input("upscaler_resolution", options=["FullHD (1080p)", "4K (2160p)"])],
),
IO.DynamicCombo.Option("Disabled", []),
],
),
IO.DynamicCombo.Input(
"interpolation_model",
options=[
IO.DynamicCombo.Option("Disabled", []),
IO.DynamicCombo.Option(
"apo-8",
[
IO.Int.Input(
"interpolation_frame_rate",
default=60,
min=15,
max=240,
display_mode=IO.NumberDisplay.number,
tooltip="Output frame rate.",
),
IO.Int.Input(
"interpolation_slowmo",
default=1,
min=1,
max=16,
display_mode=IO.NumberDisplay.number,
tooltip="Slow-motion factor applied to the input video. "
"For example, 2 makes the output twice as slow and doubles the duration.",
advanced=True,
),
IO.Boolean.Input(
"interpolation_duplicate",
default=False,
tooltip="Analyze the input for duplicate frames and remove them.",
advanced=True,
),
IO.Float.Input(
"interpolation_duplicate_threshold",
default=0.01,
min=0.001,
max=0.1,
step=0.001,
display_mode=IO.NumberDisplay.number,
tooltip="Detection sensitivity for duplicate frames.",
advanced=True,
),
],
),
],
),
IO.Combo.Input(
"dynamic_compression_level",
options=["Low", "Mid", "High"],
default="Low",
tooltip="CQP level.",
optional=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=[
"upscaler_model",
"upscaler_model.upscaler_resolution",
"interpolation_model",
]),
expr="""
(
$model := $lookup(widgets, "upscaler_model");
$res := $lookup(widgets, "upscaler_model.upscaler_resolution");
$interp := $lookup(widgets, "interpolation_model");
$is4k := $contains($res, "4k");
$hasInterp := $interp != "disabled";
$rates := {
"starlight (astra) fast": {"hd": 0.43, "uhd": 0.85},
"starlight precise 2.5": {"hd": 0.70, "uhd": 1.54},
"astra 2": {"hd": 1.72, "uhd": 2.85},
"starlight (astra) creative": {"hd": 2.25, "uhd": 3.99}
};
$surcharge := $is4k ? 0.28 : 0.14;
$entry := $lookup($rates, $model);
$base := $is4k ? $entry.uhd : $entry.hd;
$hi := $base + ($hasInterp ? $surcharge : 0);
$model = "disabled"
? {"type":"text","text":"Interpolation only"}
: ($hasInterp
? {"type":"text","text":"~" & $string($base) & "" & $string($hi) & " credits/src frame"}
: {"type":"text","text":"~" & $string($base) & " credits/src frame"})
)
""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
upscaler_model: dict,
interpolation_model: dict,
dynamic_compression_level: str = "Low",
) -> IO.NodeOutput:
upscaler_choice = upscaler_model["upscaler_model"]
interpolation_choice = interpolation_model["interpolation_model"]
if upscaler_choice == "Disabled" and interpolation_choice == "Disabled":
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
validate_container_format_is_mp4(video)
src_width, src_height = video.get_dimensions()
src_frame_rate = int(video.get_frame_rate())
duration_sec = video.get_duration()
src_video_stream = video.get_stream_source()
target_width = src_width
target_height = src_height
target_frame_rate = src_frame_rate
filters = []
if upscaler_choice != "Disabled":
if "1080p" in upscaler_model["upscaler_resolution"]:
target_pixel_p = 1080
max_long_side = 1920
else:
target_pixel_p = 2160
max_long_side = 3840
ar = src_width / src_height
if src_width >= src_height:
# Landscape or Square; Attempt to set height to target (e.g., 2160), calculate width
target_height = target_pixel_p
target_width = int(target_height * ar)
# Check if width exceeds standard bounds (for ultra-wide e.g., 21:9 ARs)
if target_width > max_long_side:
target_width = max_long_side
target_height = int(target_width / ar)
else:
# Portrait; Attempt to set width to target (e.g., 2160), calculate height
target_width = target_pixel_p
target_height = int(target_width / ar)
# Check if height exceeds standard bounds
if target_height > max_long_side:
target_height = max_long_side
target_width = int(target_height * ar)
if target_width % 2 != 0:
target_width += 1
if target_height % 2 != 0:
target_height += 1
model_id = UPSCALER_MODELS_MAP[upscaler_choice]
if model_id == "slc-1":
filters.append(
VideoEnhancementFilter(
model=model_id,
creativity=upscaler_model["creativity"],
isOptimizedMode=True,
)
)
elif model_id == "ast-2":
n_frames = video.get_frame_count()
ast2_prompt = (upscaler_model["prompt"] or "").strip()
if ast2_prompt and n_frames > AST2_MAX_FRAMES_WITH_PROMPT:
raise ValueError(
f"Astra 2 with a prompt is limited to {AST2_MAX_FRAMES_WITH_PROMPT} input frames "
f"(~15s @ 30fps); video has {n_frames}. Clear the prompt or shorten the clip."
)
if n_frames > AST2_MAX_FRAMES:
raise ValueError(f"Astra 2 is limited to {AST2_MAX_FRAMES} input frames; video has {n_frames}.")
realism = upscaler_model["realism"]
filters.append(
VideoEnhancementFilter(
model=model_id,
creativity=upscaler_model["creativity"],
prompt=(ast2_prompt or None),
sharp=upscaler_model["sharp"],
realism=(realism if realism > 0 else None),
)
)
else:
filters.append(VideoEnhancementFilter(model=model_id))
if interpolation_choice != "Disabled":
target_frame_rate = interpolation_model["interpolation_frame_rate"]
filters.append(
VideoFrameInterpolationFilter(
model=interpolation_choice,
slowmo=interpolation_model["interpolation_slowmo"],
fps=interpolation_model["interpolation_frame_rate"],
duplicate=interpolation_model["interpolation_duplicate"],
duplicate_threshold=interpolation_model["interpolation_duplicate_threshold"],
),
)
initial_res = await sync_op(
cls,
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
response_model=CreateVideoResponse,
data=CreateVideoRequest(
source=CreateVideoRequestSource(
container="mp4",
size=get_fs_object_size(src_video_stream),
duration=int(duration_sec),
frameCount=video.get_frame_count(),
frameRate=src_frame_rate,
resolution=Resolution(width=src_width, height=src_height),
),
filters=filters,
output=OutputInformationVideo(
resolution=Resolution(width=target_width, height=target_height),
frameRate=target_frame_rate,
audioCodec="AAC",
audioTransfer="Copy",
dynamicCompressionLevel=dynamic_compression_level,
),
),
wait_label="Creating task",
final_label_on_success="Task created",
)
upload_res = await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
method="PATCH",
),
response_model=VideoAcceptResponse,
wait_label="Preparing upload",
final_label_on_success="Upload started",
)
if len(upload_res.urls) > 1:
raise NotImplementedError(
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
)
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
if isinstance(src_video_stream, BytesIO):
src_video_stream.seek(0)
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
else:
with builtins.open(src_video_stream, "rb") as video_file:
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
upload_etag = res.headers["Etag"]
await sync_op(
cls,
ApiEndpoint(
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
method="PATCH",
),
response_model=VideoCompleteUploadResponse,
data=VideoCompleteUploadRequest(
uploadResults=[
VideoCompleteUploadRequestPart(
partNum=1,
eTag=upload_etag,
),
],
),
wait_label="Finalizing upload",
final_label_on_success="Upload completed",
)
final_response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
response_model=VideoStatusResponse,
status_extractor=lambda x: x.status,
progress_extractor=lambda x: getattr(x, "progress", 0),
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
poll_interval=10.0,
)
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
class TopazExtension(ComfyExtension): class TopazExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
TopazImageEnhance, TopazImageEnhance,
TopazVideoEnhance, TopazVideoEnhance,
TopazVideoEnhanceV2,
] ]

View File

@ -202,14 +202,11 @@ class JoinImageWithAlpha(io.ComfyNode):
@classmethod @classmethod
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput: def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
batch_size = min(len(image), len(alpha)) batch_size = max(len(image), len(alpha))
out_images = []
alpha = 1.0 - resize_mask(alpha, image.shape[1:]) alpha = 1.0 - resize_mask(alpha, image.shape[1:])
for i in range(batch_size): alpha = comfy.utils.repeat_to_batch_size(alpha, batch_size)
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) image = comfy.utils.repeat_to_batch_size(image, batch_size)
return io.NodeOutput(torch.cat((image[..., :3], alpha.unsqueeze(-1)), dim=-1))
return io.NodeOutput(torch.stack(out_images))
class CompositingExtension(ComfyExtension): class CompositingExtension(ComfyExtension):

View File

@ -666,12 +666,13 @@ class ColorTransfer(io.ComfyNode):
def define_schema(cls): def define_schema(cls):
return io.Schema( return io.Schema(
node_id="ColorTransfer", node_id="ColorTransfer",
display_name="Color Transfer",
category="image/postprocessing", category="image/postprocessing",
description="Match the colors of one image to another using various algorithms.", description="Match the colors of one image to another using various algorithms.",
search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"], search_aliases=["color match", "color grading", "color correction", "match colors", "color transform", "mkl", "reinhard", "histogram"],
inputs=[ inputs=[
io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."), io.Image.Input("image_target", tooltip="Image(s) to apply the color transform to."),
io.Image.Input("image_ref", optional=True, tooltip="Reference image(s) to match colors to. If not provided, processing is skipped"), io.Image.Input("image_ref", tooltip="Reference image(s) to match colors to."),
io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],), io.Combo.Input("method", options=['reinhard_lab', 'mkl_lab', 'histogram'],),
io.DynamicCombo.Input("source_stats", io.DynamicCombo.Input("source_stats",
tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)", tooltip="per_frame: each frame matched to image_ref individually. uniform: pool stats across all source frames as baseline, match to image_ref. target_frame: use one chosen frame as the baseline for the transform to image_ref, applied uniformly to all frames (preserves relative differences)",

View File

@ -49,7 +49,7 @@ class Int(io.ComfyNode):
display_name="Int", display_name="Int",
category="utils/primitive", category="utils/primitive",
inputs=[ inputs=[
io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=True), io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed),
], ],
outputs=[io.Int.Output()], outputs=[io.Int.Output()],
) )

View File

@ -32,6 +32,8 @@ class TextGenerate(io.ComfyNode):
io.Clip.Input("clip"), io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""), io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("image", optional=True), io.Image.Input("image", optional=True),
io.Image.Input("video", optional=True, tooltip="Video frames as image batch. Assumed to be 24 FPS; subsampled to 1 FPS internally."),
io.Audio.Input("audio", optional=True),
io.Int.Input("max_length", default=256, min=1, max=2048), io.Int.Input("max_length", default=256, min=1, max=2048),
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"), io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."), io.Boolean.Input("thinking", optional=True, default=False, tooltip="Operate in thinking mode if the model supports it."),
@ -43,9 +45,9 @@ class TextGenerate(io.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput: def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking) tokens = clip.tokenize(prompt, image=image, skip_template=not use_default_template, min_length=1, thinking=thinking, video=video, audio=audio)
# Get sampling parameters from dynamic combo # Get sampling parameters from dynamic combo
do_sample = sampling_mode.get("sampling_mode") == "on" do_sample = sampling_mode.get("sampling_mode") == "on"
@ -70,7 +72,8 @@ class TextGenerate(io.ComfyNode):
seed=seed seed=seed
) )
generated_text = clip.decode(generated_ids, skip_special_tokens=True) generated_text = clip.decode(generated_ids)
return io.NodeOutput(generated_text) return io.NodeOutput(generated_text)
@ -161,12 +164,12 @@ class TextGenerateLTX2Prompt(TextGenerate):
) )
@classmethod @classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True) -> io.NodeOutput: def execute(cls, clip, prompt, max_length, sampling_mode, image=None, thinking=False, use_default_template=True, video=None, audio=None) -> io.NodeOutput:
if image is None: if image is None:
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
else: else:
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n" formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image, thinking, use_default_template) return super().execute(clip, formatted_prompt, max_length, sampling_mode, image=image, thinking=thinking, use_default_template=use_default_template, video=video, audio=audio)
class TextgenExtension(ComfyExtension): class TextgenExtension(ComfyExtension):

View File

@ -28,7 +28,7 @@
#config for a1111 ui #config for a1111 ui
#all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed #all you have to do is uncomment this (remove the #) and change the base_path to where yours is installed
#a111: #a1111:
# base_path: path/to/stable-diffusion-webui/ # base_path: path/to/stable-diffusion-webui/
# checkpoints: models/Stable-diffusion # checkpoints: models/Stable-diffusion
# configs: models/Stable-diffusion # configs: models/Stable-diffusion

View File

@ -86,6 +86,6 @@ def image_alpha_fix(destination, source):
if destination.shape[-1] < source.shape[-1]: if destination.shape[-1] < source.shape[-1]:
source = source[...,:destination.shape[-1]] source = source[...,:destination.shape[-1]]
elif destination.shape[-1] > source.shape[-1]: elif destination.shape[-1] > source.shape[-1]:
destination = torch.nn.functional.pad(destination, (0, 1)) source = torch.nn.functional.pad(source, (0, 1))
destination[..., -1] = 1.0 source[..., -1] = 1.0
return destination, source return destination, source

View File

@ -1754,57 +1754,49 @@ class LoadImage:
return True return True
class LoadImageMask:
class LoadImageMask(LoadImage):
ESSENTIALS_CATEGORY = "Image Tools" ESSENTIALS_CATEGORY = "Image Tools"
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"] SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
_color_channels = ["alpha", "red", "green", "blue"] _color_channels = ["alpha", "red", "green", "blue"]
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory() types = super().INPUT_TYPES()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))] return {
return {"required": "required": {
{"image": (sorted(files), {"image_upload": True}), **types["required"],
"channel": (s._color_channels, ), } "channel": (s._color_channels, )
}
} }
CATEGORY = "mask" CATEGORY = "mask"
RETURN_TYPES = ("MASK",) RETURN_TYPES = ("MASK",)
FUNCTION = "load_image" FUNCTION = "load_image_mask"
def load_image(self, image, channel):
image_path = folder_paths.get_annotated_filepath(image) def load_image_mask(self, image, channel):
i = node_helpers.pillow(Image.open, image_path) image_tensor, mask_tensor = super().load_image(image)
i = node_helpers.pillow(ImageOps.exif_transpose, i)
if i.getbands() != ("R", "G", "B", "A"):
if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
i = i.convert("RGBA")
mask = None
c = channel[0].upper() c = channel[0].upper()
if c in i.getbands():
mask = np.array(i.getchannel(c)).astype(np.float32) / 255.0
mask = torch.from_numpy(mask)
if c == 'A': if c == 'A':
mask = 1. - mask return (mask_tensor,)
channel_idx = {'R': 0, 'G': 1, 'B': 2}.get(c, 0)
if channel_idx < image_tensor.shape[-1]:
return (image_tensor[..., channel_idx].clone(),)
else: else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") empty_mask = torch.zeros(
return (mask.unsqueeze(0),) image_tensor.shape[:-1],
dtype=image_tensor.dtype,
device=image_tensor.device
)
return (empty_mask,)
@classmethod @classmethod
def IS_CHANGED(s, image, channel): def IS_CHANGED(s, image, channel):
image_path = folder_paths.get_annotated_filepath(image) return super().IS_CHANGED(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()
@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)
return True
class LoadImageOutput(LoadImage): class LoadImageOutput(LoadImage):

View File

@ -1,3 +1,4 @@
import errno
import os import os
import sys import sys
import asyncio import asyncio
@ -1245,7 +1246,13 @@ class PromptServer():
address = addr[0] address = addr[0]
port = addr[1] port = addr[1]
site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx) site = web.TCPSite(runner, address, port, ssl_context=ssl_ctx)
try:
await site.start() await site.start()
except OSError as e:
if e.errno == errno.EADDRINUSE:
logging.error(f"Port {port} is already in use on address {address}. Please close the other application or use a different port with --port.")
raise SystemExit(1)
raise
if not hasattr(self, 'address'): if not hasattr(self, 'address'):
self.address = address #TODO: remove this self.address = address #TODO: remove this