Merge branch 'master' into matt/asset-image-dimensions-metadata
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run

This commit is contained in:
Matt Miller 2026-06-08 14:26:12 -07:00 committed by GitHub
commit 7a2594ac5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
193 changed files with 20167 additions and 12921 deletions

View File

@ -1,5 +1,4 @@
As of the time of writing this you need this driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
As of the time of writing this you need a recent driver. Updating to the latest driver is recommended.
HOW TO RUN:
@ -7,9 +6,9 @@ If you have a AMD gpu:
run_amd_gpu.bat
If you have memory issues you can try disabling the smart memory management by running comfyui with:
If you have memory issues you can try enabling the new dynamic memory management by running comfyui with:
run_amd_gpu_disable_smart_memory.bat
run_amd_gpu_enable_dynamic_vram.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints

View File

@ -17,7 +17,7 @@ jobs:
- name: Check for Windows line endings (CRLF)
run: |
# Get the list of changed files in the PR
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} -- ':!.ci')
# Flag to track if CRLF is found
CRLF_FOUND=false

View File

@ -0,0 +1,24 @@
name: Detect Unreviewed Merge
# SOC 2 compliance — reusable workflow lives in Comfy-Org/github-workflows,
# tracking issues are filed in Comfy-Org/unreviewed-merges.
on:
push:
branches: [master]
concurrency:
group: detect-unreviewed-merge-${{ github.sha }}
cancel-in-progress: false
permissions:
contents: read
pull-requests: read
jobs:
detect:
uses: Comfy-Org/github-workflows/.github/workflows/detect-unreviewed-merge.yml@4d9cb6b87f953bb7cd69954280e1465fb9bd2040 # v1
with:
approval-mode: latest-per-reviewer
secrets:
UNREVIEWED_MERGES_TOKEN: ${{ secrets.UNREVIEWED_MERGES_TOKEN }}

View File

@ -105,7 +105,7 @@ class WindowAttention(nn.Module):
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = comfy.ops.cast_to_input(relative_position_bias.permute(2, 0, 1).contiguous(), attn) # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:

View File

@ -55,12 +55,7 @@ class BackgroundRemovalModel():
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
if mask.ndim == 3:
mask = mask.unsqueeze(0)
if mask.shape[1] != 1:
mask = mask.movedim(-1, 1)
return mask
return mask.squeeze(1) # (B, 1, H, W) -> (B, H, W)
def load_background_removal_model(sd):

View File

@ -149,6 +149,7 @@ parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=Non
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
parser.add_argument("--fast-disk", action="store_true", help="Prefer disk-backed dynamic loading and offload over unpinned RAM. Can be faster for users with fast NVME disks.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")

View File

@ -9,6 +9,7 @@ import comfy.model_management
import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
import comfy.image_encoders.dino3
class Output:
def __getitem__(self, key):
@ -23,12 +24,16 @@ IMAGE_ENCODERS = {
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
"dinov3": comfy.image_encoders.dino3.DINOv3ViTModel,
}
class ClipVisionModel():
def __init__(self, json_config):
with open(json_config) as f:
config = json.load(f)
if isinstance(json_config, dict):
config = json_config
else:
with open(json_config) as f:
config = json.load(f)
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
@ -134,6 +139,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
elif 'layer.0.mlp.gate_proj.weight' in sd and 'layer.31.norm1.weight' in sd: # Dinov3 ViT-H/16+ (SwiGLU gated MLP, 32 layers)
json_config = comfy.image_encoders.dino3.DINOV3_VITH_CONFIG
else:
return None

View File

@ -1,5 +1,20 @@
import logging
import torch
_CK_STOCHASTIC_ROUNDING_AVAILABLE = False
try:
import comfy_kitchen as ck
_ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8
_CK_STOCHASTIC_ROUNDING_AVAILABLE = True
except (AttributeError, ImportError):
logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.")
if not _CK_STOCHASTIC_ROUNDING_AVAILABLE:
def _ck_stochastic_rounding_fp8(value, rng, dtype):
raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding")
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
mantissa_scaled = torch.where(
normal_mask,
@ -57,6 +72,10 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
generator.manual_seed(seed)
if _CK_STOCHASTIC_ROUNDING_AVAILABLE:
rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator)
return _ck_stochastic_rounding_fp8(value, rng, dtype)
output = torch.empty_like(value, dtype=dtype)
num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))

View File

@ -0,0 +1,259 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.ops
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
# DINOv3 ViT-H/16+ (SwiGLU)
DINOV3_VITH_CONFIG = {
"model_type": "dinov3",
"num_hidden_layers": 32,
"hidden_size": 1280,
"num_attention_heads": 20,
"num_register_tokens": 4,
"intermediate_size": 5120,
"layer_norm_eps": 1e-5,
"num_channels": 3,
"patch_size": 16,
"rope_theta": 100.0,
"use_gated_mlp": True,
"gated_mlp_act": "silu",
"image_size": 1024,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225],
}
class DINOv3ViTMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
self.act_fn = torch.nn.GELU()
def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
num_tokens = q.shape[-2]
num_patches = sin.shape[-2]
num_prefix_tokens = num_tokens - num_patches
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
return q, k
class DINOv3ViTAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
super().__init__()
self.embed_dim = hidden_size
self.num_heads = num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
def forward(self, hidden_states, attention_mask=None, position_embeddings=None, **kwargs):
batch_size, patches, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
attn = optimized_attention_for_device(query_states.device, mask=False)
attn_output = attn(
query_states, key_states, value_states, self.num_heads, attention_mask,
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class DINOv3ViTGatedMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations, act="silu"):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
self.act_fn = torch.nn.SiLU() if act == "silu" else torch.nn.GELU()
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def get_patches_center_coordinates(num_patches_h, num_patches_w, dtype, device):
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
coords_h = coords_h / num_patches_h
coords_w = coords_w / num_patches_w
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
coords = coords.flatten(0, 1)
coords = 2.0 * coords - 1.0
return coords
class DINOv3ViTRopePositionEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, rope_theta, hidden_size, num_attention_heads, patch_size, device, dtype):
super().__init__()
self.base = rope_theta
self.head_dim = hidden_size // num_attention_heads
self.patch_size = patch_size
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, pixel_values):
_, _, height, width = pixel_values.shape
num_patches_h = height // self.patch_size
num_patches_w = width // self.patch_size
patch_coords = get_patches_center_coordinates(num_patches_h, num_patches_w, dtype=torch.float32, device=pixel_values.device)
self.inv_freq = self.inv_freq.to(pixel_values.device)
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
angles = angles.flatten(1, 2)
angles = angles.tile(2)
cos = torch.cos(angles).to(dtype=pixel_values.dtype)
sin = torch.sin(angles).to(dtype=pixel_values.dtype)
return cos, sin
class DINOv3ViTEmbeddings(nn.Module):
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
super().__init__()
self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
self.patch_embeddings = operations.Conv2d(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
)
def forward(self, pixel_values, bool_masked_pos=None):
batch_size = pixel_values.shape[0]
patch_embeddings = self.patch_embeddings(pixel_values)
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
if bool_masked_pos is not None:
mask_token = comfy.ops.cast_to_input(self.mask_token, patch_embeddings)
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
cls_token = comfy.ops.cast_to_input(self.cls_token.expand(batch_size, -1, -1), patch_embeddings)
register_tokens = comfy.ops.cast_to_input(self.register_tokens.expand(batch_size, -1, -1), patch_embeddings)
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
return embeddings
class DINOv3ViTLayer(nn.Module):
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size,
num_attention_heads, device, dtype, operations, gated_mlp_act="silu"):
super().__init__()
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
if use_gated_mlp:
self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations, act=gated_mlp_act)
else:
self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations)
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.attention(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings)
hidden_states = self.layer_scale1(hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.layer_scale2(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class DINOv3ViTModel(nn.Module):
def __init__(self, config, dtype, device, operations):
super().__init__()
num_hidden_layers = config["num_hidden_layers"]
hidden_size = config["hidden_size"]
num_attention_heads = config["num_attention_heads"]
num_register_tokens = config["num_register_tokens"]
intermediate_size = config["intermediate_size"]
layer_norm_eps = config["layer_norm_eps"]
num_channels = config["num_channels"]
patch_size = config["patch_size"]
rope_theta = config["rope_theta"]
use_gated_mlp = config.get("use_gated_mlp", False)
gated_mlp_act = config.get("gated_mlp_act", "silu")
self.embeddings = DINOv3ViTEmbeddings(
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size,
dtype=dtype, device=device, operations=operations
)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device
)
self.layer = nn.ModuleList([
DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=use_gated_mlp, mlp_bias=True,
intermediate_size=intermediate_size, num_attention_heads=num_attention_heads,
dtype=dtype, device=device, operations=operations, gated_mlp_act=gated_mlp_act)
for _ in range(num_hidden_layers)])
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def forward(self, pixel_values, bool_masked_pos=None, **kwargs):
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
position_embeddings = self.rope_embeddings(pixel_values)
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, position_embeddings=position_embeddings)
if kwargs.get("skip_norm_elementwise", False):
sequence_output = F.layer_norm(hidden_states, hidden_states.shape[-1:])
else:
norm = self.norm.to(hidden_states.device)
sequence_output = norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return sequence_output, None, pooled_output, None

View File

@ -4,6 +4,7 @@ class LatentFormat:
scale_factor = 1.0
latent_channels = 4
latent_dimensions = 2
preserve_empty_channel_multiples = False
latent_rgb_factors = None
latent_rgb_factors_bias = None
latent_rgb_factors_reshape = None
@ -239,6 +240,16 @@ class Flux2(LatentFormat):
def process_out(self, latent):
return latent
class TripoSplat(LatentFormat):
# Sequence latent (B, 8192, 16) the camera token rides alongside as a second nested latent
latent_channels = 16
def process_in(self, latent):
return latent
def process_out(self, latent):
return latent
class Mochi(LatentFormat):
latent_channels = 12
latent_dimensions = 3
@ -769,6 +780,10 @@ class ACEAudio(LatentFormat):
latent_channels = 8
latent_dimensions = 2
class SeedVR2(LatentFormat):
latent_channels = 16
preserve_empty_channel_multiples = True
class ACEAudio15(LatentFormat):
latent_channels = 64
latent_dimensions = 1

View File

@ -433,11 +433,11 @@ class Attention(nn.Module):
if self.differential:
q, q_diff = q.unbind(dim=1)
k, k_diff = k.unbind(dim=1)
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options)
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out = out - out_diff
else:
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
out = self.to_out(out)

View File

@ -138,11 +138,11 @@ class Attention(nn.Module):
k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype)
if self.differential:
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True))
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True, low_precision_attention=False))
del q, k, v, q_diff, k_diff
else:
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
del q, k, v
return self.to_out(out)

View File

@ -38,6 +38,8 @@ class ChromaRadianceParams(ChromaParams):
# None means use the same dtype as the model.
nerf_embedder_dtype: Optional[torch.dtype]
use_x0: bool
# Use sequential txt_ids instead of zeros
use_sequential_txt_ids: bool
class ChromaRadiance(Chroma):
"""
@ -162,6 +164,9 @@ class ChromaRadiance(Chroma):
if params.use_x0:
self.register_buffer("__x0__", torch.tensor([]))
if params.use_sequential_txt_ids:
self.register_buffer("__sequential__", torch.tensor([]))
@property
def _nerf_final_layer(self) -> nn.Module:
if self.params.nerf_final_head_type == "linear":
@ -313,6 +318,9 @@ class ChromaRadiance(Chroma):
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
# Radiance after 2026-05-22 uses sequential txt_ids instead of zeros
if params.use_sequential_txt_ids:
txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1)
img_out = self.forward_orig(
img,

View File

@ -14,15 +14,7 @@ from torchvision import transforms
import comfy.patcher_extension
from comfy.ldm.modules.attention import optimized_attention
import comfy.ldm.common_dit
def apply_rotary_pos_emb(
t: torch.Tensor,
freqs: torch.Tensor,
) -> torch.Tensor:
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
return t_out
import comfy.quant_ops
# ---------------------- Feed Forward Network -----------------------
@ -173,8 +165,7 @@ class Attention(nn.Module):
k = self.k_norm(k)
v = self.v_norm(v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
q = apply_rotary_pos_emb(q, rope_emb)
k = apply_rotary_pos_emb(k, rope_emb)
q, k = comfy.quant_ops.ck.apply_rope_split_half(q, k, rope_emb)
return q, k, v
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)

View File

@ -5,6 +5,7 @@ import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import comfy.quant_ops
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
@ -19,15 +20,6 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
return out.to(dtype=torch.float32, device=pos.device)
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
rot_dim = freqs_cis.shape[-1]
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
cos_ = freqs_cis[0]
sin_ = freqs_cis[1]
x1, x2 = x.chunk(2, dim=-1)
x_rotated = torch.cat((-x2, x1), dim=-1)
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
class ErnieImageEmbedND3(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: tuple):
super().__init__()
@ -37,8 +29,16 @@ class ErnieImageEmbedND3(nn.Module):
def forward(self, ids: torch.Tensor) -> torch.Tensor:
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2]
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
cos_ = emb[0]
sin_ = emb[1]
N = cos_.shape[-1]
half = N // 2
cos_top = cos_[..., :half].repeat_interleave(2, dim=-1)
sin_top = sin_[..., :half].repeat_interleave(2, dim=-1)
cos_bot = cos_[..., half:].repeat_interleave(2, dim=-1)
sin_bot = sin_[..., half:].repeat_interleave(2, dim=-1)
rot = torch.stack([cos_top, -sin_top, sin_bot, cos_bot], dim=-1)
return rot.reshape(*rot.shape[:-1], 2, 2).unsqueeze(2)
class ErnieImagePatchEmbedDynamic(nn.Module):
def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None):
@ -115,8 +115,7 @@ class ErnieImageAttention(nn.Module):
key = self.norm_k(key)
if image_rotary_emb is not None:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
query, key = comfy.quant_ops.ck.apply_rope_split_half(query, key, image_rotary_emb)
q_flat = query.reshape(B, S, -1)
k_flat = key.reshape(B, S, -1)
@ -274,7 +273,7 @@ class ErnieImageModel(nn.Module):
image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1)
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
del image_ids, text_ids
sample = self.time_proj(timesteps).to(dtype)

View File

@ -4,7 +4,7 @@ from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import logging
import comfy.quant_ops
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
@ -44,21 +44,15 @@ def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
try:
import comfy.quant_ops
q_apply_rope = comfy.quant_ops.ck.apply_rope
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
def apply_rope(xq, xk, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope(xq, xk, freqs_cis)
else:
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
def apply_rope1(x, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope1(x, freqs_cis)
else:
return q_apply_rope1(x, freqs_cis)
except:
logging.warning("No comfy kitchen, using old apply_rope functions.")
apply_rope = _apply_rope
apply_rope1 = _apply_rope1
def apply_rope(xq, xk, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope(xq, xk, freqs_cis)
else:
return comfy.quant_ops.ck.apply_rope(xq, xk, freqs_cis)
def apply_rope1(x, freqs_cis):
if comfy.model_management.in_training:
return _apply_rope1(x, freqs_cis)
else:
return comfy.quant_ops.ck.apply_rope1(x, freqs_cis)

View File

@ -0,0 +1,297 @@
"""
The Ideogram 4 transformer is a NextDiT/Lumina2-family single-stream model
consumes Qwen3-VL hidden-state features (concatenated from 13 layers -> 53248 dims)
packs ``[text tokens, image tokens]`` into one sequence with block-diagonal segment attention and 3D interleaved MRoPE.
"""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.patcher_extension
from comfy.ldm.lumina.model import FeedForward
from comfy.ldm.modules.attention import optimized_attention_masked
from comfy.text_encoders.llama import apply_rope, precompute_freqs_cis
# Per-token role indicators
SEQUENCE_PADDING_INDICATOR = -1
OUTPUT_IMAGE_INDICATOR = 2
LLM_TOKEN_INDICATOR = 3
# Image grid coordinates are offset so they never collide with text positions
IMAGE_POSITION_OFFSET = 65536
class Ideogram4Attention(nn.Module):
def __init__(self, hidden_size, num_heads, eps=1e-5, dtype=None, device=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.hidden_size = hidden_size
self.qkv = operations.Linear(hidden_size, hidden_size * 3, bias=False, dtype=dtype, device=device)
self.norm_q = operations.RMSNorm(self.head_dim, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.norm_k = operations.RMSNorm(self.head_dim, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
self.o = operations.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device)
def forward(self, x, attn_mask, freqs_cis, transformer_options={}):
batch_size, seq_len, _ = x.shape
qkv = self.qkv(x).view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(dim=2)
q = self.norm_q(q)
k = self.norm_k(k)
# (B, heads, L, head_dim)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
q, k = apply_rope(q, k, freqs_cis)
out = optimized_attention_masked(q, k, v, self.num_heads, attn_mask, skip_reshape=True, transformer_options=transformer_options)
return self.o(out)
class Ideogram4TransformerBlock(nn.Module):
def __init__(self, hidden_size, intermediate_size, num_heads, norm_eps, adaln_dim, dtype=None, device=None, operations=None):
super().__init__()
self.attention = Ideogram4Attention(hidden_size, num_heads, eps=1e-5, dtype=dtype, device=device, operations=operations)
self.feed_forward = FeedForward(
dim=hidden_size, hidden_dim=intermediate_size, multiple_of=1, ffn_dim_multiplier=None,
operation_settings={"operations": operations, "dtype": dtype, "device": device},
)
self.attention_norm1 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
self.ffn_norm1 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
self.attention_norm2 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
self.ffn_norm2 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
self.adaln_modulation = operations.Linear(adaln_dim, 4 * hidden_size, bias=True, dtype=dtype, device=device)
def forward(self, x, attn_mask, freqs_cis, adaln_input, transformer_options={}):
mod = self.adaln_modulation(adaln_input)
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.chunk(4, dim=-1)
gate_msa = torch.tanh(gate_msa)
gate_mlp = torch.tanh(gate_mlp)
scale_msa = 1.0 + scale_msa
scale_mlp = 1.0 + scale_mlp
attn_out = self.attention(self.attention_norm1(x) * scale_msa, attn_mask, freqs_cis, transformer_options=transformer_options)
x = x + gate_msa * self.attention_norm2(attn_out)
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
return x
def _sinusoidal_embedding(t, dim, scale=1e4):
t = t.to(torch.float32)
half = dim // 2
freq = math.log(scale) / (half - 1)
freq = torch.exp(torch.arange(half, dtype=torch.float32, device=t.device) * -freq)
emb = t.unsqueeze(-1) * freq
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
if dim % 2 == 1:
emb = F.pad(emb, (0, 1))
return emb
class Ideogram4EmbedScalar(nn.Module):
def __init__(self, dim, input_range=(0.0, 1.0), dtype=None, device=None, operations=None):
super().__init__()
self.dim = dim
self.range_min, self.range_max = input_range
self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
def forward(self, x):
x = x.to(torch.float32)
scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min)
emb = _sinusoidal_embedding(scaled, self.dim)
emb = emb.to(self.mlp_in.weight.dtype)
emb = F.silu(self.mlp_in(emb))
return self.mlp_out(emb)
class Ideogram4FinalLayer(nn.Module):
def __init__(self, hidden_size, out_channels, adaln_dim, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
self.adaln_modulation = operations.Linear(adaln_dim, hidden_size, bias=True, dtype=dtype, device=device)
def forward(self, x, c):
scale = 1.0 + self.adaln_modulation(F.silu(c))
return self.linear(self.norm_final(x) * scale)
class Ideogram4Transformer(nn.Module):
"""A single Ideogram 4 backbone operating on a packed token sequence."""
def __init__(self, emb_dim, num_layers, num_heads, intermediate_size, adaln_dim,
in_channels, llm_features_dim, rope_theta, mrope_section, norm_eps,
dtype=None, device=None, operations=None):
super().__init__()
self.head_dim = emb_dim // num_heads
self.rope_theta = rope_theta
self.mrope_section = tuple(mrope_section)
self.input_proj = operations.Linear(in_channels, emb_dim, bias=True, dtype=dtype, device=device)
self.llm_cond_norm = operations.RMSNorm(llm_features_dim, eps=1e-6, elementwise_affine=True, dtype=dtype, device=device)
self.llm_cond_proj = operations.Linear(llm_features_dim, emb_dim, bias=True, dtype=dtype, device=device)
self.t_embedding = Ideogram4EmbedScalar(emb_dim, input_range=(0.0, 1.0), dtype=dtype, device=device, operations=operations)
self.adaln_proj = operations.Linear(emb_dim, adaln_dim, bias=True, dtype=dtype, device=device)
self.embed_image_indicator = operations.Embedding(2, emb_dim, dtype=dtype, device=device)
self.layers = nn.ModuleList([
Ideogram4TransformerBlock(emb_dim, intermediate_size, num_heads, norm_eps, adaln_dim,
dtype=dtype, device=device, operations=operations)
for _ in range(num_layers)
])
self.final_layer = Ideogram4FinalLayer(emb_dim, in_channels, adaln_dim, dtype=dtype, device=device, operations=operations)
def _backbone(self, llm_features, x, t, position_ids, attn_mask, indicator, transformer_options={}):
indicator = indicator.to(torch.long)
output_image_mask = (indicator == OUTPUT_IMAGE_INDICATOR).to(x.dtype).unsqueeze(-1)
x = x * output_image_mask
h = self.input_proj(x) * output_image_mask
t_cond = self.t_embedding(t)
if t.dim() == 1:
t_cond = t_cond.unsqueeze(1)
adaln_input = F.silu(self.adaln_proj(t_cond))
# h is zero on the text rows (content lives only on image rows), add writes the text features in place
if llm_features is not None:
L_text = llm_features.shape[1]
text_mask = (indicator[:, :L_text] == LLM_TOKEN_INDICATOR).to(x.dtype).unsqueeze(-1)
llm = self.llm_cond_norm(llm_features * text_mask)
llm = self.llm_cond_proj(llm) * text_mask
h[:, :L_text] = h[:, :L_text] + llm
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long), out_dtype=h.dtype)
# Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch).
freqs_cis = precompute_freqs_cis(
self.head_dim, position_ids[0].transpose(0, 1), self.rope_theta,
rope_dims=self.mrope_section, interleaved_mrope=True, device=position_ids.device,
)
if attn_mask is not None and attn_mask.dtype == torch.bool:
attn_mask = torch.zeros_like(attn_mask, dtype=h.dtype).masked_fill_(~attn_mask, -torch.finfo(h.dtype).max)
for layer in self.layers:
h = layer(h, attn_mask, freqs_cis, adaln_input, transformer_options=transformer_options)
return self.final_layer(h, adaln_input)
class Ideogram4Transformer2DModel(Ideogram4Transformer):
"""Ideogram 4 single-stream DiT.
Runs a packed ``[text, image]`` sequence when text context is supplied, or an image-only sequence when ``context is None``.
"""
def __init__(self, image_model=None, in_channels=128, num_layers=34, num_attention_heads=18, attention_head_dim=256, intermediate_size=12288,
adaln_dim=512, llm_features_dim=53248, rope_theta=5000000, mrope_section=(24, 20, 20), norm_eps=1e-5,
dtype=None, device=None, operations=None, **kwargs):
emb_dim = num_attention_heads * attention_head_dim
super().__init__(
emb_dim=emb_dim, num_layers=num_layers, num_heads=num_attention_heads,
intermediate_size=intermediate_size, adaln_dim=adaln_dim, in_channels=in_channels,
llm_features_dim=llm_features_dim, rope_theta=rope_theta, mrope_section=mrope_section,
norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
self.dtype = dtype
self.in_channels = in_channels
self.out_channels = in_channels
# 128-dim token = patch (2x2) * ae_channels (32).
self.patch_size = 2
self.ae_channels = in_channels // (self.patch_size * self.patch_size)
def _img_to_tokens(self, x):
B, C, gh, gw = x.shape
x = x.view(B, self.ae_channels, self.patch_size, self.patch_size, gh, gw)
x = x.permute(0, 4, 5, 2, 3, 1) # (B, gh, gw, pi, pj, c)
return x.reshape(B, gh * gw, C)
def _tokens_to_img(self, tokens, gh, gw):
B = tokens.shape[0]
C = tokens.shape[-1]
x = tokens.reshape(B, gh, gw, self.patch_size, self.patch_size, self.ae_channels)
x = x.permute(0, 5, 3, 4, 1, 2) # (B, c, pi, pj, gh, gw)
return x.reshape(B, C, gh, gw)
def _image_position_ids(self, gh, gw, device):
h_idx = torch.arange(gh, device=device).view(-1, 1).expand(gh, gw).reshape(-1)
w_idx = torch.arange(gw, device=device).view(1, -1).expand(gh, gw).reshape(-1)
t_idx = torch.zeros_like(h_idx)
return torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET # (L_img, 3)
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options):
B = x_chunk.shape[0]
device = x_chunk.device
img_tokens = self._img_to_tokens(x_chunk)
L_img = img_tokens.shape[1]
L_text = context_chunk.shape[1]
L = L_text + L_img
latent_dim = img_tokens.shape[-1]
x_full = torch.zeros(B, L, latent_dim, dtype=img_tokens.dtype, device=device)
x_full[:, L_text:] = img_tokens
text_pos = torch.arange(L_text, device=device).view(-1, 1).expand(L_text, 3)
img_pos = self._image_position_ids(gh, gw, device)
position_ids = torch.cat([text_pos, img_pos], dim=0).unsqueeze(0).expand(B, L, 3)
indicator = torch.empty(B, L, dtype=torch.long, device=device)
indicator[:, :L_text] = LLM_TOKEN_INDICATOR
indicator[:, L_text:] = OUTPUT_IMAGE_INDICATOR
attn_mask = None
if attn_mask_chunk is not None:
segment_ids = torch.ones(B, L, dtype=torch.long, device=device)
pad = (attn_mask_chunk == 0)
segment_ids[:, :L_text][pad] = SEQUENCE_PADDING_INDICATOR
indicator[:, :L_text][pad] = 0
# Block-diagonal mask from segment ids: (B, 1, L, L), True = attend.
attn_mask = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).unsqueeze(1)
out = self._backbone(context_chunk, x_full, t_chunk, position_ids, attn_mask, indicator,
transformer_options=transformer_options)
return self._tokens_to_img(out[:, L_text:], gh, gw)
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options):
B = x_chunk.shape[0]
device = x_chunk.device
img_tokens = self._img_to_tokens(x_chunk)
L_img = img_tokens.shape[1]
position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3)
indicator = torch.full((B, L_img), OUTPUT_IMAGE_INDICATOR, dtype=torch.long, device=device)
# Image-only sequence is a single segment -> no mask, full attention, no LLM context.
out = self._backbone(None, img_tokens, t_chunk, position_ids, None, indicator, transformer_options=transformer_options)
return self._tokens_to_img(out, gh, gw)
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
bs, c, gh, gw = x.shape
timesteps = 1.0 - timesteps
# unconditional pass
if context is None:
return -self._run_image_only(x, timesteps, gh, gw, transformer_options)
return -self._run_conditional(x, context, attention_mask, timesteps, gh, gw, transformer_options)

View File

@ -735,7 +735,86 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
)
return out
def _var_attention_qkv(q, k, v, heads, skip_reshape):
if skip_reshape:
return q, k, v, q.shape[-1]
total_tokens, embed_dim = q.shape
head_dim = embed_dim // heads
return (
q.view(total_tokens, heads, head_dim),
k.view(k.shape[0], heads, head_dim),
v.view(v.shape[0], heads, head_dim),
head_dim,
)
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
if skip_output_reshape:
return out
return out.reshape(-1, heads * head_dim)
def _use_blackwell_attention():
device = model_management.get_torch_device()
if device.type != "cuda":
return False
major, minor = torch.cuda.get_device_capability(device)
return (major, minor) >= (12, 0)
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
if cu_seqlens.dtype not in (torch.int32, torch.int64):
raise ValueError(f"{name} must use an integer dtype")
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
if cu_seqlens[0].item() != 0:
raise ValueError(f"{name} must start at 0")
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
raise ValueError(f"{name} must be strictly increasing")
if cu_seqlens[-1].item() != token_count:
raise ValueError(f"{name} does not match token count")
def _split_indices(cu_seqlens):
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
if cu_seqlens_k[-1].item() != v.shape[0]:
raise ValueError("cu_seqlens_k does not match v token count")
q_split_indices = _split_indices(cu_seqlens_q)
k_split_indices = _split_indices(cu_seqlens_k)
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
out = []
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
out_dtype = q_i.dtype
if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
q_i = q_i.to(torch.bfloat16)
k_i = k_i.to(torch.bfloat16)
v_i = v_i.to(torch.bfloat16)
out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
if out_i.dtype != out_dtype:
out_i = out_i.to(out_dtype)
out.append(out_i.squeeze(0).permute(1, 0, 2))
out = torch.cat(out, dim=0)
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
optimized_var_attention = var_attention_optimized_split
optimized_attention = attention_basic
if model_management.sage_attention_enabled():
@ -758,6 +837,8 @@ else:
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
optimized_attention = attention_sub_quad
logging.info("Using optimized_attention split-loop for variable-length attention")
optimized_attention_masked = optimized_attention
@ -773,6 +854,7 @@ if model_management.xformers_enabled():
register_attention_function("pytorch", attention_pytorch)
register_attention_function("sub_quad", attention_sub_quad)
register_attention_function("split", attention_split)
register_attention_function("var_attention_optimized_split", var_attention_optimized_split)
def optimized_attention_for_device(device, mask=False, small_input=False):
@ -1209,5 +1291,3 @@ class SpatialVideoTransformer(SpatialTransformer):
x = self.proj_out(x)
out = x + x_in
return out

View File

@ -13,6 +13,7 @@ if model_management.xformers_enabled_vae():
import xformers
import xformers.ops
def torch_cat_if_needed(xl, dim):
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
if len(xl) > 1:
@ -22,7 +23,8 @@ def torch_cat_if_needed(xl, dim):
else:
return None
def get_timestep_embedding(timesteps, embedding_dim):
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
@ -33,11 +35,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = math.log(10000) / (half_dim - downscale_freq_shift)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb

View File

@ -207,8 +207,9 @@ class PidNet(PixDiT_T2I):
f"Flux1/SD3 = 16 channels, Flux2 = 128 channels."
)
B = x.shape[0]
Hs = x.shape[2] // self.patch_size
Ws = x.shape[3] // self.patch_size
# Match the backbone's pad_to_patch_size (round up) so the LQ grid lines up with the patch stream.
Hs = -(-x.shape[2] // self.patch_size)
Ws = -(-x.shape[3] // self.patch_size)
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
if degrade_sigma.numel() == 1 and B > 1:

View File

@ -51,15 +51,6 @@ class FeedForward(nn.Module):
return hidden_states
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape)
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__()

View File

@ -0,0 +1,340 @@
import torch
import torch.nn.functional as F
from torch import Tensor
from comfy.ldm.seedvr.model import safe_pad_operation
from comfy.ldm.seedvr.vae import safe_interpolate_operation
from comfy.ldm.seedvr.constants import (
CIELAB_DELTA,
CIELAB_KAPPA,
D65_WHITE_X,
D65_WHITE_Z,
WAVELET_DECOMP_LEVELS,
)
def wavelet_blur(image: Tensor, radius):
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
if radius > max_safe_radius:
radius = max_safe_radius
num_channels = image.shape[1]
kernel_vals = [
[0.0625, 0.125, 0.0625],
[0.125, 0.25, 0.125],
[0.0625, 0.125, 0.0625],
]
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
return output
def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS):
high_freq = torch.zeros_like(image)
for i in range(levels):
radius = 2 ** i
low_freq = wavelet_blur(image, radius)
high_freq.add_(image).sub_(low_freq)
image = low_freq
return high_freq, low_freq
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
if content_feat.shape != style_feat.shape:
# Resize style to match content spatial dimensions
if len(content_feat.shape) >= 3:
# safe_interpolate_operation handles FP16 conversion automatically
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False
)
# Decompose both features into frequency components
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
del content_low_freq # Free memory immediately
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
del style_high_freq # Free memory immediately
if content_high_freq.shape != style_low_freq.shape:
style_low_freq = safe_interpolate_operation(
style_low_freq,
size=content_high_freq.shape[-2:],
mode='bilinear',
align_corners=False
)
content_high_freq.add_(style_low_freq)
return content_high_freq.clamp_(-1.0, 1.0)
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
original_shape = source.shape
# Flatten
source_flat = source.flatten()
reference_flat = reference.flatten()
# Sort both arrays
source_sorted, source_indices = torch.sort(source_flat)
reference_sorted, _ = torch.sort(reference_flat)
del reference_flat
# Quantile mapping
n_source = len(source_sorted)
n_reference = len(reference_sorted)
if n_source == n_reference:
matched_sorted = reference_sorted
else:
# Interpolate reference to match source quantiles
source_quantiles = torch.linspace(0, 1, n_source, device=device)
ref_indices = (source_quantiles * (n_reference - 1)).long()
ref_indices.clamp_(0, n_reference - 1)
matched_sorted = reference_sorted[ref_indices]
del source_quantiles, ref_indices, reference_sorted
del source_sorted, source_flat
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
inverse_indices = torch.argsort(source_indices)
del source_indices
matched_flat = matched_sorted[inverse_indices]
del matched_sorted, inverse_indices
return matched_flat.reshape(original_shape)
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of CIELAB images to RGB color space."""
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
# LAB to XYZ
fy = (L + 16.0) / 116.0
fx = a.div(500.0).add_(fy)
fz = fy - b / 200.0
del L, a, b
# XYZ transformation
x = torch.where(
fx > epsilon,
torch.pow(fx, 3.0),
fx.mul(116.0).sub_(16.0).div_(kappa)
)
y = torch.where(
fy > epsilon,
torch.pow(fy, 3.0),
fy.mul(116.0).sub_(16.0).div_(kappa)
)
z = torch.where(
fz > epsilon,
torch.pow(fz, 3.0),
fz.mul(116.0).sub_(16.0).div_(kappa)
)
del fx, fy, fz
# Apply D65 white point (in-place)
x.mul_(D65_WHITE_X)
# y *= 1.00000 # (no-op, skip)
z.mul_(D65_WHITE_Z)
xyz = torch.stack([x, y, z], dim=1)
del x, y, z
# Matrix multiplication: XYZ -> RGB
B, C, H, W = xyz.shape
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
del xyz
# Ensure dtype consistency for matrix multiplication
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
del xyz_flat
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del rgb_linear_flat
# Apply inverse gamma correction (delinearize)
mask = rgb_linear > 0.0031308
rgb = torch.where(
mask,
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
rgb_linear * 12.92
)
del mask, rgb_linear
return torch.clamp(rgb, 0.0, 1.0)
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
# Apply sRGB gamma correction (linearize)
mask = rgb > 0.04045
rgb_linear = torch.where(
mask,
torch.pow((rgb + 0.055) / 1.055, 2.4),
rgb / 12.92
)
del mask
# Matrix multiplication: RGB -> XYZ
B, C, H, W = rgb_linear.shape
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
del rgb_linear
# Ensure dtype consistency for matrix multiplication
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
xyz_flat = torch.matmul(rgb_flat, matrix.T)
del rgb_flat
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
del xyz_flat
# Normalize by D65 white point (in-place)
xyz[:, 0].div_(D65_WHITE_X) # X
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
xyz[:, 2].div_(D65_WHITE_Z) # Z
# XYZ to LAB transformation
epsilon_cubed = epsilon ** 3
mask = xyz > epsilon_cubed
f_xyz = torch.where(
mask,
torch.pow(xyz, 1.0 / 3.0),
xyz.mul(kappa).add_(16.0).div_(116.0)
)
del xyz, mask
# Extract channels and compute LAB
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
del f_xyz
return torch.stack([L, a, b], dim=1)
def lab_color_transfer(
content_feat: Tensor,
style_feat: Tensor,
luminance_weight: float = 0.8
) -> Tensor:
content_feat = wavelet_reconstruction(content_feat, style_feat)
if content_feat.shape != style_feat.shape:
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False
)
device = content_feat.device
def ensure_float32_precision(c):
orig_dtype = c.dtype
c = c.float()
return c, orig_dtype
content_feat, original_dtype = ensure_float32_precision(content_feat)
style_feat, _ = ensure_float32_precision(style_feat)
rgb_to_xyz_matrix = torch.tensor([
[0.4124564, 0.3575761, 0.1804375],
[0.2126729, 0.7151522, 0.0721750],
[0.0193339, 0.1191920, 0.9503041]
], dtype=torch.float32, device=device)
xyz_to_rgb_matrix = torch.tensor([
[ 3.2404542, -1.5371385, -0.4985314],
[-0.9692660, 1.8760108, 0.0415560],
[ 0.0556434, -0.2040259, 1.0572252]
], dtype=torch.float32, device=device)
epsilon = CIELAB_DELTA
kappa = CIELAB_KAPPA
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
# Convert to LAB color space
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
del content_feat
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
del style_feat, rgb_to_xyz_matrix
# Match chrominance channels (a*, b*) for accurate color transfer
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
# Handle luminance with weighted blending
if luminance_weight < 1.0:
# Partially match luminance for better overall color accuracy
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
# Blend: preserve some content L* for detail, adopt some style L* for color
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
del matched_L
else:
# Fully preserve content luminance
result_L = content_lab[:, 0]
del content_lab, style_lab
# Reconstruct LAB with corrected channels
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
del result_L, matched_a, matched_b
# Convert back to RGB
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
del result_lab, xyz_to_rgb_matrix
# Convert back to [-1, 1] range (in-place)
result = result_rgb.mul_(2.0).sub_(1.0)
del result_rgb
result = result.to(original_dtype)
return result
def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
return wavelet_reconstruction(content_feat, style_feat)
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
if content_feat.shape != style_feat.shape:
style_feat = safe_interpolate_operation(
style_feat,
size=content_feat.shape[-2:],
mode='bilinear',
align_corners=False,
)
original_dtype = content_feat.dtype
content_feat = content_feat.float()
style_feat = style_feat.float()
b, c = content_feat.shape[:2]
content_flat = content_feat.reshape(b, c, -1)
style_flat = style_feat.reshape(b, c, -1)
content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1)
content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1)
style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
del content_flat, style_flat
normalized = (content_feat - content_mean) / content_std
del content_mean, content_std
result = normalized * style_std + style_mean
del normalized, style_mean, style_std
result = result.clamp_(-1.0, 1.0)
if result.dtype != original_dtype:
result = result.to(original_dtype)
return result

