mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-10 00:07:33 +08:00
Merge branch 'master' into matt/asset-image-dimensions-metadata
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
This commit is contained in:
commit
7a2594ac5c
@ -1,5 +1,4 @@
|
||||
As of the time of writing this you need this driver for best results:
|
||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
|
||||
As of the time of writing this you need a recent driver. Updating to the latest driver is recommended.
|
||||
|
||||
HOW TO RUN:
|
||||
|
||||
@ -7,9 +6,9 @@ If you have a AMD gpu:
|
||||
|
||||
run_amd_gpu.bat
|
||||
|
||||
If you have memory issues you can try disabling the smart memory management by running comfyui with:
|
||||
If you have memory issues you can try enabling the new dynamic memory management by running comfyui with:
|
||||
|
||||
run_amd_gpu_disable_smart_memory.bat
|
||||
run_amd_gpu_enable_dynamic_vram.bat
|
||||
|
||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||
|
||||
|
||||
2
.github/workflows/check-line-endings.yml
vendored
2
.github/workflows/check-line-endings.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
||||
- name: Check for Windows line endings (CRLF)
|
||||
run: |
|
||||
# Get the list of changed files in the PR
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} -- ':!.ci')
|
||||
|
||||
# Flag to track if CRLF is found
|
||||
CRLF_FOUND=false
|
||||
|
||||
24
.github/workflows/detect-unreviewed-merge.yml
vendored
Normal file
24
.github/workflows/detect-unreviewed-merge.yml
vendored
Normal file
@ -0,0 +1,24 @@
|
||||
name: Detect Unreviewed Merge
|
||||
|
||||
# SOC 2 compliance — reusable workflow lives in Comfy-Org/github-workflows,
|
||||
# tracking issues are filed in Comfy-Org/unreviewed-merges.
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [master]
|
||||
|
||||
concurrency:
|
||||
group: detect-unreviewed-merge-${{ github.sha }}
|
||||
cancel-in-progress: false
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: read
|
||||
|
||||
jobs:
|
||||
detect:
|
||||
uses: Comfy-Org/github-workflows/.github/workflows/detect-unreviewed-merge.yml@4d9cb6b87f953bb7cd69954280e1465fb9bd2040 # v1
|
||||
with:
|
||||
approval-mode: latest-per-reviewer
|
||||
secrets:
|
||||
UNREVIEWED_MERGES_TOKEN: ${{ secrets.UNREVIEWED_MERGES_TOKEN }}
|
||||
@ -105,7 +105,7 @@ class WindowAttention(nn.Module):
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
relative_position_bias = comfy.ops.cast_to_input(relative_position_bias.permute(2, 0, 1).contiguous(), attn) # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
|
||||
@ -55,12 +55,7 @@ class BackgroundRemovalModel():
|
||||
out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False)
|
||||
|
||||
mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1] != 1:
|
||||
mask = mask.movedim(-1, 1)
|
||||
|
||||
return mask
|
||||
return mask.squeeze(1) # (B, 1, H, W) -> (B, H, W)
|
||||
|
||||
|
||||
def load_background_removal_model(sd):
|
||||
|
||||
@ -149,6 +149,7 @@ parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=Non
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
|
||||
parser.add_argument("--fast-disk", action="store_true", help="Prefer disk-backed dynamic loading and offload over unpinned RAM. Can be faster for users with fast NVME disks.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
import comfy.image_encoders.dino3
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@ -23,12 +24,16 @@ IMAGE_ENCODERS = {
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
"dinov3": comfy.image_encoders.dino3.DINOv3ViTModel,
|
||||
}
|
||||
|
||||
class ClipVisionModel():
|
||||
def __init__(self, json_config):
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
if isinstance(json_config, dict):
|
||||
config = json_config
|
||||
else:
|
||||
with open(json_config) as f:
|
||||
config = json.load(f)
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
|
||||
@ -134,6 +139,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||
elif 'layer.0.mlp.gate_proj.weight' in sd and 'layer.31.norm1.weight' in sd: # Dinov3 ViT-H/16+ (SwiGLU gated MLP, 32 layers)
|
||||
json_config = comfy.image_encoders.dino3.DINOV3_VITH_CONFIG
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@ -1,5 +1,20 @@
|
||||
import logging
|
||||
|
||||
import torch
|
||||
|
||||
_CK_STOCHASTIC_ROUNDING_AVAILABLE = False
|
||||
try:
|
||||
import comfy_kitchen as ck
|
||||
_ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8
|
||||
_CK_STOCHASTIC_ROUNDING_AVAILABLE = True
|
||||
except (AttributeError, ImportError):
|
||||
logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.")
|
||||
|
||||
if not _CK_STOCHASTIC_ROUNDING_AVAILABLE:
|
||||
def _ck_stochastic_rounding_fp8(value, rng, dtype):
|
||||
raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding")
|
||||
|
||||
|
||||
def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None):
|
||||
mantissa_scaled = torch.where(
|
||||
normal_mask,
|
||||
@ -57,6 +72,10 @@ def stochastic_rounding(value, dtype, seed=0):
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
generator = torch.Generator(device=value.device)
|
||||
generator.manual_seed(seed)
|
||||
if _CK_STOCHASTIC_ROUNDING_AVAILABLE:
|
||||
rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator)
|
||||
return _ck_stochastic_rounding_fp8(value, rng, dtype)
|
||||
|
||||
output = torch.empty_like(value, dtype=dtype)
|
||||
num_slices = max(1, (value.numel() / (4096 * 4096)))
|
||||
slice_size = max(1, round(value.shape[0] / num_slices))
|
||||
|
||||
259
comfy/image_encoders/dino3.py
Normal file
259
comfy/image_encoders/dino3.py
Normal file
@ -0,0 +1,259 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.ops
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
|
||||
|
||||
|
||||
# DINOv3 ViT-H/16+ (SwiGLU)
|
||||
DINOV3_VITH_CONFIG = {
|
||||
"model_type": "dinov3",
|
||||
"num_hidden_layers": 32,
|
||||
"hidden_size": 1280,
|
||||
"num_attention_heads": 20,
|
||||
"num_register_tokens": 4,
|
||||
"intermediate_size": 5120,
|
||||
"layer_norm_eps": 1e-5,
|
||||
"num_channels": 3,
|
||||
"patch_size": 16,
|
||||
"rope_theta": 100.0,
|
||||
"use_gated_mlp": True,
|
||||
"gated_mlp_act": "silu",
|
||||
"image_size": 1024,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225],
|
||||
}
|
||||
|
||||
|
||||
class DINOv3ViTMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.act_fn = torch.nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
|
||||
num_tokens = q.shape[-2]
|
||||
num_patches = sin.shape[-2]
|
||||
num_prefix_tokens = num_tokens - num_patches
|
||||
|
||||
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
|
||||
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
|
||||
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
|
||||
|
||||
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
|
||||
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
|
||||
|
||||
return q, k
|
||||
|
||||
|
||||
class DINOv3ViTAttention(nn.Module):
|
||||
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.embed_dim = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
|
||||
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, position_embeddings=None, **kwargs):
|
||||
batch_size, patches, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
attn = optimized_attention_for_device(query_states.device, mask=False)
|
||||
attn_output = attn(
|
||||
query_states, key_states, value_states, self.num_heads, attention_mask,
|
||||
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
class DINOv3ViTGatedMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations, act="silu"):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.act_fn = torch.nn.SiLU() if act == "silu" else torch.nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
def get_patches_center_coordinates(num_patches_h, num_patches_w, dtype, device):
|
||||
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
|
||||
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
|
||||
coords_h = coords_h / num_patches_h
|
||||
coords_w = coords_w / num_patches_w
|
||||
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
||||
coords = coords.flatten(0, 1)
|
||||
coords = 2.0 * coords - 1.0
|
||||
return coords
|
||||
|
||||
|
||||
class DINOv3ViTRopePositionEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor
|
||||
|
||||
def __init__(self, rope_theta, hidden_size, num_attention_heads, patch_size, device, dtype):
|
||||
super().__init__()
|
||||
self.base = rope_theta
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
self.patch_size = patch_size
|
||||
|
||||
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, pixel_values):
|
||||
_, _, height, width = pixel_values.shape
|
||||
num_patches_h = height // self.patch_size
|
||||
num_patches_w = width // self.patch_size
|
||||
|
||||
patch_coords = get_patches_center_coordinates(num_patches_h, num_patches_w, dtype=torch.float32, device=pixel_values.device)
|
||||
self.inv_freq = self.inv_freq.to(pixel_values.device)
|
||||
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
|
||||
angles = angles.flatten(1, 2)
|
||||
angles = angles.tile(2)
|
||||
cos = torch.cos(angles).to(dtype=pixel_values.dtype)
|
||||
sin = torch.sin(angles).to(dtype=pixel_values.dtype)
|
||||
return cos, sin
|
||||
|
||||
|
||||
class DINOv3ViTEmbeddings(nn.Module):
|
||||
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
|
||||
self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
|
||||
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
|
||||
self.patch_embeddings = operations.Conv2d(
|
||||
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, pixel_values, bool_masked_pos=None):
|
||||
batch_size = pixel_values.shape[0]
|
||||
|
||||
patch_embeddings = self.patch_embeddings(pixel_values)
|
||||
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
|
||||
|
||||
if bool_masked_pos is not None:
|
||||
mask_token = comfy.ops.cast_to_input(self.mask_token, patch_embeddings)
|
||||
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
|
||||
|
||||
cls_token = comfy.ops.cast_to_input(self.cls_token.expand(batch_size, -1, -1), patch_embeddings)
|
||||
register_tokens = comfy.ops.cast_to_input(self.register_tokens.expand(batch_size, -1, -1), patch_embeddings)
|
||||
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
|
||||
return embeddings
|
||||
|
||||
|
||||
class DINOv3ViTLayer(nn.Module):
|
||||
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size,
|
||||
num_attention_heads, device, dtype, operations, gated_mlp_act="silu"):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||
|
||||
self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
if use_gated_mlp:
|
||||
self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations, act=gated_mlp_act)
|
||||
else:
|
||||
self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations)
|
||||
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.attention(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings)
|
||||
hidden_states = self.layer_scale1(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.layer_scale2(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DINOv3ViTModel(nn.Module):
|
||||
def __init__(self, config, dtype, device, operations):
|
||||
super().__init__()
|
||||
num_hidden_layers = config["num_hidden_layers"]
|
||||
hidden_size = config["hidden_size"]
|
||||
num_attention_heads = config["num_attention_heads"]
|
||||
num_register_tokens = config["num_register_tokens"]
|
||||
intermediate_size = config["intermediate_size"]
|
||||
layer_norm_eps = config["layer_norm_eps"]
|
||||
num_channels = config["num_channels"]
|
||||
patch_size = config["patch_size"]
|
||||
rope_theta = config["rope_theta"]
|
||||
use_gated_mlp = config.get("use_gated_mlp", False)
|
||||
gated_mlp_act = config.get("gated_mlp_act", "silu")
|
||||
|
||||
self.embeddings = DINOv3ViTEmbeddings(
|
||||
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
|
||||
rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device
|
||||
)
|
||||
self.layer = nn.ModuleList([
|
||||
DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=use_gated_mlp, mlp_bias=True,
|
||||
intermediate_size=intermediate_size, num_attention_heads=num_attention_heads,
|
||||
dtype=dtype, device=device, operations=operations, gated_mlp_act=gated_mlp_act)
|
||||
for _ in range(num_hidden_layers)])
|
||||
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def forward(self, pixel_values, bool_masked_pos=None, **kwargs):
|
||||
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
for layer_module in self.layer:
|
||||
hidden_states = layer_module(hidden_states, position_embeddings=position_embeddings)
|
||||
|
||||
if kwargs.get("skip_norm_elementwise", False):
|
||||
sequence_output = F.layer_norm(hidden_states, hidden_states.shape[-1:])
|
||||
else:
|
||||
norm = self.norm.to(hidden_states.device)
|
||||
sequence_output = norm(hidden_states)
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
return sequence_output, None, pooled_output, None
|
||||
@ -4,6 +4,7 @@ class LatentFormat:
|
||||
scale_factor = 1.0
|
||||
latent_channels = 4
|
||||
latent_dimensions = 2
|
||||
preserve_empty_channel_multiples = False
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
latent_rgb_factors_reshape = None
|
||||
@ -239,6 +240,16 @@ class Flux2(LatentFormat):
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class TripoSplat(LatentFormat):
|
||||
# Sequence latent (B, 8192, 16) the camera token rides alongside as a second nested latent
|
||||
latent_channels = 16
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
@ -769,6 +780,10 @@ class ACEAudio(LatentFormat):
|
||||
latent_channels = 8
|
||||
latent_dimensions = 2
|
||||
|
||||
class SeedVR2(LatentFormat):
|
||||
latent_channels = 16
|
||||
preserve_empty_channel_multiples = True
|
||||
|
||||
class ACEAudio15(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
@ -433,11 +433,11 @@ class Attention(nn.Module):
|
||||
if self.differential:
|
||||
q, q_diff = q.unbind(dim=1)
|
||||
k, k_diff = k.unbind(dim=1)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
out = out - out_diff
|
||||
else:
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options)
|
||||
out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options)
|
||||
|
||||
out = self.to_out(out)
|
||||
|
||||
|
||||
@ -138,11 +138,11 @@ class Attention(nn.Module):
|
||||
k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype)
|
||||
|
||||
if self.differential:
|
||||
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
|
||||
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True))
|
||||
out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
|
||||
- optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True, low_precision_attention=False))
|
||||
del q, k, v, q_diff, k_diff
|
||||
else:
|
||||
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
|
||||
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)
|
||||
del q, k, v
|
||||
|
||||
return self.to_out(out)
|
||||
|
||||
@ -38,6 +38,8 @@ class ChromaRadianceParams(ChromaParams):
|
||||
# None means use the same dtype as the model.
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
use_x0: bool
|
||||
# Use sequential txt_ids instead of zeros
|
||||
use_sequential_txt_ids: bool
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
@ -162,6 +164,9 @@ class ChromaRadiance(Chroma):
|
||||
if params.use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
if params.use_sequential_txt_ids:
|
||||
self.register_buffer("__sequential__", torch.tensor([]))
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
@ -313,6 +318,9 @@ class ChromaRadiance(Chroma):
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
# Radiance after 2026-05-22 uses sequential txt_ids instead of zeros
|
||||
if params.use_sequential_txt_ids:
|
||||
txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1)
|
||||
|
||||
img_out = self.forward_orig(
|
||||
img,
|
||||
|
||||
@ -14,15 +14,7 @@ from torchvision import transforms
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
def apply_rotary_pos_emb(
|
||||
t: torch.Tensor,
|
||||
freqs: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float()
|
||||
t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1]
|
||||
t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t)
|
||||
return t_out
|
||||
import comfy.quant_ops
|
||||
|
||||
|
||||
# ---------------------- Feed Forward Network -----------------------
|
||||
@ -173,8 +165,7 @@ class Attention(nn.Module):
|
||||
k = self.k_norm(k)
|
||||
v = self.v_norm(v)
|
||||
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
|
||||
q = apply_rotary_pos_emb(q, rope_emb)
|
||||
k = apply_rotary_pos_emb(k, rope_emb)
|
||||
q, k = comfy.quant_ops.ck.apply_rope_split_half(q, k, rope_emb)
|
||||
return q, k, v
|
||||
|
||||
q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb)
|
||||
|
||||
@ -5,6 +5,7 @@ import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import comfy.quant_ops
|
||||
|
||||
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
@ -19,15 +20,6 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
out = torch.stack([torch.cos(out), torch.sin(out)], dim=0)
|
||||
return out.to(dtype=torch.float32, device=pos.device)
|
||||
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
rot_dim = freqs_cis.shape[-1]
|
||||
x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:]
|
||||
cos_ = freqs_cis[0]
|
||||
sin_ = freqs_cis[1]
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
x_rotated = torch.cat((-x2, x1), dim=-1)
|
||||
return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1)
|
||||
|
||||
class ErnieImageEmbedND3(nn.Module):
|
||||
def __init__(self, dim: int, theta: int, axes_dim: tuple):
|
||||
super().__init__()
|
||||
@ -37,8 +29,16 @@ class ErnieImageEmbedND3(nn.Module):
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1)
|
||||
emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2]
|
||||
return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim]
|
||||
cos_ = emb[0]
|
||||
sin_ = emb[1]
|
||||
N = cos_.shape[-1]
|
||||
half = N // 2
|
||||
cos_top = cos_[..., :half].repeat_interleave(2, dim=-1)
|
||||
sin_top = sin_[..., :half].repeat_interleave(2, dim=-1)
|
||||
cos_bot = cos_[..., half:].repeat_interleave(2, dim=-1)
|
||||
sin_bot = sin_[..., half:].repeat_interleave(2, dim=-1)
|
||||
rot = torch.stack([cos_top, -sin_top, sin_bot, cos_bot], dim=-1)
|
||||
return rot.reshape(*rot.shape[:-1], 2, 2).unsqueeze(2)
|
||||
|
||||
class ErnieImagePatchEmbedDynamic(nn.Module):
|
||||
def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None):
|
||||
@ -115,8 +115,7 @@ class ErnieImageAttention(nn.Module):
|
||||
key = self.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb)
|
||||
key = apply_rotary_emb(key, image_rotary_emb)
|
||||
query, key = comfy.quant_ops.ck.apply_rope_split_half(query, key, image_rotary_emb)
|
||||
|
||||
q_flat = query.reshape(B, S, -1)
|
||||
k_flat = key.reshape(B, S, -1)
|
||||
@ -274,7 +273,7 @@ class ErnieImageModel(nn.Module):
|
||||
|
||||
image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1)
|
||||
|
||||
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype)
|
||||
rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1))
|
||||
del image_ids, text_ids
|
||||
|
||||
sample = self.time_proj(timesteps).to(dtype)
|
||||
|
||||
@ -4,7 +4,7 @@ from torch import Tensor
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
import comfy.model_management
|
||||
import logging
|
||||
import comfy.quant_ops
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
|
||||
@ -44,21 +44,15 @@ def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
|
||||
|
||||
try:
|
||||
import comfy.quant_ops
|
||||
q_apply_rope = comfy.quant_ops.ck.apply_rope
|
||||
q_apply_rope1 = comfy.quant_ops.ck.apply_rope1
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope(xq, xk, freqs_cis)
|
||||
else:
|
||||
return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis)
|
||||
def apply_rope1(x, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope1(x, freqs_cis)
|
||||
else:
|
||||
return q_apply_rope1(x, freqs_cis)
|
||||
except:
|
||||
logging.warning("No comfy kitchen, using old apply_rope functions.")
|
||||
apply_rope = _apply_rope
|
||||
apply_rope1 = _apply_rope1
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope(xq, xk, freqs_cis)
|
||||
else:
|
||||
return comfy.quant_ops.ck.apply_rope(xq, xk, freqs_cis)
|
||||
|
||||
|
||||
def apply_rope1(x, freqs_cis):
|
||||
if comfy.model_management.in_training:
|
||||
return _apply_rope1(x, freqs_cis)
|
||||
else:
|
||||
return comfy.quant_ops.ck.apply_rope1(x, freqs_cis)
|
||||
|
||||
297
comfy/ldm/ideogram4/model.py
Normal file
297
comfy/ldm/ideogram4/model.py
Normal file
@ -0,0 +1,297 @@
|
||||
"""
|
||||
The Ideogram 4 transformer is a NextDiT/Lumina2-family single-stream model
|
||||
consumes Qwen3-VL hidden-state features (concatenated from 13 layers -> 53248 dims)
|
||||
packs ``[text tokens, image tokens]`` into one sequence with block-diagonal segment attention and 3D interleaved MRoPE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.patcher_extension
|
||||
from comfy.ldm.lumina.model import FeedForward
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.text_encoders.llama import apply_rope, precompute_freqs_cis
|
||||
|
||||
# Per-token role indicators
|
||||
SEQUENCE_PADDING_INDICATOR = -1
|
||||
OUTPUT_IMAGE_INDICATOR = 2
|
||||
LLM_TOKEN_INDICATOR = 3
|
||||
# Image grid coordinates are offset so they never collide with text positions
|
||||
IMAGE_POSITION_OFFSET = 65536
|
||||
|
||||
|
||||
class Ideogram4Attention(nn.Module):
|
||||
def __init__(self, hidden_size, num_heads, eps=1e-5, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.qkv = operations.Linear(hidden_size, hidden_size * 3, bias=False, dtype=dtype, device=device)
|
||||
self.norm_q = operations.RMSNorm(self.head_dim, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||
self.norm_k = operations.RMSNorm(self.head_dim, eps=eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||
self.o = operations.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, attn_mask, freqs_cis, transformer_options={}):
|
||||
batch_size, seq_len, _ = x.shape
|
||||
qkv = self.qkv(x).view(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
|
||||
q = self.norm_q(q)
|
||||
k = self.norm_k(k)
|
||||
|
||||
# (B, heads, L, head_dim)
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
q, k = apply_rope(q, k, freqs_cis)
|
||||
|
||||
out = optimized_attention_masked(q, k, v, self.num_heads, attn_mask, skip_reshape=True, transformer_options=transformer_options)
|
||||
return self.o(out)
|
||||
|
||||
|
||||
class Ideogram4TransformerBlock(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, num_heads, norm_eps, adaln_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.attention = Ideogram4Attention(hidden_size, num_heads, eps=1e-5, dtype=dtype, device=device, operations=operations)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=hidden_size, hidden_dim=intermediate_size, multiple_of=1, ffn_dim_multiplier=None,
|
||||
operation_settings={"operations": operations, "dtype": dtype, "device": device},
|
||||
)
|
||||
|
||||
self.attention_norm1 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||
self.ffn_norm1 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||
self.attention_norm2 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||
self.ffn_norm2 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device)
|
||||
|
||||
self.adaln_modulation = operations.Linear(adaln_dim, 4 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, attn_mask, freqs_cis, adaln_input, transformer_options={}):
|
||||
mod = self.adaln_modulation(adaln_input)
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.chunk(4, dim=-1)
|
||||
gate_msa = torch.tanh(gate_msa)
|
||||
gate_mlp = torch.tanh(gate_mlp)
|
||||
scale_msa = 1.0 + scale_msa
|
||||
scale_mlp = 1.0 + scale_mlp
|
||||
|
||||
attn_out = self.attention(self.attention_norm1(x) * scale_msa, attn_mask, freqs_cis, transformer_options=transformer_options)
|
||||
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
|
||||
return x
|
||||
|
||||
|
||||
def _sinusoidal_embedding(t, dim, scale=1e4):
|
||||
t = t.to(torch.float32)
|
||||
half = dim // 2
|
||||
freq = math.log(scale) / (half - 1)
|
||||
freq = torch.exp(torch.arange(half, dtype=torch.float32, device=t.device) * -freq)
|
||||
emb = t.unsqueeze(-1) * freq
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
if dim % 2 == 1:
|
||||
emb = F.pad(emb, (0, 1))
|
||||
return emb
|
||||
|
||||
|
||||
class Ideogram4EmbedScalar(nn.Module):
|
||||
def __init__(self, dim, input_range=(0.0, 1.0), dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.range_min, self.range_max = input_range
|
||||
self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||
self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.to(torch.float32)
|
||||
scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min)
|
||||
emb = _sinusoidal_embedding(scaled, self.dim)
|
||||
emb = emb.to(self.mlp_in.weight.dtype)
|
||||
emb = F.silu(self.mlp_in(emb))
|
||||
return self.mlp_out(emb)
|
||||
|
||||
|
||||
class Ideogram4FinalLayer(nn.Module):
|
||||
def __init__(self, hidden_size, out_channels, adaln_dim, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaln_modulation = operations.Linear(adaln_dim, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, c):
|
||||
scale = 1.0 + self.adaln_modulation(F.silu(c))
|
||||
return self.linear(self.norm_final(x) * scale)
|
||||
|
||||
|
||||
class Ideogram4Transformer(nn.Module):
|
||||
"""A single Ideogram 4 backbone operating on a packed token sequence."""
|
||||
|
||||
def __init__(self, emb_dim, num_layers, num_heads, intermediate_size, adaln_dim,
|
||||
in_channels, llm_features_dim, rope_theta, mrope_section, norm_eps,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.head_dim = emb_dim // num_heads
|
||||
self.rope_theta = rope_theta
|
||||
self.mrope_section = tuple(mrope_section)
|
||||
|
||||
self.input_proj = operations.Linear(in_channels, emb_dim, bias=True, dtype=dtype, device=device)
|
||||
self.llm_cond_norm = operations.RMSNorm(llm_features_dim, eps=1e-6, elementwise_affine=True, dtype=dtype, device=device)
|
||||
self.llm_cond_proj = operations.Linear(llm_features_dim, emb_dim, bias=True, dtype=dtype, device=device)
|
||||
self.t_embedding = Ideogram4EmbedScalar(emb_dim, input_range=(0.0, 1.0), dtype=dtype, device=device, operations=operations)
|
||||
self.adaln_proj = operations.Linear(emb_dim, adaln_dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
self.embed_image_indicator = operations.Embedding(2, emb_dim, dtype=dtype, device=device)
|
||||
|
||||
self.layers = nn.ModuleList([
|
||||
Ideogram4TransformerBlock(emb_dim, intermediate_size, num_heads, norm_eps, adaln_dim,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_layers)
|
||||
])
|
||||
|
||||
self.final_layer = Ideogram4FinalLayer(emb_dim, in_channels, adaln_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def _backbone(self, llm_features, x, t, position_ids, attn_mask, indicator, transformer_options={}):
|
||||
indicator = indicator.to(torch.long)
|
||||
output_image_mask = (indicator == OUTPUT_IMAGE_INDICATOR).to(x.dtype).unsqueeze(-1)
|
||||
|
||||
x = x * output_image_mask
|
||||
h = self.input_proj(x) * output_image_mask
|
||||
|
||||
t_cond = self.t_embedding(t)
|
||||
if t.dim() == 1:
|
||||
t_cond = t_cond.unsqueeze(1)
|
||||
adaln_input = F.silu(self.adaln_proj(t_cond))
|
||||
|
||||
# h is zero on the text rows (content lives only on image rows), add writes the text features in place
|
||||
if llm_features is not None:
|
||||
L_text = llm_features.shape[1]
|
||||
text_mask = (indicator[:, :L_text] == LLM_TOKEN_INDICATOR).to(x.dtype).unsqueeze(-1)
|
||||
llm = self.llm_cond_norm(llm_features * text_mask)
|
||||
llm = self.llm_cond_proj(llm) * text_mask
|
||||
h[:, :L_text] = h[:, :L_text] + llm
|
||||
|
||||
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long), out_dtype=h.dtype)
|
||||
|
||||
# Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch).
|
||||
freqs_cis = precompute_freqs_cis(
|
||||
self.head_dim, position_ids[0].transpose(0, 1), self.rope_theta,
|
||||
rope_dims=self.mrope_section, interleaved_mrope=True, device=position_ids.device,
|
||||
)
|
||||
|
||||
if attn_mask is not None and attn_mask.dtype == torch.bool:
|
||||
attn_mask = torch.zeros_like(attn_mask, dtype=h.dtype).masked_fill_(~attn_mask, -torch.finfo(h.dtype).max)
|
||||
|
||||
for layer in self.layers:
|
||||
h = layer(h, attn_mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
|
||||
return self.final_layer(h, adaln_input)
|
||||
|
||||
|
||||
class Ideogram4Transformer2DModel(Ideogram4Transformer):
|
||||
"""Ideogram 4 single-stream DiT.
|
||||
|
||||
Runs a packed ``[text, image]`` sequence when text context is supplied, or an image-only sequence when ``context is None``.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, in_channels=128, num_layers=34, num_attention_heads=18, attention_head_dim=256, intermediate_size=12288,
|
||||
adaln_dim=512, llm_features_dim=53248, rope_theta=5000000, mrope_section=(24, 20, 20), norm_eps=1e-5,
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
emb_dim = num_attention_heads * attention_head_dim
|
||||
super().__init__(
|
||||
emb_dim=emb_dim, num_layers=num_layers, num_heads=num_attention_heads,
|
||||
intermediate_size=intermediate_size, adaln_dim=adaln_dim, in_channels=in_channels,
|
||||
llm_features_dim=llm_features_dim, rope_theta=rope_theta, mrope_section=mrope_section,
|
||||
norm_eps=norm_eps, dtype=dtype, device=device, operations=operations)
|
||||
self.dtype = dtype
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
# 128-dim token = patch (2x2) * ae_channels (32).
|
||||
self.patch_size = 2
|
||||
self.ae_channels = in_channels // (self.patch_size * self.patch_size)
|
||||
|
||||
def _img_to_tokens(self, x):
|
||||
B, C, gh, gw = x.shape
|
||||
x = x.view(B, self.ae_channels, self.patch_size, self.patch_size, gh, gw)
|
||||
x = x.permute(0, 4, 5, 2, 3, 1) # (B, gh, gw, pi, pj, c)
|
||||
return x.reshape(B, gh * gw, C)
|
||||
|
||||
def _tokens_to_img(self, tokens, gh, gw):
|
||||
B = tokens.shape[0]
|
||||
C = tokens.shape[-1]
|
||||
x = tokens.reshape(B, gh, gw, self.patch_size, self.patch_size, self.ae_channels)
|
||||
x = x.permute(0, 5, 3, 4, 1, 2) # (B, c, pi, pj, gh, gw)
|
||||
return x.reshape(B, C, gh, gw)
|
||||
|
||||
def _image_position_ids(self, gh, gw, device):
|
||||
h_idx = torch.arange(gh, device=device).view(-1, 1).expand(gh, gw).reshape(-1)
|
||||
w_idx = torch.arange(gw, device=device).view(1, -1).expand(gh, gw).reshape(-1)
|
||||
t_idx = torch.zeros_like(h_idx)
|
||||
return torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET # (L_img, 3)
|
||||
|
||||
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options):
|
||||
B = x_chunk.shape[0]
|
||||
device = x_chunk.device
|
||||
img_tokens = self._img_to_tokens(x_chunk)
|
||||
L_img = img_tokens.shape[1]
|
||||
L_text = context_chunk.shape[1]
|
||||
L = L_text + L_img
|
||||
latent_dim = img_tokens.shape[-1]
|
||||
|
||||
x_full = torch.zeros(B, L, latent_dim, dtype=img_tokens.dtype, device=device)
|
||||
x_full[:, L_text:] = img_tokens
|
||||
|
||||
text_pos = torch.arange(L_text, device=device).view(-1, 1).expand(L_text, 3)
|
||||
img_pos = self._image_position_ids(gh, gw, device)
|
||||
position_ids = torch.cat([text_pos, img_pos], dim=0).unsqueeze(0).expand(B, L, 3)
|
||||
|
||||
indicator = torch.empty(B, L, dtype=torch.long, device=device)
|
||||
indicator[:, :L_text] = LLM_TOKEN_INDICATOR
|
||||
indicator[:, L_text:] = OUTPUT_IMAGE_INDICATOR
|
||||
|
||||
attn_mask = None
|
||||
if attn_mask_chunk is not None:
|
||||
segment_ids = torch.ones(B, L, dtype=torch.long, device=device)
|
||||
pad = (attn_mask_chunk == 0)
|
||||
segment_ids[:, :L_text][pad] = SEQUENCE_PADDING_INDICATOR
|
||||
indicator[:, :L_text][pad] = 0
|
||||
# Block-diagonal mask from segment ids: (B, 1, L, L), True = attend.
|
||||
attn_mask = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).unsqueeze(1)
|
||||
|
||||
out = self._backbone(context_chunk, x_full, t_chunk, position_ids, attn_mask, indicator,
|
||||
transformer_options=transformer_options)
|
||||
return self._tokens_to_img(out[:, L_text:], gh, gw)
|
||||
|
||||
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options):
|
||||
B = x_chunk.shape[0]
|
||||
device = x_chunk.device
|
||||
img_tokens = self._img_to_tokens(x_chunk)
|
||||
L_img = img_tokens.shape[1]
|
||||
|
||||
position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3)
|
||||
indicator = torch.full((B, L_img), OUTPUT_IMAGE_INDICATOR, dtype=torch.long, device=device)
|
||||
|
||||
# Image-only sequence is a single segment -> no mask, full attention, no LLM context.
|
||||
out = self._backbone(None, img_tokens, t_chunk, position_ids, None, indicator, transformer_options=transformer_options)
|
||||
return self._tokens_to_img(out, gh, gw)
|
||||
|
||||
def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options),
|
||||
).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs):
|
||||
bs, c, gh, gw = x.shape
|
||||
|
||||
timesteps = 1.0 - timesteps
|
||||
|
||||
# unconditional pass
|
||||
if context is None:
|
||||
return -self._run_image_only(x, timesteps, gh, gw, transformer_options)
|
||||
|
||||
return -self._run_conditional(x, context, attention_mask, timesteps, gh, gw, transformer_options)
|
||||
@ -735,7 +735,86 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
)
|
||||
return out
|
||||
|
||||
def _var_attention_qkv(q, k, v, heads, skip_reshape):
|
||||
if skip_reshape:
|
||||
return q, k, v, q.shape[-1]
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
return (
|
||||
q.view(total_tokens, heads, head_dim),
|
||||
k.view(k.shape[0], heads, head_dim),
|
||||
v.view(v.shape[0], heads, head_dim),
|
||||
head_dim,
|
||||
)
|
||||
|
||||
|
||||
def _var_attention_output(out, heads, head_dim, skip_output_reshape):
|
||||
if skip_output_reshape:
|
||||
return out
|
||||
return out.reshape(-1, heads * head_dim)
|
||||
|
||||
|
||||
def _use_blackwell_attention():
|
||||
device = model_management.get_torch_device()
|
||||
if device.type != "cuda":
|
||||
return False
|
||||
major, minor = torch.cuda.get_device_capability(device)
|
||||
return (major, minor) >= (12, 0)
|
||||
|
||||
|
||||
def _validate_split_cu_seqlens(name, cu_seqlens, token_count):
|
||||
if cu_seqlens.dtype not in (torch.int32, torch.int64):
|
||||
raise ValueError(f"{name} must use an integer dtype")
|
||||
if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2:
|
||||
raise ValueError(f"{name} must be a 1D tensor with at least two offsets")
|
||||
if cu_seqlens[0].item() != 0:
|
||||
raise ValueError(f"{name} must start at 0")
|
||||
if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item():
|
||||
raise ValueError(f"{name} must be strictly increasing")
|
||||
if cu_seqlens[-1].item() != token_count:
|
||||
raise ValueError(f"{name} does not match token count")
|
||||
|
||||
|
||||
def _split_indices(cu_seqlens):
|
||||
return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long)
|
||||
|
||||
|
||||
def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape)
|
||||
|
||||
_validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0])
|
||||
_validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0])
|
||||
if cu_seqlens_k[-1].item() != v.shape[0]:
|
||||
raise ValueError("cu_seqlens_k does not match v token count")
|
||||
|
||||
q_split_indices = _split_indices(cu_seqlens_q)
|
||||
k_split_indices = _split_indices(cu_seqlens_k)
|
||||
q_splits = torch.tensor_split(q, q_split_indices, dim=0)
|
||||
k_splits = torch.tensor_split(k, k_split_indices, dim=0)
|
||||
v_splits = torch.tensor_split(v, k_split_indices, dim=0)
|
||||
if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits):
|
||||
raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count")
|
||||
|
||||
out = []
|
||||
for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits):
|
||||
q_i = q_i.permute(1, 0, 2).unsqueeze(0)
|
||||
k_i = k_i.permute(1, 0, 2).unsqueeze(0)
|
||||
v_i = v_i.permute(1, 0, 2).unsqueeze(0)
|
||||
out_dtype = q_i.dtype
|
||||
if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16):
|
||||
q_i = q_i.to(torch.bfloat16)
|
||||
k_i = k_i.to(torch.bfloat16)
|
||||
v_i = v_i.to(torch.bfloat16)
|
||||
out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True)
|
||||
if out_i.dtype != out_dtype:
|
||||
out_i = out_i.to(out_dtype)
|
||||
out.append(out_i.squeeze(0).permute(1, 0, 2))
|
||||
|
||||
out = torch.cat(out, dim=0)
|
||||
return _var_attention_output(out, heads, head_dim, skip_output_reshape)
|
||||
|
||||
|
||||
optimized_var_attention = var_attention_optimized_split
|
||||
optimized_attention = attention_basic
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
@ -758,6 +837,8 @@ else:
|
||||
logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
logging.info("Using optimized_attention split-loop for variable-length attention")
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
|
||||
@ -773,6 +854,7 @@ if model_management.xformers_enabled():
|
||||
register_attention_function("pytorch", attention_pytorch)
|
||||
register_attention_function("sub_quad", attention_sub_quad)
|
||||
register_attention_function("split", attention_split)
|
||||
register_attention_function("var_attention_optimized_split", var_attention_optimized_split)
|
||||
|
||||
|
||||
def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
@ -1209,5 +1291,3 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
x = self.proj_out(x)
|
||||
out = x + x_in
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
|
||||
def torch_cat_if_needed(xl, dim):
|
||||
xl = [x for x in xl if x is not None and x.shape[dim] > 0]
|
||||
if len(xl) > 1:
|
||||
@ -22,7 +23,8 @@ def torch_cat_if_needed(xl, dim):
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
@ -33,11 +35,13 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = math.log(10000) / (half_dim - downscale_freq_shift)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||
return emb
|
||||
|
||||
@ -207,8 +207,9 @@ class PidNet(PixDiT_T2I):
|
||||
f"Flux1/SD3 = 16 channels, Flux2 = 128 channels."
|
||||
)
|
||||
B = x.shape[0]
|
||||
Hs = x.shape[2] // self.patch_size
|
||||
Ws = x.shape[3] // self.patch_size
|
||||
# Match the backbone's pad_to_patch_size (round up) so the LQ grid lines up with the patch stream.
|
||||
Hs = -(-x.shape[2] // self.patch_size)
|
||||
Ws = -(-x.shape[3] // self.patch_size)
|
||||
|
||||
degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1)
|
||||
if degrade_sigma.numel() == 1 and B > 1:
|
||||
|
||||
@ -51,15 +51,6 @@ class FeedForward(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape)
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
340
comfy/ldm/seedvr/color_fix.py
Normal file
340
comfy/ldm/seedvr/color_fix.py
Normal file
@ -0,0 +1,340 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
from comfy.ldm.seedvr.model import safe_pad_operation
|
||||
from comfy.ldm.seedvr.vae import safe_interpolate_operation
|
||||
from comfy.ldm.seedvr.constants import (
|
||||
CIELAB_DELTA,
|
||||
CIELAB_KAPPA,
|
||||
D65_WHITE_X,
|
||||
D65_WHITE_Z,
|
||||
WAVELET_DECOMP_LEVELS,
|
||||
)
|
||||
|
||||
|
||||
def wavelet_blur(image: Tensor, radius):
|
||||
max_safe_radius = max(1, min(image.shape[-2:]) // 8)
|
||||
if radius > max_safe_radius:
|
||||
radius = max_safe_radius
|
||||
|
||||
num_channels = image.shape[1]
|
||||
|
||||
kernel_vals = [
|
||||
[0.0625, 0.125, 0.0625],
|
||||
[0.125, 0.25, 0.125],
|
||||
[0.0625, 0.125, 0.0625],
|
||||
]
|
||||
kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device)
|
||||
kernel = kernel[None, None].repeat(num_channels, 1, 1, 1)
|
||||
|
||||
image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate')
|
||||
output = F.conv2d(image, kernel, groups=num_channels, dilation=radius)
|
||||
|
||||
return output
|
||||
|
||||
def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS):
|
||||
high_freq = torch.zeros_like(image)
|
||||
|
||||
for i in range(levels):
|
||||
radius = 2 ** i
|
||||
low_freq = wavelet_blur(image, radius)
|
||||
high_freq.add_(image).sub_(low_freq)
|
||||
image = low_freq
|
||||
|
||||
return high_freq, low_freq
|
||||
|
||||
def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
# Resize style to match content spatial dimensions
|
||||
if len(content_feat.shape) >= 3:
|
||||
# safe_interpolate_operation handles FP16 conversion automatically
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
# Decompose both features into frequency components
|
||||
content_high_freq, content_low_freq = wavelet_decomposition(content_feat)
|
||||
del content_low_freq # Free memory immediately
|
||||
|
||||
style_high_freq, style_low_freq = wavelet_decomposition(style_feat)
|
||||
del style_high_freq # Free memory immediately
|
||||
|
||||
if content_high_freq.shape != style_low_freq.shape:
|
||||
style_low_freq = safe_interpolate_operation(
|
||||
style_low_freq,
|
||||
size=content_high_freq.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
content_high_freq.add_(style_low_freq)
|
||||
|
||||
return content_high_freq.clamp_(-1.0, 1.0)
|
||||
|
||||
def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor:
|
||||
original_shape = source.shape
|
||||
|
||||
# Flatten
|
||||
source_flat = source.flatten()
|
||||
reference_flat = reference.flatten()
|
||||
|
||||
# Sort both arrays
|
||||
source_sorted, source_indices = torch.sort(source_flat)
|
||||
reference_sorted, _ = torch.sort(reference_flat)
|
||||
del reference_flat
|
||||
|
||||
# Quantile mapping
|
||||
n_source = len(source_sorted)
|
||||
n_reference = len(reference_sorted)
|
||||
|
||||
if n_source == n_reference:
|
||||
matched_sorted = reference_sorted
|
||||
else:
|
||||
# Interpolate reference to match source quantiles
|
||||
source_quantiles = torch.linspace(0, 1, n_source, device=device)
|
||||
ref_indices = (source_quantiles * (n_reference - 1)).long()
|
||||
ref_indices.clamp_(0, n_reference - 1)
|
||||
matched_sorted = reference_sorted[ref_indices]
|
||||
del source_quantiles, ref_indices, reference_sorted
|
||||
|
||||
del source_sorted, source_flat
|
||||
|
||||
# Reconstruct using argsort (portable across CUDA/ROCm/MPS)
|
||||
inverse_indices = torch.argsort(source_indices)
|
||||
del source_indices
|
||||
matched_flat = matched_sorted[inverse_indices]
|
||||
del matched_sorted, inverse_indices
|
||||
|
||||
return matched_flat.reshape(original_shape)
|
||||
|
||||
def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of CIELAB images to RGB color space."""
|
||||
L, a, b = lab[:, 0], lab[:, 1], lab[:, 2]
|
||||
|
||||
# LAB to XYZ
|
||||
fy = (L + 16.0) / 116.0
|
||||
fx = a.div(500.0).add_(fy)
|
||||
fz = fy - b / 200.0
|
||||
del L, a, b
|
||||
|
||||
# XYZ transformation
|
||||
x = torch.where(
|
||||
fx > epsilon,
|
||||
torch.pow(fx, 3.0),
|
||||
fx.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
y = torch.where(
|
||||
fy > epsilon,
|
||||
torch.pow(fy, 3.0),
|
||||
fy.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
z = torch.where(
|
||||
fz > epsilon,
|
||||
torch.pow(fz, 3.0),
|
||||
fz.mul(116.0).sub_(16.0).div_(kappa)
|
||||
)
|
||||
del fx, fy, fz
|
||||
|
||||
# Apply D65 white point (in-place)
|
||||
x.mul_(D65_WHITE_X)
|
||||
# y *= 1.00000 # (no-op, skip)
|
||||
z.mul_(D65_WHITE_Z)
|
||||
|
||||
xyz = torch.stack([x, y, z], dim=1)
|
||||
del x, y, z
|
||||
|
||||
# Matrix multiplication: XYZ -> RGB
|
||||
B, C, H, W = xyz.shape
|
||||
xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del xyz
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype)
|
||||
rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T)
|
||||
del xyz_flat
|
||||
|
||||
rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del rgb_linear_flat
|
||||
|
||||
# Apply inverse gamma correction (delinearize)
|
||||
mask = rgb_linear > 0.0031308
|
||||
rgb = torch.where(
|
||||
mask,
|
||||
torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055),
|
||||
rgb_linear * 12.92
|
||||
)
|
||||
del mask, rgb_linear
|
||||
|
||||
return torch.clamp(rgb, 0.0, 1.0)
|
||||
|
||||
def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor:
|
||||
"""Convert batch of RGB images to CIELAB color space using D65 illuminant."""
|
||||
# Apply sRGB gamma correction (linearize)
|
||||
mask = rgb > 0.04045
|
||||
rgb_linear = torch.where(
|
||||
mask,
|
||||
torch.pow((rgb + 0.055) / 1.055, 2.4),
|
||||
rgb / 12.92
|
||||
)
|
||||
del mask
|
||||
|
||||
# Matrix multiplication: RGB -> XYZ
|
||||
B, C, H, W = rgb_linear.shape
|
||||
rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3)
|
||||
del rgb_linear
|
||||
|
||||
# Ensure dtype consistency for matrix multiplication
|
||||
rgb_flat = rgb_flat.to(dtype=matrix.dtype)
|
||||
xyz_flat = torch.matmul(rgb_flat, matrix.T)
|
||||
del rgb_flat
|
||||
|
||||
xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2)
|
||||
del xyz_flat
|
||||
|
||||
# Normalize by D65 white point (in-place)
|
||||
xyz[:, 0].div_(D65_WHITE_X) # X
|
||||
# xyz[:, 1] /= 1.00000 # Y (no-op, skip)
|
||||
xyz[:, 2].div_(D65_WHITE_Z) # Z
|
||||
|
||||
# XYZ to LAB transformation
|
||||
epsilon_cubed = epsilon ** 3
|
||||
mask = xyz > epsilon_cubed
|
||||
f_xyz = torch.where(
|
||||
mask,
|
||||
torch.pow(xyz, 1.0 / 3.0),
|
||||
xyz.mul(kappa).add_(16.0).div_(116.0)
|
||||
)
|
||||
del xyz, mask
|
||||
|
||||
# Extract channels and compute LAB
|
||||
L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100]
|
||||
a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127]
|
||||
b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127]
|
||||
del f_xyz
|
||||
|
||||
return torch.stack([L, a, b], dim=1)
|
||||
|
||||
def lab_color_transfer(
|
||||
content_feat: Tensor,
|
||||
style_feat: Tensor,
|
||||
luminance_weight: float = 0.8
|
||||
) -> Tensor:
|
||||
content_feat = wavelet_reconstruction(content_feat, style_feat)
|
||||
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False
|
||||
)
|
||||
|
||||
device = content_feat.device
|
||||
|
||||
def ensure_float32_precision(c):
|
||||
orig_dtype = c.dtype
|
||||
c = c.float()
|
||||
return c, orig_dtype
|
||||
content_feat, original_dtype = ensure_float32_precision(content_feat)
|
||||
style_feat, _ = ensure_float32_precision(style_feat)
|
||||
|
||||
rgb_to_xyz_matrix = torch.tensor([
|
||||
[0.4124564, 0.3575761, 0.1804375],
|
||||
[0.2126729, 0.7151522, 0.0721750],
|
||||
[0.0193339, 0.1191920, 0.9503041]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
xyz_to_rgb_matrix = torch.tensor([
|
||||
[ 3.2404542, -1.5371385, -0.4985314],
|
||||
[-0.9692660, 1.8760108, 0.0415560],
|
||||
[ 0.0556434, -0.2040259, 1.0572252]
|
||||
], dtype=torch.float32, device=device)
|
||||
|
||||
epsilon = CIELAB_DELTA
|
||||
kappa = CIELAB_KAPPA
|
||||
|
||||
content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0)
|
||||
|
||||
# Convert to LAB color space
|
||||
content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del content_feat
|
||||
|
||||
style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa)
|
||||
del style_feat, rgb_to_xyz_matrix
|
||||
|
||||
# Match chrominance channels (a*, b*) for accurate color transfer
|
||||
matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device)
|
||||
matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device)
|
||||
|
||||
# Handle luminance with weighted blending
|
||||
if luminance_weight < 1.0:
|
||||
# Partially match luminance for better overall color accuracy
|
||||
matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device)
|
||||
# Blend: preserve some content L* for detail, adopt some style L* for color
|
||||
result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight))
|
||||
del matched_L
|
||||
else:
|
||||
# Fully preserve content luminance
|
||||
result_L = content_lab[:, 0]
|
||||
|
||||
del content_lab, style_lab
|
||||
|
||||
# Reconstruct LAB with corrected channels
|
||||
result_lab = torch.stack([result_L, matched_a, matched_b], dim=1)
|
||||
del result_L, matched_a, matched_b
|
||||
|
||||
# Convert back to RGB
|
||||
result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa)
|
||||
del result_lab, xyz_to_rgb_matrix
|
||||
|
||||
# Convert back to [-1, 1] range (in-place)
|
||||
result = result_rgb.mul_(2.0).sub_(1.0)
|
||||
del result_rgb
|
||||
|
||||
result = result.to(original_dtype)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor:
|
||||
return wavelet_reconstruction(content_feat, style_feat)
|
||||
|
||||
|
||||
def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor:
|
||||
if content_feat.shape != style_feat.shape:
|
||||
style_feat = safe_interpolate_operation(
|
||||
style_feat,
|
||||
size=content_feat.shape[-2:],
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
)
|
||||
|
||||
original_dtype = content_feat.dtype
|
||||
content_feat = content_feat.float()
|
||||
style_feat = style_feat.float()
|
||||
|
||||
b, c = content_feat.shape[:2]
|
||||
content_flat = content_feat.reshape(b, c, -1)
|
||||
style_flat = style_feat.reshape(b, c, -1)
|
||||
|
||||
content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1)
|
||||
content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||
style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1)
|
||||
style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1)
|
||||
del content_flat, style_flat
|
||||
|
||||
normalized = (content_feat - content_mean) / content_std
|
||||
del content_mean, content_std
|
||||
result = normalized * style_std + style_mean
|
||||
del normalized, style_mean, style_std
|
||||
|
||||
result = result.clamp_(-1.0, 1.0)
|
||||
if result.dtype != original_dtype:
|
||||
result = result.to(original_dtype)
|
||||
return result
|
||||
79
comfy/ldm/seedvr/constants.py
Normal file
79
comfy/ldm/seedvr/constants.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""Named constants for the SeedVR2 integration, grouped by provenance.
|
||||
|
||||
Provenance prefixes:
|
||||
- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline.
|
||||
- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites
|
||||
the upstream config/source path it was lifted from.
|
||||
- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature /
|
||||
ISO / CIE values; cite the standard.
|
||||
"""
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment)
|
||||
# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN)
|
||||
# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070
|
||||
# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT).
|
||||
# --------------------------------------------------------------------------------------
|
||||
SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB)
|
||||
SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# B. Fork heuristics (SEEDVR2 - this integration)
|
||||
# --------------------------------------------------------------------------------------
|
||||
SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim.
|
||||
# (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.)
|
||||
SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry.
|
||||
SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case).
|
||||
SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM.
|
||||
SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk.
|
||||
SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels).
|
||||
SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16).
|
||||
SEEDVR2_DEFAULT_TEMPORAL_SIZE = 16 # default VAE temporal tile when unset.
|
||||
|
||||
# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing)
|
||||
SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk.
|
||||
SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path.
|
||||
SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path.
|
||||
SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path.
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR)
|
||||
# --------------------------------------------------------------------------------------
|
||||
BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm.
|
||||
BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift.
|
||||
BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem).
|
||||
BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem).
|
||||
BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28.
|
||||
BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28.
|
||||
BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16).
|
||||
BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32).
|
||||
BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD = 64 # attn_video_vae.py:308 (force .contiguous() above this b*t).
|
||||
BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11.
|
||||
BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size).
|
||||
BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor.
|
||||
BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor).
|
||||
BYTEDANCE_SCHEDULE_T = 1000.0 # configs_3b/main.yaml:65 (schedule.T); timestep range.
|
||||
BYTEDANCE_SPATIAL_DIVISOR = 16 # inference_seedvr2_3b.py:241 (DivisibleCrop((16,16))).
|
||||
BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling).
|
||||
BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames).
|
||||
BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency).
|
||||
BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim).
|
||||
# Resolution-dependent timestep-shift linear fits: (x1, y1, x2, y2) for get_lin_function.
|
||||
BYTEDANCE_IMG_SHIFT_FIT = (256 * 256, 1.0, 1024 * 1024, 3.2) # infer.py:242.
|
||||
BYTEDANCE_VID_SHIFT_FIT = (256 * 256 * 37, 1.0, 1280 * 720 * 145, 5.0) # infer.py:243.
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# D. Published standards (cite the literature)
|
||||
# --------------------------------------------------------------------------------------
|
||||
ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864.
|
||||
|
||||
# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65).
|
||||
CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta).
|
||||
CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa).
|
||||
D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1).
|
||||
D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn.
|
||||
WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR).
|
||||
|
||||
# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and
|
||||
# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the
|
||||
# exact existing coefficients move verbatim rather than being retyped here.
|
||||
1665
comfy/ldm/seedvr/model.py
Normal file
1665
comfy/ldm/seedvr/model.py
Normal file
File diff suppressed because it is too large
Load Diff
2110
comfy/ldm/seedvr/vae.py
Normal file
2110
comfy/ldm/seedvr/vae.py
Normal file
File diff suppressed because it is too large
Load Diff
199
comfy/ldm/triposplat/gaussian.py
Normal file
199
comfy/ldm/triposplat/gaussian.py
Normal file
@ -0,0 +1,199 @@
|
||||
# TripoSplat 3D gaussian container. Operates on already-decoded
|
||||
# tensors and exposes them as render-ready tensors (render_tensors) for the generic SPLAT type.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
|
||||
class GaussianModel:
|
||||
def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0,
|
||||
scaling_bias: float = 0.01, opacity_bias: float = 0.1,
|
||||
scaling_activation: str = "exp", device=None):
|
||||
self.sh_degree = sh_degree
|
||||
self.mininum_kernel_size = mininum_kernel_size
|
||||
self.scaling_bias = scaling_bias
|
||||
self.opacity_bias = opacity_bias
|
||||
self.device = device
|
||||
self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
|
||||
|
||||
if scaling_activation == "exp":
|
||||
self._scaling_activation = torch.exp
|
||||
self._inverse_scaling_activation = torch.log
|
||||
elif scaling_activation == "softplus":
|
||||
self._scaling_activation = F.softplus
|
||||
self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
|
||||
|
||||
self._opacity_activation = torch.sigmoid
|
||||
self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x))
|
||||
|
||||
self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device)
|
||||
self.rots_bias = torch.zeros(4, device=self.device)
|
||||
self.rots_bias[0] = 1
|
||||
self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device)
|
||||
|
||||
self._storage = {}
|
||||
|
||||
def _get_store(self, name):
|
||||
return self._storage.get(name)
|
||||
|
||||
def _set_store(self, name, value):
|
||||
self._storage[name] = value
|
||||
|
||||
@property
|
||||
def _xyz(self):
|
||||
return self._get_store("_xyz")
|
||||
@_xyz.setter
|
||||
def _xyz(self, value):
|
||||
if value is None:
|
||||
self._set_store("_xyz", None)
|
||||
self._set_store("xyz", None)
|
||||
return
|
||||
self._set_store("_xyz", value)
|
||||
self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3])
|
||||
|
||||
@property
|
||||
def get_xyz(self):
|
||||
return self._get_store("xyz")
|
||||
|
||||
@property
|
||||
def _features_dc(self):
|
||||
return self._get_store("_features_dc")
|
||||
@_features_dc.setter
|
||||
def _features_dc(self, value):
|
||||
self._set_store("_features_dc", value)
|
||||
|
||||
@property
|
||||
def _opacity(self):
|
||||
return self._get_store("_opacity")
|
||||
@_opacity.setter
|
||||
def _opacity(self, value):
|
||||
if value is None:
|
||||
self._set_store("_opacity", None)
|
||||
self._set_store("opacity", None)
|
||||
return
|
||||
self._set_store("_opacity", value)
|
||||
self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val))
|
||||
|
||||
@property
|
||||
def get_opacity(self):
|
||||
return self._get_store("opacity")
|
||||
|
||||
@property
|
||||
def _scaling(self):
|
||||
return self._get_store("_scaling")
|
||||
@_scaling.setter
|
||||
def _scaling(self, value):
|
||||
if value is None:
|
||||
self._set_store("_scaling", None)
|
||||
self._set_store("scaling", None)
|
||||
return
|
||||
self._set_store("_scaling", value)
|
||||
s = self._scaling_activation(value + self.scale_bias)
|
||||
s = torch.square(s) + self.mininum_kernel_size ** 2
|
||||
self._set_store("scaling", torch.sqrt(s))
|
||||
|
||||
@property
|
||||
def get_scaling(self):
|
||||
return self._get_store("scaling")
|
||||
|
||||
@property
|
||||
def _rotation(self):
|
||||
return self._get_store("_rotation")
|
||||
@_rotation.setter
|
||||
def _rotation(self, value):
|
||||
self._set_store("_rotation", value)
|
||||
|
||||
_DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
|
||||
|
||||
def render_tensors(self):
|
||||
# Render-ready (activated, world-space) tensors for the generic SPLAT type. The axis transform
|
||||
# (a 3x3 rotation, object frame -> viewer Y-up) is baked into positions and rotations.
|
||||
# Returns float tensors on the intermediate device: positions (N,3), scales (N,3) linear,
|
||||
# rotations (N,4) wxyz, opacities (N,1) in [0,1], sh (N,K,3) coefficients.
|
||||
xyz = self.get_xyz.float()
|
||||
scaling = self.get_scaling.float()
|
||||
opacity = self.get_opacity.float()
|
||||
rotation = (self._rotation + self.rots_bias[None, :]).float()
|
||||
sh = self._features_dc.float() # (N, K, 3)
|
||||
T = torch.as_tensor(self._DEFAULT_TRANSFORM, dtype=torch.float32, device=xyz.device)
|
||||
xyz = xyz @ T.T
|
||||
rotation = _matrix_to_quat(torch.matmul(T, _quat_to_matrix(rotation)))
|
||||
rotation = rotation / torch.linalg.norm(rotation, dim=-1, keepdim=True)
|
||||
out_device = comfy.model_management.intermediate_device()
|
||||
return (
|
||||
xyz.to(out_device).contiguous(), scaling.to(out_device).contiguous(),
|
||||
rotation.to(out_device).contiguous(), opacity.to(out_device).contiguous(),
|
||||
sh.to(out_device).contiguous(),
|
||||
)
|
||||
|
||||
|
||||
def _quat_to_matrix(q):
|
||||
q = q / torch.linalg.norm(q, dim=-1, keepdim=True)
|
||||
w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
|
||||
R = torch.stack([
|
||||
1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y),
|
||||
2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x),
|
||||
2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y),
|
||||
], dim=-1).reshape(-1, 3, 3)
|
||||
return R
|
||||
|
||||
|
||||
def _matrix_to_quat(R):
|
||||
trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
|
||||
q = torch.zeros((R.shape[0], 4), dtype=R.dtype, device=R.device)
|
||||
s = torch.sqrt(torch.clamp(trace + 1, min=0)) * 2
|
||||
q[:, 0] = 0.25 * s
|
||||
denom = torch.where(s != 0, s, torch.ones_like(s))
|
||||
q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / denom
|
||||
q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / denom
|
||||
q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / denom
|
||||
m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0)
|
||||
s1 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=0)) * 2
|
||||
q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01]
|
||||
q[m01, 1] = 0.25 * s1[m01]
|
||||
q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01]
|
||||
q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01]
|
||||
m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0)
|
||||
s2 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=0)) * 2
|
||||
q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11]
|
||||
q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11]
|
||||
q[m11, 2] = 0.25 * s2[m11]
|
||||
q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11]
|
||||
m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0)
|
||||
s3 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=0)) * 2
|
||||
q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21]
|
||||
q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21]
|
||||
q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21]
|
||||
q[m21, 3] = 0.25 * s3[m21]
|
||||
return q / torch.linalg.norm(q, dim=-1, keepdim=True)
|
||||
|
||||
|
||||
def build_gaussian_models(decoder, points_pred: dict, pred: dict):
|
||||
# Assemble GaussianModels from the elastic decoder layout. decoder is the ElasticGaussianFixedlenDecoder
|
||||
# (carries layout / rep_config / _get_offset)
|
||||
x = points_pred
|
||||
offset = decoder._get_offset(pred['features'])
|
||||
h = pred["features"]
|
||||
ret = []
|
||||
for i in range(h.shape[0]):
|
||||
g = GaussianModel(
|
||||
sh_degree=0,
|
||||
aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
|
||||
mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'],
|
||||
scaling_bias=decoder.rep_config['scaling_bias'],
|
||||
opacity_bias=decoder.rep_config['opacity_bias'],
|
||||
scaling_activation=decoder.rep_config['scaling_activation'],
|
||||
device=h.device,
|
||||
)
|
||||
_x = x["points"][i, :, None, :]
|
||||
for k, v in decoder.layout.items():
|
||||
if k == '_xyz':
|
||||
setattr(g, k, (offset[i] + _x).flatten(0, 1))
|
||||
elif k in ('_xyz_center', '_offset_scale'):
|
||||
continue
|
||||
else:
|
||||
feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
|
||||
setattr(g, k, feats * decoder.rep_config['lr'][k])
|
||||
ret.append(g)
|
||||
return ret
|
||||
326
comfy/ldm/triposplat/model.py
Normal file
326
comfy/ldm/triposplat/model.py
Normal file
@ -0,0 +1,326 @@
|
||||
# TripoSplat flow-matching denoiser (LatentSeqMMFlowModel). Registered as a ModelType.FLOW arch and
|
||||
# driven by the standard KSampler; jointly denoises the (B, 8192, 16) latent and a (B, 1, 5) camera token
|
||||
# carried as a 2-element nested latent.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.rmsnorm
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
|
||||
|
||||
class MultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim, heads, dtype=None, device=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.empty(heads, dim, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x):
|
||||
x = comfy.rmsnorm.rms_norm(x)
|
||||
return x * comfy.model_management.cast_to(self.gamma, x.dtype, x.device)
|
||||
|
||||
|
||||
# Positional embeddings
|
||||
|
||||
class RePo3DRotaryEmbedding(nn.Module):
|
||||
def __init__(self, model_channels, num_heads, head_dim, repo_hidden_ratio=0.125, max_freq=16.0,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
repo_hidden_size = int(model_channels * repo_hidden_ratio)
|
||||
self.norm = operations.LayerNorm(model_channels, dtype=dtype, device=device)
|
||||
self.gate_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
|
||||
self.content_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
|
||||
self.act = nn.SiLU()
|
||||
self.final_map = operations.Linear(repo_hidden_size, 3 * num_heads, bias=False, dtype=dtype, device=device)
|
||||
self.dim_0 = 2 * (head_dim // 6)
|
||||
self.dim_1 = 2 * (head_dim // 6)
|
||||
self.dim_2 = head_dim - self.dim_0 - self.dim_1
|
||||
dims = [self.dim_0, self.dim_1, self.dim_2]
|
||||
freqs_list = []
|
||||
for d in dims:
|
||||
freq_dim = d // 2
|
||||
freqs_list.append(torch.linspace(1.0, float(max_freq), steps=freq_dim, dtype=torch.float32))
|
||||
self.freqs_0 = nn.Parameter(freqs_list[0])
|
||||
self.freqs_1 = nn.Parameter(freqs_list[1])
|
||||
self.freqs_2 = nn.Parameter(freqs_list[2])
|
||||
|
||||
def forward(self, hidden_states):
|
||||
h = self.norm(hidden_states)
|
||||
feat = self.act(self.gate_map(h)) * self.content_map(h)
|
||||
out = self.final_map(feat)
|
||||
B, L, _ = out.shape
|
||||
delta_pos = out.reshape(B, L, self.num_heads, 3)
|
||||
f0 = comfy.model_management.cast_to(self.freqs_0, torch.float32, out.device)
|
||||
f1 = comfy.model_management.cast_to(self.freqs_1, torch.float32, out.device)
|
||||
f2 = comfy.model_management.cast_to(self.freqs_2, torch.float32, out.device)
|
||||
ang_0 = delta_pos[..., 0].unsqueeze(-1) * f0 * torch.pi
|
||||
ang_1 = delta_pos[..., 1].unsqueeze(-1) * f1 * torch.pi
|
||||
ang_2 = delta_pos[..., 2].unsqueeze(-1) * f2 * torch.pi
|
||||
ang = torch.cat([ang_0, ang_1, ang_2], dim=-1).float() # (B, L, heads, head_dim/2)
|
||||
cos, sin = ang.cos(), ang.sin()
|
||||
return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*ang.shape, 2, 2)
|
||||
|
||||
|
||||
class PcdAbsolutePositionEmbedder(nn.Module):
|
||||
# Sinusoidal absolute position embedding. Two fixed schedules are used in TripoSplat:
|
||||
# "pow2" (flow-model latent anchors) and "log2" (octree / gaussian decoders).
|
||||
def __init__(self, channels: int, in_channels: int = 3, max_res: int = 16, schedule: str = "pow2"):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.in_channels = in_channels
|
||||
self.max_res = max_res
|
||||
self.schedule = schedule
|
||||
self.freq_dim = channels // in_channels // 2
|
||||
|
||||
def _freqs(self, device):
|
||||
if self.schedule == "pow2":
|
||||
freqs_2exp = torch.arange(self.max_res, dtype=torch.float32, device=device)
|
||||
res_dim = max(0, self.freq_dim - self.max_res)
|
||||
freqs_res = (torch.arange(res_dim, dtype=torch.float32, device=device) / max(res_dim, 1) * self.max_res
|
||||
if res_dim > 0 else torch.empty(0, device=device))
|
||||
freqs = torch.cat([freqs_2exp, freqs_res], dim=0)[:self.freq_dim]
|
||||
return torch.pow(2.0, freqs) * 2.0 # *2 folds this schedule's 2*pi into the shared *pi below
|
||||
logs = torch.linspace(0.0, float(self.max_res), steps=self.freq_dim, dtype=torch.float32, device=device)
|
||||
return torch.pow(2.0, logs)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
orig_dtype = x.dtype
|
||||
x = x.float()
|
||||
*dims, D = x.shape
|
||||
out = torch.outer(x.reshape(-1), self._freqs(x.device)) * torch.pi
|
||||
out = torch.cat([out.sin(), out.cos()], dim=-1).reshape(*dims, -1)
|
||||
if out.shape[-1] < self.channels:
|
||||
out = torch.cat([out, torch.zeros(*dims, self.channels - out.shape[-1],
|
||||
device=out.device, dtype=out.dtype)], dim=-1)
|
||||
return out.to(orig_dtype)
|
||||
|
||||
|
||||
def attention(q, k, v, transformer_options=None):
|
||||
# q, k, v: (B, L, heads, dim) -> (B, L, heads, dim). Shared optimized_attention call convention.
|
||||
out = optimized_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), heads=q.shape[2],
|
||||
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
|
||||
transformer_options=transformer_options)
|
||||
return out.transpose(1, 2)
|
||||
|
||||
|
||||
# Transformer building blocks
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, in_channels, hidden_channels, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(in_channels, hidden_channels, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(hidden_channels, out_channels, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mlp(x)
|
||||
|
||||
|
||||
class RopeMultiHeadAttention(nn.Module):
|
||||
def __init__(self, channels, num_heads, qkv_bias=True, qk_rms_norm=False, use_rope=False,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = channels // num_heads
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
self.use_rope = use_rope
|
||||
self.qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if self.qk_rms_norm:
|
||||
self.q_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.k_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.out = operations.Linear(channels, channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, rope_emb=None, transformer_options=None):
|
||||
B, L, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim)
|
||||
q, k, v = qkv.unbind(2)
|
||||
if self.use_rope:
|
||||
q, k = apply_rope(q, k, rope_emb)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
h = attention(q, k, v, transformer_options) # (B, L, heads, dim)
|
||||
return self.out(h.reshape(B, L, C))
|
||||
|
||||
|
||||
class UnifiedTransformerBlock(nn.Module):
|
||||
def __init__(self, channels, num_heads, mlp_ratio=4.0,
|
||||
use_rope=False, qk_rms_norm=False, qkv_bias=True,
|
||||
modulation=True, share_mod=False,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.modulation = modulation
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = RopeMultiHeadAttention(channels, num_heads=num_heads,
|
||||
qkv_bias=qkv_bias, use_rope=use_rope, qk_rms_norm=qk_rms_norm,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
||||
if modulation:
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
|
||||
self.shift_table = nn.Parameter(torch.empty(1, 6 * channels, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, mod=None, rotary_emb=None, transformer_options=None):
|
||||
if self.modulation:
|
||||
if not self.share_mod:
|
||||
mod = self.adaLN_modulation(mod)
|
||||
mod = mod + comfy.model_management.cast_to(self.shift_table, mod.dtype, mod.device)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.attn(h, rope_emb=rotary_emb, transformer_options=transformer_options), gate_msa.unsqueeze(1))
|
||||
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
|
||||
else:
|
||||
x = x + self.attn(self.norm1(x), rope_emb=rotary_emb, transformer_options=transformer_options)
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
emb = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
return self.mlp(emb.to(self.mlp[0].weight.dtype))
|
||||
|
||||
|
||||
class LatentSeqMMFlowModel(nn.Module):
|
||||
def __init__(self, image_model=None, q_token_length=8192, in_channels=16, model_channels=1024,
|
||||
cond_channels=1280, out_channels=16, num_blocks=24, num_refiner_blocks=2,
|
||||
num_heads=None, num_head_channels=64, cam_channels=5, cond2_channels=128,
|
||||
mlp_ratio=4, share_mod=True, qk_rms_norm=True,
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
self.q_token_length = q_token_length
|
||||
self.in_channels = in_channels
|
||||
self.cam_channels = cam_channels
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.cond2_channels = cond2_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_refiner_blocks = num_refiner_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.share_mod = share_mod
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
|
||||
factory_kwargs = dict(dtype=dtype, device=device)
|
||||
op_kwargs = dict(operations=operations, **factory_kwargs)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(model_channels, **op_kwargs)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, **factory_kwargs))
|
||||
|
||||
self.input_layer = operations.Linear(in_channels, model_channels, **factory_kwargs)
|
||||
self.cond_embedder = operations.Linear(cond_channels, model_channels, **factory_kwargs)
|
||||
self.cond_embedder2 = operations.Linear(cond2_channels, model_channels, **factory_kwargs) if cond2_channels is not None else None
|
||||
|
||||
# Fixed Sobol (low-discrepancy) 3D anchor positions for the latent tokens, used as positional encoding.
|
||||
# The embedder is parameter-free and the anchors are fixed, precompute once.
|
||||
sobol_seq = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123).draw(q_token_length)
|
||||
pos_emb = PcdAbsolutePositionEmbedder(model_channels)(sobol_seq.unsqueeze(0))
|
||||
self.register_buffer("pos_emb", pos_emb, persistent=False)
|
||||
|
||||
# RePo3DRotaryEmbedding layers for the refiner and main blocks
|
||||
repo_kwargs = dict(num_heads=self.num_heads, head_dim=num_head_channels, **op_kwargs)
|
||||
self.noise_repo_layers = nn.ModuleList(
|
||||
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
|
||||
self.context_repo_layers = nn.ModuleList(
|
||||
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
|
||||
self.repo_layers = nn.ModuleList(
|
||||
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_blocks)])
|
||||
|
||||
# Refiner blocks
|
||||
block_kwargs = dict(num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, use_rope=True, qk_rms_norm=self.qk_rms_norm, **op_kwargs)
|
||||
self.noise_refiner = nn.ModuleList(
|
||||
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_refiner_blocks)])
|
||||
self.context_refiner = nn.ModuleList(
|
||||
[UnifiedTransformerBlock(model_channels, modulation=False, **block_kwargs) for _ in range(num_refiner_blocks)])
|
||||
|
||||
self.cam_refiner = MLP(self.cam_channels, model_channels, model_channels, **op_kwargs)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_blocks)])
|
||||
|
||||
self.shift_table = nn.Parameter(torch.empty(1, 2, model_channels, **factory_kwargs))
|
||||
self.out_layer = operations.Linear(model_channels, out_channels, **factory_kwargs)
|
||||
self.cam_out_layer = operations.Linear(model_channels, cam_channels, **factory_kwargs)
|
||||
|
||||
def forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, t, context, ref_latents, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
|
||||
# x is the unpacked nested latent: [latent (B,8192,in_channels), camera (B,1,cam_channels)].
|
||||
# context == feature1.
|
||||
z, camera = x[0], x[1]
|
||||
feat1 = context
|
||||
|
||||
h_x = self.input_layer(z)
|
||||
h_cond = self.cond_embedder(feat1)
|
||||
if ref_latents is not None and self.cond_embedder2 is not None:
|
||||
# Flatten the Flux2 VAE latent (B,128,h,w) to a token sequence and front-pad to feat1's length
|
||||
# (the pad count = feat1's prefix tokens: DINOv3 cls + registers), then add to the context.
|
||||
feat2 = ref_latents[0].flatten(2).transpose(1, 2)
|
||||
feat2 = F.pad(feat2, (0, 0, feat1.shape[1] - feat2.shape[1], 0))
|
||||
h_cond = h_cond + self.cond_embedder2(feat2.to(h_cond.dtype))
|
||||
t_emb = self.t_embedder(t)
|
||||
t_mod = self.adaLN_modulation(t_emb) if self.share_mod else t_emb
|
||||
|
||||
h_x = h_x + self.pos_emb.to(z)
|
||||
|
||||
for i, block in enumerate(self.noise_refiner):
|
||||
h_x = block(h_x, mod=t_mod, rotary_emb=self.noise_repo_layers[i](h_x), transformer_options=transformer_options)
|
||||
|
||||
for i, block in enumerate(self.context_refiner):
|
||||
h_cond = block(h_cond, mod=None, rotary_emb=self.context_repo_layers[i](h_cond), transformer_options=transformer_options)
|
||||
|
||||
cam = camera.to(z)
|
||||
h_cam = self.cam_refiner(cam)
|
||||
h = torch.cat([h_x, h_cond, h_cam], dim=1)
|
||||
|
||||
for i, block in enumerate(self.blocks):
|
||||
h = block(h, mod=t_mod, rotary_emb=self.repo_layers[i](h), transformer_options=transformer_options)
|
||||
|
||||
h_x = F.layer_norm(h[:, :z.shape[1]].float(), h.shape[-1:]).to(z)
|
||||
h_cam = F.layer_norm(h[:, -cam.shape[1]:].float(), h.shape[-1:]).to(z)
|
||||
|
||||
shift, scale = (comfy.model_management.cast_to(self.shift_table, t_emb.dtype, t_emb.device) + t_emb.unsqueeze(1)).chunk(2, dim=1)
|
||||
scale = 1 + scale
|
||||
h_x = torch.addcmul(shift, h_x, scale)
|
||||
h_cam = torch.addcmul(shift, h_cam, scale)
|
||||
|
||||
return self.out_layer(h_x), self.cam_out_layer(h_cam)
|
||||
91
comfy/ldm/triposplat/preview.py
Normal file
91
comfy/ldm/triposplat/preview.py
Normal file
@ -0,0 +1,91 @@
|
||||
# Live preview for TripoSplat: decode an x0 estimate into a coarse gaussian splat and render it with a perspective orbit camera.
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
_C0 = 0.28209479177387814
|
||||
_LATENT_TOKENS = 8192 # q_token_length
|
||||
_LATENT_CH = 16 # in_channels
|
||||
_OBJECT_TO_VIEWER = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], np.float32) # object frame -> viewer Y-up frame
|
||||
|
||||
|
||||
def _view_matrix(yaw_deg, pitch_deg):
|
||||
y, p = np.radians(yaw_deg), np.radians(pitch_deg)
|
||||
Ry = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0], [-np.sin(y), 0, np.cos(y)]], np.float32)
|
||||
Rx = np.array([[1, 0, 0], [0, np.cos(p), -np.sin(p)], [0, np.sin(p), np.cos(p)]], np.float32)
|
||||
return Rx @ Ry
|
||||
|
||||
|
||||
def render_splat(xyz, rgb, scale, opacity=None, yaw=35.0, pitch=30.0, size=320, min_px=2, gain=1.0,
|
||||
max_px=9, min_opacity=0.0, fov=35.0, dist=2.2):
|
||||
# Project gaussian centers with a perspective camera and paint each as a filled disk whose screen
|
||||
# radius follows the gaussian's world-space scale, composited with a nearest-wins z-buffer.
|
||||
# gain scales the footprint (≈ std spanned), `min_px`/`max_px` clamp the on-screen radius.
|
||||
|
||||
pts = xyz.astype(np.float32) @ _OBJECT_TO_VIEWER.T
|
||||
v = pts @ _view_matrix(yaw, pitch).T
|
||||
zc = v[:, 2] + dist
|
||||
keep = zc > 1e-2
|
||||
if opacity is not None and min_opacity > 0.0: # culls gaussians with very low opacity
|
||||
keep = keep & (opacity > min_opacity)
|
||||
v, zc, scale = v[keep], zc[keep], scale[keep]
|
||||
col = (np.clip(rgb, 0, 1)[:, :3] * 255).astype(np.uint8)[keep]
|
||||
if v.shape[0] == 0:
|
||||
return Image.fromarray(np.zeros((size, size, 3), np.uint8))
|
||||
f = (size / 2) / np.tan(np.radians(fov) / 2)
|
||||
cx = size / 2 + f * v[:, 0] / zc
|
||||
cy = size / 2 + f * v[:, 1] / zc
|
||||
radius = np.clip(np.round(f * scale / zc * gain), min_px, max_px).astype(np.int32)
|
||||
|
||||
# Expand each splat to its disk pixels, bucketed by integer radius so it stays vectorized.
|
||||
px, py, pz, pc = [], [], [], []
|
||||
for r in range(int(radius.min()), int(radius.max()) + 1):
|
||||
m = radius == r
|
||||
if not m.any():
|
||||
continue
|
||||
dy, dx = np.mgrid[-r:r + 1, -r:r + 1]
|
||||
disk = (dx * dx + dy * dy) <= r * r
|
||||
ox, oy = dx[disk], dy[disk]
|
||||
px.append((cx[m, None] + ox).ravel())
|
||||
py.append((cy[m, None] + oy).ravel())
|
||||
pz.append(np.repeat(zc[m], ox.size))
|
||||
pc.append(np.repeat(col[m], ox.size, axis=0))
|
||||
px, py = np.concatenate(px), np.concatenate(py)
|
||||
pz, pc = np.concatenate(pz), np.concatenate(pc)
|
||||
xi = np.clip(px, 0, size - 1).astype(np.int64)
|
||||
yi = np.clip(py, 0, size - 1).astype(np.int64)
|
||||
|
||||
# Nearest-wins z-buffer: pack (quantized depth, source index), per-pixel min picks the closest
|
||||
# splat, then decode the winning index back to its color.
|
||||
pid = yi * size + xi
|
||||
q = np.clip((pz * 1024.0).astype(np.int64), 0, (1 << 20) - 1) # near = small
|
||||
key = (q << 32) | np.arange(pid.size, dtype=np.int64)
|
||||
buf = np.full(size * size, 1 << 62, np.int64)
|
||||
np.minimum.at(buf, pid, key)
|
||||
img = np.zeros((size * size, 3), np.uint8)
|
||||
hit = buf < (1 << 62)
|
||||
img[hit] = pc[buf[hit] & 0xFFFFFFFF]
|
||||
return Image.fromarray(img.reshape(size, size, 3))
|
||||
|
||||
|
||||
def _extract_latent(x0):
|
||||
# x0 from the sampler callback is the nested latent packed to (B, 1, TOKENS*CH + 1*5);
|
||||
# the plain single-latent case is (B, TOKENS, CH). Return the (B, TOKENS, CH) latent stream.
|
||||
if x0.ndim == 3 and x0.shape[1] == _LATENT_TOKENS and x0.shape[2] == _LATENT_CH:
|
||||
return x0
|
||||
flat = x0.reshape(x0.shape[0], -1)
|
||||
return flat[:, :_LATENT_TOKENS * _LATENT_CH].reshape(x0.shape[0], _LATENT_TOKENS, _LATENT_CH)
|
||||
|
||||
|
||||
def decode_x0_to_image(decoder, x0, cfg):
|
||||
# Decode x0 at a coarse octree level / few gaussians and render a preview image.
|
||||
latent = _extract_latent(x0)
|
||||
fsm = decoder.first_stage_model
|
||||
gaussian = fsm.decode(latent.to(decoder.device, decoder.vae_dtype),
|
||||
num_gaussians=cfg.get("gaussians", 16384), level=cfg.get("level", 5))[0]
|
||||
xyz = gaussian.get_xyz.float().cpu().numpy()
|
||||
rgb = gaussian._features_dc.float().cpu().numpy()[:, 0, :] * _C0 + 0.5
|
||||
scale = gaussian.get_scaling.float().cpu().numpy().max(axis=1) # per-splat world radius (largest axis)
|
||||
opacity = gaussian.get_opacity.float().cpu().numpy()[:, 0]
|
||||
return render_splat(xyz, rgb, scale, opacity=opacity, yaw=cfg.get("yaw", 35.0), pitch=cfg.get("pitch", 30.0),
|
||||
size=cfg.get("size", 320), min_px=1, gain=1.0, max_px=cfg.get("point_size", 3),
|
||||
min_opacity=0.01)
|
||||
382
comfy/ldm/triposplat/vae.py
Normal file
382
comfy/ldm/triposplat/vae.py
Normal file
@ -0,0 +1,382 @@
|
||||
# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an
|
||||
# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns
|
||||
# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
from .gaussian import build_gaussian_models
|
||||
from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention
|
||||
|
||||
|
||||
# Quasi-random sampling utilities (pure functions, dtype/device-agnostic)
|
||||
|
||||
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
|
||||
|
||||
|
||||
def radical_inverse(base, n):
|
||||
val = 0
|
||||
inv_base = 1.0 / base
|
||||
inv_base_n = inv_base
|
||||
while n > 0:
|
||||
digit = n % base
|
||||
val += digit * inv_base_n
|
||||
n //= base
|
||||
inv_base_n *= inv_base
|
||||
return val
|
||||
|
||||
|
||||
def halton_sequence(dim, n):
|
||||
return [radical_inverse(PRIMES[i], n) for i in range(dim)]
|
||||
|
||||
|
||||
def hammersley_sequence(dim, n, num_samples):
|
||||
return [n / num_samples] + halton_sequence(dim - 1, n)
|
||||
|
||||
|
||||
def sample_probs(probs, counts, generator=None):
|
||||
# Systematic resampling: distribute counts[r] draws across the P bins of row r
|
||||
batch_shape = counts.shape
|
||||
R = counts.numel()
|
||||
P = probs.size(-1)
|
||||
device = probs.device
|
||||
probs = probs.reshape(R, P).to(torch.float32).clamp_min(0)
|
||||
counts = counts.reshape(R).to(device=device, dtype=torch.long)
|
||||
|
||||
row_sums = probs.sum(1, keepdim=True)
|
||||
probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1))
|
||||
cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12)
|
||||
|
||||
Nmax = int(counts.max())
|
||||
if Nmax == 0:
|
||||
return counts.new_zeros(*batch_shape, P)
|
||||
cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1)
|
||||
grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax)
|
||||
u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded)
|
||||
idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1)
|
||||
weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r]
|
||||
out = torch.zeros(R, P, dtype=torch.float32, device=device)
|
||||
out.scatter_add_(1, idx, weight)
|
||||
return out.to(torch.long).view(*batch_shape, P)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
assert channels % num_heads == 0
|
||||
self.channels = channels
|
||||
self.head_dim = channels // num_heads
|
||||
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||
self.num_heads = num_heads
|
||||
self._type = type
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
if self._type == "self":
|
||||
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device)
|
||||
if self.qk_rms_norm:
|
||||
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
|
||||
self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, context=None):
|
||||
B, L, C = x.shape
|
||||
if self._type == "self":
|
||||
q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2)
|
||||
else:
|
||||
Lkv = context.shape[1]
|
||||
q = self.to_q(x).reshape(B, L, self.num_heads, -1)
|
||||
k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
h = attention(q, k, v)
|
||||
return self.to_out(h.reshape(B, L, -1))
|
||||
|
||||
|
||||
# Octree probability decoder
|
||||
|
||||
class LevelEmbedder(nn.Module):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.max_period = max_period
|
||||
|
||||
@staticmethod
|
||||
def level_embedding(t, dim, max_period=1024):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
||||
args = t[:, None].float() * freqs[None] * 2 * torch.pi
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period)
|
||||
return self.mlp(emb.to(self.mlp[0].weight.dtype))
|
||||
|
||||
|
||||
class ModulatedTransformerCrossOnlyBlock(nn.Module):
|
||||
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False,
|
||||
qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads,
|
||||
type="cross", qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x, mod, context):
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1))
|
||||
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
|
||||
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
|
||||
return x
|
||||
|
||||
|
||||
class OctreeProbabilityFixedlenDecoder(nn.Module):
|
||||
# Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits.
|
||||
def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16,
|
||||
num_head_channels=64, mlp_ratio=4.0, share_mod=True,
|
||||
qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.share_mod = share_mod
|
||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
|
||||
self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device))
|
||||
if cond_channels is not None:
|
||||
self.blocks = nn.ModuleList([
|
||||
ModulatedTransformerCrossOnlyBlock(
|
||||
model_channels, ctx_channels=cond_channels, num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||
share_mod=self.share_mod, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device)
|
||||
self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device)
|
||||
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
|
||||
|
||||
def forward(self, x, l, cond):
|
||||
d = next(self.parameters()).dtype
|
||||
B, L, _ = x.shape
|
||||
h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d)
|
||||
h = self.input_layer(h)
|
||||
l_emb = self.l_embedder(l)
|
||||
if self.share_mod:
|
||||
l_emb = self.adaLN_modulation(l_emb)
|
||||
cond = cond.to(d)
|
||||
for block in self.blocks:
|
||||
h = block(h, l_emb, cond)
|
||||
h = F.layer_norm(h.float(), h.shape[-1:]).to(d)
|
||||
logits = self.out_proj(h)
|
||||
return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
|
||||
|
||||
@staticmethod
|
||||
def sample(model, cond, num_points, level, temperature=1.0, generator=None):
|
||||
B = cond.shape[0]
|
||||
device = cond.device
|
||||
child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
|
||||
dtype=torch.long, device=device)
|
||||
prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device)
|
||||
prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device)
|
||||
prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device)
|
||||
batch_indices_range = torch.arange(B, device=device).unsqueeze(1)
|
||||
|
||||
for lv in range(1, level + 1):
|
||||
res_p = 1 << (lv - 1)
|
||||
res = 1 << lv
|
||||
parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p
|
||||
res_tensor = torch.full((B,), res, dtype=torch.long, device=device)
|
||||
pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature
|
||||
pred_probs = torch.softmax(pred_logits, dim=-1)
|
||||
pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
|
||||
sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2)
|
||||
pred_log_probs = pred_log_probs.flatten(1, 2)
|
||||
prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
|
||||
child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
|
||||
mask = sampled > 0
|
||||
max_valid = mask.sum(dim=1).max().item()
|
||||
scatter_indices = mask.cumsum(dim=1) - 1
|
||||
valid_scatter_indices = scatter_indices[mask]
|
||||
valid_batch_indices = batch_indices_range.expand_as(mask)[mask]
|
||||
next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device)
|
||||
next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask]
|
||||
next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device)
|
||||
next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask]
|
||||
next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device)
|
||||
next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask]
|
||||
prev_coords_int = next_prev_coords_int
|
||||
prev_counts = next_prev_counts
|
||||
prev_log_probs = next_prev_log_probs
|
||||
|
||||
res = 1 << level
|
||||
prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
|
||||
coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
|
||||
rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device)
|
||||
coords_norm = (coords_int.to(torch.float32) + rand) / res
|
||||
return {"points": coords_norm, "log_probs": prev_log_probs}
|
||||
|
||||
|
||||
# Elastic gaussian decoder
|
||||
|
||||
class TransformerCrossBlock(nn.Module):
|
||||
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0,
|
||||
qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True,
|
||||
dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations)
|
||||
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross",
|
||||
qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
|
||||
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x, context):
|
||||
x = x + self.self_attn(self.norm1(x))
|
||||
x = x + self.cross_attn(self.norm2(x), context)
|
||||
x = x + self.mlp(self.norm3(x))
|
||||
return x
|
||||
|
||||
|
||||
class ElasticGaussianFixedlenDecoder(nn.Module):
|
||||
# Cross-attention transformer over sampled octree points -> per-point gaussian params.
|
||||
def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16,
|
||||
num_head_channels=64, mlp_ratio=4.0, *, representation_config=None,
|
||||
qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.rep_config = representation_config or dict(
|
||||
lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1),
|
||||
perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32,
|
||||
filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1,
|
||||
scaling_activation="softplus",
|
||||
)
|
||||
self.out_channels = self._calc_layout()
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
|
||||
if cond_channels is not None:
|
||||
self.blocks = nn.ModuleList([
|
||||
TransformerCrossBlock(model_channels, ctx_channels=cond_channels,
|
||||
num_heads=self.num_heads, mlp_ratio=self.mlp_ratio,
|
||||
qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device)
|
||||
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
|
||||
self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device)
|
||||
self._build_perturbation()
|
||||
|
||||
def _calc_layout(self):
|
||||
ng = self.rep_config['num_gaussians']
|
||||
self.layout = {
|
||||
'_xyz': {'shape': (ng, 3), 'size': ng * 3},
|
||||
'_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3},
|
||||
'_scaling': {'shape': (ng, 3), 'size': ng * 3},
|
||||
'_rotation': {'shape': (ng, 4), 'size': ng * 4},
|
||||
'_opacity': {'shape': (ng, 1), 'size': ng},
|
||||
}
|
||||
self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng}
|
||||
start = 0
|
||||
for k, v in self.layout.items():
|
||||
v['range'] = (start, start + v['size'])
|
||||
start += v['size']
|
||||
return start
|
||||
|
||||
def _build_perturbation(self):
|
||||
ng = self.rep_config['num_gaussians']
|
||||
perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float()
|
||||
perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size'])
|
||||
self.register_buffer('points_offset_perturbation', perturbation)
|
||||
base = torch.tensor(self.rep_config['offset_scale'])
|
||||
self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0))
|
||||
|
||||
def _get_offset(self, h):
|
||||
B = h.shape[0]
|
||||
r = self.layout['_offset_scale']['range']
|
||||
_offset_scale = F.softplus(
|
||||
h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape'])
|
||||
+ comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device))
|
||||
|
||||
r = self.layout['_xyz']['range']
|
||||
offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape'])
|
||||
offset = offset * self.rep_config['lr']['_xyz']
|
||||
if self.rep_config['perturb_offset']:
|
||||
offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device)
|
||||
offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size']
|
||||
offset = offset * _offset_scale
|
||||
return offset
|
||||
|
||||
def forward(self, x=None, cond=None):
|
||||
pcd = x["points"]
|
||||
d = next(self.parameters()).dtype
|
||||
B, L, _ = pcd.shape
|
||||
h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d)
|
||||
h = self.input_layer(h)
|
||||
cond = cond.to(d)
|
||||
for block in self.blocks:
|
||||
h = block(h, cond)
|
||||
h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype)
|
||||
return {"features": self.out_proj(h)}
|
||||
|
||||
|
||||
# Combined octree gaussian decoder (comfy first-stage model)
|
||||
|
||||
class OctreeGaussianDecoder(nn.Module):
|
||||
_MAX_VOXEL_LEVEL = 8
|
||||
|
||||
def __init__(self, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if operations is None:
|
||||
operations = comfy.ops.disable_weight_init
|
||||
self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations)
|
||||
self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations)
|
||||
|
||||
@property
|
||||
def gaussians_per_point(self) -> int:
|
||||
return self.gs.rep_config['num_gaussians']
|
||||
|
||||
def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None):
|
||||
# level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews.
|
||||
# generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG.
|
||||
level = self._MAX_VOXEL_LEVEL if level is None else level
|
||||
num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
|
||||
points_pred = OctreeProbabilityFixedlenDecoder.sample(
|
||||
self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator,
|
||||
)
|
||||
pred = self.gs(x=points_pred, cond=latent)
|
||||
return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item
|
||||
@ -4,6 +4,7 @@ import dataclasses
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
|
||||
import comfy_aimdo.host_buffer
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
|
||||
@ -17,21 +18,18 @@ class TensorFileSlice(NamedTuple):
|
||||
def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None):
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
if not isinstance(destination, QuantizedTensor):
|
||||
return False
|
||||
if tensor._layout_cls != destination._layout_cls:
|
||||
return False
|
||||
|
||||
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream,
|
||||
if not read_tensor_file_slice_into(tensor._qdata,
|
||||
destination._qdata if destination is not None else None, stream=stream,
|
||||
destination2=(destination2._qdata if destination2 is not None else None)):
|
||||
return False
|
||||
|
||||
dst_orig_dtype = destination._params.orig_dtype
|
||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||
if destination is not None:
|
||||
dst_orig_dtype = destination._params.orig_dtype
|
||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||
if destination2 is not None:
|
||||
dst_orig_dtype = destination2._params.orig_dtype
|
||||
destination2._params.copy_from(destination._params, non_blocking=True)
|
||||
destination2._params.copy_from(destination._params if destination is not None else tensor._params, non_blocking=True)
|
||||
destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype)
|
||||
return True
|
||||
|
||||
@ -39,10 +37,15 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
if info is None:
|
||||
return False
|
||||
|
||||
if destination is not None and destination.device.type != "cpu" and destination2 is None:
|
||||
destination2 = destination
|
||||
destination = None
|
||||
|
||||
file_obj = info.file_ref
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or destination.numel() * destination.element_size() < info.size
|
||||
if (file_obj is None
|
||||
or (destination is None and destination2 is None)
|
||||
or (destination is not None and (destination.device.type != "cpu" or destination.numel() * destination.element_size() < info.size))
|
||||
or (destination2 is not None and (destination2.device.type == "cpu" or destination2.numel() * destination2.element_size() < info.size))
|
||||
or tensor.numel() * tensor.element_size() != info.size
|
||||
or tensor.storage_offset() != 0
|
||||
or not tensor.is_contiguous()):
|
||||
@ -51,6 +54,14 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
if info.size == 0:
|
||||
return True
|
||||
|
||||
if destination is None:
|
||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||
comfy_aimdo.host_buffer.read_file_to_device(file_obj, info.offset, info.size,
|
||||
stream_ptr, destination2.data_ptr(),
|
||||
destination2.device.index,
|
||||
mark_cold=False)
|
||||
return True
|
||||
|
||||
hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None)
|
||||
if hostbuf is not None:
|
||||
stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0
|
||||
@ -63,6 +74,9 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N
|
||||
device=None if destination2 is None else destination2.device.index)
|
||||
return True
|
||||
|
||||
if not hasattr(file_obj, "seek") or not hasattr(file_obj, "readinto"):
|
||||
return False
|
||||
|
||||
buf_type = ctypes.c_ubyte * info.size
|
||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||
|
||||
|
||||
@ -46,6 +46,7 @@ import comfy.ldm.wan.model_animate
|
||||
import comfy.ldm.wan.ar_model
|
||||
import comfy.ldm.wan.model_wandancer
|
||||
import comfy.ldm.hunyuan3d.model
|
||||
import comfy.ldm.triposplat.model
|
||||
import comfy.ldm.hidream.model
|
||||
import comfy.ldm.chroma.model
|
||||
import comfy.ldm.chroma_radiance.model
|
||||
@ -53,7 +54,10 @@ import comfy.ldm.pixeldit.model
|
||||
import comfy.ldm.pixeldit.pid
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.seedvr.model
|
||||
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.ideogram4.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
@ -926,6 +930,16 @@ class HunyuanDiT(BaseModel):
|
||||
out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]]))
|
||||
return out
|
||||
|
||||
class SeedVR2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
condition = kwargs.get("condition", None)
|
||||
if condition is not None:
|
||||
out["condition"] = comfy.conds.CONDRegular(condition)
|
||||
return out
|
||||
|
||||
class PixArt(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS)
|
||||
@ -1428,6 +1442,23 @@ class PiD(PixelDiTT2I):
|
||||
out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma)
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key == "lq_latent" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
lq = cond_value.cond
|
||||
dim = window.dim
|
||||
if dim >= lq.ndim:
|
||||
return None
|
||||
lq_proj = self.diffusion_model.lq_proj
|
||||
ratio = lq_proj.sr_scale * lq_proj.latent_spatial_down_factor
|
||||
# Map x window indices -> lq indices (deduplicated, sorted, in-bounds).
|
||||
lq_size = lq.size(dim)
|
||||
lq_indices = sorted({i // ratio for i in window.index_list if 0 <= i // ratio < lq_size})
|
||||
if not lq_indices:
|
||||
return None
|
||||
idx = tuple([slice(None)] * dim + [lq_indices])
|
||||
return cond_value._copy_with(lq[idx].to(device))
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
|
||||
class WAN21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
@ -1789,6 +1820,24 @@ class Hunyuan3Dv2_1(BaseModel):
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
return out
|
||||
|
||||
class TripoSplat(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.triposplat.model.LatentSeqMMFlowModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None) # DINOv3 token sequence -> cross-attention context.
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
ref_latents = kwargs.get("reference_latents", None) # Flux2 VAE image latent -> additive second conditioning.
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = comfy.conds.CONDList(list(ref_latents))
|
||||
latent_shapes = kwargs.get("latent_shapes", None) # {latent, camera} nested latent
|
||||
if latent_shapes is not None:
|
||||
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
|
||||
return out
|
||||
|
||||
|
||||
class HiDream(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
|
||||
@ -1982,6 +2031,21 @@ class QwenImage(BaseModel):
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
return out
|
||||
|
||||
class Ideogram4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class HunyuanImage21(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
|
||||
|
||||
@ -313,6 +313,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["use_x0"] = True
|
||||
else:
|
||||
dit_config["use_x0"] = False
|
||||
if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids
|
||||
dit_config["use_sequential_txt_ids"] = True
|
||||
else:
|
||||
dit_config["use_sequential_txt_ids"] = False
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
@ -594,6 +598,56 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
return dit_config
|
||||
|
||||
if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 3072
|
||||
dit_config["heads"] = 24
|
||||
dit_config["num_layers"] = 36
|
||||
# 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.``
|
||||
# submodules) at EVERY block — verified by inspecting the 7B
|
||||
# state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means
|
||||
# ``MMModule.shared_weights=False``). Native NaDiT computes
|
||||
# per-block ``shared_weights = not (i < mm_layers)``, so to keep
|
||||
# every block non-shared we set ``mm_layers = num_layers``.
|
||||
# Without this, blocks at index >= mm_layers (default 10) try to
|
||||
# load ``blocks.N.*.all.*`` keys that don't exist in the file,
|
||||
# silently miss-load → all-black output.
|
||||
dit_config["mm_layers"] = 36
|
||||
dit_config["norm_eps"] = 1e-5
|
||||
dit_config["qk_rope"] = True
|
||||
dit_config["rope_type"] = "rope3d"
|
||||
dit_config["rope_dim"] = 64
|
||||
dit_config["mlp_type"] = "normal"
|
||||
return dit_config
|
||||
elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 3072
|
||||
dit_config["heads"] = 24
|
||||
dit_config["num_layers"] = 36
|
||||
# This checkpoint layout carries shared ``all.`` MMModule keys.
|
||||
# Preserve the historical split: the initial blocks use separate
|
||||
# vid/txt modules, later blocks use shared modules.
|
||||
dit_config["mm_layers"] = 10
|
||||
dit_config["norm_eps"] = 1e-5
|
||||
dit_config["qk_rope"] = True
|
||||
dit_config["rope_type"] = "rope3d"
|
||||
dit_config["rope_dim"] = 64
|
||||
dit_config["mlp_type"] = "swiglu"
|
||||
return dit_config
|
||||
elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "seedvr2"
|
||||
dit_config["vid_dim"] = 2560
|
||||
dit_config["heads"] = 20
|
||||
dit_config["num_layers"] = 32
|
||||
dit_config["norm_eps"] = 1.0e-05
|
||||
dit_config["qk_rope"] = None
|
||||
dit_config["mlp_type"] = "swiglu"
|
||||
dit_config["vid_out_norm"] = True
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "wan2.1"
|
||||
@ -676,6 +730,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
|
||||
return dit_config
|
||||
|
||||
if '{}cam_out_layer.weight'.format(key_prefix) in state_dict_keys and '{}repo_layers.0.final_map.weight'.format(key_prefix) in state_dict_keys: # TripoSplat
|
||||
return {"image_model": "triposplat"}
|
||||
|
||||
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
|
||||
return {"image_model": "hidream_o1"}
|
||||
|
||||
@ -808,6 +865,13 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["default_ref_method"] = "negative_index"
|
||||
return dit_config
|
||||
|
||||
if '{}embed_image_indicator.weight'.format(key_prefix) in state_dict_keys: # Ideogram 4
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ideogram4"
|
||||
dit_config["in_channels"] = state_dict['{}input_proj.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||
dit_config = {}
|
||||
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
|
||||
@ -641,15 +641,17 @@ def free_pins(size, evict_active=False):
|
||||
return freed_total
|
||||
|
||||
def ensure_pin_budget(size, evict_active=False):
|
||||
shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available
|
||||
if args.fast_disk:
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
else:
|
||||
shortfall = size + max(comfy.memory_management.RAM_CACHE_HEADROOM / 2, 2048 * 1024 ** 2) - psutil.virtual_memory().available
|
||||
if shortfall <= 0:
|
||||
return True
|
||||
|
||||
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
|
||||
return free_pins(to_free, evict_active=evict_active) >= shortfall
|
||||
|
||||
def ensure_pin_registerable(size, evict_active=False):
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
def free_registrations(shortfall, evict_active=True):
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
if shortfall <= 0:
|
||||
@ -658,12 +660,22 @@ def ensure_pin_registerable(size, evict_active=False):
|
||||
shortfall += REGISTERABLE_PIN_HYSTERESIS
|
||||
for loaded_model in reversed(current_loaded_models):
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]):
|
||||
if model is not None and model.is_dynamic() and not model.model.dynamic_pins[model.load_device]["active"]:
|
||||
shortfall -= model.unregister_inactive_pins(shortfall)
|
||||
if shortfall <= 0:
|
||||
return True
|
||||
if evict_active:
|
||||
for loaded_model in current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic() and model.model.dynamic_pins[model.load_device]["active"]:
|
||||
shortfall -= model.unregister_inactive_pins(shortfall)
|
||||
if shortfall <= 0:
|
||||
return True
|
||||
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
||||
|
||||
def ensure_pin_registerable(size, evict_active=True):
|
||||
return free_registrations(TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY, evict_active=evict_active)
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model: ModelPatcher):
|
||||
self._set_model(model)
|
||||
@ -803,9 +815,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None):
|
||||
if not DISABLE_SMART_MEMORY or device is None:
|
||||
memory_to_free = 0 if device is None else memory_required - get_free_memory(device)
|
||||
if for_dynamic:
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
memory_required -= current_loaded_models[i].model.loaded_size()
|
||||
@ -817,6 +829,10 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
|
||||
if not for_dynamic and pins_required > 0:
|
||||
ensure_pin_budget(pins_required)
|
||||
ensure_pin_registerable(pins_required)
|
||||
|
||||
if len(unloaded_model) > 0:
|
||||
soft_empty_cache()
|
||||
elif device is not None:
|
||||
@ -879,15 +895,19 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
model_to_unload.model_finalizer.detach()
|
||||
|
||||
total_memory_required = {}
|
||||
total_pins_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
device = loaded_model.device
|
||||
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
||||
if not loaded_model.model.is_dynamic():
|
||||
total_pins_required[device] = total_pins_required.get(device, 0) + loaded_model.model_memory()
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
||||
device,
|
||||
for_dynamic=free_for_dynamic)
|
||||
for_dynamic=free_for_dynamic,
|
||||
pins_required=total_pins_required.get(device, 0))
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
@ -1283,7 +1303,6 @@ STREAM_CAST_BUFFERS = {}
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_AIMDO_CAST_BUFFERS = {}
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
STREAM_PIN_BUFFERS = {}
|
||||
|
||||
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||
|
||||
@ -1326,42 +1345,13 @@ def get_aimdo_cast_buffer(offload_stream, device):
|
||||
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||
return cast_buffer
|
||||
|
||||
def get_pin_buffer(offload_stream):
|
||||
pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None)
|
||||
if pin_buffer is None:
|
||||
pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3), mark_cold=False)
|
||||
STREAM_PIN_BUFFERS[offload_stream] = pin_buffer
|
||||
elif offload_stream is not None:
|
||||
event = getattr(pin_buffer, "_comfy_event", None)
|
||||
if event is not None:
|
||||
event.synchronize()
|
||||
delattr(pin_buffer, "_comfy_event")
|
||||
return pin_buffer
|
||||
|
||||
def resize_pin_buffer(pin_buffer, size):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
old_size = pin_buffer.size
|
||||
if size <= old_size:
|
||||
return True
|
||||
growth = size - old_size
|
||||
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
||||
ensure_pin_budget(growth, evict_active=True)
|
||||
ensure_pin_registerable(growth, evict_active=True)
|
||||
try:
|
||||
pin_buffer.extend(size=size, reallocate=True)
|
||||
except RuntimeError:
|
||||
return False
|
||||
TOTAL_PINNED_MEMORY += pin_buffer.size - old_size
|
||||
return True
|
||||
|
||||
def reset_cast_buffers():
|
||||
global TOTAL_PINNED_MEMORY
|
||||
global LARGEST_CASTED_WEIGHT
|
||||
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS):
|
||||
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
|
||||
if offload_stream is not None:
|
||||
offload_stream.synchronize()
|
||||
synchronize()
|
||||
@ -1370,20 +1360,24 @@ def reset_cast_buffers():
|
||||
mmap_obj.bounce()
|
||||
DIRTY_MMAPS.clear()
|
||||
|
||||
for pin_buffer in STREAM_PIN_BUFFERS.values():
|
||||
TOTAL_PINNED_MEMORY -= pin_buffer.size
|
||||
TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY)
|
||||
|
||||
for loaded_model in current_loaded_models:
|
||||
model = loaded_model.model
|
||||
if model is not None and model.is_dynamic():
|
||||
model.model.dynamic_pins[model.load_device]["active"] = False
|
||||
pin_state = model.model.dynamic_pins[model.load_device]
|
||||
|
||||
if pin_state["active"]:
|
||||
*_, buckets = pin_state["weights"]
|
||||
for size, bucket in list(buckets.items()):
|
||||
bucket[:] = [ entry for entry in bucket if entry[-1] is not None ]
|
||||
if not bucket:
|
||||
del buckets[size]
|
||||
|
||||
pin_state["active"] = False
|
||||
model.partially_unload_ram(1e30, subsets=[ "patches" ])
|
||||
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0])
|
||||
model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0], [0], {})
|
||||
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||
STREAM_PIN_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
def get_offload_stream(device):
|
||||
@ -1436,7 +1430,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
|
||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r)
|
||||
dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) if r is not None else [None] * len(tensors)
|
||||
dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None
|
||||
with wf_context:
|
||||
for tensor in tensors:
|
||||
@ -1448,9 +1442,10 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None):
|
||||
continue
|
||||
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||
mark_mmap_dirty(storage)
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
if dest_view is not None:
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
if dest2_view is not None:
|
||||
dest2_view.copy_(dest_view, non_blocking=non_blocking)
|
||||
dest2_view.copy_(tensor if dest_view is None else dest_view, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
|
||||
@ -1723,6 +1718,13 @@ def is_device_xpu(device):
|
||||
def is_device_cuda(device):
|
||||
return is_device_type(device, 'cuda')
|
||||
|
||||
def set_torch_device(device):
|
||||
"""Set the current device for the given torch device. Supports CUDA and XPU."""
|
||||
if is_device_cuda(device):
|
||||
torch.cuda.set_device(device)
|
||||
elif is_device_xpu(device):
|
||||
torch.xpu.set_device(device)
|
||||
|
||||
def is_directml_enabled():
|
||||
global directml_enabled
|
||||
if directml_enabled:
|
||||
|
||||
@ -1721,8 +1721,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
"""
|
||||
if device not in self.model.dynamic_pins:
|
||||
self.model.dynamic_pins[device] = {
|
||||
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
||||
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]),
|
||||
"weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
|
||||
"patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}),
|
||||
"hostbufs_initialized": False,
|
||||
"failed": False,
|
||||
"active": False,
|
||||
@ -1799,8 +1799,8 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
if not pin_state["hostbufs_initialized"]:
|
||||
hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size())
|
||||
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0])
|
||||
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0])
|
||||
pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
|
||||
pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {})
|
||||
pin_state["hostbufs_initialized"] = True
|
||||
pin_state["failed"] = False
|
||||
pin_state["active"] = True
|
||||
@ -1942,18 +1942,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
return freed
|
||||
|
||||
def loaded_ram_size(self):
|
||||
return (self.model.dynamic_pins[self.load_device]["weights"][0].size +
|
||||
self.model.dynamic_pins[self.load_device]["patches"][0].size)
|
||||
return (self.model.dynamic_pins[self.load_device]["weights"][0].size)
|
||||
|
||||
def pinned_memory_size(self):
|
||||
return (self.model.dynamic_pins[self.load_device]["weights"][3][0] +
|
||||
self.model.dynamic_pins[self.load_device]["patches"][3][0])
|
||||
return (self.model.dynamic_pins[self.load_device]["weights"][3][0])
|
||||
|
||||
def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]):
|
||||
freed = 0
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
for subset in subsets:
|
||||
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
|
||||
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
|
||||
split = stack_split[0]
|
||||
while split >= 0:
|
||||
module, offset = stack[split]
|
||||
@ -1978,10 +1976,12 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
freed = 0
|
||||
pin_state = self.model.dynamic_pins[self.load_device]
|
||||
for subset in subsets:
|
||||
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
|
||||
hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset]
|
||||
while len(stack) > 0:
|
||||
module, offset = stack.pop()
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
module._pin_balancer_entry[-1] = None
|
||||
del module._pin_balancer_entry
|
||||
del module._pin
|
||||
hostbuf.truncate(offset, do_unregister=module._pin_registered)
|
||||
stack_split[0] = min(stack_split[0], len(stack) - 1)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import comfy_aimdo.model_vbar
|
||||
import comfy.memory_management
|
||||
import comfy.model_management
|
||||
import comfy.ops
|
||||
|
||||
@ -50,7 +51,17 @@ def prefetch_queue_pop(queue, device, module):
|
||||
if hasattr(s, "_v"):
|
||||
comfy_modules.append(s)
|
||||
|
||||
registerable_size = 0
|
||||
for s in comfy_modules:
|
||||
registerable_size += comfy.memory_management.vram_aligned_size([s.weight, s.bias])
|
||||
for param_key in ("weight", "bias"):
|
||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||
if lowvram_fn is not None:
|
||||
registerable_size += lowvram_fn.memory_required()
|
||||
|
||||
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||
if not comfy.model_management.args.fast_disk:
|
||||
comfy.model_management.ensure_pin_registerable(registerable_size)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
queue[0] = (offload_stream, (prefetch, comfy_modules))
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ class MultiGPUThreadPool:
|
||||
"""Persistent thread pool for multi-GPU work distribution.
|
||||
|
||||
Maintains one worker thread per extra GPU device. Each thread calls
|
||||
torch.cuda.set_device() once at startup so that compiled kernel caches
|
||||
set_torch_device() once at startup so that compiled kernel caches
|
||||
(inductor/triton) stay warm across diffusion steps.
|
||||
"""
|
||||
|
||||
@ -37,7 +37,7 @@ class MultiGPUThreadPool:
|
||||
|
||||
def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue):
|
||||
try:
|
||||
torch.cuda.set_device(device)
|
||||
comfy.model_management.set_torch_device(device)
|
||||
except Exception as e:
|
||||
logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}")
|
||||
while True:
|
||||
@ -54,6 +54,8 @@ class MultiGPUThreadPool:
|
||||
try:
|
||||
result = fn(*args, **kwargs)
|
||||
result_q.put((result, None))
|
||||
except comfy.model_management.InterruptProcessingException as e:
|
||||
result_q.put((None, e))
|
||||
except Exception as e:
|
||||
result_q.put((None, e))
|
||||
|
||||
|
||||
66
comfy/ops.py
66
comfy/ops.py
@ -76,8 +76,6 @@ except:
|
||||
|
||||
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
|
||||
|
||||
STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
@ -94,9 +92,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
offload_stream = None
|
||||
cast_buffer = None
|
||||
cast_buffer_offset = 0
|
||||
stream_pin_hostbuf = None
|
||||
stream_pin_offset = 0
|
||||
stream_pin_queue = []
|
||||
|
||||
def ensure_offload_stream(module, required_size, check_largest):
|
||||
nonlocal offload_stream
|
||||
@ -130,22 +125,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
cast_buffer_offset += buffer_size
|
||||
return buffer
|
||||
|
||||
def get_stream_pin_buffer_offset(buffer_size):
|
||||
nonlocal stream_pin_hostbuf
|
||||
nonlocal stream_pin_offset
|
||||
|
||||
if buffer_size == 0 or offload_stream is None:
|
||||
return None
|
||||
|
||||
if stream_pin_hostbuf is None:
|
||||
stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream)
|
||||
if stream_pin_hostbuf is None:
|
||||
return None
|
||||
|
||||
offset = stream_pin_offset
|
||||
stream_pin_offset += buffer_size
|
||||
return offset
|
||||
|
||||
for s in comfy_modules:
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||
@ -184,12 +163,18 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
if xfer_dest is None:
|
||||
xfer_dest = get_cast_buffer(dest_size)
|
||||
|
||||
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream):
|
||||
def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream, xfer_dest2=None):
|
||||
if xfer_source is not None:
|
||||
if getattr(xfer_source, "is_lowvram_patch", False):
|
||||
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
||||
else:
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream)
|
||||
if xfer_dest is not None:
|
||||
xfer_source.prepare(xfer_dest, stream, copy=True, commit=False)
|
||||
xfer_source = [ xfer_dest ]
|
||||
xfer_dest = xfer_dest2
|
||||
xfer_dest2 = None
|
||||
elif xfer_dest2 is not None:
|
||||
xfer_source.prepare(xfer_dest2, stream, copy=True, commit=False)
|
||||
return
|
||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream, r2=xfer_dest2)
|
||||
|
||||
def handle_pin(m, pin, source, dest, subset="weights", size=None):
|
||||
if pin is not None:
|
||||
@ -198,19 +183,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
if signature is None:
|
||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||
if pin is not None:
|
||||
if isinstance(source, list):
|
||||
comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest)
|
||||
else:
|
||||
cast_maybe_lowvram_patch(source, pin, None)
|
||||
cast_maybe_lowvram_patch([ pin ], dest, offload_stream)
|
||||
return
|
||||
if pin is None:
|
||||
pin_offset = get_stream_pin_buffer_offset(size)
|
||||
if pin_offset is not None:
|
||||
stream_pin_queue.append((source, pin_offset, size, dest))
|
||||
return
|
||||
cast_maybe_lowvram_patch(source, dest, offload_stream)
|
||||
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
|
||||
|
||||
handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size)
|
||||
|
||||
@ -232,23 +205,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
prefetch["needs_cast"] = needs_cast
|
||||
s._prefetch = prefetch
|
||||
|
||||
if stream_pin_offset > 0:
|
||||
if stream_pin_hostbuf.size < stream_pin_offset:
|
||||
if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM):
|
||||
for xfer_source, _, _, xfer_dest in stream_pin_queue:
|
||||
cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream)
|
||||
return offload_stream
|
||||
stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf)
|
||||
stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf
|
||||
for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue:
|
||||
pin = stream_pin_tensor[pin_offset:pin_offset + pin_size]
|
||||
if isinstance(xfer_source, list):
|
||||
comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest)
|
||||
else:
|
||||
cast_maybe_lowvram_patch(xfer_source, pin, None)
|
||||
comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||
stream_pin_hostbuf._comfy_event = offload_stream.record_event()
|
||||
|
||||
return offload_stream
|
||||
|
||||
|
||||
|
||||
@ -1,17 +1,55 @@
|
||||
import bisect
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.memory_management
|
||||
import comfy.utils
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.torch
|
||||
import torch
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
def _add_to_bucket(module, buckets, size, priority):
|
||||
bucket = buckets.setdefault(size, [])
|
||||
entry = [-priority, 0, module]
|
||||
entry[1] = id(entry)
|
||||
bisect.insort(bucket, entry)
|
||||
module._pin_balancer_entry = entry
|
||||
|
||||
def _steal_pin(module, stack, buckets, size, priority):
|
||||
bucket = buckets.get(size)
|
||||
if bucket is None:
|
||||
return False
|
||||
|
||||
while bucket and bucket[-1][-1] is None:
|
||||
bucket.pop()
|
||||
if not bucket:
|
||||
del buckets[size]
|
||||
return False
|
||||
|
||||
if priority <= -bucket[-1][0]:
|
||||
return False
|
||||
|
||||
*_, victim = bucket.pop()
|
||||
module._pin = victim._pin
|
||||
module._pin_registered = victim._pin_registered
|
||||
module._pin_stack_index = victim._pin_stack_index
|
||||
stack[module._pin_stack_index] = (module, stack[module._pin_stack_index][1])
|
||||
|
||||
victim._pin_registered = False
|
||||
del victim._pin
|
||||
del victim._pin_stack_index
|
||||
del victim._pin_balancer_entry
|
||||
|
||||
_add_to_bucket(module, buckets, size, priority)
|
||||
return True
|
||||
|
||||
def get_pin(module, subset="weights"):
|
||||
pin = getattr(module, "_pin", None)
|
||||
if pin is None or module._pin_registered or args.disable_pinned_memory:
|
||||
return pin
|
||||
|
||||
_, _, stack_split, pinned_size = module._pin_state[subset]
|
||||
_, _, stack_split, pinned_size, *_ = module._pin_state[subset]
|
||||
size = pin.nbytes
|
||||
comfy.model_management.ensure_pin_registerable(size)
|
||||
|
||||
@ -31,33 +69,51 @@ def pin_memory(module, subset="weights", size=None):
|
||||
return
|
||||
|
||||
pin = get_pin(module, subset)
|
||||
if pin is not None or pin_state["failed"]:
|
||||
if pin is not None:
|
||||
return
|
||||
|
||||
hostbuf, stack, stack_split, pinned_size = pin_state[subset]
|
||||
hostbuf, stack, stack_split, pinned_size, counter, buckets = pin_state[subset]
|
||||
if size is None:
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
offset = hostbuf.size
|
||||
registerable_size = size + max(0, hostbuf.size - pinned_size[0])
|
||||
registerable_size = size
|
||||
priority = getattr(module, "_pin_balancer_priority", None)
|
||||
|
||||
if priority is None:
|
||||
priority = comfy.utils.bit_reverse_range(counter[0], 16)
|
||||
counter[0] += 1
|
||||
module._pin_balancer_priority = priority
|
||||
|
||||
comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM)
|
||||
if (not comfy.model_management.ensure_pin_budget(size) or
|
||||
not comfy.model_management.ensure_pin_registerable(registerable_size)):
|
||||
pin_state["failed"] = True
|
||||
return False
|
||||
return _steal_pin(module, stack, buckets, size, priority)
|
||||
|
||||
extended = False
|
||||
try:
|
||||
hostbuf.extend(size=size)
|
||||
hostbuf.extend(size=size, register=False)
|
||||
extended = True
|
||||
pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
||||
pin.untyped_storage()._comfy_hostbuf = hostbuf
|
||||
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
comfy.model_management.free_registrations(size)
|
||||
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
del pin
|
||||
hostbuf.truncate(offset, do_unregister=False)
|
||||
return _steal_pin(module, stack, buckets, size, priority)
|
||||
except RuntimeError:
|
||||
pin_state["failed"] = True
|
||||
return False
|
||||
if extended:
|
||||
hostbuf.truncate(offset, do_unregister=False)
|
||||
return _steal_pin(module, stack, buckets, size, priority)
|
||||
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
||||
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
|
||||
module._pin = pin
|
||||
stack.append((module, offset))
|
||||
module._pin_registered = True
|
||||
module._pin_stack_index = len(stack) - 1
|
||||
stack_split[0] = max(stack_split[0], module._pin_stack_index)
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||
pinned_size[0] += size
|
||||
_add_to_bucket(module, buckets, size, priority)
|
||||
return True
|
||||
|
||||
@ -44,7 +44,13 @@ def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None,
|
||||
is_empty = torch.count_nonzero(latent_image) == 0
|
||||
if is_empty:
|
||||
if latent_format.latent_channels != latent_image.shape[1]:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
preserves_collapsed_channels = (
|
||||
getattr(latent_format, "preserve_empty_channel_multiples", False)
|
||||
and latent_image.ndim == 4
|
||||
and latent_image.shape[1] % latent_format.latent_channels == 0
|
||||
)
|
||||
if not preserves_collapsed_channels:
|
||||
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
|
||||
if downscale_ratio_spacial is not None:
|
||||
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
|
||||
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio
|
||||
|
||||
@ -464,10 +464,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t
|
||||
|
||||
def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]):
|
||||
try:
|
||||
# TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once
|
||||
# we extend multigpu QA beyond CUDA. Unconditional call crashes on
|
||||
# XPU/NPU/MPS/CPU/DirectML backends.
|
||||
torch.cuda.set_device(device)
|
||||
comfy.model_management.set_torch_device(device)
|
||||
model_current: BaseModel = model_options["multigpu_clones"][device].model
|
||||
# run every hooked_to_run separately
|
||||
with torch.no_grad():
|
||||
|
||||
258
comfy/sd.py
258
comfy/sd.py
@ -1,3 +1,4 @@
|
||||
import inspect
|
||||
import json
|
||||
import torch
|
||||
from enum import Enum
|
||||
@ -16,6 +17,8 @@ import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.seedvr.vae
|
||||
import comfy.ldm.triposplat.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
import comfy.ldm.cogvideo.vae
|
||||
import comfy.ldm.hunyuan_video.vae
|
||||
@ -57,6 +60,7 @@ import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.ideogram4
|
||||
import comfy.text_encoders.ovis
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.jina_clip_2
|
||||
@ -82,6 +86,36 @@ import comfy.latent_formats
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL = 160
|
||||
|
||||
|
||||
def _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w):
|
||||
output_t = max(1, (latent_t - 1) * 4 + 1)
|
||||
return output_t * latent_h * 8 * latent_w * 8
|
||||
|
||||
|
||||
def _seedvr2_vae_decode_memory_used(shape):
|
||||
if len(shape) == 5:
|
||||
candidates = []
|
||||
if shape[1] == 16:
|
||||
candidates.append((shape[2], shape[3], shape[4]))
|
||||
if shape[-1] == 16:
|
||||
candidates.append((shape[1], shape[2], shape[3]))
|
||||
if len(candidates) == 0:
|
||||
candidates.append((shape[2], shape[3], shape[4]))
|
||||
output_pixels = max(_seedvr2_vae_decode_output_pixels(*candidate) for candidate in candidates)
|
||||
elif len(shape) == 4:
|
||||
latent_t = max(1, (shape[1] + 15) // 16)
|
||||
latent_h, latent_w = shape[2], shape[3]
|
||||
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
|
||||
else:
|
||||
latent_t, latent_h, latent_w = 1, shape[-2], shape[-1]
|
||||
output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w)
|
||||
# SeedVR2 decode performs full-frame LAB histogram matching: fp32 channels
|
||||
# plus int64 sort indices dominate peak memory, not the VAE weight dtype.
|
||||
return output_pixels * SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
@ -465,8 +499,10 @@ class CLIP:
|
||||
|
||||
class VAE:
|
||||
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
|
||||
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd
|
||||
if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
|
||||
if metadata is None or metadata.get("keep_diffusers_format") != "true":
|
||||
sd = diffusers_convert.convert_vae_state_dict(sd)
|
||||
|
||||
if model_management.is_amd():
|
||||
VAE_KL_MEM_RATIO = 2.73
|
||||
@ -538,6 +574,20 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
|
||||
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
|
||||
self.latent_channels = 16
|
||||
self.latent_dim = 3
|
||||
self.disable_offload = True
|
||||
self.memory_used_decode = lambda shape, dtype: _seedvr2_vae_decode_memory_used(shape)
|
||||
self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.process_input = lambda image: image * 2.0 - 1.0
|
||||
self.crop_input = False
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
@ -665,6 +715,7 @@ class VAE:
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||
self.downscale_index_formula = (8, 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
@ -894,6 +945,16 @@ class VAE:
|
||||
#Force cast it for --disable-dynamic-vram users until there is a true core fix.
|
||||
if not comfy.memory_management.aimdo_enabled:
|
||||
self.disable_offload = True
|
||||
elif "gs.base_offset_scale" in sd and "octree.out_proj.weight" in sd: # TripoSplat octree gaussian decoder
|
||||
self.first_stage_model = comfy.ldm.triposplat.vae.OctreeGaussianDecoder()
|
||||
self.latent_channels = 16
|
||||
self.latent_dim = 1
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
# The generic VAE.encode/decode path isn't used: VAEDecodeTripoSplat calls the gaussian
|
||||
# decoder directly (structured GaussianSplat objects, not a tensor and reserves VRAM itself from num_gaussians.
|
||||
def _no_generic_io(*args, **kwargs):
|
||||
raise RuntimeError("TripoSplat gaussian decoder: use the 'TripoSplat Decode' (VAEDecodeTripoSplat)")
|
||||
self.memory_used_encode = self.memory_used_decode = _no_generic_io
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@ -994,6 +1055,40 @@ class VAE:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_seedvr2(self, samples, tile_x=32, tile_y=32, overlap=8, tile_t=16, overlap_t=4):
|
||||
sf_s = getattr(self.first_stage_model, "spatial_downsample_factor", 8)
|
||||
sf_t = getattr(self.first_stage_model, "temporal_downsample_factor", 4)
|
||||
if tile_t is None:
|
||||
tile_t = 16
|
||||
if overlap_t is None:
|
||||
overlap_t = 4
|
||||
if tile_t > 0:
|
||||
temporal_size = tile_t * sf_t
|
||||
temporal_overlap = max(0, overlap_t) * sf_t
|
||||
else:
|
||||
temporal_size = 0
|
||||
temporal_overlap = 0
|
||||
args = {
|
||||
"enable_tiling": True,
|
||||
"tile_size": (tile_y * sf_s, tile_x * sf_s),
|
||||
"tile_overlap": (overlap * sf_s, overlap * sf_s),
|
||||
"temporal_size": temporal_size,
|
||||
"temporal_overlap": temporal_overlap,
|
||||
}
|
||||
output = self.first_stage_model.decode(
|
||||
samples.to(self.vae_dtype).to(self.device),
|
||||
seedvr2_tiling=args,
|
||||
)
|
||||
return self.process_output(output.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
|
||||
|
||||
def _format_seedvr2_encoded_samples(self, samples):
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
if samples.ndim == 4:
|
||||
samples = samples.unsqueeze(2)
|
||||
samples = samples.contiguous()
|
||||
samples = samples * 0.9152
|
||||
return samples
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
@ -1030,6 +1125,36 @@ class VAE:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def encode_tiled_seedvr2(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
if tile_y is None:
|
||||
tile_y = 512
|
||||
if tile_x is None:
|
||||
tile_x = 512
|
||||
if overlap is None:
|
||||
overlap_y = 64
|
||||
overlap_x = 64
|
||||
else:
|
||||
overlap_y = overlap
|
||||
overlap_x = overlap
|
||||
if tile_t is None:
|
||||
tile_t = 9999
|
||||
if overlap_t is None:
|
||||
overlap_t = 0
|
||||
overlap_y = min(overlap_y, max(0, tile_y - 8))
|
||||
overlap_x = min(overlap_x, max(0, tile_x - 8))
|
||||
self.first_stage_model.device = self.device
|
||||
x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device)
|
||||
output = comfy.ldm.seedvr.vae.tiled_vae(
|
||||
x,
|
||||
self.first_stage_model,
|
||||
tile_size=(tile_y, tile_x),
|
||||
tile_overlap=(overlap_y, overlap_x),
|
||||
temporal_size=tile_t,
|
||||
temporal_overlap=overlap_t,
|
||||
encode=True,
|
||||
)
|
||||
return output.to(device=self.output_device, dtype=self.vae_output_dtype())
|
||||
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
@ -1077,16 +1202,40 @@ class VAE:
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||
elif dims == 2:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
# SeedVR2 latents arrive in 4D collapsed form ``(B, 16*T, H, W)``
|
||||
# downstream of ``SeedVR2Conditioning`` (which performs the
|
||||
# ``rearrange(b c t h w -> b (c t) h w)`` collapse). The
|
||||
# generic ``decode_tiled_`` would treat the channel dim as
|
||||
# spatial-only and crash on the collapsed (16, T) layout
|
||||
# under ``tiled_scale``'s mask broadcast; route SeedVR2 4D
|
||||
# latents to ``decode_tiled_seedvr2`` instead, whose wrapper
|
||||
# dispatch handles both 4D and 5D inputs.
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
pixel_samples = self.decode_tiled_(samples_in)
|
||||
elif dims == 3:
|
||||
tile = 256 // self.spacial_compression_decode()
|
||||
overlap = tile // 4
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
def decode_tiled(
|
||||
self,
|
||||
samples,
|
||||
tile_x=None,
|
||||
tile_y=None,
|
||||
overlap=None,
|
||||
tile_t=None,
|
||||
overlap_t=None,
|
||||
):
|
||||
self.throw_exception_if_invalid()
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
@ -1100,7 +1249,20 @@ class VAE:
|
||||
args["overlap"] = overlap
|
||||
|
||||
with model_management.cuda_device_context(self.device):
|
||||
if dims == 1 or self.extra_1d_channel is not None:
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper) and dims in (2, 3):
|
||||
seedvr2_args = {}
|
||||
if tile_x is not None:
|
||||
seedvr2_args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
seedvr2_args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
seedvr2_args["overlap"] = overlap
|
||||
if tile_t is not None:
|
||||
seedvr2_args["tile_t"] = tile_t
|
||||
if overlap_t is not None:
|
||||
seedvr2_args["overlap_t"] = overlap_t
|
||||
output = self.decode_tiled_seedvr2(samples, **seedvr2_args)
|
||||
elif dims == 1 or self.extra_1d_channel is not None:
|
||||
args.pop("tile_y")
|
||||
output = self.decode_tiled_1d(samples, **args)
|
||||
elif dims == 2:
|
||||
@ -1142,6 +1304,8 @@ class VAE:
|
||||
else:
|
||||
pixels_in = pixels_in.to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in)
|
||||
if isinstance(out, tuple):
|
||||
out = out[0]
|
||||
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
@ -1161,20 +1325,23 @@ class VAE:
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
samples = self.encode_tiled_seedvr2(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
else:
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1 or self.extra_1d_channel is not None:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
return samples
|
||||
return self._format_seedvr2_encoded_samples(samples)
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
dims = self.latent_dim
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
if dims == 3:
|
||||
if dims == 3 and pixel_samples.ndim < 5:
|
||||
if not self.not_video:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
else:
|
||||
@ -1198,22 +1365,47 @@ class VAE:
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper):
|
||||
seedvr2_args = {}
|
||||
if tile_x is not None:
|
||||
seedvr2_args["tile_x"] = tile_x
|
||||
else:
|
||||
seedvr2_args["tile_x"] = 512
|
||||
if tile_y is not None:
|
||||
seedvr2_args["tile_y"] = tile_y
|
||||
else:
|
||||
seedvr2_args["tile_y"] = 512
|
||||
if overlap is not None:
|
||||
seedvr2_args["overlap"] = overlap
|
||||
else:
|
||||
seedvr2_args["overlap"] = 64
|
||||
if tile_t is not None:
|
||||
seedvr2_args["tile_t"] = tile_t
|
||||
else:
|
||||
seedvr2_args["tile_t"] = 9999
|
||||
if overlap_t is not None:
|
||||
seedvr2_args["overlap_t"] = overlap_t
|
||||
else:
|
||||
seedvr2_args["overlap_t"] = 0
|
||||
samples = self.encode_tiled_seedvr2(pixel_samples, **seedvr2_args)
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
spatial_overlap = overlap if overlap is not None else 64
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, spatial_overlap, spatial_overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
|
||||
return samples
|
||||
return self._format_seedvr2_encoded_samples(samples)
|
||||
|
||||
def get_sd(self):
|
||||
return self.first_stage_model.state_dict()
|
||||
@ -1287,6 +1479,7 @@ class CLIPType(Enum):
|
||||
COGVIDEOX = 27
|
||||
LENS = 28
|
||||
PIXELDIT = 29
|
||||
IDEOGRAM4 = 30
|
||||
|
||||
|
||||
|
||||
@ -1585,8 +1778,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||
elif te_model == TEModel.QWEN3_8B:
|
||||
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b")
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B
|
||||
if clip_type == CLIPType.IDEOGRAM4:
|
||||
clip_target.clip = comfy.text_encoders.ideogram4.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Tokenizer
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b")
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B
|
||||
elif te_model == TEModel.JINA_CLIP_2:
|
||||
clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper
|
||||
clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper
|
||||
@ -1735,6 +1932,17 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (model, clip, vae)
|
||||
|
||||
|
||||
def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device):
|
||||
set_dtype = model_config.set_inference_dtype
|
||||
parameters = inspect.signature(set_dtype).parameters
|
||||
supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values())
|
||||
if supports_device:
|
||||
set_dtype(dtype, manual_cast_dtype, device=device)
|
||||
else:
|
||||
set_dtype(dtype, manual_cast_dtype)
|
||||
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False):
|
||||
sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic)
|
||||
@ -1842,7 +2050,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
|
||||
|
||||
if model_config.clip_vision_prefix is not None:
|
||||
if output_clipvision:
|
||||
@ -1983,7 +2191,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
_set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device)
|
||||
|
||||
if custom_operations is not None:
|
||||
model_config.custom_operations = custom_operations
|
||||
|
||||
@ -24,6 +24,7 @@ import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.ideogram4
|
||||
import comfy.text_encoders.anima
|
||||
import comfy.text_encoders.ace15
|
||||
import comfy.text_encoders.longcat_image
|
||||
@ -1538,6 +1539,30 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
|
||||
|
||||
latent_format = latent_formats.Hunyuan3Dv2mini
|
||||
|
||||
class TripoSplat(supported_models_base.BASE):
|
||||
# Image -> 3D gaussian splat flow denoiser
|
||||
unet_config = {
|
||||
"image_model": "triposplat",
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 0.6
|
||||
|
||||
latent_format = latent_formats.TripoSplat
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.TripoSplat(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class HiDream(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hidream",
|
||||
@ -1647,6 +1672,35 @@ class Chroma(supported_models_base.BASE):
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
|
||||
|
||||
class SeedVR2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "seedvr2"
|
||||
}
|
||||
latent_format = comfy.latent_formats.SeedVR2
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
sampling_settings = {
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
|
||||
if (
|
||||
dtype == torch.float16
|
||||
and manual_cast_dtype is None
|
||||
and comfy.model_management.should_use_bf16(device)
|
||||
):
|
||||
manual_cast_dtype = torch.bfloat16
|
||||
super().set_inference_dtype(dtype, manual_cast_dtype, device=device)
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SeedVR2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
unet_config = {
|
||||
"image_model": "chroma_radiance",
|
||||
@ -1722,6 +1776,44 @@ class Omnigen2(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect))
|
||||
|
||||
class Ideogram4(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ideogram4",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 1.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 11.6
|
||||
|
||||
unet_extra_config = {
|
||||
"num_attention_heads": 18,
|
||||
"attention_head_dim": 256,
|
||||
"intermediate_size": 12288,
|
||||
"adaln_dim": 512,
|
||||
"llm_features_dim": 53248,
|
||||
"rope_theta": 5000000,
|
||||
"mrope_section": [24, 20, 20],
|
||||
"norm_eps": 1e-5,
|
||||
}
|
||||
latent_format = latent_formats.Flux2
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Ideogram4(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ideogram4.Ideogram4Tokenizer, comfy.text_encoders.ideogram4.te(**hunyuan_detect))
|
||||
|
||||
class QwenImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "qwen_image",
|
||||
@ -1966,7 +2058,6 @@ class LongCatImage(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class RT_DETR_v4(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "RT_DETR_v4",
|
||||
@ -2200,14 +2291,17 @@ models = [
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
Hunyuan3Dv2_1,
|
||||
TripoSplat,
|
||||
HiDream,
|
||||
HiDreamO1,
|
||||
Chroma,
|
||||
SeedVR2,
|
||||
ChromaRadiance,
|
||||
ACEStep,
|
||||
ACEStep15,
|
||||
Omnigen2,
|
||||
QwenImage,
|
||||
Ideogram4,
|
||||
Flux2,
|
||||
Lens,
|
||||
Kandinsky5Image,
|
||||
|
||||
@ -115,7 +115,7 @@ class BASE:
|
||||
replace_prefix = {"": self.vae_key_prefix[0]}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype, device=None):
|
||||
self.unet_config['dtype'] = dtype
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
||||
|
||||
77
comfy/text_encoders/ideogram4.py
Normal file
77
comfy/text_encoders/ideogram4.py
Normal file
@ -0,0 +1,77 @@
|
||||
"""Ideogram 4 text encoder: Qwen3-VL-8B language model, 13-layer tap.
|
||||
|
||||
Ideogram 4 conditions on the concatenation of hidden states from 13 layers of
|
||||
Qwen3-VL (layers 0,3,...,33,35), giving a 4096*13 = 53248-dim feature per token.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from transformers import Qwen2Tokenizer
|
||||
|
||||
import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
|
||||
# Reference taps outputs of layers (0,3,...,35); comfy captures layer inputs, offset by +1.
|
||||
IDEOGRAM4_TAP_LAYERS = [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 36]
|
||||
|
||||
|
||||
class Qwen3VLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory,
|
||||
embedding_size=4096, embedding_key='qwen3vl_8b', tokenizer_class=Qwen2Tokenizer,
|
||||
has_start_token=False, has_end_token=False, pad_to_max_length=False,
|
||||
max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data,
|
||||
name="qwen3vl_8b", tokenizer=Qwen3VLTokenizer)
|
||||
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
|
||||
|
||||
# Qwen3-VL-8B = 5e6 (vs plain Qwen3-8B's 1e6)
|
||||
# final_norm/lm_head off -> Ideogram only reads raw tapped hidden states
|
||||
QWEN3VL_8B_CONFIG = {"rope_theta": 5000000.0, "final_norm": False, "lm_head": False}
|
||||
|
||||
|
||||
class Qwen3VL8BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=IDEOGRAM4_TAP_LAYERS, layer_idx=None,
|
||||
textmodel_json_config=dict(QWEN3VL_8B_CONFIG),
|
||||
dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False,
|
||||
model_class=comfy.text_encoders.llama.Qwen3_8B,
|
||||
enable_attention_masks=attention_mask, return_attention_masks=attention_mask,
|
||||
model_options=model_options)
|
||||
|
||||
|
||||
class Ideogram4TEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Qwen3VL8BModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096) stacked in ascending layer order.
|
||||
out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13). permute -> (B, seq, H, taps).
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Ideogram4TEModel_(Ideogram4TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Ideogram4TEModel_
|
||||
@ -85,9 +85,9 @@ _TYPES = {
|
||||
def load_safetensors(ckpt):
|
||||
import comfy_aimdo.model_mmap
|
||||
|
||||
f = open(ckpt, "rb", buffering=0)
|
||||
file_lock = threading.Lock()
|
||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||
f = model_mmap.get_file_handle()
|
||||
file_size = os.path.getsize(ckpt)
|
||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||
|
||||
@ -1452,3 +1452,10 @@ def deepcopy_list_dict(obj, memo=None):
|
||||
|
||||
memo[obj_id] = res
|
||||
return res
|
||||
|
||||
def bit_reverse_range(index, bits):
|
||||
result = 0
|
||||
for _ in range(bits):
|
||||
result = (result << 1) | (index & 1)
|
||||
index >>= 1
|
||||
return result
|
||||
|
||||
@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
from comfy_execution.utils import get_executing_context
|
||||
@ -143,6 +143,7 @@ class Types:
|
||||
VideoComponents = VideoComponents
|
||||
MESH = MESH
|
||||
VOXEL = VOXEL
|
||||
SPLAT = SPLAT
|
||||
File3D = File3D
|
||||
|
||||
|
||||
|
||||
@ -65,6 +65,12 @@ class VideoInput(ABC):
|
||||
buffer.seek(0)
|
||||
return buffer
|
||||
|
||||
def get_active_trim_window(self) -> tuple[float, float]:
|
||||
"""Return the active trim as ``(start_time, duration)`` in seconds (start_time normalized
|
||||
to ``>= 0``; ``duration == 0`` means "until the end"). Default: no trim; trimmable subclasses override.
|
||||
"""
|
||||
return 0.0, 0.0
|
||||
|
||||
# Provide a default implementation, but subclasses can provide optimized versions
|
||||
# if possible.
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
|
||||
@ -75,6 +75,12 @@ class VideoFromFile(VideoInput):
|
||||
self.__file.seek(0)
|
||||
return self.__file
|
||||
|
||||
def get_active_trim_window(self) -> tuple[float, float]:
|
||||
start_time = self.__start_time
|
||||
if start_time < 0:
|
||||
start_time = max(self._get_raw_duration() + start_time, 0.0)
|
||||
return float(start_time), float(self.__duration)
|
||||
|
||||
def get_dimensions(self) -> tuple[int, int]:
|
||||
"""
|
||||
Returns the dimensions of the video input.
|
||||
|
||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL, SVG as _SVG, File3D
|
||||
from ._util import MESH, VOXEL, SPLAT, SVG as _SVG, File3D
|
||||
|
||||
|
||||
class FolderType(str, Enum):
|
||||
@ -684,6 +684,10 @@ class Voxel(ComfyTypeIO):
|
||||
class Mesh(ComfyTypeIO):
|
||||
Type = MESH
|
||||
|
||||
@comfytype(io_type="SPLAT")
|
||||
class Splat(ComfyTypeIO):
|
||||
Type = SPLAT
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D")
|
||||
class File3DAny(ComfyTypeIO):
|
||||
@ -727,6 +731,42 @@ class File3DUSDZ(ComfyTypeIO):
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_PLY")
|
||||
class File3DPLY(ComfyTypeIO):
|
||||
"""PLY format 3D file - point cloud or Gaussian splat."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_SPLAT")
|
||||
class File3DSPLAT(ComfyTypeIO):
|
||||
"""SPLAT format 3D file - 3D Gaussian splat."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_SPZ")
|
||||
class File3DSPZ(ComfyTypeIO):
|
||||
"""SPZ format 3D file - compressed 3D Gaussian splat."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_KSPLAT")
|
||||
class File3DKSPLAT(ComfyTypeIO):
|
||||
"""KSPLAT format 3D file - 3D Gaussian splat."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_SPLAT_ANY")
|
||||
class File3DSplatAny(ComfyTypeIO):
|
||||
"""General 3D Gaussian splat file type - accepts any supported splat container (.ply / .spz / .splat / .ksplat)."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_POINT_CLOUD_ANY")
|
||||
class File3DPointCloudAny(ComfyTypeIO):
|
||||
"""General point cloud file type - accepts any supported point cloud container (currently .ply)."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="HOOKS")
|
||||
class Hooks(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
@ -762,14 +802,32 @@ class Accumulation(ComfyTypeIO):
|
||||
@comfytype(io_type="LOAD3D_CAMERA")
|
||||
class Load3DCamera(ComfyTypeIO):
|
||||
class CameraInfo(TypedDict):
|
||||
position: dict[str, float | int]
|
||||
target: dict[str, float | int]
|
||||
zoom: int
|
||||
cameraType: str
|
||||
# Coordinate system: right-handed, Y-up, camera looks down -Z
|
||||
position: dict[str, float | int] # scene units
|
||||
target: dict[str, float | int] # scene units; OrbitControls focus point
|
||||
zoom: float | int # dimensionless, 1 = 100%
|
||||
cameraType: str # 'perspective' | 'orthographic'
|
||||
quaternion: NotRequired[dict[str, float | int]] # normalized, dimensionless; camera world rotation
|
||||
fov: NotRequired[float | int] # degrees, vertical FOV (perspective only)
|
||||
aspect: NotRequired[float | int] # width / height (perspective only)
|
||||
near: NotRequired[float | int] # scene units
|
||||
far: NotRequired[float | int] # scene units
|
||||
frustum: NotRequired[dict[str, float | int]] # orthographic only: {left, right, top, bottom} in scene units
|
||||
|
||||
Type = CameraInfo
|
||||
|
||||
|
||||
@comfytype(io_type="LOAD3D_MODEL_INFO")
|
||||
class Load3DModelInfo(ComfyTypeIO):
|
||||
class Model3DTransform(TypedDict):
|
||||
# Coordinate system: right-handed, Y-up, world space
|
||||
position: dict[str, float | int] # scene units
|
||||
quaternion: dict[str, float | int] # normalized, dimensionless; world rotation
|
||||
scale: dict[str, float | int] # dimensionless multiplier
|
||||
|
||||
Type = list[Model3DTransform]
|
||||
|
||||
|
||||
@comfytype(io_type="LOAD_3D")
|
||||
class Load3D(ComfyTypeIO):
|
||||
"""3D models are stored as a dictionary."""
|
||||
@ -779,6 +837,7 @@ class Load3D(ComfyTypeIO):
|
||||
normal: str
|
||||
camera_info: Load3DCamera.CameraInfo
|
||||
recording: NotRequired[str]
|
||||
model_3d_info: NotRequired[list[Load3DModelInfo.Model3DTransform]]
|
||||
|
||||
Type = Model3DDict
|
||||
|
||||
@ -2277,6 +2336,7 @@ __all__ = [
|
||||
"LossMap",
|
||||
"Voxel",
|
||||
"Mesh",
|
||||
"Splat",
|
||||
"File3DAny",
|
||||
"File3DGLB",
|
||||
"File3DGLTF",
|
||||
@ -2284,6 +2344,12 @@ __all__ = [
|
||||
"File3DOBJ",
|
||||
"File3DSTL",
|
||||
"File3DUSDZ",
|
||||
"File3DPLY",
|
||||
"File3DSPLAT",
|
||||
"File3DSPZ",
|
||||
"File3DKSPLAT",
|
||||
"File3DSplatAny",
|
||||
"File3DPointCloudAny",
|
||||
"Hooks",
|
||||
"HookKeyframes",
|
||||
"TimestepsRange",
|
||||
@ -2291,6 +2357,7 @@ __all__ = [
|
||||
"FlowControl",
|
||||
"Accumulation",
|
||||
"Load3DCamera",
|
||||
"Load3DModelInfo",
|
||||
"Load3D",
|
||||
"Load3DAnimation",
|
||||
"Photomaker",
|
||||
|
||||
@ -285,7 +285,7 @@ class AudioSaveHelper:
|
||||
results = []
|
||||
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||
file = f"{filename_with_batch_num}_{counter:05}.{format}"
|
||||
output_path = os.path.join(full_output_folder, file)
|
||||
|
||||
# Use original sample rate initially
|
||||
@ -452,6 +452,16 @@ class PreviewUI3D(_UIOutput):
|
||||
return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
|
||||
|
||||
|
||||
class PreviewUI3DAdvanced(_UIOutput):
|
||||
def __init__(self, model_file, camera_info, model_3d_info):
|
||||
self.model_file = model_file
|
||||
self.camera_info = camera_info
|
||||
self.model_3d_info = model_3d_info
|
||||
|
||||
def as_dict(self):
|
||||
return {"result": [self.model_file, self.camera_info, self.model_3d_info]}
|
||||
|
||||
|
||||
class PreviewText(_UIOutput):
|
||||
def __init__(self, value: str, **kwargs):
|
||||
self.value = value
|
||||
@ -471,5 +481,6 @@ __all__ = [
|
||||
"PreviewAudio",
|
||||
"PreviewVideo",
|
||||
"PreviewUI3D",
|
||||
"PreviewUI3DAdvanced",
|
||||
"PreviewText",
|
||||
]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH, File3D
|
||||
from .geometry_types import VOXEL, MESH, SPLAT, File3D
|
||||
from .image_types import SVG
|
||||
|
||||
__all__ = [
|
||||
@ -9,6 +9,7 @@ __all__ = [
|
||||
"VideoComponents",
|
||||
"VOXEL",
|
||||
"MESH",
|
||||
"SPLAT",
|
||||
"File3D",
|
||||
"SVG",
|
||||
]
|
||||
|
||||
@ -11,13 +11,32 @@ class VOXEL:
|
||||
self.data = data
|
||||
|
||||
|
||||
class SPLAT:
|
||||
"""A batch of 3D Gaussian splats in render-ready (activated, world-space) form.
|
||||
|
||||
Tensors are (B, N, ...) and zero-padded to a common N across the batch; `counts` (B,) holds the
|
||||
real per-item lengths (None when rows are uniform and no slicing is needed). SH coefficients are
|
||||
stored as (B, N, K, 3) with K = (sh_degree + 1)**2; the DC (diffuse) term is sh[..., 0, :].
|
||||
"""
|
||||
|
||||
def __init__(self, positions: torch.Tensor, scales: torch.Tensor, rotations: torch.Tensor,
|
||||
opacities: torch.Tensor, sh: torch.Tensor, counts: torch.Tensor | None = None):
|
||||
self.positions = positions # (B, N, 3) world-space centers
|
||||
self.scales = scales # (B, N, 3) linear (positive) per-axis std
|
||||
self.rotations = rotations # (B, N, 4) quaternion wxyz (normalized)
|
||||
self.opacities = opacities # (B, N, 1) in [0, 1]
|
||||
self.sh = sh # (B, N, K, 3) spherical-harmonic color coefficients
|
||||
self.counts = counts # (B,) real lengths, or None
|
||||
|
||||
|
||||
class MESH:
|
||||
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor,
|
||||
uvs: torch.Tensor | None = None,
|
||||
vertex_colors: torch.Tensor | None = None,
|
||||
texture: torch.Tensor | None = None,
|
||||
vertex_counts: torch.Tensor | None = None,
|
||||
face_counts: torch.Tensor | None = None):
|
||||
face_counts: torch.Tensor | None = None,
|
||||
unlit: bool = False):
|
||||
|
||||
assert (vertex_counts is None) == (face_counts is None), \
|
||||
"vertex_counts and face_counts must be provided together (both or neither)"
|
||||
@ -30,6 +49,8 @@ class MESH:
|
||||
# these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed.
|
||||
self.vertex_counts = vertex_counts
|
||||
self.face_counts = face_counts
|
||||
# Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes.
|
||||
self.unlit = unlit
|
||||
|
||||
|
||||
class File3D:
|
||||
|
||||
32
comfy_api_nodes/apis/beeble.py
Normal file
32
comfy_api_nodes/apis/beeble.py
Normal file
@ -0,0 +1,32 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CreateSwitchXRequest(BaseModel):
|
||||
generation_type: str = Field(...)
|
||||
source_uri: str = Field(...)
|
||||
alpha_mode: str = Field(...)
|
||||
prompt: str | None = Field(None, max_length=2000)
|
||||
reference_image_uri: str | None = Field(None)
|
||||
alpha_uri: str | None = Field(None)
|
||||
max_resolution: int = Field(1080)
|
||||
callback_url: str | None = Field(None)
|
||||
idempotency_key: str | None = Field(None, max_length=256, min_length=1)
|
||||
|
||||
|
||||
class SwitchXOutputUrls(BaseModel):
|
||||
render: str | None = Field(None)
|
||||
source: str | None = Field(None)
|
||||
alpha: str | None = Field(None)
|
||||
|
||||
|
||||
class SwitchXStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
progress: int | None = Field(None)
|
||||
generation_type: str | None = Field(None)
|
||||
alpha_mode: str | None = Field(None)
|
||||
output: SwitchXOutputUrls | None = Field(None)
|
||||
error: str | None = Field(None)
|
||||
created_at: str | None = Field(None)
|
||||
modified_at: str | None = Field(None)
|
||||
completed_at: str | None = Field(None)
|
||||
@ -1,71 +1,72 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, confloat, conint
|
||||
|
||||
|
||||
class BFLOutputFormat(str, Enum):
|
||||
png = 'png'
|
||||
jpeg = 'jpeg'
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BFLFluxExpandImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image')
|
||||
bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image')
|
||||
left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image')
|
||||
right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
|
||||
prompt: str = Field(...)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
seed: int | None = Field(None)
|
||||
top: int = Field(...)
|
||||
bottom: int = Field(...)
|
||||
left: int = Field(...)
|
||||
right: int = Field(...)
|
||||
steps: int = Field(...)
|
||||
guidance: float = Field(...)
|
||||
safety_tolerance: int = Field(6)
|
||||
output_format: str = Field("png")
|
||||
image: str = Field(None, description="A Base64-encoded string representing the image you wish to expand")
|
||||
|
||||
|
||||
class BFLFluxFillImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
prompt: str = Field(...)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
seed: int | None = Field(None)
|
||||
steps: int = Field(...)
|
||||
guidance: float = Field(...)
|
||||
safety_tolerance: int = Field(6)
|
||||
output_format: str = Field("png")
|
||||
image: str = Field(
|
||||
None, description="Base64-encoded string representing the image to modify. Can contain alpha mask if desired.",
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
|
||||
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
mask: str = Field(
|
||||
None, description="Base64-encoded string representing the mask of the areas you wish to modify."
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
|
||||
|
||||
class BFLFluxEraseRequest(BaseModel):
|
||||
image: str = Field(..., description="A Base64-encoded string representing the image to erase from.")
|
||||
mask: str = Field(
|
||||
...,
|
||||
description="A Base64-encoded black/white mask matching the input dimensions; "
|
||||
"white (255) marks areas to remove, black (0) marks areas to preserve.",
|
||||
)
|
||||
image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.')
|
||||
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
|
||||
dilate_pixels: int = Field(10)
|
||||
seed: int | None = Field(None)
|
||||
output_format: str = Field("png")
|
||||
|
||||
|
||||
class BFLFluxVTORequest(BaseModel):
|
||||
prompt: str = Field(
|
||||
..., description="Natural-language styling instruction. Required field, but may be an empty string."
|
||||
)
|
||||
person: str = Field(..., description="A Base64-encoded string representing the person image.")
|
||||
garment: str = Field(..., description="A Base64-encoded string representing the garment reference image.")
|
||||
seed: int | None = Field(None)
|
||||
safety_tolerance: int = Field(5)
|
||||
output_format: str = Field("png")
|
||||
|
||||
|
||||
class BFLFluxProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.')
|
||||
height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
|
||||
# image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
|
||||
# None, description='Blend between the prompt and the image prompt.'
|
||||
# )
|
||||
prompt: str = Field(...)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
seed: int | None = Field(None)
|
||||
width: int = Field(1024, description="Must be a multiple of 32.")
|
||||
height: int = Field(768, description="Must be a multiple of 32.")
|
||||
safety_tolerance: int = Field(6)
|
||||
output_format: str = Field("png")
|
||||
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
|
||||
|
||||
|
||||
class Flux2ProGenerateRequest(BaseModel):
|
||||
@ -83,55 +84,37 @@ class Flux2ProGenerateRequest(BaseModel):
|
||||
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
safety_tolerance: int | None = Field(
|
||||
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
|
||||
)
|
||||
output_format: str | None = Field(
|
||||
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
|
||||
)
|
||||
safety_tolerance: int = Field(5)
|
||||
output_format: str = Field("png")
|
||||
|
||||
|
||||
class BFLFluxKontextProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
|
||||
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
|
||||
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
|
||||
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
|
||||
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
prompt: str = Field(...)
|
||||
input_image: str | None = Field(None, description="Image to edit in base64 format")
|
||||
seed: int | None = Field(None)
|
||||
guidance: float = Field(...)
|
||||
steps: int = Field(...)
|
||||
safety_tolerance: int = Field(2)
|
||||
output_format: str = Field("png")
|
||||
aspect_ratio: str | None = Field(None)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
|
||||
|
||||
class BFLFluxProUltraGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for image generation.')
|
||||
prompt_upsampling: Optional[bool] = Field(
|
||||
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
|
||||
)
|
||||
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
|
||||
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
|
||||
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
|
||||
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
|
||||
)
|
||||
output_format: Optional[BFLOutputFormat] = Field(
|
||||
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
|
||||
)
|
||||
raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
|
||||
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
|
||||
image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
|
||||
None, description='Blend between the prompt and the image prompt.'
|
||||
)
|
||||
prompt: str = Field(...)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
seed: int | None = Field(None)
|
||||
aspect_ratio: str | None = Field(None)
|
||||
safety_tolerance: int = Field(6)
|
||||
output_format: str = Field("png")
|
||||
raw: bool | None = Field(None)
|
||||
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
|
||||
image_prompt_strength: float | None = Field(None)
|
||||
|
||||
|
||||
class BFLFluxProGenerateResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
polling_url: str = Field(..., description="URL to poll for the generation result.")
|
||||
id: str = Field(...)
|
||||
polling_url: str = Field(...)
|
||||
cost: float | None = Field(None, description="Price in cents")
|
||||
|
||||
|
||||
@ -145,7 +128,7 @@ class BFLStatus(str, Enum):
|
||||
|
||||
|
||||
class BFLFluxStatusResponse(BaseModel):
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
status: BFLStatus = Field(..., description="The status of the task.")
|
||||
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
|
||||
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)
|
||||
id: str = Field(...)
|
||||
status: BFLStatus = Field(...)
|
||||
result: dict[str, Any] | None = Field(None)
|
||||
progress: float | None = Field(None, ge=0.0, le=1.0)
|
||||
|
||||
@ -97,3 +97,28 @@ class BriaRemoveVideoBackgroundResult(BaseModel):
|
||||
class BriaRemoveVideoBackgroundResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaRemoveVideoBackgroundResult | None = Field(None)
|
||||
|
||||
|
||||
class BriaVideoGreenScreenRequest(BaseModel):
|
||||
video: str = Field(..., description="Publicly accessible URL of the input video.")
|
||||
green_shade: str = Field(
|
||||
default="broadcast_green",
|
||||
description="Solid chroma-key shade applied behind the foreground "
|
||||
"(broadcast_green, chroma_green, or blue_screen).",
|
||||
)
|
||||
output_container_and_codec: str = Field(...)
|
||||
preserve_audio: bool = Field(True)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class BriaVideoReplaceBackgroundRequest(BaseModel):
|
||||
video: str = Field(..., description="Publicly accessible URL of the input (foreground) video.")
|
||||
background_url: str = Field(
|
||||
...,
|
||||
description="Publicly accessible URL of the background image or video to composite behind "
|
||||
"the foreground. Stretched to the foreground frame; match its aspect ratio for "
|
||||
"undistorted results.",
|
||||
)
|
||||
output_container_and_codec: str = Field(...)
|
||||
preserve_audio: bool = Field(True)
|
||||
seed: int = Field(...)
|
||||
|
||||
@ -158,8 +158,9 @@ class SeedanceCreateAssetResponse(BaseModel):
|
||||
|
||||
|
||||
class SeedanceVirtualLibraryCreateAssetRequest(BaseModel):
|
||||
url: str = Field(..., description="Publicly accessible URL of the image asset to upload.")
|
||||
url: str = Field(..., description="Publicly accessible URL of the asset to upload.")
|
||||
hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.")
|
||||
asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.")
|
||||
|
||||
|
||||
# Dollars per 1K tokens, keyed by (model_id, has_video_input).
|
||||
|
||||
@ -108,13 +108,19 @@ class GeminiVideoMetadata(BaseModel):
|
||||
startOffset: GeminiOffset | None = Field(None)
|
||||
|
||||
|
||||
class GeminiThinkingConfig(BaseModel):
|
||||
includeThoughts: bool | None = Field(None)
|
||||
thinkingLevel: str = Field(...)
|
||||
|
||||
|
||||
class GeminiGenerationConfig(BaseModel):
|
||||
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
|
||||
maxOutputTokens: int | None = Field(None, ge=16, le=65536)
|
||||
seed: int | None = Field(None)
|
||||
stopSequences: list[str] | None = Field(None)
|
||||
temperature: float | None = Field(None, ge=0.0, le=2.0)
|
||||
topK: int | None = Field(None, ge=1)
|
||||
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||
thinkingConfig: GeminiThinkingConfig | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageOutputOptions(BaseModel):
|
||||
@ -128,11 +134,6 @@ class GeminiImageConfig(BaseModel):
|
||||
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
|
||||
|
||||
|
||||
class GeminiThinkingConfig(BaseModel):
|
||||
includeThoughts: bool | None = Field(None)
|
||||
thinkingLevel: str = Field(...)
|
||||
|
||||
|
||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||
responseModalities: list[str] | None = Field(None)
|
||||
imageConfig: GeminiImageConfig | None = Field(None)
|
||||
|
||||
@ -290,3 +290,19 @@ class IdeogramV3Request(BaseModel):
|
||||
None,
|
||||
description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.'
|
||||
)
|
||||
|
||||
|
||||
class IdeogramV4Request(BaseModel):
|
||||
text_prompt: str | None = Field(
|
||||
None,
|
||||
description="Natural-language prompt; Magic Prompt is applied automatically. "
|
||||
"Supply exactly one of text_prompt or json_prompt.",
|
||||
)
|
||||
json_prompt: dict[str, Any] | None = Field(
|
||||
None,
|
||||
description="Structured V4 prompt object consumed directly (disables Magic Prompt). "
|
||||
"Supply exactly one of text_prompt or json_prompt.",
|
||||
)
|
||||
resolution: str | None = Field(None, description="Output resolution in WIDTHxHEIGHT (e.g. '2048x2048').")
|
||||
rendering_speed: str | None = Field(None, description="Rendering speed: 'TURBO', 'DEFAULT', or 'QUALITY'.")
|
||||
enable_copyright_detection: bool | None = Field(None, description="Opt into post-generation copyright detection.")
|
||||
|
||||
46
comfy_api_nodes/apis/krea.py
Normal file
46
comfy_api_nodes/apis/krea.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Pydantic models for the Krea image-generation API."""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class KreaMoodboard(BaseModel):
|
||||
id: str = Field(...)
|
||||
strength: float = Field(default=0.35, ge=-0.5, le=1.5)
|
||||
|
||||
|
||||
class KreaImageStyleReference(BaseModel):
|
||||
strength: float = Field(..., ge=-2.0, le=2.0)
|
||||
url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class KreaGenerateImageRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
aspect_ratio: str = Field(...)
|
||||
resolution: str = Field(...)
|
||||
seed: int | None = Field(default=None)
|
||||
creativity: str = Field(default="medium")
|
||||
moodboards: list[KreaMoodboard] | None = Field(default=None)
|
||||
image_style_references: list[KreaImageStyleReference] | None = Field(default=None)
|
||||
|
||||
|
||||
class KreaJobResult(BaseModel):
|
||||
urls: list[str] | None = Field(default=None)
|
||||
style_id: str | None = Field(default=None)
|
||||
|
||||
|
||||
class KreaJob(BaseModel):
|
||||
job_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
created_at: str = Field(...)
|
||||
completed_at: str | None = Field(default=None)
|
||||
result: KreaJobResult | None = Field(default=None)
|
||||
|
||||
|
||||
class KreaAssetResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
image_url: str = Field(...)
|
||||
uploaded_at: str = Field(...)
|
||||
width: float | None = Field(default=None)
|
||||
height: float | None = Field(default=None)
|
||||
size_bytes: float | None = Field(default=None)
|
||||
mime_type: str | None = Field(default=None)
|
||||
@ -1,25 +1,25 @@
|
||||
from enum import Enum
|
||||
from typing import Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field, RootModel
|
||||
|
||||
|
||||
class TripoModelVersion(str, Enum):
|
||||
v3_1_20260211 = 'v3.1-20260211'
|
||||
v3_0_20250812 = 'v3.0-20250812'
|
||||
v2_5_20250123 = 'v2.5-20250123'
|
||||
v2_0_20240919 = 'v2.0-20240919'
|
||||
v1_4_20240625 = 'v1.4-20240625'
|
||||
v3_1_20260211 = "v3.1-20260211"
|
||||
v3_0_20250812 = "v3.0-20250812"
|
||||
v2_5_20250123 = "v2.5-20250123"
|
||||
v2_0_20240919 = "v2.0-20240919"
|
||||
v1_4_20240625 = "v1.4-20240625"
|
||||
|
||||
|
||||
class TripoGeometryQuality(str, Enum):
|
||||
standard = 'standard'
|
||||
detailed = 'detailed'
|
||||
standard = "standard"
|
||||
detailed = "detailed"
|
||||
|
||||
|
||||
class TripoTextureQuality(str, Enum):
|
||||
standard = 'standard'
|
||||
detailed = 'detailed'
|
||||
standard = "standard"
|
||||
detailed = "detailed"
|
||||
|
||||
|
||||
class TripoStyle(str, Enum):
|
||||
@ -33,6 +33,7 @@ class TripoStyle(str, Enum):
|
||||
ANCIENT_BRONZE = "ancient_bronze"
|
||||
NONE = "None"
|
||||
|
||||
|
||||
class TripoTaskType(str, Enum):
|
||||
TEXT_TO_MODEL = "text_to_model"
|
||||
IMAGE_TO_MODEL = "image_to_model"
|
||||
@ -45,26 +46,27 @@ class TripoTaskType(str, Enum):
|
||||
STYLIZE_MODEL = "stylize_model"
|
||||
CONVERT_MODEL = "convert_model"
|
||||
|
||||
|
||||
class TripoTextureAlignment(str, Enum):
|
||||
ORIGINAL_IMAGE = "original_image"
|
||||
GEOMETRY = "geometry"
|
||||
|
||||
|
||||
class TripoOrientation(str, Enum):
|
||||
ALIGN_IMAGE = "align_image"
|
||||
DEFAULT = "default"
|
||||
|
||||
|
||||
class TripoOutFormat(str, Enum):
|
||||
GLB = "glb"
|
||||
FBX = "fbx"
|
||||
|
||||
class TripoTopology(str, Enum):
|
||||
BIP = "bip"
|
||||
QUAD = "quad"
|
||||
|
||||
class TripoSpec(str, Enum):
|
||||
MIXAMO = "mixamo"
|
||||
TRIPO = "tripo"
|
||||
|
||||
|
||||
class TripoAnimation(str, Enum):
|
||||
IDLE = "preset:idle"
|
||||
WALK = "preset:walk"
|
||||
@ -83,11 +85,6 @@ class TripoAnimation(str, Enum):
|
||||
SERPENTINE_MARCH = "preset:serpentine:march"
|
||||
AQUATIC_MARCH = "preset:aquatic:march"
|
||||
|
||||
class TripoStylizeStyle(str, Enum):
|
||||
LEGO = "lego"
|
||||
VOXEL = "voxel"
|
||||
VORONOI = "voronoi"
|
||||
MINECRAFT = "minecraft"
|
||||
|
||||
class TripoConvertFormat(str, Enum):
|
||||
GLTF = "GLTF"
|
||||
@ -97,6 +94,7 @@ class TripoConvertFormat(str, Enum):
|
||||
STL = "STL"
|
||||
_3MF = "3MF"
|
||||
|
||||
|
||||
class TripoTextureFormat(str, Enum):
|
||||
BMP = "BMP"
|
||||
DPX = "DPX"
|
||||
@ -108,6 +106,7 @@ class TripoTextureFormat(str, Enum):
|
||||
TIFF = "TIFF"
|
||||
WEBP = "WEBP"
|
||||
|
||||
|
||||
class TripoTaskStatus(str, Enum):
|
||||
QUEUED = "queued"
|
||||
RUNNING = "running"
|
||||
@ -118,183 +117,223 @@ class TripoTaskStatus(str, Enum):
|
||||
BANNED = "banned"
|
||||
EXPIRED = "expired"
|
||||
|
||||
|
||||
class TripoFbxPreset(str, Enum):
|
||||
BLENDER = "blender"
|
||||
MIXAMO = "mixamo"
|
||||
_3DSMAX = "3dsmax"
|
||||
|
||||
|
||||
class TripoFileTokenReference(BaseModel):
|
||||
type: Optional[str] = Field(None, description='The type of the reference')
|
||||
type: str | None = Field(None, description="The type of the reference")
|
||||
file_token: str
|
||||
|
||||
|
||||
class TripoUrlReference(BaseModel):
|
||||
type: Optional[str] = Field(None, description='The type of the reference')
|
||||
type: str | None = Field(None, description="The type of the reference")
|
||||
url: str
|
||||
|
||||
|
||||
class TripoObjectStorage(BaseModel):
|
||||
bucket: str
|
||||
key: str
|
||||
|
||||
|
||||
class TripoObjectReference(BaseModel):
|
||||
type: str
|
||||
object: TripoObjectStorage
|
||||
|
||||
|
||||
class TripoFileEmptyReference(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class TripoFileReference(RootModel):
|
||||
root: TripoFileTokenReference | TripoUrlReference | TripoObjectReference | TripoFileEmptyReference
|
||||
|
||||
class TripoGetStsTokenRequest(BaseModel):
|
||||
format: str = Field(..., description='The format of the image')
|
||||
|
||||
class TripoTextToModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
|
||||
prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
|
||||
negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
|
||||
model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||
image_seed: Optional[int] = Field(None, description='The seed for the text')
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||
style: Optional[TripoStyle] = None
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description="Type of task")
|
||||
prompt: str = Field(..., description="The text prompt describing the model to generate", max_length=1024)
|
||||
negative_prompt: str | None = Field(None, description="The negative text prompt", max_length=1024)
|
||||
model_version: TripoModelVersion | None = TripoModelVersion.v2_5_20250123
|
||||
face_limit: int | None = Field(None, description="The number of faces to limit the generation to")
|
||||
texture: bool | None = Field(True, description="Whether to apply texture to the generated model")
|
||||
pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model")
|
||||
image_seed: int | None = Field(None, description="The seed for the text")
|
||||
model_seed: int | None = Field(None, description="The seed for the model")
|
||||
texture_seed: int | None = Field(None, description="The seed for the texture")
|
||||
texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard
|
||||
geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard
|
||||
style: TripoStyle | None = None
|
||||
auto_size: bool | None = Field(False, description="Whether to auto-size the model")
|
||||
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
|
||||
|
||||
|
||||
class TripoImageToModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task')
|
||||
file: TripoFileReference = Field(..., description='The file reference to convert to a model')
|
||||
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
||||
style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||
type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description="Type of task")
|
||||
file: TripoFileReference = Field(..., description="The file reference to convert to a model")
|
||||
model_version: TripoModelVersion | None = Field(None, description="The model version to use for generation")
|
||||
face_limit: int | None = Field(None, description="The number of faces to limit the generation to")
|
||||
texture: bool | None = Field(True, description="Whether to apply texture to the generated model")
|
||||
pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model")
|
||||
model_seed: int | None = Field(None, description="The seed for the model")
|
||||
texture_seed: int | None = Field(None, description="The seed for the texture")
|
||||
texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard
|
||||
geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard
|
||||
texture_alignment: TripoTextureAlignment | None = Field(
|
||||
TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method"
|
||||
)
|
||||
style: TripoStyle | None = Field(None, description="The style to apply to the generated model")
|
||||
auto_size: bool | None = Field(False, description="Whether to auto-size the model")
|
||||
orientation: TripoOrientation | None = TripoOrientation.DEFAULT
|
||||
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
|
||||
|
||||
|
||||
class TripoMultiviewToModelRequest(BaseModel):
|
||||
type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL
|
||||
files: list[TripoFileReference] = Field(..., description='The file references to convert to a model')
|
||||
model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
|
||||
orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection')
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
|
||||
geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard
|
||||
texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
|
||||
auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
|
||||
orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
|
||||
quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
|
||||
files: list[TripoFileReference] = Field(..., description="The file references to convert to a model")
|
||||
model_version: TripoModelVersion | None = Field(None, description="The model version to use for generation")
|
||||
orthographic_projection: bool | None = Field(False, description="Whether to use orthographic projection")
|
||||
face_limit: int | None = Field(None, description="The number of faces to limit the generation to")
|
||||
texture: bool | None = Field(True, description="Whether to apply texture to the generated model")
|
||||
pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model")
|
||||
model_seed: int | None = Field(None, description="The seed for the model")
|
||||
texture_seed: int | None = Field(None, description="The seed for the texture")
|
||||
texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard
|
||||
geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard
|
||||
texture_alignment: TripoTextureAlignment | None = TripoTextureAlignment.ORIGINAL_IMAGE
|
||||
auto_size: bool | None = Field(False, description="Whether to auto-size the model")
|
||||
orientation: TripoOrientation | None = Field(TripoOrientation.DEFAULT, description="The orientation for the model")
|
||||
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
|
||||
|
||||
|
||||
class TripoTextureModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
texture: Optional[bool] = Field(True, description='Whether to apply texture to the model')
|
||||
pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model')
|
||||
model_seed: Optional[int] = Field(None, description='The seed for the model')
|
||||
texture_seed: Optional[int] = Field(None, description='The seed for the texture')
|
||||
texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture')
|
||||
texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description="Type of task")
|
||||
original_model_task_id: str = Field(..., description="The task ID of the original model")
|
||||
texture: bool | None = Field(True, description="Whether to apply texture to the model")
|
||||
pbr: bool | None = Field(True, description="Whether to apply PBR to the model")
|
||||
model_seed: int | None = Field(None, description="The seed for the model")
|
||||
texture_seed: int | None = Field(None, description="The seed for the texture")
|
||||
texture_quality: TripoTextureQuality | None = Field(None, description="The quality of the texture")
|
||||
texture_alignment: TripoTextureAlignment | None = Field(
|
||||
TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method"
|
||||
)
|
||||
|
||||
|
||||
class TripoRefineModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task')
|
||||
draft_model_task_id: str = Field(..., description='The task ID of the draft model')
|
||||
type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description="Type of task")
|
||||
draft_model_task_id: str = Field(..., description="The task ID of the draft model")
|
||||
|
||||
class TripoAnimatePrerigcheckRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
|
||||
class TripoAnimateRigRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
|
||||
spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging')
|
||||
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description="Type of task")
|
||||
original_model_task_id: str = Field(..., description="The task ID of the original model")
|
||||
out_format: TripoOutFormat | None = Field(TripoOutFormat.GLB, description="The output format")
|
||||
spec: TripoSpec | None = Field(TripoSpec.TRIPO, description="The specification for rigging")
|
||||
|
||||
|
||||
class TripoAnimateRetargetRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
animation: TripoAnimation = Field(..., description='The animation to apply')
|
||||
out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
|
||||
bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation')
|
||||
type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description="Type of task")
|
||||
original_model_task_id: str = Field(..., description="The task ID of the original model")
|
||||
animation: TripoAnimation = Field(..., description="The animation to apply")
|
||||
out_format: TripoOutFormat | None = Field(TripoOutFormat.GLB, description="The output format")
|
||||
bake_animation: bool | None = Field(True, description="Whether to bake the animation")
|
||||
|
||||
class TripoStylizeModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task')
|
||||
style: TripoStylizeStyle = Field(..., description='The style to apply to the model')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
block_size: Optional[int] = Field(80, description='The block size for stylization')
|
||||
|
||||
class TripoConvertModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
|
||||
format: TripoConvertFormat = Field(..., description='The format to convert to')
|
||||
original_model_task_id: str = Field(..., description='The task ID of the original model')
|
||||
quad: Optional[bool] = Field(None, description='Whether to apply quad to the model')
|
||||
force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry')
|
||||
face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to')
|
||||
flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model')
|
||||
flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom')
|
||||
texture_size: Optional[int] = Field(None, description='The size of the texture')
|
||||
texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
|
||||
pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom')
|
||||
scale_factor: Optional[float] = Field(None, description='The scale factor for the model')
|
||||
with_animation: Optional[bool] = Field(None, description='Whether to include animations')
|
||||
pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs')
|
||||
bake: Optional[bool] = Field(None, description='Whether to bake the model')
|
||||
part_names: Optional[list[str]] = Field(None, description='The names of the parts to include')
|
||||
fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export')
|
||||
export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors')
|
||||
export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export')
|
||||
animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place')
|
||||
type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description="Type of task")
|
||||
format: TripoConvertFormat = Field(..., description="The format to convert to")
|
||||
original_model_task_id: str = Field(..., description="The task ID of the original model")
|
||||
quad: bool | None = Field(None, description="Whether to apply quad to the model")
|
||||
force_symmetry: bool | None = Field(None, description="Whether to force symmetry")
|
||||
face_limit: int | None = Field(None, description="The number of faces to limit the conversion to")
|
||||
flatten_bottom: bool | None = Field(None, description="Whether to flatten the bottom of the model")
|
||||
flatten_bottom_threshold: float | None = Field(None, description="The threshold for flattening the bottom")
|
||||
texture_size: int | None = Field(None, description="The size of the texture")
|
||||
texture_format: TripoTextureFormat | None = Field(TripoTextureFormat.JPEG, description="The format of the texture")
|
||||
pivot_to_center_bottom: bool | None = Field(None, description="Whether to pivot to the center bottom")
|
||||
scale_factor: float | None = Field(None, description="The scale factor for the model")
|
||||
with_animation: bool | None = Field(None, description="Whether to include animations")
|
||||
pack_uv: bool | None = Field(None, description="Whether to pack the UVs")
|
||||
bake: bool | None = Field(None, description="Whether to bake the model")
|
||||
part_names: list[str] | None = Field(None, description="The names of the parts to include")
|
||||
fbx_preset: TripoFbxPreset | None = Field(None, description="The preset for the FBX export")
|
||||
export_vertex_colors: bool | None = Field(None, description="Whether to export the vertex colors")
|
||||
export_orientation: TripoOrientation | None = Field(None, description="The orientation for the export")
|
||||
animate_in_place: bool | None = Field(None, description="Whether to animate in place")
|
||||
|
||||
|
||||
class TripoP1CommonRequest(BaseModel):
|
||||
"""Fields supported by Tripo P1 across all input types."""
|
||||
|
||||
model_version: str = Field("P1-20260311")
|
||||
model_seed: int | None = Field(None, description="Random seed for geometry generation")
|
||||
face_limit: int | None = Field(None, ge=48, le=20000, description="Target face count (48-20000)")
|
||||
texture: bool | None = Field(None, description="Enable texturing; pbr=True forces this true")
|
||||
pbr: bool | None = Field(None, description="Enable PBR maps; when true, texture is also enabled")
|
||||
texture_seed: int | None = Field(None, description="Random seed for texture generation")
|
||||
texture_quality: str | None = Field(None, description='"standard" or "detailed"')
|
||||
auto_size: bool | None = Field(None, description="Scale to real-world meters")
|
||||
compress: str | None = Field(None, description='Only "geometry" is supported')
|
||||
export_uv: bool | None = Field(None, description="Perform UV unwrapping during generation")
|
||||
|
||||
|
||||
class TripoP1TextToModelRequest(TripoP1CommonRequest):
|
||||
type: str = "text_to_model"
|
||||
prompt: str = Field(..., max_length=1024)
|
||||
negative_prompt: str | None = Field(None, max_length=255)
|
||||
image_seed: int | None = None
|
||||
|
||||
|
||||
class TripoP1ImageToModelRequest(TripoP1CommonRequest):
|
||||
type: str = "image_to_model"
|
||||
file: TripoFileReference
|
||||
enable_image_autofix: bool | None = None
|
||||
texture_alignment: str | None = Field(None, description='"original_image" or "geometry"')
|
||||
orientation: str | None = Field(None, description='"default" or "align_image"; needs texture=true')
|
||||
|
||||
|
||||
class TripoP1MultiviewToModelRequest(TripoP1CommonRequest):
|
||||
"""P1 multiview generation.
|
||||
|
||||
Tripo requires `files` to be exactly four entries in [front, left, back, right] order with `{}`
|
||||
(TripoFileEmptyReference) for omitted slots; front is required and at least two images total must be provided.
|
||||
"""
|
||||
|
||||
type: str = "multiview_to_model"
|
||||
files: list[TripoFileReference]
|
||||
texture_alignment: str | None = None
|
||||
orientation: str | None = None
|
||||
|
||||
|
||||
class TripoTaskOutput(BaseModel):
|
||||
model: Optional[str] = Field(None, description='URL to the model')
|
||||
base_model: Optional[str] = Field(None, description='URL to the base model')
|
||||
pbr_model: Optional[str] = Field(None, description='URL to the PBR model')
|
||||
rendered_image: Optional[str] = Field(None, description='URL to the rendered image')
|
||||
riggable: Optional[bool] = Field(None, description='Whether the model is riggable')
|
||||
model: str | None = Field(None, description="URL to the model")
|
||||
base_model: str | None = Field(None, description="URL to the base model")
|
||||
pbr_model: str | None = Field(None, description="URL to the PBR model")
|
||||
rendered_image: str | None = Field(None, description="URL to the rendered image")
|
||||
riggable: bool | None = Field(None, description="Whether the model is riggable")
|
||||
|
||||
|
||||
class TripoTask(BaseModel):
|
||||
task_id: str = Field(..., description='The task ID')
|
||||
type: Optional[str] = Field(None, description='The type of task')
|
||||
status: Optional[TripoTaskStatus] = Field(None, description='The status of the task')
|
||||
input: Optional[dict[str, Any]] = Field(None, description='The input parameters for the task')
|
||||
output: Optional[TripoTaskOutput] = Field(None, description='The output of the task')
|
||||
progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100)
|
||||
create_time: Optional[int] = Field(None, description='The creation time of the task')
|
||||
running_left_time: Optional[int] = Field(None, description='The estimated time left for the task')
|
||||
queue_position: Optional[int] = Field(None, description='The position in the queue')
|
||||
task_id: str = Field(..., description="The task ID")
|
||||
type: str | None = Field(None, description="The type of task")
|
||||
status: TripoTaskStatus | None = Field(None, description="The status of the task")
|
||||
input: dict[str, Any] | None = Field(None, description="The input parameters for the task")
|
||||
output: TripoTaskOutput | None = Field(None, description="The output of the task")
|
||||
progress: int | None = Field(None, description="The progress of the task", ge=0, le=100)
|
||||
create_time: int | None = Field(None, description="The creation time of the task")
|
||||
running_left_time: int | None = Field(None, description="The estimated time left for the task")
|
||||
queue_position: int | None = Field(None, description="The position in the queue")
|
||||
consumed_credit: int | None = Field(None)
|
||||
|
||||
|
||||
class TripoTaskResponse(BaseModel):
|
||||
code: int = Field(0, description='The response code')
|
||||
data: TripoTask = Field(..., description='The task data')
|
||||
code: int = Field(0, description="The response code")
|
||||
data: TripoTask = Field(..., description="The task data")
|
||||
|
||||
class TripoGeneralResponse(BaseModel):
|
||||
code: int = Field(0, description='The response code')
|
||||
data: dict[str, str] = Field(..., description='The task ID data')
|
||||
|
||||
class TripoBalanceData(BaseModel):
|
||||
balance: float = Field(..., description='The account balance')
|
||||
frozen: float = Field(..., description='The frozen balance')
|
||||
|
||||
class TripoBalanceResponse(BaseModel):
|
||||
code: int = Field(0, description='The response code')
|
||||
data: TripoBalanceData = Field(..., description='The balance data')
|
||||
|
||||
class TripoErrorResponse(BaseModel):
|
||||
code: int = Field(..., description='The error code')
|
||||
message: str = Field(..., description='The error message')
|
||||
suggestion: str = Field(..., description='The suggestion for fixing the error')
|
||||
code: int = Field(..., description="The error code")
|
||||
message: str = Field(..., description="The error message")
|
||||
suggestion: str = Field(..., description="The suggestion for fixing the error")
|
||||
|
||||
@ -155,7 +155,7 @@ class ClaudeNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ClaudeNode",
|
||||
display_name="Anthropic Claude",
|
||||
category="api node/text/Anthropic",
|
||||
category="partner/text/Anthropic",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses with Anthropic's Claude models. "
|
||||
"Provide a text prompt and optionally one or more images for multimodal context.",
|
||||
|
||||
404
comfy_api_nodes/nodes_beeble.py
Normal file
404
comfy_api_nodes/nodes_beeble.py
Normal file
@ -0,0 +1,404 @@
|
||||
from fractions import Fraction
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types
|
||||
from comfy_api_nodes.apis.beeble import (
|
||||
CreateSwitchXRequest,
|
||||
SwitchXStatusResponse,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
bytesio_to_image_tensor,
|
||||
convert_mask_to_image,
|
||||
download_url_as_bytesio,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
downscale_image_tensor,
|
||||
downscale_video_to_max_pixels,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
validate_video_frame_count,
|
||||
)
|
||||
|
||||
_MAX_PIXELS = 2_770_000
|
||||
_MAX_FRAMES = 240
|
||||
_MAX_PROMPT_LEN = 2000
|
||||
|
||||
|
||||
def _validate_inputs(prompt: str | None, reference_image: Input.Image | None) -> str | None:
|
||||
"""Beeble requires at least one of prompt or reference_image. Returns the cleaned prompt."""
|
||||
cleaned = prompt.strip() if prompt else ""
|
||||
if not cleaned and reference_image is None:
|
||||
raise ValueError("At least one of 'prompt' or 'reference_image' must be provided.")
|
||||
if cleaned:
|
||||
validate_string(cleaned, strip_whitespace=False, max_length=_MAX_PROMPT_LEN)
|
||||
return cleaned or None
|
||||
|
||||
|
||||
async def _upload_mask_as_image(
|
||||
cls: type[IO.ComfyNode],
|
||||
mask: Input.Image,
|
||||
*,
|
||||
wait_label: str,
|
||||
) -> str:
|
||||
"""Encode a single-frame MASK (H, W) or (1, H, W) as a PNG and upload."""
|
||||
if mask.dim() == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
image = convert_mask_to_image(mask[:1])
|
||||
return await upload_image_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
mime_type="image/png",
|
||||
wait_label=wait_label,
|
||||
total_pixels=_MAX_PIXELS,
|
||||
)
|
||||
|
||||
|
||||
async def _upload_mask_batch_as_video(
|
||||
cls: type[IO.ComfyNode],
|
||||
mask: Input.Image,
|
||||
*,
|
||||
frame_rate: Fraction,
|
||||
source_frame_count: int,
|
||||
wait_label: str,
|
||||
) -> str:
|
||||
"""Encode a MASK batch (N, H, W) as a grayscale H.264 MP4 at frame_rate and upload.
|
||||
|
||||
The matte is always downscaled to the pixel budget so it stays within Beeble's limit and
|
||||
keeps the same dimensions as the (similarly downscaled) source — both use the same algorithm
|
||||
from the same starting dimensions, and downscaling is a no-op when already within budget.
|
||||
"""
|
||||
if mask.dim() == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[0] != source_frame_count:
|
||||
raise ValueError(
|
||||
f"Custom alpha video frame count ({mask.shape[0]}) does not match the "
|
||||
f"source video frame count ({source_frame_count}). The Beeble API requires "
|
||||
"one mask per source frame."
|
||||
)
|
||||
images = downscale_image_tensor(convert_mask_to_image(mask), _MAX_PIXELS)
|
||||
alpha_video = InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=None, frame_rate=frame_rate))
|
||||
return await upload_video_to_comfyapi(cls, alpha_video, wait_label=wait_label)
|
||||
|
||||
|
||||
def _alpha_mode_input(*, video: bool) -> IO.DynamicCombo.Input:
|
||||
"""Build the alpha_mode DynamicCombo with mode-specific extra inputs."""
|
||||
select_keyframe_tooltip = (
|
||||
"First-frame keyframe mask. Beeble propagates this across the video." if video else "Grayscale keyframe mask."
|
||||
)
|
||||
custom_tooltip = (
|
||||
"Per-frame grayscale mask covering the entire video. "
|
||||
"Must have the same frame count as the source. "
|
||||
"Connect a MASK output from SAM3_TrackToMask or similar."
|
||||
if video
|
||||
else "Grayscale mask to apply."
|
||||
)
|
||||
return IO.DynamicCombo.Input(
|
||||
"alpha_mode",
|
||||
tooltip=(
|
||||
"Controls how SwitchX decides what to keep vs. regenerate. "
|
||||
"'auto' isolates the main subject automatically. "
|
||||
"'fill' regenerates the entire frame while preserving geometry. "
|
||||
"'select' propagates a first-frame keyframe across the clip. "
|
||||
"'custom' uses a per-frame alpha matte you provide."
|
||||
),
|
||||
options=[
|
||||
IO.DynamicCombo.Option("auto", []),
|
||||
IO.DynamicCombo.Option("fill", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"select",
|
||||
[IO.Mask.Input("alpha_keyframe", tooltip=select_keyframe_tooltip)],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
"custom",
|
||||
[IO.Mask.Input("alpha_mask", tooltip=custom_tooltip)],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _common_inputs(*, source: IO.Input, video: bool) -> list[IO.Input]:
|
||||
return [
|
||||
source,
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip=(
|
||||
"Text description of the desired output (max 2000 chars). "
|
||||
"At least one of 'prompt' or 'reference_image' is required."
|
||||
),
|
||||
),
|
||||
IO.Image.Input(
|
||||
"reference_image",
|
||||
optional=True,
|
||||
tooltip=(
|
||||
"Reference image whose look (background, lighting, costume) the result "
|
||||
"should adopt. At least one of 'reference_image' or 'prompt' is required."
|
||||
),
|
||||
),
|
||||
_alpha_mode_input(video=video),
|
||||
IO.Combo.Input(
|
||||
"max_resolution",
|
||||
options=["1080p", "720p"],
|
||||
default="1080p",
|
||||
tooltip="Maximum output resolution.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip=(
|
||||
"Seed controls whether the node should re-run; " "results are non-deterministic regardless of seed."
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
async def _submit_and_poll(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: CreateSwitchXRequest,
|
||||
) -> SwitchXStatusResponse:
|
||||
initial = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/beeble/v1/switchx/generations", method="POST"),
|
||||
response_model=SwitchXStatusResponse,
|
||||
data=request,
|
||||
)
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/beeble/v1/switchx/generations/{initial.id}"),
|
||||
response_model=SwitchXStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
)
|
||||
|
||||
|
||||
def _require_output_url(response: SwitchXStatusResponse, name: str) -> str:
|
||||
if response.output is None or getattr(response.output, name) is None:
|
||||
raise RuntimeError(f"Beeble job {response.id} completed without a {name!r} output URL.")
|
||||
return getattr(response.output, name)
|
||||
|
||||
|
||||
def _alpha_url(response: SwitchXStatusResponse, mode: str) -> str | None:
|
||||
"""URL of the alpha matte, or None when the mode produces no separate matte.
|
||||
|
||||
'fill' selects the whole frame, so Beeble writes no alpha asset even though the status
|
||||
response still returns a (dangling) signed URL for it — fetching it 403s with S3
|
||||
AccessDenied. The other three modes ('auto', 'custom', 'select') all produce a real,
|
||||
downloadable matte.
|
||||
"""
|
||||
if mode == "fill" or response.output is None:
|
||||
return None
|
||||
return response.output.alpha
|
||||
|
||||
|
||||
class BeebleSwitchXVideoEdit(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="BeebleSwitchXVideoEdit",
|
||||
display_name="Beeble SwitchX Video Edit",
|
||||
category="partner/video/Beeble",
|
||||
description=(
|
||||
"Edit a video with Beeble SwitchX. Switches anything in the scene (background, "
|
||||
"lighting, costume) while preserving the original subject's pixels and motion. "
|
||||
"Provide a reference image and/or text prompt to describe the new look. "
|
||||
"Max 240 frames, max ~2.77MP per frame."
|
||||
),
|
||||
inputs=_common_inputs(source=IO.Video.Input("video"), video=True),
|
||||
outputs=[
|
||||
IO.Video.Output(display_name="video"),
|
||||
IO.Video.Output(
|
||||
display_name="alpha",
|
||||
tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.",
|
||||
),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143;
|
||||
{"type":"usd","usd": $rate, "format":{"suffix":"/30 frames"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
prompt: str,
|
||||
alpha_mode: dict,
|
||||
max_resolution: str,
|
||||
seed: int,
|
||||
reference_image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
cleaned_prompt = _validate_inputs(prompt, reference_image)
|
||||
|
||||
validate_video_frame_count(video, max_frame_count=_MAX_FRAMES)
|
||||
video = downscale_video_to_max_pixels(video, _MAX_PIXELS)
|
||||
|
||||
mode = alpha_mode["alpha_mode"]
|
||||
alpha_uri: str | None = None
|
||||
if mode == "select":
|
||||
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe")
|
||||
elif mode == "custom":
|
||||
alpha_uri = await _upload_mask_batch_as_video(
|
||||
cls,
|
||||
alpha_mode["alpha_mask"],
|
||||
frame_rate=video.get_frame_rate(),
|
||||
source_frame_count=video.get_frame_count(),
|
||||
wait_label="Uploading alpha video",
|
||||
)
|
||||
|
||||
source_uri = await upload_video_to_comfyapi(cls, video, wait_label="Uploading source")
|
||||
reference_uri: str | None = None
|
||||
if reference_image is not None:
|
||||
reference_uri = await upload_image_to_comfyapi(
|
||||
cls,
|
||||
reference_image,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading reference",
|
||||
total_pixels=_MAX_PIXELS,
|
||||
)
|
||||
|
||||
request = CreateSwitchXRequest(
|
||||
generation_type="video",
|
||||
source_uri=source_uri,
|
||||
alpha_mode=mode,
|
||||
prompt=cleaned_prompt,
|
||||
reference_image_uri=reference_uri,
|
||||
alpha_uri=alpha_uri,
|
||||
max_resolution=1080 if max_resolution == "1080p" else 720,
|
||||
)
|
||||
response = await _submit_and_poll(cls, request)
|
||||
|
||||
render = await download_url_to_video_output(_require_output_url(response, "render"))
|
||||
alpha = None
|
||||
if (alpha_url := _alpha_url(response, mode)) is not None:
|
||||
alpha = await download_url_to_video_output(alpha_url)
|
||||
return IO.NodeOutput(render, alpha)
|
||||
|
||||
|
||||
class BeebleSwitchXImageEdit(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="BeebleSwitchXImageEdit",
|
||||
display_name="Beeble SwitchX Image Edit",
|
||||
category="partner/image/Beeble",
|
||||
description=(
|
||||
"Edit a single image with Beeble SwitchX. Switches anything in the scene "
|
||||
"(background, lighting, costume) while preserving the original subject's pixels. "
|
||||
"Provide a reference image and/or text prompt to describe the new look. "
|
||||
"Max ~2.77MP."
|
||||
),
|
||||
inputs=_common_inputs(source=IO.Image.Input("image"), video=False),
|
||||
outputs=[
|
||||
IO.Image.Output(display_name="image"),
|
||||
IO.Mask.Output(
|
||||
display_name="alpha",
|
||||
tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.",
|
||||
),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]),
|
||||
expr="""
|
||||
(
|
||||
$rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143;
|
||||
{"type":"usd","usd": $rate}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
alpha_mode: dict,
|
||||
max_resolution: str,
|
||||
seed: int,
|
||||
reference_image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
cleaned_prompt = _validate_inputs(prompt, reference_image)
|
||||
|
||||
image = downscale_image_tensor(image, _MAX_PIXELS)
|
||||
|
||||
mode = alpha_mode["alpha_mode"]
|
||||
alpha_uri: str | None = None
|
||||
if mode == "select":
|
||||
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe")
|
||||
elif mode == "custom":
|
||||
alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_mask"], wait_label="Uploading alpha")
|
||||
|
||||
source_uri = await upload_image_to_comfyapi(
|
||||
cls,
|
||||
image,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading source",
|
||||
total_pixels=None,
|
||||
)
|
||||
reference_uri: str | None = None
|
||||
if reference_image is not None:
|
||||
reference_uri = await upload_image_to_comfyapi(
|
||||
cls,
|
||||
reference_image,
|
||||
mime_type="image/png",
|
||||
wait_label="Uploading reference",
|
||||
total_pixels=_MAX_PIXELS,
|
||||
)
|
||||
|
||||
request = CreateSwitchXRequest(
|
||||
generation_type="image",
|
||||
source_uri=source_uri,
|
||||
alpha_mode=mode,
|
||||
prompt=cleaned_prompt,
|
||||
reference_image_uri=reference_uri,
|
||||
alpha_uri=alpha_uri,
|
||||
max_resolution=1080 if max_resolution == "1080p" else 720,
|
||||
)
|
||||
response = await _submit_and_poll(cls, request)
|
||||
|
||||
render = await download_url_to_image_tensor(_require_output_url(response, "render"))
|
||||
alpha_mask = None
|
||||
if (alpha_url := _alpha_url(response, mode)) is not None:
|
||||
alpha_image = bytesio_to_image_tensor(await download_url_as_bytesio(alpha_url), mode="L")
|
||||
alpha_mask = alpha_image.squeeze(-1) if alpha_image.dim() == 4 else alpha_image
|
||||
return IO.NodeOutput(render, alpha_mask)
|
||||
|
||||
|
||||
class BeebleExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
BeebleSwitchXVideoEdit,
|
||||
BeebleSwitchXImageEdit,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> BeebleExtension:
|
||||
return BeebleExtension()
|
||||
@ -4,17 +4,20 @@ from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bfl import (
|
||||
BFLFluxEraseRequest,
|
||||
BFLFluxExpandImageRequest,
|
||||
BFLFluxFillImageRequest,
|
||||
BFLFluxKontextProGenerateRequest,
|
||||
BFLFluxProGenerateResponse,
|
||||
BFLFluxProUltraGenerateRequest,
|
||||
BFLFluxStatusResponse,
|
||||
BFLFluxVTORequest,
|
||||
BFLStatus,
|
||||
Flux2ProGenerateRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
convert_mask_to_image,
|
||||
download_url_to_image_tensor,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
@ -22,19 +25,11 @@ from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
validate_aspect_ratio_string,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
def convert_mask_to_image(mask: Input.Image):
|
||||
"""
|
||||
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
|
||||
"""
|
||||
mask = mask.unsqueeze(-1)
|
||||
mask = torch.cat([mask] * 3, dim=-1)
|
||||
return mask
|
||||
|
||||
|
||||
class FluxProUltraImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -42,7 +37,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="FluxProUltraImageNode",
|
||||
display_name="Flux 1.1 [pro] Ultra Image",
|
||||
category="api node/image/BFL",
|
||||
category="partner/image/BFL",
|
||||
description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -160,7 +155,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id=cls.NODE_ID,
|
||||
display_name=cls.DISPLAY_NAME,
|
||||
category="api node/image/BFL",
|
||||
category="partner/image/BFL",
|
||||
description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -282,7 +277,7 @@ class FluxProExpandNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="FluxProExpandNode",
|
||||
display_name="Flux.1 Expand Image",
|
||||
category="api node/image/BFL",
|
||||
category="partner/image/BFL",
|
||||
description="Outpaints image based on prompt.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -419,7 +414,7 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="FluxProFillNode",
|
||||
display_name="Flux.1 Fill Image",
|
||||
category="api node/image/BFL",
|
||||
category="partner/image/BFL",
|
||||
description="Inpaints image based on mask and prompt.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -519,6 +514,174 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class FluxEraseNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="FluxEraseNode",
|
||||
display_name="Flux Erase Image",
|
||||
category="partner/image/BFL",
|
||||
description="Removes the masked object from an image and reconstructs the background. "
|
||||
"Paint the mask over what you want to erase.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Mask.Input("mask", tooltip="White areas are removed; black areas are preserved."),
|
||||
IO.Int.Input(
|
||||
"dilate_pixels",
|
||||
default=10,
|
||||
min=0,
|
||||
max=25,
|
||||
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"range_usd","min_usd":0.03,"max_usd":0.06,"format":{"approximate":true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
mask: Input.Image,
|
||||
dilate_pixels: int = 10,
|
||||
seed: int = 0,
|
||||
) -> IO.NodeOutput:
|
||||
validate_image_dimensions(image, min_width=256, min_height=256)
|
||||
mask = resize_mask_to_image(mask, image)
|
||||
mask = tensor_to_base64_string(convert_mask_to_image(mask))
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/erase-v1", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxEraseRequest(
|
||||
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
|
||||
mask=mask,
|
||||
dilate_pixels=dilate_pixels,
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
|
||||
def price_extractor(_r: BaseModel) -> float | None:
|
||||
return None if initial_response.cost is None else initial_response.cost / 100
|
||||
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
price_extractor=price_extractor,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class FluxVTONode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="FluxVTONode",
|
||||
display_name="Flux Virtual Try-On",
|
||||
category="partner/image/BFL",
|
||||
description="Virtual try-on: dresses the person in the provided garment.",
|
||||
inputs=[
|
||||
IO.Image.Input("person", tooltip="Image of the person to dress."),
|
||||
IO.Image.Input("garment", tooltip="Image of the garment to apply."),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Optional natural-language styling instruction (e.g. how the garment should fit).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"range_usd","min_usd":0.0375,"max_usd":0.075,"format":{"approximate":true}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
person: Input.Image,
|
||||
garment: Input.Image,
|
||||
prompt: str = "",
|
||||
seed: int = 0,
|
||||
) -> IO.NodeOutput:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/vto-v1", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxVTORequest(
|
||||
prompt=prompt,
|
||||
person=tensor_to_base64_string(person[:, :, :, :3]),
|
||||
garment=tensor_to_base64_string(garment[:, :, :, :3]),
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
|
||||
def price_extractor(_r: BaseModel) -> float | None:
|
||||
return None if initial_response.cost is None else initial_response.cost / 100
|
||||
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
price_extractor=price_extractor,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class Flux2ProImageNode(IO.ComfyNode):
|
||||
|
||||
NODE_ID = "Flux2ProImageNode"
|
||||
@ -545,7 +708,7 @@ class Flux2ProImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id=cls.NODE_ID,
|
||||
display_name=cls.DISPLAY_NAME,
|
||||
category="api node/image/BFL",
|
||||
category="partner/image/BFL",
|
||||
description="Generates images synchronously based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -716,7 +879,7 @@ class Flux2ImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Flux2ImageNode",
|
||||
display_name="Flux.2 Image",
|
||||
category="api node/image/BFL",
|
||||
category="partner/image/BFL",
|
||||
description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -853,6 +1016,8 @@ class BFLExtension(ComfyExtension):
|
||||
FluxKontextMaxImageNode,
|
||||
FluxProExpandNode,
|
||||
FluxProFillNode,
|
||||
FluxEraseNode,
|
||||
FluxVTONode,
|
||||
Flux2ProImageNode,
|
||||
Flux2MaxImageNode,
|
||||
Flux2ImageNode,
|
||||
|
||||
@ -1,14 +1,19 @@
|
||||
import av
|
||||
import torch
|
||||
from av.codec import CodecContext
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bria import (
|
||||
BriaEditImageRequest,
|
||||
BriaImageEditResponse,
|
||||
BriaRemoveBackgroundRequest,
|
||||
BriaRemoveBackgroundResponse,
|
||||
BriaRemoveVideoBackgroundRequest,
|
||||
BriaRemoveVideoBackgroundResponse,
|
||||
BriaImageEditResponse,
|
||||
BriaStatusResponse,
|
||||
BriaVideoGreenScreenRequest,
|
||||
BriaVideoReplaceBackgroundRequest,
|
||||
InputModerationSettings,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
@ -31,7 +36,7 @@ class BriaImageEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="BriaImageEditNode",
|
||||
display_name="Bria FIBO Image Edit",
|
||||
category="api node/image/Bria",
|
||||
category="partner/image/Bria",
|
||||
description="Edit images using Bria latest model",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["FIBO"]),
|
||||
@ -169,7 +174,7 @@ class BriaRemoveImageBackground(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="BriaRemoveImageBackground",
|
||||
display_name="Bria Remove Image Background",
|
||||
category="api node/image/Bria",
|
||||
category="partner/image/Bria",
|
||||
description="Remove the background from an image using Bria RMBG 2.0.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -245,7 +250,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="BriaRemoveVideoBackground",
|
||||
display_name="Bria Remove Video Background",
|
||||
category="api node/video/Bria",
|
||||
category="partner/video/Bria",
|
||||
description="Remove the background from a video using Bria. ",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
@ -316,6 +321,248 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||
|
||||
|
||||
class BriaVideoGreenScreen(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaVideoGreenScreen",
|
||||
display_name="Bria Video Green Screen",
|
||||
category="partner/video/Bria",
|
||||
description="Replace a video's background with a solid chroma-key screen using Bria.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Combo.Input(
|
||||
"green_shade",
|
||||
options=["broadcast_green", "chroma_green", "blue_screen"],
|
||||
tooltip="Solid chroma-key shade applied behind the foreground: "
|
||||
"broadcast_green (#00B140), chroma_green (#00FF00), or blue_screen (#0000FF).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
green_shade: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_video_duration(video, max_duration=60.0)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/video/edit/green_screen", method="POST"),
|
||||
data=BriaVideoGreenScreenRequest(
|
||||
video=await upload_video_to_comfyapi(cls, video),
|
||||
green_shade=green_shade,
|
||||
output_container_and_codec="mp4_h264",
|
||||
seed=seed,
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaRemoveVideoBackgroundResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||
|
||||
|
||||
class BriaVideoReplaceBackground(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaVideoReplaceBackground",
|
||||
display_name="Bria Video Replace Background",
|
||||
category="partner/video/Bria",
|
||||
description="Replace a video's background with a supplied image or video using Bria. "
|
||||
"The output keeps the foreground's resolution and frame rate; a background with a "
|
||||
"different aspect ratio is stretched to fit, so match it for undistorted results.",
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="Foreground video whose background is replaced."),
|
||||
IO.Image.Input(
|
||||
"background_image",
|
||||
optional=True,
|
||||
tooltip="Background image to composite behind the foreground. "
|
||||
"Provide either a background image or a background video, not both.",
|
||||
),
|
||||
IO.Video.Input(
|
||||
"background_video",
|
||||
optional=True,
|
||||
tooltip="Background video to composite behind the foreground. "
|
||||
"Provide either a background image or a background video, not both.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
seed: int,
|
||||
background_image: Input.Image | None = None,
|
||||
background_video: Input.Video | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if (background_image is None) == (background_video is None):
|
||||
raise ValueError("Provide either a background image or a background video, not both.")
|
||||
validate_video_duration(video, max_duration=60.0)
|
||||
if background_video is not None:
|
||||
validate_video_duration(background_video, max_duration=60.0)
|
||||
background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background")
|
||||
else:
|
||||
background_url = await upload_image_to_comfyapi(cls, background_image, wait_label="Uploading background")
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"),
|
||||
data=BriaVideoReplaceBackgroundRequest(
|
||||
video=await upload_video_to_comfyapi(cls, video),
|
||||
background_url=background_url,
|
||||
output_container_and_codec="mp4_h264",
|
||||
seed=seed,
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaRemoveVideoBackgroundResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||
|
||||
|
||||
def _video_to_images_and_mask(video: Input.Video) -> tuple[Input.Image, Input.Mask]:
|
||||
"""Decode a transparent webm (VP9 + alpha) into image frames and an alpha mask.
|
||||
|
||||
VP9 keeps its alpha in a side layer that PyAV's default vp9 decoder drops, so the frames
|
||||
are decoded with libvpx-vp9. Returns RGB images [B,H,W,3] in 0..1 and a mask [B,H,W]
|
||||
following the Load Image convention (1 = transparent) for compositing or Save WEBM.
|
||||
"""
|
||||
rgb_frames: list[torch.Tensor] = []
|
||||
alpha_frames: list[torch.Tensor] = []
|
||||
with av.open(video.get_stream_source(), mode="r") as container:
|
||||
stream = container.streams.video[0]
|
||||
decoder = CodecContext.create("libvpx-vp9", "r") if stream.codec_context.name == "vp9" else None
|
||||
for packet in container.demux(stream):
|
||||
for frame in (decoder.decode(packet) if decoder is not None else packet.decode()):
|
||||
rgba = torch.from_numpy(frame.to_ndarray(format="rgba")).float() / 255.0
|
||||
rgb_frames.append(rgba[..., :3])
|
||||
alpha_frames.append(rgba[..., 3])
|
||||
images = torch.stack(rgb_frames) if rgb_frames else torch.zeros(0, 0, 0, 3)
|
||||
mask = (1.0 - torch.stack(alpha_frames)) if alpha_frames else torch.zeros((images.shape[0], 64, 64))
|
||||
return images, mask
|
||||
|
||||
|
||||
class BriaTransparentVideoBackground(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaTransparentVideoBackground",
|
||||
display_name="Bria Remove Video Background (Transparent)",
|
||||
category="partner/video/Bria",
|
||||
description="Remove the background from a video using Bria and return the cut-out frames "
|
||||
"plus an alpha mask. Connect both to a compositing node, or feed them to Save WEBM to "
|
||||
"write a transparent video.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(display_name="images"),
|
||||
IO.Mask.Output(display_name="mask"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_video_duration(video, max_duration=60.0)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
|
||||
data=BriaRemoveVideoBackgroundRequest(
|
||||
video=await upload_video_to_comfyapi(cls, video),
|
||||
background_color="Transparent",
|
||||
output_container_and_codec="webm_vp9",
|
||||
seed=seed,
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaRemoveVideoBackgroundResponse,
|
||||
)
|
||||
video_out = await download_url_to_video_output(response.result.video_url)
|
||||
images, mask = _video_to_images_and_mask(video_out)
|
||||
return IO.NodeOutput(images, mask)
|
||||
|
||||
|
||||
class BriaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -323,6 +570,9 @@ class BriaExtension(ComfyExtension):
|
||||
BriaImageEditNode,
|
||||
BriaRemoveImageBackground,
|
||||
BriaRemoveVideoBackground,
|
||||
BriaVideoGreenScreen,
|
||||
# BriaVideoReplaceBackground, # server returns Status 500 when we pass background video
|
||||
BriaTransparentVideoBackground,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -2,11 +2,13 @@ import hashlib
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.bytedance import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
@ -43,6 +45,7 @@ from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
downscale_image_tensor_by_max_side,
|
||||
downscale_video_to_max_pixels,
|
||||
get_number_of_images,
|
||||
image_tensor_pair_to_batch,
|
||||
@ -121,6 +124,52 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st
|
||||
)
|
||||
|
||||
|
||||
def _prepare_seedance_image(image: Input.Image) -> Input.Image:
|
||||
"""Auto-downscale a Seedance image input to the per-side limits, then validate it."""
|
||||
validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
image = downscale_image_tensor_by_max_side(image, max_side=6000)
|
||||
validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000)
|
||||
return image
|
||||
|
||||
|
||||
# Supported output aspect ratios, used to pre-size FLF frames to matching pixel pair to avoid the 1080p stretch jump.
|
||||
SEEDANCE2_RATIO_WH = {
|
||||
"16:9": (16, 9),
|
||||
"4:3": (4, 3),
|
||||
"1:1": (1, 1),
|
||||
"3:4": (3, 4),
|
||||
"9:16": (9, 16),
|
||||
"21:9": (21, 9),
|
||||
}
|
||||
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080}
|
||||
|
||||
|
||||
def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]:
|
||||
"""Exact supported output (width, height) for (resolution, ratio).
|
||||
|
||||
The shorter side equals the resolution number (e.g. 1080p 16:9 -> 1920x1080). For ratio
|
||||
"adaptive" (or any unexpected value) the ratio is derived from the image's own aspect, snapped
|
||||
to the nearest supported ratio, so the output keeps the frame's orientation.
|
||||
"""
|
||||
short = SEEDANCE2_RES_SHORT_SIDE[resolution]
|
||||
if ratio not in SEEDANCE2_RATIO_WH:
|
||||
aspect = image.shape[-2] / image.shape[-3] # W / H; tensor is (B, H, W, C)
|
||||
ratio = min(SEEDANCE2_RATIO_WH, key=lambda k: abs(SEEDANCE2_RATIO_WH[k][0] / SEEDANCE2_RATIO_WH[k][1] - aspect))
|
||||
rw, rh = SEEDANCE2_RATIO_WH[ratio]
|
||||
if rw >= rh: # landscape or square: shorter side is the height
|
||||
out_w, out_h = round(short * rw / rh), short
|
||||
else: # portrait: shorter side is the width
|
||||
out_w, out_h = short, round(short * rh / rw)
|
||||
return out_w - out_w % 2, out_h - out_h % 2
|
||||
|
||||
|
||||
def _resize_to_exact(image: torch.Tensor, width: int, height: int) -> torch.Tensor:
|
||||
"""Center-crop to the target aspect and resize to exactly width x height (lanczos)."""
|
||||
samples = image.movedim(-1, 1) # (B, H, W, C) -> (B, C, H, W)
|
||||
resized = common_upscale(samples, width, height, "lanczos", "center")
|
||||
return resized.movedim(1, -1)
|
||||
|
||||
|
||||
async def _resolve_reference_assets(
|
||||
cls: type[IO.ComfyNode],
|
||||
asset_ids: list[str],
|
||||
@ -308,6 +357,26 @@ async def _seedance_virtual_library_upload_image_asset(
|
||||
return f"asset://{create_resp.asset_id}"
|
||||
|
||||
|
||||
async def _seedance_virtual_library_upload_video_asset(
|
||||
cls: type[IO.ComfyNode],
|
||||
video: Input.Video,
|
||||
*,
|
||||
wait_label: str = "Uploading video",
|
||||
) -> str:
|
||||
buf = BytesIO()
|
||||
video.save_to(buf, format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264)
|
||||
video_hash = hashlib.sha256(buf.getbuffer()).hexdigest()
|
||||
public_url = await upload_video_to_comfyapi(cls, video, wait_label=wait_label)
|
||||
create_resp = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/seedance/virtual-library/assets", method="POST"),
|
||||
response_model=SeedanceCreateAssetResponse,
|
||||
data=SeedanceVirtualLibraryCreateAssetRequest(url=public_url, hash=video_hash, asset_type="Video"),
|
||||
)
|
||||
await _wait_for_asset_active(cls, create_resp.asset_id, group_id="virtual-library")
|
||||
return f"asset://{create_resp.asset_id}"
|
||||
|
||||
|
||||
def _seedance2_price_extractor(model_id: str, has_video_input: bool):
|
||||
"""Returns a price_extractor closure for Seedance 2.0 poll_op."""
|
||||
rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input))
|
||||
@ -338,7 +407,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceImageNode",
|
||||
display_name="ByteDance Image",
|
||||
category="api node/image/ByteDance",
|
||||
category="partner/image/ByteDance",
|
||||
description="Generate images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
|
||||
@ -462,7 +531,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedreamNode",
|
||||
display_name="ByteDance Seedream 4.5 & 5.0",
|
||||
category="api node/image/ByteDance",
|
||||
category="partner/image/ByteDance",
|
||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -724,7 +793,7 @@ class ByteDanceSeedreamNodeV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedreamNodeV2",
|
||||
display_name="ByteDance Seedream 4.5 & 5.0",
|
||||
category="api node/image/ByteDance",
|
||||
category="partner/image/ByteDance",
|
||||
description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -890,7 +959,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceTextToVideoNode",
|
||||
display_name="ByteDance Text to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate video using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1018,7 +1087,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceImageToVideoNode",
|
||||
display_name="ByteDance Image to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate video using ByteDance models via api based on image and prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1155,7 +1224,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceFirstLastFrameNode",
|
||||
display_name="ByteDance First-Last-Frame to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate video using prompt and first and last frames.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1303,7 +1372,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceImageReferenceNode",
|
||||
display_name="ByteDance Reference Images to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate video using prompt and reference images.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1546,7 +1615,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2TextToVideoNode",
|
||||
display_name="ByteDance Seedance 2.0 Text to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate video using Seedance 2.0 models based on a text prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1647,7 +1716,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2FirstLastFrameNode",
|
||||
display_name="ByteDance Seedance 2.0 First-Last-Frame to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1760,6 +1829,29 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
if last_frame is not None and last_frame_asset_id:
|
||||
raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.")
|
||||
|
||||
request_ratio = model["ratio"]
|
||||
if first_frame_asset_id or last_frame_asset_id:
|
||||
if first_frame is not None:
|
||||
first_frame = _prepare_seedance_image(first_frame)
|
||||
if last_frame is not None:
|
||||
last_frame = _prepare_seedance_image(last_frame)
|
||||
else:
|
||||
# The 1080p FLF stretch fix (pre-size frames to a supported pixel pair + submit ratio="adaptive")
|
||||
# only applies to local image inputs we can resize.
|
||||
request_ratio = "adaptive"
|
||||
target_dims: tuple[int, int] | None = None
|
||||
if first_frame is not None:
|
||||
validate_image_aspect_ratio(first_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], first_frame)
|
||||
first_frame = _resize_to_exact(first_frame, *target_dims)
|
||||
if last_frame is not None:
|
||||
validate_image_aspect_ratio(last_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
validate_image_dimensions(last_frame, min_width=300, min_height=300)
|
||||
if target_dims is None:
|
||||
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], last_frame)
|
||||
last_frame = _resize_to_exact(last_frame, *target_dims)
|
||||
|
||||
asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a]
|
||||
image_assets: dict[str, str] = {}
|
||||
if asset_ids_to_resolve:
|
||||
@ -1809,7 +1901,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
content=content,
|
||||
generate_audio=model["generate_audio"],
|
||||
resolution=model["resolution"],
|
||||
ratio=model["ratio"],
|
||||
ratio=request_ratio,
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
@ -1866,7 +1958,7 @@ def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"auto_downscale",
|
||||
default=False,
|
||||
default=True,
|
||||
optional=True,
|
||||
tooltip="Automatically downscale reference videos that exceed the model's pixel budget "
|
||||
"for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.",
|
||||
@ -1909,7 +2001,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDance2ReferenceNode",
|
||||
display_name="ByteDance Seedance 2.0 Reference to Video",
|
||||
category="api node/video/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description="Generate, edit, or extend video using Seedance 2.0 with reference images, "
|
||||
"videos, and audio. Supports multimodal reference, video editing, and video extension.",
|
||||
inputs=[
|
||||
@ -2034,6 +2126,9 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
f"(audios={len(reference_audios)}, audio assets={len(reference_audio_assets)}). Maximum is 3."
|
||||
)
|
||||
|
||||
for key in reference_images:
|
||||
reference_images[key] = _prepare_seedance_image(reference_images[key])
|
||||
|
||||
model_id = SEEDANCE_MODELS[model["model"]]
|
||||
has_video_input = total_videos > 0
|
||||
|
||||
@ -2106,7 +2201,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode):
|
||||
content.append(
|
||||
TaskVideoContent(
|
||||
video_url=TaskVideoContentUrl(
|
||||
url=await upload_video_to_comfyapi(
|
||||
url=await _seedance_virtual_library_upload_video_asset(
|
||||
cls,
|
||||
reference_videos[key],
|
||||
wait_label=f"Uploading video {i}",
|
||||
@ -2203,7 +2298,7 @@ class ByteDanceCreateImageAsset(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceCreateImageAsset",
|
||||
display_name="ByteDance Create Image Asset",
|
||||
category="api node/image/ByteDance",
|
||||
category="partner/image/ByteDance",
|
||||
description=(
|
||||
"Create a Seedance 2.0 personal image asset. Uploads the input image and "
|
||||
"registers it in the given asset group. If group_id is empty, runs a real-person "
|
||||
@ -2270,7 +2365,7 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceCreateVideoAsset",
|
||||
display_name="ByteDance Create Video Asset",
|
||||
category="api node/video/ByteDance",
|
||||
category="partner/video/ByteDance",
|
||||
description=(
|
||||
"Create a Seedance 2.0 personal video asset. Uploads the input video and "
|
||||
"registers it in the given asset group. If group_id is empty, runs a real-person "
|
||||
|
||||
@ -144,7 +144,7 @@ class ByteDanceSeedNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ByteDanceSeedNode",
|
||||
display_name="ByteDance Seed",
|
||||
category="api node/text/ByteDance",
|
||||
category="partner/text/ByteDance",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses with ByteDance's Seed 2.0 models. "
|
||||
"Provide a text prompt and optionally one or more images or videos for multimodal context.",
|
||||
|
||||
@ -69,7 +69,7 @@ class ElevenLabsSpeechToText(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsSpeechToText",
|
||||
display_name="ElevenLabs Speech to Text",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Transcribe audio to text. "
|
||||
"Supports automatic language detection, speaker diarization, and audio event tagging.",
|
||||
inputs=[
|
||||
@ -210,7 +210,7 @@ class ElevenLabsVoiceSelector(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsVoiceSelector",
|
||||
display_name="ElevenLabs Voice Selector",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Select a predefined ElevenLabs voice for text-to-speech generation.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -239,7 +239,7 @@ class ElevenLabsTextToSpeech(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsTextToSpeech",
|
||||
display_name="ElevenLabs Text to Speech",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Convert text to speech.",
|
||||
inputs=[
|
||||
IO.Custom(ELEVENLABS_VOICE).Input(
|
||||
@ -414,7 +414,7 @@ class ElevenLabsAudioIsolation(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsAudioIsolation",
|
||||
display_name="ElevenLabs Voice Isolation",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Remove background noise from audio, isolating vocals or speech.",
|
||||
inputs=[
|
||||
IO.Audio.Input(
|
||||
@ -459,7 +459,7 @@ class ElevenLabsTextToSoundEffects(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsTextToSoundEffects",
|
||||
display_name="ElevenLabs Text to Sound Effects",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Generate sound effects from text descriptions.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -555,7 +555,7 @@ class ElevenLabsInstantVoiceClone(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsInstantVoiceClone",
|
||||
display_name="ElevenLabs Instant Voice Clone",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Create a cloned voice from audio samples. "
|
||||
"Provide 1-8 audio recordings of the voice to clone.",
|
||||
inputs=[
|
||||
@ -658,7 +658,7 @@ class ElevenLabsSpeechToSpeech(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsSpeechToSpeech",
|
||||
display_name="ElevenLabs Speech to Speech",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Transform speech from one voice to another while preserving the original content and emotion.",
|
||||
inputs=[
|
||||
IO.Custom(ELEVENLABS_VOICE).Input(
|
||||
@ -793,7 +793,7 @@ class ElevenLabsTextToDialogue(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ElevenLabsTextToDialogue",
|
||||
display_name="ElevenLabs Text to Dialogue",
|
||||
category="api node/audio/ElevenLabs",
|
||||
category="partner/audio/ElevenLabs",
|
||||
description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.",
|
||||
inputs=[
|
||||
IO.Float.Input(
|
||||
|
||||
@ -8,7 +8,7 @@ import os
|
||||
from enum import Enum
|
||||
from fnmatch import fnmatch
|
||||
from io import BytesIO
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
@ -19,6 +19,7 @@ from comfy_api_nodes.apis.gemini import (
|
||||
GeminiContent,
|
||||
GeminiFileData,
|
||||
GeminiGenerateContentRequest,
|
||||
GeminiGenerationConfig,
|
||||
GeminiGenerateContentResponse,
|
||||
GeminiImageConfig,
|
||||
GeminiImageGenerateContentRequest,
|
||||
@ -40,13 +41,18 @@ from comfy_api_nodes.util import (
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
video_to_base64_string,
|
||||
)
|
||||
|
||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
GEMINI_URL_INPUT_BUDGET = 10
|
||||
GEMINI_MAX_INLINE_BYTES = 18 * 1024 * 1024
|
||||
GEMINI_IMAGE_SYS_PROMPT = (
|
||||
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
|
||||
"Interpret all user input—regardless of "
|
||||
@ -285,6 +291,140 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
||||
return final_price / 1_000_000.0
|
||||
|
||||
|
||||
def create_video_parts(video_input: Input.Video) -> list[GeminiPart]:
|
||||
"""Convert a single video input to Gemini API compatible parts (inline MP4/H.264)."""
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.video_mp4,
|
||||
data=base_64_string,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def create_audio_parts(audio_input: Input.Audio) -> list[GeminiPart]:
|
||||
"""Convert an audio input to Gemini API compatible parts (one inline MP3 part per batch item)."""
|
||||
audio_parts: list[GeminiPart] = []
|
||||
for batch_index in range(audio_input["waveform"].shape[0]):
|
||||
# Recreate an IO.AUDIO object for the given batch dimension index
|
||||
audio_at_index = Input.Audio(
|
||||
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
|
||||
sample_rate=audio_input["sample_rate"],
|
||||
)
|
||||
# Convert to MP3 format for compatibility with Gemini API
|
||||
audio_bytes = audio_to_base64_string(
|
||||
audio_at_index,
|
||||
container_format="mp3",
|
||||
codec_name="libmp3lame",
|
||||
)
|
||||
audio_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.audio_mp3,
|
||||
data=audio_bytes,
|
||||
)
|
||||
)
|
||||
)
|
||||
return audio_parts
|
||||
|
||||
|
||||
def _flatten_images(images: list[Input.Image]) -> list[torch.Tensor]:
|
||||
"""Expand any batched image tensors into individual (H, W, C) frames, preserving order."""
|
||||
frames: list[torch.Tensor] = []
|
||||
for img in images:
|
||||
if len(img.shape) == 4:
|
||||
frames.extend(img[i] for i in range(img.shape[0]))
|
||||
else:
|
||||
frames.append(img)
|
||||
return frames
|
||||
|
||||
|
||||
def _flatten_audio(audios: list[Input.Audio]) -> list[Input.Audio]:
|
||||
"""Expand any batched audio inputs into individual single-clip audio inputs, preserving order."""
|
||||
clips: list[Input.Audio] = []
|
||||
for audio in audios:
|
||||
waveform = audio["waveform"]
|
||||
for i in range(waveform.shape[0]):
|
||||
clips.append(Input.Audio(waveform=waveform[i].unsqueeze(0), sample_rate=audio["sample_rate"]))
|
||||
return clips
|
||||
|
||||
|
||||
async def _media_url_part(cls: type[IO.ComfyNode], kind: str, payload: Any) -> GeminiPart:
|
||||
"""Upload a single media unit to ComfyAPI storage and return a fileData (URL) part."""
|
||||
if kind == "image":
|
||||
url = await upload_image_to_comfyapi(cls, payload, mime_type="image/png", wait_label="Uploading image")
|
||||
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.image_png, fileUri=url))
|
||||
if kind == "audio":
|
||||
url = await upload_audio_to_comfyapi(
|
||||
cls, payload, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mp3"
|
||||
)
|
||||
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.audio_mp3, fileUri=url))
|
||||
url = await upload_video_to_comfyapi(cls, payload, wait_label="Uploading video")
|
||||
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.video_mp4, fileUri=url))
|
||||
|
||||
|
||||
def _media_inline_part(kind: str, payload: Any) -> tuple[GeminiPart, int]:
|
||||
"""Encode a single media unit as an inline base64 part; returns (part, base64_length)."""
|
||||
if kind == "image":
|
||||
data = tensor_to_base64_string(payload, mime_type="image/webp")
|
||||
mime = GeminiMimeType.image_webp
|
||||
elif kind == "audio":
|
||||
data = audio_to_base64_string(payload, container_format="mp3", codec_name="libmp3lame")
|
||||
mime = GeminiMimeType.audio_mp3
|
||||
else:
|
||||
data = video_to_base64_string(
|
||||
payload, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
mime = GeminiMimeType.video_mp4
|
||||
return GeminiPart(inlineData=GeminiInlineData(mimeType=mime, data=data)), len(data)
|
||||
|
||||
|
||||
async def build_gemini_media_parts(
|
||||
cls: type[IO.ComfyNode],
|
||||
images: list[Input.Image],
|
||||
audios: list[Input.Audio],
|
||||
videos: list[Input.Video],
|
||||
*,
|
||||
url_budget: int = GEMINI_URL_INPUT_BUDGET,
|
||||
max_inline_bytes: int = GEMINI_MAX_INLINE_BYTES,
|
||||
) -> list[GeminiPart]:
|
||||
"""Build Gemini parts for multimodal inputs (images, audio, video).
|
||||
|
||||
fileData URLs are preferred for every media type: the upload is fetched directly by the
|
||||
model, keeping the request body tiny regardless of media size. The URL budget is shared
|
||||
across all media and assigned largest-first (video, then audio, then images), so that if it
|
||||
is ever exhausted the inline-base64 overflow is limited to the smallest items. Total inline
|
||||
payload is capped by `max_inline_bytes`.
|
||||
"""
|
||||
units: list[tuple[str, Any]] = (
|
||||
[("video", v) for v in videos]
|
||||
+ [("audio", a) for a in _flatten_audio(audios)]
|
||||
+ [("image", f) for f in _flatten_images(images)]
|
||||
)
|
||||
|
||||
parts: list[GeminiPart] = []
|
||||
url_used = 0
|
||||
inline_bytes = 0
|
||||
for kind, payload in units:
|
||||
if url_used < url_budget:
|
||||
parts.append(await _media_url_part(cls, kind, payload))
|
||||
url_used += 1
|
||||
continue
|
||||
part, nbytes = _media_inline_part(kind, payload)
|
||||
inline_bytes += nbytes
|
||||
if inline_bytes > max_inline_bytes:
|
||||
raise ValueError(
|
||||
f"Too much media to send inline (over {max_inline_bytes // (1024 * 1024)}MB after the first "
|
||||
f"{url_budget} inputs are uploaded as URLs). Reduce the number or size of attached media."
|
||||
)
|
||||
parts.append(part)
|
||||
return parts
|
||||
|
||||
|
||||
class GeminiNode(IO.ComfyNode):
|
||||
"""
|
||||
Node to generate text responses from a Gemini model.
|
||||
@ -300,7 +440,7 @@ class GeminiNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNode",
|
||||
display_name="Google Gemini",
|
||||
category="api node/text/Gemini",
|
||||
category="partner/text/Gemini",
|
||||
description="Generate text responses with Google's Gemini AI model. "
|
||||
"You can provide multiple types of inputs (text, images, audio, video) "
|
||||
"as context for generating more relevant and meaningful responses.",
|
||||
@ -407,58 +547,9 @@ class GeminiNode(IO.ComfyNode):
|
||||
)
|
||||
""",
|
||||
),
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
||||
"""Convert video input to Gemini API compatible parts."""
|
||||
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.video_mp4,
|
||||
data=base_64_string,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert audio input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded audio.
|
||||
"""
|
||||
audio_parts: list[GeminiPart] = []
|
||||
for batch_index in range(audio_input["waveform"].shape[0]):
|
||||
# Recreate an IO.AUDIO object for the given batch dimension index
|
||||
audio_at_index = Input.Audio(
|
||||
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
|
||||
sample_rate=audio_input["sample_rate"],
|
||||
)
|
||||
# Convert to MP3 format for compatibility with Gemini API
|
||||
audio_bytes = audio_to_base64_string(
|
||||
audio_at_index,
|
||||
container_format="mp3",
|
||||
codec_name="libmp3lame",
|
||||
)
|
||||
audio_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.audio_mp3,
|
||||
data=audio_bytes,
|
||||
)
|
||||
)
|
||||
)
|
||||
return audio_parts
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
@ -482,9 +573,9 @@ class GeminiNode(IO.ComfyNode):
|
||||
if images is not None:
|
||||
parts.extend(await create_image_parts(cls, images))
|
||||
if audio is not None:
|
||||
parts.extend(cls.create_audio_parts(audio))
|
||||
parts.extend(create_audio_parts(audio))
|
||||
if video is not None:
|
||||
parts.extend(cls.create_video_parts(video))
|
||||
parts.extend(create_video_parts(video))
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
@ -512,6 +603,210 @@ class GeminiNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||
|
||||
|
||||
GEMINI_V2_MODELS: dict[str, str] = {
|
||||
"Gemini 3.1 Pro": "gemini-3.1-pro-preview",
|
||||
"Gemini 3.1 Flash-Lite": "gemini-3.1-flash-lite-preview",
|
||||
}
|
||||
|
||||
|
||||
def _gemini_text_model_inputs(thinking_default: str) -> list[Input]:
|
||||
"""Per-model inputs revealed by the model DynamicCombo (shared media + sampling controls)."""
|
||||
return [
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, 17)],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional image(s) to use as context for the model. Up to 16 images.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"audio",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Audio.Input("audio"),
|
||||
names=["audio_1"],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional audio clip to use as context for the model.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"video",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("video"),
|
||||
names=["video_1"],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional video clip to use as context for the model.",
|
||||
),
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
"files",
|
||||
optional=True,
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Input Files node.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"thinking_level",
|
||||
options=["LOW", "HIGH"],
|
||||
default=thinking_default,
|
||||
tooltip="How hard the model reasons internally before answering. "
|
||||
"HIGH improves quality on difficult tasks but costs more (thinking) tokens and is slower.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.01,
|
||||
tooltip="Controls randomness. Lower is more focused/deterministic, higher is more creative.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=0.95,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"max_output_tokens",
|
||||
default=32768,
|
||||
min=16,
|
||||
max=65536,
|
||||
tooltip="Maximum tokens to generate, including the model's internal thinking. "
|
||||
"With thinking_level HIGH, a low value can leave no room for the answer; raise this if "
|
||||
"responses come back empty or truncated. The model stops early when finished, so a higher "
|
||||
"cap costs nothing extra for short replies.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class GeminiNodeV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNodeV2",
|
||||
display_name="Google Gemini",
|
||||
category="partner/text/Gemini",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses with Google's Gemini models. Provide a text prompt and, "
|
||||
"optionally, one or more images, audio clips, videos, or files as multimodal context.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text input to the model. Include detailed instructions, questions, or context.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Gemini 3.1 Pro", _gemini_text_model_inputs("HIGH")),
|
||||
IO.DynamicCombo.Option("Gemini 3.1 Flash-Lite", _gemini_text_model_inputs("LOW")),
|
||||
],
|
||||
tooltip="The Gemini model used to generate the response.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for sampling. Set to 0 for a random seed. Deterministic output isn't guaranteed.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Foundational instructions that dictate the model's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$contains($m, "lite") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.00025, 0.0015],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
} : {
|
||||
"type": "list_usd",
|
||||
"usd": [0.002, 0.012],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_id = GEMINI_V2_MODELS[model["model"]]
|
||||
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
images = [t for t in (model.get("images") or {}).values() if t is not None]
|
||||
audios = [a for a in (model.get("audio") or {}).values() if a is not None]
|
||||
videos = [v for v in (model.get("video") or {}).values() if v is not None]
|
||||
if images or audios or videos:
|
||||
parts.extend(await build_gemini_media_parts(cls, images, audios, videos))
|
||||
files = model.get("files")
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"),
|
||||
data=GeminiGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role=GeminiRole.user,
|
||||
parts=parts,
|
||||
)
|
||||
],
|
||||
generationConfig=GeminiGenerationConfig(
|
||||
temperature=model["temperature"],
|
||||
topP=model["top_p"],
|
||||
maxOutputTokens=model["max_output_tokens"],
|
||||
seed=seed if seed > 0 else None,
|
||||
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
|
||||
output_text = get_text_from_response(response)
|
||||
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||
|
||||
|
||||
class GeminiInputFiles(IO.ComfyNode):
|
||||
"""
|
||||
Loads and formats input files for use with the Gemini API.
|
||||
@ -541,7 +836,7 @@ class GeminiInputFiles(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiInputFiles",
|
||||
display_name="Gemini Input Files",
|
||||
category="api node/text/Gemini",
|
||||
category="partner/text/Gemini",
|
||||
description="Loads and prepares input files to include as inputs for Gemini LLM nodes. "
|
||||
"The files will be read by the Gemini model when generating a response. "
|
||||
"The contents of the text file count toward the token limit. "
|
||||
@ -598,7 +893,7 @@ class GeminiImage(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiImageNode",
|
||||
display_name="Nano Banana (Google Gemini Image)",
|
||||
category="api node/image/Gemini",
|
||||
category="partner/image/Gemini",
|
||||
description="Edit images synchronously via Google API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -731,7 +1026,7 @@ class GeminiImage2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiImage2Node",
|
||||
display_name="Nano Banana Pro (Google Gemini Image)",
|
||||
category="api node/image/Gemini",
|
||||
category="partner/image/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -869,7 +1164,7 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNanoBanana2",
|
||||
display_name="Nano Banana 2",
|
||||
category="api node/image/Gemini",
|
||||
category="partner/image/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -1085,7 +1380,7 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNanoBanana2V2",
|
||||
display_name="Nano Banana 2",
|
||||
category="api node/image/Gemini",
|
||||
category="partner/image/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -1129,6 +1424,26 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.01,
|
||||
optional=True,
|
||||
tooltip="Controls randomness in generation. Lower is more focused/deterministic.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=0.95,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
optional=True,
|
||||
tooltip="Nucleus sampling threshold. Lower is more focused, higher more diverse.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@ -1165,6 +1480,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
seed: int,
|
||||
response_modalities: str,
|
||||
system_prompt: str = "",
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 0.95,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_choice = model["model"]
|
||||
@ -1204,6 +1521,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=image_config,
|
||||
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
|
||||
temperature=temperature,
|
||||
topP=top_p,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
@ -1222,6 +1541,7 @@ class GeminiExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
GeminiNode,
|
||||
GeminiNodeV2,
|
||||
GeminiImage,
|
||||
GeminiImage2,
|
||||
GeminiNanoBanana2,
|
||||
|
||||
@ -29,6 +29,11 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
|
||||
|
||||
_GROK_VIDEO_MODEL_API_IDS = {
|
||||
"grok-imagine-video-1.5": "grok-imagine-video-1.5-preview",
|
||||
}
|
||||
|
||||
|
||||
def _extract_grok_price(response) -> float | None:
|
||||
if response.usage and response.usage.cost_in_usd_ticks is not None:
|
||||
return response.usage.cost_in_usd_ticks / 10_000_000_000
|
||||
@ -49,7 +54,7 @@ class GrokImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageNode",
|
||||
display_name="Grok Image",
|
||||
category="api node/image/Grok",
|
||||
category="partner/image/Grok",
|
||||
description="Generate images using Grok based on a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -58,7 +63,6 @@ class GrokImageNode(IO.ComfyNode):
|
||||
"grok-imagine-image-quality",
|
||||
"grok-imagine-image-pro",
|
||||
"grok-imagine-image",
|
||||
"grok-imagine-image-beta",
|
||||
],
|
||||
),
|
||||
IO.String.Input(
|
||||
@ -224,7 +228,7 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageEditNode",
|
||||
display_name="Grok Image Edit",
|
||||
category="api node/image/Grok",
|
||||
category="partner/image/Grok",
|
||||
description="Modify an existing image based on a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -233,7 +237,6 @@ class GrokImageEditNode(IO.ComfyNode):
|
||||
"grok-imagine-image-quality",
|
||||
"grok-imagine-image-pro",
|
||||
"grok-imagine-image",
|
||||
"grok-imagine-image-beta",
|
||||
],
|
||||
),
|
||||
IO.Image.Input("image", display_name="images"),
|
||||
@ -366,7 +369,7 @@ class GrokImageEditNodeV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokImageEditNodeV2",
|
||||
display_name="Grok Image Edit",
|
||||
category="api node/image/Grok",
|
||||
category="partner/image/Grok",
|
||||
description="Modify an existing image based on a text prompt",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -503,10 +506,14 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoNode",
|
||||
display_name="Grok Video",
|
||||
category="api node/video/Grok",
|
||||
category="partner/video/Grok",
|
||||
description="Generate video from a prompt or an image",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["grok-imagine-video", "grok-imagine-video-1.5"],
|
||||
tooltip="grok-imagine-video-1.5 currently always requires an input image.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
@ -542,7 +549,11 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
tooltip="Seed to determine if node should re-run; "
|
||||
"actual results are nondeterministic regardless of seed.",
|
||||
),
|
||||
IO.Image.Input("image", optional=True),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
tooltip="Optional starting image for grok-imagine-video. Required for grok-imagine-video-1.5.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@ -554,12 +565,16 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]),
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"], inputs=["image"]),
|
||||
expr="""
|
||||
(
|
||||
$rate := widgets.resolution = "720p" ? 0.07 : 0.05;
|
||||
$is15 := $contains(widgets.model, "1.5");
|
||||
$rate := $is15
|
||||
? (widgets.resolution = "720p" ? 0.2002 : 0.1144)
|
||||
: (widgets.resolution = "720p" ? 0.07 : 0.05);
|
||||
$imgCost := $is15 ? 0.0143 : 0.002;
|
||||
$base := $rate * widgets.duration;
|
||||
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
|
||||
{"type":"usd","usd": inputs.image.connected ? $base + $imgCost : $base}
|
||||
)
|
||||
""",
|
||||
),
|
||||
@ -576,8 +591,8 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
seed: int,
|
||||
image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if model == "grok-imagine-video-beta":
|
||||
model = "grok-imagine-video"
|
||||
if image is None and model == "grok-imagine-video-1.5":
|
||||
raise ValueError(f"The '{model}' model requires an input image; connect one to the 'image' input.")
|
||||
image_url = None
|
||||
if image is not None:
|
||||
if get_number_of_images(image) != 1:
|
||||
@ -588,7 +603,7 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"),
|
||||
data=VideoGenerationRequest(
|
||||
model=model,
|
||||
model=_GROK_VIDEO_MODEL_API_IDS.get(model, model),
|
||||
image=image_url,
|
||||
prompt=prompt,
|
||||
resolution=resolution,
|
||||
@ -603,7 +618,7 @@ class GrokVideoNode(IO.ComfyNode):
|
||||
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
|
||||
status_extractor=lambda r: r.status if r.status is not None else "complete",
|
||||
response_model=VideoStatusResponse,
|
||||
price_extractor=_extract_grok_price,
|
||||
price_extractor=_extract_grok_video_price if model == "grok-imagine-video-1.5" else _extract_grok_price,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
|
||||
|
||||
@ -615,10 +630,10 @@ class GrokVideoEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoEditNode",
|
||||
display_name="Grok Video Edit",
|
||||
category="api node/video/Grok",
|
||||
category="partner/video/Grok",
|
||||
description="Edit an existing video based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
|
||||
IO.Combo.Input("model", options=["grok-imagine-video"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
@ -693,7 +708,7 @@ class GrokVideoReferenceNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoReferenceNode",
|
||||
display_name="Grok Reference-to-Video",
|
||||
category="api node/video/Grok",
|
||||
category="partner/video/Grok",
|
||||
description="Generate video guided by reference images as style and content references.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -826,7 +841,7 @@ class GrokVideoExtendNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="GrokVideoExtendNode",
|
||||
display_name="Grok Video Extend",
|
||||
category="api node/video/Grok",
|
||||
category="partner/video/Grok",
|
||||
description="Extend an existing video with a seamless continuation based on a text prompt.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -71,7 +71,7 @@ class HitPawGeneralImageEnhance(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HitPawGeneralImageEnhance",
|
||||
display_name="HitPaw General Image Enhance",
|
||||
category="api node/image/HitPaw",
|
||||
category="partner/image/HitPaw",
|
||||
description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. "
|
||||
f"Maximum output: {MAX_MP_GENERATIVE} megapixels.",
|
||||
inputs=[
|
||||
@ -201,7 +201,7 @@ class HitPawVideoEnhance(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HitPawVideoEnhance",
|
||||
display_name="HitPaw Video Enhance",
|
||||
category="api node/video/HitPaw",
|
||||
category="partner/video/HitPaw",
|
||||
description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. "
|
||||
"Prices shown are per second of video.",
|
||||
inputs=[
|
||||
|
||||
@ -123,7 +123,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TencentTextToModelNode",
|
||||
display_name="Hunyuan3D: Text to Model",
|
||||
category="api node/3d/Tencent",
|
||||
category="partner/3d/Tencent",
|
||||
essentials_category="3D",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -242,7 +242,7 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TencentImageToModelNode",
|
||||
display_name="Hunyuan3D: Image(s) to Model",
|
||||
category="api node/3d/Tencent",
|
||||
category="partner/3d/Tencent",
|
||||
essentials_category="3D",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -415,7 +415,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TencentModelTo3DUVNode",
|
||||
display_name="Hunyuan3D: Model to UV",
|
||||
category="api node/3d/Tencent",
|
||||
category="partner/3d/Tencent",
|
||||
description="Perform UV unfolding on a 3D model to generate UV texture. "
|
||||
"Input model must have less than 30000 faces.",
|
||||
inputs=[
|
||||
@ -505,7 +505,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Tencent3DTextureEditNode",
|
||||
display_name="Hunyuan3D: 3D Texture Edit",
|
||||
category="api node/3d/Tencent",
|
||||
category="partner/3d/Tencent",
|
||||
description="After inputting the 3D model, perform 3D model texture redrawing.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
@ -594,7 +594,7 @@ class Tencent3DPartNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Tencent3DPartNode",
|
||||
display_name="Hunyuan3D: 3D Part",
|
||||
category="api node/3d/Tencent",
|
||||
category="partner/3d/Tencent",
|
||||
description="Automatically perform component identification and generation based on the model structure.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
@ -666,7 +666,7 @@ class TencentSmartTopologyNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TencentSmartTopologyNode",
|
||||
display_name="Hunyuan3D: Smart Topology",
|
||||
category="api node/3d/Tencent",
|
||||
category="partner/3d/Tencent",
|
||||
description="Perform smart retopology on a 3D model. "
|
||||
"Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.",
|
||||
inputs=[
|
||||
|
||||
@ -10,6 +10,7 @@ from comfy_api_nodes.apis.ideogram import (
|
||||
ImageRequest,
|
||||
IdeogramV3Request,
|
||||
IdeogramV3EditRequest,
|
||||
IdeogramV4Request,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
@ -17,6 +18,7 @@ from comfy_api_nodes.util import (
|
||||
download_url_as_bytesio,
|
||||
resize_mask_to_image,
|
||||
sync_op,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
V1_V1_RES_MAP = {
|
||||
@ -234,7 +236,7 @@ class IdeogramV1(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV1",
|
||||
display_name="Ideogram V1",
|
||||
category="api node/image/Ideogram",
|
||||
category="partner/image/Ideogram",
|
||||
description="Generates images using the Ideogram V1 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -360,7 +362,7 @@ class IdeogramV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV2",
|
||||
display_name="Ideogram V2",
|
||||
category="api node/image/Ideogram",
|
||||
category="partner/image/Ideogram",
|
||||
description="Generates images using the Ideogram V2 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -526,7 +528,7 @@ class IdeogramV3(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV3",
|
||||
display_name="Ideogram V3",
|
||||
category="api node/image/Ideogram",
|
||||
category="partner/image/Ideogram",
|
||||
description="Generates images using the Ideogram V3 model. "
|
||||
"Supports both regular image generation from text prompts and image editing with mask.",
|
||||
inputs=[
|
||||
@ -798,6 +800,119 @@ class IdeogramV3(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramV4(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="IdeogramV4",
|
||||
display_name="Ideogram V4",
|
||||
category="partner/image/Ideogram",
|
||||
description="Generates images using the Ideogram 4.0 model from a text prompt.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the image generation.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=[
|
||||
"Auto",
|
||||
"2048x2048 (1:1)",
|
||||
"1440x2880 (1:2)",
|
||||
"2880x1440 (2:1)",
|
||||
"1664x2496 (2:3)",
|
||||
"2496x1664 (3:2)",
|
||||
"1792x2240 (4:5)",
|
||||
"2240x1792 (5:4)",
|
||||
"1440x2560 (9:16)",
|
||||
"2560x1440 (16:9)",
|
||||
"1600x2560 (5:8)",
|
||||
"2560x1600 (8:5)",
|
||||
"1728x2304 (3:4)",
|
||||
"2304x1728 (4:3)",
|
||||
"1296x3168 (9:22)",
|
||||
"3168x1296 (22:9)",
|
||||
"1152x2944 (9:23)",
|
||||
"2944x1152 (23:9)",
|
||||
"1248x3328 (3:8)",
|
||||
"3328x1248 (8:3)",
|
||||
"1280x3072 (5:12)",
|
||||
"3072x1280 (12:5)",
|
||||
],
|
||||
default="Auto",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"rendering_speed",
|
||||
options=["DEFAULT", "TURBO", "QUALITY"],
|
||||
default="DEFAULT",
|
||||
tooltip="Controls the trade-off between generation speed and quality.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["rendering_speed"]),
|
||||
expr="""
|
||||
(
|
||||
$speed := widgets.rendering_speed;
|
||||
$price :=
|
||||
$contains($speed,"turbo") ? 0.0429 :
|
||||
$contains($speed,"quality") ? 0.143 :
|
||||
0.0858;
|
||||
{"type":"usd","usd": $price}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
resolution: str,
|
||||
rendering_speed: str,
|
||||
seed: int,
|
||||
):
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/ideogram/ideogram-v4/generate", method="POST"),
|
||||
response_model=IdeogramGenerateResponse,
|
||||
data=IdeogramV4Request(
|
||||
text_prompt=prompt,
|
||||
resolution=resolution.split(" ")[0] if resolution != "Auto" else None,
|
||||
rendering_speed=rendering_speed,
|
||||
),
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
if not response.data or len(response.data) == 0:
|
||||
raise Exception("No images were generated in the response")
|
||||
image_urls = [image_data.url for image_data in response.data if image_data.url]
|
||||
if not image_urls:
|
||||
raise Exception("No image URLs were generated in the response")
|
||||
return IO.NodeOutput(await download_and_process_images(image_urls))
|
||||
|
||||
|
||||
class IdeogramExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -805,6 +920,7 @@ class IdeogramExtension(ComfyExtension):
|
||||
IdeogramV1,
|
||||
IdeogramV2,
|
||||
IdeogramV3,
|
||||
IdeogramV4,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -642,7 +642,7 @@ class KlingCameraControls(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingCameraControls",
|
||||
display_name="Kling Camera Controls",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Allows specifying configuration options for Kling Camera Controls and motion control effects.",
|
||||
inputs=[
|
||||
IO.Combo.Input("camera_control_type", options=KlingCameraControlType),
|
||||
@ -762,7 +762,7 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingTextToVideoNode",
|
||||
display_name="Kling Text to Video",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Kling Text to Video Node",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -849,7 +849,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProTextToVideoNode",
|
||||
display_name="Kling 3.0 Omni Text to Video",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Use text prompts to generate videos with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
@ -998,7 +998,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProFirstLastFrameNode",
|
||||
display_name="Kling 3.0 Omni First-Last-Frame to Video",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
@ -1205,7 +1205,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProImageToVideoNode",
|
||||
display_name="Kling 3.0 Omni Image to Video",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Use up to 7 reference images to generate a video with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
@ -1374,7 +1374,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProVideoToVideoNode",
|
||||
display_name="Kling 3.0 Omni Video to Video",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
@ -1485,7 +1485,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProEditVideoNode",
|
||||
display_name="Kling 3.0 Omni Edit Video",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
essentials_category="Video Generation",
|
||||
description="Edit an existing video with the latest model from Kling.",
|
||||
inputs=[
|
||||
@ -1593,7 +1593,7 @@ class OmniProImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProImageNode",
|
||||
display_name="Kling 3.0 Omni Image",
|
||||
category="api node/image/Kling",
|
||||
category="partner/image/Kling",
|
||||
description="Create or edit images with the latest model from Kling.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]),
|
||||
@ -1721,7 +1721,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingCameraControlT2VNode",
|
||||
display_name="Kling Text to Video (Camera Control)",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -1783,7 +1783,7 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingImage2VideoNode",
|
||||
display_name="Kling Image(First Frame) to Video",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
inputs=[
|
||||
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -1882,7 +1882,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingCameraControlI2VNode",
|
||||
display_name="Kling Image to Video (Camera Control)",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -1953,7 +1953,7 @@ class KlingStartEndFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingStartEndFrameNode",
|
||||
display_name="Kling Start-End Frame to Video",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -2047,7 +2047,7 @@ class KlingVideoExtendNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingVideoExtendNode",
|
||||
display_name="Kling Video Extend",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -2128,7 +2128,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingDualCharacterVideoEffectNode",
|
||||
display_name="Kling Dual Character Video Effects",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.",
|
||||
inputs=[
|
||||
IO.Image.Input("image_left", tooltip="Left side image"),
|
||||
@ -2218,7 +2218,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingSingleImageVideoEffectNode",
|
||||
display_name="Kling Video Effects",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Achieve different special effects when generating a video based on the effect_scene.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -2291,7 +2291,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingLipSyncAudioToVideoNode",
|
||||
display_name="Kling Lip Sync Video with Audio",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
essentials_category="Video Generation",
|
||||
description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
|
||||
inputs=[
|
||||
@ -2343,7 +2343,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingLipSyncTextToVideoNode",
|
||||
display_name="Kling Lip Sync Video with Text",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
@ -2411,7 +2411,7 @@ class KlingVirtualTryOnNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingVirtualTryOnNode",
|
||||
display_name="Kling Virtual Try On",
|
||||
category="api node/image/Kling",
|
||||
category="partner/image/Kling",
|
||||
description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.",
|
||||
inputs=[
|
||||
IO.Image.Input("human_image"),
|
||||
@ -2478,7 +2478,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingImageGenerationNode",
|
||||
display_name="Kling 3.0 Image",
|
||||
category="api node/image/Kling",
|
||||
category="partner/image/Kling",
|
||||
description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -2615,7 +2615,7 @@ class TextToVideoWithAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingTextToVideoWithAudio",
|
||||
display_name="Kling 2.6 Text to Video with Audio",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
||||
@ -2683,7 +2683,7 @@ class ImageToVideoWithAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingImageToVideoWithAudio",
|
||||
display_name="Kling 2.6 Image(First Frame) to Video with Audio",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||
IO.Image.Input("start_frame"),
|
||||
@ -2753,7 +2753,7 @@ class MotionControl(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingMotionControl",
|
||||
display_name="Kling Motion Control",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True),
|
||||
IO.Image.Input("reference_image"),
|
||||
@ -2854,7 +2854,7 @@ class KlingVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingVideoNode",
|
||||
display_name="Kling 3.0 Video",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Generate videos with Kling V3. "
|
||||
"Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.",
|
||||
inputs=[
|
||||
@ -3077,7 +3077,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingFirstLastFrameNode",
|
||||
display_name="Kling 3.0 First-Last-Frame to Video",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Generate videos with Kling V3 using first and last frames.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
@ -3202,7 +3202,7 @@ class KlingAvatarNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="KlingAvatarNode",
|
||||
display_name="Kling Avatar 2.0",
|
||||
category="api node/video/Kling",
|
||||
category="partner/video/Kling",
|
||||
description="Generate broadcast-style digital human videos from a single photo and an audio file.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
|
||||
294
comfy_api_nodes/nodes_krea.py
Normal file
294
comfy_api_nodes/nodes_krea.py
Normal file
@ -0,0 +1,294 @@
|
||||
"""Krea image-generation nodes."""
|
||||
|
||||
import re
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.krea import (
|
||||
KreaAssetResponse,
|
||||
KreaGenerateImageRequest,
|
||||
KreaImageStyleReference,
|
||||
KreaJob,
|
||||
KreaMoodboard,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_bytesio,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
class KreaIO:
|
||||
STYLE_REF = "KREA_STYLE_REF"
|
||||
|
||||
|
||||
async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Image) -> str:
|
||||
"""Upload an image to Krea's /assets endpoint and return the Krea-hosted image URL."""
|
||||
img_io = tensor_to_bytesio(image, total_pixels=2048 * 2048, mime_type="image/png")
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/krea/assets", method="POST"),
|
||||
response_model=KreaAssetResponse,
|
||||
files=[("file", (img_io.name, img_io, "image/png"))],
|
||||
content_type="multipart/form-data",
|
||||
max_retries=1,
|
||||
wait_label="Uploading reference",
|
||||
)
|
||||
return response.image_url
|
||||
|
||||
|
||||
_MODEL_MEDIUM = "Krea 2 Medium"
|
||||
_MODEL_MEDIUM_TURBO = "Krea 2 Medium Turbo"
|
||||
_MODEL_LARGE = "Krea 2 Large"
|
||||
_MODEL_ENDPOINTS: dict[str, str] = {
|
||||
_MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium",
|
||||
_MODEL_MEDIUM_TURBO: "/proxy/krea/generate/image/krea/krea-2/medium-turbo",
|
||||
_MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large",
|
||||
}
|
||||
|
||||
_ASPECT_RATIOS = ["1:1", "4:3", "3:2", "16:9", "2.35:1", "4:5", "2:3", "9:16"]
|
||||
_RESOLUTIONS = ["1K"]
|
||||
_CREATIVITY_LEVELS = ["raw", "low", "medium", "high"]
|
||||
_KREA_QUEUED_STATUSES = ["backlogged", "queued", "scheduled"]
|
||||
|
||||
_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$")
|
||||
|
||||
|
||||
def _krea_model_inputs() -> list:
|
||||
"""Nested inputs shared by Krea 2 Medium, Medium Turbo and Large under the DynamicCombo."""
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=_ASPECT_RATIOS,
|
||||
tooltip="Output aspect ratio.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=_RESOLUTIONS,
|
||||
tooltip="Resolution scale.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"creativity",
|
||||
options=_CREATIVITY_LEVELS,
|
||||
default="medium",
|
||||
tooltip="Prompt interpretation strength: raw stays closest to the prompt; high is most creative.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"moodboard_id",
|
||||
default="",
|
||||
tooltip="Optional Krea moodboard UUID (e.g. from the Krea website). "
|
||||
"Leave empty to disable. Only one moodboard is supported per request.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"moodboard_strength",
|
||||
default=0.35,
|
||||
min=-0.5,
|
||||
max=1.5,
|
||||
step=0.05,
|
||||
tooltip="Moodboard influence; ignored when moodboard_id is empty.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Custom(KreaIO.STYLE_REF).Input(
|
||||
"style_reference",
|
||||
optional=True,
|
||||
tooltip="Optional chain of style references (max 10) from Krea 2 Style Reference nodes.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class Krea2ImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Krea2ImageNode",
|
||||
display_name="Krea 2 Image",
|
||||
category="partner/image/Krea",
|
||||
description=(
|
||||
"Generate images via Krea 2 — pick Medium (expressive illustrations) or "
|
||||
"Large (expressive photorealism). Supports an optional moodboard and up "
|
||||
"to 10 chained image style references."
|
||||
),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text prompt for the image.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()),
|
||||
IO.DynamicCombo.Option(_MODEL_MEDIUM_TURBO, _krea_model_inputs()),
|
||||
IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()),
|
||||
],
|
||||
tooltip="Krea 2 Medium is best for expressive illustrations; "
|
||||
"Krea 2 Large is best for expressive photorealism.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Random seed for reproducibility.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model", "model.moodboard_id"],
|
||||
inputs=["model.style_reference"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$rates := {
|
||||
"krea 2 medium turbo": {"text": 0.015, "style": 0.0175, "moodboard": 0.02},
|
||||
"krea 2 medium": {"text": 0.03, "style": 0.035, "moodboard": 0.04},
|
||||
"krea 2 large": {"text": 0.06, "style": 0.065, "moodboard": 0.07}
|
||||
};
|
||||
$r := $lookup($rates, widgets.model);
|
||||
$hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0;
|
||||
$hasStyle := $lookup(inputs, "model.style_reference").connected;
|
||||
$usd := $hasMoodboard ? $r.moodboard : ($hasStyle ? $r.style : $r.text);
|
||||
{"type":"usd","usd": $usd}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||
|
||||
model_choice = model["model"]
|
||||
endpoint_path = _MODEL_ENDPOINTS.get(model_choice)
|
||||
if endpoint_path is None:
|
||||
raise ValueError(f"Unknown Krea 2 model: {model_choice!r}")
|
||||
|
||||
moodboards: list[KreaMoodboard] | None = None
|
||||
mb_id = (model.get("moodboard_id") or "").strip()
|
||||
if mb_id:
|
||||
if not _UUID_RE.match(mb_id):
|
||||
raise ValueError(f"moodboard_id must be a UUID (received {mb_id!r}); copy it from the Krea website.")
|
||||
mb_strength = model.get("moodboard_strength")
|
||||
moodboards = [KreaMoodboard(id=mb_id, strength=0.35 if mb_strength is None else float(mb_strength))]
|
||||
|
||||
style_reference = model.get("style_reference")
|
||||
image_style_references: list[KreaImageStyleReference] | None = None
|
||||
if style_reference:
|
||||
if len(style_reference) > 10:
|
||||
raise ValueError(f"Krea 2 accepts at most 10 image_style_references; received {len(style_reference)}.")
|
||||
image_style_references = [
|
||||
KreaImageStyleReference(url=ref["url"], strength=float(ref["strength"])) for ref in style_reference
|
||||
]
|
||||
initial = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=endpoint_path, method="POST"),
|
||||
response_model=KreaJob,
|
||||
data=KreaGenerateImageRequest(
|
||||
prompt=prompt,
|
||||
aspect_ratio=model["aspect_ratio"],
|
||||
resolution=model["resolution"],
|
||||
seed=seed,
|
||||
creativity=model["creativity"],
|
||||
moodboards=moodboards,
|
||||
image_style_references=image_style_references,
|
||||
),
|
||||
)
|
||||
job = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/krea/jobs/{initial.job_id}", method="GET"),
|
||||
response_model=KreaJob,
|
||||
status_extractor=lambda r: r.status,
|
||||
queued_statuses=_KREA_QUEUED_STATUSES,
|
||||
)
|
||||
if not job.result or not job.result.urls:
|
||||
raise RuntimeError(f"Krea 2 job {job.job_id} completed without any image URLs.")
|
||||
image = await download_url_to_image_tensor(job.result.urls[0])
|
||||
return IO.NodeOutput(image)
|
||||
|
||||
|
||||
class Krea2StyleReferenceNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Krea2StyleReferenceNode",
|
||||
display_name="Krea 2 Style Reference",
|
||||
category="partner/image/Krea",
|
||||
description=(
|
||||
"Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 "
|
||||
"Style Reference nodes (max 10) and feed the final `style_reference` output "
|
||||
"into Krea 2 Image. Each image is uploaded to ComfyAPI storage and passed as URL."
|
||||
),
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="Reference image whose style influences the generation.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"strength",
|
||||
default=1.0,
|
||||
min=-2.0,
|
||||
max=2.0,
|
||||
step=0.05,
|
||||
tooltip="Reference strength; negative values invert the style influence.",
|
||||
),
|
||||
IO.Custom(KreaIO.STYLE_REF).Input(
|
||||
"style_reference",
|
||||
optional=True,
|
||||
tooltip="Optional incoming chain of style references; this node appends one more.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(KreaIO.STYLE_REF).Output(display_name="style_reference")],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
strength: float,
|
||||
style_reference: list[dict] | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain: list[dict] = list(style_reference) if style_reference else []
|
||||
if len(chain) >= 10:
|
||||
raise ValueError("Krea 2 accepts at most 10 image_style_references in one generation.")
|
||||
url = await _upload_image_to_krea_assets(cls, image)
|
||||
chain.append({"url": url, "strength": float(strength)})
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class KreaExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
Krea2ImageNode,
|
||||
Krea2StyleReferenceNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> KreaExtension:
|
||||
return KreaExtension()
|
||||
@ -50,7 +50,7 @@ class TextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiTextToVideo",
|
||||
display_name="LTXV Text To Video",
|
||||
category="api node/video/LTXV",
|
||||
category="partner/video/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=list(MODELS_MAP.keys())),
|
||||
@ -127,7 +127,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LtxvApiImageToVideo",
|
||||
display_name="LTXV Image To Video",
|
||||
category="api node/video/LTXV",
|
||||
category="partner/video/LTXV",
|
||||
description="Professional-quality videos with customizable duration and resolution based on start image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="First frame to be used for the video."),
|
||||
|
||||
@ -46,7 +46,7 @@ class LumaReferenceNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaReferenceNode",
|
||||
display_name="Luma Reference",
|
||||
category="api node/image/Luma",
|
||||
category="partner/image/Luma",
|
||||
description="Holds an image and weight for use with Luma Generate Image node.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -85,7 +85,7 @@ class LumaConceptsNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaConceptsNode",
|
||||
display_name="Luma Concepts",
|
||||
category="api node/video/Luma",
|
||||
category="partner/video/Luma",
|
||||
description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -134,7 +134,7 @@ class LumaImageGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageNode",
|
||||
display_name="Luma Text to Image",
|
||||
category="api node/image/Luma",
|
||||
category="partner/image/Luma",
|
||||
description="Generates images synchronously based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -278,7 +278,7 @@ class LumaImageModifyNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageModifyNode",
|
||||
display_name="Luma Image to Image",
|
||||
category="api node/image/Luma",
|
||||
category="partner/image/Luma",
|
||||
description="Modifies images synchronously based on prompt and aspect ratio.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -371,7 +371,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaVideoNode",
|
||||
display_name="Luma Text to Video",
|
||||
category="api node/video/Luma",
|
||||
category="partner/video/Luma",
|
||||
description="Generates videos synchronously based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -472,7 +472,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageToVideoNode",
|
||||
display_name="Luma Image to Video",
|
||||
category="api node/video/Luma",
|
||||
category="partner/video/Luma",
|
||||
description="Generates videos synchronously based on prompt, input images, and output_size.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -724,7 +724,7 @@ class LumaImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageNode2",
|
||||
display_name="Luma UNI-1 Image",
|
||||
category="api node/image/Luma",
|
||||
category="partner/image/Luma",
|
||||
description="Generate images from text using the Luma UNI-1 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -853,7 +853,7 @@ class LumaImageEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="LumaImageEditNode2",
|
||||
display_name="Luma UNI-1 Image Edit",
|
||||
category="api node/image/Luma",
|
||||
category="partner/image/Luma",
|
||||
description="Edit an existing image with a text prompt using the Luma UNI-1 model.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
|
||||
@ -61,7 +61,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageUpscalerCreativeNode",
|
||||
display_name="Magnific Image Upscale (Creative)",
|
||||
category="api node/image/Magnific",
|
||||
category="partner/image/Magnific",
|
||||
description="Prompt‑guided enhancement, stylization, and 2x/4x/8x/16x upscaling. "
|
||||
"Maximum output: 25.3 megapixels.",
|
||||
inputs=[
|
||||
@ -240,7 +240,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageUpscalerPreciseV2Node",
|
||||
display_name="Magnific Image Upscale (Precise V2)",
|
||||
category="api node/image/Magnific",
|
||||
category="partner/image/Magnific",
|
||||
description="High-fidelity upscaling with fine control over sharpness, grain, and detail. "
|
||||
"Maximum output: 10060×10060 pixels.",
|
||||
inputs=[
|
||||
@ -400,7 +400,7 @@ class MagnificImageStyleTransferNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageStyleTransferNode",
|
||||
display_name="Magnific Image Style Transfer",
|
||||
category="api node/image/Magnific",
|
||||
category="partner/image/Magnific",
|
||||
description="Transfer the style from a reference image to your input image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to apply style transfer to."),
|
||||
@ -549,7 +549,7 @@ class MagnificImageRelightNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageRelightNode",
|
||||
display_name="Magnific Image Relight",
|
||||
category="api node/image/Magnific",
|
||||
category="partner/image/Magnific",
|
||||
description="Relight an image with lighting adjustments and optional reference-based light transfer.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to relight."),
|
||||
@ -789,7 +789,7 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MagnificImageSkinEnhancerNode",
|
||||
display_name="Magnific Image Skin Enhancer",
|
||||
category="api node/image/Magnific",
|
||||
category="partner/image/Magnific",
|
||||
description="Skin enhancement for portraits with multiple processing modes.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The portrait image to enhance."),
|
||||
|
||||
@ -33,7 +33,7 @@ class MeshyTextToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyTextToModelNode",
|
||||
display_name="Meshy: Text to Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.String.Input("prompt", multiline=True, default=""),
|
||||
@ -145,7 +145,7 @@ class MeshyRefineNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyRefineNode",
|
||||
display_name="Meshy: Refine Draft Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
description="Refine a previously created draft model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
@ -240,7 +240,7 @@ class MeshyImageToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyImageToModelNode",
|
||||
display_name="Meshy: Image to Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.Image.Input("image"),
|
||||
@ -405,7 +405,7 @@ class MeshyMultiImageToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyMultiImageToModelNode",
|
||||
display_name="Meshy: Multi-Image to Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.Autogrow.Input(
|
||||
@ -575,7 +575,7 @@ class MeshyRigModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyRigModelNode",
|
||||
display_name="Meshy: Rig Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
description="Provides a rigged character in standard formats. "
|
||||
"Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, "
|
||||
"or humanoid assets with unclear limb and body structure.",
|
||||
@ -656,7 +656,7 @@ class MeshyAnimateModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyAnimateModelNode",
|
||||
display_name="Meshy: Animate Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
description="Apply a specific animation action to a previously rigged character.",
|
||||
inputs=[
|
||||
IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"),
|
||||
@ -722,7 +722,7 @@ class MeshyTextureNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MeshyTextureNode",
|
||||
display_name="Meshy: Texture Model",
|
||||
category="api node/3d/Meshy",
|
||||
category="partner/3d/Meshy",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["latest"]),
|
||||
IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"),
|
||||
|
||||
@ -101,7 +101,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MinimaxTextToVideoNode",
|
||||
display_name="MiniMax Text to Video",
|
||||
category="api node/video/MiniMax",
|
||||
category="partner/video/MiniMax",
|
||||
description="Generates videos synchronously based on a prompt, and optional parameters.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -163,7 +163,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MinimaxImageToVideoNode",
|
||||
display_name="MiniMax Image to Video",
|
||||
category="api node/video/MiniMax",
|
||||
category="partner/video/MiniMax",
|
||||
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -230,7 +230,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MinimaxSubjectToVideoNode",
|
||||
display_name="MiniMax Subject to Video",
|
||||
category="api node/video/MiniMax",
|
||||
category="partner/video/MiniMax",
|
||||
description="Generates videos synchronously based on an image and prompt, and optional parameters.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
@ -294,7 +294,7 @@ class MinimaxHailuoVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="MinimaxHailuoVideoNode",
|
||||
display_name="MiniMax Hailuo Video",
|
||||
category="api node/video/MiniMax",
|
||||
category="partner/video/MiniMax",
|
||||
description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -99,7 +99,7 @@ class OpenAIDalle2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIDalle2",
|
||||
display_name="OpenAI DALL·E 2",
|
||||
category="api node/image/OpenAI",
|
||||
category="partner/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -249,7 +249,7 @@ class OpenAIDalle3(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIDalle3",
|
||||
display_name="OpenAI DALL·E 3",
|
||||
category="api node/image/OpenAI",
|
||||
category="partner/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -371,7 +371,7 @@ class OpenAIGPTImage1(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImage1",
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="api node/image/OpenAI",
|
||||
category="partner/image/OpenAI",
|
||||
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
|
||||
is_deprecated=True,
|
||||
inputs=[
|
||||
@ -695,7 +695,7 @@ class OpenAIGPTImageNodeV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIGPTImageNodeV2",
|
||||
display_name="OpenAI GPT Image 2",
|
||||
category="api node/image/OpenAI",
|
||||
category="partner/image/OpenAI",
|
||||
description="Generates images via OpenAI's GPT Image endpoint.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -962,7 +962,7 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIChatNode",
|
||||
display_name="OpenAI ChatGPT",
|
||||
category="api node/text/OpenAI",
|
||||
category="partner/text/OpenAI",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses from an OpenAI model.",
|
||||
inputs=[
|
||||
@ -1201,7 +1201,7 @@ class OpenAIInputFiles(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIInputFiles",
|
||||
display_name="OpenAI ChatGPT Input Files",
|
||||
category="api node/text/OpenAI",
|
||||
category="partner/text/OpenAI",
|
||||
description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -1248,7 +1248,7 @@ class OpenAIChatConfig(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIChatConfig",
|
||||
display_name="OpenAI ChatGPT Advanced Options",
|
||||
category="api node/text/OpenAI",
|
||||
category="partner/text/OpenAI",
|
||||
description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
|
||||
@ -265,7 +265,7 @@ class OpenRouterLLMNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenRouterLLMNode",
|
||||
display_name="OpenRouter LLM",
|
||||
category="api node/text/OpenRouter",
|
||||
category="partner/text/OpenRouter",
|
||||
essentials_category="Text Generation",
|
||||
description=(
|
||||
"Generate text responses through OpenRouter. Routes to a curated set of popular "
|
||||
|
||||
@ -53,7 +53,7 @@ class PixverseTemplateNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="PixverseTemplateNode",
|
||||
display_name="PixVerse Template",
|
||||
category="api node/video/PixVerse",
|
||||
category="partner/video/PixVerse",
|
||||
inputs=[
|
||||
IO.Combo.Input("template", options=list(pixverse_templates.keys())),
|
||||
],
|
||||
@ -74,7 +74,7 @@ class PixverseTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="PixverseTextToVideoNode",
|
||||
display_name="PixVerse Text to Video",
|
||||
category="api node/video/PixVerse",
|
||||
category="partner/video/PixVerse",
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -192,7 +192,7 @@ class PixverseImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="PixverseImageToVideoNode",
|
||||
display_name="PixVerse Image to Video",
|
||||
category="api node/video/PixVerse",
|
||||
category="partner/video/PixVerse",
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -310,7 +310,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="PixverseTransitionVideoNode",
|
||||
display_name="PixVerse Transition Video",
|
||||
category="api node/video/PixVerse",
|
||||
category="partner/video/PixVerse",
|
||||
description="Generates videos based on prompt and output_size.",
|
||||
inputs=[
|
||||
IO.Image.Input("first_frame"),
|
||||
|
||||
@ -62,7 +62,7 @@ class QuiverTextToSVGNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="QuiverTextToSVGNode",
|
||||
display_name="Quiver Text to SVG",
|
||||
category="api node/image/Quiver",
|
||||
category="partner/image/Quiver",
|
||||
description="Generate an SVG from a text prompt using Quiver AI.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -177,7 +177,7 @@ class QuiverImageToSVGNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="QuiverImageToSVGNode",
|
||||
display_name="Quiver Image to SVG",
|
||||
category="api node/image/Quiver",
|
||||
category="partner/image/Quiver",
|
||||
description="Vectorize a raster image into SVG using Quiver AI.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
|
||||
@ -178,7 +178,7 @@ class RecraftColorRGBNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftColorRGB",
|
||||
display_name="Recraft Color RGB",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Create Recraft Color by choosing specific RGB values.",
|
||||
inputs=[
|
||||
IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."),
|
||||
@ -204,7 +204,7 @@ class RecraftControlsNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftControls",
|
||||
display_name="Recraft Controls",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Create Recraft Controls for customizing Recraft generation.",
|
||||
inputs=[
|
||||
IO.Custom(RecraftIO.COLOR).Input("colors", optional=True),
|
||||
@ -228,7 +228,7 @@ class RecraftStyleV3RealisticImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3RealisticImage",
|
||||
display_name="Recraft Style - Realistic Image",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Select realistic_image style and optional substyle.",
|
||||
inputs=[
|
||||
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
|
||||
@ -253,7 +253,7 @@ class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3DigitalIllustration",
|
||||
display_name="Recraft Style - Digital Illustration",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Select realistic_image style and optional substyle.",
|
||||
inputs=[
|
||||
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
|
||||
@ -272,7 +272,7 @@ class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3VectorIllustrationNode",
|
||||
display_name="Recraft Style - Realistic Image",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Select realistic_image style and optional substyle.",
|
||||
inputs=[
|
||||
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)),
|
||||
@ -291,7 +291,7 @@ class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3LogoRaster",
|
||||
display_name="Recraft Style - Logo Raster",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Select realistic_image style and optional substyle.",
|
||||
inputs=[
|
||||
IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)),
|
||||
@ -308,7 +308,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftStyleV3InfiniteStyleLibrary",
|
||||
display_name="Recraft Style - Infinite Style Library",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
|
||||
inputs=[
|
||||
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
|
||||
@ -331,7 +331,7 @@ class RecraftCreateStyleNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftCreateStyleNode",
|
||||
display_name="Recraft Create Style",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Create a custom style from reference images. "
|
||||
"Upload 1-5 images to use as style references. "
|
||||
"Total size of all images is limited to 5 MB.",
|
||||
@ -400,7 +400,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftTextToImageNode",
|
||||
display_name="Recraft Text to Image",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Generates images synchronously based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."),
|
||||
@ -512,7 +512,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftImageToImageNode",
|
||||
display_name="Recraft Image to Image",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Modify image based on prompt and strength.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -630,7 +630,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftImageInpaintingNode",
|
||||
display_name="Recraft Image Inpainting",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Modify image based on prompt and mask.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -732,7 +732,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftTextToVectorNode",
|
||||
display_name="Recraft Text to Vector",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Generates SVG synchronously based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True),
|
||||
@ -832,7 +832,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftVectorizeImageNode",
|
||||
display_name="Recraft Vectorize Image",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
essentials_category="Image Tools",
|
||||
description="Generates SVG synchronously from an input image.",
|
||||
inputs=[
|
||||
@ -876,7 +876,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftReplaceBackgroundNode",
|
||||
display_name="Recraft Replace Background",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Replace background on image, based on provided prompt.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -963,7 +963,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftRemoveBackgroundNode",
|
||||
display_name="Recraft Remove Background",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
essentials_category="Image Tools",
|
||||
description="Remove background from image, and return processed image and mask.",
|
||||
inputs=[
|
||||
@ -1012,7 +1012,7 @@ class RecraftCrispUpscaleNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftCrispUpscaleNode",
|
||||
display_name="Recraft Crisp Upscale Image",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Upscale image synchronously.\n"
|
||||
"Enhances a given raster image using ‘crisp upscale’ tool, "
|
||||
"increasing image resolution, making the image sharper and cleaner.",
|
||||
@ -1058,7 +1058,7 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftCreativeUpscaleNode",
|
||||
display_name="Recraft Creative Upscale Image",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Upscale image synchronously.\n"
|
||||
"Enhances a given raster image using ‘creative upscale’ tool, "
|
||||
"boosting resolution with a focus on refining small details and faces.",
|
||||
@ -1086,7 +1086,7 @@ class RecraftV4TextToImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftV4TextToImageNode",
|
||||
display_name="Recraft V4 Text to Image",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Generates images using Recraft V4 or V4 Pro models.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -1210,7 +1210,7 @@ class RecraftV4TextToVectorNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RecraftV4TextToVectorNode",
|
||||
display_name="Recraft V4 Text to Vector",
|
||||
category="api node/image/Recraft",
|
||||
category="partner/image/Recraft",
|
||||
description="Generates SVG using Recraft V4 or V4 Pro models.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -109,7 +109,7 @@ class ReveImageCreateNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageCreateNode",
|
||||
display_name="Reve Image Create",
|
||||
category="api node/image/Reve",
|
||||
category="partner/image/Reve",
|
||||
description="Generate images from text descriptions using Reve.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -200,7 +200,7 @@ class ReveImageEditNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageEditNode",
|
||||
display_name="Reve Image Edit",
|
||||
category="api node/image/Reve",
|
||||
category="partner/image/Reve",
|
||||
description="Edit images using natural language instructions with Reve.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to edit."),
|
||||
@ -300,7 +300,7 @@ class ReveImageRemixNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageRemixNode",
|
||||
display_name="Reve Image Remix",
|
||||
category="api node/image/Reve",
|
||||
category="partner/image/Reve",
|
||||
description="Combine reference images with text prompts to create new images using Reve.",
|
||||
inputs=[
|
||||
IO.Autogrow.Input(
|
||||
|
||||
@ -230,7 +230,7 @@ class Rodin3D_Regular(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Regular",
|
||||
display_name="Rodin 3D Generate - Regular Generate",
|
||||
category="api node/3d/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -289,7 +289,7 @@ class Rodin3D_Detail(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Detail",
|
||||
display_name="Rodin 3D Generate - Detail Generate",
|
||||
category="api node/3d/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -348,7 +348,7 @@ class Rodin3D_Smooth(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Smooth",
|
||||
display_name="Rodin 3D Generate - Smooth Generate",
|
||||
category="api node/3d/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -406,7 +406,7 @@ class Rodin3D_Sketch(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Sketch",
|
||||
display_name="Rodin 3D Generate - Sketch Generate",
|
||||
category="api node/3d/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -468,7 +468,7 @@ class Rodin3D_Gen2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen2",
|
||||
display_name="Rodin 3D Generate - Gen-2 Generate",
|
||||
category="api node/3d/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("Images"),
|
||||
@ -941,7 +941,7 @@ class Rodin3D_Gen25_Image(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen25_Image",
|
||||
display_name="Rodin 3D Gen-2.5 - Image to 3D",
|
||||
category="api node/3d/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=(
|
||||
"Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. "
|
||||
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
|
||||
@ -1035,7 +1035,7 @@ class Rodin3D_Gen25_Text(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Rodin3D_Gen25_Text",
|
||||
display_name="Rodin 3D Gen-2.5 - Text to 3D",
|
||||
category="api node/3d/Rodin",
|
||||
category="partner/3d/Rodin",
|
||||
description=(
|
||||
"Generate a 3D model from a text prompt via Rodin Gen-2.5. "
|
||||
"Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost."
|
||||
|
||||
@ -140,7 +140,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RunwayImageToVideoNodeGen3a",
|
||||
display_name="Runway Image to Video (Gen3a Turbo)",
|
||||
category="api node/video/Runway",
|
||||
category="partner/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen3a Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
@ -234,7 +234,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RunwayImageToVideoNodeGen4",
|
||||
display_name="Runway Image to Video (Gen4 Turbo)",
|
||||
category="api node/video/Runway",
|
||||
category="partner/video/Runway",
|
||||
description="Generate a video from a single starting frame using Gen4 Turbo model. "
|
||||
"Before diving in, review these best practices to ensure that "
|
||||
"your input selections will set your generation up for success: "
|
||||
@ -329,7 +329,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RunwayFirstLastFrameNode",
|
||||
display_name="Runway First-Last-Frame to Video",
|
||||
category="api node/video/Runway",
|
||||
category="partner/video/Runway",
|
||||
description="Upload first and last keyframes, draft a prompt, and generate a video. "
|
||||
"More complex transitions, such as cases where the Last frame is completely different "
|
||||
"from the First frame, may benefit from the longer 10s duration. "
|
||||
@ -440,7 +440,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="RunwayTextToImageNode",
|
||||
display_name="Runway Text to Image",
|
||||
category="api node/image/Runway",
|
||||
category="partner/image/Runway",
|
||||
description="Generate an image from a text prompt using Runway's Gen 4 model. "
|
||||
"You can also include reference image to guide the generation.",
|
||||
inputs=[
|
||||
|
||||
@ -34,7 +34,7 @@ class SoniloVideoToMusic(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SoniloVideoToMusic",
|
||||
display_name="Sonilo Video to Music",
|
||||
category="api node/audio/Sonilo",
|
||||
category="partner/audio/Sonilo",
|
||||
description="Generate music from video content using Sonilo's AI model. "
|
||||
"Analyzes the video and creates matching music.",
|
||||
inputs=[
|
||||
@ -99,7 +99,7 @@ class SoniloTextToMusic(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SoniloTextToMusic",
|
||||
display_name="Sonilo Text to Music",
|
||||
category="api node/audio/Sonilo",
|
||||
category="partner/audio/Sonilo",
|
||||
description="Generate music from a text prompt using Sonilo's AI model. "
|
||||
"Leave duration at 0 to let the model infer it from the prompt.",
|
||||
inputs=[
|
||||
|
||||
@ -34,7 +34,7 @@ class OpenAIVideoSora2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="OpenAIVideoSora2",
|
||||
display_name="OpenAI Sora - Video (DEPRECATED)",
|
||||
category="api node/video/Sora",
|
||||
category="partner/video/Sora",
|
||||
description=(
|
||||
"OpenAI video and audio generation.\n\n"
|
||||
"DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. "
|
||||
|
||||
@ -62,7 +62,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageUltraNode",
|
||||
display_name="Stability AI Stable Image Ultra",
|
||||
category="api node/image/Stability AI",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -197,7 +197,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityStableImageSD_3_5Node",
|
||||
display_name="Stability AI Stable Diffusion 3.5 Image",
|
||||
category="api node/image/Stability AI",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -354,7 +354,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleConservativeNode",
|
||||
display_name="Stability AI Upscale Conservative",
|
||||
category="api node/image/Stability AI",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -457,7 +457,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleCreativeNode",
|
||||
display_name="Stability AI Upscale Creative",
|
||||
category="api node/image/Stability AI",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -578,7 +578,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityUpscaleFastNode",
|
||||
display_name="Stability AI Upscale Fast",
|
||||
category="api node/image/Stability AI",
|
||||
category="partner/image/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
@ -630,7 +630,7 @@ class StabilityTextToAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityTextToAudio",
|
||||
display_name="Stability AI Text To Audio",
|
||||
category="api node/audio/Stability AI",
|
||||
category="partner/audio/Stability AI",
|
||||
essentials_category="Audio",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
@ -708,7 +708,7 @@ class StabilityAudioToAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioToAudio",
|
||||
display_name="Stability AI Audio To Audio",
|
||||
category="api node/audio/Stability AI",
|
||||
category="partner/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -802,7 +802,7 @@ class StabilityAudioInpaint(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="StabilityAudioInpaint",
|
||||
display_name="Stability AI Audio Inpaint",
|
||||
category="api node/audio/Stability AI",
|
||||
category="partner/audio/Stability AI",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
|
||||
@ -52,7 +52,7 @@ class TopazImageEnhance(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TopazImageEnhance",
|
||||
display_name="Topaz Image Enhance",
|
||||
category="api node/image/Topaz",
|
||||
category="partner/image/Topaz",
|
||||
description="Industry-standard upscaling and image enhancement.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["Reimagine"]),
|
||||
@ -235,7 +235,7 @@ class TopazVideoEnhance(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhance",
|
||||
display_name="Topaz Video Enhance (Legacy)",
|
||||
category="api node/video/Topaz",
|
||||
category="partner/video/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
@ -475,7 +475,7 @@ class TopazVideoEnhanceV2(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhanceV2",
|
||||
display_name="Topaz Video Enhance",
|
||||
category="api node/video/Topaz",
|
||||
category="partner/video/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
|
||||
@ -11,6 +11,9 @@ from comfy_api_nodes.apis.tripo import (
|
||||
TripoModelVersion,
|
||||
TripoMultiviewToModelRequest,
|
||||
TripoOrientation,
|
||||
TripoP1ImageToModelRequest,
|
||||
TripoP1MultiviewToModelRequest,
|
||||
TripoP1TextToModelRequest,
|
||||
TripoRefineModelRequest,
|
||||
TripoStyle,
|
||||
TripoTaskResponse,
|
||||
@ -80,7 +83,7 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoTextToModelNode",
|
||||
display_name="Tripo: Text to Model",
|
||||
category="api node/3d/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True),
|
||||
IO.String.Input("negative_prompt", multiline=True, optional=True),
|
||||
@ -93,10 +96,22 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
IO.Int.Input("image_seed", default=42, optional=True, advanced=True),
|
||||
IO.Int.Input("model_seed", default=42, optional=True, advanced=True),
|
||||
IO.Int.Input("texture_seed", default=42, optional=True, advanced=True),
|
||||
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
|
||||
IO.Combo.Input(
|
||||
"texture_quality",
|
||||
default="standard",
|
||||
options=["standard", "detailed"],
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True, advanced=True),
|
||||
IO.Boolean.Input("quad", default=False, optional=True, advanced=True),
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
|
||||
IO.Combo.Input(
|
||||
"geometry_quality",
|
||||
default="standard",
|
||||
options=["standard", "detailed"],
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
@ -195,7 +210,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoImageToModelNode",
|
||||
display_name="Tripo: Image to Model",
|
||||
category="api node/3d/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Combo.Input(
|
||||
@ -209,16 +224,36 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
IO.Boolean.Input("pbr", default=True, optional=True),
|
||||
IO.Int.Input("model_seed", default=42, optional=True, advanced=True),
|
||||
IO.Combo.Input(
|
||||
"orientation", options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True, advanced=True
|
||||
"orientation",
|
||||
options=TripoOrientation,
|
||||
default=TripoOrientation.DEFAULT,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input("texture_seed", default=42, optional=True, advanced=True),
|
||||
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
|
||||
IO.Combo.Input(
|
||||
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True
|
||||
"texture_quality",
|
||||
default="standard",
|
||||
options=["standard", "detailed"],
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"texture_alignment",
|
||||
default="original_image",
|
||||
options=["original_image", "geometry"],
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True),
|
||||
IO.Boolean.Input("quad", default=False, optional=True, advanced=True),
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
|
||||
IO.Combo.Input(
|
||||
"geometry_quality",
|
||||
default="standard",
|
||||
options=["standard", "detailed"],
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
@ -323,7 +358,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoMultiviewToModelNode",
|
||||
display_name="Tripo: Multiview to Model",
|
||||
category="api node/3d/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.Image.Input("image_left", optional=True),
|
||||
@ -346,13 +381,35 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
IO.Boolean.Input("pbr", default=True, optional=True),
|
||||
IO.Int.Input("model_seed", default=42, optional=True, advanced=True),
|
||||
IO.Int.Input("texture_seed", default=42, optional=True, advanced=True),
|
||||
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
|
||||
IO.Combo.Input(
|
||||
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True
|
||||
"texture_quality",
|
||||
default="standard",
|
||||
options=["standard", "detailed"],
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"texture_alignment",
|
||||
default="original_image",
|
||||
options=["original_image", "geometry"],
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True),
|
||||
IO.Boolean.Input("quad", default=False, optional=True, advanced=True, tooltip="This parameter is deprecated and does nothing."),
|
||||
IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
|
||||
IO.Boolean.Input(
|
||||
"quad",
|
||||
default=False,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="This parameter is deprecated and does nothing.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"geometry_quality",
|
||||
default="standard",
|
||||
options=["standard", "detailed"],
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
@ -461,15 +518,25 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoTextureNode",
|
||||
display_name="Tripo: Texture model",
|
||||
category="api node/3d/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[
|
||||
IO.Custom("MODEL_TASK_ID").Input("model_task_id"),
|
||||
IO.Boolean.Input("texture", default=True, optional=True),
|
||||
IO.Boolean.Input("pbr", default=True, optional=True),
|
||||
IO.Int.Input("texture_seed", default=42, optional=True, advanced=True),
|
||||
IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True),
|
||||
IO.Combo.Input(
|
||||
"texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True
|
||||
"texture_quality",
|
||||
default="standard",
|
||||
options=["standard", "detailed"],
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"texture_alignment",
|
||||
default="original_image",
|
||||
options=["original_image", "geometry"],
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
@ -528,7 +595,7 @@ class TripoRefineNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoRefineNode",
|
||||
display_name="Tripo: Refine Draft model",
|
||||
category="api node/3d/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
description="Refine a draft model created by v1.4 Tripo models only.",
|
||||
inputs=[
|
||||
IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"),
|
||||
@ -568,7 +635,7 @@ class TripoRigNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoRigNode",
|
||||
display_name="Tripo: Rig model",
|
||||
category="api node/3d/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
@ -605,7 +672,7 @@ class TripoRetargetNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoRetargetNode",
|
||||
display_name="Tripo: Retarget rigged model",
|
||||
category="api node/3d/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[
|
||||
IO.Custom("RIG_TASK_ID").Input("original_model_task_id"),
|
||||
IO.Combo.Input(
|
||||
@ -626,7 +693,7 @@ class TripoRetargetNode(IO.ComfyNode):
|
||||
"preset:hexapod:walk",
|
||||
"preset:octopod:walk",
|
||||
"preset:serpentine:march",
|
||||
"preset:aquatic:march"
|
||||
"preset:aquatic:march",
|
||||
],
|
||||
),
|
||||
],
|
||||
@ -670,7 +737,7 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="TripoConversionNode",
|
||||
display_name="Tripo: Convert model",
|
||||
category="api node/3d/Tripo",
|
||||
category="partner/3d/Tripo",
|
||||
inputs=[
|
||||
IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"),
|
||||
IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]),
|
||||
@ -817,7 +884,7 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
# Parse part_names from comma-separated string to list
|
||||
part_names_list = None
|
||||
if part_names and part_names.strip():
|
||||
part_names_list = [name.strip() for name in part_names.split(',') if name.strip()]
|
||||
part_names_list = [name.strip() for name in part_names.split(",") if name.strip()]
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
@ -848,6 +915,373 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
return await poll_until_finished(cls, response, average_duration=30)
|
||||
|
||||
|
||||
def _p1_price_expr(*, geometry_credits: int, textured_credits: int, detailed_credits: int) -> str:
|
||||
return (
|
||||
"("
|
||||
" $mode := widgets.output_mode;"
|
||||
' $detailed := $lookup(widgets, "output_mode.texture_quality") = "detailed";'
|
||||
f' $credits := $mode = "geometry only" ? {geometry_credits} : ($detailed ? {detailed_credits} : {textured_credits});'
|
||||
' {"type":"usd","usd": $credits * 0.01, "format": {"approximate": true}}'
|
||||
")"
|
||||
)
|
||||
|
||||
|
||||
def _p1_textured_inputs(*, include_image_alignment: bool) -> list:
|
||||
"""Inputs shown inside the 'Textured' branch of the P1 output_mode DynamicCombo."""
|
||||
inputs: list = [
|
||||
IO.Boolean.Input("pbr", default=True, tooltip="Include PBR maps. When on, base texture is forced on too."),
|
||||
IO.Combo.Input("texture_quality", options=["standard", "detailed"], default="standard"),
|
||||
]
|
||||
if include_image_alignment:
|
||||
inputs.extend(
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"texture_alignment",
|
||||
options=["original_image", "geometry"],
|
||||
default="original_image",
|
||||
tooltip="Prioritize visual fidelity to the source image, or alignment to the mesh geometry.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"orientation",
|
||||
options=["default", "align_image"],
|
||||
default="default",
|
||||
tooltip="Rotate the output to match the source image. Only applies when textured.",
|
||||
),
|
||||
]
|
||||
)
|
||||
inputs.append(IO.Int.Input("texture_seed", default=42, advanced=True))
|
||||
return inputs
|
||||
|
||||
|
||||
def _build_p1_output_mode(*, include_image_alignment: bool) -> IO.DynamicCombo.Input:
|
||||
return IO.DynamicCombo.Input(
|
||||
"output_mode",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Geometry only", []),
|
||||
IO.DynamicCombo.Option("Textured", _p1_textured_inputs(include_image_alignment=include_image_alignment)),
|
||||
],
|
||||
tooltip='"Geometry only" returns an untextured mesh. "Textured" adds color/PBR maps.',
|
||||
)
|
||||
|
||||
|
||||
def _resolve_p1_texture_fields(output_mode: dict) -> dict:
|
||||
"""Translate the output_mode DynamicCombo payload into P1 request fields.
|
||||
|
||||
pbr=true forces texture=true server-side, but we send both explicitly so the
|
||||
intent is visible in the request body and logs.
|
||||
"""
|
||||
mode = output_mode["output_mode"]
|
||||
if mode == "Geometry only":
|
||||
return {"texture": False, "pbr": False}
|
||||
out = {
|
||||
"texture": True,
|
||||
"pbr": bool(output_mode.get("pbr", True)),
|
||||
"texture_quality": output_mode.get("texture_quality", "standard"),
|
||||
"texture_seed": output_mode.get("texture_seed"),
|
||||
}
|
||||
if "texture_alignment" in output_mode:
|
||||
out["texture_alignment"] = output_mode["texture_alignment"]
|
||||
if "orientation" in output_mode:
|
||||
out["orientation"] = output_mode["orientation"]
|
||||
return out
|
||||
|
||||
|
||||
def _p1_common_inputs() -> list:
|
||||
"""Inputs shared by all P1 nodes (placed after output_mode)."""
|
||||
return [
|
||||
IO.Int.Input(
|
||||
"face_limit",
|
||||
default=-1,
|
||||
min=-1,
|
||||
max=20000,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Target face count, 48-20000. -1 lets Tripo pick adaptively.",
|
||||
),
|
||||
IO.Int.Input("model_seed", default=42, optional=True, advanced=True),
|
||||
IO.Boolean.Input(
|
||||
"auto_size",
|
||||
default=False,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Scale the output to approximate real-world meters.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"export_uv",
|
||||
default=True,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="UV unwrap during generation. Turn off for faster geometry-only runs.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"compress_geometry",
|
||||
default=False,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Apply geometry-based compression. Decompress before editing.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _build_p1_request_kwargs(
|
||||
*,
|
||||
output_mode: dict,
|
||||
face_limit: int,
|
||||
model_seed: int,
|
||||
auto_size: bool,
|
||||
export_uv: bool,
|
||||
compress_geometry: bool,
|
||||
) -> dict:
|
||||
"""Common P1 request fields shared by all three node types."""
|
||||
kwargs: dict = {
|
||||
"model_seed": model_seed,
|
||||
"face_limit": face_limit if face_limit != -1 else None,
|
||||
"auto_size": auto_size,
|
||||
"export_uv": export_uv,
|
||||
"compress": "geometry" if compress_geometry else None,
|
||||
}
|
||||
kwargs.update(_resolve_p1_texture_fields(output_mode))
|
||||
return kwargs
|
||||
|
||||
|
||||
class TripoP1TextToModelNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TripoP1TextToModelNode",
|
||||
display_name="Tripo P1: Text to Model",
|
||||
category="partner/3d/Tripo",
|
||||
description="Tripo P1 text-to-3D. Optimized for low-poly, game-ready meshes with stable topology.",
|
||||
inputs=[
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Up to 1024 characters."),
|
||||
IO.String.Input("negative_prompt", multiline=True, optional=True, tooltip="Up to 255 characters."),
|
||||
_build_p1_output_mode(include_image_alignment=False),
|
||||
IO.Int.Input("image_seed", default=42, optional=True, advanced=True),
|
||||
*_p1_common_inputs(),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]),
|
||||
expr=_p1_price_expr(geometry_credits=30, textured_credits=40, detailed_credits=50),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
output_mode: dict,
|
||||
negative_prompt: str | None = None,
|
||||
image_seed: int | None = None,
|
||||
face_limit: int = -1,
|
||||
model_seed: int | None = None,
|
||||
auto_size: bool = False,
|
||||
export_uv: bool = True,
|
||||
compress_geometry: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
if not prompt:
|
||||
raise RuntimeError("Prompt is required")
|
||||
common = _build_p1_request_kwargs(
|
||||
output_mode=output_mode,
|
||||
face_limit=face_limit,
|
||||
model_seed=model_seed,
|
||||
auto_size=auto_size,
|
||||
export_uv=export_uv,
|
||||
compress_geometry=compress_geometry,
|
||||
)
|
||||
request = TripoP1TextToModelRequest(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt or None,
|
||||
image_seed=image_seed,
|
||||
**common,
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
|
||||
response_model=TripoTaskResponse,
|
||||
data=request,
|
||||
)
|
||||
return await poll_until_finished(cls, response, average_duration=60)
|
||||
|
||||
|
||||
class TripoP1ImageToModelNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TripoP1ImageToModelNode",
|
||||
display_name="Tripo P1: Image to Model",
|
||||
category="partner/3d/Tripo",
|
||||
description="Tripo P1 image-to-3D. Optimized for low-poly, game-ready meshes.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
_build_p1_output_mode(include_image_alignment=True),
|
||||
IO.Boolean.Input(
|
||||
"enable_image_autofix",
|
||||
default=False,
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Pre-process the input image for better generation quality.",
|
||||
),
|
||||
*_p1_common_inputs(),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]),
|
||||
expr=_p1_price_expr(geometry_credits=40, textured_credits=50, detailed_credits=60),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
output_mode: dict,
|
||||
enable_image_autofix: bool = False,
|
||||
face_limit: int = -1,
|
||||
model_seed: int | None = None,
|
||||
auto_size: bool = False,
|
||||
export_uv: bool = True,
|
||||
compress_geometry: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
if image is None:
|
||||
raise RuntimeError("Image is required")
|
||||
tripo_file = TripoFileReference(
|
||||
root=TripoUrlReference(
|
||||
url=(await upload_images_to_comfyapi(cls, image, max_images=1))[0],
|
||||
type="jpeg",
|
||||
)
|
||||
)
|
||||
common = _build_p1_request_kwargs(
|
||||
output_mode=output_mode,
|
||||
face_limit=face_limit,
|
||||
model_seed=model_seed,
|
||||
auto_size=auto_size,
|
||||
export_uv=export_uv,
|
||||
compress_geometry=compress_geometry,
|
||||
)
|
||||
request = TripoP1ImageToModelRequest(
|
||||
file=tripo_file,
|
||||
enable_image_autofix=enable_image_autofix,
|
||||
**common,
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
|
||||
response_model=TripoTaskResponse,
|
||||
data=request,
|
||||
)
|
||||
return await poll_until_finished(cls, response, average_duration=60)
|
||||
|
||||
|
||||
class TripoP1MultiviewToModelNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TripoP1MultiviewToModelNode",
|
||||
display_name="Tripo P1: Multiview to Model",
|
||||
category="partner/3d/Tripo",
|
||||
description="Tripo P1 multiview-to-3D from 2-4 reference images in [front, left, back, right] order. "
|
||||
"Front is required; any combination of the other three may be omitted.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="Front view (0°). Required."),
|
||||
IO.Image.Input(
|
||||
"image_left",
|
||||
optional=True,
|
||||
tooltip="Left view (90°), i.e. the subject's left side.",
|
||||
),
|
||||
IO.Image.Input("image_back", optional=True, tooltip="Back view (180°)."),
|
||||
IO.Image.Input(
|
||||
"image_right",
|
||||
optional=True,
|
||||
tooltip="Right view (270°), i.e. the subject's right side.",
|
||||
),
|
||||
_build_p1_output_mode(include_image_alignment=True),
|
||||
*_p1_common_inputs(),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]),
|
||||
expr=_p1_price_expr(geometry_credits=40, textured_credits=50, detailed_credits=60),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
output_mode: dict,
|
||||
image_left: Input.Image | None = None,
|
||||
image_back: Input.Image | None = None,
|
||||
image_right: Input.Image | None = None,
|
||||
face_limit: int = -1,
|
||||
model_seed: int | None = None,
|
||||
auto_size: bool = False,
|
||||
export_uv: bool = True,
|
||||
compress_geometry: bool = False,
|
||||
) -> IO.NodeOutput:
|
||||
views = [image, image_left, image_back, image_right]
|
||||
if sum(1 for v in views if v is not None) < 2:
|
||||
raise RuntimeError("Tripo P1 multiview requires at least 2 images (front plus one of left/back/right).")
|
||||
|
||||
files: list[TripoFileReference] = []
|
||||
for view in views:
|
||||
if view is None:
|
||||
files.append(TripoFileReference(root=TripoFileEmptyReference()))
|
||||
continue
|
||||
url = (await upload_images_to_comfyapi(cls, view, max_images=1))[0]
|
||||
files.append(TripoFileReference(root=TripoUrlReference(url=url, type="jpeg")))
|
||||
|
||||
common = _build_p1_request_kwargs(
|
||||
output_mode=output_mode,
|
||||
face_limit=face_limit,
|
||||
model_seed=model_seed,
|
||||
auto_size=auto_size,
|
||||
export_uv=export_uv,
|
||||
compress_geometry=compress_geometry,
|
||||
)
|
||||
request = TripoP1MultiviewToModelRequest(files=files, **common)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"),
|
||||
response_model=TripoTaskResponse,
|
||||
data=request,
|
||||
)
|
||||
return await poll_until_finished(cls, response, average_duration=80)
|
||||
|
||||
|
||||
class TripoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -855,6 +1289,9 @@ class TripoExtension(ComfyExtension):
|
||||
TripoTextToModelNode,
|
||||
TripoImageToModelNode,
|
||||
TripoMultiviewToModelNode,
|
||||
TripoP1TextToModelNode,
|
||||
TripoP1ImageToModelNode,
|
||||
TripoP1MultiviewToModelNode,
|
||||
TripoTextureNode,
|
||||
TripoRefineNode,
|
||||
TripoRigNode,
|
||||
|
||||
@ -45,7 +45,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="VeoVideoGenerationNode",
|
||||
display_name="Google Veo 2 Video Generation",
|
||||
category="api node/video/Veo",
|
||||
category="partner/video/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 2 API",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -256,7 +256,7 @@ class Veo3VideoGenerationNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Veo3VideoGenerationNode",
|
||||
display_name="Google Veo 3 Video Generation",
|
||||
category="api node/video/Veo",
|
||||
category="partner/video/Veo",
|
||||
description="Generates videos from text prompts using Google's Veo 3 API",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
@ -468,7 +468,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Veo3FirstLastFrameNode",
|
||||
display_name="Google Veo 3 First-Last-Frame to Video",
|
||||
category="api node/video/Veo",
|
||||
category="partner/video/Veo",
|
||||
description="Generate video using prompt and first and last frames.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
|
||||
@ -71,7 +71,7 @@ class ViduTextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduTextToVideoNode",
|
||||
display_name="Vidu Text To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate video from a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
@ -169,7 +169,7 @@ class ViduImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduImageToVideoNode",
|
||||
display_name="Vidu Image To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate video from image and optional prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
@ -273,7 +273,7 @@ class ViduReferenceVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduReferenceVideoNode",
|
||||
display_name="Vidu Reference To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate video from multiple images and a prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
@ -388,7 +388,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduStartEndToVideoNode",
|
||||
display_name="Vidu Start End To Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video from start and end frames and a prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"),
|
||||
@ -492,7 +492,7 @@ class Vidu2TextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2TextToVideoNode",
|
||||
display_name="Vidu2 Text-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate video from a text prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2"]),
|
||||
@ -584,7 +584,7 @@ class Vidu2ImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2ImageToVideoNode",
|
||||
display_name="Vidu2 Image-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video from an image and an optional prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
|
||||
@ -714,7 +714,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2ReferenceVideoNode",
|
||||
display_name="Vidu2 Reference-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video from multiple reference images and a prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2"]),
|
||||
@ -849,7 +849,7 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu2StartEndToVideoNode",
|
||||
display_name="Vidu2 Start/End Frame-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video from a start frame, an end frame, and a prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]),
|
||||
@ -969,7 +969,7 @@ class ViduExtendVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduExtendVideoNode",
|
||||
display_name="Vidu Video Extension",
|
||||
category="api node/video/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Extend an existing video by generating additional frames.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1138,7 +1138,7 @@ class ViduMultiFrameVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="ViduMultiFrameVideoNode",
|
||||
display_name="Vidu Multi-Frame Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video with multiple keyframe transitions.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]),
|
||||
@ -1284,7 +1284,7 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3TextToVideoNode",
|
||||
display_name="Vidu Q3 Text-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate video from a text prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1429,7 +1429,7 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3ImageToVideoNode",
|
||||
display_name="Vidu Q3 Image-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video from an image and an optional prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1571,7 +1571,7 @@ class Vidu3StartEndToVideoNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Vidu3StartEndToVideoNode",
|
||||
display_name="Vidu Q3 Start/End Frame-to-Video Generation",
|
||||
category="api node/video/Vidu",
|
||||
category="partner/video/Vidu",
|
||||
description="Generate a video from a start frame, an end frame, and a prompt.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
|
||||
@ -61,7 +61,7 @@ class WanTextToImageApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanTextToImageApi",
|
||||
display_name="Wan Text to Image",
|
||||
category="api node/image/Wan",
|
||||
category="partner/image/Wan",
|
||||
description="Generates an image based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -184,7 +184,7 @@ class WanImageToImageApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanImageToImageApi",
|
||||
display_name="Wan Image to Image",
|
||||
category="api node/image/Wan",
|
||||
category="partner/image/Wan",
|
||||
description="Generates an image from one or two input images and a text prompt. "
|
||||
"The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).",
|
||||
inputs=[
|
||||
@ -312,7 +312,7 @@ class WanTextToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanTextToVideoApi",
|
||||
display_name="Wan Text to Video",
|
||||
category="api node/video/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generates a video based on a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -495,7 +495,7 @@ class WanImageToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanImageToVideoApi",
|
||||
display_name="Wan Image to Video",
|
||||
category="api node/video/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generates a video from the first frame and a text prompt.",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
@ -674,7 +674,7 @@ class WanReferenceVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WanReferenceVideoApi",
|
||||
display_name="Wan Reference to Video",
|
||||
category="api node/video/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Use the character and voice from input videos, combined with a prompt, "
|
||||
"to generate a new video that maintains character consistency.",
|
||||
inputs=[
|
||||
@ -828,7 +828,7 @@ class Wan2TextToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2TextToVideoApi",
|
||||
display_name="Wan 2.7 Text to Video",
|
||||
category="api node/video/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generates a video based on a text prompt using the Wan 2.7 model.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -981,7 +981,7 @@ class Wan2ImageToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2ImageToVideoApi",
|
||||
display_name="Wan 2.7 Image to Video",
|
||||
category="api node/video/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generate a video from a first-frame image, with optional last-frame image and audio.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1152,7 +1152,7 @@ class Wan2VideoContinuationApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2VideoContinuationApi",
|
||||
display_name="Wan 2.7 Video Continuation",
|
||||
category="api node/video/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Continue a video from where it left off, with optional last-frame control.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1319,7 +1319,7 @@ class Wan2VideoEditApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2VideoEditApi",
|
||||
display_name="Wan 2.7 Video Edit",
|
||||
category="api node/video/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Edit a video using text instructions, reference images, or style transfer.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1477,7 +1477,7 @@ class Wan2ReferenceVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="Wan2ReferenceVideoApi",
|
||||
display_name="Wan 2.7 Reference to Video",
|
||||
category="api node/video/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generate a video featuring a person or object from reference materials. "
|
||||
"Supports single-character performances and multi-character interactions.",
|
||||
inputs=[
|
||||
@ -1651,7 +1651,7 @@ class HappyHorseTextToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HappyHorseTextToVideoApi",
|
||||
display_name="HappyHorse Text to Video",
|
||||
category="api node/video/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generates a video based on a text prompt using the HappyHorse model.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1775,7 +1775,7 @@ class HappyHorseImageToVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HappyHorseImageToVideoApi",
|
||||
display_name="HappyHorse Image to Video",
|
||||
category="api node/video/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generate a video from a first-frame image using the HappyHorse model.",
|
||||
inputs=[
|
||||
IO.DynamicCombo.Input(
|
||||
@ -1905,7 +1905,7 @@ class HappyHorseVideoEditApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HappyHorseVideoEditApi",
|
||||
display_name="HappyHorse Video Edit",
|
||||
category="api node/video/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Edit a video using text instructions or reference images with the HappyHorse model. "
|
||||
"Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.",
|
||||
inputs=[
|
||||
@ -2046,7 +2046,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="HappyHorseReferenceVideoApi",
|
||||
display_name="HappyHorse Reference to Video",
|
||||
category="api node/video/Wan",
|
||||
category="partner/video/Wan",
|
||||
description="Generate a video featuring a person or object from reference materials with the HappyHorse "
|
||||
"model. Supports single-character performances and multi-character interactions.",
|
||||
inputs=[
|
||||
|
||||
@ -27,7 +27,7 @@ class WavespeedFlashVSRNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WavespeedFlashVSRNode",
|
||||
display_name="FlashVSR Video Upscale",
|
||||
category="api node/video/WaveSpeed",
|
||||
category="partner/video/WaveSpeed",
|
||||
description="Fast, high-quality video upscaler that "
|
||||
"boosts resolution and restores clarity for low-resolution or blurry footage.",
|
||||
inputs=[
|
||||
@ -98,7 +98,7 @@ class WavespeedImageUpscaleNode(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="WavespeedImageUpscaleNode",
|
||||
display_name="WaveSpeed Image Upscale",
|
||||
category="api node/image/WaveSpeed",
|
||||
category="partner/image/WaveSpeed",
|
||||
description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]),
|
||||
|
||||
@ -86,7 +86,7 @@ class _PollUIState:
|
||||
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"]
|
||||
|
||||
|
||||
async def sync_op(
|
||||
|
||||
@ -469,6 +469,11 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input
|
||||
input_container = None
|
||||
output_container = None
|
||||
|
||||
# get_stream_source() is untrimmed, so apply the trim window in this same pass.
|
||||
# start_time is normalized (>= 0); duration == 0 means "until the end".
|
||||
start_time, duration = video.get_active_trim_window()
|
||||
trimming = bool(start_time or duration)
|
||||
|
||||
try:
|
||||
input_source = video.get_stream_source()
|
||||
input_container = av.open(input_source, mode="r")
|
||||
@ -487,16 +492,45 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input
|
||||
audio_stream.layout = stream.layout
|
||||
break
|
||||
|
||||
in_video = input_container.streams.video[0]
|
||||
start_pts = int(start_time / in_video.time_base) if trimming else 0
|
||||
end_pts = int((start_time + duration) / in_video.time_base) if duration else None
|
||||
if start_pts:
|
||||
input_container.seek(start_pts, stream=in_video)
|
||||
|
||||
encoded = 0
|
||||
for frame in input_container.decode(video=0):
|
||||
if trimming:
|
||||
if frame.pts is None or frame.pts < start_pts:
|
||||
continue
|
||||
if end_pts is not None and frame.pts >= end_pts:
|
||||
break
|
||||
frame = frame.reformat(width=out_w, height=out_h, format="yuv420p")
|
||||
# Re-wrap as a fresh frame: dropping irregular source timestamps (VFR/AVI/GIF/...)
|
||||
# lets the encoder assign clean ones and avoids mp4 muxer errors.
|
||||
frame = av.VideoFrame.from_ndarray(frame.to_ndarray(format="yuv420p"), format="yuv420p")
|
||||
for packet in video_stream.encode(frame):
|
||||
output_container.mux(packet)
|
||||
encoded += 1
|
||||
for packet in video_stream.encode():
|
||||
output_container.mux(packet)
|
||||
|
||||
if encoded == 0:
|
||||
raise ValueError(
|
||||
f"resize produced no frames (start_time={start_time}, duration={duration} "
|
||||
"selected nothing from the source)"
|
||||
)
|
||||
|
||||
if audio_stream is not None:
|
||||
input_container.seek(0)
|
||||
for audio_frame in input_container.decode(audio=0):
|
||||
if trimming:
|
||||
if audio_frame.time is None or audio_frame.time < start_time:
|
||||
continue
|
||||
if duration and audio_frame.time > start_time + duration:
|
||||
break
|
||||
# Carry odd audio time bases the mp4 muxer rejects; reset pts, encoder assigns clean ones (MP3-in-AVI)
|
||||
audio_frame.pts = None
|
||||
for packet in audio_stream.encode(audio_frame):
|
||||
output_container.mux(packet)
|
||||
for packet in audio_stream.encode():
|
||||
|
||||
@ -11,7 +11,7 @@ class TextEncodeAceStepAudio(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TextEncodeAceStepAudio",
|
||||
category="conditioning",
|
||||
category="model/conditioning",
|
||||
inputs=[
|
||||
IO.Clip.Input("clip"),
|
||||
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
@ -33,7 +33,7 @@ class TextEncodeAceStepAudio15(IO.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TextEncodeAceStepAudio1.5",
|
||||
category="conditioning",
|
||||
category="model/conditioning",
|
||||
inputs=[
|
||||
IO.Clip.Input("clip"),
|
||||
IO.String.Input("tags", multiline=True, dynamic_prompts=True),
|
||||
@ -67,7 +67,7 @@ class EmptyAceStepLatentAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="EmptyAceStepLatentAudio",
|
||||
display_name="Empty Ace Step 1.0 Latent Audio",
|
||||
category="latent/audio",
|
||||
category="model/latent/audio",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1),
|
||||
IO.Int.Input(
|
||||
@ -90,7 +90,7 @@ class EmptyAceStep15LatentAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="EmptyAceStep1.5LatentAudio",
|
||||
display_name="Empty Ace Step 1.5 Latent Audio",
|
||||
category="latent/audio",
|
||||
category="model/latent/audio",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01),
|
||||
IO.Int.Input(
|
||||
|
||||
@ -45,7 +45,7 @@ class SamplerLCMUpscale(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SamplerLCMUpscale",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
inputs=[
|
||||
io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01, advanced=True),
|
||||
io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1, advanced=True),
|
||||
@ -91,7 +91,7 @@ class SamplerLCM(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="SamplerLCM",
|
||||
category="sampling/samplers",
|
||||
category="model/sampling/samplers",
|
||||
description=("LCM sampler with tunable per-step noise. s_noise is a multiplier on the model's training noise scale"),
|
||||
inputs=[
|
||||
io.Float.Input("s_noise", default=1.0, min=0.0, max=64.0, step=0.01,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user