Merge branch 'Comfy-Org:master' into fix/standardize-temp-filenames

This commit is contained in:
Dnak-jb 2026-02-18 20:42:04 -06:00 committed by GitHub
commit 918a670f0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
50 changed files with 2258 additions and 241 deletions

2
.gitignore vendored
View File

@ -11,7 +11,7 @@ extra_model_paths.yaml
/.vs /.vs
.vscode/ .vscode/
.idea/ .idea/
venv/ venv*/
.venv/ .venv/
/web/extensions/* /web/extensions/*
!/web/extensions/logging.js.example !/web/extensions/logging.js.example

105
app/node_replace_manager.py Normal file
View File

@ -0,0 +1,105 @@
from __future__ import annotations
from aiohttp import web
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from comfy_api.latest._io_public import NodeReplace
from comfy_execution.graph_utils import is_link
import nodes
class NodeStruct(TypedDict):
inputs: dict[str, str | int | float | bool | tuple[str, int]]
class_type: str
_meta: dict[str, str]
def copy_node_struct(node_struct: NodeStruct, empty_inputs: bool = False) -> NodeStruct:
new_node_struct = node_struct.copy()
if empty_inputs:
new_node_struct["inputs"] = {}
else:
new_node_struct["inputs"] = node_struct["inputs"].copy()
new_node_struct["_meta"] = node_struct["_meta"].copy()
return new_node_struct
class NodeReplaceManager:
"""Manages node replacement registrations."""
def __init__(self):
self._replacements: dict[str, list[NodeReplace]] = {}
def register(self, node_replace: NodeReplace):
"""Register a node replacement mapping."""
self._replacements.setdefault(node_replace.old_node_id, []).append(node_replace)
def get_replacement(self, old_node_id: str) -> list[NodeReplace] | None:
"""Get replacements for an old node ID."""
return self._replacements.get(old_node_id)
def has_replacement(self, old_node_id: str) -> bool:
"""Check if a replacement exists for an old node ID."""
return old_node_id in self._replacements
def apply_replacements(self, prompt: dict[str, NodeStruct]):
connections: dict[str, list[tuple[str, str, int]]] = {}
need_replacement: set[str] = set()
for node_number, node_struct in prompt.items():
class_type = node_struct["class_type"]
# need replacement if not in NODE_CLASS_MAPPINGS and has replacement
if class_type not in nodes.NODE_CLASS_MAPPINGS.keys() and self.has_replacement(class_type):
need_replacement.add(node_number)
# keep track of connections
for input_id, input_value in node_struct["inputs"].items():
if is_link(input_value):
conn_number = input_value[0]
connections.setdefault(conn_number, []).append((node_number, input_id, input_value[1]))
for node_number in need_replacement:
node_struct = prompt[node_number]
class_type = node_struct["class_type"]
replacements = self.get_replacement(class_type)
if replacements is None:
continue
# just use the first replacement
replacement = replacements[0]
new_node_id = replacement.new_node_id
# if replacement is not a valid node, skip trying to replace it as will only cause confusion
if new_node_id not in nodes.NODE_CLASS_MAPPINGS.keys():
continue
# first, replace node id (class_type)
new_node_struct = copy_node_struct(node_struct, empty_inputs=True)
new_node_struct["class_type"] = new_node_id
# TODO: consider replacing display_name in _meta as well for error reporting purposes; would need to query node schema
# second, replace inputs
if replacement.input_mapping is not None:
for input_map in replacement.input_mapping:
if "set_value" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = input_map["set_value"]
elif "old_id" in input_map:
new_node_struct["inputs"][input_map["new_id"]] = node_struct["inputs"][input_map["old_id"]]
# finalize input replacement
prompt[node_number] = new_node_struct
# third, replace outputs
if replacement.output_mapping is not None:
# re-mapping outputs requires changing the input values of nodes that receive connections from this one
if node_number in connections:
for conns in connections[node_number]:
conn_node_number, conn_input_id, old_output_idx = conns
for output_map in replacement.output_mapping:
if output_map["old_idx"] == old_output_idx:
new_output_idx = output_map["new_idx"]
previous_input = prompt[conn_node_number]["inputs"][conn_input_id]
previous_input[1] = new_output_idx
def as_dict(self):
"""Serialize all replacements to dict."""
return {
k: [v.as_dict() for v in v_list]
for k, v_list in self._replacements.items()
}
def add_routes(self, routes):
@routes.get("/node_replacements")
async def get_node_replacements(request):
return web.json_response(self.as_dict())

View File

@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
if source_attention_mask.ndim == 2: if source_attention_mask.ndim == 2:
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1) source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
x = self.in_proj(self.embed(target_input_ids))
context = source_hidden_states context = source_hidden_states
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0) position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0) position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
position_embeddings = self.rotary_emb(x, position_ids) position_embeddings = self.rotary_emb(x, position_ids)

View File

@ -152,6 +152,7 @@ class Chroma(nn.Module):
transformer_options={}, transformer_options={},
attn_mask: Tensor = None, attn_mask: Tensor = None,
) -> Tensor: ) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
# running on sequences img # running on sequences img
@ -228,6 +229,7 @@ class Chroma(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single" transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i transformer_options["block_index"] = i
if i not in self.skip_dit: if i not in self.skip_dit:

View File

@ -196,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
else: else:
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
# prepare image for attention # prepare image for attention
img_modulated = self.img_norm1(img) img_modulated = self.img_norm1(img)
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img) img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
@ -224,6 +227,12 @@ class DoubleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v del q, k, v
if "attn1_output_patch" in transformer_patches:
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks # calculate the img bloks
@ -303,6 +312,9 @@ class SingleStreamBlock(nn.Module):
else: else:
mod = vec mod = vec
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1) qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
@ -312,6 +324,12 @@ class SingleStreamBlock(nn.Module):
# compute attention # compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v del q, k, v
if "attn1_output_patch" in transformer_patches:
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
# compute activation in mlp stream, cat again and run second linear layer # compute activation in mlp stream, cat again and run second linear layer
if self.yak_mlp: if self.yak_mlp:
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2] mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]

View File

@ -142,6 +142,7 @@ class Flux(nn.Module):
attn_mask: Tensor = None, attn_mask: Tensor = None,
) -> Tensor: ) -> Tensor:
transformer_options = transformer_options.copy()
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
if img.ndim != 3 or txt.ndim != 3: if img.ndim != 3 or txt.ndim != 3:
@ -231,6 +232,7 @@ class Flux(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single" transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:

View File

@ -304,6 +304,7 @@ class HunyuanVideo(nn.Module):
control=None, control=None,
transformer_options={}, transformer_options={},
) -> Tensor: ) -> Tensor:
transformer_options = transformer_options.copy()
patches_replace = transformer_options.get("patches_replace", {}) patches_replace = transformer_options.get("patches_replace", {})
initial_shape = list(img.shape) initial_shape = list(img.shape)
@ -416,6 +417,7 @@ class HunyuanVideo(nn.Module):
transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single" transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
for i, block in enumerate(self.single_blocks): for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace: if ("single_block", i) in blocks_replace:

View File

@ -102,19 +102,7 @@ class VideoConv3d(nn.Module):
return self.conv(x) return self.conv(x)
def interpolate_up(x, scale_factor): def interpolate_up(x, scale_factor):
try:
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest") return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
except: #operation not implemented for bf16
orig_shape = list(x.shape)
out_shape = orig_shape[:2]
for i in range(len(orig_shape) - 2):
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
return out
class Upsample(nn.Module): class Upsample(nn.Module):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0): def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):

View File

@ -374,6 +374,31 @@ def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Ten
return padded_tensor return padded_tensor
def calculate_shape(patches, weight, key, original_weights=None):
current_shape = weight.shape
for p in patches:
v = p[1]
offset = p[3]
# Offsets restore the old shape; lists force a diff without metadata
if offset is not None or isinstance(v, list):
continue
if isinstance(v, weight_adapter.WeightAdapterBase):
adapter_shape = v.calculate_shape(key)
if adapter_shape is not None:
current_shape = adapter_shape
continue
# Standard diff logic with padding
if len(v) == 2:
patch_type, patch_data = v[0], v[1]
if patch_type == "diff" and len(patch_data) > 1 and patch_data[1]['pad_weight']:
current_shape = patch_data[0].shape
return current_shape
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None): def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, original_weights=None):
for p in patches: for p in patches:
strength = p[0] strength = p[0]

View File

@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module):
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1) xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
context = c_crossattn context = c_crossattn
dtype = self.get_dtype() dtype = self.get_dtype_inference()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
xc = xc.to(dtype) xc = xc.to(dtype)
device = xc.device device = xc.device
@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module):
def get_dtype(self): def get_dtype(self):
return self.diffusion_model.dtype return self.diffusion_model.dtype
def get_dtype_inference(self):
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
return dtype
def encode_adm(self, **kwargs): def encode_adm(self, **kwargs):
return None return None
@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module):
input_shapes += shape input_shapes += shape
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention(): if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype() dtype = self.get_dtype_inference()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked #TODO: this needs to be tweaked
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes)) area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024) return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
@ -1165,7 +1167,7 @@ class Anima(BaseModel):
t5xxl_ids = t5xxl_ids.unsqueeze(0) t5xxl_ids = t5xxl_ids.unsqueeze(0)
if torch.is_inference_mode_enabled(): # if not we are training if torch.is_inference_mode_enabled(): # if not we are training
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype())) cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
else: else:
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids) out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights) out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)

View File

@ -406,13 +406,16 @@ class ModelPatcher:
def memory_required(self, input_shape): def memory_required(self, input_shape):
return self.model.memory_required(input_shape=input_shape) return self.model.memory_required(input_shape=input_shape)
def disable_model_cfg1_optimization(self):
self.model_options["disable_cfg1_optimization"] = True
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
if len(inspect.signature(sampler_cfg_function).parameters) == 3: if len(inspect.signature(sampler_cfg_function).parameters) == 3:
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
else: else:
self.model_options["sampler_cfg_function"] = sampler_cfg_function self.model_options["sampler_cfg_function"] = sampler_cfg_function
if disable_cfg1_optimization: if disable_cfg1_optimization:
self.model_options["disable_cfg1_optimization"] = True self.disable_model_cfg1_optimization()
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False): def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization) self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
@ -1514,8 +1517,10 @@ class ModelPatcherDynamic(ModelPatcher):
weight, _, _ = get_key_weight(self.model, key) weight, _, _ = get_key_weight(self.model, key)
if weight is None: if weight is None:
return 0 return (False, 0)
if key in self.patches: if key in self.patches:
if comfy.lora.calculate_shape(self.patches[key], weight, key) != weight.shape:
return (True, 0)
setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches)) setattr(m, param_key + "_lowvram_function", LowVramPatch(key, self.patches))
num_patches += 1 num_patches += 1
else: else:
@ -1529,7 +1534,13 @@ class ModelPatcherDynamic(ModelPatcher):
model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype model_dtype = getattr(m, param_key + "_comfy_model_dtype", None) or weight.dtype
weight._model_dtype = model_dtype weight._model_dtype = model_dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype) geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
return comfy.memory_management.vram_aligned_size(geometry) return (False, comfy.memory_management.vram_aligned_size(geometry))
def force_load_param(self, param_key, device_to):
key = key_param_name_to_key(n, param_key)
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to)
if hasattr(m, "comfy_cast_weights"): if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True m.comfy_cast_weights = True
@ -1537,10 +1548,16 @@ class ModelPatcherDynamic(ModelPatcher):
m.seed_key = n m.seed_key = n
set_dirty(m, dirty) set_dirty(m, dirty)
v_weight_size = 0 force_load, v_weight_size = setup_param(self, m, n, "weight")
v_weight_size += setup_param(self, m, n, "weight") force_load_bias, v_weight_bias = setup_param(self, m, n, "bias")
v_weight_size += setup_param(self, m, n, "bias") force_load = force_load or force_load_bias
v_weight_size += v_weight_bias
if force_load:
logging.info(f"Module {n} has resizing Lora - force loading")
force_load_param(self, "weight", device_to)
force_load_param(self, "bias", device_to)
else:
if vbar is not None and not hasattr(m, "_v"): if vbar is not None and not hasattr(m, "_v"):
m._v = vbar.alloc(v_weight_size) m._v = vbar.alloc(v_weight_size)
allocated_size += v_weight_size allocated_size += v_weight_size
@ -1606,6 +1623,11 @@ class ModelPatcherDynamic(ModelPatcher):
for m in self.model.modules(): for m in self.model.modules():
move_weight_functions(m, device_to) move_weight_functions(m, device_to)
keys = list(self.backup.keys())
for k in keys:
bk = self.backup[k]
comfy.utils.set_attr_param(self.model, k, bk.weight)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False): def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above assert not force_patch_weights #See above
with self.use_ejected(skip_and_inject_on_exit_only=True): with self.use_ejected(skip_and_inject_on_exit_only=True):

View File

@ -21,7 +21,6 @@ import logging
import comfy.model_management import comfy.model_management
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
import comfy.float import comfy.float
import comfy.rmsnorm
import json import json
import comfy.memory_management import comfy.memory_management
import comfy.pinned_memory import comfy.pinned_memory
@ -80,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype): def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
offload_stream = None offload_stream = None
xfer_dest = None xfer_dest = None
@ -171,10 +170,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
#FIXME: this is not accurate, we need to be sensitive to the compute dtype #FIXME: this is not accurate, we need to be sensitive to the compute dtype
x = lowvram_fn(x) x = lowvram_fn(x)
if (isinstance(orig, QuantizedTensor) and if (isinstance(orig, QuantizedTensor) and
(orig.dtype == dtype and len(fns) == 0 or update_weight)): (want_requant and len(fns) == 0 or update_weight)):
seed = comfy.utils.string_to_seed(s.seed_key) seed = comfy.utils.string_to_seed(s.seed_key)
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed) y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
if orig.dtype == dtype and len(fns) == 0: if want_requant and len(fns) == 0:
#The layer actually wants our freshly saved QT #The layer actually wants our freshly saved QT
x = y x = y
elif update_weight: elif update_weight:
@ -195,7 +194,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
return weight, bias, (offload_stream, device if signature is not None else None, None) return weight, bias, (offload_stream, device if signature is not None else None, None)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None): def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
# will add async-offload support to your cast and improve performance. # will add async-offload support to your cast and improve performance.
@ -213,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
non_blocking = comfy.model_management.device_supports_non_blocking(device) non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"): if hasattr(s, "_v"):
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype) return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
if offloadable and (device != s.weight.device or if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)): (s.bias is not None and device != s.bias.device)):
@ -463,7 +462,7 @@ class disable_weight_init:
else: else:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp): class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
def reset_parameters(self): def reset_parameters(self):
self.bias = None self.bias = None
return None return None
@ -475,8 +474,7 @@ class disable_weight_init:
weight = None weight = None
bias = None bias = None
offload_stream = None offload_stream = None
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
uncast_bias_weight(self, weight, bias, offload_stream) uncast_bias_weight(self, weight, bias, offload_stream)
return x return x
@ -852,8 +850,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def _forward(self, input, weight, bias): def _forward(self, input, weight, bias):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def forward_comfy_cast_weights(self, input, compute_dtype=None): def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype) weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
x = self._forward(input, weight, bias) x = self._forward(input, weight, bias)
uncast_bias_weight(self, weight, bias, offload_stream) uncast_bias_weight(self, weight, bias, offload_stream)
return x return x
@ -883,8 +881,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
scale = comfy.model_management.cast_to_device(scale, input.device, None) scale = comfy.model_management.cast_to_device(scale, input.device, None)
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale) input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
output = self.forward_comfy_cast_weights(input, compute_dtype)
# Reshape output back to 3D if input was 3D # Reshape output back to 3D if input was 3D
if reshaped_3d: if reshaped_3d:

View File

@ -1,57 +1,10 @@
import torch import torch
import comfy.model_management import comfy.model_management
import numbers
import logging
RMSNorm = None
try:
rms_norm_torch = torch.nn.functional.rms_norm
RMSNorm = torch.nn.RMSNorm
except:
rms_norm_torch = None
logging.warning("Please update pytorch to use native RMSNorm")
RMSNorm = torch.nn.RMSNorm
def rms_norm(x, weight=None, eps=1e-6): def rms_norm(x, weight=None, eps=1e-6):
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
if weight is None: if weight is None:
return rms_norm_torch(x, (x.shape[-1],), eps=eps) return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
else: else:
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps) return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
if weight is None:
return r
else:
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
if RMSNorm is None:
class RMSNorm(torch.nn.Module):
def __init__(
self,
normalized_shape,
eps=1e-6,
elementwise_affine=True,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = torch.nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
else:
self.register_parameter("weight", None)
self.bias = None
def forward(self, x):
return rms_norm(x, self.weight, self.eps)

View File

@ -423,6 +423,19 @@ class CLIP:
def get_key_patches(self): def get_key_patches(self):
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options()
if self.layer_idx is not None:
self.cond_stage_model.set_clip_options({"layer": self.layer_idx})
self.load_model()
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
class VAE: class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
@ -1182,6 +1195,7 @@ class TEModel(Enum):
JINA_CLIP_2 = 19 JINA_CLIP_2 = 19
QWEN3_8B = 20 QWEN3_8B = 20
QWEN3_06B = 21 QWEN3_06B = 21
GEMMA_3_4B_VISION = 22
def detect_te_model(sd): def detect_te_model(sd):
@ -1210,6 +1224,9 @@ def detect_te_model(sd):
if 'model.layers.47.self_attn.q_norm.weight' in sd: if 'model.layers.47.self_attn.q_norm.weight' in sd:
return TEModel.GEMMA_3_12B return TEModel.GEMMA_3_12B
if 'model.layers.0.self_attn.q_norm.weight' in sd: if 'model.layers.0.self_attn.q_norm.weight' in sd:
if 'vision_model.embeddings.patch_embedding.weight' in sd:
return TEModel.GEMMA_3_4B_VISION
else:
return TEModel.GEMMA_3_4B return TEModel.GEMMA_3_4B
return TEModel.GEMMA_2_2B return TEModel.GEMMA_2_2B
if 'model.layers.0.self_attn.k_proj.bias' in sd: if 'model.layers.0.self_attn.k_proj.bias' in sd:
@ -1270,6 +1287,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
else: else:
if "text_projection" in clip_data[i]: if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
if "lm_head.weight" in clip_data[i]:
clip_data[i]["model.lm_head.weight"] = clip_data[i].pop("lm_head.weight") # prefix missing in some models
tokenizer_data = {} tokenizer_data = {}
clip_target = EmptyClass() clip_target = EmptyClass()
@ -1335,6 +1354,14 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b") clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_4B_VISION:
clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data), model_type="gemma3_4b_vision")
clip_target.tokenizer = comfy.text_encoders.lumina2.NTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.GEMMA_3_12B:
clip_target.clip = comfy.text_encoders.lt.gemma3_te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.Gemma3_12BTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif te_model == TEModel.LLAMA3_8: elif te_model == TEModel.LLAMA3_8:
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None) clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)

View File

@ -308,6 +308,15 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def load_sd(self, sd): def load_sd(self, sd):
return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False)) return self.transformer.load_state_dict(sd, strict=False, assign=getattr(self, "can_assign_sd", False))
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[]):
if isinstance(tokens, dict):
tokens_only = next(iter(tokens.values())) # todo: get this better?
else:
tokens_only = tokens
tokens_only = [[t[0] for t in b] for b in tokens_only]
embeds = self.process_tokens(tokens_only, device=self.execution_device)[0]
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens)
def parse_parentheses(string): def parse_parentheses(string):
result = [] result = []
current_item = "" current_item = ""
@ -663,6 +672,9 @@ class SDTokenizer:
def state_dict(self): def state_dict(self):
return {} return {}
def decode(self, token_ids, skip_special_tokens=True):
return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
class SD1Tokenizer: class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None): def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
if name is not None: if name is not None:
@ -686,6 +698,9 @@ class SD1Tokenizer:
def state_dict(self): def state_dict(self):
return getattr(self, self.clip).state_dict() return getattr(self, self.clip).state_dict()
def decode(self, token_ids, skip_special_tokens=True):
return getattr(self, self.clip).decode(token_ids, skip_special_tokens=skip_special_tokens)
class SD1CheckpointClipModel(SDClipModel): class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options) super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options)
@ -722,3 +737,6 @@ class SD1ClipModel(torch.nn.Module):
def load_sd(self, sd): def load_sd(self, sd):
return getattr(self, self.clip).load_sd(sd) return getattr(self, self.clip).load_sd(sd)
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
return getattr(self, self.clip).generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)

View File

@ -3,6 +3,8 @@ import torch.nn as nn
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Any, Tuple from typing import Optional, Any, Tuple
import math import math
from tqdm import tqdm
import comfy.utils
from comfy.ldm.modules.attention import optimized_attention_for_device from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management import comfy.model_management
@ -313,6 +315,13 @@ class Gemma3_4B_Config:
final_norm: bool = True final_norm: bool = True
lm_head: bool = False lm_head: bool = False
GEMMA3_VISION_CONFIG = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14}
@dataclass
class Gemma3_4B_Vision_Config(Gemma3_4B_Config):
vision_config = GEMMA3_VISION_CONFIG
mm_tokens_per_image = 256
@dataclass @dataclass
class Gemma3_12B_Config: class Gemma3_12B_Config:
vocab_size: int = 262208 vocab_size: int = 262208
@ -336,7 +345,7 @@ class Gemma3_12B_Config:
rope_scale = [8.0, 1.0] rope_scale = [8.0, 1.0]
final_norm: bool = True final_norm: bool = True
lm_head: bool = False lm_head: bool = False
vision_config = {"num_channels": 3, "hidden_act": "gelu_pytorch_tanh", "hidden_size": 1152, "image_size": 896, "intermediate_size": 4304, "model_type": "siglip_vision_model", "num_attention_heads": 16, "num_hidden_layers": 27, "patch_size": 14} vision_config = GEMMA3_VISION_CONFIG
mm_tokens_per_image = 256 mm_tokens_per_image = 256
class RMSNorm(nn.Module): class RMSNorm(nn.Module):
@ -441,8 +450,10 @@ class Attention(nn.Module):
freqs_cis: Optional[torch.Tensor] = None, freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None, optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
sliding_window: Optional[int] = None,
): ):
batch_size, seq_length, _ = hidden_states.shape batch_size, seq_length, _ = hidden_states.shape
xq = self.q_proj(hidden_states) xq = self.q_proj(hidden_states)
xk = self.k_proj(hidden_states) xk = self.k_proj(hidden_states)
xv = self.v_proj(hidden_states) xv = self.v_proj(hidden_states)
@ -477,6 +488,11 @@ class Attention(nn.Module):
else: else:
present_key_value = (xk, xv, index + num_tokens) present_key_value = (xk, xv, index + num_tokens)
if sliding_window is not None and xk.shape[2] > sliding_window:
xk = xk[:, :, -sliding_window:]
xv = xv[:, :, -sliding_window:]
attention_mask = attention_mask[..., -sliding_window:] if attention_mask is not None else None
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1) xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
@ -559,10 +575,12 @@ class TransformerBlockGemma2(nn.Module):
optimized_attention=None, optimized_attention=None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
): ):
sliding_window = None
if self.transformer_type == 'gemma3': if self.transformer_type == 'gemma3':
if self.sliding_attention: if self.sliding_attention:
sliding_window = self.sliding_attention
if x.shape[1] > self.sliding_attention: if x.shape[1] > self.sliding_attention:
sliding_mask = torch.full((x.shape[1], x.shape[1]), float("-inf"), device=x.device, dtype=x.dtype) sliding_mask = torch.full((x.shape[1], x.shape[1]), torch.finfo(x.dtype).min, device=x.device, dtype=x.dtype)
sliding_mask.tril_(diagonal=-self.sliding_attention) sliding_mask.tril_(diagonal=-self.sliding_attention)
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask + sliding_mask attention_mask = attention_mask + sliding_mask
@ -581,6 +599,7 @@ class TransformerBlockGemma2(nn.Module):
freqs_cis=freqs_cis, freqs_cis=freqs_cis,
optimized_attention=optimized_attention, optimized_attention=optimized_attention,
past_key_value=past_key_value, past_key_value=past_key_value,
sliding_window=sliding_window,
) )
x = self.post_attention_layernorm(x) x = self.post_attention_layernorm(x)
@ -765,6 +784,104 @@ class BaseLlama:
def forward(self, input_ids, *args, **kwargs): def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs) return self.model(input_ids, *args, **kwargs)
class BaseGenerate:
def logits(self, x):
input = x[:, -1:]
if hasattr(self.model, "lm_head"):
module = self.model.lm_head
else:
module = self.model.embed_tokens
offload_stream = None
if module.comfy_cast_weights:
weight, _, offload_stream = comfy.ops.cast_bias_weight(module, input, offloadable=True)
else:
weight = self.model.embed_tokens.weight.to(x)
x = torch.nn.functional.linear(input, weight, None)
comfy.ops.uncast_bias_weight(module, weight, None, offload_stream)
return x
def generate(self, embeds=None, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.9, min_p=0.0, repetition_penalty=1.0, seed=42, stop_tokens=[], initial_tokens=[], execution_dtype=None, min_tokens=0):
device = embeds.device
model_config = self.model.config
if execution_dtype is None:
if comfy.model_management.should_use_bf16(device):
execution_dtype = torch.bfloat16
else:
execution_dtype = torch.float32
embeds = embeds.to(execution_dtype)
if embeds.ndim == 2:
embeds = embeds.unsqueeze(0)
past_key_values = [] #kv_cache init
max_cache_len = embeds.shape[1] + max_length
for x in range(model_config.num_hidden_layers):
past_key_values.append((torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype),
torch.empty([embeds.shape[0], model_config.num_key_value_heads, max_cache_len, model_config.head_dim], device=device, dtype=execution_dtype), 0))
generator = torch.Generator(device=device).manual_seed(seed) if do_sample else None
generated_token_ids = []
pbar = comfy.utils.ProgressBar(max_length)
# Generation loop
for step in tqdm(range(max_length), desc="Generating tokens"):
x, _, past_key_values = self.model.forward(None, embeds=embeds, attention_mask=None, past_key_values=past_key_values)
logits = self.logits(x)[:, -1]
next_token = self.sample_token(logits, temperature, top_k, top_p, min_p, repetition_penalty, initial_tokens + generated_token_ids, generator, do_sample=do_sample)
token_id = next_token[0].item()
generated_token_ids.append(token_id)
embeds = self.model.embed_tokens(next_token).to(execution_dtype)
pbar.update(1)
if token_id in stop_tokens:
break
return generated_token_ids
def sample_token(self, logits, temperature, top_k, top_p, min_p, repetition_penalty, token_history, generator, do_sample=True):
if not do_sample or temperature == 0.0:
return torch.argmax(logits, dim=-1, keepdim=True)
# Sampling mode
if repetition_penalty != 1.0:
for i in range(logits.shape[0]):
for token_id in set(token_history):
logits[i, token_id] *= repetition_penalty if logits[i, token_id] < 0 else 1/repetition_penalty
if temperature != 1.0:
logits = logits / temperature
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = torch.finfo(logits.dtype).min
if min_p > 0.0:
probs_before_filter = torch.nn.functional.softmax(logits, dim=-1)
top_probs, _ = probs_before_filter.max(dim=-1, keepdim=True)
min_threshold = min_p * top_probs
indices_to_remove = probs_before_filter < min_threshold
logits[indices_to_remove] = torch.finfo(logits.dtype).min
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 0] = False
indices_to_remove = torch.zeros_like(logits, dtype=torch.bool)
indices_to_remove.scatter_(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = torch.finfo(logits.dtype).min
probs = torch.nn.functional.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1, generator=generator)
class BaseQwen3: class BaseQwen3:
def logits(self, x): def logits(self, x):
input = x[:, -1:] input = x[:, -1:]
@ -871,7 +988,7 @@ class Ovis25_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Qwen25_7BVLI(BaseLlama, torch.nn.Module): class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
config = Qwen25_7BVLI_Config(**config_dict) config = Qwen25_7BVLI_Config(**config_dict)
@ -881,6 +998,9 @@ class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations) self.visual = qwen_vl.Qwen2VLVisionTransformer(hidden_size=1280, output_hidden_size=config.hidden_size, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
# todo: should this be tied or not?
#self.lm_head = operations.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
def preprocess_embed(self, embed, device): def preprocess_embed(self, embed, device):
if embed["type"] == "image": if embed["type"] == "image":
image, grid = qwen_vl.process_qwen2vl_images(embed["data"]) image, grid = qwen_vl.process_qwen2vl_images(embed["data"])
@ -923,7 +1043,7 @@ class Gemma2_2B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Gemma3_4B(BaseLlama, torch.nn.Module): class Gemma3_4B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
config = Gemma3_4B_Config(**config_dict) config = Gemma3_4B_Config(**config_dict)
@ -932,7 +1052,25 @@ class Gemma3_4B(BaseLlama, torch.nn.Module):
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype self.dtype = dtype
class Gemma3_12B(BaseLlama, torch.nn.Module): class Gemma3_4B_Vision(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
config = Gemma3_4B_Vision_Config(**config_dict)
self.num_layers = config.num_hidden_layers
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
self.dtype = dtype
self.multi_modal_projector = Gemma3MultiModalProjector(config, dtype, device, operations)
self.vision_model = comfy.clip_model.CLIPVision(config.vision_config, dtype, device, operations)
self.image_size = config.vision_config["image_size"]
def preprocess_embed(self, embed, device):
if embed["type"] == "image":
image = comfy.clip_model.clip_preprocess(embed["data"], size=self.image_size, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], crop=True)
return self.multi_modal_projector(self.vision_model(image.to(device, dtype=torch.float32))[0]), None
return None, None
class Gemma3_12B(BaseLlama, BaseGenerate, torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations): def __init__(self, config_dict, dtype, device, operations):
super().__init__() super().__init__()
config = Gemma3_12B_Config(**config_dict) config = Gemma3_12B_Config(**config_dict)

View File

@ -6,6 +6,7 @@ import comfy.text_encoders.genmo
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
import torch import torch
import comfy.utils import comfy.utils
import math
class T5XXLTokenizer(sd1_clip.SDTokenizer): class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -22,40 +23,79 @@ def ltxv_te(*args, **kwargs):
return comfy.text_encoders.genmo.mochi_te(*args, **kwargs) return comfy.text_encoders.genmo.mochi_te(*args, **kwargs)
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer): class Gemma3_Tokenizer():
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
def state_dict(self): def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()} return {"spiece_model": self.tokenizer.serialize_model()}
def tokenize_with_weights(self, text, return_word_ids=False, image=None, llama_template=None, skip_template=True, **kwargs):
self.llama_template = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n{}<end_of_turn>\n<start_of_turn>model\n"
self.llama_template_images = "<start_of_turn>system\nYou are a helpful assistant.<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>{}<end_of_turn>\n\n<start_of_turn>model\n"
if image is None:
images = []
else:
samples = image.movedim(-1, 1)
total = int(896 * 896)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled").movedim(1, -1)
images = [s[:, :, :, :3]]
if text.startswith('<start_of_turn>'):
skip_template = True
if skip_template:
llama_text = text
else:
if llama_template is None:
if len(images) > 0:
llama_text = self.llama_template_images.format(text)
else:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
text_tokens = super().tokenize_with_weights(llama_text, return_word_ids)
if len(images) > 0:
embed_count = 0
for r in text_tokens:
for i, token in enumerate(r):
if token[0] == 262144 and embed_count < len(images):
r[i] = ({"type": "image", "data": images[embed_count]},) + token[1:]
embed_count += 1
return text_tokens
class Gemma3_12BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None)
special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer): class LTXAVGemmaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_12b", tokenizer=Gemma3_12BTokenizer)
class Gemma3_12BModel(sd1_clip.SDClipModel): class Gemma3_12BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="all", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None: if llama_quantization_metadata is not None:
model_options = model_options.copy() model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata model_options["quantization_metadata"] = llama_quantization_metadata
self.dtypes = set()
self.dtypes.add(dtype)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_12B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def tokenize_with_weights(self, text, return_word_ids=False, llama_template="{}", image_embeds=None, **kwargs): def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
text = llama_template.format(text) tokens_only = [[t[0] for t in b] for b in tokens]
text_tokens = super().tokenize_with_weights(text, return_word_ids) embeds, _, _, embeds_info = self.process_tokens(tokens_only, self.execution_device)
embed_count = 0 comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
for k in text_tokens: return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
tt = text_tokens[k]
for r in tt:
for i in range(len(r)):
if r[i][0] == 262144:
if image_embeds is not None and embed_count < image_embeds.shape[0]:
r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image"},) + r[i][1:]
embed_count += 1
return text_tokens
class LTXAVTEModel(torch.nn.Module): class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
@ -112,6 +152,9 @@ class LTXAVTEModel(torch.nn.Module):
return out.to(out_device), pooled return out.to(out_device), pooled
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
def load_sd(self, sd): def load_sd(self, sd):
if "model.layers.47.self_attn.q_norm.weight" in sd: if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd) return self.gemma3_12b.load_sd(sd)
@ -152,3 +195,14 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
dtype = dtype_llama dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return LTXAVTEModel_ return LTXAVTEModel_
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
class Gemma3_12BModel_(Gemma3_12BModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Gemma3_12BModel_

View File

@ -1,23 +1,23 @@
from comfy import sd1_clip from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.llama import comfy.text_encoders.llama
from comfy.text_encoders.lt import Gemma3_Tokenizer
import comfy.utils
class Gemma2BTokenizer(sd1_clip.SDTokenizer): class Gemma2BTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None) tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) special_tokens = {"<end_of_turn>": 107}
super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, tokenizer_data=tokenizer_data)
def state_dict(self): def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()} return {"spiece_model": self.tokenizer.serialize_model()}
class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): class Gemma3_4BTokenizer(Gemma3_Tokenizer, sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = tokenizer_data.get("spiece_model", None) tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, disable_weights=True, tokenizer_data=tokenizer_data) special_tokens = {"<image_soft_token>": 262144, "<end_of_turn>": 106}
super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False, "special_tokens": special_tokens}, disable_weights=True, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LuminaTokenizer(sd1_clip.SD1Tokenizer): class LuminaTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
@ -31,6 +31,9 @@ class Gemma2_2BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[107])
class Gemma3_4BModel(sd1_clip.SDClipModel): class Gemma3_4BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
@ -40,6 +43,23 @@ class Gemma3_4BModel(sd1_clip.SDClipModel):
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def generate(self, embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return super().generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106])
class Gemma3_4B_Vision_Model(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma3_4B_Vision, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
def process_tokens(self, tokens, device):
embeds, _, _, embeds_info = super().process_tokens(tokens, device)
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return embeds
class LuminaModel(sd1_clip.SD1ClipModel): class LuminaModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel): def __init__(self, device="cpu", dtype=None, model_options={}, name="gemma2_2b", clip_model=Gemma2_2BModel):
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
@ -50,6 +70,8 @@ def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b
model = Gemma2_2BModel model = Gemma2_2BModel
elif model_type == "gemma3_4b": elif model_type == "gemma3_4b":
model = Gemma3_4BModel model = Gemma3_4BModel
elif model_type == "gemma3_4b_vision":
model = Gemma3_4B_Vision_Model
class LuminaTEModel_(LuminaModel): class LuminaTEModel_(LuminaModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options={}):

View File

@ -6,9 +6,10 @@ class SPieceTokenizer:
def from_pretrained(path, **kwargs): def from_pretrained(path, **kwargs):
return SPieceTokenizer(path, **kwargs) return SPieceTokenizer(path, **kwargs)
def __init__(self, tokenizer_path, add_bos=False, add_eos=True): def __init__(self, tokenizer_path, add_bos=False, add_eos=True, special_tokens=None):
self.add_bos = add_bos self.add_bos = add_bos
self.add_eos = add_eos self.add_eos = add_eos
self.special_tokens = special_tokens
import sentencepiece import sentencepiece
if torch.is_tensor(tokenizer_path): if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes() tokenizer_path = tokenizer_path.numpy().tobytes()
@ -27,8 +28,32 @@ class SPieceTokenizer:
return out return out
def __call__(self, string): def __call__(self, string):
if self.special_tokens is not None:
import re
special_tokens_pattern = '|'.join(re.escape(token) for token in self.special_tokens.keys())
if special_tokens_pattern and re.search(special_tokens_pattern, string):
parts = re.split(f'({special_tokens_pattern})', string)
result = []
for part in parts:
if not part:
continue
if part in self.special_tokens:
result.append(self.special_tokens[part])
else:
encoded = self.tokenizer.encode(part, add_bos=False, add_eos=False)
result.extend(encoded)
return {"input_ids": result}
out = self.tokenizer.encode(string) out = self.tokenizer.encode(string)
return {"input_ids": out} return {"input_ids": out}
def decode(self, token_ids, skip_special_tokens=False):
if skip_special_tokens and self.special_tokens:
special_token_ids = set(self.special_tokens.values())
token_ids = [tid for tid in token_ids if tid not in special_token_ids]
return self.tokenizer.decode(token_ids)
def serialize_model(self): def serialize_model(self):
return torch.ByteTensor(list(self.tokenizer.serialized_model_proto())) return torch.ByteTensor(list(self.tokenizer.serialized_model_proto()))

View File

@ -1418,3 +1418,11 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res memo[obj_id] = res
return res return res
def normalize_image_embeddings(embeds, embeds_info, scale_factor):
"""Normalize image embeddings to match text embedding scale"""
for info in embeds_info:
if info.get("type") == "image":
start_idx = info["index"]
end_idx = start_idx + info["size"]
embeds[:, start_idx:end_idx, :] /= scale_factor

View File

@ -49,6 +49,12 @@ class WeightAdapterBase:
""" """
raise NotImplementedError raise NotImplementedError
def calculate_shape(
self,
key
):
return None
def calculate_weight( def calculate_weight(
self, self,
weight, weight,

View File

@ -214,6 +214,13 @@ class LoRAAdapter(WeightAdapterBase):
else: else:
return None return None
def calculate_shape(
self,
key
):
reshape = self.weights[5]
return tuple(reshape) if reshape is not None else None
def calculate_weight( def calculate_weight(
self, self,
weight, weight,

View File

@ -14,6 +14,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"supports_preview_metadata": True, "supports_preview_metadata": True,
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}}, "extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
} }

View File

@ -21,6 +21,17 @@ class ComfyAPI_latest(ComfyAPIBase):
VERSION = "latest" VERSION = "latest"
STABLE = False STABLE = False
def __init__(self):
super().__init__()
self.node_replacement = self.NodeReplacement()
self.execution = self.Execution()
class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None:
"""Register a node replacement mapping."""
from server import PromptServer
PromptServer.instance.node_replace_manager.register(node_replace)
class Execution(ProxiedSingleton): class Execution(ProxiedSingleton):
async def set_progress( async def set_progress(
self, self,
@ -73,8 +84,6 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display, image=to_display,
) )
execution: Execution
class ComfyExtension(ABC): class ComfyExtension(ABC):
async def on_load(self) -> None: async def on_load(self) -> None:
""" """

View File

@ -75,6 +75,12 @@ class NumberDisplay(str, Enum):
slider = "slider" slider = "slider"
class ControlAfterGenerate(str, Enum):
fixed = "fixed"
increment = "increment"
decrement = "decrement"
randomize = "randomize"
class _ComfyType(ABC): class _ComfyType(ABC):
Type = Any Type = Any
io_type: str = None io_type: str = None
@ -263,7 +269,7 @@ class Int(ComfyTypeIO):
class Input(WidgetInput): class Input(WidgetInput):
'''Integer input.''' '''Integer input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None, default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min self.min = min
@ -345,7 +351,7 @@ class Combo(ComfyTypeIO):
tooltip: str=None, tooltip: str=None,
lazy: bool=None, lazy: bool=None,
default: str | int | Enum = None, default: str | int | Enum = None,
control_after_generate: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
upload: UploadType=None, upload: UploadType=None,
image_folder: FolderType=None, image_folder: FolderType=None,
remote: RemoteOptions=None, remote: RemoteOptions=None,
@ -389,7 +395,7 @@ class MultiCombo(ComfyTypeI):
Type = list[str] Type = list[str]
class Input(Combo.Input): class Input(Combo.Input):
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None, default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced) super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
self.multiselect = True self.multiselect = True
@ -1203,6 +1209,30 @@ class Color(ComfyTypeIO):
def as_dict(self): def as_dict(self):
return super().as_dict() return super().as_dict()
@comfytype(io_type="BOUNDING_BOX")
class BoundingBox(ComfyTypeIO):
class BoundingBoxDict(TypedDict):
x: int
y: int
width: int
height: int
Type = BoundingBoxDict
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: dict=None, component: str=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless)
self.component = component
if default is None:
self.default = {"x": 0, "y": 0, "width": 512, "height": 512}
def as_dict(self):
d = super().as_dict()
if self.component:
d["component"] = self.component
return d
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {} DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]): def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func DYNAMIC_INPUT_LOOKUP[io_type] = func
@ -2030,11 +2060,74 @@ class _UIOutput(ABC):
... ...
class InputMapOldId(TypedDict):
"""Map an old node input to a new node input by ID."""
new_id: str
old_id: str
class InputMapSetValue(TypedDict):
"""Set a specific value for a new node input."""
new_id: str
set_value: Any
InputMap = InputMapOldId | InputMapSetValue
"""
Input mapping for node replacement. Type is inferred by dictionary keys:
- {"new_id": str, "old_id": str} - maps old input to new input
- {"new_id": str, "set_value": Any} - sets a specific value for new input
"""
class OutputMap(TypedDict):
"""Map outputs of node replacement via indexes."""
new_idx: int
old_idx: int
class NodeReplace:
"""
Defines a possible node replacement, mapping inputs and outputs of the old node to the new node.
Also supports assigning specific values to the input widgets of the new node.
Args:
new_node_id: The class name of the new replacement node.
old_node_id: The class name of the deprecated node.
old_widget_ids: Ordered list of input IDs for widgets that may not have an input slot
connected. The workflow JSON stores widget values by their relative position index,
not by ID. This list maps those positional indexes to input IDs, enabling the
replacement system to correctly identify widget values during node migration.
input_mapping: List of input mappings from old node to new node.
output_mapping: List of output mappings from old node to new node.
"""
def __init__(self,
new_node_id: str,
old_node_id: str,
old_widget_ids: list[str] | None=None,
input_mapping: list[InputMap] | None=None,
output_mapping: list[OutputMap] | None=None,
):
self.new_node_id = new_node_id
self.old_node_id = old_node_id
self.old_widget_ids = old_widget_ids
self.input_mapping = input_mapping
self.output_mapping = output_mapping
def as_dict(self):
"""Create serializable representation of the node replacement."""
return {
"new_node_id": self.new_node_id,
"old_node_id": self.old_node_id,
"old_widget_ids": self.old_widget_ids,
"input_mapping": list(self.input_mapping) if self.input_mapping else None,
"output_mapping": list(self.output_mapping) if self.output_mapping else None,
}
__all__ = [ __all__ = [
"FolderType", "FolderType",
"UploadType", "UploadType",
"RemoteOptions", "RemoteOptions",
"NumberDisplay", "NumberDisplay",
"ControlAfterGenerate",
"comfytype", "comfytype",
"Custom", "Custom",
@ -2121,4 +2214,6 @@ __all__ = [
"ImageCompare", "ImageCompare",
"PriceBadgeDepends", "PriceBadgeDepends",
"PriceBadge", "PriceBadge",
"BoundingBox",
"NodeReplace",
] ]

View File

@ -45,17 +45,55 @@ class BriaEditImageRequest(BaseModel):
) )
class BriaRemoveBackgroundRequest(BaseModel):
image: str = Field(...)
sync: bool = Field(False)
visual_input_content_moderation: bool = Field(
False, description="If true, returns 422 on input image moderation failure."
)
visual_output_content_moderation: bool = Field(
False, description="If true, returns 422 on visual output moderation failure."
)
seed: int = Field(...)
class BriaStatusResponse(BaseModel): class BriaStatusResponse(BaseModel):
request_id: str = Field(...) request_id: str = Field(...)
status_url: str = Field(...) status_url: str = Field(...)
warning: str | None = Field(None) warning: str | None = Field(None)
class BriaResult(BaseModel): class BriaRemoveBackgroundResult(BaseModel):
image_url: str = Field(...)
class BriaRemoveBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveBackgroundResult | None = Field(None)
class BriaImageEditResult(BaseModel):
structured_prompt: str = Field(...) structured_prompt: str = Field(...)
image_url: str = Field(...) image_url: str = Field(...)
class BriaResponse(BaseModel): class BriaImageEditResponse(BaseModel):
status: str = Field(...) status: str = Field(...)
result: BriaResult | None = Field(None) result: BriaImageEditResult | None = Field(None)
class BriaRemoveVideoBackgroundRequest(BaseModel):
video: str = Field(...)
background_color: str = Field(default="transparent", description="Background color for the output video.")
output_container_and_codec: str = Field(...)
preserve_audio: bool = Field(True)
seed: int = Field(...)
class BriaRemoveVideoBackgroundResult(BaseModel):
video_url: str = Field(...)
class BriaRemoveVideoBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveVideoBackgroundResult | None = Field(None)

View File

@ -64,3 +64,23 @@ class To3DProTaskResultResponse(BaseModel):
class To3DProTaskQueryRequest(BaseModel): class To3DProTaskQueryRequest(BaseModel):
JobId: str = Field(...) JobId: str = Field(...)
class To3DUVFileInput(BaseModel):
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
Url: str = Field(...)
class To3DUVTaskRequest(BaseModel):
File: To3DUVFileInput = Field(...)
class TextureEditImageInfo(BaseModel):
Url: str = Field(...)
class TextureEditTaskRequest(BaseModel):
File3D: To3DUVFileInput = Field(...)
Image: TextureEditImageInfo | None = Field(None)
Prompt: str | None = Field(None)
EnablePBR: bool | None = Field(None)

View File

@ -198,11 +198,6 @@ dict_recraft_substyles_v3 = {
} }
class RecraftModel(str, Enum):
recraftv3 = 'recraftv3'
recraftv2 = 'recraftv2'
class RecraftImageSize(str, Enum): class RecraftImageSize(str, Enum):
res_1024x1024 = '1024x1024' res_1024x1024 = '1024x1024'
res_1365x1024 = '1365x1024' res_1365x1024 = '1365x1024'
@ -221,6 +216,41 @@ class RecraftImageSize(str, Enum):
res_1707x1024 = '1707x1024' res_1707x1024 = '1707x1024'
RECRAFT_V4_SIZES = [
"1024x1024",
"1536x768",
"768x1536",
"1280x832",
"832x1280",
"1216x896",
"896x1216",
"1152x896",
"896x1152",
"832x1344",
"1280x896",
"896x1280",
"1344x768",
"768x1344",
]
RECRAFT_V4_PRO_SIZES = [
"2048x2048",
"3072x1536",
"1536x3072",
"2560x1664",
"1664x2560",
"2432x1792",
"1792x2432",
"2304x1792",
"1792x2304",
"1664x2688",
"1434x1024",
"1024x1434",
"2560x1792",
"1792x2560",
]
class RecraftColorObject(BaseModel): class RecraftColorObject(BaseModel):
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model') rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
@ -234,17 +264,16 @@ class RecraftControlsObject(BaseModel):
class RecraftImageGenerationRequest(BaseModel): class RecraftImageGenerationRequest(BaseModel):
prompt: str = Field(..., description='The text prompt describing the image to generate') prompt: str = Field(..., description='The text prompt describing the image to generate')
size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")') size: str | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
n: int = Field(..., description='The number of images to generate') n: int = Field(..., description='The number of images to generate')
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image') negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")') model: str = Field(...)
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")') style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input') substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process') controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID') style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity') strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
random_seed: int | None = Field(None, description="Seed for video generation") random_seed: int | None = Field(None, description="Seed for video generation")
# text_layout
class RecraftReturnedObject(BaseModel): class RecraftReturnedObject(BaseModel):

View File

@ -3,7 +3,11 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bria import ( from comfy_api_nodes.apis.bria import (
BriaEditImageRequest, BriaEditImageRequest,
BriaResponse, BriaRemoveBackgroundRequest,
BriaRemoveBackgroundResponse,
BriaRemoveVideoBackgroundRequest,
BriaRemoveVideoBackgroundResponse,
BriaImageEditResponse,
BriaStatusResponse, BriaStatusResponse,
InputModerationSettings, InputModerationSettings,
) )
@ -11,10 +15,12 @@ from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
convert_mask_to_image, convert_mask_to_image,
download_url_to_image_tensor, download_url_to_image_tensor,
get_number_of_images, download_url_to_video_output,
poll_op, poll_op,
sync_op, sync_op,
upload_images_to_comfyapi, upload_image_to_comfyapi,
upload_video_to_comfyapi,
validate_video_duration,
) )
@ -73,21 +79,15 @@ class BriaImageEditNode(IO.ComfyNode):
IO.DynamicCombo.Input( IO.DynamicCombo.Input(
"moderation", "moderation",
options=[ options=[
IO.DynamicCombo.Option("false", []),
IO.DynamicCombo.Option( IO.DynamicCombo.Option(
"true", "true",
[ [
IO.Boolean.Input( IO.Boolean.Input("prompt_content_moderation", default=False),
"prompt_content_moderation", default=False IO.Boolean.Input("visual_input_moderation", default=False),
), IO.Boolean.Input("visual_output_moderation", default=True),
IO.Boolean.Input(
"visual_input_moderation", default=False
),
IO.Boolean.Input(
"visual_output_moderation", default=True
),
], ],
), ),
IO.DynamicCombo.Option("false", []),
], ],
tooltip="Moderation settings", tooltip="Moderation settings",
), ),
@ -127,50 +127,26 @@ class BriaImageEditNode(IO.ComfyNode):
mask: Input.Image | None = None, mask: Input.Image | None = None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if not prompt and not structured_prompt: if not prompt and not structured_prompt:
raise ValueError( raise ValueError("One of prompt or structured_prompt is required to be non-empty.")
"One of prompt or structured_prompt is required to be non-empty."
)
if get_number_of_images(image) != 1:
raise ValueError("Exactly one input image is required.")
mask_url = None mask_url = None
if mask is not None: if mask is not None:
mask_url = ( mask_url = await upload_image_to_comfyapi(cls, convert_mask_to_image(mask), wait_label="Uploading mask")
await upload_images_to_comfyapi(
cls,
convert_mask_to_image(mask),
max_images=1,
mime_type="image/png",
wait_label="Uploading mask",
)
)[0]
response = await sync_op( response = await sync_op(
cls, cls,
ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"), ApiEndpoint(path="proxy/bria/v2/image/edit", method="POST"),
data=BriaEditImageRequest( data=BriaEditImageRequest(
instruction=prompt if prompt else None, instruction=prompt if prompt else None,
structured_instruction=structured_prompt if structured_prompt else None, structured_instruction=structured_prompt if structured_prompt else None,
images=await upload_images_to_comfyapi( images=[await upload_image_to_comfyapi(cls, image, wait_label="Uploading image")],
cls,
image,
max_images=1,
mime_type="image/png",
wait_label="Uploading image",
),
mask=mask_url, mask=mask_url,
negative_prompt=negative_prompt if negative_prompt else None, negative_prompt=negative_prompt if negative_prompt else None,
guidance_scale=guidance_scale, guidance_scale=guidance_scale,
seed=seed, seed=seed,
model_version=model, model_version=model,
steps_num=steps, steps_num=steps,
prompt_content_moderation=moderation.get( prompt_content_moderation=moderation.get("prompt_content_moderation", False),
"prompt_content_moderation", False visual_input_content_moderation=moderation.get("visual_input_moderation", False),
), visual_output_content_moderation=moderation.get("visual_output_moderation", False),
visual_input_content_moderation=moderation.get(
"visual_input_moderation", False
),
visual_output_content_moderation=moderation.get(
"visual_output_moderation", False
),
), ),
response_model=BriaStatusResponse, response_model=BriaStatusResponse,
) )
@ -178,7 +154,7 @@ class BriaImageEditNode(IO.ComfyNode):
cls, cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status, status_extractor=lambda r: r.status,
response_model=BriaResponse, response_model=BriaImageEditResponse,
) )
return IO.NodeOutput( return IO.NodeOutput(
await download_url_to_image_tensor(response.result.image_url), await download_url_to_image_tensor(response.result.image_url),
@ -186,11 +162,167 @@ class BriaImageEditNode(IO.ComfyNode):
) )
class BriaRemoveImageBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaRemoveImageBackground",
display_name="Bria Remove Image Background",
category="api node/image/Bria",
description="Remove the background from an image using Bria RMBG 2.0.",
inputs=[
IO.Image.Input("image"),
IO.DynamicCombo.Input(
"moderation",
options=[
IO.DynamicCombo.Option("false", []),
IO.DynamicCombo.Option(
"true",
[
IO.Boolean.Input("visual_input_moderation", default=False),
IO.Boolean.Input("visual_output_moderation", default=True),
],
),
],
tooltip="Moderation settings",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.018}""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
moderation: dict,
seed: int,
) -> IO.NodeOutput:
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/image/edit/remove_background", method="POST"),
data=BriaRemoveBackgroundRequest(
image=await upload_image_to_comfyapi(cls, image, wait_label="Uploading image"),
sync=False,
visual_input_content_moderation=moderation.get("visual_input_moderation", False),
visual_output_content_moderation=moderation.get("visual_output_moderation", False),
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveBackgroundResponse,
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result.image_url))
class BriaRemoveVideoBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaRemoveVideoBackground",
display_name="Bria Remove Video Background",
category="api node/video/Bria",
description="Remove the background from a video using Bria. ",
inputs=[
IO.Video.Input("video"),
IO.Combo.Input(
"background_color",
options=[
"Black",
"White",
"Gray",
"Red",
"Green",
"Blue",
"Yellow",
"Cyan",
"Magenta",
"Orange",
],
tooltip="Background color for the output video.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
background_color: str,
seed: int,
) -> IO.NodeOutput:
validate_video_duration(video, max_duration=60.0)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
data=BriaRemoveVideoBackgroundRequest(
video=await upload_video_to_comfyapi(cls, video),
background_color=background_color,
output_container_and_codec="mp4_h264",
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveVideoBackgroundResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
class BriaExtension(ComfyExtension): class BriaExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
BriaImageEditNode, BriaImageEditNode,
BriaRemoveImageBackground,
BriaRemoveVideoBackground,
] ]

View File

@ -6,6 +6,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
import base64 import base64
import os import os
from enum import Enum from enum import Enum
from fnmatch import fnmatch
from io import BytesIO from io import BytesIO
from typing import Literal from typing import Literal
@ -119,6 +120,13 @@ async def create_image_parts(
return image_parts return image_parts
def _mime_matches(mime: GeminiMimeType | None, pattern: str) -> bool:
"""Check if a MIME type matches a pattern. Supports fnmatch globs (e.g. 'image/*')."""
if mime is None:
return False
return fnmatch(mime.value, pattern)
def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]: def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]:
""" """
Filter response parts by their type. Filter response parts by their type.
@ -151,9 +159,9 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
for part in candidate.content.parts: for part in candidate.content.parts:
if part_type == "text" and part.text: if part_type == "text" and part.text:
parts.append(part) parts.append(part)
elif part.inlineData and part.inlineData.mimeType == part_type: elif part.inlineData and _mime_matches(part.inlineData.mimeType, part_type):
parts.append(part) parts.append(part)
elif part.fileData and part.fileData.mimeType == part_type: elif part.fileData and _mime_matches(part.fileData.mimeType, part_type):
parts.append(part) parts.append(part)
if not parts and blocked_reasons: if not parts and blocked_reasons:
@ -178,7 +186,7 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = [] image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png") parts = get_parts_by_type(response, "image/*")
for part in parts: for part in parts:
if part.inlineData: if part.inlineData:
image_data = base64.b64decode(part.inlineData.data) image_data = base64.b64decode(part.inlineData.data)

View File

@ -1,31 +1,48 @@
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.hunyuan3d import ( from comfy_api_nodes.apis.hunyuan3d import (
Hunyuan3DViewImage, Hunyuan3DViewImage,
InputGenerateType, InputGenerateType,
ResultFile3D, ResultFile3D,
TextureEditTaskRequest,
To3DProTaskCreateResponse, To3DProTaskCreateResponse,
To3DProTaskQueryRequest, To3DProTaskQueryRequest,
To3DProTaskRequest, To3DProTaskRequest,
To3DProTaskResultResponse, To3DProTaskResultResponse,
To3DUVFileInput,
To3DUVTaskRequest,
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
download_url_to_file_3d, download_url_to_file_3d,
download_url_to_image_tensor,
downscale_image_tensor_by_max_side, downscale_image_tensor_by_max_side,
poll_op, poll_op,
sync_op, sync_op,
upload_3d_model_to_comfyapi,
upload_image_to_comfyapi, upload_image_to_comfyapi,
validate_image_dimensions, validate_image_dimensions,
validate_string, validate_string,
) )
def get_file_from_response(response_objs: list[ResultFile3D], file_type: str) -> ResultFile3D | None: def _is_tencent_rate_limited(status: int, body: object) -> bool:
return (
status == 400
and isinstance(body, dict)
and "RequestLimitExceeded" in str(body.get("Response", {}).get("Error", {}).get("Code", ""))
)
def get_file_from_response(
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
) -> ResultFile3D | None:
for i in response_objs: for i in response_objs:
if i.Type.lower() == file_type.lower(): if i.Type.lower() == file_type.lower():
return i return i
if raise_if_not_found:
raise ValueError(f"'{file_type}' file type is not found in the response.")
return None return None
@ -35,7 +52,7 @@ class TencentTextToModelNode(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="TencentTextToModelNode", node_id="TencentTextToModelNode",
display_name="Hunyuan3D: Text to Model (Pro)", display_name="Hunyuan3D: Text to Model",
category="api node/3d/Tencent", category="api node/3d/Tencent",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -120,6 +137,7 @@ class TencentTextToModelNode(IO.ComfyNode):
EnablePBR=generate_type.get("pbr", None), EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None), PolygonType=generate_type.get("polygon_type", None),
), ),
is_rate_limited=_is_tencent_rate_limited,
) )
if response.Error: if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
@ -131,11 +149,14 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse, response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status, status_extractor=lambda r: r.Status,
) )
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
return IO.NodeOutput( return IO.NodeOutput(
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
),
) )
@ -145,7 +166,7 @@ class TencentImageToModelNode(IO.ComfyNode):
def define_schema(cls): def define_schema(cls):
return IO.Schema( return IO.Schema(
node_id="TencentImageToModelNode", node_id="TencentImageToModelNode",
display_name="Hunyuan3D: Image(s) to Model (Pro)", display_name="Hunyuan3D: Image(s) to Model",
category="api node/3d/Tencent", category="api node/3d/Tencent",
inputs=[ inputs=[
IO.Combo.Input( IO.Combo.Input(
@ -268,6 +289,7 @@ class TencentImageToModelNode(IO.ComfyNode):
EnablePBR=generate_type.get("pbr", None), EnablePBR=generate_type.get("pbr", None),
PolygonType=generate_type.get("polygon_type", None), PolygonType=generate_type.get("polygon_type", None),
), ),
is_rate_limited=_is_tencent_rate_limited,
) )
if response.Error: if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}") raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
@ -279,11 +301,257 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse, response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status, status_extractor=lambda r: r.Status,
) )
glb_result = get_file_from_response(result.ResultFile3Ds, "glb")
obj_result = get_file_from_response(result.ResultFile3Ds, "obj")
file_glb = await download_url_to_file_3d(glb_result.Url, "glb", task_id=task_id) if glb_result else None
return IO.NodeOutput( return IO.NodeOutput(
file_glb, file_glb, await download_url_to_file_3d(obj_result.Url, "obj", task_id=task_id) if obj_result else None f"{task_id}.glb",
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
),
await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
),
)
class TencentModelTo3DUVNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TencentModelTo3DUVNode",
display_name="Hunyuan3D: Model to UV",
category="api node/3d/Tencent",
description="Perform UV unfolding on a 3D model to generate UV texture. "
"Input model must have less than 30000 faces.",
inputs=[
IO.MultiType.Input(
"model_3d",
types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DFBX, IO.File3DAny],
tooltip="Input 3D model (GLB, OBJ, or FBX)",
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.File3DOBJ.Output(display_name="OBJ"),
IO.File3DFBX.Output(display_name="FBX"),
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.2}'),
)
SUPPORTED_FORMATS = {"glb", "obj", "fbx"}
@classmethod
async def execute(
cls,
model_3d: Types.File3D,
seed: int,
) -> IO.NodeOutput:
_ = seed
file_format = model_3d.format.lower()
if file_format not in cls.SUPPORTED_FORMATS:
raise ValueError(
f"Unsupported file format: '{file_format}'. "
f"Supported formats: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
response_model=To3DProTaskCreateResponse,
data=To3DUVTaskRequest(
File=To3DUVFileInput(
Type=file_format.upper(),
Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
)
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv/query", method="POST"),
data=To3DProTaskQueryRequest(JobId=response.JobId),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
)
class Tencent3DTextureEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Tencent3DTextureEditNode",
display_name="Hunyuan3D: 3D Texture Edit",
category="api node/3d/Tencent",
description="After inputting the 3D model, perform 3D model texture redrawing.",
inputs=[
IO.MultiType.Input(
"model_3d",
types=[IO.File3DFBX, IO.File3DAny],
tooltip="3D model in FBX format. Model should have less than 100000 faces.",
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Describes texture editing. Supports up to 1024 UTF-8 characters.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.File3DGLB.Output(display_name="GLB"),
IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.6}""",
),
)
@classmethod
async def execute(
cls,
model_3d: Types.File3D,
prompt: str,
seed: int,
) -> IO.NodeOutput:
_ = seed
file_format = model_3d.format.lower()
if file_format != "fbx":
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
validate_string(prompt, field_name="prompt", min_length=1, max_length=1024)
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
response_model=To3DProTaskCreateResponse,
data=TextureEditTaskRequest(
File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
Prompt=prompt,
EnablePBR=True,
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit/query", method="POST"),
data=To3DProTaskQueryRequest(JobId=response.JobId),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
)
class Tencent3DPartNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Tencent3DPartNode",
display_name="Hunyuan3D: 3D Part",
category="api node/3d/Tencent",
description="Automatically perform component identification and generation based on the model structure.",
inputs=[
IO.MultiType.Input(
"model_3d",
types=[IO.File3DFBX, IO.File3DAny],
tooltip="3D model in FBX format. Model should have less than 30000 faces.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.File3DFBX.Output(display_name="FBX"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":0.6}'),
)
@classmethod
async def execute(
cls,
model_3d: Types.File3D,
seed: int,
) -> IO.NodeOutput:
_ = seed
file_format = model_3d.format.lower()
if file_format != "fbx":
raise ValueError(f"Unsupported file format: '{file_format}'. Only FBX format is supported.")
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
response_model=To3DProTaskCreateResponse,
data=To3DUVTaskRequest(
File=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed with code {response.Error.Code}: {response.Error.Message}")
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part/query", method="POST"),
data=To3DProTaskQueryRequest(JobId=response.JobId),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
) )
@ -293,6 +561,9 @@ class TencentHunyuan3DExtension(ComfyExtension):
return [ return [
TencentTextToModelNode, TencentTextToModelNode,
TencentImageToModelNode, TencentImageToModelNode,
# TencentModelTo3DUVNode,
# Tencent3DTextureEditNode,
Tencent3DPartNode,
] ]

View File

@ -43,7 +43,6 @@ class SupportedOpenAIModel(str, Enum):
o1 = "o1" o1 = "o1"
o3 = "o3" o3 = "o3"
o1_pro = "o1-pro" o1_pro = "o1-pro"
gpt_4o = "gpt-4o"
gpt_4_1 = "gpt-4.1" gpt_4_1 = "gpt-4.1"
gpt_4_1_mini = "gpt-4.1-mini" gpt_4_1_mini = "gpt-4.1-mini"
gpt_4_1_nano = "gpt-4.1-nano" gpt_4_1_nano = "gpt-4.1-nano"
@ -649,11 +648,6 @@ class OpenAIChatNode(IO.ComfyNode):
"usd": [0.01, 0.04], "usd": [0.01, 0.04],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
} }
: $contains($m, "gpt-4o") ? {
"type": "list_usd",
"usd": [0.0025, 0.01],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gpt-4.1-nano") ? { : $contains($m, "gpt-4.1-nano") ? {
"type": "list_usd", "type": "list_usd",
"usd": [0.0001, 0.0004], "usd": [0.0001, 0.0004],

View File

@ -1,5 +1,4 @@
from io import BytesIO from io import BytesIO
from typing import Optional, Union
import aiohttp import aiohttp
import torch import torch
@ -9,6 +8,8 @@ from typing_extensions import override
from comfy.utils import ProgressBar from comfy.utils import ProgressBar
from comfy_api.latest import IO, ComfyExtension from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.recraft import ( from comfy_api_nodes.apis.recraft import (
RECRAFT_V4_PRO_SIZES,
RECRAFT_V4_SIZES,
RecraftColor, RecraftColor,
RecraftColorChain, RecraftColorChain,
RecraftControls, RecraftControls,
@ -18,7 +19,6 @@ from comfy_api_nodes.apis.recraft import (
RecraftImageGenerationResponse, RecraftImageGenerationResponse,
RecraftImageSize, RecraftImageSize,
RecraftIO, RecraftIO,
RecraftModel,
RecraftStyle, RecraftStyle,
RecraftStyleV3, RecraftStyleV3,
get_v3_substyles, get_v3_substyles,
@ -39,7 +39,7 @@ async def handle_recraft_file_request(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
image: torch.Tensor, image: torch.Tensor,
path: str, path: str,
mask: Optional[torch.Tensor] = None, mask: torch.Tensor | None = None,
total_pixels: int = 4096 * 4096, total_pixels: int = 4096 * 4096,
timeout: int = 1024, timeout: int = 1024,
request=None, request=None,
@ -73,11 +73,11 @@ async def handle_recraft_file_request(
def recraft_multipart_parser( def recraft_multipart_parser(
data, data,
parent_key=None, parent_key=None,
formatter: Optional[type[callable]] = None, formatter: type[callable] | None = None,
converted_to_check: Optional[list[list]] = None, converted_to_check: list[list] | None = None,
is_list: bool = False, is_list: bool = False,
return_mode: str = "formdata", # "dict" | "formdata" return_mode: str = "formdata", # "dict" | "formdata"
) -> Union[dict, aiohttp.FormData]: ) -> dict | aiohttp.FormData:
""" """
Formats data such that multipart/form-data will work with aiohttp library when both files and data are present. Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
@ -309,7 +309,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
node_id="RecraftStyleV3InfiniteStyleLibrary", node_id="RecraftStyleV3InfiniteStyleLibrary",
display_name="Recraft Style - Infinite Style Library", display_name="Recraft Style - Infinite Style Library",
category="api node/image/Recraft", category="api node/image/Recraft",
description="Select style based on preexisting UUID from Recraft's Infinite Style Library.", description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
inputs=[ inputs=[
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."), IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
], ],
@ -485,7 +485,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
data=RecraftImageGenerationRequest( data=RecraftImageGenerationRequest(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
model=RecraftModel.recraftv3, model="recraftv3",
size=size, size=size,
n=n, n=n,
style=recraft_style.style, style=recraft_style.style,
@ -598,7 +598,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
request = RecraftImageGenerationRequest( request = RecraftImageGenerationRequest(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
model=RecraftModel.recraftv3, model="recraftv3",
n=n, n=n,
strength=round(strength, 2), strength=round(strength, 2),
style=recraft_style.style, style=recraft_style.style,
@ -698,7 +698,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
request = RecraftImageGenerationRequest( request = RecraftImageGenerationRequest(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
model=RecraftModel.recraftv3, model="recraftv3",
n=n, n=n,
style=recraft_style.style, style=recraft_style.style,
substyle=recraft_style.substyle, substyle=recraft_style.substyle,
@ -810,7 +810,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
data=RecraftImageGenerationRequest( data=RecraftImageGenerationRequest(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
model=RecraftModel.recraftv3, model="recraftv3",
size=size, size=size,
n=n, n=n,
style=recraft_style.style, style=recraft_style.style,
@ -933,7 +933,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
request = RecraftImageGenerationRequest( request = RecraftImageGenerationRequest(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
model=RecraftModel.recraftv3, model="recraftv3",
n=n, n=n,
style=recraft_style.style, style=recraft_style.style,
substyle=recraft_style.substyle, substyle=recraft_style.substyle,
@ -1078,6 +1078,252 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
) )
class RecraftV4TextToImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RecraftV4TextToImageNode",
display_name="Recraft V4 Text to Image",
category="api node/image/Recraft",
description="Generates images using Recraft V4 or V4 Pro models.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Prompt for the image generation. Maximum 10,000 characters.",
),
IO.String.Input(
"negative_prompt",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"recraftv4",
[
IO.Combo.Input(
"size",
options=RECRAFT_V4_SIZES,
default="1024x1024",
tooltip="The size of the generated image.",
),
],
),
IO.DynamicCombo.Option(
"recraftv4_pro",
[
IO.Combo.Input(
"size",
options=RECRAFT_V4_PRO_SIZES,
default="2048x2048",
tooltip="The size of the generated image.",
),
],
),
],
tooltip="The model to use for generation.",
),
IO.Int.Input(
"n",
default=1,
min=1,
max=6,
tooltip="The number of images to generate.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Custom(RecraftIO.CONTROLS).Input(
"recraft_controls",
tooltip="Optional additional controls over the generation via the Recraft Controls node.",
optional=True,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
expr="""
(
$prices := {"recraftv4": 0.04, "recraftv4_pro": 0.25};
{"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
model: dict,
n: int,
seed: int,
recraft_controls: RecraftControls | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
response_model=RecraftImageGenerationResponse,
data=RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
model=model["model"],
size=model["size"],
n=n,
controls=recraft_controls.create_api_model() if recraft_controls else None,
),
max_retries=1,
)
images = []
for data in response.data:
with handle_recraft_image_output():
image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024))
if len(image.shape) < 4:
image = image.unsqueeze(0)
images.append(image)
return IO.NodeOutput(torch.cat(images, dim=0))
class RecraftV4TextToVectorNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="RecraftV4TextToVectorNode",
display_name="Recraft V4 Text to Vector",
category="api node/image/Recraft",
description="Generates SVG using Recraft V4 or V4 Pro models.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
tooltip="Prompt for the image generation. Maximum 10,000 characters.",
),
IO.String.Input(
"negative_prompt",
multiline=True,
tooltip="An optional text description of undesired elements on an image.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"recraftv4",
[
IO.Combo.Input(
"size",
options=RECRAFT_V4_SIZES,
default="1024x1024",
tooltip="The size of the generated image.",
),
],
),
IO.DynamicCombo.Option(
"recraftv4_pro",
[
IO.Combo.Input(
"size",
options=RECRAFT_V4_PRO_SIZES,
default="2048x2048",
tooltip="The size of the generated image.",
),
],
),
],
tooltip="The model to use for generation.",
),
IO.Int.Input(
"n",
default=1,
min=1,
max=6,
tooltip="The number of images to generate.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Custom(RecraftIO.CONTROLS).Input(
"recraft_controls",
tooltip="Optional additional controls over the generation via the Recraft Controls node.",
optional=True,
),
],
outputs=[
IO.SVG.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
expr="""
(
$prices := {"recraftv4": 0.08, "recraftv4_pro": 0.30};
{"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
negative_prompt: str,
model: dict,
n: int,
seed: int,
recraft_controls: RecraftControls | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
response_model=RecraftImageGenerationResponse,
data=RecraftImageGenerationRequest(
prompt=prompt,
negative_prompt=negative_prompt if negative_prompt else None,
model=model["model"],
size=model["size"],
n=n,
style="vector_illustration",
substyle=None,
controls=recraft_controls.create_api_model() if recraft_controls else None,
),
max_retries=1,
)
svg_data = []
for data in response.data:
svg_data.append(await download_url_as_bytesio(data.url, timeout=1024))
return IO.NodeOutput(SVG(svg_data))
class RecraftExtension(ComfyExtension): class RecraftExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1098,6 +1344,8 @@ class RecraftExtension(ComfyExtension):
RecraftCreateStyleNode, RecraftCreateStyleNode,
RecraftColorRGBNode, RecraftColorRGBNode,
RecraftControlsNode, RecraftControlsNode,
RecraftV4TextToImageNode,
RecraftV4TextToVectorNode,
] ]

View File

@ -505,6 +505,9 @@ class Rodin3D_Gen2(IO.ComfyNode):
IO.Hidden.unique_id, IO.Hidden.unique_id,
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.4}""",
),
) )
@classmethod @classmethod

View File

@ -54,6 +54,7 @@ async def execute_task(
response_model=TaskStatusResponse, response_model=TaskStatusResponse,
status_extractor=lambda r: r.state, status_extractor=lambda r: r.state,
progress_extractor=lambda r: r.progress, progress_extractor=lambda r: r.progress,
price_extractor=lambda r: r.credits * 0.005 if r.credits is not None else None,
max_poll_attempts=max_poll_attempts, max_poll_attempts=max_poll_attempts,
) )
if not response.creations: if not response.creations:
@ -1306,6 +1307,36 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
), ),
], ],
), ),
IO.DynamicCombo.Option(
"viduq3-turbo",
[
IO.Combo.Input(
"aspect_ratio",
options=["16:9", "9:16", "3:4", "4:3", "1:1"],
tooltip="The aspect ratio of the output video.",
),
IO.Combo.Input(
"resolution",
options=["720p", "1080p"],
tooltip="Resolution of the output video.",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=16,
step=1,
display_mode=IO.NumberDisplay.slider,
tooltip="Duration of the output video in seconds.",
),
IO.Boolean.Input(
"audio",
default=False,
tooltip="When enabled, outputs video with sound "
"(including dialogue and sound effects).",
),
],
),
], ],
tooltip="Model to use for video generation.", tooltip="Model to use for video generation.",
), ),
@ -1334,13 +1365,20 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]), depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
expr=""" expr="""
( (
$res := $lookup(widgets, "model.resolution"); $res := $lookup(widgets, "model.resolution");
$base := $lookup({"720p": 0.075, "1080p": 0.1}, $res); $d := $lookup(widgets, "model.duration");
$perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res); $contains(widgets.model, "turbo")
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)} ? (
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
{"type":"usd","usd": $rate * $d}
)
: (
$rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
{"type":"usd","usd": $rate * $d}
)
) )
""", """,
), ),
@ -1409,6 +1447,31 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
), ),
], ],
), ),
IO.DynamicCombo.Option(
"viduq3-turbo",
[
IO.Combo.Input(
"resolution",
options=["720p", "1080p"],
tooltip="Resolution of the output video.",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=16,
step=1,
display_mode=IO.NumberDisplay.slider,
tooltip="Duration of the output video in seconds.",
),
IO.Boolean.Input(
"audio",
default=False,
tooltip="When enabled, outputs video with sound "
"(including dialogue and sound effects).",
),
],
),
], ],
tooltip="Model to use for video generation.", tooltip="Model to use for video generation.",
), ),
@ -1442,13 +1505,20 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
], ],
is_api_node=True, is_api_node=True,
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]), depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
expr=""" expr="""
( (
$res := $lookup(widgets, "model.resolution"); $res := $lookup(widgets, "model.resolution");
$base := $lookup({"720p": 0.075, "1080p": 0.275, "2k": 0.35}, $res); $d := $lookup(widgets, "model.duration");
$perSec := $lookup({"720p": 0.05, "1080p": 0.075, "2k": 0.075}, $res); $contains(widgets.model, "turbo")
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)} ? (
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
{"type":"usd","usd": $rate * $d}
)
: (
$rate := $lookup({"720p": 0.15, "1080p": 0.16, "2k": 0.2}, $res);
{"type":"usd","usd": $rate * $d}
)
) )
""", """,
), ),
@ -1481,6 +1551,145 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(results[0].url)) return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class Vidu3StartEndToVideoNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="Vidu3StartEndToVideoNode",
display_name="Vidu Q3 Start/End Frame-to-Video Generation",
category="api node/video/Vidu",
description="Generate a video from a start frame, an end frame, and a prompt.",
inputs=[
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"viduq3-pro",
[
IO.Combo.Input(
"resolution",
options=["720p", "1080p"],
tooltip="Resolution of the output video.",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=16,
step=1,
display_mode=IO.NumberDisplay.slider,
tooltip="Duration of the output video in seconds.",
),
IO.Boolean.Input(
"audio",
default=False,
tooltip="When enabled, outputs video with sound "
"(including dialogue and sound effects).",
),
],
),
IO.DynamicCombo.Option(
"viduq3-turbo",
[
IO.Combo.Input(
"resolution",
options=["720p", "1080p"],
tooltip="Resolution of the output video.",
),
IO.Int.Input(
"duration",
default=5,
min=1,
max=16,
step=1,
display_mode=IO.NumberDisplay.slider,
tooltip="Duration of the output video in seconds.",
),
IO.Boolean.Input(
"audio",
default=False,
tooltip="When enabled, outputs video with sound "
"(including dialogue and sound effects).",
),
],
),
],
tooltip="Model to use for video generation.",
),
IO.Image.Input("first_frame"),
IO.Image.Input("end_frame"),
IO.String.Input(
"prompt",
multiline=True,
tooltip="Prompt description (max 2000 characters).",
),
IO.Int.Input(
"seed",
default=1,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
expr="""
(
$res := $lookup(widgets, "model.resolution");
$d := $lookup(widgets, "model.duration");
$contains(widgets.model, "turbo")
? (
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
{"type":"usd","usd": $rate * $d}
)
: (
$rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
{"type":"usd","usd": $rate * $d}
)
)
""",
),
)
@classmethod
async def execute(
cls,
model: dict,
first_frame: Input.Image,
end_frame: Input.Image,
prompt: str,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, max_length=2000)
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
payload = TaskCreationRequest(
model=model["model"],
prompt=prompt,
duration=model["duration"],
seed=seed,
resolution=model["resolution"],
audio=model["audio"],
images=[
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
for frame in (first_frame, end_frame)
],
)
results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
class ViduExtension(ComfyExtension): class ViduExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -1497,6 +1706,7 @@ class ViduExtension(ComfyExtension):
ViduMultiFrameVideoNode, ViduMultiFrameVideoNode,
Vidu3TextToVideoNode, Vidu3TextToVideoNode,
Vidu3ImageToVideoNode, Vidu3ImageToVideoNode,
Vidu3StartEndToVideoNode,
] ]

View File

@ -33,6 +33,7 @@ from .download_helpers import (
download_url_to_video_output, download_url_to_video_output,
) )
from .upload_helpers import ( from .upload_helpers import (
upload_3d_model_to_comfyapi,
upload_audio_to_comfyapi, upload_audio_to_comfyapi,
upload_file_to_comfyapi, upload_file_to_comfyapi,
upload_image_to_comfyapi, upload_image_to_comfyapi,
@ -62,6 +63,7 @@ __all__ = [
"sync_op", "sync_op",
"sync_op_raw", "sync_op_raw",
# Upload helpers # Upload helpers
"upload_3d_model_to_comfyapi",
"upload_audio_to_comfyapi", "upload_audio_to_comfyapi",
"upload_file_to_comfyapi", "upload_file_to_comfyapi",
"upload_image_to_comfyapi", "upload_image_to_comfyapi",

View File

@ -57,7 +57,7 @@ def tensor_to_bytesio(
image: torch.Tensor, image: torch.Tensor,
*, *,
total_pixels: int | None = 2048 * 2048, total_pixels: int | None = 2048 * 2048,
mime_type: str = "image/png", mime_type: str | None = "image/png",
) -> BytesIO: ) -> BytesIO:
"""Converts a torch.Tensor image to a named BytesIO object. """Converts a torch.Tensor image to a named BytesIO object.

View File

@ -164,6 +164,27 @@ async def upload_video_to_comfyapi(
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label) return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
_3D_MIME_TYPES = {
"glb": "model/gltf-binary",
"obj": "model/obj",
"fbx": "application/octet-stream",
}
async def upload_3d_model_to_comfyapi(
cls: type[IO.ComfyNode],
model_3d: Types.File3D,
file_format: str,
) -> str:
"""Uploads a 3D model file to ComfyUI API and returns its download URL."""
return await upload_file_to_comfyapi(
cls,
model_3d.get_data(),
f"{uuid.uuid4()}.{file_format}",
_3D_MIME_TYPES.get(file_format, "application/octet-stream"),
)
async def upload_file_to_comfyapi( async def upload_file_to_comfyapi(
cls: type[IO.ComfyNode], cls: type[IO.ComfyNode],
file_bytes_io: BytesIO, file_bytes_io: BytesIO,

View File

@ -698,6 +698,67 @@ class EmptyAudio(IO.ComfyNode):
create_empty_audio = execute # TODO: remove create_empty_audio = execute # TODO: remove
class AudioEqualizer3Band(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="AudioEqualizer3Band",
search_aliases=["eq", "bass boost", "treble boost", "equalizer"],
display_name="Audio Equalizer (3-Band)",
category="audio",
is_experimental=True,
inputs=[
IO.Audio.Input("audio"),
IO.Float.Input("low_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Low frequencies (Bass)"),
IO.Int.Input("low_freq", default=100, min=20, max=500, tooltip="Cutoff frequency for Low shelf"),
IO.Float.Input("mid_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for Mid frequencies"),
IO.Int.Input("mid_freq", default=1000, min=200, max=4000, tooltip="Center frequency for Mids"),
IO.Float.Input("mid_q", default=0.707, min=0.1, max=10.0, step=0.1, tooltip="Q factor (bandwidth) for Mids"),
IO.Float.Input("high_gain_dB", default=0.0, min=-24.0, max=24.0, step=0.1, tooltip="Gain for High frequencies (Treble)"),
IO.Int.Input("high_freq", default=5000, min=1000, max=15000, tooltip="Cutoff frequency for High shelf"),
],
outputs=[IO.Audio.Output()],
)
@classmethod
def execute(cls, audio, low_gain_dB, low_freq, mid_gain_dB, mid_freq, mid_q, high_gain_dB, high_freq) -> IO.NodeOutput:
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
eq_waveform = waveform.clone()
# 1. Apply Low Shelf (Bass)
if low_gain_dB != 0:
eq_waveform = torchaudio.functional.bass_biquad(
eq_waveform,
sample_rate,
gain=low_gain_dB,
central_freq=float(low_freq),
Q=0.707
)
# 2. Apply Peaking EQ (Mids)
if mid_gain_dB != 0:
eq_waveform = torchaudio.functional.equalizer_biquad(
eq_waveform,
sample_rate,
center_freq=float(mid_freq),
gain=mid_gain_dB,
Q=mid_q
)
# 3. Apply High Shelf (Treble)
if high_gain_dB != 0:
eq_waveform = torchaudio.functional.treble_biquad(
eq_waveform,
sample_rate,
gain=high_gain_dB,
central_freq=float(high_freq),
Q=0.707
)
return IO.NodeOutput({"waveform": eq_waveform, "sample_rate": sample_rate})
class AudioExtension(ComfyExtension): class AudioExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -720,6 +781,7 @@ class AudioExtension(ComfyExtension):
AudioMerge, AudioMerge,
AudioAdjustVolume, AudioAdjustVolume,
EmptyAudio, EmptyAudio,
AudioEqualizer3Band,
] ]
async def comfy_entrypoint() -> AudioExtension: async def comfy_entrypoint() -> AudioExtension:

View File

@ -23,8 +23,9 @@ class ImageCrop(IO.ComfyNode):
return IO.Schema( return IO.Schema(
node_id="ImageCrop", node_id="ImageCrop",
search_aliases=["trim"], search_aliases=["trim"],
display_name="Image Crop", display_name="Image Crop (Deprecated)",
category="image/transform", category="image/transform",
is_deprecated=True,
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1),
@ -47,6 +48,57 @@ class ImageCrop(IO.ComfyNode):
crop = execute # TODO: remove crop = execute # TODO: remove
class ImageCropV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ImageCropV2",
search_aliases=["trim"],
display_name="Image Crop",
category="image/transform",
inputs=[
IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
],
outputs=[IO.Image.Output()],
)
@classmethod
def execute(cls, image, crop_region) -> IO.NodeOutput:
x = crop_region.get("x", 0)
y = crop_region.get("y", 0)
width = crop_region.get("width", 512)
height = crop_region.get("height", 512)
x = min(x, image.shape[2] - 1)
y = min(y, image.shape[1] - 1)
to_x = width + x
to_y = height + y
img = image[:,y:to_y, x:to_x, :]
return IO.NodeOutput(img, ui=UI.PreviewImage(img))
class BoundingBox(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="PrimitiveBoundingBox",
display_name="Bounding Box",
category="utils/primitive",
inputs=[
IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION),
IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION),
IO.Int.Input("width", default=512, min=1, max=MAX_RESOLUTION),
IO.Int.Input("height", default=512, min=1, max=MAX_RESOLUTION),
],
outputs=[IO.BoundingBox.Output()],
)
@classmethod
def execute(cls, x, y, width, height) -> IO.NodeOutput:
return IO.NodeOutput({"x": x, "y": y, "width": width, "height": height})
class RepeatImageBatch(IO.ComfyNode): class RepeatImageBatch(IO.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -632,6 +684,8 @@ class ImagesExtension(ComfyExtension):
async def get_node_list(self) -> list[type[IO.ComfyNode]]: async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [ return [
ImageCrop, ImageCrop,
ImageCropV2,
BoundingBox,
RepeatImageBatch, RepeatImageBatch,
ImageFromBatch, ImageFromBatch,
ImageAddNoise, ImageAddNoise,

View File

@ -7,6 +7,7 @@ import logging
from enum import Enum from enum import Enum
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from tqdm.auto import trange
CLAMP_QUANTILE = 0.99 CLAMP_QUANTILE = 0.99
@ -49,12 +50,22 @@ LORA_TYPES = {"standard": LORAType.STANDARD,
"full_diff": LORAType.FULL_DIFF} "full_diff": LORAType.FULL_DIFF}
def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False): def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora_type, bias_diff=False):
comfy.model_management.load_models_gpu([model_diff], force_patch_weights=True) comfy.model_management.load_models_gpu([model_diff])
sd = model_diff.model_state_dict(filter_prefix=prefix_model) sd = model_diff.model_state_dict(filter_prefix=prefix_model)
for k in sd: sd_keys = list(sd.keys())
if k.endswith(".weight"): for index in trange(len(sd_keys), unit="weight"):
k = sd_keys[index]
op_keys = sd_keys[index].rsplit('.', 1)
if len(op_keys) < 2 or op_keys[1] not in ["weight", "bias"] or (op_keys[1] == "bias" and not bias_diff):
continue
op = comfy.utils.get_attr(model_diff.model, op_keys[0])
if hasattr(op, "comfy_cast_weights") and not getattr(op, "comfy_patched_weights", False):
weight_diff = model_diff.patch_weight_to_device(k, model_diff.load_device, return_weight=True)
else:
weight_diff = sd[k] weight_diff = sd[k]
if op_keys[1] == "weight":
if lora_type == LORAType.STANDARD: if lora_type == LORAType.STANDARD:
if weight_diff.ndim < 2: if weight_diff.ndim < 2:
if bias_diff: if bias_diff:
@ -69,8 +80,8 @@ def calc_lora_model(model_diff, rank, prefix_model, prefix_lora, output_sd, lora
elif lora_type == LORAType.FULL_DIFF: elif lora_type == LORAType.FULL_DIFF:
output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu() output_sd["{}{}.diff".format(prefix_lora, k[len(prefix_model):-7])] = weight_diff.contiguous().half().cpu()
elif bias_diff and k.endswith(".bias"): elif bias_diff and op_keys[1] == "bias":
output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = sd[k].contiguous().half().cpu() output_sd["{}{}.diff_b".format(prefix_lora, k[len(prefix_model):-5])] = weight_diff.contiguous().half().cpu()
return output_sd return output_sd
class LoraSave(io.ComfyNode): class LoraSave(io.ComfyNode):

99
comfy_extras/nodes_nag.py Normal file
View File

@ -0,0 +1,99 @@
import torch
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
class NAGuidance(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="NAGuidance",
display_name="Normalized Attention Guidance",
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
category="",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to apply NAG to."),
io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
# io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
# io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
],
outputs=[
io.Model.Output(tooltip="The patched model with NAG enabled."),
],
)
@classmethod
def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
m = model.clone()
# sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
# sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
def nag_attention_output_patch(out, extra_options):
cond_or_uncond = extra_options.get("cond_or_uncond", None)
if cond_or_uncond is None:
return out
if not (1 in cond_or_uncond and 0 in cond_or_uncond):
return out
# sigma = extra_options.get("sigmas", None)
# if sigma is not None and len(sigma) > 0:
# sigma = sigma[0].item()
# if sigma > sigma_start or sigma < sigma_end:
# return out
img_slice = extra_options.get("img_slice", None)
if img_slice is not None:
orig_out = out
out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
batch_size = out.shape[0]
half_size = batch_size // len(cond_or_uncond)
ind_neg = cond_or_uncond.index(1)
ind_pos = cond_or_uncond.index(0)
z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
eps = 1e-6
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
ratio = norm_guided / norm_pos
scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
guided_normalized = guided * scale_factor
z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
if img_slice is not None:
orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
return orig_out
else:
out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
return out
m.set_model_attn1_output_patch(nag_attention_output_patch)
m.disable_model_cfg1_optimization()
return io.NodeOutput(m)
class NagExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
NAGuidance,
]
async def comfy_entrypoint() -> NagExtension:
return NagExtension()

View File

@ -655,6 +655,7 @@ class BatchImagesMasksLatentsNode(io.ComfyNode):
batched = batch_masks(values) batched = batch_masks(values)
return io.NodeOutput(batched) return io.NodeOutput(batched)
class PostProcessingExtension(ComfyExtension): class PostProcessingExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:

View File

@ -0,0 +1,103 @@
from comfy_api.latest import ComfyExtension, io, ComfyAPI
api = ComfyAPI()
async def register_replacements():
"""Register all built-in node replacements."""
await register_replacements_longeredge()
await register_replacements_batchimages()
await register_replacements_upscaleimage()
await register_replacements_controlnet()
await register_replacements_load3d()
await register_replacements_preview3d()
await register_replacements_svdimg2vid()
await register_replacements_conditioningavg()
async def register_replacements_longeredge():
# No dynamic inputs here
await api.node_replacement.register(io.NodeReplace(
new_node_id="ImageScaleToMaxDimension",
old_node_id="ResizeImagesByLongerEdge",
old_widget_ids=["longer_edge"],
input_mapping=[
{"new_id": "image", "old_id": "images"},
{"new_id": "largest_size", "old_id": "longer_edge"},
{"new_id": "upscale_method", "set_value": "lanczos"},
],
# just to test the frontend output_mapping code, does nothing really here
output_mapping=[{"new_idx": 0, "old_idx": 0}],
))
async def register_replacements_batchimages():
# BatchImages node uses Autogrow
await api.node_replacement.register(io.NodeReplace(
new_node_id="BatchImagesNode",
old_node_id="ImageBatch",
input_mapping=[
{"new_id": "images.image0", "old_id": "image1"},
{"new_id": "images.image1", "old_id": "image2"},
],
))
async def register_replacements_upscaleimage():
# ResizeImageMaskNode uses DynamicCombo
await api.node_replacement.register(io.NodeReplace(
new_node_id="ResizeImageMaskNode",
old_node_id="ImageScaleBy",
old_widget_ids=["upscale_method", "scale_by"],
input_mapping=[
{"new_id": "input", "old_id": "image"},
{"new_id": "resize_type", "set_value": "scale by multiplier"},
{"new_id": "resize_type.multiplier", "old_id": "scale_by"},
{"new_id": "scale_method", "old_id": "upscale_method"},
],
))
async def register_replacements_controlnet():
# T2IAdapterLoader → ControlNetLoader
await api.node_replacement.register(io.NodeReplace(
new_node_id="ControlNetLoader",
old_node_id="T2IAdapterLoader",
input_mapping=[
{"new_id": "control_net_name", "old_id": "t2i_adapter_name"},
],
))
async def register_replacements_load3d():
# Load3DAnimation merged into Load3D
await api.node_replacement.register(io.NodeReplace(
new_node_id="Load3D",
old_node_id="Load3DAnimation",
))
async def register_replacements_preview3d():
# Preview3DAnimation merged into Preview3D
await api.node_replacement.register(io.NodeReplace(
new_node_id="Preview3D",
old_node_id="Preview3DAnimation",
))
async def register_replacements_svdimg2vid():
# Typo fix: SDV → SVD
await api.node_replacement.register(io.NodeReplace(
new_node_id="SVD_img2vid_Conditioning",
old_node_id="SDV_img2vid_Conditioning",
))
async def register_replacements_conditioningavg():
# Typo fix: trailing space in node name
await api.node_replacement.register(io.NodeReplace(
new_node_id="ConditioningAverage",
old_node_id="ConditioningAverage ",
))
class NodeReplacementsExtension(ComfyExtension):
async def on_load(self) -> None:
await register_replacements()
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return []
async def comfy_entrypoint() -> NodeReplacementsExtension:
return NodeReplacementsExtension()

View File

@ -0,0 +1,176 @@
from comfy_api.latest import ComfyExtension, io
from typing_extensions import override
class TextGenerate(io.ComfyNode):
@classmethod
def define_schema(cls):
# Define dynamic combo options for sampling mode
sampling_options = [
io.DynamicCombo.Option(
key="on",
inputs=[
io.Float.Input("temperature", default=0.7, min=0.01, max=2.0, step=0.000001),
io.Int.Input("top_k", default=64, min=0, max=1000),
io.Float.Input("top_p", default=0.95, min=0.0, max=1.0, step=0.01),
io.Float.Input("min_p", default=0.05, min=0.0, max=1.0, step=0.01),
io.Float.Input("repetition_penalty", default=1.05, min=0.0, max=5.0, step=0.01),
io.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff),
]
),
io.DynamicCombo.Option(
key="off",
inputs=[]
),
]
return io.Schema(
node_id="TextGenerate",
category="textgen/",
search_aliases=["LLM", "gemma"],
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("image", optional=True),
io.Int.Input("max_length", default=256, min=1, max=2048),
io.DynamicCombo.Input("sampling_mode", options=sampling_options, display_name="Sampling Mode"),
],
outputs=[
io.String.Output(display_name="generated_text"),
],
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
tokens = clip.tokenize(prompt, image=image, skip_template=False)
# Get sampling parameters from dynamic combo
do_sample = sampling_mode.get("sampling_mode") == "on"
temperature = sampling_mode.get("temperature", 1.0)
top_k = sampling_mode.get("top_k", 50)
top_p = sampling_mode.get("top_p", 1.0)
min_p = sampling_mode.get("min_p", 0.0)
seed = sampling_mode.get("seed", None)
repetition_penalty = sampling_mode.get("repetition_penalty", 1.0)
generated_ids = clip.generate(
tokens,
do_sample=do_sample,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
min_p=min_p,
repetition_penalty=repetition_penalty,
seed=seed
)
generated_text = clip.decode(generated_ids, skip_special_tokens=True)
return io.NodeOutput(generated_text)
LTX2_T2V_SYSTEM_PROMPT = """You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
#### Guidelines
- Strictly follow all aspects of the user's raw input: include every element requested (style, visuals, motions, actions, camera movement, audio).
- If the input is vague, invent concrete details: lighting, textures, materials, scene settings, etc.
- For characters: describe gender, clothing, hair, expressions. DO NOT invent unrequested characters.
- Use active language: present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
- Maintain chronological flow: use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape (background audio, ambient sounds, SFX, speech/music when requested). Integrate sounds chronologically alongside actions. Be specific (e.g., "soft footsteps on tile"), not vague (e.g., "ambient sound is present").
- Speech (only when requested):
- For ANY speech-related input (talking, conversation, singing, etc.), ALWAYS include exact words in quotes with voice characteristics (e.g., "The man says in an excited voice: 'You won't believe what I just saw!'").
- Specify language if not English and accent if relevant.
- Style: Include visual style at the beginning: "Style: <style>, <rest of prompt>." Default to cinematic-realistic if unspecified. Omit if unclear.
- Visual and audio only: NO non-visual/auditory senses (smell, taste, touch).
- Restrained language: Avoid dramatic/exaggerated terms. Use mild, natural phrasing.
- Colors: Use plain terms ("red dress"), not intensified ("vibrant blue," "bright red").
- Lighting: Use neutral descriptions ("soft overhead light"), not harsh ("blinding light").
- Facial features: Use delicate modifiers for subtle features (i.e., "subtle freckles").
#### Important notes:
- Analyze the user's raw input carefully. In cases of FPV or POV, exclude the description of the subject whose POV is requested.
- Camera motion: DO NOT invent camera motion unless requested by the user.
- Speech: DO NOT modify user-provided character dialogue unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Format: DO NOT use phrases like "The scene opens with...". Start directly with Style (optional) and chronological scene description.
- Format: DO NOT start your response with special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- If the user's raw input prompt is highly detailed, chronological and in the requested format: DO NOT make major edits or introduce new elements. Add/enhance audio descriptions if missing.
#### Output Format (Strict):
- Single continuous paragraph in natural language (English).
- NO titles, headings, prefaces, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
Your output quality is CRITICAL. Generate visually rich, dynamic prompts with integrated audio for high-quality video generation.
#### Example
Input: "A woman at a coffee shop talking on the phone"
Output:
Style: realistic with cinematic lighting. In a medium close-up, a woman in her early 30s with shoulder-length brown hair sits at a small wooden table by the window. She wears a cream-colored turtleneck sweater, holding a white ceramic coffee cup in one hand and a smartphone to her ear with the other. Ambient cafe sounds fill the spaceespresso machine hiss, quiet conversations, gentle clinking of cups. The woman listens intently, nodding slightly, then takes a sip of her coffee and sets it down with a soft clink. Her face brightens into a warm smile as she speaks in a clear, friendly voice, 'That sounds perfect! I'd love to meet up this weekend. How about Saturday afternoon?' She laughs softly—a genuine chuckle—and shifts in her chair. Behind her, other patrons move subtly in and out of focus. 'Great, I'll see you then,' she concludes cheerfully, lowering the phone.
"""
LTX2_I2V_SYSTEM_PROMPT = """You are a Creative Assistant. Given a user's raw input prompt describing a scene or concept, expand it into a detailed video generation prompt with specific visuals and integrated audio to guide a text-to-video model.
You are a Creative Assistant writing concise, action-focused image-to-video prompts. Given an image (first frame) and user Raw Input Prompt, generate a prompt to guide video generation from that image.
#### Guidelines:
- Analyze the Image: Identify Subject, Setting, Elements, Style and Mood.
- Follow user Raw Input Prompt: Include all requested motion, actions, camera movements, audio, and details. If in conflict with the image, prioritize user request while maintaining visual consistency (describe transition from image to user's scene).
- Describe only changes from the image: Don't reiterate established visual details. Inaccurate descriptions may cause scene cuts.
- Active language: Use present-progressive verbs ("is walking," "speaking"). If no action specified, describe natural movements.
- Chronological flow: Use temporal connectors ("as," "then," "while").
- Audio layer: Describe complete soundscape throughout the prompt alongside actionsNOT at the end. Align audio intensity with action tempo. Include natural background audio, ambient sounds, effects, speech or music (when requested). Be specific (e.g., "soft footsteps on tile") not vague (e.g., "ambient sound").
- Speech (only when requested): Provide exact words in quotes with character's visual/voice characteristics (e.g., "The tall man speaks in a low, gravelly voice"), language if not English and accent if relevant. If general conversation mentioned without text, generate contextual quoted dialogue. (i.e., "The man is talking" input -> the output should include exact spoken words, like: "The man is talking in an excited voice saying: 'You won't believe what I just saw!' His hands gesture expressively as he speaks, eyebrows raised with enthusiasm. The ambient sound of a quiet room underscores his animated speech.")
- Style: Include visual style at beginning: "Style: <style>, <rest of prompt>." If unclear, omit to avoid conflicts.
- Visual and audio only: Describe only what is seen and heard. NO smell, taste, or tactile sensations.
- Restrained language: Avoid dramatic terms. Use mild, natural, understated phrasing.
#### Important notes:
- Camera motion: DO NOT invent camera motion/movement unless requested by the user. Make sure to include camera motion only if specified in the input.
- Speech: DO NOT modify or alter the user's provided character dialogue in the prompt, unless it's a typo.
- No timestamps or cuts: DO NOT use timestamps or describe scene cuts unless explicitly requested.
- Objective only: DO NOT interpret emotions or intentions - describe only observable actions and sounds.
- Format: DO NOT use phrases like "The scene opens with..." / "The video starts...". Start directly with Style (optional) and chronological scene description.
- Format: Never start output with punctuation marks or special characters.
- DO NOT invent dialogue unless the user mentions speech/talking/singing/conversation.
- Your performance is CRITICAL. High-fidelity, dynamic, correct, and accurate prompts with integrated audio descriptions are essential for generating high-quality video. Your goal is flawless execution of these rules.
#### Output Format (Strict):
- Single concise paragraph in natural English. NO titles, headings, prefaces, sections, code fences, or Markdown.
- If unsafe/invalid, return original user prompt. Never ask questions or clarifications.
#### Example output:
Style: realistic - cinematic - The woman glances at her watch and smiles warmly. She speaks in a cheerful, friendly voice, "I think we're right on time!" In the background, a café barista prepares drinks at the counter. The barista calls out in a clear, upbeat tone, "Two cappuccinos ready!" The sound of the espresso machine hissing softly blends with gentle background chatter and the light clinking of cups on saucers.
"""
class TextGenerateLTX2Prompt(TextGenerate):
@classmethod
def define_schema(cls):
parent_schema = super().define_schema()
return io.Schema(
node_id="TextGenerateLTX2Prompt",
category=parent_schema.category,
inputs=parent_schema.inputs,
outputs=parent_schema.outputs,
search_aliases=["prompt enhance", "LLM", "gemma"],
)
@classmethod
def execute(cls, clip, prompt, max_length, sampling_mode, image=None) -> io.NodeOutput:
if image is None:
formatted_prompt = f"<start_of_turn>system\n{LTX2_T2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
else:
formatted_prompt = f"<start_of_turn>system\n{LTX2_I2V_SYSTEM_PROMPT.strip()}<end_of_turn>\n<start_of_turn>user\n\n<image_soft_token>\n\nUser Raw Input Prompt: {prompt}.<end_of_turn>\n<start_of_turn>model\n"
return super().execute(clip, formatted_prompt, max_length, sampling_mode, image)
class TextgenExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextGenerate,
TextGenerateLTX2Prompt,
]
async def comfy_entrypoint() -> TextgenExtension:
return TextgenExtension()

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.13.0" __version__ = "0.14.1"

View File

@ -2264,6 +2264,7 @@ async def load_custom_node(module_path: str, ignore=set(), module_parent="custom
if not isinstance(extension, ComfyExtension): if not isinstance(extension, ComfyExtension):
logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.") logging.warning(f"comfy_entrypoint in {module_path} did not return a ComfyExtension, skipping.")
return False return False
await extension.on_load()
node_list = await extension.get_node_list() node_list = await extension.get_node_list()
if not isinstance(node_list, list): if not isinstance(node_list, list):
logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.") logging.warning(f"comfy_entrypoint in {module_path} did not return a list of nodes, skipping.")
@ -2433,8 +2434,11 @@ async def init_builtin_extra_nodes():
"nodes_image_compare.py", "nodes_image_compare.py",
"nodes_zimage.py", "nodes_zimage.py",
"nodes_lora_debug.py", "nodes_lora_debug.py",
"nodes_textgen.py",
"nodes_color.py", "nodes_color.py",
"nodes_toolkit.py", "nodes_toolkit.py",
"nodes_replacements.py",
"nodes_nag.py",
] ]
import_failed = [] import_failed = []

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.13.0" version = "0.14.1"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.38.14 comfyui-frontend-package==1.39.14
comfyui-workflow-templates==0.8.38 comfyui-workflow-templates==0.8.43
comfyui-embedded-docs==0.4.1 comfyui-embedded-docs==0.4.1
torch torch
torchsde torchsde

View File

@ -40,6 +40,7 @@ from app.user_manager import UserManager
from app.model_manager import ModelFileManager from app.model_manager import ModelFileManager
from app.custom_node_manager import CustomNodeManager from app.custom_node_manager import CustomNodeManager
from app.subgraph_manager import SubgraphManager from app.subgraph_manager import SubgraphManager
from app.node_replace_manager import NodeReplaceManager
from typing import Optional, Union from typing import Optional, Union
from api_server.routes.internal.internal_routes import InternalRoutes from api_server.routes.internal.internal_routes import InternalRoutes
from protocol import BinaryEventTypes from protocol import BinaryEventTypes
@ -204,6 +205,7 @@ class PromptServer():
self.model_file_manager = ModelFileManager() self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager() self.custom_node_manager = CustomNodeManager()
self.subgraph_manager = SubgraphManager() self.subgraph_manager = SubgraphManager()
self.node_replace_manager = NodeReplaceManager()
self.internal_routes = InternalRoutes(self) self.internal_routes = InternalRoutes(self)
self.supports = ["custom_nodes_from_web"] self.supports = ["custom_nodes_from_web"]
self.prompt_queue = execution.PromptQueue(self) self.prompt_queue = execution.PromptQueue(self)
@ -887,6 +889,8 @@ class PromptServer():
if "partial_execution_targets" in json_data: if "partial_execution_targets" in json_data:
partial_execution_targets = json_data["partial_execution_targets"] partial_execution_targets = json_data["partial_execution_targets"]
self.node_replace_manager.apply_replacements(prompt)
valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets) valid = await execution.validate_prompt(prompt_id, prompt, partial_execution_targets)
extra_data = {} extra_data = {}
if "extra_data" in json_data: if "extra_data" in json_data:
@ -995,6 +999,7 @@ class PromptServer():
self.model_file_manager.add_routes(self.routes) self.model_file_manager.add_routes(self.routes)
self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items()) self.custom_node_manager.add_routes(self.routes, self.app, nodes.LOADED_MODULE_DIRS.items())
self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items()) self.subgraph_manager.add_routes(self.routes, nodes.LOADED_MODULE_DIRS.items())
self.node_replace_manager.add_routes(self.routes)
self.app.add_subapp('/internal', self.internal_routes.get_app()) self.app.add_subapp('/internal', self.internal_routes.get_app())
# Prefix every route with /api for easier matching for delegation. # Prefix every route with /api for easier matching for delegation.