View File

@ -0,0 +1,79 @@
"""Named constants for the SeedVR2 integration, grouped by provenance.
Provenance prefixes:
- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline.
- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites
the upstream config/source path it was lifted from.
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
ISO / CIE values; cite the standard.
"""
# --------------------------------------------------------------------------------------
# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment)
# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN)
# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070
# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT).
# --------------------------------------------------------------------------------------
SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB)
SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB
# --------------------------------------------------------------------------------------
# B. Fork heuristics (SEEDVR2 - this integration)
# --------------------------------------------------------------------------------------
SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim.
# (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.)
SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry.
SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case).
SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM.
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels).
SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16).
SEEDVR2_DEFAULT_TEMPORAL_SIZE = 16 # default VAE temporal tile when unset.
# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing)
SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk.
SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path.
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
# --------------------------------------------------------------------------------------
# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR)
# --------------------------------------------------------------------------------------
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem).
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem).
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32).
BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD = 64 # attn_video_vae.py:308 (force .contiguous() above this b*t).
BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11.
BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size).
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor.
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor).
BYTEDANCE_SCHEDULE_T = 1000.0 # configs_3b/main.yaml:65 (schedule.T); timestep range.
BYTEDANCE_SPATIAL_DIVISOR = 16 # inference_seedvr2_3b.py:241 (DivisibleCrop((16,16))).
BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling).
BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames).
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
# Resolution-dependent timestep-shift linear fits: (x1, y1, x2, y2) for get_lin_function.
BYTEDANCE_IMG_SHIFT_FIT = (256 * 256, 1.0, 1024 * 1024, 3.2) # infer.py:242.
BYTEDANCE_VID_SHIFT_FIT = (256 * 256 * 37, 1.0, 1280 * 720 * 145, 5.0) # infer.py:243.
# --------------------------------------------------------------------------------------
# D. Published standards (cite the literature)
# --------------------------------------------------------------------------------------
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65).
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).
# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and
# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the
# exact existing coefficients move verbatim rather than being retyped here.

1665
comfy/ldm/seedvr/model.py Normal file

File diff suppressed because it is too large Load Diff

2110
comfy/ldm/seedvr/vae.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,199 @@
# TripoSplat 3D gaussian container. Operates on already-decoded
# tensors and exposes them as render-ready tensors (render_tensors) for the generic SPLAT type.
import torch
import torch.nn.functional as F
import comfy.model_management
class GaussianModel:
def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0,
scaling_bias: float = 0.01, opacity_bias: float = 0.1,
scaling_activation: str = "exp", device=None):
self.sh_degree = sh_degree
self.mininum_kernel_size = mininum_kernel_size
self.scaling_bias = scaling_bias
self.opacity_bias = opacity_bias
self.device = device
self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
if scaling_activation == "exp":
self._scaling_activation = torch.exp
self._inverse_scaling_activation = torch.log
elif scaling_activation == "softplus":
self._scaling_activation = F.softplus
self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
self._opacity_activation = torch.sigmoid
self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x))
self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device)
self.rots_bias = torch.zeros(4, device=self.device)
self.rots_bias[0] = 1
self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device)
self._storage = {}
def _get_store(self, name):
return self._storage.get(name)
def _set_store(self, name, value):
self._storage[name] = value
@property
def _xyz(self):
return self._get_store("_xyz")
@_xyz.setter
def _xyz(self, value):
if value is None:
self._set_store("_xyz", None)
self._set_store("xyz", None)
return
self._set_store("_xyz", value)
self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3])
@property
def get_xyz(self):
return self._get_store("xyz")
@property
def _features_dc(self):
return self._get_store("_features_dc")
@_features_dc.setter
def _features_dc(self, value):
self._set_store("_features_dc", value)
@property
def _opacity(self):
return self._get_store("_opacity")
@_opacity.setter
def _opacity(self, value):
if value is None:
self._set_store("_opacity", None)
self._set_store("opacity", None)
return
self._set_store("_opacity", value)
self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val))
@property
def get_opacity(self):
return self._get_store("opacity")
@property
def _scaling(self):
return self._get_store("_scaling")
@_scaling.setter
def _scaling(self, value):
if value is None:
self._set_store("_scaling", None)
self._set_store("scaling", None)
return
self._set_store("_scaling", value)
s = self._scaling_activation(value + self.scale_bias)
s = torch.square(s) + self.mininum_kernel_size ** 2
self._set_store("scaling", torch.sqrt(s))
@property
def get_scaling(self):
return self._get_store("scaling")
@property
def _rotation(self):
return self._get_store("_rotation")
@_rotation.setter
def _rotation(self, value):
self._set_store("_rotation", value)
_DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
def render_tensors(self):
# Render-ready (activated, world-space) tensors for the generic SPLAT type. The axis transform
# (a 3x3 rotation, object frame -> viewer Y-up) is baked into positions and rotations.
# Returns float tensors on the intermediate device: positions (N,3), scales (N,3) linear,
# rotations (N,4) wxyz, opacities (N,1) in [0,1], sh (N,K,3) coefficients.
xyz = self.get_xyz.float()
scaling = self.get_scaling.float()
opacity = self.get_opacity.float()
rotation = (self._rotation + self.rots_bias[None, :]).float()
sh = self._features_dc.float() # (N, K, 3)
T = torch.as_tensor(self._DEFAULT_TRANSFORM, dtype=torch.float32, device=xyz.device)
xyz = xyz @ T.T
rotation = _matrix_to_quat(torch.matmul(T, _quat_to_matrix(rotation)))
rotation = rotation / torch.linalg.norm(rotation, dim=-1, keepdim=True)
out_device = comfy.model_management.intermediate_device()
return (
xyz.to(out_device).contiguous(), scaling.to(out_device).contiguous(),
rotation.to(out_device).contiguous(), opacity.to(out_device).contiguous(),
sh.to(out_device).contiguous(),
)
def _quat_to_matrix(q):
q = q / torch.linalg.norm(q, dim=-1, keepdim=True)
w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
R = torch.stack([
1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y),
2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x),
2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y),
], dim=-1).reshape(-1, 3, 3)
return R
def _matrix_to_quat(R):
trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
q = torch.zeros((R.shape[0], 4), dtype=R.dtype, device=R.device)
s = torch.sqrt(torch.clamp(trace + 1, min=0)) * 2
q[:, 0] = 0.25 * s
denom = torch.where(s != 0, s, torch.ones_like(s))
q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / denom
q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / denom
q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / denom
m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0)
s1 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=0)) * 2
q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01]
q[m01, 1] = 0.25 * s1[m01]
q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01]
q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01]
m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0)
s2 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=0)) * 2
q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11]
q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11]
q[m11, 2] = 0.25 * s2[m11]
q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11]
m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0)
s3 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=0)) * 2
q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21]
q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21]
q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21]
q[m21, 3] = 0.25 * s3[m21]
return q / torch.linalg.norm(q, dim=-1, keepdim=True)
def build_gaussian_models(decoder, points_pred: dict, pred: dict):
# Assemble GaussianModels from the elastic decoder layout. decoder is the ElasticGaussianFixedlenDecoder
# (carries layout / rep_config / _get_offset)
x = points_pred
offset = decoder._get_offset(pred['features'])
h = pred["features"]
ret = []
for i in range(h.shape[0]):
g = GaussianModel(
sh_degree=0,
aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'],
scaling_bias=decoder.rep_config['scaling_bias'],
opacity_bias=decoder.rep_config['opacity_bias'],
scaling_activation=decoder.rep_config['scaling_activation'],
device=h.device,
)
_x = x["points"][i, :, None, :]
for k, v in decoder.layout.items():
if k == '_xyz':
setattr(g, k, (offset[i] + _x).flatten(0, 1))
elif k in ('_xyz_center', '_offset_scale'):
continue
else:
feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
setattr(g, k, feats * decoder.rep_config['lr'][k])
ret.append(g)
return ret

View File

@ -0,0 +1,326 @@
# TripoSplat flow-matching denoiser (LatentSeqMMFlowModel). Registered as a ModelType.FLOW arch and
# driven by the standard KSampler; jointly denoises the (B, 8192, 16) latent and a (B, 1, 5) camera token
# carried as a 2-element nested latent.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.model_management
import comfy.patcher_extension
import comfy.rmsnorm
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope
class MultiHeadRMSNorm(nn.Module):
def __init__(self, dim, heads, dtype=None, device=None):
super().__init__()
self.gamma = nn.Parameter(torch.empty(heads, dim, dtype=dtype, device=device))
def forward(self, x):
x = comfy.rmsnorm.rms_norm(x)
return x * comfy.model_management.cast_to(self.gamma, x.dtype, x.device)
# Positional embeddings
class RePo3DRotaryEmbedding(nn.Module):
def __init__(self, model_channels, num_heads, head_dim, repo_hidden_ratio=0.125, max_freq=16.0,
dtype=None, device=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
repo_hidden_size = int(model_channels * repo_hidden_ratio)
self.norm = operations.LayerNorm(model_channels, dtype=dtype, device=device)
self.gate_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
self.content_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
self.act = nn.SiLU()
self.final_map = operations.Linear(repo_hidden_size, 3 * num_heads, bias=False, dtype=dtype, device=device)
self.dim_0 = 2 * (head_dim // 6)
self.dim_1 = 2 * (head_dim // 6)
self.dim_2 = head_dim - self.dim_0 - self.dim_1
dims = [self.dim_0, self.dim_1, self.dim_2]
freqs_list = []
for d in dims:
freq_dim = d // 2
freqs_list.append(torch.linspace(1.0, float(max_freq), steps=freq_dim, dtype=torch.float32))
self.freqs_0 = nn.Parameter(freqs_list[0])
self.freqs_1 = nn.Parameter(freqs_list[1])
self.freqs_2 = nn.Parameter(freqs_list[2])
def forward(self, hidden_states):
h = self.norm(hidden_states)
feat = self.act(self.gate_map(h)) * self.content_map(h)
out = self.final_map(feat)
B, L, _ = out.shape
delta_pos = out.reshape(B, L, self.num_heads, 3)
f0 = comfy.model_management.cast_to(self.freqs_0, torch.float32, out.device)
f1 = comfy.model_management.cast_to(self.freqs_1, torch.float32, out.device)
f2 = comfy.model_management.cast_to(self.freqs_2, torch.float32, out.device)
ang_0 = delta_pos[..., 0].unsqueeze(-1) * f0 * torch.pi
ang_1 = delta_pos[..., 1].unsqueeze(-1) * f1 * torch.pi
ang_2 = delta_pos[..., 2].unsqueeze(-1) * f2 * torch.pi
ang = torch.cat([ang_0, ang_1, ang_2], dim=-1).float() # (B, L, heads, head_dim/2)
cos, sin = ang.cos(), ang.sin()
return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*ang.shape, 2, 2)
class PcdAbsolutePositionEmbedder(nn.Module):
# Sinusoidal absolute position embedding. Two fixed schedules are used in TripoSplat:
# "pow2" (flow-model latent anchors) and "log2" (octree / gaussian decoders).
def __init__(self, channels: int, in_channels: int = 3, max_res: int = 16, schedule: str = "pow2"):
super().__init__()
self.channels = channels
self.in_channels = in_channels
self.max_res = max_res
self.schedule = schedule
self.freq_dim = channels // in_channels // 2
def _freqs(self, device):
if self.schedule == "pow2":
freqs_2exp = torch.arange(self.max_res, dtype=torch.float32, device=device)
res_dim = max(0, self.freq_dim - self.max_res)
freqs_res = (torch.arange(res_dim, dtype=torch.float32, device=device) / max(res_dim, 1) * self.max_res
if res_dim > 0 else torch.empty(0, device=device))
freqs = torch.cat([freqs_2exp, freqs_res], dim=0)[:self.freq_dim]
return torch.pow(2.0, freqs) * 2.0 # *2 folds this schedule's 2*pi into the shared *pi below
logs = torch.linspace(0.0, float(self.max_res), steps=self.freq_dim, dtype=torch.float32, device=device)
return torch.pow(2.0, logs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
x = x.float()
*dims, D = x.shape
out = torch.outer(x.reshape(-1), self._freqs(x.device)) * torch.pi
out = torch.cat([out.sin(), out.cos()], dim=-1).reshape(*dims, -1)
if out.shape[-1] < self.channels:
out = torch.cat([out, torch.zeros(*dims, self.channels - out.shape[-1],
device=out.device, dtype=out.dtype)], dim=-1)
return out.to(orig_dtype)
def attention(q, k, v, transformer_options=None):
# q, k, v: (B, L, heads, dim) -> (B, L, heads, dim). Shared optimized_attention call convention.
out = optimized_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), heads=q.shape[2],
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
transformer_options=transformer_options)
return out.transpose(1, 2)
# Transformer building blocks
class MLP(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(in_channels, hidden_channels, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(hidden_channels, out_channels, dtype=dtype, device=device),
)
def forward(self, x):
return self.mlp(x)
class RopeMultiHeadAttention(nn.Module):
def __init__(self, channels, num_heads, qkv_bias=True, qk_rms_norm=False, use_rope=False,
dtype=None, device=None, operations=None):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.head_dim = channels // num_heads
self.qk_rms_norm = qk_rms_norm
self.use_rope = use_rope
self.qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
if self.qk_rms_norm:
self.q_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.k_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.out = operations.Linear(channels, channels, dtype=dtype, device=device)
def forward(self, x, rope_emb=None, transformer_options=None):
B, L, C = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(2)
if self.use_rope:
q, k = apply_rope(q, k, rope_emb)
if self.qk_rms_norm:
q = self.q_norm(q)
k = self.k_norm(k)
h = attention(q, k, v, transformer_options) # (B, L, heads, dim)
return self.out(h.reshape(B, L, C))
class UnifiedTransformerBlock(nn.Module):
def __init__(self, channels, num_heads, mlp_ratio=4.0,
use_rope=False, qk_rms_norm=False, qkv_bias=True,
modulation=True, share_mod=False,
dtype=None, device=None, operations=None):
super().__init__()
self.modulation = modulation
self.share_mod = share_mod
self.norm1 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
self.attn = RopeMultiHeadAttention(channels, num_heads=num_heads,
qkv_bias=qkv_bias, use_rope=use_rope, qk_rms_norm=qk_rms_norm,
dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
if modulation:
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
self.shift_table = nn.Parameter(torch.empty(1, 6 * channels, dtype=dtype, device=device))
def forward(self, x, mod=None, rotary_emb=None, transformer_options=None):
if self.modulation:
if not self.share_mod:
mod = self.adaLN_modulation(mod)
mod = mod + comfy.model_management.cast_to(self.shift_table, mod.dtype, mod.device)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
x = torch.addcmul(x, self.attn(h, rope_emb=rotary_emb, transformer_options=transformer_options), gate_msa.unsqueeze(1))
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
else:
x = x + self.attn(self.norm1(x), rope_emb=rotary_emb, transformer_options=transformer_options)
x = x + self.mlp(self.norm2(x))
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
emb = self.timestep_embedding(t, self.frequency_embedding_size)
return self.mlp(emb.to(self.mlp[0].weight.dtype))
class LatentSeqMMFlowModel(nn.Module):
def __init__(self, image_model=None, q_token_length=8192, in_channels=16, model_channels=1024,
cond_channels=1280, out_channels=16, num_blocks=24, num_refiner_blocks=2,
num_heads=None, num_head_channels=64, cam_channels=5, cond2_channels=128,
mlp_ratio=4, share_mod=True, qk_rms_norm=True,
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
self.q_token_length = q_token_length
self.in_channels = in_channels
self.cam_channels = cam_channels
self.model_channels = model_channels
self.cond_channels = cond_channels
self.cond2_channels = cond2_channels
self.out_channels = out_channels
self.num_blocks = num_blocks
self.num_refiner_blocks = num_refiner_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.share_mod = share_mod
self.qk_rms_norm = qk_rms_norm
factory_kwargs = dict(dtype=dtype, device=device)
op_kwargs = dict(operations=operations, **factory_kwargs)
self.t_embedder = TimestepEmbedder(model_channels, **op_kwargs)
if share_mod:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, **factory_kwargs))
self.input_layer = operations.Linear(in_channels, model_channels, **factory_kwargs)
self.cond_embedder = operations.Linear(cond_channels, model_channels, **factory_kwargs)
self.cond_embedder2 = operations.Linear(cond2_channels, model_channels, **factory_kwargs) if cond2_channels is not None else None
# Fixed Sobol (low-discrepancy) 3D anchor positions for the latent tokens, used as positional encoding.
# The embedder is parameter-free and the anchors are fixed, precompute once.
sobol_seq = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123).draw(q_token_length)
pos_emb = PcdAbsolutePositionEmbedder(model_channels)(sobol_seq.unsqueeze(0))
self.register_buffer("pos_emb", pos_emb, persistent=False)
# RePo3DRotaryEmbedding layers for the refiner and main blocks
repo_kwargs = dict(num_heads=self.num_heads, head_dim=num_head_channels, **op_kwargs)
self.noise_repo_layers = nn.ModuleList(
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
self.context_repo_layers = nn.ModuleList(
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
self.repo_layers = nn.ModuleList(
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_blocks)])
# Refiner blocks
block_kwargs = dict(num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, use_rope=True, qk_rms_norm=self.qk_rms_norm, **op_kwargs)
self.noise_refiner = nn.ModuleList(
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_refiner_blocks)])
self.context_refiner = nn.ModuleList(
[UnifiedTransformerBlock(model_channels, modulation=False, **block_kwargs) for _ in range(num_refiner_blocks)])
self.cam_refiner = MLP(self.cam_channels, model_channels, model_channels, **op_kwargs)
self.blocks = nn.ModuleList(
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_blocks)])
self.shift_table = nn.Parameter(torch.empty(1, 2, model_channels, **factory_kwargs))
self.out_layer = operations.Linear(model_channels, out_channels, **factory_kwargs)
self.cam_out_layer = operations.Linear(model_channels, cam_channels, **factory_kwargs)
def forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, t, context, ref_latents, transformer_options, **kwargs)
def _forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
# x is the unpacked nested latent: [latent (B,8192,in_channels), camera (B,1,cam_channels)].
# context == feature1.
z, camera = x[0], x[1]
feat1 = context
h_x = self.input_layer(z)
h_cond = self.cond_embedder(feat1)
if ref_latents is not None and self.cond_embedder2 is not None:
# Flatten the Flux2 VAE latent (B,128,h,w) to a token sequence and front-pad to feat1's length
# (the pad count = feat1's prefix tokens: DINOv3 cls + registers), then add to the context.
feat2 = ref_latents[0].flatten(2).transpose(1, 2)
feat2 = F.pad(feat2, (0, 0, feat1.shape[1] - feat2.shape[1], 0))
h_cond = h_cond + self.cond_embedder2(feat2.to(h_cond.dtype))
t_emb = self.t_embedder(t)
t_mod = self.adaLN_modulation(t_emb) if self.share_mod else t_emb
h_x = h_x + self.pos_emb.to(z)
for i, block in enumerate(self.noise_refiner):
h_x = block(h_x, mod=t_mod, rotary_emb=self.noise_repo_layers[i](h_x), transformer_options=transformer_options)
for i, block in enumerate(self.context_refiner):
h_cond = block(h_cond, mod=None, rotary_emb=self.context_repo_layers[i](h_cond), transformer_options=transformer_options)
cam = camera.to(z)
h_cam = self.cam_refiner(cam)
h = torch.cat([h_x, h_cond, h_cam], dim=1)
for i, block in enumerate(self.blocks):
h = block(h, mod=t_mod, rotary_emb=self.repo_layers[i](h), transformer_options=transformer_options)
h_x = F.layer_norm(h[:, :z.shape[1]].float(), h.shape[-1:]).to(z)
h_cam = F.layer_norm(h[:, -cam.shape[1]:].float(), h.shape[-1:]).to(z)
shift, scale = (comfy.model_management.cast_to(self.shift_table, t_emb.dtype, t_emb.device) + t_emb.unsqueeze(1)).chunk(2, dim=1)
scale = 1 + scale
h_x = torch.addcmul(shift, h_x, scale)
h_cam = torch.addcmul(shift, h_cam, scale)
return self.out_layer(h_x), self.cam_out_layer(h_cam)

View File

@ -0,0 +1,91 @@
# Live preview for TripoSplat: decode an x0 estimate into a coarse gaussian splat and render it with a perspective orbit camera.
import numpy as np
from PIL import Image
_C0 = 0.28209479177387814
_LATENT_TOKENS = 8192 # q_token_length
_LATENT_CH = 16 # in_channels
_OBJECT_TO_VIEWER = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], np.float32) # object frame -> viewer Y-up frame
def _view_matrix(yaw_deg, pitch_deg):
y, p = np.radians(yaw_deg), np.radians(pitch_deg)
Ry = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0], [-np.sin(y), 0, np.cos(y)]], np.float32)
Rx = np.array([[1, 0, 0], [0, np.cos(p), -np.sin(p)], [0, np.sin(p), np.cos(p)]], np.float32)
return Rx @ Ry
def render_splat(xyz, rgb, scale, opacity=None, yaw=35.0, pitch=30.0, size=320, min_px=2, gain=1.0,
max_px=9, min_opacity=0.0, fov=35.0, dist=2.2):
# Project gaussian centers with a perspective camera and paint each as a filled disk whose screen
# radius follows the gaussian's world-space scale, composited with a nearest-wins z-buffer.
# gain scales the footprint (≈ std spanned), `min_px`/`max_px` clamp the on-screen radius.
pts = xyz.astype(np.float32) @ _OBJECT_TO_VIEWER.T
v = pts @ _view_matrix(yaw, pitch).T
zc = v[:, 2] + dist
keep = zc > 1e-2
if opacity is not None and min_opacity > 0.0: # culls gaussians with very low opacity
keep = keep & (opacity > min_opacity)
v, zc, scale = v[keep], zc[keep], scale[keep]
col = (np.clip(rgb, 0, 1)[:, :3] * 255).astype(np.uint8)[keep]
if v.shape[0] == 0:
return Image.fromarray(np.zeros((size, size, 3), np.uint8))
f = (size / 2) / np.tan(np.radians(fov) / 2)
cx = size / 2 + f * v[:, 0] / zc
cy = size / 2 + f * v[:, 1] / zc
radius = np.clip(np.round(f * scale / zc * gain), min_px, max_px).astype(np.int32)
# Expand each splat to its disk pixels, bucketed by integer radius so it stays vectorized.
px, py, pz, pc = [], [], [], []
for r in range(int(radius.min()), int(radius.max()) + 1):
m = radius == r
if not m.any():
continue
dy, dx = np.mgrid[-r:r + 1, -r:r + 1]
disk = (dx * dx + dy * dy) <= r * r
ox, oy = dx[disk], dy[disk]
px.append((cx[m, None] + ox).ravel())
py.append((cy[m, None] + oy).ravel())
pz.append(np.repeat(zc[m], ox.size))
pc.append(np.repeat(col[m], ox.size, axis=0))
px, py = np.concatenate(px), np.concatenate(py)
pz, pc = np.concatenate(pz), np.concatenate(pc)
xi = np.clip(px, 0, size - 1).astype(np.int64)
yi = np.clip(py, 0, size - 1).astype(np.int64)
# Nearest-wins z-buffer: pack (quantized depth, source index), per-pixel min picks the closest
# splat, then decode the winning index back to its color.
pid = yi * size + xi
q = np.clip((pz * 1024.0).astype(np.int64), 0, (1 << 20) - 1) # near = small
key = (q << 32) | np.arange(pid.size, dtype=np.int64)
buf = np.full(size * size, 1 << 62, np.int64)
np.minimum.at(buf, pid, key)
img = np.zeros((size * size, 3), np.uint8)
hit = buf < (1 << 62)
img[hit] = pc[buf[hit] & 0xFFFFFFFF]
return Image.fromarray(img.reshape(size, size, 3))
def _extract_latent(x0):
# x0 from the sampler callback is the nested latent packed to (B, 1, TOKENS*CH + 1*5);
# the plain single-latent case is (B, TOKENS, CH). Return the (B, TOKENS, CH) latent stream.
if x0.ndim == 3 and x0.shape[1] == _LATENT_TOKENS and x0.shape[2] == _LATENT_CH:
return x0
flat = x0.reshape(x0.shape[0], -1)
return flat[:, :_LATENT_TOKENS * _LATENT_CH].reshape(x0.shape[0], _LATENT_TOKENS, _LATENT_CH)
def decode_x0_to_image(decoder, x0, cfg):
# Decode x0 at a coarse octree level / few gaussians and render a preview image.
latent = _extract_latent(x0)
fsm = decoder.first_stage_model
gaussian = fsm.decode(latent.to(decoder.device, decoder.vae_dtype),
num_gaussians=cfg.get("gaussians", 16384), level=cfg.get("level", 5))[0]
xyz = gaussian.get_xyz.float().cpu().numpy()
rgb = gaussian._features_dc.float().cpu().numpy()[:, 0, :] * _C0 + 0.5
scale = gaussian.get_scaling.float().cpu().numpy().max(axis=1) # per-splat world radius (largest axis)
opacity = gaussian.get_opacity.float().cpu().numpy()[:, 0]
return render_splat(xyz, rgb, scale, opacity=opacity, yaw=cfg.get("yaw", 35.0), pitch=cfg.get("pitch", 30.0),
size=cfg.get("size", 320), min_px=1, gain=1.0, max_px=cfg.get("point_size", 3),
min_opacity=0.01)

382
comfy/ldm/triposplat/vae.py Normal file
View File

@ -0,0 +1,382 @@
# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an
# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns
# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.model_management
import comfy.ops
from .gaussian import build_gaussian_models
from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention
# Quasi-random sampling utilities (pure functions, dtype/device-agnostic)
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
def radical_inverse(base, n):
val = 0
inv_base = 1.0 / base
inv_base_n = inv_base
while n > 0:
digit = n % base
val += digit * inv_base_n
n //= base
inv_base_n *= inv_base
return val
def halton_sequence(dim, n):
return [radical_inverse(PRIMES[i], n) for i in range(dim)]
def hammersley_sequence(dim, n, num_samples):
return [n / num_samples] + halton_sequence(dim - 1, n)
def sample_probs(probs, counts, generator=None):
# Systematic resampling: distribute counts[r] draws across the P bins of row r
batch_shape = counts.shape
R = counts.numel()
P = probs.size(-1)
device = probs.device
probs = probs.reshape(R, P).to(torch.float32).clamp_min(0)
counts = counts.reshape(R).to(device=device, dtype=torch.long)
row_sums = probs.sum(1, keepdim=True)
probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1))
cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12)
Nmax = int(counts.max())
if Nmax == 0:
return counts.new_zeros(*batch_shape, P)
cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1)
grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax)
u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded)
idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1)
weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r]
out = torch.zeros(R, P, dtype=torch.float32, device=device)
out.scatter_add_(1, idx, weight)
return out.to(torch.long).view(*batch_shape, P)
class MultiHeadAttention(nn.Module):
def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False,
dtype=None, device=None, operations=None):
super().__init__()
assert channels % num_heads == 0
self.channels = channels
self.head_dim = channels // num_heads
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
self.num_heads = num_heads
self._type = type
self.qk_rms_norm = qk_rms_norm
if self._type == "self":
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
else:
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device)
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device)
if self.qk_rms_norm:
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device)
def forward(self, x, context=None):
B, L, C = x.shape
if self._type == "self":
q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2)
else:
Lkv = context.shape[1]
q = self.to_q(x).reshape(B, L, self.num_heads, -1)
k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
h = attention(q, k, v)
return self.to_out(h.reshape(B, L, -1))
# Octree probability decoder
class LevelEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024,
dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
@staticmethod
def level_embedding(t, dim, max_period=1024):
half = dim // 2
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None] * 2 * torch.pi
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period)
return self.mlp(emb.to(self.mlp[0].weight.dtype))
class ModulatedTransformerCrossOnlyBlock(nn.Module):
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False,
qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.share_mod = share_mod
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads,
type="cross", qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
def forward(self, x, mod, context):
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1))
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
return x
class OctreeProbabilityFixedlenDecoder(nn.Module):
# Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits.
def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16,
num_head_channels=64, mlp_ratio=4.0, share_mod=True,
qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
super().__init__()
self.model_channels = model_channels
self.cond_channels = cond_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.share_mod = share_mod
self.qk_rms_norm_cross = qk_rms_norm_cross
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device))
if cond_channels is not None:
self.blocks = nn.ModuleList([
ModulatedTransformerCrossOnlyBlock(
model_channels, ctx_channels=cond_channels, num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross,
share_mod=self.share_mod, dtype=dtype, device=device, operations=operations)
for _ in range(num_blocks)
])
self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device)
self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device)
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
def forward(self, x, l, cond):
d = next(self.parameters()).dtype
B, L, _ = x.shape
h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d)
h = self.input_layer(h)
l_emb = self.l_embedder(l)
if self.share_mod:
l_emb = self.adaLN_modulation(l_emb)
cond = cond.to(d)
for block in self.blocks:
h = block(h, l_emb, cond)
h = F.layer_norm(h.float(), h.shape[-1:]).to(d)
logits = self.out_proj(h)
return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
@staticmethod
def sample(model, cond, num_points, level, temperature=1.0, generator=None):
B = cond.shape[0]
device = cond.device
child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
dtype=torch.long, device=device)
prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device)
prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device)
prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device)
batch_indices_range = torch.arange(B, device=device).unsqueeze(1)
for lv in range(1, level + 1):
res_p = 1 << (lv - 1)
res = 1 << lv
parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p
res_tensor = torch.full((B,), res, dtype=torch.long, device=device)
pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature
pred_probs = torch.softmax(pred_logits, dim=-1)
pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2)
pred_log_probs = pred_log_probs.flatten(1, 2)
prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
mask = sampled > 0
max_valid = mask.sum(dim=1).max().item()
scatter_indices = mask.cumsum(dim=1) - 1
valid_scatter_indices = scatter_indices[mask]
valid_batch_indices = batch_indices_range.expand_as(mask)[mask]
next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device)
next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask]
next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device)
next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask]
next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device)
next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask]
prev_coords_int = next_prev_coords_int
prev_counts = next_prev_counts
prev_log_probs = next_prev_log_probs
res = 1 << level
prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device)
coords_norm = (coords_int.to(torch.float32) + rand) / res
return {"points": coords_norm, "log_probs": prev_log_probs}
# Elastic gaussian decoder
class TransformerCrossBlock(nn.Module):
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0,
qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True,
dtype=None, device=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations)
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross",
qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
def forward(self, x, context):
x = x + self.self_attn(self.norm1(x))
x = x + self.cross_attn(self.norm2(x), context)
x = x + self.mlp(self.norm3(x))
return x
class ElasticGaussianFixedlenDecoder(nn.Module):
# Cross-attention transformer over sampled octree points -> per-point gaussian params.
def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16,
num_head_channels=64, mlp_ratio=4.0, *, representation_config=None,
qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
super().__init__()
self.rep_config = representation_config or dict(
lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1),
perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32,
filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1,
scaling_activation="softplus",
)
self.out_channels = self._calc_layout()
self.model_channels = model_channels
self.cond_channels = cond_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
if cond_channels is not None:
self.blocks = nn.ModuleList([
TransformerCrossBlock(model_channels, ctx_channels=cond_channels,
num_heads=self.num_heads, mlp_ratio=self.mlp_ratio,
qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross,
dtype=dtype, device=device, operations=operations)
for _ in range(num_blocks)
])
self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device)
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device)
self._build_perturbation()
def _calc_layout(self):
ng = self.rep_config['num_gaussians']
self.layout = {
'_xyz': {'shape': (ng, 3), 'size': ng * 3},
'_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3},
'_scaling': {'shape': (ng, 3), 'size': ng * 3},
'_rotation': {'shape': (ng, 4), 'size': ng * 4},
'_opacity': {'shape': (ng, 1), 'size': ng},
}
self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng}
start = 0
for k, v in self.layout.items():
v['range'] = (start, start + v['size'])
start += v['size']
return start
def _build_perturbation(self):
ng = self.rep_config['num_gaussians']
perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float()
perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size'])
self.register_buffer('points_offset_perturbation', perturbation)
base = torch.tensor(self.rep_config['offset_scale'])
self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0))
def _get_offset(self, h):
B = h.shape[0]
r = self.layout['_offset_scale']['range']
_offset_scale = F.softplus(
h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape'])
+ comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device))
r = self.layout['_xyz']['range']
offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape'])
offset = offset * self.rep_config['lr']['_xyz']
if self.rep_config['perturb_offset']:
offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device)
offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size']
offset = offset * _offset_scale
return offset
def forward(self, x=None, cond=None):
pcd = x["points"]
d = next(self.parameters()).dtype
B, L, _ = pcd.shape
h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d)
h = self.input_layer(h)
cond = cond.to(d)
for block in self.blocks:
h = block(h, cond)
h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype)
return {"features": self.out_proj(h)}
# Combined octree gaussian decoder (comfy first-stage model)
class OctreeGaussianDecoder(nn.Module):
_MAX_VOXEL_LEVEL = 8
def __init__(self, dtype=None, device=None, operations=None):
super().__init__()
if operations is None:
operations = comfy.ops.disable_weight_init
self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations)
self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations)
@property
def gaussians_per_point(self) -> int:
return self.gs.rep_config['num_gaussians']
def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None):
# level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews.
# generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG.
level = self._MAX_VOXEL_LEVEL if level is None else level
num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
points_pred = OctreeProbabilityFixedlenDecoder.sample(
self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator,
)
pred = self.gs(x=points_pred, cond=latent)
return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item

View File

@ -4,6 +4,7 @@ import dataclasses
import torch
from typing import NamedTuple
import comfy_aimdo.host_buffer
from comfy.quant_ops import QuantizedTensor
@ -17,21 +18,18 @@ class TensorFileSlice(NamedTuple):
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor):
return False
if tensor._layout_cls != destination._layout_cls:
return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
if not read_tensor_file_slice_into(tensor._qdata,
destination._qdata if destination is not None else None, stream=stream,
destination2=(destination2._qdata if destination2 is not None else None)):
return False
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
if destination is not None:
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
if destination2 is not None:
dst_orig_dtype = destination2._params.orig_dtype
destination2._params.copy_from(destination._params, non_blocking=True)
destination2._params.copy_from(destination._params if destination is not None else tensor._params, non_blocking=True)
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
return True
@ -39,10 +37,15 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
if info is None:
return False
if destination is not None and destination.device.type != "cpu" and destination2 is None:
destination2 = destination
destination = None
file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
or destination.numel() * destination.element_size() < info.size
if (file_obj is None
or (destination is None and destination2 is None)
or (destination is not None and (destination.device.type != "cpu" or destination.numel() * destination.element_size() < info.size))
or (destination2 is not None and (destination2.device.type == "cpu" or destination2.numel() * destination2.element_size() < info.size))
or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0
or not tensor.is_contiguous()):
@ -51,6 +54,14 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
if info.size == 0:
return True
if destination is None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
comfy_aimdo.host_buffer.read_file_to_device(file_obj, info.offset, info.size,
stream_ptr, destination2.data_ptr(),
destination2.device.index,
mark_cold=False)
return True
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
if hostbuf is not None:
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
@ -63,6 +74,9 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
device=None if destination2 is None else destination2.device.index)
return True
if not hasattr(file_obj, "seek") or not hasattr(file_obj, "readinto"):
return False
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))

View File

@ -46,6 +46,7 @@ import comfy.ldm.wan.model_animate
import comfy.ldm.wan.ar_model
import comfy.ldm.wan.model_wandancer
import comfy.ldm.hunyuan3d.model
import comfy.ldm.triposplat.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.chroma_radiance.model
@ -53,7 +54,10 @@ import comfy.ldm.pixeldit.model
import comfy.ldm.pixeldit.pid
import comfy.ldm.ace.model
import comfy.ldm.omnigen.omnigen2
import comfy.ldm.seedvr.model
import comfy.ldm.qwen_image.model
import comfy.ldm.ideogram4.model
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15
@ -926,6 +930,16 @@ class HunyuanDiT(BaseModel):
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
return out
class SeedVR2(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
condition = kwargs.get("condition", None)
if condition is not None:
out["condition"] = comfy.conds.CONDRegular(condition)
return out
class PixArt(BaseModel):
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
@ -1428,6 +1442,23 @@ class PiD(PixelDiTT2I):
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "lq_latent" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
lq = cond_value.cond
dim = window.dim
if dim >= lq.ndim:
return None
lq_proj = self.diffusion_model.lq_proj
ratio = lq_proj.sr_scale * lq_proj.latent_spatial_down_factor
# Map x window indices -> lq indices (deduplicated, sorted, in-bounds).
lq_size = lq.size(dim)
lq_indices = sorted({i // ratio for i in window.index_list if 0 <= i // ratio < lq_size})
if not lq_indices:
return None
idx = tuple([slice(None)] * dim + [lq_indices])
return cond_value._copy_with(lq[idx].to(device))
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
@ -1789,6 +1820,24 @@ class Hunyuan3Dv2_1(BaseModel):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class TripoSplat(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.triposplat.model.LatentSeqMMFlowModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None) # DINOv3 token sequence -> cross-attention context.
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None) # Flux2 VAE image latent -> additive second conditioning.
if ref_latents is not None:
out['ref_latents'] = comfy.conds.CONDList(list(ref_latents))
latent_shapes = kwargs.get("latent_shapes", None) # {latent, camera} nested latent
if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
return out
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
@ -1982,6 +2031,21 @@ class QwenImage(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
return out
class Ideogram4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
attention_mask = kwargs.get("attention_mask", None)
if attention_mask is not None:
if torch.numel(attention_mask) != attention_mask.sum():
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out
class HunyuanImage21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)

View File

@ -313,6 +313,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["use_x0"] = True
else:
dit_config["use_x0"] = False
if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids
dit_config["use_sequential_txt_ids"] = True
else:
dit_config["use_sequential_txt_ids"] = False
else:
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
@ -594,6 +598,56 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
return dit_config
if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.``
# submodules) at EVERY block — verified by inspecting the 7B
# state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means
# ``MMModule.shared_weights=False``). Native NaDiT computes
# per-block ``shared_weights = not (i < mm_layers)``, so to keep
# every block non-shared we set ``mm_layers = num_layers``.
# Without this, blocks at index >= mm_layers (default 10) try to
# load ``blocks.N.*.all.*`` keys that don't exist in the file,
# silently miss-load → all-black output.
dit_config["mm_layers"] = 36
dit_config["norm_eps"] = 1e-5
dit_config["qk_rope"] = True
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "normal"
return dit_config
elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 3072
dit_config["heads"] = 24
dit_config["num_layers"] = 36
# This checkpoint layout carries shared ``all.`` MMModule keys.
# Preserve the historical split: the initial blocks use separate
# vid/txt modules, later blocks use shared modules.
dit_config["mm_layers"] = 10
dit_config["norm_eps"] = 1e-5
dit_config["qk_rope"] = True
dit_config["rope_type"] = "rope3d"
dit_config["rope_dim"] = 64
dit_config["mlp_type"] = "swiglu"
return dit_config
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
dit_config = {}
dit_config["image_model"] = "seedvr2"
dit_config["vid_dim"] = 2560
dit_config["heads"] = 20
dit_config["num_layers"] = 32
dit_config["norm_eps"] = 1.0e-05
dit_config["qk_rope"] = None
dit_config["mlp_type"] = "swiglu"
dit_config["vid_out_norm"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"
@ -676,6 +730,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config
if '{}cam_out_layer.weight'.format(key_prefix) in state_dict_keys and '{}repo_layers.0.final_map.weight'.format(key_prefix) in state_dict_keys: # TripoSplat
return {"image_model": "triposplat"}
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
return {"image_model": "hidream_o1"}
@ -808,6 +865,13 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["default_ref_method"] = "negative_index"
return dit_config
if '{}embed_image_indicator.weight'.format(key_prefix) in state_dict_keys: # Ideogram 4
dit_config = {}
dit_config["image_model"] = "ideogram4"
dit_config["in_channels"] = state_dict['{}input_proj.weight'.format(key_prefix)].shape[1]
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
return dit_config
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
dit_config = {}
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]

View File

@ -641,15 +641,17 @@ def free_pins(size, evict_active=False):
return freed_total
def ensure_pin_budget(size, evict_active=False):
shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available
if args.fast_disk:
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
else:
shortfall = size + max(comfy.memory_management.RAM_CACHE_HEADROOM / 2, 2048 * 1024 ** 2) - psutil.virtual_memory().available
if shortfall <= 0:
return True
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
return free_pins(to_free, evict_active=evict_active) >= shortfall
def ensure_pin_registerable(size, evict_active=False):
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
def free_registrations(shortfall, evict_active=True):
if MAX_PINNED_MEMORY <= 0:
return False
if shortfall <= 0:
@ -658,12 +660,22 @@ def ensure_pin_registerable(size, evict_active=False):
shortfall += REGISTERABLE_PIN_HYSTERESIS
for loaded_model in reversed(current_loaded_models):
model = loaded_model.model
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
if model is not None and model.is_dynamic() and not model.model.dynamic_pins[model.load_device]["active"]:
shortfall -= model.unregister_inactive_pins(shortfall)
if shortfall <= 0:
return True
if evict_active:
for loaded_model in current_loaded_models:
model = loaded_model.model
if model is not None and model.is_dynamic() and model.model.dynamic_pins[model.load_device]["active"]:
shortfall -= model.unregister_inactive_pins(shortfall)
if shortfall <= 0:
return True
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
def ensure_pin_registerable(size, evict_active=True):
return free_registrations(TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY, evict_active=evict_active)
class LoadedModel:
def __init__(self, model: ModelPatcher):
self._set_model(model)
@ -803,9 +815,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for x in can_unload_sorted:
i = x[-1]
memory_to_free = 1e32
if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
if not DISABLE_SMART_MEMORY or device is None:
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
if for_dynamic:
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand.
memory_required -= current_loaded_models[i].model.loaded_size()
@ -817,6 +829,10 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i))
if not for_dynamic and pins_required > 0:
ensure_pin_budget(pins_required)
ensure_pin_registerable(pins_required)
if len(unloaded_model) > 0:
soft_empty_cache()
elif device is not None:
@ -879,15 +895,19 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
model_to_unload.model_finalizer.detach()
total_memory_required = {}
total_pins_required = {}
for loaded_model in models_to_load:
device = loaded_model.device
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
if not loaded_model.model.is_dynamic():
total_pins_required[device] = total_pins_required.get(device, 0) + loaded_model.model_memory()
for device in total_memory_required:
if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem,
device,
for_dynamic=free_for_dynamic)
for_dynamic=free_for_dynamic,
pins_required=total_pins_required.get(device, 0))
for device in total_memory_required:
if device != torch.device("cpu"):
@ -1283,7 +1303,6 @@ STREAM_CAST_BUFFERS = {}
LARGEST_CASTED_WEIGHT = (None, 0)
STREAM_AIMDO_CAST_BUFFERS = {}
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
STREAM_PIN_BUFFERS = {}
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
@ -1326,42 +1345,13 @@ def get_aimdo_cast_buffer(offload_stream, device):
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
return cast_buffer
def get_pin_buffer(offload_stream):
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
if pin_buffer is None:
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3), mark_cold=False)
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
elif offload_stream is not None:
event = getattr(pin_buffer, "_comfy_event", None)
if event is not None:
event.synchronize()
delattr(pin_buffer, "_comfy_event")
return pin_buffer
def resize_pin_buffer(pin_buffer, size):
global TOTAL_PINNED_MEMORY
old_size = pin_buffer.size
if size <= old_size:
return True
growth = size - old_size
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
ensure_pin_budget(growth, evict_active=True)
ensure_pin_registerable(growth, evict_active=True)
try:
pin_buffer.extend(size=size, reallocate=True)
except RuntimeError:
return False
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
return True
def reset_cast_buffers():
global TOTAL_PINNED_MEMORY
global LARGEST_CASTED_WEIGHT
global LARGEST_AIMDO_CASTED_WEIGHT
LARGEST_CASTED_WEIGHT = (None, 0)
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
if offload_stream is not None:
offload_stream.synchronize()
synchronize()
@ -1370,20 +1360,24 @@ def reset_cast_buffers():
mmap_obj.bounce()
DIRTY_MMAPS.clear()
for pin_buffer in STREAM_PIN_BUFFERS.values():
TOTAL_PINNED_MEMORY -= pin_buffer.size
TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY)
for loaded_model in current_loaded_models:
model = loaded_model.model
if model is not None and model.is_dynamic():
model.model.dynamic_pins[model.load_device]["active"] = False
pin_state = model.model.dynamic_pins[model.load_device]
if pin_state["active"]:
*_, buckets = pin_state["weights"]
for size, bucket in list(buckets.items()):
bucket[:] = [ entry for entry in bucket if entry[-1] is not None ]
if not bucket:
del buckets[size]
pin_state["active"] = False
model.partially_unload_ram(1e30, subsets=[ "patches" ])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0])
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0], [0], {})
STREAM_CAST_BUFFERS.clear()
STREAM_AIMDO_CAST_BUFFERS.clear()
STREAM_PIN_BUFFERS.clear()
soft_empty_cache()
def get_offload_stream(device):
@ -1436,7 +1430,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
if hasattr(wf_context, "as_context"):
wf_context = wf_context.as_context(stream)
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) if r is not None else [None] * len(tensors)
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
with wf_context:
for tensor in tensors:
@ -1448,9 +1442,10 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
mark_mmap_dirty(storage)
dest_view.copy_(tensor, non_blocking=non_blocking)
if dest_view is not None:
dest_view.copy_(tensor, non_blocking=non_blocking)
if dest2_view is not None:
dest2_view.copy_(dest_view, non_blocking=non_blocking)
dest2_view.copy_(tensor if dest_view is None else dest_view, non_blocking=non_blocking)
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
@ -1723,6 +1718,13 @@ def is_device_xpu(device):
def is_device_cuda(device):
return is_device_type(device, 'cuda')
def set_torch_device(device):
"""Set the current device for the given torch device. Supports CUDA and XPU."""
if is_device_cuda(device):
torch.cuda.set_device(device)
elif is_device_xpu(device):
torch.xpu.set_device(device)
def is_directml_enabled():
global directml_enabled
if directml_enabled:

View File

@ -1721,8 +1721,8 @@ class ModelPatcherDynamic(ModelPatcher):
"""
if device not in self.model.dynamic_pins:
self.model.dynamic_pins[device] = {
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
"hostbufs_initialized": False,
"failed": False,
"active": False,
@ -1799,8 +1799,8 @@ class ModelPatcherDynamic(ModelPatcher):
pin_state = self.model.dynamic_pins[self.load_device]
if not pin_state["hostbufs_initialized"]:
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0])
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
pin_state["hostbufs_initialized"] = True
pin_state["failed"] = False
pin_state["active"] = True
@ -1942,18 +1942,16 @@ class ModelPatcherDynamic(ModelPatcher):
return freed
def loaded_ram_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
self.model.dynamic_pins[self.load_device]["patches"][0].size)
return (self.model.dynamic_pins[self.load_device]["weights"][0].size)
def pinned_memory_size(self):
return (self.model.dynamic_pins[self.load_device]["weights"][3][0] +
self.model.dynamic_pins[self.load_device]["patches"][3][0])
return (self.model.dynamic_pins[self.load_device]["weights"][3][0])
def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
split = stack_split[0]
while split >= 0:
module, offset = stack[split]
@ -1978,10 +1976,12 @@ class ModelPatcherDynamic(ModelPatcher):
freed = 0
pin_state = self.model.dynamic_pins[self.load_device]
for subset in subsets:
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
while len(stack) > 0:
module, offset = stack.pop()
size = module._pin.numel() * module._pin.element_size()
module._pin_balancer_entry[-1] = None
del module._pin_balancer_entry
del module._pin
hostbuf.truncate(offset, do_unregister=module._pin_registered)
stack_split[0] = min(stack_split[0], len(stack) - 1)

View File

@ -1,4 +1,5 @@
import comfy_aimdo.model_vbar
import comfy.memory_management
import comfy.model_management
import comfy.ops
@ -50,7 +51,17 @@ def prefetch_queue_pop(queue, device, module):
if hasattr(s, "_v"):
comfy_modules.append(s)
registerable_size = 0
for s in comfy_modules:
registerable_size += comfy.memory_management.vram_aligned_size([s.weight, s.bias])
for param_key in ("weight", "bias"):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
if lowvram_fn is not None:
registerable_size += lowvram_fn.memory_required()
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
if not comfy.model_management.args.fast_disk:
comfy.model_management.ensure_pin_registerable(registerable_size)
comfy.model_management.sync_stream(device, offload_stream)
queue[0] = (offload_stream, (prefetch, comfy_modules))

View File

@ -17,7 +17,7 @@ class MultiGPUThreadPool:
"""Persistent thread pool for multi-GPU work distribution.
Maintains one worker thread per extra GPU device. Each thread calls
torch.cuda.set_device() once at startup so that compiled kernel caches
set_torch_device() once at startup so that compiled kernel caches
(inductor/triton) stay warm across diffusion steps.
"""
@ -37,7 +37,7 @@ class MultiGPUThreadPool:
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
try:
torch.cuda.set_device(device)
comfy.model_management.set_torch_device(device)
except Exception as e:
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
while True:
@ -54,6 +54,8 @@ class MultiGPUThreadPool:
try:
result = fn(*args, **kwargs)
result_q.put((result, None))
except comfy.model_management.InterruptProcessingException as e:
result_q.put((None, e))
except Exception as e:
result_q.put((None, e))

View File

@ -76,8 +76,6 @@ except:
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
def cast_to_input(weight, input, non_blocking=False, copy=True):
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
@ -94,9 +92,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
offload_stream = None
cast_buffer = None
cast_buffer_offset = 0
stream_pin_hostbuf = None
stream_pin_offset = 0
stream_pin_queue = []
def ensure_offload_stream(module, required_size, check_largest):
nonlocal offload_stream
@ -130,22 +125,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
cast_buffer_offset += buffer_size
return buffer
def get_stream_pin_buffer_offset(buffer_size):
nonlocal stream_pin_hostbuf
nonlocal stream_pin_offset
if buffer_size == 0 or offload_stream is None:
return None
if stream_pin_hostbuf is None:
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
if stream_pin_hostbuf is None:
return None
offset = stream_pin_offset
stream_pin_offset += buffer_size
return offset
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
@ -184,12 +163,18 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if xfer_dest is None:
xfer_dest = get_cast_buffer(dest_size)
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream, xfer_dest2=None):
if xfer_source is not None:
if getattr(xfer_source, "is_lowvram_patch", False):
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
else:
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
if xfer_dest is not None:
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
xfer_source = [ xfer_dest ]
xfer_dest = xfer_dest2
xfer_dest2 = None
elif xfer_dest2 is not None:
xfer_source.prepare(xfer_dest2, stream, copy=True, commit=False)
return
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream, r2=xfer_dest2)
def handle_pin(m, pin, source, dest, subset="weights", size=None):
if pin is not None:
@ -198,19 +183,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
if signature is None:
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
pin = comfy.pinned_memory.get_pin(m, subset=subset)
if pin is not None:
if isinstance(source, list):
comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest)
else:
cast_maybe_lowvram_patch(source, pin, None)
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
return
if pin is None:
pin_offset = get_stream_pin_buffer_offset(size)
if pin_offset is not None:
stream_pin_queue.append((source, pin_offset, size, dest))
return
cast_maybe_lowvram_patch(source, dest, offload_stream)
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
@ -232,23 +205,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
if stream_pin_offset > 0:
if stream_pin_hostbuf.size < stream_pin_offset:
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
for xfer_source, _, _, xfer_dest in stream_pin_queue:
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
return offload_stream
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf)
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
if isinstance(xfer_source, list):
comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest)
else:
cast_maybe_lowvram_patch(xfer_source, pin, None)
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
stream_pin_hostbuf._comfy_event = offload_stream.record_event()
return offload_stream

View File

@ -1,17 +1,55 @@
import bisect
import comfy.model_management
import comfy.memory_management
import comfy.utils
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
import torch
from comfy.cli_args import args
def _add_to_bucket(module, buckets, size, priority):
bucket = buckets.setdefault(size, [])
entry = [-priority, 0, module]
entry[1] = id(entry)
bisect.insort(bucket, entry)
module._pin_balancer_entry = entry
def _steal_pin(module, stack, buckets, size, priority):
bucket = buckets.get(size)
if bucket is None:
return False
while bucket and bucket[-1][-1] is None:
bucket.pop()
if not bucket:
del buckets[size]
return False
if priority <= -bucket[-1][0]:
return False
*_, victim = bucket.pop()
module._pin = victim._pin
module._pin_registered = victim._pin_registered
module._pin_stack_index = victim._pin_stack_index
stack[module._pin_stack_index] = (module, stack[module._pin_stack_index][1])
victim._pin_registered = False
del victim._pin
del victim._pin_stack_index
del victim._pin_balancer_entry
_add_to_bucket(module, buckets, size, priority)
return True
def get_pin(module, subset="weights"):
pin = getattr(module, "_pin", None)
if pin is None or module._pin_registered or args.disable_pinned_memory:
return pin
_, _, stack_split, pinned_size = module._pin_state[subset]
_, _, stack_split, pinned_size, *_ = module._pin_state[subset]
size = pin.nbytes
comfy.model_management.ensure_pin_registerable(size)
@ -31,33 +69,51 @@ def pin_memory(module, subset="weights", size=None):
return
pin = get_pin(module, subset)
if pin is not None or pin_state["failed"]:
if pin is not None:
return
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
hostbuf, stack, stack_split, pinned_size, counter, buckets = pin_state[subset]
if size is None:
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
offset = hostbuf.size
registerable_size = size + max(0, hostbuf.size - pinned_size[0])
registerable_size = size
priority = getattr(module, "_pin_balancer_priority", None)
if priority is None:
priority = comfy.utils.bit_reverse_range(counter[0], 16)
counter[0] += 1
module._pin_balancer_priority = priority
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
if (not comfy.model_management.ensure_pin_budget(size) or
not comfy.model_management.ensure_pin_registerable(registerable_size)):
pin_state["failed"] = True
return False
return _steal_pin(module, stack, buckets, size, priority)
extended = False
try:
hostbuf.extend(size=size)
hostbuf.extend(size=size, register=False)
extended = True
pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
pin.untyped_storage()._comfy_hostbuf = hostbuf
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
comfy.model_management.discard_cuda_async_error()
comfy.model_management.free_registrations(size)
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
comfy.model_management.discard_cuda_async_error()
del pin
hostbuf.truncate(offset, do_unregister=False)
return _steal_pin(module, stack, buckets, size, priority)
except RuntimeError:
pin_state["failed"] = True
return False
if extended:
hostbuf.truncate(offset, do_unregister=False)
return _steal_pin(module, stack, buckets, size, priority)
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
module._pin = pin
stack.append((module, offset))
module._pin_registered = True
module._pin_stack_index = len(stack) - 1
stack_split[0] = max(stack_split[0], module._pin_stack_index)
comfy.model_management.TOTAL_PINNED_MEMORY += size
pinned_size[0] += size
_add_to_bucket(module, buckets, size, priority)
return True

View File

@ -44,7 +44,13 @@ def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None,
is_empty = torch.count_nonzero(latent_image) == 0
if is_empty:
if latent_format.latent_channels != latent_image.shape[1]:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
preserves_collapsed_channels = (
getattr(latent_format, "preserve_empty_channel_multiples", False)
and latent_image.ndim == 4
and latent_image.shape[1] % latent_format.latent_channels == 0
)
if not preserves_collapsed_channels:
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
if downscale_ratio_spacial is not None:
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio

View File

@ -464,10 +464,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
try:
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
# XPU/NPU/MPS/CPU/DirectML backends.
torch.cuda.set_device(device)
comfy.model_management.set_torch_device(device)
model_current: BaseModel = model_options["multigpu_clones"][device].model
# run every hooked_to_run separately
with torch.no_grad():

View File

@ -1,3 +1,4 @@
import inspect
import json
import torch
from enum import Enum
@ -16,6 +17,8 @@ import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.seedvr.vae
import comfy.ldm.triposplat.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae
import comfy.ldm.hunyuan_video.vae
@ -57,6 +60,7 @@ import comfy.text_encoders.omnigen2
import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.z_image
import comfy.text_encoders.ideogram4
import comfy.text_encoders.ovis
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.jina_clip_2
@ -82,6 +86,36 @@ import comfy.latent_formats
import comfy.ldm.flux.redux
SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL = 160
def _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w):
output_t = max(1, (latent_t - 1) * 4 + 1)
return output_t * latent_h * 8 * latent_w * 8
def _seedvr2_vae_decode_memory_used(shape):
if len(shape) == 5:
candidates = []
if shape[1] == 16:
candidates.append((shape[2], shape[3], shape[4]))
if shape[-1] == 16:
candidates.append((shape[1], shape[2], shape[3]))
if len(candidates) == 0:
candidates.append((shape[2], shape[3], shape[4]))
output_pixels = max(_seedvr2_vae_decode_output_pixels(*candidate) for candidate in candidates)
elif len(shape) == 4:
latent_t = max(1, (shape[1] + 15) // 16)
latent_h, latent_w = shape[2], shape[3]
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
else:
latent_t, latent_h, latent_w = 1, shape[-2], shape[-1]
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
# SeedVR2 decode performs full-frame LAB histogram matching: fp32 channels
# plus int64 sort indices dominate peak memory, not the VAE weight dtype.
return output_pixels * SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
key_map = {}
if model is not None:
@ -465,8 +499,10 @@ class CLIP:
class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
if metadata is None or metadata.get("keep_diffusers_format") != "true":
sd = diffusers_convert.convert_vae_state_dict(sd)
if model_management.is_amd():
VAE_KL_MEM_RATIO = 2.73
@ -538,6 +574,20 @@ class VAE:
self.first_stage_model = StageC_coder()
self.downscale_ratio = 32
self.latent_channels = 16
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
self.latent_channels = 16
self.latent_dim = 3
self.disable_offload = True
self.memory_used_decode = lambda shape, dtype: _seedvr2_vae_decode_memory_used(shape)
self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.process_input = lambda image: image * 2.0 - 1.0
self.crop_input = False
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
@ -665,6 +715,7 @@ class VAE:
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
@ -894,6 +945,16 @@ class VAE:
#Force cast it for --disable-dynamic-vram users until there is a true core fix.
if not comfy.memory_management.aimdo_enabled:
self.disable_offload = True
elif "gs.base_offset_scale" in sd and "octree.out_proj.weight" in sd: # TripoSplat octree gaussian decoder
self.first_stage_model = comfy.ldm.triposplat.vae.OctreeGaussianDecoder()
self.latent_channels = 16
self.latent_dim = 1
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# The generic VAE.encode/decode path isn't used: VAEDecodeTripoSplat calls the gaussian
# decoder directly (structured GaussianSplat objects, not a tensor and reserves VRAM itself from num_gaussians.
def _no_generic_io(*args, **kwargs):
raise RuntimeError("TripoSplat gaussian decoder: use the 'TripoSplat Decode' (VAEDecodeTripoSplat)")
self.memory_used_encode = self.memory_used_decode = _no_generic_io
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@ -994,6 +1055,40 @@ class VAE:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def decode_tiled_seedvr2(self, samples, tile_x=32, tile_y=32, overlap=8, tile_t=16, overlap_t=4):
sf_s = getattr(self.first_stage_model, "spatial_downsample_factor", 8)
sf_t = getattr(self.first_stage_model, "temporal_downsample_factor", 4)
if tile_t is None:
tile_t = 16
if overlap_t is None:
overlap_t = 4
if tile_t > 0:
temporal_size = tile_t * sf_t
temporal_overlap = max(0, overlap_t) * sf_t
else:
temporal_size = 0
temporal_overlap = 0
args = {
"enable_tiling": True,
"tile_size": (tile_y * sf_s, tile_x * sf_s),
"tile_overlap": (overlap * sf_s, overlap * sf_s),
"temporal_size": temporal_size,
"temporal_overlap": temporal_overlap,
}
output = self.first_stage_model.decode(
samples.to(self.vae_dtype).to(self.device),
seedvr2_tiling=args,
)
return self.process_output(output.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
def _format_seedvr2_encoded_samples(self, samples):
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
if samples.ndim == 4:
samples = samples.unsqueeze(2)
samples = samples.contiguous()
samples = samples * 0.9152
return samples
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
@ -1030,6 +1125,36 @@ class VAE:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def encode_tiled_seedvr2(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
if tile_y is None:
tile_y = 512
if tile_x is None:
tile_x = 512
if overlap is None:
overlap_y = 64
overlap_x = 64
else:
overlap_y = overlap
overlap_x = overlap
if tile_t is None:
tile_t = 9999
if overlap_t is None:
overlap_t = 0
overlap_y = min(overlap_y, max(0, tile_y - 8))
overlap_x = min(overlap_x, max(0, tile_x - 8))
self.first_stage_model.device = self.device
x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device)
output = comfy.ldm.seedvr.vae.tiled_vae(
x,
self.first_stage_model,
tile_size=(tile_y, tile_x),
tile_overlap=(overlap_y, overlap_x),
temporal_size=tile_t,
temporal_overlap=overlap_t,
encode=True,
)
return output.to(device=self.output_device, dtype=self.vae_output_dtype())
def decode(self, samples_in, vae_options={}):
self.throw_exception_if_invalid()
pixel_samples = None
@ -1077,16 +1202,40 @@ class VAE:
if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
# SeedVR2 latents arrive in 4D collapsed form ``(B, 16*T, H, W)``
# downstream of ``SeedVR2Conditioning`` (which performs the
# ``rearrange(b c t h w -> b (c t) h w)`` collapse). The
# generic ``decode_tiled_`` would treat the channel dim as
# spatial-only and crash on the collapsed (16, T) layout
# under ``tiled_scale``'s mask broadcast; route SeedVR2 4D
# latents to ``decode_tiled_seedvr2`` instead, whose wrapper
# dispatch handles both 4D and 5D inputs.
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
else:
pixel_samples = self.decode_tiled_(samples_in)
elif dims == 3:
tile = 256 // self.spacial_compression_decode()
overlap = tile // 4
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
else:
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
def decode_tiled(
self,
samples,
tile_x=None,
tile_y=None,
overlap=None,
tile_t=None,
overlap_t=None,
):
self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
@ -1100,7 +1249,20 @@ class VAE:
args["overlap"] = overlap
with model_management.cuda_device_context(self.device):
if dims == 1 or self.extra_1d_channel is not None:
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper) and dims in (2, 3):
seedvr2_args = {}
if tile_x is not None:
seedvr2_args["tile_x"] = tile_x
if tile_y is not None:
seedvr2_args["tile_y"] = tile_y
if overlap is not None:
seedvr2_args["overlap"] = overlap
if tile_t is not None:
seedvr2_args["tile_t"] = tile_t
if overlap_t is not None:
seedvr2_args["overlap_t"] = overlap_t
output = self.decode_tiled_seedvr2(samples, **seedvr2_args)
elif dims == 1 or self.extra_1d_channel is not None:
args.pop("tile_y")
output = self.decode_tiled_1d(samples, **args)
elif dims == 2:
@ -1142,6 +1304,8 @@ class VAE:
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
if isinstance(out, tuple):
out = out[0]
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
@ -1161,20 +1325,23 @@ class VAE:
if self.latent_dim == 3:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
samples = self.encode_tiled_seedvr2(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap)
else:
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)
return samples
return self._format_seedvr2_encoded_samples(samples)
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
if dims == 3:
if dims == 3 and pixel_samples.ndim < 5:
if not self.not_video:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
else:
@ -1198,22 +1365,47 @@ class VAE:
elif dims == 2:
samples = self.encode_tiled_(pixel_samples, **args)
elif dims == 3:
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
seedvr2_args = {}
if tile_x is not None:
seedvr2_args["tile_x"] = tile_x
else:
seedvr2_args["tile_x"] = 512
if tile_y is not None:
seedvr2_args["tile_y"] = tile_y
else:
seedvr2_args["tile_y"] = 512
if overlap is not None:
seedvr2_args["overlap"] = overlap
else:
seedvr2_args["overlap"] = 64
if tile_t is not None:
seedvr2_args["tile_t"] = tile_t
else:
seedvr2_args["tile_t"] = 9999
if overlap_t is not None:
seedvr2_args["overlap_t"] = overlap_t
else:
seedvr2_args["overlap_t"] = 0
samples = self.encode_tiled_seedvr2(pixel_samples, **seedvr2_args)
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if tile_t is not None:
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
else:
tile_t_latent = 9999
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
if overlap_t is None:
args["overlap"] = (1, overlap, overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
spatial_overlap = overlap if overlap is not None else 64
if overlap_t is None:
args["overlap"] = (1, spatial_overlap, spatial_overlap)
else:
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap)
maximum = pixel_samples.shape[2]
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
return samples
return self._format_seedvr2_encoded_samples(samples)
def get_sd(self):
return self.first_stage_model.state_dict()
@ -1287,6 +1479,7 @@ class CLIPType(Enum):
COGVIDEOX = 27
LENS = 28
PIXELDIT = 29
IDEOGRAM4 = 30
@ -1585,8 +1778,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
elif te_model == TEModel.QWEN3_8B:
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b")
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B
if clip_type == CLIPType.IDEOGRAM4:
clip_target.clip = comfy.text_encoders.ideogram4.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Tokenizer
else:
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b")
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B
elif te_model == TEModel.JINA_CLIP_2:
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
@ -1735,6 +1932,17 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device):
set_dtype = model_config.set_inference_dtype
parameters = inspect.signature(set_dtype).parameters
supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values())
if supports_device:
set_dtype(dtype, manual_cast_dtype, device=device)
else:
set_dtype(dtype, manual_cast_dtype)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
@ -1842,7 +2050,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
if model_config.clip_vision_prefix is not None:
if output_clipvision:
@ -1983,7 +2191,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
else:
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
if custom_operations is not None:
model_config.custom_operations = custom_operations

View File

@ -24,6 +24,7 @@ import comfy.text_encoders.qwen_image
import comfy.text_encoders.hunyuan_image
import comfy.text_encoders.kandinsky5
import comfy.text_encoders.z_image
import comfy.text_encoders.ideogram4
import comfy.text_encoders.anima
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
@ -1538,6 +1539,30 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
latent_format = latent_formats.Hunyuan3Dv2mini
class TripoSplat(supported_models_base.BASE):
# Image -> 3D gaussian splat flow denoiser
unet_config = {
"image_model": "triposplat",
}
unet_extra_config = {}
sampling_settings = {
"shift": 3.0,
}
memory_usage_factor = 0.6
latent_format = latent_formats.TripoSplat
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
return model_base.TripoSplat(self, device=device)
def clip_target(self, state_dict={}):
return None
class HiDream(supported_models_base.BASE):
unet_config = {
"image_model": "hidream",
@ -1647,6 +1672,35 @@ class Chroma(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
class SeedVR2(supported_models_base.BASE):
unet_config = {
"image_model": "seedvr2"
}
latent_format = comfy.latent_formats.SeedVR2
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
sampling_settings = {
"shift": 1.0,
}
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
if (
dtype == torch.float16
and manual_cast_dtype is None
and comfy.model_management.should_use_bf16(device)
):
manual_cast_dtype = torch.bfloat16
super().set_inference_dtype(dtype, manual_cast_dtype, device=device)
def get_model(self, state_dict, prefix="", device=None):
out = model_base.SeedVR2(self, device=device)
return out
def clip_target(self, state_dict={}):
return None
class ChromaRadiance(Chroma):
unet_config = {
"image_model": "chroma_radiance",
@ -1722,6 +1776,44 @@ class Omnigen2(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
class Ideogram4(supported_models_base.BASE):
unet_config = {
"image_model": "ideogram4",
}
sampling_settings = {
"multiplier": 1.0,
"shift": 1.0,
}
memory_usage_factor = 11.6
unet_extra_config = {
"num_attention_heads": 18,
"attention_head_dim": 256,
"intermediate_size": 12288,
"adaln_dim": 512,
"llm_features_dim": 53248,
"rope_theta": 5000000,
"mrope_section": [24, 20, 20],
"norm_eps": 1e-5,
}
latent_format = latent_formats.Flux2
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
def get_model(self, state_dict, prefix="", device=None):
out = model_base.Ideogram4(self, device=device)
return out
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.ideogram4.Ideogram4Tokenizer, comfy.text_encoders.ideogram4.te(**hunyuan_detect))
class QwenImage(supported_models_base.BASE):
unet_config = {
"image_model": "qwen_image",
@ -1966,7 +2058,6 @@ class LongCatImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
class RT_DETR_v4(supported_models_base.BASE):
unet_config = {
"image_model": "RT_DETR_v4",
@ -2200,14 +2291,17 @@ models = [
Hunyuan3Dv2mini,
Hunyuan3Dv2,
Hunyuan3Dv2_1,
TripoSplat,
HiDream,
HiDreamO1,
Chroma,
SeedVR2,
ChromaRadiance,
ACEStep,
ACEStep15,
Omnigen2,
QwenImage,
Ideogram4,
Flux2,
Lens,
Kandinsky5Image,

View File

@ -115,7 +115,7 @@ class BASE:
replace_prefix = {"": self.vae_key_prefix[0]}
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
def set_inference_dtype(self, dtype, manual_cast_dtype):
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
self.unet_config['dtype'] = dtype
self.manual_cast_dtype = manual_cast_dtype

View File

@ -0,0 +1,77 @@
"""Ideogram 4 text encoder: Qwen3-VL-8B language model, 13-layer tap.
Ideogram 4 conditions on the concatenation of hidden states from 13 layers of
Qwen3-VL (layers 0,3,...,33,35), giving a 4096*13 = 53248-dim feature per token.
"""
import os
from transformers import Qwen2Tokenizer
import comfy.text_encoders.llama
from comfy import sd1_clip
# Reference taps outputs of layers (0,3,...,35); comfy captures layer inputs, offset by +1.
IDEOGRAM4_TAP_LAYERS = [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 36]
class Qwen3VLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory,
embedding_size=4096, embedding_key='qwen3vl_8b', tokenizer_class=Qwen2Tokenizer,
has_start_token=False, has_end_token=False, pad_to_max_length=False,
max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
name="qwen3vl_8b", tokenizer=Qwen3VLTokenizer)
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
# Qwen3-VL-8B = 5e6 (vs plain Qwen3-8B's 1e6)
# final_norm/lm_head off -> Ideogram only reads raw tapped hidden states
QWEN3VL_8B_CONFIG = {"rope_theta": 5000000.0, "final_norm": False, "lm_head": False}
class Qwen3VL8BModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
super().__init__(device=device, layer=IDEOGRAM4_TAP_LAYERS, layer_idx=None,
textmodel_json_config=dict(QWEN3VL_8B_CONFIG),
dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False,
model_class=comfy.text_encoders.llama.Qwen3_8B,
enable_attention_masks=attention_mask, return_attention_masks=attention_mask,
model_options=model_options)
class Ideogram4TEModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Qwen3VL8BModel, model_options=model_options)
def encode_token_weights(self, token_weight_pairs):
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096) stacked in ascending layer order.
out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13). permute -> (B, seq, H, taps).
return out, pooled, extra
def te(dtype_llama=None, llama_quantization_metadata=None):
class Ideogram4TEModel_(Ideogram4TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if dtype_llama is not None:
dtype = dtype_llama
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, dtype=dtype, model_options=model_options)
return Ideogram4TEModel_

View File

@ -85,9 +85,9 @@ _TYPES = {
def load_safetensors(ckpt):
import comfy_aimdo.model_mmap
f = open(ckpt, "rb", buffering=0)
file_lock = threading.Lock()
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
f = model_mmap.get_file_handle()
file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
@ -1452,3 +1452,10 @@ def deepcopy_list_dict(obj, memo=None):
memo[obj_id] = res
return res
def bit_reverse_range(index, bits):
result = 0
for _ in range(bits):
result = (result << 1) | (index & 1)
index >>= 1
return result

View File

@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
from comfy_api.internal.async_to_sync import create_sync_class
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
from ._input_impl import VideoFromFile, VideoFromComponents
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D
from . import _io_public as io
from . import _ui_public as ui
from comfy_execution.utils import get_executing_context
@ -143,6 +143,7 @@ class Types:
VideoComponents = VideoComponents
MESH = MESH
VOXEL = VOXEL
SPLAT = SPLAT
File3D = File3D

View File

@ -65,6 +65,12 @@ class VideoInput(ABC):
buffer.seek(0)
return buffer
def get_active_trim_window(self) -> tuple[float, float]:
"""Return the active trim as ``(start_time, duration)`` in seconds (start_time normalized
to ``>= 0``; ``duration == 0`` means "until the end"). Default: no trim; trimmable subclasses override.
"""
return 0.0, 0.0
# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:

View File

@ -75,6 +75,12 @@ class VideoFromFile(VideoInput):
self.__file.seek(0)
return self.__file
def get_active_trim_window(self) -> tuple[float, float]:
start_time = self.__start_time
if start_time < 0:
start_time = max(self._get_raw_duration() + start_time, 0.0)
return float(start_time), float(self.__duration)
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.

View File

@ -28,7 +28,7 @@ if TYPE_CHECKING:
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL, SVG as _SVG, File3D
from ._util import MESH, VOXEL, SPLAT, SVG as _SVG, File3D
class FolderType(str, Enum):
@ -684,6 +684,10 @@ class Voxel(ComfyTypeIO):
class Mesh(ComfyTypeIO):
Type = MESH
@comfytype(io_type="SPLAT")
class Splat(ComfyTypeIO):
Type = SPLAT
@comfytype(io_type="FILE_3D")
class File3DAny(ComfyTypeIO):
@ -727,6 +731,42 @@ class File3DUSDZ(ComfyTypeIO):
Type = File3D
@comfytype(io_type="FILE_3D_PLY")
class File3DPLY(ComfyTypeIO):
"""PLY format 3D file - point cloud or Gaussian splat."""
Type = File3D
@comfytype(io_type="FILE_3D_SPLAT")
class File3DSPLAT(ComfyTypeIO):
"""SPLAT format 3D file - 3D Gaussian splat."""
Type = File3D
@comfytype(io_type="FILE_3D_SPZ")
class File3DSPZ(ComfyTypeIO):
"""SPZ format 3D file - compressed 3D Gaussian splat."""
Type = File3D
@comfytype(io_type="FILE_3D_KSPLAT")
class File3DKSPLAT(ComfyTypeIO):
"""KSPLAT format 3D file - 3D Gaussian splat."""
Type = File3D
@comfytype(io_type="FILE_3D_SPLAT_ANY")
class File3DSplatAny(ComfyTypeIO):
"""General 3D Gaussian splat file type - accepts any supported splat container (.ply / .spz / .splat / .ksplat)."""
Type = File3D
@comfytype(io_type="FILE_3D_POINT_CLOUD_ANY")
class File3DPointCloudAny(ComfyTypeIO):
"""General point cloud file type - accepts any supported point cloud container (currently .ply)."""
Type = File3D
@comfytype(io_type="HOOKS")
class Hooks(ComfyTypeIO):
if TYPE_CHECKING:
@ -762,14 +802,32 @@ class Accumulation(ComfyTypeIO):
@comfytype(io_type="LOAD3D_CAMERA")
class Load3DCamera(ComfyTypeIO):
class CameraInfo(TypedDict):
position: dict[str, float | int]
target: dict[str, float | int]
zoom: int
cameraType: str
# Coordinate system: right-handed, Y-up, camera looks down -Z
position: dict[str, float | int] # scene units
target: dict[str, float | int] # scene units; OrbitControls focus point
zoom: float | int # dimensionless, 1 = 100%
cameraType: str # 'perspective' | 'orthographic'
quaternion: NotRequired[dict[str, float | int]] # normalized, dimensionless; camera world rotation
fov: NotRequired[float | int] # degrees, vertical FOV (perspective only)
aspect: NotRequired[float | int] # width / height (perspective only)
near: NotRequired[float | int] # scene units
far: NotRequired[float | int] # scene units
frustum: NotRequired[dict[str, float | int]] # orthographic only: {left, right, top, bottom} in scene units
Type = CameraInfo
@comfytype(io_type="LOAD3D_MODEL_INFO")
class Load3DModelInfo(ComfyTypeIO):
class Model3DTransform(TypedDict):
# Coordinate system: right-handed, Y-up, world space
position: dict[str, float | int] # scene units
quaternion: dict[str, float | int] # normalized, dimensionless; world rotation
scale: dict[str, float | int] # dimensionless multiplier
Type = list[Model3DTransform]
@comfytype(io_type="LOAD_3D")
class Load3D(ComfyTypeIO):
"""3D models are stored as a dictionary."""
@ -779,6 +837,7 @@ class Load3D(ComfyTypeIO):
normal: str
camera_info: Load3DCamera.CameraInfo
recording: NotRequired[str]
model_3d_info: NotRequired[list[Load3DModelInfo.Model3DTransform]]
Type = Model3DDict
@ -2277,6 +2336,7 @@ __all__ = [
"LossMap",
"Voxel",
"Mesh",
"Splat",
"File3DAny",
"File3DGLB",
"File3DGLTF",
@ -2284,6 +2344,12 @@ __all__ = [
"File3DOBJ",
"File3DSTL",
"File3DUSDZ",
"File3DPLY",
"File3DSPLAT",
"File3DSPZ",
"File3DKSPLAT",
"File3DSplatAny",
"File3DPointCloudAny",
"Hooks",
"HookKeyframes",
"TimestepsRange",
@ -2291,6 +2357,7 @@ __all__ = [
"FlowControl",
"Accumulation",
"Load3DCamera",
"Load3DModelInfo",
"Load3D",
"Load3DAnimation",
"Photomaker",

View File

@ -285,7 +285,7 @@ class AudioSaveHelper:
results = []
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
file = f"{filename_with_batch_num}_{counter:05}.{format}"
output_path = os.path.join(full_output_folder, file)
# Use original sample rate initially
@ -452,6 +452,16 @@ class PreviewUI3D(_UIOutput):
return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
class PreviewUI3DAdvanced(_UIOutput):
def __init__(self, model_file, camera_info, model_3d_info):
self.model_file = model_file
self.camera_info = camera_info
self.model_3d_info = model_3d_info
def as_dict(self):
return {"result": [self.model_file, self.camera_info, self.model_3d_info]}
class PreviewText(_UIOutput):
def __init__(self, value: str, **kwargs):
self.value = value
@ -471,5 +481,6 @@ __all__ = [
"PreviewAudio",
"PreviewVideo",
"PreviewUI3D",
"PreviewUI3DAdvanced",
"PreviewText",
]

View File

@ -1,5 +1,5 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH, File3D
from .geometry_types import VOXEL, MESH, SPLAT, File3D
from .image_types import SVG
__all__ = [
@ -9,6 +9,7 @@ __all__ = [
"VideoComponents",
"VOXEL",
"MESH",
"SPLAT",
"File3D",
"SVG",
]

View File

@ -11,13 +11,32 @@ class VOXEL:
self.data = data
class SPLAT:
"""A batch of 3D Gaussian splats in render-ready (activated, world-space) form.
Tensors are (B, N, ...) and zero-padded to a common N across the batch; `counts` (B,) holds the
real per-item lengths (None when rows are uniform and no slicing is needed). SH coefficients are
stored as (B, N, K, 3) with K = (sh_degree + 1)**2; the DC (diffuse) term is sh[..., 0, :].
"""
def __init__(self, positions: torch.Tensor, scales: torch.Tensor, rotations: torch.Tensor,
opacities: torch.Tensor, sh: torch.Tensor, counts: torch.Tensor | None = None):
self.positions = positions # (B, N, 3) world-space centers
self.scales = scales # (B, N, 3) linear (positive) per-axis std
self.rotations = rotations # (B, N, 4) quaternion wxyz (normalized)
self.opacities = opacities # (B, N, 1) in [0, 1]
self.sh = sh # (B, N, K, 3) spherical-harmonic color coefficients
self.counts = counts # (B,) real lengths, or None
class MESH:
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor,
uvs: torch.Tensor | None = None,
vertex_colors: torch.Tensor | None = None,
texture: torch.Tensor | None = None,
vertex_counts: torch.Tensor | None = None,
face_counts: torch.Tensor | None = None):
face_counts: torch.Tensor | None = None,
unlit: bool = False):
assert (vertex_counts is None) == (face_counts is None), \
"vertex_counts and face_counts must be provided together (both or neither)"
@ -30,6 +49,8 @@ class MESH:
# these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed.
self.vertex_counts = vertex_counts
self.face_counts = face_counts
# Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes.
self.unlit = unlit
class File3D:

View File

@ -0,0 +1,32 @@
from pydantic import BaseModel, Field
class CreateSwitchXRequest(BaseModel):
generation_type: str = Field(...)
source_uri: str = Field(...)
alpha_mode: str = Field(...)
prompt: str | None = Field(None, max_length=2000)
reference_image_uri: str | None = Field(None)
alpha_uri: str | None = Field(None)
max_resolution: int = Field(1080)
callback_url: str | None = Field(None)
idempotency_key: str | None = Field(None, max_length=256, min_length=1)
class SwitchXOutputUrls(BaseModel):
render: str | None = Field(None)
source: str | None = Field(None)
alpha: str | None = Field(None)
class SwitchXStatusResponse(BaseModel):
id: str = Field(...)
status: str = Field(...)
progress: int | None = Field(None)
generation_type: str | None = Field(None)
alpha_mode: str | None = Field(None)
output: SwitchXOutputUrls | None = Field(None)
error: str | None = Field(None)
created_at: str | None = Field(None)
modified_at: str | None = Field(None)
completed_at: str | None = Field(None)

View File

@ -1,71 +1,72 @@
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field, confloat, conint
class BFLOutputFormat(str, Enum):
png = 'png'
jpeg = 'jpeg'
from pydantic import BaseModel, Field
class BFLFluxExpandImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image')
bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image')
left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image')
right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
top: int = Field(...)
bottom: int = Field(...)
left: int = Field(...)
right: int = Field(...)
steps: int = Field(...)
guidance: float = Field(...)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image: str = Field(None, description="A Base64-encoded string representing the image you wish to expand")
class BFLFluxFillImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
steps: int = Field(...)
guidance: float = Field(...)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image: str = Field(
None, description="Base64-encoded string representing the image to modify. Can contain alpha mask if desired.",
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
mask: str = Field(
None, description="Base64-encoded string representing the mask of the areas you wish to modify."
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
class BFLFluxEraseRequest(BaseModel):
image: str = Field(..., description="A Base64-encoded string representing the image to erase from.")
mask: str = Field(
...,
description="A Base64-encoded black/white mask matching the input dimensions; "
"white (255) marks areas to remove, black (0) marks areas to preserve.",
)
image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.')
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
dilate_pixels: int = Field(10)
seed: int | None = Field(None)
output_format: str = Field("png")
class BFLFluxVTORequest(BaseModel):
prompt: str = Field(
..., description="Natural-language styling instruction. Required field, but may be an empty string."
)
person: str = Field(..., description="A Base64-encoded string representing the person image.")
garment: str = Field(..., description="A Base64-encoded string representing the garment reference image.")
seed: int | None = Field(None)
safety_tolerance: int = Field(5)
output_format: str = Field("png")
class BFLFluxProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.')
height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
# image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
# None, description='Blend between the prompt and the image prompt.'
# )
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
width: int = Field(1024, description="Must be a multiple of 32.")
height: int = Field(768, description="Must be a multiple of 32.")
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
class Flux2ProGenerateRequest(BaseModel):
@ -83,55 +84,37 @@ class Flux2ProGenerateRequest(BaseModel):
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
safety_tolerance: int | None = Field(
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
)
output_format: str | None = Field(
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
)
safety_tolerance: int = Field(5)
output_format: str = Field("png")
class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
prompt: str = Field(...)
input_image: str | None = Field(None, description="Image to edit in base64 format")
seed: int | None = Field(None)
guidance: float = Field(...)
steps: int = Field(...)
safety_tolerance: int = Field(2)
output_format: str = Field("png")
aspect_ratio: str | None = Field(None)
prompt_upsampling: bool | None = Field(None)
class BFLFluxProUltraGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
None, description='Blend between the prompt and the image prompt.'
)
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
aspect_ratio: str | None = Field(None)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
raw: bool | None = Field(None)
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
image_prompt_strength: float | None = Field(None)
class BFLFluxProGenerateResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.")
polling_url: str = Field(..., description="URL to poll for the generation result.")
id: str = Field(...)
polling_url: str = Field(...)
cost: float | None = Field(None, description="Price in cents")
@ -145,7 +128,7 @@ class BFLStatus(str, Enum):
class BFLFluxStatusResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.")
status: BFLStatus = Field(..., description="The status of the task.")
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)
id: str = Field(...)
status: BFLStatus = Field(...)
result: dict[str, Any] | None = Field(None)
progress: float | None = Field(None, ge=0.0, le=1.0)

View File

@ -97,3 +97,28 @@ class BriaRemoveVideoBackgroundResult(BaseModel):
class BriaRemoveVideoBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveVideoBackgroundResult | None = Field(None)
class BriaVideoGreenScreenRequest(BaseModel):
video: str = Field(..., description="Publicly accessible URL of the input video.")
green_shade: str = Field(
default="broadcast_green",
description="Solid chroma-key shade applied behind the foreground "
"(broadcast_green, chroma_green, or blue_screen).",
)
output_container_and_codec: str = Field(...)
preserve_audio: bool = Field(True)
seed: int = Field(...)
class BriaVideoReplaceBackgroundRequest(BaseModel):
video: str = Field(..., description="Publicly accessible URL of the input (foreground) video.")
background_url: str = Field(
...,
description="Publicly accessible URL of the background image or video to composite behind "
"the foreground. Stretched to the foreground frame; match its aspect ratio for "
"undistorted results.",
)
output_container_and_codec: str = Field(...)
preserve_audio: bool = Field(True)
seed: int = Field(...)

View File

@ -158,8 +158,9 @@ class SeedanceCreateAssetResponse(BaseModel):
class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
url: str = Field(..., description="Publicly accessible URL of the image asset to upload.")
url: str = Field(..., description="Publicly accessible URL of the asset to upload.")
hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.")
asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.")
# Dollars per 1K tokens, keyed by (model_id, has_video_input).

View File

@ -108,13 +108,19 @@ class GeminiVideoMetadata(BaseModel):
startOffset: GeminiOffset | None = Field(None)
class GeminiThinkingConfig(BaseModel):
includeThoughts: bool | None = Field(None)
thinkingLevel: str = Field(...)
class GeminiGenerationConfig(BaseModel):
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
maxOutputTokens: int | None = Field(None, ge=16, le=65536)
seed: int | None = Field(None)
stopSequences: list[str] | None = Field(None)
temperature: float | None = Field(None, ge=0.0, le=2.0)
topK: int | None = Field(None, ge=1)
topP: float | None = Field(None, ge=0.0, le=1.0)
thinkingConfig: GeminiThinkingConfig | None = Field(None)
class GeminiImageOutputOptions(BaseModel):
@ -128,11 +134,6 @@ class GeminiImageConfig(BaseModel):
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
class GeminiThinkingConfig(BaseModel):
includeThoughts: bool | None = Field(None)
thinkingLevel: str = Field(...)
class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: list[str] | None = Field(None)
imageConfig: GeminiImageConfig | None = Field(None)

View File

@ -290,3 +290,19 @@ class IdeogramV3Request(BaseModel):
None,
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
)
class IdeogramV4Request(BaseModel):
text_prompt: str | None = Field(
None,
description="Natural-language prompt; Magic Prompt is applied automatically. "
"Supply exactly one of text_prompt or json_prompt.",
)
json_prompt: dict[str, Any] | None = Field(
None,
description="Structured V4 prompt object consumed directly (disables Magic Prompt). "
"Supply exactly one of text_prompt or json_prompt.",
)
resolution: str | None = Field(None, description="Output resolution in WIDTHxHEIGHT (e.g. '2048x2048').")
rendering_speed: str | None = Field(None, description="Rendering speed: 'TURBO', 'DEFAULT', or 'QUALITY'.")
enable_copyright_detection: bool | None = Field(None, description="Opt into post-generation copyright detection.")

View File

@ -0,0 +1,46 @@
"""Pydantic models for the Krea image-generation API."""
from pydantic import BaseModel, Field
class KreaMoodboard(BaseModel):
id: str = Field(...)
strength: float = Field(default=0.35, ge=-0.5, le=1.5)
class KreaImageStyleReference(BaseModel):
strength: float = Field(..., ge=-2.0, le=2.0)
url: str | None = Field(default=None)
class KreaGenerateImageRequest(BaseModel):
prompt: str = Field(...)
aspect_ratio: str = Field(...)
resolution: str = Field(...)
seed: int | None = Field(default=None)
creativity: str = Field(default="medium")
moodboards: list[KreaMoodboard] | None = Field(default=None)
image_style_references: list[KreaImageStyleReference] | None = Field(default=None)
class KreaJobResult(BaseModel):
urls: list[str] | None = Field(default=None)
style_id: str | None = Field(default=None)
class KreaJob(BaseModel):
job_id: str = Field(...)
status: str = Field(...)
created_at: str = Field(...)
completed_at: str | None = Field(default=None)
result: KreaJobResult | None = Field(default=None)
class KreaAssetResponse(BaseModel):
id: str = Field(...)
image_url: str = Field(...)
uploaded_at: str = Field(...)
width: float | None = Field(default=None)
height: float | None = Field(default=None)
size_bytes: float | None = Field(default=None)
mime_type: str | None = Field(default=None)

View File

@ -1,25 +1,25 @@
from enum import Enum
from typing import Optional, Any
from typing import Any
from pydantic import BaseModel, Field, RootModel
class TripoModelVersion(str, Enum):
v3_1_20260211 = 'v3.1-20260211'
v3_0_20250812 = 'v3.0-20250812'
v2_5_20250123 = 'v2.5-20250123'
v2_0_20240919 = 'v2.0-20240919'
v1_4_20240625 = 'v1.4-20240625'
v3_1_20260211 = "v3.1-20260211"
v3_0_20250812 = "v3.0-20250812"
v2_5_20250123 = "v2.5-20250123"
v2_0_20240919 = "v2.0-20240919"
v1_4_20240625 = "v1.4-20240625"
class TripoGeometryQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
standard = "standard"
detailed = "detailed"
class TripoTextureQuality(str, Enum):
standard = 'standard'
detailed = 'detailed'
standard = "standard"
detailed = "detailed"
class TripoStyle(str, Enum):
@ -33,6 +33,7 @@ class TripoStyle(str, Enum):
ANCIENT_BRONZE = "ancient_bronze"
NONE = "None"
class TripoTaskType(str, Enum):
TEXT_TO_MODEL = "text_to_model"
IMAGE_TO_MODEL = "image_to_model"
@ -45,26 +46,27 @@ class TripoTaskType(str, Enum):
STYLIZE_MODEL = "stylize_model"
CONVERT_MODEL = "convert_model"
class TripoTextureAlignment(str, Enum):
ORIGINAL_IMAGE = "original_image"
GEOMETRY = "geometry"
class TripoOrientation(str, Enum):
ALIGN_IMAGE = "align_image"
DEFAULT = "default"
class TripoOutFormat(str, Enum):
GLB = "glb"
FBX = "fbx"
class TripoTopology(str, Enum):
BIP = "bip"
QUAD = "quad"
class TripoSpec(str, Enum):
MIXAMO = "mixamo"
TRIPO = "tripo"
class TripoAnimation(str, Enum):
IDLE = "preset:idle"
WALK = "preset:walk"
@ -83,11 +85,6 @@ class TripoAnimation(str, Enum):
SERPENTINE_MARCH = "preset:serpentine:march"
AQUATIC_MARCH = "preset:aquatic:march"
class TripoStylizeStyle(str, Enum):
LEGO = "lego"
VOXEL = "voxel"
VORONOI = "voronoi"
MINECRAFT = "minecraft"
class TripoConvertFormat(str, Enum):
GLTF = "GLTF"
@ -97,6 +94,7 @@ class TripoConvertFormat(str, Enum):
STL = "STL"
_3MF = "3MF"
class TripoTextureFormat(str, Enum):
BMP = "BMP"
DPX = "DPX"
@ -108,6 +106,7 @@ class TripoTextureFormat(str, Enum):
TIFF = "TIFF"
WEBP = "WEBP"
class TripoTaskStatus(str, Enum):
QUEUED = "queued"
RUNNING = "running"
@ -118,183 +117,223 @@ class TripoTaskStatus(str, Enum):
BANNED = "banned"
EXPIRED = "expired"
class TripoFbxPreset(str, Enum):
BLENDER = "blender"
MIXAMO = "mixamo"
_3DSMAX = "3dsmax"
class TripoFileTokenReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
type: str | None = Field(None, description="The type of the reference")
file_token: str
class TripoUrlReference(BaseModel):
type: Optional[str] = Field(None, description='The type of the reference')
type: str | None = Field(None, description="The type of the reference")
url: str
class TripoObjectStorage(BaseModel):
bucket: str
key: str
class TripoObjectReference(BaseModel):
type: str
object: TripoObjectStorage
class TripoFileEmptyReference(BaseModel):
pass
class TripoFileReference(RootModel):
root: TripoFileTokenReference | TripoUrlReference | TripoObjectReference | TripoFileEmptyReference
class TripoGetStsTokenRequest(BaseModel):
format: str = Field(..., description='The format of the image')
class TripoTextToModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
image_seed: Optional[int] = Field(None, description='The seed for the text')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
style: Optional[TripoStyle] = None
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description="Type of task")
prompt: str = Field(..., description="The text prompt describing the model to generate", max_length=1024)
negative_prompt: str | None = Field(None, description="The negative text prompt", max_length=1024)
model_version: TripoModelVersion | None = TripoModelVersion.v2_5_20250123
face_limit: int | None = Field(None, description="The number of faces to limit the generation to")
texture: bool | None = Field(True, description="Whether to apply texture to the generated model")
pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model")
image_seed: int | None = Field(None, description="The seed for the text")
model_seed: int | None = Field(None, description="The seed for the model")
texture_seed: int | None = Field(None, description="The seed for the texture")
texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard
geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard
style: TripoStyle | None = None
auto_size: bool | None = Field(False, description="Whether to auto-size the model")
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
class TripoImageToModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task')
file: TripoFileReference = Field(..., description='The file reference to convert to a model')
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description="Type of task")
file: TripoFileReference = Field(..., description="The file reference to convert to a model")
model_version: TripoModelVersion | None = Field(None, description="The model version to use for generation")
face_limit: int | None = Field(None, description="The number of faces to limit the generation to")
texture: bool | None = Field(True, description="Whether to apply texture to the generated model")
pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model")
model_seed: int | None = Field(None, description="The seed for the model")
texture_seed: int | None = Field(None, description="The seed for the texture")
texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard
geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard
texture_alignment: TripoTextureAlignment | None = Field(
TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method"
)
style: TripoStyle | None = Field(None, description="The style to apply to the generated model")
auto_size: bool | None = Field(False, description="Whether to auto-size the model")
orientation: TripoOrientation | None = TripoOrientation.DEFAULT
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
class TripoMultiviewToModelRequest(BaseModel):
type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL
files: list[TripoFileReference] = Field(..., description='The file references to convert to a model')
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
files: list[TripoFileReference] = Field(..., description="The file references to convert to a model")
model_version: TripoModelVersion | None = Field(None, description="The model version to use for generation")
orthographic_projection: bool | None = Field(False, description="Whether to use orthographic projection")
face_limit: int | None = Field(None, description="The number of faces to limit the generation to")
texture: bool | None = Field(True, description="Whether to apply texture to the generated model")
pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model")
model_seed: int | None = Field(None, description="The seed for the model")
texture_seed: int | None = Field(None, description="The seed for the texture")
texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard
geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard
texture_alignment: TripoTextureAlignment | None = TripoTextureAlignment.ORIGINAL_IMAGE
auto_size: bool | None = Field(False, description="Whether to auto-size the model")
orientation: TripoOrientation | None = Field(TripoOrientation.DEFAULT, description="The orientation for the model")
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
class TripoTextureModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
texture: Optional[bool] = Field(True, description='Whether to apply texture to the model')
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model')
model_seed: Optional[int] = Field(None, description='The seed for the model')
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture')
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description="Type of task")
original_model_task_id: str = Field(..., description="The task ID of the original model")
texture: bool | None = Field(True, description="Whether to apply texture to the model")
pbr: bool | None = Field(True, description="Whether to apply PBR to the model")
model_seed: int | None = Field(None, description="The seed for the model")
texture_seed: int | None = Field(None, description="The seed for the texture")
texture_quality: TripoTextureQuality | None = Field(None, description="The quality of the texture")
texture_alignment: TripoTextureAlignment | None = Field(
TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method"
)
class TripoRefineModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task')
draft_model_task_id: str = Field(..., description='The task ID of the draft model')
type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description="Type of task")
draft_model_task_id: str = Field(..., description="The task ID of the draft model")
class TripoAnimatePrerigcheckRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
class TripoAnimateRigRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging')
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description="Type of task")
original_model_task_id: str = Field(..., description="The task ID of the original model")
out_format: TripoOutFormat | None = Field(TripoOutFormat.GLB, description="The output format")
spec: TripoSpec | None = Field(TripoSpec.TRIPO, description="The specification for rigging")
class TripoAnimateRetargetRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task')
original_model_task_id: str = Field(..., description='The task ID of the original model')
animation: TripoAnimation = Field(..., description='The animation to apply')
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation')
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description="Type of task")
original_model_task_id: str = Field(..., description="The task ID of the original model")
animation: TripoAnimation = Field(..., description="The animation to apply")
out_format: TripoOutFormat | None = Field(TripoOutFormat.GLB, description="The output format")
bake_animation: bool | None = Field(True, description="Whether to bake the animation")
class TripoStylizeModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task')
style: TripoStylizeStyle = Field(..., description='The style to apply to the model')
original_model_task_id: str = Field(..., description='The task ID of the original model')
block_size: Optional[int] = Field(80, description='The block size for stylization')
class TripoConvertModelRequest(BaseModel):
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
format: TripoConvertFormat = Field(..., description='The format to convert to')
original_model_task_id: str = Field(..., description='The task ID of the original model')
quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
texture_size: Optional[int] = Field(None, description='The size of the texture')
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
bake: Optional[bool] = Field(None, description='Whether to bake the model')
part_names: Optional[list[str]] = Field(None, description='The names of the parts to include')
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description="Type of task")
format: TripoConvertFormat = Field(..., description="The format to convert to")
original_model_task_id: str = Field(..., description="The task ID of the original model")
quad: bool | None = Field(None, description="Whether to apply quad to the model")
force_symmetry: bool | None = Field(None, description="Whether to force symmetry")
face_limit: int | None = Field(None, description="The number of faces to limit the conversion to")
flatten_bottom: bool | None = Field(None, description="Whether to flatten the bottom of the model")
flatten_bottom_threshold: float | None = Field(None, description="The threshold for flattening the bottom")
texture_size: int | None = Field(None, description="The size of the texture")
texture_format: TripoTextureFormat | None = Field(TripoTextureFormat.JPEG, description="The format of the texture")
pivot_to_center_bottom: bool | None = Field(None, description="Whether to pivot to the center bottom")
scale_factor: float | None = Field(None, description="The scale factor for the model")
with_animation: bool | None = Field(None, description="Whether to include animations")
pack_uv: bool | None = Field(None, description="Whether to pack the UVs")
bake: bool | None = Field(None, description="Whether to bake the model")
part_names: list[str] | None = Field(None, description="The names of the parts to include")
fbx_preset: TripoFbxPreset | None = Field(None, description="The preset for the FBX export")
export_vertex_colors: bool | None = Field(None, description="Whether to export the vertex colors")
export_orientation: TripoOrientation | None = Field(None, description="The orientation for the export")
animate_in_place: bool | None = Field(None, description="Whether to animate in place")
class TripoP1CommonRequest(BaseModel):
"""Fields supported by Tripo P1 across all input types."""
model_version: str = Field("P1-20260311")
model_seed: int | None = Field(None, description="Random seed for geometry generation")
face_limit: int | None = Field(None, ge=48, le=20000, description="Target face count (48-20000)")
texture: bool | None = Field(None, description="Enable texturing; pbr=True forces this true")
pbr: bool | None = Field(None, description="Enable PBR maps; when true, texture is also enabled")
texture_seed: int | None = Field(None, description="Random seed for texture generation")
texture_quality: str | None = Field(None, description='"standard" or "detailed"')
auto_size: bool | None = Field(None, description="Scale to real-world meters")
compress: str | None = Field(None, description='Only "geometry" is supported')
export_uv: bool | None = Field(None, description="Perform UV unwrapping during generation")
class TripoP1TextToModelRequest(TripoP1CommonRequest):
type: str = "text_to_model"
prompt: str = Field(..., max_length=1024)
negative_prompt: str | None = Field(None, max_length=255)
image_seed: int | None = None
class TripoP1ImageToModelRequest(TripoP1CommonRequest):
type: str = "image_to_model"
file: TripoFileReference
enable_image_autofix: bool | None = None
texture_alignment: str | None = Field(None, description='"original_image" or "geometry"')
orientation: str | None = Field(None, description='"default" or "align_image"; needs texture=true')
class TripoP1MultiviewToModelRequest(TripoP1CommonRequest):
"""P1 multiview generation.
Tripo requires `files` to be exactly four entries in [front, left, back, right] order with `{}`
(TripoFileEmptyReference) for omitted slots; front is required and at least two images total must be provided.
"""
type: str = "multiview_to_model"
files: list[TripoFileReference]
texture_alignment: str | None = None
orientation: str | None = None
class TripoTaskOutput(BaseModel):
model: Optional[str] = Field(None, description='URL to the model')
base_model: Optional[str] = Field(None, description='URL to the base model')
pbr_model: Optional[str] = Field(None, description='URL to the PBR model')
rendered_image: Optional[str] = Field(None, description='URL to the rendered image')
riggable: Optional[bool] = Field(None, description='Whether the model is riggable')
model: str | None = Field(None, description="URL to the model")
base_model: str | None = Field(None, description="URL to the base model")
pbr_model: str | None = Field(None, description="URL to the PBR model")
rendered_image: str | None = Field(None, description="URL to the rendered image")
riggable: bool | None = Field(None, description="Whether the model is riggable")
class TripoTask(BaseModel):
task_id: str = Field(..., description='The task ID')
type: Optional[str] = Field(None, description='The type of task')
status: Optional[TripoTaskStatus] = Field(None, description='The status of the task')
input: Optional[dict[str, Any]] = Field(None, description='The input parameters for the task')
output: Optional[TripoTaskOutput] = Field(None, description='The output of the task')
progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100)
create_time: Optional[int] = Field(None, description='The creation time of the task')
running_left_time: Optional[int] = Field(None, description='The estimated time left for the task')
queue_position: Optional[int] = Field(None, description='The position in the queue')
task_id: str = Field(..., description="The task ID")
type: str | None = Field(None, description="The type of task")
status: TripoTaskStatus | None = Field(None, description="The status of the task")
input: dict[str, Any] | None = Field(None, description="The input parameters for the task")
output: TripoTaskOutput | None = Field(None, description="The output of the task")
progress: int | None = Field(None, description="The progress of the task", ge=0, le=100)
create_time: int | None = Field(None, description="The creation time of the task")
running_left_time: int | None = Field(None, description="The estimated time left for the task")
queue_position: int | None = Field(None, description="The position in the queue")
consumed_credit: int | None = Field(None)
class TripoTaskResponse(BaseModel):
code: int = Field(0, description='The response code')
data: TripoTask = Field(..., description='The task data')
code: int = Field(0, description="The response code")
data: TripoTask = Field(..., description="The task data")
class TripoGeneralResponse(BaseModel):
code: int = Field(0, description='The response code')
data: dict[str, str] = Field(..., description='The task ID data')
class TripoBalanceData(BaseModel):
balance: float = Field(..., description='The account balance')
frozen: float = Field(..., description='The frozen balance')
class TripoBalanceResponse(BaseModel):
code: int = Field(0, description='The response code')
data: TripoBalanceData = Field(..., description='The balance data')
class TripoErrorResponse(BaseModel):
code: int = Field(..., description='The error code')
message: str = Field(..., description='The error message')
suggestion: str = Field(..., description='The suggestion for fixing the error')
code: int = Field(..., description="The error code")
message: str = Field(..., description="The error message")
suggestion: str = Field(..., description="The suggestion for fixing the error")

View File

@ -155,7 +155,7 @@ class ClaudeNode(IO.ComfyNode):
return IO.Schema(
node_id="ClaudeNode",
display_name="Anthropic Claude",
category="api node/text/Anthropic",
category="partner/text/Anthropic",
essentials_category="Text Generation",
description="Generate text responses with Anthropic's Claude models. "
"Provide a text prompt and optionally one or more images for multimodal context.",

View File

@ -0,0 +1,404 @@
from fractions import Fraction
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types
from comfy_api_nodes.apis.beeble import (
CreateSwitchXRequest,
SwitchXStatusResponse,
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
convert_mask_to_image,
download_url_as_bytesio,
download_url_to_image_tensor,
download_url_to_video_output,
downscale_image_tensor,
downscale_video_to_max_pixels,
poll_op,
sync_op,
upload_image_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
validate_video_frame_count,
)
_MAX_PIXELS = 2_770_000
_MAX_FRAMES = 240
_MAX_PROMPT_LEN = 2000
def _validate_inputs(prompt: str | None, reference_image: Input.Image | None) -> str | None:
"""Beeble requires at least one of prompt or reference_image. Returns the cleaned prompt."""
cleaned = prompt.strip() if prompt else ""
if not cleaned and reference_image is None:
raise ValueError("At least one of 'prompt' or 'reference_image' must be provided.")
if cleaned:
validate_string(cleaned, strip_whitespace=False, max_length=_MAX_PROMPT_LEN)
return cleaned or None
async def _upload_mask_as_image(
cls: type[IO.ComfyNode],
mask: Input.Image,
*,
wait_label: str,
) -> str:
"""Encode a single-frame MASK (H, W) or (1, H, W) as a PNG and upload."""
if mask.dim() == 2:
mask = mask.unsqueeze(0)
image = convert_mask_to_image(mask[:1])
return await upload_image_to_comfyapi(
cls,
image,
mime_type="image/png",
wait_label=wait_label,
total_pixels=_MAX_PIXELS,
)
async def _upload_mask_batch_as_video(
cls: type[IO.ComfyNode],
mask: Input.Image,
*,
frame_rate: Fraction,
source_frame_count: int,
wait_label: str,
) -> str:
"""Encode a MASK batch (N, H, W) as a grayscale H.264 MP4 at frame_rate and upload.
The matte is always downscaled to the pixel budget so it stays within Beeble's limit and
keeps the same dimensions as the (similarly downscaled) source both use the same algorithm
from the same starting dimensions, and downscaling is a no-op when already within budget.
"""
if mask.dim() == 2:
mask = mask.unsqueeze(0)
if mask.shape[0] != source_frame_count:
raise ValueError(
f"Custom alpha video frame count ({mask.shape[0]}) does not match the "
f"source video frame count ({source_frame_count}). The Beeble API requires "
"one mask per source frame."
)
images = downscale_image_tensor(convert_mask_to_image(mask), _MAX_PIXELS)
alpha_video = InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=None, frame_rate=frame_rate))
return await upload_video_to_comfyapi(cls, alpha_video, wait_label=wait_label)
def _alpha_mode_input(*, video: bool) -> IO.DynamicCombo.Input:
"""Build the alpha_mode DynamicCombo with mode-specific extra inputs."""
select_keyframe_tooltip = (
"First-frame keyframe mask. Beeble propagates this across the video." if video else "Grayscale keyframe mask."
)
custom_tooltip = (
"Per-frame grayscale mask covering the entire video. "
"Must have the same frame count as the source. "
"Connect a MASK output from SAM3_TrackToMask or similar."
if video
else "Grayscale mask to apply."
)
return IO.DynamicCombo.Input(
"alpha_mode",
tooltip=(
"Controls how SwitchX decides what to keep vs. regenerate. "
"'auto' isolates the main subject automatically. "
"'fill' regenerates the entire frame while preserving geometry. "
"'select' propagates a first-frame keyframe across the clip. "
"'custom' uses a per-frame alpha matte you provide."
),
options=[
IO.DynamicCombo.Option("auto", []),
IO.DynamicCombo.Option("fill", []),
IO.DynamicCombo.Option(
"select",
[IO.Mask.Input("alpha_keyframe", tooltip=select_keyframe_tooltip)],
),
IO.DynamicCombo.Option(
"custom",
[IO.Mask.Input("alpha_mask", tooltip=custom_tooltip)],
),
],
)
def _common_inputs(*, source: IO.Input, video: bool) -> list[IO.Input]:
return [
source,
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip=(
"Text description of the desired output (max 2000 chars). "
"At least one of 'prompt' or 'reference_image' is required."
),
),
IO.Image.Input(
"reference_image",
optional=True,
tooltip=(
"Reference image whose look (background, lighting, costume) the result "
"should adopt. At least one of 'reference_image' or 'prompt' is required."
),
),
_alpha_mode_input(video=video),
IO.Combo.Input(
"max_resolution",
options=["1080p", "720p"],
default="1080p",
tooltip="Maximum output resolution.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip=(
"Seed controls whether the node should re-run; " "results are non-deterministic regardless of seed."
),
),
]
async def _submit_and_poll(
cls: type[IO.ComfyNode],
request: CreateSwitchXRequest,
) -> SwitchXStatusResponse:
initial = await sync_op(
cls,
ApiEndpoint(path="/proxy/beeble/v1/switchx/generations", method="POST"),
response_model=SwitchXStatusResponse,
data=request,
)
return await poll_op(
cls,
ApiEndpoint(path=f"/proxy/beeble/v1/switchx/generations/{initial.id}"),
response_model=SwitchXStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
)
def _require_output_url(response: SwitchXStatusResponse, name: str) -> str:
if response.output is None or getattr(response.output, name) is None:
raise RuntimeError(f"Beeble job {response.id} completed without a {name!r} output URL.")
return getattr(response.output, name)
def _alpha_url(response: SwitchXStatusResponse, mode: str) -> str | None:
"""URL of the alpha matte, or None when the mode produces no separate matte.
'fill' selects the whole frame, so Beeble writes no alpha asset even though the status
response still returns a (dangling) signed URL for it fetching it 403s with S3
AccessDenied. The other three modes ('auto', 'custom', 'select') all produce a real,
downloadable matte.
"""
if mode == "fill" or response.output is None:
return None
return response.output.alpha
class BeebleSwitchXVideoEdit(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="BeebleSwitchXVideoEdit",
display_name="Beeble SwitchX Video Edit",
category="partner/video/Beeble",
description=(
"Edit a video with Beeble SwitchX. Switches anything in the scene (background, "
"lighting, costume) while preserving the original subject's pixels and motion. "
"Provide a reference image and/or text prompt to describe the new look. "
"Max 240 frames, max ~2.77MP per frame."
),
inputs=_common_inputs(source=IO.Video.Input("video"), video=True),
outputs=[
IO.Video.Output(display_name="video"),
IO.Video.Output(
display_name="alpha",
tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.",
),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]),
expr="""
(
$rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143;
{"type":"usd","usd": $rate, "format":{"suffix":"/30 frames"}}
)
""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
prompt: str,
alpha_mode: dict,
max_resolution: str,
seed: int,
reference_image: Input.Image | None = None,
) -> IO.NodeOutput:
cleaned_prompt = _validate_inputs(prompt, reference_image)
validate_video_frame_count(video, max_frame_count=_MAX_FRAMES)
video = downscale_video_to_max_pixels(video, _MAX_PIXELS)
mode = alpha_mode["alpha_mode"]
alpha_uri: str | None = None
if mode == "select":
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe")
elif mode == "custom":
alpha_uri = await _upload_mask_batch_as_video(
cls,
alpha_mode["alpha_mask"],
frame_rate=video.get_frame_rate(),
source_frame_count=video.get_frame_count(),
wait_label="Uploading alpha video",
)
source_uri = await upload_video_to_comfyapi(cls, video, wait_label="Uploading source")
reference_uri: str | None = None
if reference_image is not None:
reference_uri = await upload_image_to_comfyapi(
cls,
reference_image,
mime_type="image/png",
wait_label="Uploading reference",
total_pixels=_MAX_PIXELS,
)
request = CreateSwitchXRequest(
generation_type="video",
source_uri=source_uri,
alpha_mode=mode,
prompt=cleaned_prompt,
reference_image_uri=reference_uri,
alpha_uri=alpha_uri,
max_resolution=1080 if max_resolution == "1080p" else 720,
)
response = await _submit_and_poll(cls, request)
render = await download_url_to_video_output(_require_output_url(response, "render"))
alpha = None
if (alpha_url := _alpha_url(response, mode)) is not None:
alpha = await download_url_to_video_output(alpha_url)
return IO.NodeOutput(render, alpha)
class BeebleSwitchXImageEdit(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="BeebleSwitchXImageEdit",
display_name="Beeble SwitchX Image Edit",
category="partner/image/Beeble",
description=(
"Edit a single image with Beeble SwitchX. Switches anything in the scene "
"(background, lighting, costume) while preserving the original subject's pixels. "
"Provide a reference image and/or text prompt to describe the new look. "
"Max ~2.77MP."
),
inputs=_common_inputs(source=IO.Image.Input("image"), video=False),
outputs=[
IO.Image.Output(display_name="image"),
IO.Mask.Output(
display_name="alpha",
tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.",
),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]),
expr="""
(
$rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143;
{"type":"usd","usd": $rate}
)
""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
prompt: str,
alpha_mode: dict,
max_resolution: str,
seed: int,
reference_image: Input.Image | None = None,
) -> IO.NodeOutput:
cleaned_prompt = _validate_inputs(prompt, reference_image)
image = downscale_image_tensor(image, _MAX_PIXELS)
mode = alpha_mode["alpha_mode"]
alpha_uri: str | None = None
if mode == "select":
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe")
elif mode == "custom":
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_mask"], wait_label="Uploading alpha")
source_uri = await upload_image_to_comfyapi(
cls,
image,
mime_type="image/png",
wait_label="Uploading source",
total_pixels=None,
)
reference_uri: str | None = None
if reference_image is not None:
reference_uri = await upload_image_to_comfyapi(
cls,
reference_image,
mime_type="image/png",
wait_label="Uploading reference",
total_pixels=_MAX_PIXELS,
)
request = CreateSwitchXRequest(
generation_type="image",
source_uri=source_uri,
alpha_mode=mode,
prompt=cleaned_prompt,
reference_image_uri=reference_uri,
alpha_uri=alpha_uri,
max_resolution=1080 if max_resolution == "1080p" else 720,
)
response = await _submit_and_poll(cls, request)
render = await download_url_to_image_tensor(_require_output_url(response, "render"))
alpha_mask = None
if (alpha_url := _alpha_url(response, mode)) is not None:
alpha_image = bytesio_to_image_tensor(await download_url_as_bytesio(alpha_url), mode="L")
alpha_mask = alpha_image.squeeze(-1) if alpha_image.dim() == 4 else alpha_image
return IO.NodeOutput(render, alpha_mask)
class BeebleExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
BeebleSwitchXVideoEdit,
BeebleSwitchXImageEdit,
]
async def comfy_entrypoint() -> BeebleExtension:
return BeebleExtension()

View File

@ -4,17 +4,20 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl import (
BFLFluxEraseRequest,
BFLFluxExpandImageRequest,
BFLFluxFillImageRequest,
BFLFluxKontextProGenerateRequest,
BFLFluxProGenerateResponse,
BFLFluxProUltraGenerateRequest,
BFLFluxStatusResponse,
BFLFluxVTORequest,
BFLStatus,
Flux2ProGenerateRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
convert_mask_to_image,
download_url_to_image_tensor,
get_number_of_images,
poll_op,
@ -22,19 +25,11 @@ from comfy_api_nodes.util import (
sync_op,
tensor_to_base64_string,
validate_aspect_ratio_string,
validate_image_dimensions,
validate_string,
)
def convert_mask_to_image(mask: Input.Image):
"""
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
"""
mask = mask.unsqueeze(-1)
mask = torch.cat([mask] * 3, dim=-1)
return mask
class FluxProUltraImageNode(IO.ComfyNode):
@classmethod
@ -42,7 +37,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
return IO.Schema(
node_id="FluxProUltraImageNode",
display_name="Flux 1.1 [pro] Ultra Image",
category="api node/image/BFL",
category="partner/image/BFL",
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
inputs=[
IO.String.Input(
@ -160,7 +155,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
return IO.Schema(
node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME,
category="api node/image/BFL",
category="partner/image/BFL",
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
inputs=[
IO.String.Input(
@ -282,7 +277,7 @@ class FluxProExpandNode(IO.ComfyNode):
return IO.Schema(
node_id="FluxProExpandNode",
display_name="Flux.1 Expand Image",
category="api node/image/BFL",
category="partner/image/BFL",
description="Outpaints image based on prompt.",
inputs=[
IO.Image.Input("image"),
@ -419,7 +414,7 @@ class FluxProFillNode(IO.ComfyNode):
return IO.Schema(
node_id="FluxProFillNode",
display_name="Flux.1 Fill Image",
category="api node/image/BFL",
category="partner/image/BFL",
description="Inpaints image based on mask and prompt.",
inputs=[
IO.Image.Input("image"),
@ -519,6 +514,174 @@ class FluxProFillNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxEraseNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxEraseNode",
display_name="Flux Erase Image",
category="partner/image/BFL",
description="Removes the masked object from an image and reconstructs the background. "
"Paint the mask over what you want to erase.",
inputs=[
IO.Image.Input("image"),
IO.Mask.Input("mask", tooltip="White areas are removed; black areas are preserved."),
IO.Int.Input(
"dilate_pixels",
default=10,
min=0,
max=25,
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
optional=True,
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"range_usd","min_usd":0.03,"max_usd":0.06,"format":{"approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
mask: Input.Image,
dilate_pixels: int = 10,
seed: int = 0,
) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=256, min_height=256)
mask = resize_mask_to_image(mask, image)
mask = tensor_to_base64_string(convert_mask_to_image(mask))
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/erase-v1", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxEraseRequest(
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
mask=mask,
dilate_pixels=dilate_pixels,
seed=seed,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxVTONode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxVTONode",
display_name="Flux Virtual Try-On",
category="partner/image/BFL",
description="Virtual try-on: dresses the person in the provided garment.",
inputs=[
IO.Image.Input("person", tooltip="Image of the person to dress."),
IO.Image.Input("garment", tooltip="Image of the garment to apply."),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional natural-language styling instruction (e.g. how the garment should fit).",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"range_usd","min_usd":0.0375,"max_usd":0.075,"format":{"approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
person: Input.Image,
garment: Input.Image,
prompt: str = "",
seed: int = 0,
) -> IO.NodeOutput:
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/vto-v1", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxVTORequest(
prompt=prompt,
person=tensor_to_base64_string(person[:, :, :, :3]),
garment=tensor_to_base64_string(garment[:, :, :, :3]),
seed=seed,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode"
@ -545,7 +708,7 @@ class Flux2ProImageNode(IO.ComfyNode):
return IO.Schema(
node_id=cls.NODE_ID,
display_name=cls.DISPLAY_NAME,
category="api node/image/BFL",
category="partner/image/BFL",
description="Generates images synchronously based on prompt and resolution.",
inputs=[
IO.String.Input(
@ -716,7 +879,7 @@ class Flux2ImageNode(IO.ComfyNode):
return IO.Schema(
node_id="Flux2ImageNode",
display_name="Flux.2 Image",
category="api node/image/BFL",
category="partner/image/BFL",
description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.",
inputs=[
IO.String.Input(
@ -853,6 +1016,8 @@ class BFLExtension(ComfyExtension):
FluxKontextMaxImageNode,
FluxProExpandNode,
FluxProFillNode,
FluxEraseNode,
FluxVTONode,
Flux2ProImageNode,
Flux2MaxImageNode,
Flux2ImageNode,

View File

@ -1,14 +1,19 @@
import av
import torch
from av.codec import CodecContext
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bria import (
BriaEditImageRequest,
BriaImageEditResponse,
BriaRemoveBackgroundRequest,
BriaRemoveBackgroundResponse,
BriaRemoveVideoBackgroundRequest,
BriaRemoveVideoBackgroundResponse,
BriaImageEditResponse,
BriaStatusResponse,
BriaVideoGreenScreenRequest,
BriaVideoReplaceBackgroundRequest,
InputModerationSettings,
)
from comfy_api_nodes.util import (
@ -31,7 +36,7 @@ class BriaImageEditNode(IO.ComfyNode):
return IO.Schema(
node_id="BriaImageEditNode",
display_name="Bria FIBO Image Edit",
category="api node/image/Bria",
category="partner/image/Bria",
description="Edit images using Bria latest model",
inputs=[
IO.Combo.Input("model", options=["FIBO"]),
@ -169,7 +174,7 @@ class BriaRemoveImageBackground(IO.ComfyNode):
return IO.Schema(
node_id="BriaRemoveImageBackground",
display_name="Bria Remove Image Background",
category="api node/image/Bria",
category="partner/image/Bria",
description="Remove the background from an image using Bria RMBG 2.0.",
inputs=[
IO.Image.Input("image"),
@ -245,7 +250,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
return IO.Schema(
node_id="BriaRemoveVideoBackground",
display_name="Bria Remove Video Background",
category="api node/video/Bria",
category="partner/video/Bria",
description="Remove the background from a video using Bria. ",
inputs=[
IO.Video.Input("video"),
@ -316,6 +321,248 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
class BriaVideoGreenScreen(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaVideoGreenScreen",
display_name="Bria Video Green Screen",
category="partner/video/Bria",
description="Replace a video's background with a solid chroma-key screen using Bria.",
inputs=[
IO.Video.Input("video"),
IO.Combo.Input(
"green_shade",
options=["broadcast_green", "chroma_green", "blue_screen"],
tooltip="Solid chroma-key shade applied behind the foreground: "
"broadcast_green (#00B140), chroma_green (#00FF00), or blue_screen (#0000FF).",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
green_shade: str,
seed: int,
) -> IO.NodeOutput:
validate_video_duration(video, max_duration=60.0)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/video/edit/green_screen", method="POST"),
data=BriaVideoGreenScreenRequest(
video=await upload_video_to_comfyapi(cls, video),
green_shade=green_shade,
output_container_and_codec="mp4_h264",
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveVideoBackgroundResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
class BriaVideoReplaceBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaVideoReplaceBackground",
display_name="Bria Video Replace Background",
category="partner/video/Bria",
description="Replace a video's background with a supplied image or video using Bria. "
"The output keeps the foreground's resolution and frame rate; a background with a "
"different aspect ratio is stretched to fit, so match it for undistorted results.",
inputs=[
IO.Video.Input("video", tooltip="Foreground video whose background is replaced."),
IO.Image.Input(
"background_image",
optional=True,
tooltip="Background image to composite behind the foreground. "
"Provide either a background image or a background video, not both.",
),
IO.Video.Input(
"background_video",
optional=True,
tooltip="Background video to composite behind the foreground. "
"Provide either a background image or a background video, not both.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
seed: int,
background_image: Input.Image | None = None,
background_video: Input.Video | None = None,
) -> IO.NodeOutput:
if (background_image is None) == (background_video is None):
raise ValueError("Provide either a background image or a background video, not both.")
validate_video_duration(video, max_duration=60.0)
if background_video is not None:
validate_video_duration(background_video, max_duration=60.0)
background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background")
else:
background_url = await upload_image_to_comfyapi(cls, background_image, wait_label="Uploading background")
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"),
data=BriaVideoReplaceBackgroundRequest(
video=await upload_video_to_comfyapi(cls, video),
background_url=background_url,
output_container_and_codec="mp4_h264",
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveVideoBackgroundResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
def _video_to_images_and_mask(video: Input.Video) -> tuple[Input.Image, Input.Mask]:
"""Decode a transparent webm (VP9 + alpha) into image frames and an alpha mask.
VP9 keeps its alpha in a side layer that PyAV's default vp9 decoder drops, so the frames
are decoded with libvpx-vp9. Returns RGB images [B,H,W,3] in 0..1 and a mask [B,H,W]
following the Load Image convention (1 = transparent) for compositing or Save WEBM.
"""
rgb_frames: list[torch.Tensor] = []
alpha_frames: list[torch.Tensor] = []
with av.open(video.get_stream_source(), mode="r") as container:
stream = container.streams.video[0]
decoder = CodecContext.create("libvpx-vp9", "r") if stream.codec_context.name == "vp9" else None
for packet in container.demux(stream):
for frame in (decoder.decode(packet) if decoder is not None else packet.decode()):
rgba = torch.from_numpy(frame.to_ndarray(format="rgba")).float() / 255.0
rgb_frames.append(rgba[..., :3])
alpha_frames.append(rgba[..., 3])
images = torch.stack(rgb_frames) if rgb_frames else torch.zeros(0, 0, 0, 3)
mask = (1.0 - torch.stack(alpha_frames)) if alpha_frames else torch.zeros((images.shape[0], 64, 64))
return images, mask
class BriaTransparentVideoBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaTransparentVideoBackground",
display_name="Bria Remove Video Background (Transparent)",
category="partner/video/Bria",
description="Remove the background from a video using Bria and return the cut-out frames "
"plus an alpha mask. Connect both to a compositing node, or feed them to Save WEBM to "
"write a transparent video.",
inputs=[
IO.Video.Input("video"),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.Image.Output(display_name="images"),
IO.Mask.Output(display_name="mask"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
seed: int,
) -> IO.NodeOutput:
validate_video_duration(video, max_duration=60.0)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
data=BriaRemoveVideoBackgroundRequest(
video=await upload_video_to_comfyapi(cls, video),
background_color="Transparent",
output_container_and_codec="webm_vp9",
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveVideoBackgroundResponse,
)
video_out = await download_url_to_video_output(response.result.video_url)
images, mask = _video_to_images_and_mask(video_out)
return IO.NodeOutput(images, mask)
class BriaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -323,6 +570,9 @@ class BriaExtension(ComfyExtension):
BriaImageEditNode,
BriaRemoveImageBackground,
BriaRemoveVideoBackground,
BriaVideoGreenScreen,
# BriaVideoReplaceBackground, # server returns Status 500 when we pass background video
BriaTransparentVideoBackground,
]

View File

@ -2,11 +2,13 @@ import hashlib
import logging
import math
import re
from io import BytesIO
import torch
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy.utils import common_upscale
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.bytedance import (
RECOMMENDED_PRESETS,
RECOMMENDED_PRESETS_SEEDREAM_4,
@ -43,6 +45,7 @@ from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
download_url_to_video_output,
downscale_image_tensor_by_max_side,
downscale_video_to_max_pixels,
get_number_of_images,
image_tensor_pair_to_batch,
@ -121,6 +124,52 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st
)
def _prepare_seedance_image(image: Input.Image) -> Input.Image:
"""Auto-downscale a Seedance image input to the per-side limits, then validate it."""
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
image = downscale_image_tensor_by_max_side(image, max_side=6000)
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
return image
# Supported output aspect ratios, used to pre-size FLF frames to matching pixel pair to avoid the 1080p stretch jump.
SEEDANCE2_RATIO_WH = {
"16:9": (16, 9),
"4:3": (4, 3),
"1:1": (1, 1),
"3:4": (3, 4),
"9:16": (9, 16),
"21:9": (21, 9),
}
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080}
def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]:
"""Exact supported output (width, height) for (resolution, ratio).
The shorter side equals the resolution number (e.g. 1080p 16:9 -> 1920x1080). For ratio
"adaptive" (or any unexpected value) the ratio is derived from the image's own aspect, snapped
to the nearest supported ratio, so the output keeps the frame's orientation.
"""
short = SEEDANCE2_RES_SHORT_SIDE[resolution]
if ratio not in SEEDANCE2_RATIO_WH:
aspect = image.shape[-2] / image.shape[-3] # W / H; tensor is (B, H, W, C)
ratio = min(SEEDANCE2_RATIO_WH, key=lambda k: abs(SEEDANCE2_RATIO_WH[k][0] / SEEDANCE2_RATIO_WH[k][1] - aspect))
rw, rh = SEEDANCE2_RATIO_WH[ratio]
if rw >= rh: # landscape or square: shorter side is the height
out_w, out_h = round(short * rw / rh), short
else: # portrait: shorter side is the width
out_w, out_h = short, round(short * rh / rw)
return out_w - out_w % 2, out_h - out_h % 2
def _resize_to_exact(image: torch.Tensor, width: int, height: int) -> torch.Tensor:
"""Center-crop to the target aspect and resize to exactly width x height (lanczos)."""
samples = image.movedim(-1, 1) # (B, H, W, C) -> (B, C, H, W)
resized = common_upscale(samples, width, height, "lanczos", "center")
return resized.movedim(1, -1)
async def _resolve_reference_assets(
cls: type[IO.ComfyNode],
asset_ids: list[str],
@ -308,6 +357,26 @@ async def _seedance_virtual_library_upload_image_asset(
return f"asset://{create_resp.asset_id}"
async def _seedance_virtual_library_upload_video_asset(
cls: type[IO.ComfyNode],
video: Input.Video,
*,
wait_label: str = "Uploading video",
) -> str:
buf = BytesIO()
video.save_to(buf, format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264)
video_hash = hashlib.sha256(buf.getbuffer()).hexdigest()
public_url = await upload_video_to_comfyapi(cls, video, wait_label=wait_label)
create_resp = await sync_op(
cls,
ApiEndpoint(path="/proxy/seedance/virtual-library/assets", method="POST"),
response_model=SeedanceCreateAssetResponse,
data=SeedanceVirtualLibraryCreateAssetRequest(url=public_url, hash=video_hash, asset_type="Video"),
)
await _wait_for_asset_active(cls, create_resp.asset_id, group_id="virtual-library")
return f"asset://{create_resp.asset_id}"
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
@ -338,7 +407,7 @@ class ByteDanceImageNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceImageNode",
display_name="ByteDance Image",
category="api node/image/ByteDance",
category="partner/image/ByteDance",
description="Generate images using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
@ -462,7 +531,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceSeedreamNode",
display_name="ByteDance Seedream 4.5 & 5.0",
category="api node/image/ByteDance",
category="partner/image/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[
IO.Combo.Input(
@ -724,7 +793,7 @@ class ByteDanceSeedreamNodeV2(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceSeedreamNodeV2",
display_name="ByteDance Seedream 4.5 & 5.0",
category="api node/image/ByteDance",
category="partner/image/ByteDance",
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
inputs=[
IO.String.Input(
@ -890,7 +959,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceTextToVideoNode",
display_name="ByteDance Text to Video",
category="api node/video/ByteDance",
category="partner/video/ByteDance",
description="Generate video using ByteDance models via api based on prompt",
inputs=[
IO.Combo.Input(
@ -1018,7 +1087,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceImageToVideoNode",
display_name="ByteDance Image to Video",
category="api node/video/ByteDance",
category="partner/video/ByteDance",
description="Generate video using ByteDance models via api based on image and prompt",
inputs=[
IO.Combo.Input(
@ -1155,7 +1224,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceFirstLastFrameNode",
display_name="ByteDance First-Last-Frame to Video",
category="api node/video/ByteDance",
category="partner/video/ByteDance",
description="Generate video using prompt and first and last frames.",
inputs=[
IO.Combo.Input(
@ -1303,7 +1372,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceImageReferenceNode",
display_name="ByteDance Reference Images to Video",
category="api node/video/ByteDance",
category="partner/video/ByteDance",
description="Generate video using prompt and reference images.",
inputs=[
IO.Combo.Input(
@ -1546,7 +1615,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDance2TextToVideoNode",
display_name="ByteDance Seedance 2.0 Text to Video",
category="api node/video/ByteDance",
category="partner/video/ByteDance",
description="Generate video using Seedance 2.0 models based on a text prompt.",
inputs=[
IO.DynamicCombo.Input(
@ -1647,7 +1716,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDance2FirstLastFrameNode",
display_name="ByteDance Seedance 2.0 First-Last-Frame to Video",
category="api node/video/ByteDance",
category="partner/video/ByteDance",
description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.",
inputs=[
IO.DynamicCombo.Input(
@ -1760,6 +1829,29 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
if last_frame is not None and last_frame_asset_id:
raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.")
request_ratio = model["ratio"]
if first_frame_asset_id or last_frame_asset_id:
if first_frame is not None:
first_frame = _prepare_seedance_image(first_frame)
if last_frame is not None:
last_frame = _prepare_seedance_image(last_frame)
else:
# The 1080p FLF stretch fix (pre-size frames to a supported pixel pair + submit ratio="adaptive")
# only applies to local image inputs we can resize.
request_ratio = "adaptive"
target_dims: tuple[int, int] | None = None
if first_frame is not None:
validate_image_aspect_ratio(first_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
validate_image_dimensions(first_frame, min_width=300, min_height=300)
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], first_frame)
first_frame = _resize_to_exact(first_frame, *target_dims)
if last_frame is not None:
validate_image_aspect_ratio(last_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
validate_image_dimensions(last_frame, min_width=300, min_height=300)
if target_dims is None:
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], last_frame)
last_frame = _resize_to_exact(last_frame, *target_dims)
asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a]
image_assets: dict[str, str] = {}
if asset_ids_to_resolve:
@ -1809,7 +1901,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
content=content,
generate_audio=model["generate_audio"],
resolution=model["resolution"],
ratio=model["ratio"],
ratio=request_ratio,
duration=model["duration"],
seed=seed,
watermark=watermark,
@ -1866,7 +1958,7 @@ def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16
),
IO.Boolean.Input(
"auto_downscale",
default=False,
default=True,
optional=True,
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
@ -1909,7 +2001,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDance2ReferenceNode",
display_name="ByteDance Seedance 2.0 Reference to Video",
category="api node/video/ByteDance",
category="partner/video/ByteDance",
description="Generate, edit, or extend video using Seedance 2.0 with reference images, "
"videos, and audio. Supports multimodal reference, video editing, and video extension.",
inputs=[
@ -2034,6 +2126,9 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
f"(audios={len(reference_audios)}, audio assets={len(reference_audio_assets)}). Maximum is 3."
)
for key in reference_images:
reference_images[key] = _prepare_seedance_image(reference_images[key])
model_id = SEEDANCE_MODELS[model["model"]]
has_video_input = total_videos > 0
@ -2106,7 +2201,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
content.append(
TaskVideoContent(
video_url=TaskVideoContentUrl(
url=await upload_video_to_comfyapi(
url=await _seedance_virtual_library_upload_video_asset(
cls,
reference_videos[key],
wait_label=f"Uploading video {i}",
@ -2203,7 +2298,7 @@ class ByteDanceCreateImageAsset(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceCreateImageAsset",
display_name="ByteDance Create Image Asset",
category="api node/image/ByteDance",
category="partner/image/ByteDance",
description=(
"Create a Seedance 2.0 personal image asset. Uploads the input image and "
"registers it in the given asset group. If group_id is empty, runs a real-person "
@ -2270,7 +2365,7 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceCreateVideoAsset",
display_name="ByteDance Create Video Asset",
category="api node/video/ByteDance",
category="partner/video/ByteDance",
description=(
"Create a Seedance 2.0 personal video asset. Uploads the input video and "
"registers it in the given asset group. If group_id is empty, runs a real-person "

View File

@ -144,7 +144,7 @@ class ByteDanceSeedNode(IO.ComfyNode):
return IO.Schema(
node_id="ByteDanceSeedNode",
display_name="ByteDance Seed",
category="api node/text/ByteDance",
category="partner/text/ByteDance",
essentials_category="Text Generation",
description="Generate text responses with ByteDance's Seed 2.0 models. "
"Provide a text prompt and optionally one or more images or videos for multimodal context.",

View File

@ -69,7 +69,7 @@ class ElevenLabsSpeechToText(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsSpeechToText",
display_name="ElevenLabs Speech to Text",
category="api node/audio/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Transcribe audio to text. "
"Supports automatic language detection, speaker diarization, and audio event tagging.",
inputs=[
@ -210,7 +210,7 @@ class ElevenLabsVoiceSelector(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsVoiceSelector",
display_name="ElevenLabs Voice Selector",
category="api node/audio/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Select a predefined ElevenLabs voice for text-to-speech generation.",
inputs=[
IO.Combo.Input(
@ -239,7 +239,7 @@ class ElevenLabsTextToSpeech(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsTextToSpeech",
display_name="ElevenLabs Text to Speech",
category="api node/audio/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Convert text to speech.",
inputs=[
IO.Custom(ELEVENLABS_VOICE).Input(
@ -414,7 +414,7 @@ class ElevenLabsAudioIsolation(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsAudioIsolation",
display_name="ElevenLabs Voice Isolation",
category="api node/audio/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Remove background noise from audio, isolating vocals or speech.",
inputs=[
IO.Audio.Input(
@ -459,7 +459,7 @@ class ElevenLabsTextToSoundEffects(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsTextToSoundEffects",
display_name="ElevenLabs Text to Sound Effects",
category="api node/audio/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Generate sound effects from text descriptions.",
inputs=[
IO.String.Input(
@ -555,7 +555,7 @@ class ElevenLabsInstantVoiceClone(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsInstantVoiceClone",
display_name="ElevenLabs Instant Voice Clone",
category="api node/audio/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Create a cloned voice from audio samples. "
"Provide 1-8 audio recordings of the voice to clone.",
inputs=[
@ -658,7 +658,7 @@ class ElevenLabsSpeechToSpeech(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsSpeechToSpeech",
display_name="ElevenLabs Speech to Speech",
category="api node/audio/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Transform speech from one voice to another while preserving the original content and emotion.",
inputs=[
IO.Custom(ELEVENLABS_VOICE).Input(
@ -793,7 +793,7 @@ class ElevenLabsTextToDialogue(IO.ComfyNode):
return IO.Schema(
node_id="ElevenLabsTextToDialogue",
display_name="ElevenLabs Text to Dialogue",
category="api node/audio/ElevenLabs",
category="partner/audio/ElevenLabs",
description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.",
inputs=[
IO.Float.Input(

View File

@ -8,7 +8,7 @@ import os
from enum import Enum
from fnmatch import fnmatch
from io import BytesIO
from typing import Literal
from typing import Any, Literal
import torch
from typing_extensions import override
@ -19,6 +19,7 @@ from comfy_api_nodes.apis.gemini import (
GeminiContent,
GeminiFileData,
GeminiGenerateContentRequest,
GeminiGenerationConfig,
GeminiGenerateContentResponse,
GeminiImageConfig,
GeminiImageGenerateContentRequest,
@ -40,13 +41,18 @@ from comfy_api_nodes.util import (
get_number_of_images,
sync_op,
tensor_to_base64_string,
upload_audio_to_comfyapi,
upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
video_to_base64_string,
)
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
GEMINI_URL_INPUT_BUDGET = 10
GEMINI_MAX_INLINE_BYTES = 18 * 1024 * 1024
GEMINI_IMAGE_SYS_PROMPT = (
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
"Interpret all user input—regardless of "
@ -285,6 +291,140 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
return final_price / 1_000_000.0
def create_video_parts(video_input: Input.Video) -> list[GeminiPart]:
"""Convert a single video input to Gemini API compatible parts (inline MP4/H.264)."""
base_64_string = video_to_base64_string(
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.video_mp4,
data=base_64_string,
)
)
]
def create_audio_parts(audio_input: Input.Audio) -> list[GeminiPart]:
"""Convert an audio input to Gemini API compatible parts (one inline MP3 part per batch item)."""
audio_parts: list[GeminiPart] = []
for batch_index in range(audio_input["waveform"].shape[0]):
# Recreate an IO.AUDIO object for the given batch dimension index
audio_at_index = Input.Audio(
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
sample_rate=audio_input["sample_rate"],
)
# Convert to MP3 format for compatibility with Gemini API
audio_bytes = audio_to_base64_string(
audio_at_index,
container_format="mp3",
codec_name="libmp3lame",
)
audio_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.audio_mp3,
data=audio_bytes,
)
)
)
return audio_parts
def _flatten_images(images: list[Input.Image]) -> list[torch.Tensor]:
"""Expand any batched image tensors into individual (H, W, C) frames, preserving order."""
frames: list[torch.Tensor] = []
for img in images:
if len(img.shape) == 4:
frames.extend(img[i] for i in range(img.shape[0]))
else:
frames.append(img)
return frames
def _flatten_audio(audios: list[Input.Audio]) -> list[Input.Audio]:
"""Expand any batched audio inputs into individual single-clip audio inputs, preserving order."""
clips: list[Input.Audio] = []
for audio in audios:
waveform = audio["waveform"]
for i in range(waveform.shape[0]):
clips.append(Input.Audio(waveform=waveform[i].unsqueeze(0), sample_rate=audio["sample_rate"]))
return clips
async def _media_url_part(cls: type[IO.ComfyNode], kind: str, payload: Any) -> GeminiPart:
"""Upload a single media unit to ComfyAPI storage and return a fileData (URL) part."""
if kind == "image":
url = await upload_image_to_comfyapi(cls, payload, mime_type="image/png", wait_label="Uploading image")
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.image_png, fileUri=url))
if kind == "audio":
url = await upload_audio_to_comfyapi(
cls, payload, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mp3"
)
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.audio_mp3, fileUri=url))
url = await upload_video_to_comfyapi(cls, payload, wait_label="Uploading video")
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.video_mp4, fileUri=url))
def _media_inline_part(kind: str, payload: Any) -> tuple[GeminiPart, int]:
"""Encode a single media unit as an inline base64 part; returns (part, base64_length)."""
if kind == "image":
data = tensor_to_base64_string(payload, mime_type="image/webp")
mime = GeminiMimeType.image_webp
elif kind == "audio":
data = audio_to_base64_string(payload, container_format="mp3", codec_name="libmp3lame")
mime = GeminiMimeType.audio_mp3
else:
data = video_to_base64_string(
payload, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
mime = GeminiMimeType.video_mp4
return GeminiPart(inlineData=GeminiInlineData(mimeType=mime, data=data)), len(data)
async def build_gemini_media_parts(
cls: type[IO.ComfyNode],
images: list[Input.Image],
audios: list[Input.Audio],
videos: list[Input.Video],
*,
url_budget: int = GEMINI_URL_INPUT_BUDGET,
max_inline_bytes: int = GEMINI_MAX_INLINE_BYTES,
) -> list[GeminiPart]:
"""Build Gemini parts for multimodal inputs (images, audio, video).
fileData URLs are preferred for every media type: the upload is fetched directly by the
model, keeping the request body tiny regardless of media size. The URL budget is shared
across all media and assigned largest-first (video, then audio, then images), so that if it
is ever exhausted the inline-base64 overflow is limited to the smallest items. Total inline
payload is capped by `max_inline_bytes`.
"""
units: list[tuple[str, Any]] = (
[("video", v) for v in videos]
+ [("audio", a) for a in _flatten_audio(audios)]
+ [("image", f) for f in _flatten_images(images)]
)
parts: list[GeminiPart] = []
url_used = 0
inline_bytes = 0
for kind, payload in units:
if url_used < url_budget:
parts.append(await _media_url_part(cls, kind, payload))
url_used += 1
continue
part, nbytes = _media_inline_part(kind, payload)
inline_bytes += nbytes
if inline_bytes > max_inline_bytes:
raise ValueError(
f"Too much media to send inline (over {max_inline_bytes // (1024 * 1024)}MB after the first "
f"{url_budget} inputs are uploaded as URLs). Reduce the number or size of attached media."
)
parts.append(part)
return parts
class GeminiNode(IO.ComfyNode):
"""
Node to generate text responses from a Gemini model.
@ -300,7 +440,7 @@ class GeminiNode(IO.ComfyNode):
return IO.Schema(
node_id="GeminiNode",
display_name="Google Gemini",
category="api node/text/Gemini",
category="partner/text/Gemini",
description="Generate text responses with Google's Gemini AI model. "
"You can provide multiple types of inputs (text, images, audio, video) "
"as context for generating more relevant and meaningful responses.",
@ -407,58 +547,9 @@ class GeminiNode(IO.ComfyNode):
)
""",
),
is_deprecated=True,
)
@classmethod
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts."""
base_64_string = video_to_base64_string(
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.video_mp4,
data=base_64_string,
)
)
]
@classmethod
def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]:
"""
Convert audio input to Gemini API compatible parts.
Args:
audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
Returns:
List of GeminiPart objects containing the encoded audio.
"""
audio_parts: list[GeminiPart] = []
for batch_index in range(audio_input["waveform"].shape[0]):
# Recreate an IO.AUDIO object for the given batch dimension index
audio_at_index = Input.Audio(
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
sample_rate=audio_input["sample_rate"],
)
# Convert to MP3 format for compatibility with Gemini API
audio_bytes = audio_to_base64_string(
audio_at_index,
container_format="mp3",
codec_name="libmp3lame",
)
audio_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.audio_mp3,
data=audio_bytes,
)
)
)
return audio_parts
@classmethod
async def execute(
cls,
@ -482,9 +573,9 @@ class GeminiNode(IO.ComfyNode):
if images is not None:
parts.extend(await create_image_parts(cls, images))
if audio is not None:
parts.extend(cls.create_audio_parts(audio))
parts.extend(create_audio_parts(audio))
if video is not None:
parts.extend(cls.create_video_parts(video))
parts.extend(create_video_parts(video))
if files is not None:
parts.extend(files)
@ -512,6 +603,210 @@ class GeminiNode(IO.ComfyNode):
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
GEMINI_V2_MODELS: dict[str, str] = {
"Gemini 3.1 Pro": "gemini-3.1-pro-preview",
"Gemini 3.1 Flash-Lite": "gemini-3.1-flash-lite-preview",
}
def _gemini_text_model_inputs(thinking_default: str) -> list[Input]:
"""Per-model inputs revealed by the model DynamicCombo (shared media + sampling controls)."""
return [
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 17)],
min=0,
),
tooltip="Optional image(s) to use as context for the model. Up to 16 images.",
),
IO.Autogrow.Input(
"audio",
template=IO.Autogrow.TemplateNames(
IO.Audio.Input("audio"),
names=["audio_1"],
min=0,
),
tooltip="Optional audio clip to use as context for the model.",
),
IO.Autogrow.Input(
"video",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("video"),
names=["video_1"],
min=0,
),
tooltip="Optional video clip to use as context for the model.",
),
IO.Custom("GEMINI_INPUT_FILES").Input(
"files",
optional=True,
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Input Files node.",
),
IO.Combo.Input(
"thinking_level",
options=["LOW", "HIGH"],
default=thinking_default,
tooltip="How hard the model reasons internally before answering. "
"HIGH improves quality on difficult tasks but costs more (thinking) tokens and is slower.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.01,
tooltip="Controls randomness. Lower is more focused/deterministic, higher is more creative.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=0.95,
min=0.0,
max=1.0,
step=0.01,
tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.",
advanced=True,
),
IO.Int.Input(
"max_output_tokens",
default=32768,
min=16,
max=65536,
tooltip="Maximum tokens to generate, including the model's internal thinking. "
"With thinking_level HIGH, a low value can leave no room for the answer; raise this if "
"responses come back empty or truncated. The model stops early when finished, so a higher "
"cap costs nothing extra for short replies.",
advanced=True,
),
]
class GeminiNodeV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GeminiNodeV2",
display_name="Google Gemini",
category="partner/text/Gemini",
essentials_category="Text Generation",
description="Generate text responses with Google's Gemini models. Provide a text prompt and, "
"optionally, one or more images, audio clips, videos, or files as multimodal context.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text input to the model. Include detailed instructions, questions, or context.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Gemini 3.1 Pro", _gemini_text_model_inputs("HIGH")),
IO.DynamicCombo.Option("Gemini 3.1 Flash-Lite", _gemini_text_model_inputs("LOW")),
],
tooltip="The Gemini model used to generate the response.",
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for sampling. Set to 0 for a random seed. Deterministic output isn't guaranteed.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
advanced=True,
tooltip="Foundational instructions that dictate the model's behavior.",
),
],
outputs=[
IO.String.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m, "lite") ? {
"type": "list_usd",
"usd": [0.00025, 0.0015],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
} : {
"type": "list_usd",
"usd": [0.002, 0.012],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_id = GEMINI_V2_MODELS[model["model"]]
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
images = [t for t in (model.get("images") or {}).values() if t is not None]
audios = [a for a in (model.get("audio") or {}).values() if a is not None]
videos = [v for v in (model.get("video") or {}).values() if v is not None]
if images or audios or videos:
parts.extend(await build_gemini_media_parts(cls, images, audios, videos))
files = model.get("files")
if files is not None:
parts.extend(files)
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"),
data=GeminiGenerateContentRequest(
contents=[
GeminiContent(
role=GeminiRole.user,
parts=parts,
)
],
generationConfig=GeminiGenerationConfig(
temperature=model["temperature"],
topP=model["top_p"],
maxOutputTokens=model["max_output_tokens"],
seed=seed if seed > 0 else None,
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
output_text = get_text_from_response(response)
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
class GeminiInputFiles(IO.ComfyNode):
"""
Loads and formats input files for use with the Gemini API.
@ -541,7 +836,7 @@ class GeminiInputFiles(IO.ComfyNode):
return IO.Schema(
node_id="GeminiInputFiles",
display_name="Gemini Input Files",
category="api node/text/Gemini",
category="partner/text/Gemini",
description="Loads and prepares input files to include as inputs for Gemini LLM nodes. "
"The files will be read by the Gemini model when generating a response. "
"The contents of the text file count toward the token limit. "
@ -598,7 +893,7 @@ class GeminiImage(IO.ComfyNode):
return IO.Schema(
node_id="GeminiImageNode",
display_name="Nano Banana (Google Gemini Image)",
category="api node/image/Gemini",
category="partner/image/Gemini",
description="Edit images synchronously via Google API.",
inputs=[
IO.String.Input(
@ -731,7 +1026,7 @@ class GeminiImage2(IO.ComfyNode):
return IO.Schema(
node_id="GeminiImage2Node",
display_name="Nano Banana Pro (Google Gemini Image)",
category="api node/image/Gemini",
category="partner/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(
@ -869,7 +1164,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
return IO.Schema(
node_id="GeminiNanoBanana2",
display_name="Nano Banana 2",
category="api node/image/Gemini",
category="partner/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(
@ -1085,7 +1380,7 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
return IO.Schema(
node_id="GeminiNanoBanana2V2",
display_name="Nano Banana 2",
category="api node/image/Gemini",
category="partner/image/Gemini",
description="Generate or edit images synchronously via Google Vertex API.",
inputs=[
IO.String.Input(
@ -1129,6 +1424,26 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
tooltip="Foundational instructions that dictate an AI's behavior.",
advanced=True,
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.01,
optional=True,
tooltip="Controls randomness in generation. Lower is more focused/deterministic.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=0.95,
min=0.0,
max=1.0,
step=0.01,
optional=True,
tooltip="Nucleus sampling threshold. Lower is more focused, higher more diverse.",
advanced=True,
),
],
outputs=[
IO.Image.Output(),
@ -1165,6 +1480,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
seed: int,
response_modalities: str,
system_prompt: str = "",
temperature: float = 1.0,
top_p: float = 0.95,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_choice = model["model"]
@ -1204,6 +1521,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
temperature=temperature,
topP=top_p,
),
systemInstruction=gemini_system_prompt,
),
@ -1222,6 +1541,7 @@ class GeminiExtension(ComfyExtension):
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
GeminiNode,
GeminiNodeV2,
GeminiImage,
GeminiImage2,
GeminiNanoBanana2,

View File

@ -29,6 +29,11 @@ from comfy_api_nodes.util import (
)
_GROK_VIDEO_MODEL_API_IDS = {
"grok-imagine-video-1.5": "grok-imagine-video-1.5-preview",
}
def _extract_grok_price(response) -> float | None:
if response.usage and response.usage.cost_in_usd_ticks is not None:
return response.usage.cost_in_usd_ticks / 10_000_000_000
@ -49,7 +54,7 @@ class GrokImageNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokImageNode",
display_name="Grok Image",
category="api node/image/Grok",
category="partner/image/Grok",
description="Generate images using Grok based on a text prompt",
inputs=[
IO.Combo.Input(
@ -58,7 +63,6 @@ class GrokImageNode(IO.ComfyNode):
"grok-imagine-image-quality",
"grok-imagine-image-pro",
"grok-imagine-image",
"grok-imagine-image-beta",
],
),
IO.String.Input(
@ -224,7 +228,7 @@ class GrokImageEditNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokImageEditNode",
display_name="Grok Image Edit",
category="api node/image/Grok",
category="partner/image/Grok",
description="Modify an existing image based on a text prompt",
inputs=[
IO.Combo.Input(
@ -233,7 +237,6 @@ class GrokImageEditNode(IO.ComfyNode):
"grok-imagine-image-quality",
"grok-imagine-image-pro",
"grok-imagine-image",
"grok-imagine-image-beta",
],
),
IO.Image.Input("image", display_name="images"),
@ -366,7 +369,7 @@ class GrokImageEditNodeV2(IO.ComfyNode):
return IO.Schema(
node_id="GrokImageEditNodeV2",
display_name="Grok Image Edit",
category="api node/image/Grok",
category="partner/image/Grok",
description="Modify an existing image based on a text prompt",
inputs=[
IO.String.Input(
@ -503,10 +506,14 @@ class GrokVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokVideoNode",
display_name="Grok Video",
category="api node/video/Grok",
category="partner/video/Grok",
description="Generate video from a prompt or an image",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.Combo.Input(
"model",
options=["grok-imagine-video", "grok-imagine-video-1.5"],
tooltip="grok-imagine-video-1.5 currently always requires an input image.",
),
IO.String.Input(
"prompt",
multiline=True,
@ -542,7 +549,11 @@ class GrokVideoNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Image.Input("image", optional=True),
IO.Image.Input(
"image",
optional=True,
tooltip="Optional starting image for grok-imagine-video. Required for grok-imagine-video-1.5.",
),
],
outputs=[
IO.Video.Output(),
@ -554,12 +565,16 @@ class GrokVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]),
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"], inputs=["image"]),
expr="""
(
$rate := widgets.resolution = "720p" ? 0.07 : 0.05;
$is15 := $contains(widgets.model, "1.5");
$rate := $is15
? (widgets.resolution = "720p" ? 0.2002 : 0.1144)
: (widgets.resolution = "720p" ? 0.07 : 0.05);
$imgCost := $is15 ? 0.0143 : 0.002;
$base := $rate * widgets.duration;
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
{"type":"usd","usd": inputs.image.connected ? $base + $imgCost : $base}
)
""",
),
@ -576,8 +591,8 @@ class GrokVideoNode(IO.ComfyNode):
seed: int,
image: Input.Image | None = None,
) -> IO.NodeOutput:
if model == "grok-imagine-video-beta":
model = "grok-imagine-video"
if image is None and model == "grok-imagine-video-1.5":
raise ValueError(f"The '{model}' model requires an input image; connect one to the 'image' input.")
image_url = None
if image is not None:
if get_number_of_images(image) != 1:
@ -588,7 +603,7 @@ class GrokVideoNode(IO.ComfyNode):
cls,
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
data=VideoGenerationRequest(
model=model,
model=_GROK_VIDEO_MODEL_API_IDS.get(model, model),
image=image_url,
prompt=prompt,
resolution=resolution,
@ -603,7 +618,7 @@ class GrokVideoNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_price,
price_extractor=_extract_grok_video_price if model == "grok-imagine-video-1.5" else _extract_grok_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
@ -615,10 +630,10 @@ class GrokVideoEditNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokVideoEditNode",
display_name="Grok Video Edit",
category="api node/video/Grok",
category="partner/video/Grok",
description="Edit an existing video based on a text prompt.",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.Combo.Input("model", options=["grok-imagine-video"]),
IO.String.Input(
"prompt",
multiline=True,
@ -693,7 +708,7 @@ class GrokVideoReferenceNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokVideoReferenceNode",
display_name="Grok Reference-to-Video",
category="api node/video/Grok",
category="partner/video/Grok",
description="Generate video guided by reference images as style and content references.",
inputs=[
IO.String.Input(
@ -826,7 +841,7 @@ class GrokVideoExtendNode(IO.ComfyNode):
return IO.Schema(
node_id="GrokVideoExtendNode",
display_name="Grok Video Extend",
category="api node/video/Grok",
category="partner/video/Grok",
description="Extend an existing video with a seamless continuation based on a text prompt.",
inputs=[
IO.String.Input(

View File

@ -71,7 +71,7 @@ class HitPawGeneralImageEnhance(IO.ComfyNode):
return IO.Schema(
node_id="HitPawGeneralImageEnhance",
display_name="HitPaw General Image Enhance",
category="api node/image/HitPaw",
category="partner/image/HitPaw",
description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
inputs=[
@ -201,7 +201,7 @@ class HitPawVideoEnhance(IO.ComfyNode):
return IO.Schema(
node_id="HitPawVideoEnhance",
display_name="HitPaw Video Enhance",
category="api node/video/HitPaw",
category="partner/video/HitPaw",
description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
"Prices shown are per second of video.",
inputs=[

View File

@ -123,7 +123,7 @@ class TencentTextToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TencentTextToModelNode",
display_name="Hunyuan3D: Text to Model",
category="api node/3d/Tencent",
category="partner/3d/Tencent",
essentials_category="3D",
inputs=[
IO.Combo.Input(
@ -242,7 +242,7 @@ class TencentImageToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TencentImageToModelNode",
display_name="Hunyuan3D: Image(s) to Model",
category="api node/3d/Tencent",
category="partner/3d/Tencent",
essentials_category="3D",
inputs=[
IO.Combo.Input(
@ -415,7 +415,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
return IO.Schema(
node_id="TencentModelTo3DUVNode",
display_name="Hunyuan3D: Model to UV",
category="api node/3d/Tencent",
category="partner/3d/Tencent",
description="Perform UV unfolding on a 3D model to generate UV texture. "
"Input model must have less than 30000 faces.",
inputs=[
@ -505,7 +505,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
return IO.Schema(
node_id="Tencent3DTextureEditNode",
display_name="Hunyuan3D: 3D Texture Edit",
category="api node/3d/Tencent",
category="partner/3d/Tencent",
description="After inputting the 3D model, perform 3D model texture redrawing.",
inputs=[
IO.MultiType.Input(
@ -594,7 +594,7 @@ class Tencent3DPartNode(IO.ComfyNode):
return IO.Schema(
node_id="Tencent3DPartNode",
display_name="Hunyuan3D: 3D Part",
category="api node/3d/Tencent",
category="partner/3d/Tencent",
description="Automatically perform component identification and generation based on the model structure.",
inputs=[
IO.MultiType.Input(
@ -666,7 +666,7 @@ class TencentSmartTopologyNode(IO.ComfyNode):
return IO.Schema(
node_id="TencentSmartTopologyNode",
display_name="Hunyuan3D: Smart Topology",
category="api node/3d/Tencent",
category="partner/3d/Tencent",
description="Perform smart retopology on a 3D model. "
"Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.",
inputs=[

View File

@ -10,6 +10,7 @@ from comfy_api_nodes.apis.ideogram import (
ImageRequest,
IdeogramV3Request,
IdeogramV3EditRequest,
IdeogramV4Request,
)
from comfy_api_nodes.util import (
ApiEndpoint,
@ -17,6 +18,7 @@ from comfy_api_nodes.util import (
download_url_as_bytesio,
resize_mask_to_image,
sync_op,
validate_string,
)
V1_V1_RES_MAP = {
@ -234,7 +236,7 @@ class IdeogramV1(IO.ComfyNode):
return IO.Schema(
node_id="IdeogramV1",
display_name="Ideogram V1",
category="api node/image/Ideogram",
category="partner/image/Ideogram",
description="Generates images using the Ideogram V1 model.",
inputs=[
IO.String.Input(
@ -360,7 +362,7 @@ class IdeogramV2(IO.ComfyNode):
return IO.Schema(
node_id="IdeogramV2",
display_name="Ideogram V2",
category="api node/image/Ideogram",
category="partner/image/Ideogram",
description="Generates images using the Ideogram V2 model.",
inputs=[
IO.String.Input(
@ -526,7 +528,7 @@ class IdeogramV3(IO.ComfyNode):
return IO.Schema(
node_id="IdeogramV3",
display_name="Ideogram V3",
category="api node/image/Ideogram",
category="partner/image/Ideogram",
description="Generates images using the Ideogram V3 model. "
"Supports both regular image generation from text prompts and image editing with mask.",
inputs=[
@ -798,6 +800,119 @@ class IdeogramV3(IO.ComfyNode):
return IO.NodeOutput(await download_and_process_images(image_urls))
class IdeogramV4(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="IdeogramV4",
display_name="Ideogram V4",
category="partner/image/Ideogram",
description="Generates images using the Ideogram 4.0 model from a text prompt.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the image generation.",
),
IO.Combo.Input(
"resolution",
options=[
"Auto",
"2048x2048 (1:1)",
"1440x2880 (1:2)",
"2880x1440 (2:1)",
"1664x2496 (2:3)",
"2496x1664 (3:2)",
"1792x2240 (4:5)",
"2240x1792 (5:4)",
"1440x2560 (9:16)",
"2560x1440 (16:9)",
"1600x2560 (5:8)",
"2560x1600 (8:5)",
"1728x2304 (3:4)",
"2304x1728 (4:3)",
"1296x3168 (9:22)",
"3168x1296 (22:9)",
"1152x2944 (9:23)",
"2944x1152 (23:9)",
"1248x3328 (3:8)",
"3328x1248 (8:3)",
"1280x3072 (5:12)",
"3072x1280 (12:5)",
],
default="Auto",
),
IO.Combo.Input(
"rendering_speed",
options=["DEFAULT", "TURBO", "QUALITY"],
default="DEFAULT",
tooltip="Controls the trade-off between generation speed and quality.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
control_after_generate=True,
display_mode=IO.NumberDisplay.number,
),
],
outputs=[
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["rendering_speed"]),
expr="""
(
$speed := widgets.rendering_speed;
$price :=
$contains($speed,"turbo") ? 0.0429 :
$contains($speed,"quality") ? 0.143 :
0.0858;
{"type":"usd","usd": $price}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
resolution: str,
rendering_speed: str,
seed: int,
):
validate_string(prompt, strip_whitespace=True, min_length=1)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/ideogram/ideogram-v4/generate", method="POST"),
response_model=IdeogramGenerateResponse,
data=IdeogramV4Request(
text_prompt=prompt,
resolution=resolution.split(" ")[0] if resolution != "Auto" else None,
rendering_speed=rendering_speed,
),
max_retries=1,
)
if not response.data or len(response.data) == 0:
raise Exception("No images were generated in the response")
image_urls = [image_data.url for image_data in response.data if image_data.url]
if not image_urls:
raise Exception("No image URLs were generated in the response")
return IO.NodeOutput(await download_and_process_images(image_urls))
class IdeogramExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -805,6 +920,7 @@ class IdeogramExtension(ComfyExtension):
IdeogramV1,
IdeogramV2,
IdeogramV3,
IdeogramV4,
]

View File

@ -642,7 +642,7 @@ class KlingCameraControls(IO.ComfyNode):
return IO.Schema(
node_id="KlingCameraControls",
display_name="Kling Camera Controls",
category="api node/video/Kling",
category="partner/video/Kling",
description="Allows specifying configuration options for Kling Camera Controls and motion control effects.",
inputs=[
IO.Combo.Input("camera_control_type", options=KlingCameraControlType),
@ -762,7 +762,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingTextToVideoNode",
display_name="Kling Text to Video",
category="api node/video/Kling",
category="partner/video/Kling",
description="Kling Text to Video Node",
inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -849,7 +849,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProTextToVideoNode",
display_name="Kling 3.0 Omni Text to Video",
category="api node/video/Kling",
category="partner/video/Kling",
description="Use text prompts to generate videos with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -998,7 +998,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProFirstLastFrameNode",
display_name="Kling 3.0 Omni First-Last-Frame to Video",
category="api node/video/Kling",
category="partner/video/Kling",
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -1205,7 +1205,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProImageToVideoNode",
display_name="Kling 3.0 Omni Image to Video",
category="api node/video/Kling",
category="partner/video/Kling",
description="Use up to 7 reference images to generate a video with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -1374,7 +1374,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProVideoToVideoNode",
display_name="Kling 3.0 Omni Video to Video",
category="api node/video/Kling",
category="partner/video/Kling",
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
@ -1485,7 +1485,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProEditVideoNode",
display_name="Kling 3.0 Omni Edit Video",
category="api node/video/Kling",
category="partner/video/Kling",
essentials_category="Video Generation",
description="Edit an existing video with the latest model from Kling.",
inputs=[
@ -1593,7 +1593,7 @@ class OmniProImageNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingOmniProImageNode",
display_name="Kling 3.0 Omni Image",
category="api node/image/Kling",
category="partner/image/Kling",
description="Create or edit images with the latest model from Kling.",
inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]),
@ -1721,7 +1721,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingCameraControlT2VNode",
display_name="Kling Text to Video (Camera Control)",
category="api node/video/Kling",
category="partner/video/Kling",
description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.",
inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1783,7 +1783,7 @@ class KlingImage2VideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingImage2VideoNode",
display_name="Kling Image(First Frame) to Video",
category="api node/video/Kling",
category="partner/video/Kling",
inputs=[
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -1882,7 +1882,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingCameraControlI2VNode",
display_name="Kling Image to Video (Camera Control)",
category="api node/video/Kling",
category="partner/video/Kling",
description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.",
inputs=[
IO.Image.Input(
@ -1953,7 +1953,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingStartEndFrameNode",
display_name="Kling Start-End Frame to Video",
category="api node/video/Kling",
category="partner/video/Kling",
description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.",
inputs=[
IO.Image.Input(
@ -2047,7 +2047,7 @@ class KlingVideoExtendNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingVideoExtendNode",
display_name="Kling Video Extend",
category="api node/video/Kling",
category="partner/video/Kling",
description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.",
inputs=[
IO.String.Input(
@ -2128,7 +2128,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingDualCharacterVideoEffectNode",
display_name="Kling Dual Character Video Effects",
category="api node/video/Kling",
category="partner/video/Kling",
description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.",
inputs=[
IO.Image.Input("image_left", tooltip="Left side image"),
@ -2218,7 +2218,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingSingleImageVideoEffectNode",
display_name="Kling Video Effects",
category="api node/video/Kling",
category="partner/video/Kling",
description="Achieve different special effects when generating a video based on the effect_scene.",
inputs=[
IO.Image.Input(
@ -2291,7 +2291,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingLipSyncAudioToVideoNode",
display_name="Kling Lip Sync Video with Audio",
category="api node/video/Kling",
category="partner/video/Kling",
essentials_category="Video Generation",
description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
inputs=[
@ -2343,7 +2343,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingLipSyncTextToVideoNode",
display_name="Kling Lip Sync Video with Text",
category="api node/video/Kling",
category="partner/video/Kling",
description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
inputs=[
IO.Video.Input("video"),
@ -2411,7 +2411,7 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingVirtualTryOnNode",
display_name="Kling Virtual Try On",
category="api node/image/Kling",
category="partner/image/Kling",
description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.",
inputs=[
IO.Image.Input("human_image"),
@ -2478,7 +2478,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingImageGenerationNode",
display_name="Kling 3.0 Image",
category="api node/image/Kling",
category="partner/image/Kling",
description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.",
inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
@ -2615,7 +2615,7 @@ class TextToVideoWithAudio(IO.ComfyNode):
return IO.Schema(
node_id="KlingTextToVideoWithAudio",
display_name="Kling 2.6 Text to Video with Audio",
category="api node/video/Kling",
category="partner/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
@ -2683,7 +2683,7 @@ class ImageToVideoWithAudio(IO.ComfyNode):
return IO.Schema(
node_id="KlingImageToVideoWithAudio",
display_name="Kling 2.6 Image(First Frame) to Video with Audio",
category="api node/video/Kling",
category="partner/video/Kling",
inputs=[
IO.Combo.Input("model_name", options=["kling-v2-6"]),
IO.Image.Input("start_frame"),
@ -2753,7 +2753,7 @@ class MotionControl(IO.ComfyNode):
return IO.Schema(
node_id="KlingMotionControl",
display_name="Kling Motion Control",
category="api node/video/Kling",
category="partner/video/Kling",
inputs=[
IO.String.Input("prompt", multiline=True),
IO.Image.Input("reference_image"),
@ -2854,7 +2854,7 @@ class KlingVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingVideoNode",
display_name="Kling 3.0 Video",
category="api node/video/Kling",
category="partner/video/Kling",
description="Generate videos with Kling V3. "
"Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.",
inputs=[
@ -3077,7 +3077,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingFirstLastFrameNode",
display_name="Kling 3.0 First-Last-Frame to Video",
category="api node/video/Kling",
category="partner/video/Kling",
description="Generate videos with Kling V3 using first and last frames.",
inputs=[
IO.String.Input("prompt", multiline=True, default=""),
@ -3202,7 +3202,7 @@ class KlingAvatarNode(IO.ComfyNode):
return IO.Schema(
node_id="KlingAvatarNode",
display_name="Kling Avatar 2.0",
category="api node/video/Kling",
category="partner/video/Kling",
description="Generate broadcast-style digital human videos from a single photo and an audio file.",
inputs=[
IO.Image.Input(

View File

@ -0,0 +1,294 @@
"""Krea image-generation nodes."""
import re
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.krea import (
KreaAssetResponse,
KreaGenerateImageRequest,
KreaImageStyleReference,
KreaJob,
KreaMoodboard,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_image_tensor,
poll_op,
sync_op,
tensor_to_bytesio,
validate_string,
)
class KreaIO:
STYLE_REF = "KREA_STYLE_REF"
async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Image) -> str:
"""Upload an image to Krea's /assets endpoint and return the Krea-hosted image URL."""
img_io = tensor_to_bytesio(image, total_pixels=2048 * 2048, mime_type="image/png")
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/krea/assets", method="POST"),
response_model=KreaAssetResponse,
files=[("file", (img_io.name, img_io, "image/png"))],
content_type="multipart/form-data",
max_retries=1,
wait_label="Uploading reference",
)
return response.image_url
_MODEL_MEDIUM = "Krea 2 Medium"
_MODEL_MEDIUM_TURBO = "Krea 2 Medium Turbo"
_MODEL_LARGE = "Krea 2 Large"
_MODEL_ENDPOINTS: dict[str, str] = {
_MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium",
_MODEL_MEDIUM_TURBO: "/proxy/krea/generate/image/krea/krea-2/medium-turbo",
_MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large",
}
_ASPECT_RATIOS = ["1:1", "4:3", "3:2", "16:9", "2.35:1", "4:5", "2:3", "9:16"]
_RESOLUTIONS = ["1K"]
_CREATIVITY_LEVELS = ["raw", "low", "medium", "high"]
_KREA_QUEUED_STATUSES = ["backlogged", "queued", "scheduled"]
_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
def _krea_model_inputs() -> list:
"""Nested inputs shared by Krea 2 Medium, Medium Turbo and Large under the DynamicCombo."""
return [
IO.Combo.Input(
"aspect_ratio",
options=_ASPECT_RATIOS,
tooltip="Output aspect ratio.",
),
IO.Combo.Input(
"resolution",
options=_RESOLUTIONS,
tooltip="Resolution scale.",
),
IO.Combo.Input(
"creativity",
options=_CREATIVITY_LEVELS,
default="medium",
tooltip="Prompt interpretation strength: raw stays closest to the prompt; high is most creative.",
),
IO.String.Input(
"moodboard_id",
default="",
tooltip="Optional Krea moodboard UUID (e.g. from the Krea website). "
"Leave empty to disable. Only one moodboard is supported per request.",
optional=True,
),
IO.Float.Input(
"moodboard_strength",
default=0.35,
min=-0.5,
max=1.5,
step=0.05,
tooltip="Moodboard influence; ignored when moodboard_id is empty.",
optional=True,
),
IO.Custom(KreaIO.STYLE_REF).Input(
"style_reference",
optional=True,
tooltip="Optional chain of style references (max 10) from Krea 2 Style Reference nodes.",
),
]
class Krea2ImageNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Krea2ImageNode",
display_name="Krea 2 Image",
category="partner/image/Krea",
description=(
"Generate images via Krea 2 — pick Medium (expressive illustrations) or "
"Large (expressive photorealism). Supports an optional moodboard and up "
"to 10 chained image style references."
),
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text prompt for the image.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()),
IO.DynamicCombo.Option(_MODEL_MEDIUM_TURBO, _krea_model_inputs()),
IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()),
],
tooltip="Krea 2 Medium is best for expressive illustrations; "
"Krea 2 Large is best for expressive photorealism.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Random seed for reproducibility.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model", "model.moodboard_id"],
inputs=["model.style_reference"],
),
expr="""
(
$rates := {
"krea 2 medium turbo": {"text": 0.015, "style": 0.0175, "moodboard": 0.02},
"krea 2 medium": {"text": 0.03, "style": 0.035, "moodboard": 0.04},
"krea 2 large": {"text": 0.06, "style": 0.065, "moodboard": 0.07}
};
$r := $lookup($rates, widgets.model);
$hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0;
$hasStyle := $lookup(inputs, "model.style_reference").connected;
$usd := $hasMoodboard ? $r.moodboard : ($hasStyle ? $r.style : $r.text);
{"type":"usd","usd": $usd}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1)
model_choice = model["model"]
endpoint_path = _MODEL_ENDPOINTS.get(model_choice)
if endpoint_path is None:
raise ValueError(f"Unknown Krea 2 model: {model_choice!r}")
moodboards: list[KreaMoodboard] | None = None
mb_id = (model.get("moodboard_id") or "").strip()
if mb_id:
if not _UUID_RE.match(mb_id):
raise ValueError(f"moodboard_id must be a UUID (received {mb_id!r}); copy it from the Krea website.")
mb_strength = model.get("moodboard_strength")
moodboards = [KreaMoodboard(id=mb_id, strength=0.35 if mb_strength is None else float(mb_strength))]
style_reference = model.get("style_reference")
image_style_references: list[KreaImageStyleReference] | None = None
if style_reference:
if len(style_reference) > 10:
raise ValueError(f"Krea 2 accepts at most 10 image_style_references; received {len(style_reference)}.")
image_style_references = [
KreaImageStyleReference(url=ref["url"], strength=float(ref["strength"])) for ref in style_reference
]
initial = await sync_op(
cls,
ApiEndpoint(path=endpoint_path, method="POST"),
response_model=KreaJob,
data=KreaGenerateImageRequest(
prompt=prompt,
aspect_ratio=model["aspect_ratio"],
resolution=model["resolution"],
seed=seed,
creativity=model["creativity"],
moodboards=moodboards,
image_style_references=image_style_references,
),
)
job = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/krea/jobs/{initial.job_id}", method="GET"),
response_model=KreaJob,
status_extractor=lambda r: r.status,
queued_statuses=_KREA_QUEUED_STATUSES,
)
if not job.result or not job.result.urls:
raise RuntimeError(f"Krea 2 job {job.job_id} completed without any image URLs.")
image = await download_url_to_image_tensor(job.result.urls[0])
return IO.NodeOutput(image)
class Krea2StyleReferenceNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="Krea2StyleReferenceNode",
display_name="Krea 2 Style Reference",
category="partner/image/Krea",
description=(
"Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 "
"Style Reference nodes (max 10) and feed the final `style_reference` output "
"into Krea 2 Image. Each image is uploaded to ComfyAPI storage and passed as URL."
),
inputs=[
IO.Image.Input(
"image",
tooltip="Reference image whose style influences the generation.",
),
IO.Float.Input(
"strength",
default=1.0,
min=-2.0,
max=2.0,
step=0.05,
tooltip="Reference strength; negative values invert the style influence.",
),
IO.Custom(KreaIO.STYLE_REF).Input(
"style_reference",
optional=True,
tooltip="Optional incoming chain of style references; this node appends one more.",
),
],
outputs=[IO.Custom(KreaIO.STYLE_REF).Output(display_name="style_reference")],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
)
@classmethod
async def execute(
cls,
image: Input.Image,
strength: float,
style_reference: list[dict] | None = None,
) -> IO.NodeOutput:
chain: list[dict] = list(style_reference) if style_reference else []
if len(chain) >= 10:
raise ValueError("Krea 2 accepts at most 10 image_style_references in one generation.")
url = await _upload_image_to_krea_assets(cls, image)
chain.append({"url": url, "strength": float(strength)})
return IO.NodeOutput(chain)
class KreaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
Krea2ImageNode,
Krea2StyleReferenceNode,
]
async def comfy_entrypoint() -> KreaExtension:
return KreaExtension()

View File

@ -50,7 +50,7 @@ class TextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="LtxvApiTextToVideo",
display_name="LTXV Text To Video",
category="api node/video/LTXV",
category="partner/video/LTXV",
description="Professional-quality videos with customizable duration and resolution.",
inputs=[
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
@ -127,7 +127,7 @@ class ImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="LtxvApiImageToVideo",
display_name="LTXV Image To Video",
category="api node/video/LTXV",
category="partner/video/LTXV",
description="Professional-quality videos with customizable duration and resolution based on start image.",
inputs=[
IO.Image.Input("image", tooltip="First frame to be used for the video."),

View File

@ -46,7 +46,7 @@ class LumaReferenceNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaReferenceNode",
display_name="Luma Reference",
category="api node/image/Luma",
category="partner/image/Luma",
description="Holds an image and weight for use with Luma Generate Image node.",
inputs=[
IO.Image.Input(
@ -85,7 +85,7 @@ class LumaConceptsNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaConceptsNode",
display_name="Luma Concepts",
category="api node/video/Luma",
category="partner/video/Luma",
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
inputs=[
IO.Combo.Input(
@ -134,7 +134,7 @@ class LumaImageGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageNode",
display_name="Luma Text to Image",
category="api node/image/Luma",
category="partner/image/Luma",
description="Generates images synchronously based on prompt and aspect ratio.",
inputs=[
IO.String.Input(
@ -278,7 +278,7 @@ class LumaImageModifyNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageModifyNode",
display_name="Luma Image to Image",
category="api node/image/Luma",
category="partner/image/Luma",
description="Modifies images synchronously based on prompt and aspect ratio.",
inputs=[
IO.Image.Input(
@ -371,7 +371,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaVideoNode",
display_name="Luma Text to Video",
category="api node/video/Luma",
category="partner/video/Luma",
description="Generates videos synchronously based on prompt and output_size.",
inputs=[
IO.String.Input(
@ -472,7 +472,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageToVideoNode",
display_name="Luma Image to Video",
category="api node/video/Luma",
category="partner/video/Luma",
description="Generates videos synchronously based on prompt, input images, and output_size.",
inputs=[
IO.String.Input(
@ -724,7 +724,7 @@ class LumaImageNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageNode2",
display_name="Luma UNI-1 Image",
category="api node/image/Luma",
category="partner/image/Luma",
description="Generate images from text using the Luma UNI-1 model.",
inputs=[
IO.String.Input(
@ -853,7 +853,7 @@ class LumaImageEditNode(IO.ComfyNode):
return IO.Schema(
node_id="LumaImageEditNode2",
display_name="Luma UNI-1 Image Edit",
category="api node/image/Luma",
category="partner/image/Luma",
description="Edit an existing image with a text prompt using the Luma UNI-1 model.",
inputs=[
IO.Image.Input(

View File

@ -61,7 +61,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageUpscalerCreativeNode",
display_name="Magnific Image Upscale (Creative)",
category="api node/image/Magnific",
category="partner/image/Magnific",
description="Promptguided enhancement, stylization, and 2x/4x/8x/16x upscaling. "
"Maximum output: 25.3 megapixels.",
inputs=[
@ -240,7 +240,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageUpscalerPreciseV2Node",
display_name="Magnific Image Upscale (Precise V2)",
category="api node/image/Magnific",
category="partner/image/Magnific",
description="High-fidelity upscaling with fine control over sharpness, grain, and detail. "
"Maximum output: 10060×10060 pixels.",
inputs=[
@ -400,7 +400,7 @@ class MagnificImageStyleTransferNode(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageStyleTransferNode",
display_name="Magnific Image Style Transfer",
category="api node/image/Magnific",
category="partner/image/Magnific",
description="Transfer the style from a reference image to your input image.",
inputs=[
IO.Image.Input("image", tooltip="The image to apply style transfer to."),
@ -549,7 +549,7 @@ class MagnificImageRelightNode(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageRelightNode",
display_name="Magnific Image Relight",
category="api node/image/Magnific",
category="partner/image/Magnific",
description="Relight an image with lighting adjustments and optional reference-based light transfer.",
inputs=[
IO.Image.Input("image", tooltip="The image to relight."),
@ -789,7 +789,7 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode):
return IO.Schema(
node_id="MagnificImageSkinEnhancerNode",
display_name="Magnific Image Skin Enhancer",
category="api node/image/Magnific",
category="partner/image/Magnific",
description="Skin enhancement for portraits with multiple processing modes.",
inputs=[
IO.Image.Input("image", tooltip="The portrait image to enhance."),

View File

@ -33,7 +33,7 @@ class MeshyTextToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyTextToModelNode",
display_name="Meshy: Text to Model",
category="api node/3d/Meshy",
category="partner/3d/Meshy",
inputs=[
IO.Combo.Input("model", options=["latest"]),
IO.String.Input("prompt", multiline=True, default=""),
@ -145,7 +145,7 @@ class MeshyRefineNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyRefineNode",
display_name="Meshy: Refine Draft Model",
category="api node/3d/Meshy",
category="partner/3d/Meshy",
description="Refine a previously created draft model.",
inputs=[
IO.Combo.Input("model", options=["latest"]),
@ -240,7 +240,7 @@ class MeshyImageToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyImageToModelNode",
display_name="Meshy: Image to Model",
category="api node/3d/Meshy",
category="partner/3d/Meshy",
inputs=[
IO.Combo.Input("model", options=["latest"]),
IO.Image.Input("image"),
@ -405,7 +405,7 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyMultiImageToModelNode",
display_name="Meshy: Multi-Image to Model",
category="api node/3d/Meshy",
category="partner/3d/Meshy",
inputs=[
IO.Combo.Input("model", options=["latest"]),
IO.Autogrow.Input(
@ -575,7 +575,7 @@ class MeshyRigModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyRigModelNode",
display_name="Meshy: Rig Model",
category="api node/3d/Meshy",
category="partner/3d/Meshy",
description="Provides a rigged character in standard formats. "
"Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, "
"or humanoid assets with unclear limb and body structure.",
@ -656,7 +656,7 @@ class MeshyAnimateModelNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyAnimateModelNode",
display_name="Meshy: Animate Model",
category="api node/3d/Meshy",
category="partner/3d/Meshy",
description="Apply a specific animation action to a previously rigged character.",
inputs=[
IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"),
@ -722,7 +722,7 @@ class MeshyTextureNode(IO.ComfyNode):
return IO.Schema(
node_id="MeshyTextureNode",
display_name="Meshy: Texture Model",
category="api node/3d/Meshy",
category="partner/3d/Meshy",
inputs=[
IO.Combo.Input("model", options=["latest"]),
IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"),

View File

@ -101,7 +101,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="MinimaxTextToVideoNode",
display_name="MiniMax Text to Video",
category="api node/video/MiniMax",
category="partner/video/MiniMax",
description="Generates videos synchronously based on a prompt, and optional parameters.",
inputs=[
IO.String.Input(
@ -163,7 +163,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="MinimaxImageToVideoNode",
display_name="MiniMax Image to Video",
category="api node/video/MiniMax",
category="partner/video/MiniMax",
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
IO.Image.Input(
@ -230,7 +230,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="MinimaxSubjectToVideoNode",
display_name="MiniMax Subject to Video",
category="api node/video/MiniMax",
category="partner/video/MiniMax",
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
inputs=[
IO.Image.Input(
@ -294,7 +294,7 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="MinimaxHailuoVideoNode",
display_name="MiniMax Hailuo Video",
category="api node/video/MiniMax",
category="partner/video/MiniMax",
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
inputs=[
IO.String.Input(

View File

@ -99,7 +99,7 @@ class OpenAIDalle2(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIDalle2",
display_name="OpenAI DALL·E 2",
category="api node/image/OpenAI",
category="partner/image/OpenAI",
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
inputs=[
IO.String.Input(
@ -249,7 +249,7 @@ class OpenAIDalle3(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIDalle3",
display_name="OpenAI DALL·E 3",
category="api node/image/OpenAI",
category="partner/image/OpenAI",
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
inputs=[
IO.String.Input(
@ -371,7 +371,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIGPTImage1",
display_name="OpenAI GPT Image 2",
category="api node/image/OpenAI",
category="partner/image/OpenAI",
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
is_deprecated=True,
inputs=[
@ -695,7 +695,7 @@ class OpenAIGPTImageNodeV2(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIGPTImageNodeV2",
display_name="OpenAI GPT Image 2",
category="api node/image/OpenAI",
category="partner/image/OpenAI",
description="Generates images via OpenAI's GPT Image endpoint.",
inputs=[
IO.String.Input(
@ -962,7 +962,7 @@ class OpenAIChatNode(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIChatNode",
display_name="OpenAI ChatGPT",
category="api node/text/OpenAI",
category="partner/text/OpenAI",
essentials_category="Text Generation",
description="Generate text responses from an OpenAI model.",
inputs=[
@ -1201,7 +1201,7 @@ class OpenAIInputFiles(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIInputFiles",
display_name="OpenAI ChatGPT Input Files",
category="api node/text/OpenAI",
category="partner/text/OpenAI",
description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.",
inputs=[
IO.Combo.Input(
@ -1248,7 +1248,7 @@ class OpenAIChatConfig(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIChatConfig",
display_name="OpenAI ChatGPT Advanced Options",
category="api node/text/OpenAI",
category="partner/text/OpenAI",
description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.",
inputs=[
IO.Combo.Input(

View File

@ -265,7 +265,7 @@ class OpenRouterLLMNode(IO.ComfyNode):
return IO.Schema(
node_id="OpenRouterLLMNode",
display_name="OpenRouter LLM",
category="api node/text/OpenRouter",
category="partner/text/OpenRouter",
essentials_category="Text Generation",
description=(
"Generate text responses through OpenRouter. Routes to a curated set of popular "

View File

@ -53,7 +53,7 @@ class PixverseTemplateNode(IO.ComfyNode):
return IO.Schema(
node_id="PixverseTemplateNode",
display_name="PixVerse Template",
category="api node/video/PixVerse",
category="partner/video/PixVerse",
inputs=[
IO.Combo.Input("template", options=list(pixverse_templates.keys())),
],
@ -74,7 +74,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="PixverseTextToVideoNode",
display_name="PixVerse Text to Video",
category="api node/video/PixVerse",
category="partner/video/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.String.Input(
@ -192,7 +192,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="PixverseImageToVideoNode",
display_name="PixVerse Image to Video",
category="api node/video/PixVerse",
category="partner/video/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("image"),
@ -310,7 +310,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="PixverseTransitionVideoNode",
display_name="PixVerse Transition Video",
category="api node/video/PixVerse",
category="partner/video/PixVerse",
description="Generates videos based on prompt and output_size.",
inputs=[
IO.Image.Input("first_frame"),

View File

@ -62,7 +62,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
return IO.Schema(
node_id="QuiverTextToSVGNode",
display_name="Quiver Text to SVG",
category="api node/image/Quiver",
category="partner/image/Quiver",
description="Generate an SVG from a text prompt using Quiver AI.",
inputs=[
IO.String.Input(
@ -177,7 +177,7 @@ class QuiverImageToSVGNode(IO.ComfyNode):
return IO.Schema(
node_id="QuiverImageToSVGNode",
display_name="Quiver Image to SVG",
category="api node/image/Quiver",
category="partner/image/Quiver",
description="Vectorize a raster image into SVG using Quiver AI.",
inputs=[
IO.Image.Input(

View File

@ -178,7 +178,7 @@ class RecraftColorRGBNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftColorRGB",
display_name="Recraft Color RGB",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Create Recraft Color by choosing specific RGB values.",
inputs=[
IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."),
@ -204,7 +204,7 @@ class RecraftControlsNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftControls",
display_name="Recraft Controls",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Create Recraft Controls for customizing Recraft generation.",
inputs=[
IO.Custom(RecraftIO.COLOR).Input("colors", optional=True),
@ -228,7 +228,7 @@ class RecraftStyleV3RealisticImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftStyleV3RealisticImage",
display_name="Recraft Style - Realistic Image",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Select realistic_image style and optional substyle.",
inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
@ -253,7 +253,7 @@ class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode):
return IO.Schema(
node_id="RecraftStyleV3DigitalIllustration",
display_name="Recraft Style - Digital Illustration",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Select realistic_image style and optional substyle.",
inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
@ -272,7 +272,7 @@ class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode):
return IO.Schema(
node_id="RecraftStyleV3VectorIllustrationNode",
display_name="Recraft Style - Realistic Image",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Select realistic_image style and optional substyle.",
inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
@ -291,7 +291,7 @@ class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode):
return IO.Schema(
node_id="RecraftStyleV3LogoRaster",
display_name="Recraft Style - Logo Raster",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Select realistic_image style and optional substyle.",
inputs=[
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)),
@ -308,7 +308,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
return IO.Schema(
node_id="RecraftStyleV3InfiniteStyleLibrary",
display_name="Recraft Style - Infinite Style Library",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
inputs=[
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
@ -331,7 +331,7 @@ class RecraftCreateStyleNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftCreateStyleNode",
display_name="Recraft Create Style",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Create a custom style from reference images. "
"Upload 1-5 images to use as style references. "
"Total size of all images is limited to 5 MB.",
@ -400,7 +400,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftTextToImageNode",
display_name="Recraft Text to Image",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Generates images synchronously based on prompt and resolution.",
inputs=[
IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."),
@ -512,7 +512,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftImageToImageNode",
display_name="Recraft Image to Image",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Modify image based on prompt and strength.",
inputs=[
IO.Image.Input("image"),
@ -630,7 +630,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftImageInpaintingNode",
display_name="Recraft Image Inpainting",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Modify image based on prompt and mask.",
inputs=[
IO.Image.Input("image"),
@ -732,7 +732,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftTextToVectorNode",
display_name="Recraft Text to Vector",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Generates SVG synchronously based on prompt and resolution.",
inputs=[
IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True),
@ -832,7 +832,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftVectorizeImageNode",
display_name="Recraft Vectorize Image",
category="api node/image/Recraft",
category="partner/image/Recraft",
essentials_category="Image Tools",
description="Generates SVG synchronously from an input image.",
inputs=[
@ -876,7 +876,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftReplaceBackgroundNode",
display_name="Recraft Replace Background",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Replace background on image, based on provided prompt.",
inputs=[
IO.Image.Input("image"),
@ -963,7 +963,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftRemoveBackgroundNode",
display_name="Recraft Remove Background",
category="api node/image/Recraft",
category="partner/image/Recraft",
essentials_category="Image Tools",
description="Remove background from image, and return processed image and mask.",
inputs=[
@ -1012,7 +1012,7 @@ class RecraftCrispUpscaleNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftCrispUpscaleNode",
display_name="Recraft Crisp Upscale Image",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Upscale image synchronously.\n"
"Enhances a given raster image using crisp upscale tool, "
"increasing image resolution, making the image sharper and cleaner.",
@ -1058,7 +1058,7 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
return IO.Schema(
node_id="RecraftCreativeUpscaleNode",
display_name="Recraft Creative Upscale Image",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Upscale image synchronously.\n"
"Enhances a given raster image using creative upscale tool, "
"boosting resolution with a focus on refining small details and faces.",
@ -1086,7 +1086,7 @@ class RecraftV4TextToImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftV4TextToImageNode",
display_name="Recraft V4 Text to Image",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Generates images using Recraft V4 or V4 Pro models.",
inputs=[
IO.String.Input(
@ -1210,7 +1210,7 @@ class RecraftV4TextToVectorNode(IO.ComfyNode):
return IO.Schema(
node_id="RecraftV4TextToVectorNode",
display_name="Recraft V4 Text to Vector",
category="api node/image/Recraft",
category="partner/image/Recraft",
description="Generates SVG using Recraft V4 or V4 Pro models.",
inputs=[
IO.String.Input(

View File

@ -109,7 +109,7 @@ class ReveImageCreateNode(IO.ComfyNode):
return IO.Schema(
node_id="ReveImageCreateNode",
display_name="Reve Image Create",
category="api node/image/Reve",
category="partner/image/Reve",
description="Generate images from text descriptions using Reve.",
inputs=[
IO.String.Input(
@ -200,7 +200,7 @@ class ReveImageEditNode(IO.ComfyNode):
return IO.Schema(
node_id="ReveImageEditNode",
display_name="Reve Image Edit",
category="api node/image/Reve",
category="partner/image/Reve",
description="Edit images using natural language instructions with Reve.",
inputs=[
IO.Image.Input("image", tooltip="The image to edit."),
@ -300,7 +300,7 @@ class ReveImageRemixNode(IO.ComfyNode):
return IO.Schema(
node_id="ReveImageRemixNode",
display_name="Reve Image Remix",
category="api node/image/Reve",
category="partner/image/Reve",
description="Combine reference images with text prompts to create new images using Reve.",
inputs=[
IO.Autogrow.Input(

View File

@ -230,7 +230,7 @@ class Rodin3D_Regular(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Regular",
display_name="Rodin 3D Generate - Regular Generate",
category="api node/3d/Rodin",
category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -289,7 +289,7 @@ class Rodin3D_Detail(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Detail",
display_name="Rodin 3D Generate - Detail Generate",
category="api node/3d/Rodin",
category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -348,7 +348,7 @@ class Rodin3D_Smooth(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Smooth",
display_name="Rodin 3D Generate - Smooth Generate",
category="api node/3d/Rodin",
category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -406,7 +406,7 @@ class Rodin3D_Sketch(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Sketch",
display_name="Rodin 3D Generate - Sketch Generate",
category="api node/3d/Rodin",
category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -468,7 +468,7 @@ class Rodin3D_Gen2(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Gen2",
display_name="Rodin 3D Generate - Gen-2 Generate",
category="api node/3d/Rodin",
category="partner/3d/Rodin",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("Images"),
@ -941,7 +941,7 @@ class Rodin3D_Gen25_Image(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Gen25_Image",
display_name="Rodin 3D Gen-2.5 - Image to 3D",
category="api node/3d/Rodin",
category="partner/3d/Rodin",
description=(
"Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
@ -1035,7 +1035,7 @@ class Rodin3D_Gen25_Text(IO.ComfyNode):
return IO.Schema(
node_id="Rodin3D_Gen25_Text",
display_name="Rodin 3D Gen-2.5 - Text to 3D",
category="api node/3d/Rodin",
category="partner/3d/Rodin",
description=(
"Generate a 3D model from a text prompt via Rodin Gen-2.5. "
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."

View File

@ -140,7 +140,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
return IO.Schema(
node_id="RunwayImageToVideoNodeGen3a",
display_name="Runway Image to Video (Gen3a Turbo)",
category="api node/video/Runway",
category="partner/video/Runway",
description="Generate a video from a single starting frame using Gen3a Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
@ -234,7 +234,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
return IO.Schema(
node_id="RunwayImageToVideoNodeGen4",
display_name="Runway Image to Video (Gen4 Turbo)",
category="api node/video/Runway",
category="partner/video/Runway",
description="Generate a video from a single starting frame using Gen4 Turbo model. "
"Before diving in, review these best practices to ensure that "
"your input selections will set your generation up for success: "
@ -329,7 +329,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="RunwayFirstLastFrameNode",
display_name="Runway First-Last-Frame to Video",
category="api node/video/Runway",
category="partner/video/Runway",
description="Upload first and last keyframes, draft a prompt, and generate a video. "
"More complex transitions, such as cases where the Last frame is completely different "
"from the First frame, may benefit from the longer 10s duration. "
@ -440,7 +440,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
return IO.Schema(
node_id="RunwayTextToImageNode",
display_name="Runway Text to Image",
category="api node/image/Runway",
category="partner/image/Runway",
description="Generate an image from a text prompt using Runway's Gen 4 model. "
"You can also include reference image to guide the generation.",
inputs=[

View File

@ -34,7 +34,7 @@ class SoniloVideoToMusic(IO.ComfyNode):
return IO.Schema(
node_id="SoniloVideoToMusic",
display_name="Sonilo Video to Music",
category="api node/audio/Sonilo",
category="partner/audio/Sonilo",
description="Generate music from video content using Sonilo's AI model. "
"Analyzes the video and creates matching music.",
inputs=[
@ -99,7 +99,7 @@ class SoniloTextToMusic(IO.ComfyNode):
return IO.Schema(
node_id="SoniloTextToMusic",
display_name="Sonilo Text to Music",
category="api node/audio/Sonilo",
category="partner/audio/Sonilo",
description="Generate music from a text prompt using Sonilo's AI model. "
"Leave duration at 0 to let the model infer it from the prompt.",
inputs=[

View File

@ -34,7 +34,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
return IO.Schema(
node_id="OpenAIVideoSora2",
display_name="OpenAI Sora - Video (DEPRECATED)",
category="api node/video/Sora",
category="partner/video/Sora",
description=(
"OpenAI video and audio generation.\n\n"
"DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. "

View File

@ -62,7 +62,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
return IO.Schema(
node_id="StabilityStableImageUltraNode",
display_name="Stability AI Stable Image Ultra",
category="api node/image/Stability AI",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
@ -197,7 +197,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
return IO.Schema(
node_id="StabilityStableImageSD_3_5Node",
display_name="Stability AI Stable Diffusion 3.5 Image",
category="api node/image/Stability AI",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.String.Input(
@ -354,7 +354,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
return IO.Schema(
node_id="StabilityUpscaleConservativeNode",
display_name="Stability AI Upscale Conservative",
category="api node/image/Stability AI",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
@ -457,7 +457,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
return IO.Schema(
node_id="StabilityUpscaleCreativeNode",
display_name="Stability AI Upscale Creative",
category="api node/image/Stability AI",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
@ -578,7 +578,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
return IO.Schema(
node_id="StabilityUpscaleFastNode",
display_name="Stability AI Upscale Fast",
category="api node/image/Stability AI",
category="partner/image/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Image.Input("image"),
@ -630,7 +630,7 @@ class StabilityTextToAudio(IO.ComfyNode):
return IO.Schema(
node_id="StabilityTextToAudio",
display_name="Stability AI Text To Audio",
category="api node/audio/Stability AI",
category="partner/audio/Stability AI",
essentials_category="Audio",
description=cleandoc(cls.__doc__ or ""),
inputs=[
@ -708,7 +708,7 @@ class StabilityAudioToAudio(IO.ComfyNode):
return IO.Schema(
node_id="StabilityAudioToAudio",
display_name="Stability AI Audio To Audio",
category="api node/audio/Stability AI",
category="partner/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(
@ -802,7 +802,7 @@ class StabilityAudioInpaint(IO.ComfyNode):
return IO.Schema(
node_id="StabilityAudioInpaint",
display_name="Stability AI Audio Inpaint",
category="api node/audio/Stability AI",
category="partner/audio/Stability AI",
description=cleandoc(cls.__doc__ or ""),
inputs=[
IO.Combo.Input(

View File

@ -52,7 +52,7 @@ class TopazImageEnhance(IO.ComfyNode):
return IO.Schema(
node_id="TopazImageEnhance",
display_name="Topaz Image Enhance",
category="api node/image/Topaz",
category="partner/image/Topaz",
description="Industry-standard upscaling and image enhancement.",
inputs=[
IO.Combo.Input("model", options=["Reimagine"]),
@ -235,7 +235,7 @@ class TopazVideoEnhance(IO.ComfyNode):
return IO.Schema(
node_id="TopazVideoEnhance",
display_name="Topaz Video Enhance (Legacy)",
category="api node/video/Topaz",
category="partner/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),
@ -475,7 +475,7 @@ class TopazVideoEnhanceV2(IO.ComfyNode):
return IO.Schema(
node_id="TopazVideoEnhanceV2",
display_name="Topaz Video Enhance",
category="api node/video/Topaz",
category="partner/video/Topaz",
description="Breathe new life into video with powerful upscaling and recovery technology.",
inputs=[
IO.Video.Input("video"),

View File

@ -11,6 +11,9 @@ from comfy_api_nodes.apis.tripo import (
TripoModelVersion,
TripoMultiviewToModelRequest,
TripoOrientation,
TripoP1ImageToModelRequest,
TripoP1MultiviewToModelRequest,
TripoP1TextToModelRequest,
TripoRefineModelRequest,
TripoStyle,
TripoTaskResponse,
@ -80,7 +83,7 @@ class TripoTextToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoTextToModelNode",
display_name="Tripo: Text to Model",
category="api node/3d/Tripo",
category="partner/3d/Tripo",
inputs=[
IO.String.Input("prompt", multiline=True),
IO.String.Input("negative_prompt", multiline=True, optional=True),
@ -93,10 +96,22 @@ class TripoTextToModelNode(IO.ComfyNode):
IO.Int.Input("image_seed", default=42, optional=True, advanced=True),
IO.Int.Input("model_seed", default=42, optional=True, advanced=True),
IO.Int.Input("texture_seed", default=42, optional=True, advanced=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Combo.Input(
"texture_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True, advanced=True),
IO.Boolean.Input("quad", default=False, optional=True, advanced=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Combo.Input(
"geometry_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
@ -195,7 +210,7 @@ class TripoImageToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoImageToModelNode",
display_name="Tripo: Image to Model",
category="api node/3d/Tripo",
category="partner/3d/Tripo",
inputs=[
IO.Image.Input("image"),
IO.Combo.Input(
@ -209,16 +224,36 @@ class TripoImageToModelNode(IO.ComfyNode):
IO.Boolean.Input("pbr", default=True, optional=True),
IO.Int.Input("model_seed", default=42, optional=True, advanced=True),
IO.Combo.Input(
"orientation", options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True, advanced=True
"orientation",
options=TripoOrientation,
default=TripoOrientation.DEFAULT,
optional=True,
advanced=True,
),
IO.Int.Input("texture_seed", default=42, optional=True, advanced=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Combo.Input(
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True
"texture_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
IO.Combo.Input(
"texture_alignment",
default="original_image",
options=["original_image", "geometry"],
optional=True,
advanced=True,
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True),
IO.Boolean.Input("quad", default=False, optional=True, advanced=True),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Combo.Input(
"geometry_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
@ -323,7 +358,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoMultiviewToModelNode",
display_name="Tripo: Multiview to Model",
category="api node/3d/Tripo",
category="partner/3d/Tripo",
inputs=[
IO.Image.Input("image"),
IO.Image.Input("image_left", optional=True),
@ -346,13 +381,35 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
IO.Boolean.Input("pbr", default=True, optional=True),
IO.Int.Input("model_seed", default=42, optional=True, advanced=True),
IO.Int.Input("texture_seed", default=42, optional=True, advanced=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Combo.Input(
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True
"texture_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
IO.Combo.Input(
"texture_alignment",
default="original_image",
options=["original_image", "geometry"],
optional=True,
advanced=True,
),
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True),
IO.Boolean.Input("quad", default=False, optional=True, advanced=True, tooltip="This parameter is deprecated and does nothing."),
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Boolean.Input(
"quad",
default=False,
optional=True,
advanced=True,
tooltip="This parameter is deprecated and does nothing.",
),
IO.Combo.Input(
"geometry_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
@ -461,15 +518,25 @@ class TripoTextureNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoTextureNode",
display_name="Tripo: Texture model",
category="api node/3d/Tripo",
category="partner/3d/Tripo",
inputs=[
IO.Custom("MODEL_TASK_ID").Input("model_task_id"),
IO.Boolean.Input("texture", default=True, optional=True),
IO.Boolean.Input("pbr", default=True, optional=True),
IO.Int.Input("texture_seed", default=42, optional=True, advanced=True),
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
IO.Combo.Input(
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True
"texture_quality",
default="standard",
options=["standard", "detailed"],
optional=True,
advanced=True,
),
IO.Combo.Input(
"texture_alignment",
default="original_image",
options=["original_image", "geometry"],
optional=True,
advanced=True,
),
],
outputs=[
@ -528,7 +595,7 @@ class TripoRefineNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoRefineNode",
display_name="Tripo: Refine Draft model",
category="api node/3d/Tripo",
category="partner/3d/Tripo",
description="Refine a draft model created by v1.4 Tripo models only.",
inputs=[
IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"),
@ -568,7 +635,7 @@ class TripoRigNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoRigNode",
display_name="Tripo: Rig model",
category="api node/3d/Tripo",
category="partner/3d/Tripo",
inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
@ -605,7 +672,7 @@ class TripoRetargetNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoRetargetNode",
display_name="Tripo: Retarget rigged model",
category="api node/3d/Tripo",
category="partner/3d/Tripo",
inputs=[
IO.Custom("RIG_TASK_ID").Input("original_model_task_id"),
IO.Combo.Input(
@ -626,7 +693,7 @@ class TripoRetargetNode(IO.ComfyNode):
"preset:hexapod:walk",
"preset:octopod:walk",
"preset:serpentine:march",
"preset:aquatic:march"
"preset:aquatic:march",
],
),
],
@ -670,7 +737,7 @@ class TripoConversionNode(IO.ComfyNode):
return IO.Schema(
node_id="TripoConversionNode",
display_name="Tripo: Convert model",
category="api node/3d/Tripo",
category="partner/3d/Tripo",
inputs=[
IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"),
IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]),
@ -817,7 +884,7 @@ class TripoConversionNode(IO.ComfyNode):
# Parse part_names from comma-separated string to list
part_names_list = None
if part_names and part_names.strip():
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
part_names_list = [name.strip() for name in part_names.split(",") if name.strip()]
response = await sync_op(
cls,
@ -848,6 +915,373 @@ class TripoConversionNode(IO.ComfyNode):
return await poll_until_finished(cls, response, average_duration=30)
def _p1_price_expr(*, geometry_credits: int, textured_credits: int, detailed_credits: int) -> str:
return (
"("
" $mode := widgets.output_mode;"
' $detailed := $lookup(widgets, "output_mode.texture_quality") = "detailed";'
f' $credits := $mode = "geometry only" ? {geometry_credits} : ($detailed ? {detailed_credits} : {textured_credits});'
' {"type":"usd","usd": $credits * 0.01, "format": {"approximate": true}}'
")"
)
def _p1_textured_inputs(*, include_image_alignment: bool) -> list:
"""Inputs shown inside the 'Textured' branch of the P1 output_mode DynamicCombo."""
inputs: list = [
IO.Boolean.Input("pbr", default=True, tooltip="Include PBR maps. When on, base texture is forced on too."),
IO.Combo.Input("texture_quality", options=["standard", "detailed"], default="standard"),
]
if include_image_alignment:
inputs.extend(
[
IO.Combo.Input(
"texture_alignment",
options=["original_image", "geometry"],
default="original_image",
tooltip="Prioritize visual fidelity to the source image, or alignment to the mesh geometry.",
),
IO.Combo.Input(
"orientation",
options=["default", "align_image"],
default="default",
tooltip="Rotate the output to match the source image. Only applies when textured.",
),
]
)
inputs.append(IO.Int.Input("texture_seed", default=42, advanced=True))
return inputs
def _build_p1_output_mode(*, include_image_alignment: bool) -> IO.DynamicCombo.Input:
return IO.DynamicCombo.Input(
"output_mode",
options=[
IO.DynamicCombo.Option("Geometry only", []),
IO.DynamicCombo.Option("Textured", _p1_textured_inputs(include_image_alignment=include_image_alignment)),
],
tooltip='"Geometry only" returns an untextured mesh. "Textured" adds color/PBR maps.',
)
def _resolve_p1_texture_fields(output_mode: dict) -> dict:
"""Translate the output_mode DynamicCombo payload into P1 request fields.
pbr=true forces texture=true server-side, but we send both explicitly so the
intent is visible in the request body and logs.
"""
mode = output_mode["output_mode"]
if mode == "Geometry only":
return {"texture": False, "pbr": False}
out = {
"texture": True,
"pbr": bool(output_mode.get("pbr", True)),
"texture_quality": output_mode.get("texture_quality", "standard"),
"texture_seed": output_mode.get("texture_seed"),
}
if "texture_alignment" in output_mode:
out["texture_alignment"] = output_mode["texture_alignment"]
if "orientation" in output_mode:
out["orientation"] = output_mode["orientation"]
return out
def _p1_common_inputs() -> list:
"""Inputs shared by all P1 nodes (placed after output_mode)."""
return [
IO.Int.Input(
"face_limit",
default=-1,
min=-1,
max=20000,
optional=True,
advanced=True,
tooltip="Target face count, 48-20000. -1 lets Tripo pick adaptively.",
),
IO.Int.Input("model_seed", default=42, optional=True, advanced=True),
IO.Boolean.Input(
"auto_size",
default=False,
optional=True,
advanced=True,
tooltip="Scale the output to approximate real-world meters.",
),
IO.Boolean.Input(
"export_uv",
default=True,
optional=True,
advanced=True,
tooltip="UV unwrap during generation. Turn off for faster geometry-only runs.",
),
IO.Boolean.Input(
"compress_geometry",
default=False,
optional=True,
advanced=True,
tooltip="Apply geometry-based compression. Decompress before editing.",
),
]
def _build_p1_request_kwargs(
*,
output_mode: dict,
face_limit: int,
model_seed: int,
auto_size: bool,
export_uv: bool,
compress_geometry: bool,
) -> dict:
"""Common P1 request fields shared by all three node types."""
kwargs: dict = {
"model_seed": model_seed,
"face_limit": face_limit if face_limit != -1 else None,
"auto_size": auto_size,
"export_uv": export_uv,
"compress": "geometry" if compress_geometry else None,
}
kwargs.update(_resolve_p1_texture_fields(output_mode))
return kwargs
class TripoP1TextToModelNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoP1TextToModelNode",
display_name="Tripo P1: Text to Model",
category="partner/3d/Tripo",
description="Tripo P1 text-to-3D. Optimized for low-poly, game-ready meshes with stable topology.",
inputs=[
IO.String.Input("prompt", multiline=True, tooltip="Up to 1024 characters."),
IO.String.Input("negative_prompt", multiline=True, optional=True, tooltip="Up to 255 characters."),
_build_p1_output_mode(include_image_alignment=False),
IO.Int.Input("image_seed", default=42, optional=True, advanced=True),
*_p1_common_inputs(),
],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]),
expr=_p1_price_expr(geometry_credits=30, textured_credits=40, detailed_credits=50),
),
)
@classmethod
async def execute(
cls,
prompt: str,
output_mode: dict,
negative_prompt: str | None = None,
image_seed: int | None = None,
face_limit: int = -1,
model_seed: int | None = None,
auto_size: bool = False,
export_uv: bool = True,
compress_geometry: bool = False,
) -> IO.NodeOutput:
if not prompt:
raise RuntimeError("Prompt is required")
common = _build_p1_request_kwargs(
output_mode=output_mode,
face_limit=face_limit,
model_seed=model_seed,
auto_size=auto_size,
export_uv=export_uv,
compress_geometry=compress_geometry,
)
request = TripoP1TextToModelRequest(
prompt=prompt,
negative_prompt=negative_prompt or None,
image_seed=image_seed,
**common,
)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=request,
)
return await poll_until_finished(cls, response, average_duration=60)
class TripoP1ImageToModelNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoP1ImageToModelNode",
display_name="Tripo P1: Image to Model",
category="partner/3d/Tripo",
description="Tripo P1 image-to-3D. Optimized for low-poly, game-ready meshes.",
inputs=[
IO.Image.Input("image"),
_build_p1_output_mode(include_image_alignment=True),
IO.Boolean.Input(
"enable_image_autofix",
default=False,
optional=True,
advanced=True,
tooltip="Pre-process the input image for better generation quality.",
),
*_p1_common_inputs(),
],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]),
expr=_p1_price_expr(geometry_credits=40, textured_credits=50, detailed_credits=60),
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
output_mode: dict,
enable_image_autofix: bool = False,
face_limit: int = -1,
model_seed: int | None = None,
auto_size: bool = False,
export_uv: bool = True,
compress_geometry: bool = False,
) -> IO.NodeOutput:
if image is None:
raise RuntimeError("Image is required")
tripo_file = TripoFileReference(
root=TripoUrlReference(
url=(await upload_images_to_comfyapi(cls, image, max_images=1))[0],
type="jpeg",
)
)
common = _build_p1_request_kwargs(
output_mode=output_mode,
face_limit=face_limit,
model_seed=model_seed,
auto_size=auto_size,
export_uv=export_uv,
compress_geometry=compress_geometry,
)
request = TripoP1ImageToModelRequest(
file=tripo_file,
enable_image_autofix=enable_image_autofix,
**common,
)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=request,
)
return await poll_until_finished(cls, response, average_duration=60)
class TripoP1MultiviewToModelNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoP1MultiviewToModelNode",
display_name="Tripo P1: Multiview to Model",
category="partner/3d/Tripo",
description="Tripo P1 multiview-to-3D from 2-4 reference images in [front, left, back, right] order. "
"Front is required; any combination of the other three may be omitted.",
inputs=[
IO.Image.Input("image", tooltip="Front view (0°). Required."),
IO.Image.Input(
"image_left",
optional=True,
tooltip="Left view (90°), i.e. the subject's left side.",
),
IO.Image.Input("image_back", optional=True, tooltip="Back view (180°)."),
IO.Image.Input(
"image_right",
optional=True,
tooltip="Right view (270°), i.e. the subject's right side.",
),
_build_p1_output_mode(include_image_alignment=True),
*_p1_common_inputs(),
],
outputs=[
IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
IO.File3DGLB.Output(display_name="GLB"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]),
expr=_p1_price_expr(geometry_credits=40, textured_credits=50, detailed_credits=60),
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
output_mode: dict,
image_left: Input.Image | None = None,
image_back: Input.Image | None = None,
image_right: Input.Image | None = None,
face_limit: int = -1,
model_seed: int | None = None,
auto_size: bool = False,
export_uv: bool = True,
compress_geometry: bool = False,
) -> IO.NodeOutput:
views = [image, image_left, image_back, image_right]
if sum(1 for v in views if v is not None) < 2:
raise RuntimeError("Tripo P1 multiview requires at least 2 images (front plus one of left/back/right).")
files: list[TripoFileReference] = []
for view in views:
if view is None:
files.append(TripoFileReference(root=TripoFileEmptyReference()))
continue
url = (await upload_images_to_comfyapi(cls, view, max_images=1))[0]
files.append(TripoFileReference(root=TripoUrlReference(url=url, type="jpeg")))
common = _build_p1_request_kwargs(
output_mode=output_mode,
face_limit=face_limit,
model_seed=model_seed,
auto_size=auto_size,
export_uv=export_uv,
compress_geometry=compress_geometry,
)
request = TripoP1MultiviewToModelRequest(files=files, **common)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
response_model=TripoTaskResponse,
data=request,
)
return await poll_until_finished(cls, response, average_duration=80)
class TripoExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -855,6 +1289,9 @@ class TripoExtension(ComfyExtension):
TripoTextToModelNode,
TripoImageToModelNode,
TripoMultiviewToModelNode,
TripoP1TextToModelNode,
TripoP1ImageToModelNode,
TripoP1MultiviewToModelNode,
TripoTextureNode,
TripoRefineNode,
TripoRigNode,

View File

@ -45,7 +45,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="VeoVideoGenerationNode",
display_name="Google Veo 2 Video Generation",
category="api node/video/Veo",
category="partner/video/Veo",
description="Generates videos from text prompts using Google's Veo 2 API",
inputs=[
IO.String.Input(
@ -256,7 +256,7 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
return IO.Schema(
node_id="Veo3VideoGenerationNode",
display_name="Google Veo 3 Video Generation",
category="api node/video/Veo",
category="partner/video/Veo",
description="Generates videos from text prompts using Google's Veo 3 API",
inputs=[
IO.String.Input(
@ -468,7 +468,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
return IO.Schema(
node_id="Veo3FirstLastFrameNode",
display_name="Google Veo 3 First-Last-Frame to Video",
category="api node/video/Veo",
category="partner/video/Veo",
description="Generate video using prompt and first and last frames.",
inputs=[
IO.String.Input(

View File

@ -71,7 +71,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduTextToVideoNode",
display_name="Vidu Text To Video Generation",
category="api node/video/Vidu",
category="partner/video/Vidu",
description="Generate video from a text prompt",
inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -169,7 +169,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduImageToVideoNode",
display_name="Vidu Image To Video Generation",
category="api node/video/Vidu",
category="partner/video/Vidu",
description="Generate video from image and optional prompt",
inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -273,7 +273,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduReferenceVideoNode",
display_name="Vidu Reference To Video Generation",
category="api node/video/Vidu",
category="partner/video/Vidu",
description="Generate video from multiple images and a prompt",
inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -388,7 +388,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduStartEndToVideoNode",
display_name="Vidu Start End To Video Generation",
category="api node/video/Vidu",
category="partner/video/Vidu",
description="Generate a video from start and end frames and a prompt",
inputs=[
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
@ -492,7 +492,7 @@ class Vidu2TextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu2TextToVideoNode",
display_name="Vidu2 Text-to-Video Generation",
category="api node/video/Vidu",
category="partner/video/Vidu",
description="Generate video from a text prompt",
inputs=[
IO.Combo.Input("model", options=["viduq2"]),
@ -584,7 +584,7 @@ class Vidu2ImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu2ImageToVideoNode",
display_name="Vidu2 Image-to-Video Generation",
category="api node/video/Vidu",
category="partner/video/Vidu",
description="Generate a video from an image and an optional prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
@ -714,7 +714,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu2ReferenceVideoNode",
display_name="Vidu2 Reference-to-Video Generation",
category="api node/video/Vidu",
category="partner/video/Vidu",
description="Generate a video from multiple reference images and a prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2"]),
@ -849,7 +849,7 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu2StartEndToVideoNode",
display_name="Vidu2 Start/End Frame-to-Video Generation",
category="api node/video/Vidu",
category="partner/video/Vidu",
description="Generate a video from a start frame, an end frame, and a prompt.",
inputs=[
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
@ -969,7 +969,7 @@ class ViduExtendVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduExtendVideoNode",
display_name="Vidu Video Extension",
category="api node/video/Vidu",
category="partner/video/Vidu",
description="Extend an existing video by generating additional frames.",
inputs=[
IO.DynamicCombo.Input(
@ -1138,7 +1138,7 @@ class ViduMultiFrameVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="ViduMultiFrameVideoNode",
display_name="Vidu Multi-Frame Video Generation",
category="api node/video/Vidu",
category="partner/video/Vidu",
description="Generate a video with multiple keyframe transitions.",
inputs=[
IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]),
@ -1284,7 +1284,7 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu3TextToVideoNode",
display_name="Vidu Q3 Text-to-Video Generation",
category="api node/video/Vidu",
category="partner/video/Vidu",
description="Generate video from a text prompt.",
inputs=[
IO.DynamicCombo.Input(
@ -1429,7 +1429,7 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu3ImageToVideoNode",
display_name="Vidu Q3 Image-to-Video Generation",
category="api node/video/Vidu",
category="partner/video/Vidu",
description="Generate a video from an image and an optional prompt.",
inputs=[
IO.DynamicCombo.Input(
@ -1571,7 +1571,7 @@ class Vidu3StartEndToVideoNode(IO.ComfyNode):
return IO.Schema(
node_id="Vidu3StartEndToVideoNode",
display_name="Vidu Q3 Start/End Frame-to-Video Generation",
category="api node/video/Vidu",
category="partner/video/Vidu",
description="Generate a video from a start frame, an end frame, and a prompt.",
inputs=[
IO.DynamicCombo.Input(

View File

@ -61,7 +61,7 @@ class WanTextToImageApi(IO.ComfyNode):
return IO.Schema(
node_id="WanTextToImageApi",
display_name="Wan Text to Image",
category="api node/image/Wan",
category="partner/image/Wan",
description="Generates an image based on a text prompt.",
inputs=[
IO.Combo.Input(
@ -184,7 +184,7 @@ class WanImageToImageApi(IO.ComfyNode):
return IO.Schema(
node_id="WanImageToImageApi",
display_name="Wan Image to Image",
category="api node/image/Wan",
category="partner/image/Wan",
description="Generates an image from one or two input images and a text prompt. "
"The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
inputs=[
@ -312,7 +312,7 @@ class WanTextToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="WanTextToVideoApi",
display_name="Wan Text to Video",
category="api node/video/Wan",
category="partner/video/Wan",
description="Generates a video based on a text prompt.",
inputs=[
IO.Combo.Input(
@ -495,7 +495,7 @@ class WanImageToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="WanImageToVideoApi",
display_name="Wan Image to Video",
category="api node/video/Wan",
category="partner/video/Wan",
description="Generates a video from the first frame and a text prompt.",
inputs=[
IO.Combo.Input(
@ -674,7 +674,7 @@ class WanReferenceVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="WanReferenceVideoApi",
display_name="Wan Reference to Video",
category="api node/video/Wan",
category="partner/video/Wan",
description="Use the character and voice from input videos, combined with a prompt, "
"to generate a new video that maintains character consistency.",
inputs=[
@ -828,7 +828,7 @@ class Wan2TextToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2TextToVideoApi",
display_name="Wan 2.7 Text to Video",
category="api node/video/Wan",
category="partner/video/Wan",
description="Generates a video based on a text prompt using the Wan 2.7 model.",
inputs=[
IO.DynamicCombo.Input(
@ -981,7 +981,7 @@ class Wan2ImageToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2ImageToVideoApi",
display_name="Wan 2.7 Image to Video",
category="api node/video/Wan",
category="partner/video/Wan",
description="Generate a video from a first-frame image, with optional last-frame image and audio.",
inputs=[
IO.DynamicCombo.Input(
@ -1152,7 +1152,7 @@ class Wan2VideoContinuationApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2VideoContinuationApi",
display_name="Wan 2.7 Video Continuation",
category="api node/video/Wan",
category="partner/video/Wan",
description="Continue a video from where it left off, with optional last-frame control.",
inputs=[
IO.DynamicCombo.Input(
@ -1319,7 +1319,7 @@ class Wan2VideoEditApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2VideoEditApi",
display_name="Wan 2.7 Video Edit",
category="api node/video/Wan",
category="partner/video/Wan",
description="Edit a video using text instructions, reference images, or style transfer.",
inputs=[
IO.DynamicCombo.Input(
@ -1477,7 +1477,7 @@ class Wan2ReferenceVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="Wan2ReferenceVideoApi",
display_name="Wan 2.7 Reference to Video",
category="api node/video/Wan",
category="partner/video/Wan",
description="Generate a video featuring a person or object from reference materials. "
"Supports single-character performances and multi-character interactions.",
inputs=[
@ -1651,7 +1651,7 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="HappyHorseTextToVideoApi",
display_name="HappyHorse Text to Video",
category="api node/video/Wan",
category="partner/video/Wan",
description="Generates a video based on a text prompt using the HappyHorse model.",
inputs=[
IO.DynamicCombo.Input(
@ -1775,7 +1775,7 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="HappyHorseImageToVideoApi",
display_name="HappyHorse Image to Video",
category="api node/video/Wan",
category="partner/video/Wan",
description="Generate a video from a first-frame image using the HappyHorse model.",
inputs=[
IO.DynamicCombo.Input(
@ -1905,7 +1905,7 @@ class HappyHorseVideoEditApi(IO.ComfyNode):
return IO.Schema(
node_id="HappyHorseVideoEditApi",
display_name="HappyHorse Video Edit",
category="api node/video/Wan",
category="partner/video/Wan",
description="Edit a video using text instructions or reference images with the HappyHorse model. "
"Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.",
inputs=[
@ -2046,7 +2046,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
return IO.Schema(
node_id="HappyHorseReferenceVideoApi",
display_name="HappyHorse Reference to Video",
category="api node/video/Wan",
category="partner/video/Wan",
description="Generate a video featuring a person or object from reference materials with the HappyHorse "
"model. Supports single-character performances and multi-character interactions.",
inputs=[

View File

@ -27,7 +27,7 @@ class WavespeedFlashVSRNode(IO.ComfyNode):
return IO.Schema(
node_id="WavespeedFlashVSRNode",
display_name="FlashVSR Video Upscale",
category="api node/video/WaveSpeed",
category="partner/video/WaveSpeed",
description="Fast, high-quality video upscaler that "
"boosts resolution and restores clarity for low-resolution or blurry footage.",
inputs=[
@ -98,7 +98,7 @@ class WavespeedImageUpscaleNode(IO.ComfyNode):
return IO.Schema(
node_id="WavespeedImageUpscaleNode",
display_name="WaveSpeed Image Upscale",
category="api node/image/WaveSpeed",
category="partner/image/WaveSpeed",
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
inputs=[
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),

View File

@ -86,7 +86,7 @@ class _PollUIState:
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"]
async def sync_op(

View File

@ -469,6 +469,11 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input
input_container = None
output_container = None
# get_stream_source() is untrimmed, so apply the trim window in this same pass.
# start_time is normalized (>= 0); duration == 0 means "until the end".
start_time, duration = video.get_active_trim_window()
trimming = bool(start_time or duration)
try:
input_source = video.get_stream_source()
input_container = av.open(input_source, mode="r")
@ -487,16 +492,45 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input
audio_stream.layout = stream.layout
break
in_video = input_container.streams.video[0]
start_pts = int(start_time / in_video.time_base) if trimming else 0
end_pts = int((start_time + duration) / in_video.time_base) if duration else None
if start_pts:
input_container.seek(start_pts, stream=in_video)
encoded = 0
for frame in input_container.decode(video=0):
if trimming:
if frame.pts is None or frame.pts < start_pts:
continue
if end_pts is not None and frame.pts >= end_pts:
break
frame = frame.reformat(width=out_w, height=out_h, format="yuv420p")
# Re-wrap as a fresh frame: dropping irregular source timestamps (VFR/AVI/GIF/...)
# lets the encoder assign clean ones and avoids mp4 muxer errors.
frame = av.VideoFrame.from_ndarray(frame.to_ndarray(format="yuv420p"), format="yuv420p")
for packet in video_stream.encode(frame):
output_container.mux(packet)
encoded += 1
for packet in video_stream.encode():
output_container.mux(packet)
if encoded == 0:
raise ValueError(
f"resize produced no frames (start_time={start_time}, duration={duration} "
"selected nothing from the source)"
)
if audio_stream is not None:
input_container.seek(0)
for audio_frame in input_container.decode(audio=0):
if trimming:
if audio_frame.time is None or audio_frame.time < start_time:
continue
if duration and audio_frame.time > start_time + duration:
break
# Carry odd audio time bases the mp4 muxer rejects; reset pts, encoder assigns clean ones (MP3-in-AVI)
audio_frame.pts = None
for packet in audio_stream.encode(audio_frame):
output_container.mux(packet)
for packet in audio_stream.encode():

View File

@ -11,7 +11,7 @@ class TextEncodeAceStepAudio(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TextEncodeAceStepAudio",
category="conditioning",
category="model/conditioning",
inputs=[
IO.Clip.Input("clip"),
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
@ -33,7 +33,7 @@ class TextEncodeAceStepAudio15(IO.ComfyNode):
def define_schema(cls):
return IO.Schema(
node_id="TextEncodeAceStepAudio1.5",
category="conditioning",
category="model/conditioning",
inputs=[
IO.Clip.Input("clip"),
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
@ -67,7 +67,7 @@ class EmptyAceStepLatentAudio(IO.ComfyNode):
return IO.Schema(
node_id="EmptyAceStepLatentAudio",
display_name="Empty Ace Step 1.0 Latent Audio",
category="latent/audio",
category="model/latent/audio",
inputs=[
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
IO.Int.Input(
@ -90,7 +90,7 @@ class EmptyAceStep15LatentAudio(IO.ComfyNode):
return IO.Schema(
node_id="EmptyAceStep1.5LatentAudio",
display_name="Empty Ace Step 1.5 Latent Audio",
category="latent/audio",
category="model/latent/audio",
inputs=[
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
IO.Int.Input(

View File

@ -45,7 +45,7 @@ class SamplerLCMUpscale(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SamplerLCMUpscale",
category="sampling/samplers",
category="model/sampling/samplers",
inputs=[
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01, advanced=True),
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1, advanced=True),
@ -91,7 +91,7 @@ class SamplerLCM(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="SamplerLCM",
category="sampling/samplers",
category="model/sampling/samplers",
description=("LCM sampler with tunable per-step noise. s_noise is a multiplier on the model's training noise scale"),
inputs=[
io.Float.Input("s_noise", default=1.0, min=0.0, max=64.0, step=0.01,

Some files were not shown because too many files have changed in this diff Show More