diff --git a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt index 2cbb00d99..2c72c8a13 100755 --- a/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt +++ b/.ci/windows_amd_base_files/README_VERY_IMPORTANT.txt @@ -1,5 +1,4 @@ -As of the time of writing this you need this driver for best results: -https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html +As of the time of writing this you need a recent driver. Updating to the latest driver is recommended. HOW TO RUN: @@ -7,9 +6,9 @@ If you have a AMD gpu: run_amd_gpu.bat -If you have memory issues you can try disabling the smart memory management by running comfyui with: +If you have memory issues you can try enabling the new dynamic memory management by running comfyui with: -run_amd_gpu_disable_smart_memory.bat +run_amd_gpu_enable_dynamic_vram.bat IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints diff --git a/.github/workflows/check-line-endings.yml b/.github/workflows/check-line-endings.yml index eeb594d6c..a69a24a87 100644 --- a/.github/workflows/check-line-endings.yml +++ b/.github/workflows/check-line-endings.yml @@ -17,7 +17,7 @@ jobs: - name: Check for Windows line endings (CRLF) run: | # Get the list of changed files in the PR - CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }}) + CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} -- ':!.ci') # Flag to track if CRLF is found CRLF_FOUND=false diff --git a/.github/workflows/detect-unreviewed-merge.yml b/.github/workflows/detect-unreviewed-merge.yml new file mode 100644 index 000000000..4fabecb94 --- /dev/null +++ b/.github/workflows/detect-unreviewed-merge.yml @@ -0,0 +1,24 @@ +name: Detect Unreviewed Merge + +# SOC 2 compliance — reusable workflow lives in Comfy-Org/github-workflows, +# tracking issues are filed in Comfy-Org/unreviewed-merges. + +on: + push: + branches: [master] + +concurrency: + group: detect-unreviewed-merge-${{ github.sha }} + cancel-in-progress: false + +permissions: + contents: read + pull-requests: read + +jobs: + detect: + uses: Comfy-Org/github-workflows/.github/workflows/detect-unreviewed-merge.yml@4d9cb6b87f953bb7cd69954280e1465fb9bd2040 # v1 + with: + approval-mode: latest-per-reviewer + secrets: + UNREVIEWED_MERGES_TOKEN: ${{ secrets.UNREVIEWED_MERGES_TOKEN }} diff --git a/comfy/background_removal/birefnet.py b/comfy/background_removal/birefnet.py index df54b2b90..78a80246e 100644 --- a/comfy/background_removal/birefnet.py +++ b/comfy/background_removal/birefnet.py @@ -105,7 +105,7 @@ class WindowAttention(nn.Module): relative_position_bias = self.relative_position_bias_table[self.relative_position_index.long().view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = comfy.ops.cast_to_input(relative_position_bias.permute(2, 0, 1).contiguous(), attn) # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: diff --git a/comfy/bg_removal_model.py b/comfy/bg_removal_model.py index 6dec65e63..c772c5f6a 100644 --- a/comfy/bg_removal_model.py +++ b/comfy/bg_removal_model.py @@ -55,12 +55,7 @@ class BackgroundRemovalModel(): out = torch.nn.functional.interpolate(out, size=(H, W), mode="bicubic", antialias=False) mask = out.sigmoid().to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) - if mask.ndim == 3: - mask = mask.unsqueeze(0) - if mask.shape[1] != 1: - mask = mask.movedim(-1, 1) - - return mask + return mask.squeeze(1) # (B, 1, H, W) -> (B, H, W) def load_background_removal_model(sd): diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 9bda414d1..a4cabcc65 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -149,6 +149,7 @@ parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=Non parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.") parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.") +parser.add_argument("--fast-disk", action="store_true", help="Prefer disk-backed dynamic loading and offload over unpinned RAM. Can be faster for users with fast NVME disks.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 1691fca81..ce8924a11 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -9,6 +9,7 @@ import comfy.model_management import comfy.utils import comfy.clip_model import comfy.image_encoders.dino2 +import comfy.image_encoders.dino3 class Output: def __getitem__(self, key): @@ -23,12 +24,16 @@ IMAGE_ENCODERS = { "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "dinov2": comfy.image_encoders.dino2.Dinov2Model, + "dinov3": comfy.image_encoders.dino3.DINOv3ViTModel, } class ClipVisionModel(): def __init__(self, json_config): - with open(json_config) as f: - config = json.load(f) + if isinstance(json_config, dict): + config = json_config + else: + with open(json_config) as f: + config = json.load(f) self.image_size = config.get("image_size", 224) self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073]) @@ -134,6 +139,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") elif 'encoder.layer.23.layer_scale2.lambda1' in sd: json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json") + elif 'layer.0.mlp.gate_proj.weight' in sd and 'layer.31.norm1.weight' in sd: # Dinov3 ViT-H/16+ (SwiGLU gated MLP, 32 layers) + json_config = comfy.image_encoders.dino3.DINOV3_VITH_CONFIG else: return None diff --git a/comfy/float.py b/comfy/float.py index 184b3d6d0..3c82d6359 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -1,5 +1,20 @@ +import logging + import torch +_CK_STOCHASTIC_ROUNDING_AVAILABLE = False +try: + import comfy_kitchen as ck + _ck_stochastic_rounding_fp8 = ck.stochastic_rounding_fp8 + _CK_STOCHASTIC_ROUNDING_AVAILABLE = True +except (AttributeError, ImportError): + logging.warning("comfy_kitchen does not support stochastic FP8 rounding, please update comfy_kitchen.") + +if not _CK_STOCHASTIC_ROUNDING_AVAILABLE: + def _ck_stochastic_rounding_fp8(value, rng, dtype): + raise NotImplementedError("comfy_kitchen does not support stochastic FP8 rounding") + + def calc_mantissa(abs_x, exponent, normal_mask, MANTISSA_BITS, EXPONENT_BIAS, generator=None): mantissa_scaled = torch.where( normal_mask, @@ -57,6 +72,10 @@ def stochastic_rounding(value, dtype, seed=0): if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: generator = torch.Generator(device=value.device) generator.manual_seed(seed) + if _CK_STOCHASTIC_ROUNDING_AVAILABLE: + rng = torch.randint(0, 256, value.size(), dtype=torch.uint8, layout=value.layout, device=value.device, generator=generator) + return _ck_stochastic_rounding_fp8(value, rng, dtype) + output = torch.empty_like(value, dtype=dtype) num_slices = max(1, (value.numel() / (4096 * 4096))) slice_size = max(1, round(value.shape[0] / num_slices)) diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py new file mode 100644 index 000000000..ad29b06f8 --- /dev/null +++ b/comfy/image_encoders/dino3.py @@ -0,0 +1,259 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.ops +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale + + +# DINOv3 ViT-H/16+ (SwiGLU) +DINOV3_VITH_CONFIG = { + "model_type": "dinov3", + "num_hidden_layers": 32, + "hidden_size": 1280, + "num_attention_heads": 20, + "num_register_tokens": 4, + "intermediate_size": 5120, + "layer_norm_eps": 1e-5, + "num_channels": 3, + "patch_size": 16, + "rope_theta": 100.0, + "use_gated_mlp": True, + "gated_mlp_act": "silu", + "image_size": 1024, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225], +} + + +class DINOv3ViTMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, **kwargs): + num_tokens = q.shape[-2] + num_patches = sin.shape[-2] + num_prefix_tokens = num_tokens - num_patches + + q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2) + k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2) + + q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin) + k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin) + + q = torch.cat((q_prefix_tokens, q_patches), dim=-2) + k = torch.cat((k_prefix_tokens, k_patches), dim=-2) + + return q, k + + +class DINOv3ViTAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): + super().__init__() + self.embed_dim = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False + self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None, **kwargs): + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attn = optimized_attention_for_device(query_states.device, mask=False) + attn_output = attn( + query_states, key_states, value_states, self.num_heads, attention_mask, + skip_reshape=True, skip_output_reshape=True, low_precision_attention=False, + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output + + +class DINOv3ViTGatedMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations, act="silu"): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.SiLU() if act == "silu" else torch.nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +def get_patches_center_coordinates(num_patches_h, num_patches_w, dtype, device): + coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) + coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) + coords_h = coords_h / num_patches_h + coords_w = coords_w / num_patches_w + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) + coords = coords.flatten(0, 1) + coords = 2.0 * coords - 1.0 + return coords + + +class DINOv3ViTRopePositionEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, rope_theta, hidden_size, num_attention_heads, patch_size, device, dtype): + super().__init__() + self.base = rope_theta + self.head_dim = hidden_size // num_attention_heads + self.patch_size = patch_size + + inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, pixel_values): + _, _, height, width = pixel_values.shape + num_patches_h = height // self.patch_size + num_patches_w = width // self.patch_size + + patch_coords = get_patches_center_coordinates(num_patches_h, num_patches_w, dtype=torch.float32, device=pixel_values.device) + self.inv_freq = self.inv_freq.to(pixel_values.device) + angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] + angles = angles.flatten(1, 2) + angles = angles.tile(2) + cos = torch.cos(angles).to(dtype=pixel_values.dtype) + sin = torch.sin(angles).to(dtype=pixel_values.dtype) + return cos, sin + + +class DINOv3ViTEmbeddings(nn.Module): + def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations): + super().__init__() + self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype)) + self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype)) + self.patch_embeddings = operations.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype + ) + + def forward(self, pixel_values, bool_masked_pos=None): + batch_size = pixel_values.shape[0] + + patch_embeddings = self.patch_embeddings(pixel_values) + patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) + + if bool_masked_pos is not None: + mask_token = comfy.ops.cast_to_input(self.mask_token, patch_embeddings) + patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) + + cls_token = comfy.ops.cast_to_input(self.cls_token.expand(batch_size, -1, -1), patch_embeddings) + register_tokens = comfy.ops.cast_to_input(self.register_tokens.expand(batch_size, -1, -1), patch_embeddings) + embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) + return embeddings + + +class DINOv3ViTLayer(nn.Module): + def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size, + num_attention_heads, device, dtype, operations, gated_mlp_act="silu"): + super().__init__() + self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) + self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) + + self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + if use_gated_mlp: + self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations, act=gated_mlp_act) + else: + self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations) + self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) + + def forward(self, hidden_states, attention_mask=None, position_embeddings=None): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.attention(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings) + hidden_states = self.layer_scale1(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_scale2(hidden_states) + hidden_states = hidden_states + residual + return hidden_states + + +class DINOv3ViTModel(nn.Module): + def __init__(self, config, dtype, device, operations): + super().__init__() + num_hidden_layers = config["num_hidden_layers"] + hidden_size = config["hidden_size"] + num_attention_heads = config["num_attention_heads"] + num_register_tokens = config["num_register_tokens"] + intermediate_size = config["intermediate_size"] + layer_norm_eps = config["layer_norm_eps"] + num_channels = config["num_channels"] + patch_size = config["patch_size"] + rope_theta = config["rope_theta"] + use_gated_mlp = config.get("use_gated_mlp", False) + gated_mlp_act = config.get("gated_mlp_act", "silu") + + self.embeddings = DINOv3ViTEmbeddings( + hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size, + dtype=dtype, device=device, operations=operations + ) + self.rope_embeddings = DINOv3ViTRopePositionEmbedding( + rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device + ) + self.layer = nn.ModuleList([ + DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=use_gated_mlp, mlp_bias=True, + intermediate_size=intermediate_size, num_attention_heads=num_attention_heads, + dtype=dtype, device=device, operations=operations, gated_mlp_act=gated_mlp_act) + for _ in range(num_hidden_layers)]) + self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward(self, pixel_values, bool_masked_pos=None, **kwargs): + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + position_embeddings = self.rope_embeddings(pixel_values) + + for layer_module in self.layer: + hidden_states = layer_module(hidden_states, position_embeddings=position_embeddings) + + if kwargs.get("skip_norm_elementwise", False): + sequence_output = F.layer_norm(hidden_states, hidden_states.shape[-1:]) + else: + norm = self.norm.to(hidden_states.device) + sequence_output = norm(hidden_states) + pooled_output = sequence_output[:, 0, :] + return sequence_output, None, pooled_output, None diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 12a934d71..fcbd97c59 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -4,6 +4,7 @@ class LatentFormat: scale_factor = 1.0 latent_channels = 4 latent_dimensions = 2 + preserve_empty_channel_multiples = False latent_rgb_factors = None latent_rgb_factors_bias = None latent_rgb_factors_reshape = None @@ -239,6 +240,16 @@ class Flux2(LatentFormat): def process_out(self, latent): return latent +class TripoSplat(LatentFormat): + # Sequence latent (B, 8192, 16) the camera token rides alongside as a second nested latent + latent_channels = 16 + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent + class Mochi(LatentFormat): latent_channels = 12 latent_dimensions = 3 @@ -769,6 +780,10 @@ class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 +class SeedVR2(LatentFormat): + latent_channels = 16 + preserve_empty_channel_multiples = True + class ACEAudio15(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/audio/dit.py b/comfy/ldm/audio/dit.py index a6258b755..c28be5b49 100644 --- a/comfy/ldm/audio/dit.py +++ b/comfy/ldm/audio/dit.py @@ -433,11 +433,11 @@ class Attention(nn.Module): if self.differential: q, q_diff = q.unbind(dim=1) k, k_diff = k.unbind(dim=1) - out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) - out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, transformer_options=transformer_options) + out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options) + out_diff = optimized_attention(q_diff, k_diff, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options) out = out - out_diff else: - out = optimized_attention(q, k, v, h, skip_reshape=True, transformer_options=transformer_options) + out = optimized_attention(q, k, v, h, skip_reshape=True, low_precision_attention=False, transformer_options=transformer_options) out = self.to_out(out) diff --git a/comfy/ldm/audio/vae_sa3.py b/comfy/ldm/audio/vae_sa3.py index 276846444..8be36d6ee 100644 --- a/comfy/ldm/audio/vae_sa3.py +++ b/comfy/ldm/audio/vae_sa3.py @@ -138,11 +138,11 @@ class Attention(nn.Module): k_diff = _apply_rotary_pos_emb(k_diff.float(), freqs).to(k_dtype) if self.differential: - out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) - - optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True)) + out = (optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False) + - optimized_attention(q_diff, k_diff, v, h, mask=mask, skip_reshape=True, low_precision_attention=False)) del q, k, v, q_diff, k_diff else: - out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True) + out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True, low_precision_attention=False) del q, k, v return self.to_out(out) diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 4fb56165e..86af98d36 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -38,6 +38,8 @@ class ChromaRadianceParams(ChromaParams): # None means use the same dtype as the model. nerf_embedder_dtype: Optional[torch.dtype] use_x0: bool + # Use sequential txt_ids instead of zeros + use_sequential_txt_ids: bool class ChromaRadiance(Chroma): """ @@ -162,6 +164,9 @@ class ChromaRadiance(Chroma): if params.use_x0: self.register_buffer("__x0__", torch.tensor([])) + if params.use_sequential_txt_ids: + self.register_buffer("__sequential__", torch.tensor([])) + @property def _nerf_final_layer(self) -> nn.Module: if self.params.nerf_final_head_type == "linear": @@ -313,6 +318,9 @@ class ChromaRadiance(Chroma): img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + # Radiance after 2026-05-22 uses sequential txt_ids instead of zeros + if params.use_sequential_txt_ids: + txt_ids[:, :, 0] = torch.arange(context.shape[1], device=x.device, dtype=x.dtype).unsqueeze(0).expand(bs, -1) img_out = self.forward_orig( img, diff --git a/comfy/ldm/cosmos/predict2.py b/comfy/ldm/cosmos/predict2.py index 2268bff38..671fe834d 100644 --- a/comfy/ldm/cosmos/predict2.py +++ b/comfy/ldm/cosmos/predict2.py @@ -14,15 +14,7 @@ from torchvision import transforms import comfy.patcher_extension from comfy.ldm.modules.attention import optimized_attention import comfy.ldm.common_dit - -def apply_rotary_pos_emb( - t: torch.Tensor, - freqs: torch.Tensor, -) -> torch.Tensor: - t_ = t.reshape(*t.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2).float() - t_out = freqs[..., 0] * t_[..., 0] + freqs[..., 1] * t_[..., 1] - t_out = t_out.movedim(-1, -2).reshape(*t.shape).type_as(t) - return t_out +import comfy.quant_ops # ---------------------- Feed Forward Network ----------------------- @@ -173,8 +165,7 @@ class Attention(nn.Module): k = self.k_norm(k) v = self.v_norm(v) if self.is_selfattn and rope_emb is not None: # only apply to self-attention! - q = apply_rotary_pos_emb(q, rope_emb) - k = apply_rotary_pos_emb(k, rope_emb) + q, k = comfy.quant_ops.ck.apply_rope_split_half(q, k, rope_emb) return q, k, v q, k, v = apply_norm_and_rotary_pos_emb(q, k, v, rope_emb) diff --git a/comfy/ldm/ernie/model.py b/comfy/ldm/ernie/model.py index eba661aec..f158ca1d2 100644 --- a/comfy/ldm/ernie/model.py +++ b/comfy/ldm/ernie/model.py @@ -5,6 +5,7 @@ import torch.nn.functional as F from comfy.ldm.modules.attention import optimized_attention import comfy.model_management +import comfy.quant_ops def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: assert dim % 2 == 0 @@ -19,15 +20,6 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: out = torch.stack([torch.cos(out), torch.sin(out)], dim=0) return out.to(dtype=torch.float32, device=pos.device) -def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - rot_dim = freqs_cis.shape[-1] - x, x_pass = x_in[..., :rot_dim], x_in[..., rot_dim:] - cos_ = freqs_cis[0] - sin_ = freqs_cis[1] - x1, x2 = x.chunk(2, dim=-1) - x_rotated = torch.cat((-x2, x1), dim=-1) - return torch.cat((x * cos_ + x_rotated * sin_, x_pass), dim=-1) - class ErnieImageEmbedND3(nn.Module): def __init__(self, dim: int, theta: int, axes_dim: tuple): super().__init__() @@ -37,8 +29,16 @@ class ErnieImageEmbedND3(nn.Module): def forward(self, ids: torch.Tensor) -> torch.Tensor: emb = torch.cat([rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(3)], dim=-1) - emb = emb.unsqueeze(3) # [2, B, S, 1, head_dim//2] - return torch.stack([emb, emb], dim=-1).reshape(*emb.shape[:-1], -1) # [B, S, 1, head_dim] + cos_ = emb[0] + sin_ = emb[1] + N = cos_.shape[-1] + half = N // 2 + cos_top = cos_[..., :half].repeat_interleave(2, dim=-1) + sin_top = sin_[..., :half].repeat_interleave(2, dim=-1) + cos_bot = cos_[..., half:].repeat_interleave(2, dim=-1) + sin_bot = sin_[..., half:].repeat_interleave(2, dim=-1) + rot = torch.stack([cos_top, -sin_top, sin_bot, cos_bot], dim=-1) + return rot.reshape(*rot.shape[:-1], 2, 2).unsqueeze(2) class ErnieImagePatchEmbedDynamic(nn.Module): def __init__(self, in_channels: int, embed_dim: int, patch_size: int, operations, device=None, dtype=None): @@ -115,8 +115,7 @@ class ErnieImageAttention(nn.Module): key = self.norm_k(key) if image_rotary_emb is not None: - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) + query, key = comfy.quant_ops.ck.apply_rope_split_half(query, key, image_rotary_emb) q_flat = query.reshape(B, S, -1) k_flat = key.reshape(B, S, -1) @@ -274,7 +273,7 @@ class ErnieImageModel(nn.Module): image_ids = image_ids.view(1, N_img, 3).expand(B, -1, -1) - rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)).to(x.dtype) + rotary_pos_emb = self.pos_embed(torch.cat([image_ids, text_ids], dim=1)) del image_ids, text_ids sample = self.time_proj(timesteps).to(dtype) diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 6d0aed827..891dea7dd 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -4,7 +4,7 @@ from torch import Tensor from comfy.ldm.modules.attention import optimized_attention import comfy.model_management -import logging +import comfy.quant_ops def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: @@ -44,21 +44,15 @@ def _apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) -try: - import comfy.quant_ops - q_apply_rope = comfy.quant_ops.ck.apply_rope - q_apply_rope1 = comfy.quant_ops.ck.apply_rope1 - def apply_rope(xq, xk, freqs_cis): - if comfy.model_management.in_training: - return _apply_rope(xq, xk, freqs_cis) - else: - return apply_rope1(xq, freqs_cis), apply_rope1(xk, freqs_cis) - def apply_rope1(x, freqs_cis): - if comfy.model_management.in_training: - return _apply_rope1(x, freqs_cis) - else: - return q_apply_rope1(x, freqs_cis) -except: - logging.warning("No comfy kitchen, using old apply_rope functions.") - apply_rope = _apply_rope - apply_rope1 = _apply_rope1 +def apply_rope(xq, xk, freqs_cis): + if comfy.model_management.in_training: + return _apply_rope(xq, xk, freqs_cis) + else: + return comfy.quant_ops.ck.apply_rope(xq, xk, freqs_cis) + + +def apply_rope1(x, freqs_cis): + if comfy.model_management.in_training: + return _apply_rope1(x, freqs_cis) + else: + return comfy.quant_ops.ck.apply_rope1(x, freqs_cis) diff --git a/comfy/ldm/ideogram4/model.py b/comfy/ldm/ideogram4/model.py new file mode 100644 index 000000000..b86c65bf0 --- /dev/null +++ b/comfy/ldm/ideogram4/model.py @@ -0,0 +1,297 @@ +""" +The Ideogram 4 transformer is a NextDiT/Lumina2-family single-stream model +consumes Qwen3-VL hidden-state features (concatenated from 13 layers -> 53248 dims) +packs ``[text tokens, image tokens]`` into one sequence with block-diagonal segment attention and 3D interleaved MRoPE. +""" + +from __future__ import annotations + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.patcher_extension +from comfy.ldm.lumina.model import FeedForward +from comfy.ldm.modules.attention import optimized_attention_masked +from comfy.text_encoders.llama import apply_rope, precompute_freqs_cis + +# Per-token role indicators +SEQUENCE_PADDING_INDICATOR = -1 +OUTPUT_IMAGE_INDICATOR = 2 +LLM_TOKEN_INDICATOR = 3 +# Image grid coordinates are offset so they never collide with text positions +IMAGE_POSITION_OFFSET = 65536 + + +class Ideogram4Attention(nn.Module): + def __init__(self, hidden_size, num_heads, eps=1e-5, dtype=None, device=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + self.hidden_size = hidden_size + + self.qkv = operations.Linear(hidden_size, hidden_size * 3, bias=False, dtype=dtype, device=device) + self.norm_q = operations.RMSNorm(self.head_dim, eps=eps, elementwise_affine=True, dtype=dtype, device=device) + self.norm_k = operations.RMSNorm(self.head_dim, eps=eps, elementwise_affine=True, dtype=dtype, device=device) + self.o = operations.Linear(hidden_size, hidden_size, bias=False, dtype=dtype, device=device) + + def forward(self, x, attn_mask, freqs_cis, transformer_options={}): + batch_size, seq_len, _ = x.shape + qkv = self.qkv(x).view(batch_size, seq_len, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(dim=2) + + q = self.norm_q(q) + k = self.norm_k(k) + + # (B, heads, L, head_dim) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + q, k = apply_rope(q, k, freqs_cis) + + out = optimized_attention_masked(q, k, v, self.num_heads, attn_mask, skip_reshape=True, transformer_options=transformer_options) + return self.o(out) + + +class Ideogram4TransformerBlock(nn.Module): + def __init__(self, hidden_size, intermediate_size, num_heads, norm_eps, adaln_dim, dtype=None, device=None, operations=None): + super().__init__() + self.attention = Ideogram4Attention(hidden_size, num_heads, eps=1e-5, dtype=dtype, device=device, operations=operations) + self.feed_forward = FeedForward( + dim=hidden_size, hidden_dim=intermediate_size, multiple_of=1, ffn_dim_multiplier=None, + operation_settings={"operations": operations, "dtype": dtype, "device": device}, + ) + + self.attention_norm1 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device) + self.ffn_norm1 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device) + self.attention_norm2 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device) + self.ffn_norm2 = operations.RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=True, dtype=dtype, device=device) + + self.adaln_modulation = operations.Linear(adaln_dim, 4 * hidden_size, bias=True, dtype=dtype, device=device) + + def forward(self, x, attn_mask, freqs_cis, adaln_input, transformer_options={}): + mod = self.adaln_modulation(adaln_input) + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.chunk(4, dim=-1) + gate_msa = torch.tanh(gate_msa) + gate_mlp = torch.tanh(gate_mlp) + scale_msa = 1.0 + scale_msa + scale_mlp = 1.0 + scale_mlp + + attn_out = self.attention(self.attention_norm1(x) * scale_msa, attn_mask, freqs_cis, transformer_options=transformer_options) + x = x + gate_msa * self.attention_norm2(attn_out) + x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp)) + return x + + +def _sinusoidal_embedding(t, dim, scale=1e4): + t = t.to(torch.float32) + half = dim // 2 + freq = math.log(scale) / (half - 1) + freq = torch.exp(torch.arange(half, dtype=torch.float32, device=t.device) * -freq) + emb = t.unsqueeze(-1) * freq + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + if dim % 2 == 1: + emb = F.pad(emb, (0, 1)) + return emb + + +class Ideogram4EmbedScalar(nn.Module): + def __init__(self, dim, input_range=(0.0, 1.0), dtype=None, device=None, operations=None): + super().__init__() + self.dim = dim + self.range_min, self.range_max = input_range + self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device) + self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device) + + def forward(self, x): + x = x.to(torch.float32) + scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min) + emb = _sinusoidal_embedding(scaled, self.dim) + emb = emb.to(self.mlp_in.weight.dtype) + emb = F.silu(self.mlp_in(emb)) + return self.mlp_out(emb) + + +class Ideogram4FinalLayer(nn.Module): + def __init__(self, hidden_size, out_channels, adaln_dim, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(hidden_size, eps=1e-6, elementwise_affine=False, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device) + self.adaln_modulation = operations.Linear(adaln_dim, hidden_size, bias=True, dtype=dtype, device=device) + + def forward(self, x, c): + scale = 1.0 + self.adaln_modulation(F.silu(c)) + return self.linear(self.norm_final(x) * scale) + + +class Ideogram4Transformer(nn.Module): + """A single Ideogram 4 backbone operating on a packed token sequence.""" + + def __init__(self, emb_dim, num_layers, num_heads, intermediate_size, adaln_dim, + in_channels, llm_features_dim, rope_theta, mrope_section, norm_eps, + dtype=None, device=None, operations=None): + super().__init__() + self.head_dim = emb_dim // num_heads + self.rope_theta = rope_theta + self.mrope_section = tuple(mrope_section) + + self.input_proj = operations.Linear(in_channels, emb_dim, bias=True, dtype=dtype, device=device) + self.llm_cond_norm = operations.RMSNorm(llm_features_dim, eps=1e-6, elementwise_affine=True, dtype=dtype, device=device) + self.llm_cond_proj = operations.Linear(llm_features_dim, emb_dim, bias=True, dtype=dtype, device=device) + self.t_embedding = Ideogram4EmbedScalar(emb_dim, input_range=(0.0, 1.0), dtype=dtype, device=device, operations=operations) + self.adaln_proj = operations.Linear(emb_dim, adaln_dim, bias=True, dtype=dtype, device=device) + + self.embed_image_indicator = operations.Embedding(2, emb_dim, dtype=dtype, device=device) + + self.layers = nn.ModuleList([ + Ideogram4TransformerBlock(emb_dim, intermediate_size, num_heads, norm_eps, adaln_dim, + dtype=dtype, device=device, operations=operations) + for _ in range(num_layers) + ]) + + self.final_layer = Ideogram4FinalLayer(emb_dim, in_channels, adaln_dim, dtype=dtype, device=device, operations=operations) + + def _backbone(self, llm_features, x, t, position_ids, attn_mask, indicator, transformer_options={}): + indicator = indicator.to(torch.long) + output_image_mask = (indicator == OUTPUT_IMAGE_INDICATOR).to(x.dtype).unsqueeze(-1) + + x = x * output_image_mask + h = self.input_proj(x) * output_image_mask + + t_cond = self.t_embedding(t) + if t.dim() == 1: + t_cond = t_cond.unsqueeze(1) + adaln_input = F.silu(self.adaln_proj(t_cond)) + + # h is zero on the text rows (content lives only on image rows), add writes the text features in place + if llm_features is not None: + L_text = llm_features.shape[1] + text_mask = (indicator[:, :L_text] == LLM_TOKEN_INDICATOR).to(x.dtype).unsqueeze(-1) + llm = self.llm_cond_norm(llm_features * text_mask) + llm = self.llm_cond_proj(llm) * text_mask + h[:, :L_text] = h[:, :L_text] + llm + + h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long), out_dtype=h.dtype) + + # Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch). + freqs_cis = precompute_freqs_cis( + self.head_dim, position_ids[0].transpose(0, 1), self.rope_theta, + rope_dims=self.mrope_section, interleaved_mrope=True, device=position_ids.device, + ) + + if attn_mask is not None and attn_mask.dtype == torch.bool: + attn_mask = torch.zeros_like(attn_mask, dtype=h.dtype).masked_fill_(~attn_mask, -torch.finfo(h.dtype).max) + + for layer in self.layers: + h = layer(h, attn_mask, freqs_cis, adaln_input, transformer_options=transformer_options) + + return self.final_layer(h, adaln_input) + + +class Ideogram4Transformer2DModel(Ideogram4Transformer): + """Ideogram 4 single-stream DiT. + + Runs a packed ``[text, image]`` sequence when text context is supplied, or an image-only sequence when ``context is None``. + """ + + def __init__(self, image_model=None, in_channels=128, num_layers=34, num_attention_heads=18, attention_head_dim=256, intermediate_size=12288, + adaln_dim=512, llm_features_dim=53248, rope_theta=5000000, mrope_section=(24, 20, 20), norm_eps=1e-5, + dtype=None, device=None, operations=None, **kwargs): + emb_dim = num_attention_heads * attention_head_dim + super().__init__( + emb_dim=emb_dim, num_layers=num_layers, num_heads=num_attention_heads, + intermediate_size=intermediate_size, adaln_dim=adaln_dim, in_channels=in_channels, + llm_features_dim=llm_features_dim, rope_theta=rope_theta, mrope_section=mrope_section, + norm_eps=norm_eps, dtype=dtype, device=device, operations=operations) + self.dtype = dtype + self.in_channels = in_channels + self.out_channels = in_channels + # 128-dim token = patch (2x2) * ae_channels (32). + self.patch_size = 2 + self.ae_channels = in_channels // (self.patch_size * self.patch_size) + + def _img_to_tokens(self, x): + B, C, gh, gw = x.shape + x = x.view(B, self.ae_channels, self.patch_size, self.patch_size, gh, gw) + x = x.permute(0, 4, 5, 2, 3, 1) # (B, gh, gw, pi, pj, c) + return x.reshape(B, gh * gw, C) + + def _tokens_to_img(self, tokens, gh, gw): + B = tokens.shape[0] + C = tokens.shape[-1] + x = tokens.reshape(B, gh, gw, self.patch_size, self.patch_size, self.ae_channels) + x = x.permute(0, 5, 3, 4, 1, 2) # (B, c, pi, pj, gh, gw) + return x.reshape(B, C, gh, gw) + + def _image_position_ids(self, gh, gw, device): + h_idx = torch.arange(gh, device=device).view(-1, 1).expand(gh, gw).reshape(-1) + w_idx = torch.arange(gw, device=device).view(1, -1).expand(gh, gw).reshape(-1) + t_idx = torch.zeros_like(h_idx) + return torch.stack([t_idx, h_idx, w_idx], dim=1) + IMAGE_POSITION_OFFSET # (L_img, 3) + + def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options): + B = x_chunk.shape[0] + device = x_chunk.device + img_tokens = self._img_to_tokens(x_chunk) + L_img = img_tokens.shape[1] + L_text = context_chunk.shape[1] + L = L_text + L_img + latent_dim = img_tokens.shape[-1] + + x_full = torch.zeros(B, L, latent_dim, dtype=img_tokens.dtype, device=device) + x_full[:, L_text:] = img_tokens + + text_pos = torch.arange(L_text, device=device).view(-1, 1).expand(L_text, 3) + img_pos = self._image_position_ids(gh, gw, device) + position_ids = torch.cat([text_pos, img_pos], dim=0).unsqueeze(0).expand(B, L, 3) + + indicator = torch.empty(B, L, dtype=torch.long, device=device) + indicator[:, :L_text] = LLM_TOKEN_INDICATOR + indicator[:, L_text:] = OUTPUT_IMAGE_INDICATOR + + attn_mask = None + if attn_mask_chunk is not None: + segment_ids = torch.ones(B, L, dtype=torch.long, device=device) + pad = (attn_mask_chunk == 0) + segment_ids[:, :L_text][pad] = SEQUENCE_PADDING_INDICATOR + indicator[:, :L_text][pad] = 0 + # Block-diagonal mask from segment ids: (B, 1, L, L), True = attend. + attn_mask = (segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1)).unsqueeze(1) + + out = self._backbone(context_chunk, x_full, t_chunk, position_ids, attn_mask, indicator, + transformer_options=transformer_options) + return self._tokens_to_img(out[:, L_text:], gh, gw) + + def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options): + B = x_chunk.shape[0] + device = x_chunk.device + img_tokens = self._img_to_tokens(x_chunk) + L_img = img_tokens.shape[1] + + position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3) + indicator = torch.full((B, L_img), OUTPUT_IMAGE_INDICATOR, dtype=torch.long, device=device) + + # Image-only sequence is a single segment -> no mask, full attention, no LLM context. + out = self._backbone(None, img_tokens, t_chunk, position_ids, None, indicator, transformer_options=transformer_options) + return self._tokens_to_img(out, gh, gw) + + def forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options), + ).execute(x, timesteps, context, attention_mask, transformer_options, **kwargs) + + def _forward(self, x, timesteps, context=None, attention_mask=None, transformer_options={}, **kwargs): + bs, c, gh, gw = x.shape + + timesteps = 1.0 - timesteps + + # unconditional pass + if context is None: + return -self._run_image_only(x, timesteps, gh, gw, transformer_options) + + return -self._run_conditional(x, context, attention_mask, timesteps, gh, gw, transformer_options) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 55360535a..b78e764c7 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -735,7 +735,86 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) return out +def _var_attention_qkv(q, k, v, heads, skip_reshape): + if skip_reshape: + return q, k, v, q.shape[-1] + total_tokens, embed_dim = q.shape + head_dim = embed_dim // heads + return ( + q.view(total_tokens, heads, head_dim), + k.view(k.shape[0], heads, head_dim), + v.view(v.shape[0], heads, head_dim), + head_dim, + ) + +def _var_attention_output(out, heads, head_dim, skip_output_reshape): + if skip_output_reshape: + return out + return out.reshape(-1, heads * head_dim) + + +def _use_blackwell_attention(): + device = model_management.get_torch_device() + if device.type != "cuda": + return False + major, minor = torch.cuda.get_device_capability(device) + return (major, minor) >= (12, 0) + + +def _validate_split_cu_seqlens(name, cu_seqlens, token_count): + if cu_seqlens.dtype not in (torch.int32, torch.int64): + raise ValueError(f"{name} must use an integer dtype") + if cu_seqlens.ndim != 1 or cu_seqlens.numel() < 2: + raise ValueError(f"{name} must be a 1D tensor with at least two offsets") + if cu_seqlens[0].item() != 0: + raise ValueError(f"{name} must start at 0") + if (cu_seqlens[1:] <= cu_seqlens[:-1]).any().item(): + raise ValueError(f"{name} must be strictly increasing") + if cu_seqlens[-1].item() != token_count: + raise ValueError(f"{name} does not match token count") + + +def _split_indices(cu_seqlens): + return cu_seqlens[1:-1].to(device="cpu", dtype=torch.long) + + +def var_attention_optimized_split(q, k, v, heads, cu_seqlens_q, cu_seqlens_k, *args, skip_reshape=False, skip_output_reshape=False, **kwargs): + q, k, v, head_dim = _var_attention_qkv(q, k, v, heads, skip_reshape) + + _validate_split_cu_seqlens("cu_seqlens_q", cu_seqlens_q, q.shape[0]) + _validate_split_cu_seqlens("cu_seqlens_k", cu_seqlens_k, k.shape[0]) + if cu_seqlens_k[-1].item() != v.shape[0]: + raise ValueError("cu_seqlens_k does not match v token count") + + q_split_indices = _split_indices(cu_seqlens_q) + k_split_indices = _split_indices(cu_seqlens_k) + q_splits = torch.tensor_split(q, q_split_indices, dim=0) + k_splits = torch.tensor_split(k, k_split_indices, dim=0) + v_splits = torch.tensor_split(v, k_split_indices, dim=0) + if len(q_splits) != len(k_splits) or len(q_splits) != len(v_splits): + raise ValueError("cu_seqlens_q and cu_seqlens_k must describe the same sequence count") + + out = [] + for q_i, k_i, v_i in zip(q_splits, k_splits, v_splits): + q_i = q_i.permute(1, 0, 2).unsqueeze(0) + k_i = k_i.permute(1, 0, 2).unsqueeze(0) + v_i = v_i.permute(1, 0, 2).unsqueeze(0) + out_dtype = q_i.dtype + if optimized_attention is attention_sage and q_i.dtype not in (torch.float16, torch.bfloat16): + q_i = q_i.to(torch.bfloat16) + k_i = k_i.to(torch.bfloat16) + v_i = v_i.to(torch.bfloat16) + out_i = optimized_attention(q_i, k_i, v_i, heads, skip_reshape=True, skip_output_reshape=True) + if out_i.dtype != out_dtype: + out_i = out_i.to(out_dtype) + out.append(out_i.squeeze(0).permute(1, 0, 2)) + + out = torch.cat(out, dim=0) + return _var_attention_output(out, heads, head_dim, skip_output_reshape) + + +optimized_var_attention = var_attention_optimized_split optimized_attention = attention_basic if model_management.sage_attention_enabled(): @@ -758,6 +837,8 @@ else: logging.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention") optimized_attention = attention_sub_quad +logging.info("Using optimized_attention split-loop for variable-length attention") + optimized_attention_masked = optimized_attention @@ -773,6 +854,7 @@ if model_management.xformers_enabled(): register_attention_function("pytorch", attention_pytorch) register_attention_function("sub_quad", attention_sub_quad) register_attention_function("split", attention_split) +register_attention_function("var_attention_optimized_split", var_attention_optimized_split) def optimized_attention_for_device(device, mask=False, small_input=False): @@ -1209,5 +1291,3 @@ class SpatialVideoTransformer(SpatialTransformer): x = self.proj_out(x) out = x + x_in return out - - diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index fcbaa074f..235df0b83 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -13,6 +13,7 @@ if model_management.xformers_enabled_vae(): import xformers import xformers.ops + def torch_cat_if_needed(xl, dim): xl = [x for x in xl if x is not None and x.shape[dim] > 0] if len(xl) > 1: @@ -22,7 +23,8 @@ def torch_cat_if_needed(xl, dim): else: return None -def get_timestep_embedding(timesteps, embedding_dim): + +def get_timestep_embedding(timesteps, embedding_dim, flip_sin_to_cos=False, downscale_freq_shift=1): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. @@ -33,11 +35,13 @@ def get_timestep_embedding(timesteps, embedding_dim): assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 - emb = math.log(10000) / (half_dim - 1) + emb = math.log(10000) / (half_dim - downscale_freq_shift) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0,1,0,0)) return emb diff --git a/comfy/ldm/pixeldit/pid.py b/comfy/ldm/pixeldit/pid.py index 0ad4b7ce8..21b73907a 100644 --- a/comfy/ldm/pixeldit/pid.py +++ b/comfy/ldm/pixeldit/pid.py @@ -207,8 +207,9 @@ class PidNet(PixDiT_T2I): f"Flux1/SD3 = 16 channels, Flux2 = 128 channels." ) B = x.shape[0] - Hs = x.shape[2] // self.patch_size - Ws = x.shape[3] // self.patch_size + # Match the backbone's pad_to_patch_size (round up) so the LQ grid lines up with the patch stream. + Hs = -(-x.shape[2] // self.patch_size) + Ws = -(-x.shape[3] // self.patch_size) degrade_sigma = degrade_sigma.to(device=x.device, dtype=torch.float32).reshape(-1) if degrade_sigma.numel() == 1 and B > 1: diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 0862f72f7..3462d8108 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -51,15 +51,6 @@ class FeedForward(nn.Module): return hidden_states -def apply_rotary_emb(x, freqs_cis): - if x.shape[1] == 0: - return x - - t_ = x.reshape(*x.shape[:-1], -1, 1, 2) - t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] - return t_out.reshape(*x.shape) - - class QwenTimestepProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None): super().__init__() diff --git a/comfy/ldm/seedvr/color_fix.py b/comfy/ldm/seedvr/color_fix.py new file mode 100644 index 000000000..7ddfc03af --- /dev/null +++ b/comfy/ldm/seedvr/color_fix.py @@ -0,0 +1,340 @@ +import torch +import torch.nn.functional as F +from torch import Tensor + +from comfy.ldm.seedvr.model import safe_pad_operation +from comfy.ldm.seedvr.vae import safe_interpolate_operation +from comfy.ldm.seedvr.constants import ( + CIELAB_DELTA, + CIELAB_KAPPA, + D65_WHITE_X, + D65_WHITE_Z, + WAVELET_DECOMP_LEVELS, +) + + +def wavelet_blur(image: Tensor, radius): + max_safe_radius = max(1, min(image.shape[-2:]) // 8) + if radius > max_safe_radius: + radius = max_safe_radius + + num_channels = image.shape[1] + + kernel_vals = [ + [0.0625, 0.125, 0.0625], + [0.125, 0.25, 0.125], + [0.0625, 0.125, 0.0625], + ] + kernel = torch.tensor(kernel_vals, dtype=image.dtype, device=image.device) + kernel = kernel[None, None].repeat(num_channels, 1, 1, 1) + + image = safe_pad_operation(image, (radius, radius, radius, radius), mode='replicate') + output = F.conv2d(image, kernel, groups=num_channels, dilation=radius) + + return output + +def wavelet_decomposition(image: Tensor, levels: int = WAVELET_DECOMP_LEVELS): + high_freq = torch.zeros_like(image) + + for i in range(levels): + radius = 2 ** i + low_freq = wavelet_blur(image, radius) + high_freq.add_(image).sub_(low_freq) + image = low_freq + + return high_freq, low_freq + +def wavelet_reconstruction(content_feat: Tensor, style_feat: Tensor) -> Tensor: + + if content_feat.shape != style_feat.shape: + # Resize style to match content spatial dimensions + if len(content_feat.shape) >= 3: + # safe_interpolate_operation handles FP16 conversion automatically + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False + ) + + # Decompose both features into frequency components + content_high_freq, content_low_freq = wavelet_decomposition(content_feat) + del content_low_freq # Free memory immediately + + style_high_freq, style_low_freq = wavelet_decomposition(style_feat) + del style_high_freq # Free memory immediately + + if content_high_freq.shape != style_low_freq.shape: + style_low_freq = safe_interpolate_operation( + style_low_freq, + size=content_high_freq.shape[-2:], + mode='bilinear', + align_corners=False + ) + + content_high_freq.add_(style_low_freq) + + return content_high_freq.clamp_(-1.0, 1.0) + +def _histogram_matching_channel(source: Tensor, reference: Tensor, device: torch.device) -> Tensor: + original_shape = source.shape + + # Flatten + source_flat = source.flatten() + reference_flat = reference.flatten() + + # Sort both arrays + source_sorted, source_indices = torch.sort(source_flat) + reference_sorted, _ = torch.sort(reference_flat) + del reference_flat + + # Quantile mapping + n_source = len(source_sorted) + n_reference = len(reference_sorted) + + if n_source == n_reference: + matched_sorted = reference_sorted + else: + # Interpolate reference to match source quantiles + source_quantiles = torch.linspace(0, 1, n_source, device=device) + ref_indices = (source_quantiles * (n_reference - 1)).long() + ref_indices.clamp_(0, n_reference - 1) + matched_sorted = reference_sorted[ref_indices] + del source_quantiles, ref_indices, reference_sorted + + del source_sorted, source_flat + + # Reconstruct using argsort (portable across CUDA/ROCm/MPS) + inverse_indices = torch.argsort(source_indices) + del source_indices + matched_flat = matched_sorted[inverse_indices] + del matched_sorted, inverse_indices + + return matched_flat.reshape(original_shape) + +def _lab_to_rgb_batch(lab: Tensor, device: torch.device, matrix_inv: Tensor, epsilon: float, kappa: float) -> Tensor: + """Convert batch of CIELAB images to RGB color space.""" + L, a, b = lab[:, 0], lab[:, 1], lab[:, 2] + + # LAB to XYZ + fy = (L + 16.0) / 116.0 + fx = a.div(500.0).add_(fy) + fz = fy - b / 200.0 + del L, a, b + + # XYZ transformation + x = torch.where( + fx > epsilon, + torch.pow(fx, 3.0), + fx.mul(116.0).sub_(16.0).div_(kappa) + ) + y = torch.where( + fy > epsilon, + torch.pow(fy, 3.0), + fy.mul(116.0).sub_(16.0).div_(kappa) + ) + z = torch.where( + fz > epsilon, + torch.pow(fz, 3.0), + fz.mul(116.0).sub_(16.0).div_(kappa) + ) + del fx, fy, fz + + # Apply D65 white point (in-place) + x.mul_(D65_WHITE_X) + # y *= 1.00000 # (no-op, skip) + z.mul_(D65_WHITE_Z) + + xyz = torch.stack([x, y, z], dim=1) + del x, y, z + + # Matrix multiplication: XYZ -> RGB + B, C, H, W = xyz.shape + xyz_flat = xyz.permute(0, 2, 3, 1).reshape(-1, 3) + del xyz + + # Ensure dtype consistency for matrix multiplication + xyz_flat = xyz_flat.to(dtype=matrix_inv.dtype) + rgb_linear_flat = torch.matmul(xyz_flat, matrix_inv.T) + del xyz_flat + + rgb_linear = rgb_linear_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) + del rgb_linear_flat + + # Apply inverse gamma correction (delinearize) + mask = rgb_linear > 0.0031308 + rgb = torch.where( + mask, + torch.pow(torch.clamp(rgb_linear, min=0.0), 1.0 / 2.4).mul_(1.055).sub_(0.055), + rgb_linear * 12.92 + ) + del mask, rgb_linear + + return torch.clamp(rgb, 0.0, 1.0) + +def _rgb_to_lab_batch(rgb: Tensor, device: torch.device, matrix: Tensor, epsilon: float, kappa: float) -> Tensor: + """Convert batch of RGB images to CIELAB color space using D65 illuminant.""" + # Apply sRGB gamma correction (linearize) + mask = rgb > 0.04045 + rgb_linear = torch.where( + mask, + torch.pow((rgb + 0.055) / 1.055, 2.4), + rgb / 12.92 + ) + del mask + + # Matrix multiplication: RGB -> XYZ + B, C, H, W = rgb_linear.shape + rgb_flat = rgb_linear.permute(0, 2, 3, 1).reshape(-1, 3) + del rgb_linear + + # Ensure dtype consistency for matrix multiplication + rgb_flat = rgb_flat.to(dtype=matrix.dtype) + xyz_flat = torch.matmul(rgb_flat, matrix.T) + del rgb_flat + + xyz = xyz_flat.reshape(B, H, W, 3).permute(0, 3, 1, 2) + del xyz_flat + + # Normalize by D65 white point (in-place) + xyz[:, 0].div_(D65_WHITE_X) # X + # xyz[:, 1] /= 1.00000 # Y (no-op, skip) + xyz[:, 2].div_(D65_WHITE_Z) # Z + + # XYZ to LAB transformation + epsilon_cubed = epsilon ** 3 + mask = xyz > epsilon_cubed + f_xyz = torch.where( + mask, + torch.pow(xyz, 1.0 / 3.0), + xyz.mul(kappa).add_(16.0).div_(116.0) + ) + del xyz, mask + + # Extract channels and compute LAB + L = f_xyz[:, 1].mul(116.0).sub_(16.0) # Lightness [0, 100] + a = (f_xyz[:, 0] - f_xyz[:, 1]).mul_(500.0) # Green-Red [-128, 127] + b = (f_xyz[:, 1] - f_xyz[:, 2]).mul_(200.0) # Blue-Yellow [-128, 127] + del f_xyz + + return torch.stack([L, a, b], dim=1) + +def lab_color_transfer( + content_feat: Tensor, + style_feat: Tensor, + luminance_weight: float = 0.8 +) -> Tensor: + content_feat = wavelet_reconstruction(content_feat, style_feat) + + if content_feat.shape != style_feat.shape: + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False + ) + + device = content_feat.device + + def ensure_float32_precision(c): + orig_dtype = c.dtype + c = c.float() + return c, orig_dtype + content_feat, original_dtype = ensure_float32_precision(content_feat) + style_feat, _ = ensure_float32_precision(style_feat) + + rgb_to_xyz_matrix = torch.tensor([ + [0.4124564, 0.3575761, 0.1804375], + [0.2126729, 0.7151522, 0.0721750], + [0.0193339, 0.1191920, 0.9503041] + ], dtype=torch.float32, device=device) + + xyz_to_rgb_matrix = torch.tensor([ + [ 3.2404542, -1.5371385, -0.4985314], + [-0.9692660, 1.8760108, 0.0415560], + [ 0.0556434, -0.2040259, 1.0572252] + ], dtype=torch.float32, device=device) + + epsilon = CIELAB_DELTA + kappa = CIELAB_KAPPA + + content_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) + style_feat.add_(1.0).mul_(0.5).clamp_(0.0, 1.0) + + # Convert to LAB color space + content_lab = _rgb_to_lab_batch(content_feat, device, rgb_to_xyz_matrix, epsilon, kappa) + del content_feat + + style_lab = _rgb_to_lab_batch(style_feat, device, rgb_to_xyz_matrix, epsilon, kappa) + del style_feat, rgb_to_xyz_matrix + + # Match chrominance channels (a*, b*) for accurate color transfer + matched_a = _histogram_matching_channel(content_lab[:, 1], style_lab[:, 1], device) + matched_b = _histogram_matching_channel(content_lab[:, 2], style_lab[:, 2], device) + + # Handle luminance with weighted blending + if luminance_weight < 1.0: + # Partially match luminance for better overall color accuracy + matched_L = _histogram_matching_channel(content_lab[:, 0], style_lab[:, 0], device) + # Blend: preserve some content L* for detail, adopt some style L* for color + result_L = content_lab[:, 0].mul(luminance_weight).add_(matched_L.mul(1.0 - luminance_weight)) + del matched_L + else: + # Fully preserve content luminance + result_L = content_lab[:, 0] + + del content_lab, style_lab + + # Reconstruct LAB with corrected channels + result_lab = torch.stack([result_L, matched_a, matched_b], dim=1) + del result_L, matched_a, matched_b + + # Convert back to RGB + result_rgb = _lab_to_rgb_batch(result_lab, device, xyz_to_rgb_matrix, epsilon, kappa) + del result_lab, xyz_to_rgb_matrix + + # Convert back to [-1, 1] range (in-place) + result = result_rgb.mul_(2.0).sub_(1.0) + del result_rgb + + result = result.to(original_dtype) + + return result + + +def wavelet_color_transfer(content_feat: Tensor, style_feat: Tensor) -> Tensor: + return wavelet_reconstruction(content_feat, style_feat) + + +def adain_color_transfer(content_feat: Tensor, style_feat: Tensor, eps: float = 1e-5) -> Tensor: + if content_feat.shape != style_feat.shape: + style_feat = safe_interpolate_operation( + style_feat, + size=content_feat.shape[-2:], + mode='bilinear', + align_corners=False, + ) + + original_dtype = content_feat.dtype + content_feat = content_feat.float() + style_feat = style_feat.float() + + b, c = content_feat.shape[:2] + content_flat = content_feat.reshape(b, c, -1) + style_flat = style_feat.reshape(b, c, -1) + + content_mean = content_flat.mean(dim=2).reshape(b, c, 1, 1) + content_std = (content_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) + style_mean = style_flat.mean(dim=2).reshape(b, c, 1, 1) + style_std = (style_flat.var(dim=2, correction=0) + eps).sqrt().reshape(b, c, 1, 1) + del content_flat, style_flat + + normalized = (content_feat - content_mean) / content_std + del content_mean, content_std + result = normalized * style_std + style_mean + del normalized, style_mean, style_std + + result = result.clamp_(-1.0, 1.0) + if result.dtype != original_dtype: + result = result.to(original_dtype) + return result diff --git a/comfy/ldm/seedvr/constants.py b/comfy/ldm/seedvr/constants.py new file mode 100644 index 000000000..95838d1dd --- /dev/null +++ b/comfy/ldm/seedvr/constants.py @@ -0,0 +1,79 @@ +"""Named constants for the SeedVR2 integration, grouped by provenance. + +Provenance prefixes: +- ``SEEDVR2_*`` - introduced by this integration (no external origin); rationale inline. +- ``BYTEDANCE_*`` - ported from the official ByteDance-Seed/SeedVR release; each cites + the upstream config/source path it was lifted from. +- unprefixed standards (``ROPE_THETA``, ``CIELAB_*``, ``D65_*``) - published literature / + ISO / CIE values; cite the standard. +""" + +# -------------------------------------------------------------------------------------- +# A. Progressive-sampler chunk-size law (SEEDVR2 - this integration's VRAM experiment) +# n_max(frames/chunk) = SEEDVR2_CHUNK_FRAMES_PER_GB * (free_GB - SEEDVR2_CHUNK_GB_MARGIN) +# rounded to the 4n+1 grid. Fit on 22 blocked-5090 cells, validated on a real RTX 4070 +# (3b and 7b). Resolution-independent (the VAE tiling sets the wall, not the DiT). +# -------------------------------------------------------------------------------------- +SEEDVR2_CHUNK_GB_MARGIN = 3 # fixed VRAM overhead before chunks scale (GiB) +SEEDVR2_CHUNK_FRAMES_PER_GB = 4 # empirical slope: pixel frames admitted per free GiB + +# -------------------------------------------------------------------------------------- +# B. Fork heuristics (SEEDVR2 - this integration) +# -------------------------------------------------------------------------------------- +SEEDVR2_7B_VID_DIM = 3072 # runtime 3b-vs-7b sentinel; tested against vid_dim. + # (3072 is ByteDance's 7b vid_dim; the sentinel use is ours.) +SEEDVR2_OOM_BACKOFF_DIVISOR = 2 # auto-chunk OOM retry: halve the chunk and retry. +SEEDVR2_DTYPE_BYTES_FLOOR = 4 # per-element byte floor for memory math (fp32 worst case). +SEEDVR2_7B_MLP_CHUNK = 8192 # 7b MLP token-chunk to bound peak VRAM. +SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS = 4096 # partial-RoPE application token-chunk. +SEEDVR2_LATENT_CHANNELS = 16 # SeedVR2 latent channel count (== BYTEDANCE latent_channels). +SEEDVR2_COND_CHANNELS = 17 # conditioning channels = vid_in_channels(33) - latent(16). +SEEDVR2_DEFAULT_TEMPORAL_SIZE = 16 # default VAE temporal tile when unset. + +# Color-correction memory model (fork tuning; per-frame VRAM estimate for chunk sizing) +SEEDVR2_COLOR_MEM_HEADROOM = 0.75 # fraction of free VRAM usable per color-correction chunk. +SEEDVR2_LAB_SCALE_MULTIPLIER = 13 # per-frame byte multiplier, LAB path. +SEEDVR2_WAVELET_SCALE_MULTIPLIER = 10 # per-frame byte multiplier, wavelet path. +SEEDVR2_ADAIN_SCALE_MULTIPLIER = 6 # per-frame byte multiplier, AdaIN path. + +# -------------------------------------------------------------------------------------- +# C. ByteDance config / source (BYTEDANCE - cite ByteDance-Seed/SeedVR) +# -------------------------------------------------------------------------------------- +BYTEDANCE_VAE_SCALING_FACTOR = 0.9152 # configs_3b/main.yaml:57 (scaling_factor); latent denorm. +BYTEDANCE_VAE_SHIFTING_FACTOR = 0.0 # infer.py (shifting_factor default); latent denorm shift. +BYTEDANCE_VAE_CONV_MEM_GIB = 0.5 # configs_3b/main.yaml:54 (conv_max_mem). +BYTEDANCE_VAE_NORM_MEM_GIB = 0.5 # configs_3b/main.yaml:55 (norm_max_mem). +BYTEDANCE_LOGVAR_CLAMP_MIN = -30.0 # video_vae_v3/modules/types.py:28. +BYTEDANCE_LOGVAR_CLAMP_MAX = 20.0 # video_vae_v3/modules/types.py:28. +BYTEDANCE_GN_CHUNKS_FP16 = 4 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp16). +BYTEDANCE_GN_CHUNKS_FP32 = 2 # causal_inflation_lib.py:351 (GroupNorm chunk count, fp32). +BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD = 64 # attn_video_vae.py:308 (force .contiguous() above this b*t). +BYTEDANCE_BLOCK_OUT_CHANNELS = (128, 256, 512, 512) # s8_c16_t4_inflation_sd3.yaml:7-11. +BYTEDANCE_SLICING_SAMPLE_MIN = 4 # s8_c16_t4_inflation_sd3.yaml:22 (slicing_sample_min_size). +BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE = 4 # infer.py:230 (temporal_downsample_factor); the 4n+1 factor. +BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE = 8 # infer.py:231 (spatial_downsample_factor). +BYTEDANCE_SCHEDULE_T = 1000.0 # configs_3b/main.yaml:65 (schedule.T); timestep range. +BYTEDANCE_SPATIAL_DIVISOR = 16 # inference_seedvr2_3b.py:241 (DivisibleCrop((16,16))). +BYTEDANCE_720P_REF_AREA = 45 * 80 # dit_v2/window.py:32 (720p reference area for window scaling). +BYTEDANCE_MAX_TEMPORAL_WINDOW = 30 # dit_v2/window.py:35 (max temporal window frames). +BYTEDANCE_ROPE_MAX_FREQ = 256 # dit_v2/rope.py:31 (pixel-RoPE max frequency). +BYTEDANCE_SINUSOIDAL_DIM = 256 # dit_3b/nadit.py:120 (timestep sinusoidal embed dim). +# Resolution-dependent timestep-shift linear fits: (x1, y1, x2, y2) for get_lin_function. +BYTEDANCE_IMG_SHIFT_FIT = (256 * 256, 1.0, 1024 * 1024, 3.2) # infer.py:242. +BYTEDANCE_VID_SHIFT_FIT = (256 * 256 * 37, 1.0, 1280 * 720 * 145, 5.0) # infer.py:243. + +# -------------------------------------------------------------------------------------- +# D. Published standards (cite the literature) +# -------------------------------------------------------------------------------------- +ROPE_THETA = 10000 # RoPE base; Su et al., "RoFormer", arXiv:2104.09864. + +# CIELAB f(t) piecewise constants and D65 white point (CIE 15 colorimetry; CIE D65). +CIELAB_DELTA = 6.0 / 29.0 # CIE 15 (delta). +CIELAB_KAPPA = (29.0 / 3.0) ** 3 # CIE 15 (kappa). +D65_WHITE_X = 0.95047 # CIE D65 standard illuminant Xn (Yn = 1). +D65_WHITE_Z = 1.08883 # CIE D65 standard illuminant Zn. +WAVELET_DECOMP_LEVELS = 5 # wavelet color-fix decomposition depth (GIMP/Krita; StableSR). + +# NOTE: the sRGB<->XYZ D65 3x3 matrices (IEC 61966-2-1) remain inline in the color code and +# are named (SRGB_TO_XYZ_D65 / XYZ_TO_SRGB_D65) during the color-module extraction, where the +# exact existing coefficients move verbatim rather than being retyped here. diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py new file mode 100644 index 000000000..3fa9fe07e --- /dev/null +++ b/comfy/ldm/seedvr/model.py @@ -0,0 +1,1665 @@ +from dataclasses import dataclass +from typing import Optional, Tuple, Union, List, Dict, Any, Callable +import einops +from einops import rearrange +import torch.nn.functional as F +from math import ceil, pi +import torch +from itertools import chain +from comfy.ldm.modules.diffusionmodules.model import get_timestep_embedding +from comfy.ldm.modules.attention import optimized_var_attention +from torch.nn.modules.utils import _triple +from torch import nn +import math +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.seedvr.constants import ( + BYTEDANCE_720P_REF_AREA, + BYTEDANCE_MAX_TEMPORAL_WINDOW, + BYTEDANCE_ROPE_MAX_FREQ, + BYTEDANCE_SINUSOIDAL_DIM, + ROPE_THETA, + SEEDVR2_7B_MLP_CHUNK, + SEEDVR2_7B_VID_DIM, + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, +) +import comfy.model_management +import numbers + +def _torch_float8_types(): + return tuple( + getattr(torch, name) + for name in ( + "float8_e4m3fn", + "float8_e4m3fnuz", + "float8_e5m2", + "float8_e5m2fnuz", + "float8_e8m0fnu", + ) + if hasattr(torch, name) + ) + +class CustomRMSNorm(nn.Module): + + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None): + super(CustomRMSNorm, self).__init__() + + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = torch.Size(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + + if self.elementwise_affine: + self.weight = nn.Parameter(torch.ones(*normalized_shape, device=device, dtype=dtype)) + else: + self.register_parameter('weight', None) + + def forward(self, input): + + dims = tuple(range(-len(self.normalized_shape), 0)) + + normalized = input.float() + variance = normalized.pow(2).mean(dim=dims, keepdim=True) + rms = torch.sqrt(variance + self.eps) + + normalized = normalized / rms + + if self.elementwise_affine: + return normalized * self.weight.to(input.dtype) + return normalized + +class Cache: + def __init__(self, disable=False, prefix="", cache=None): + self.cache = cache if cache is not None else {} + self.disable = disable + self.prefix = prefix + + def __call__(self, key: str, fn: Callable): + if self.disable: + return fn() + + key = self.prefix + key + try: + result = self.cache[key] + except KeyError: + result = fn() + self.cache[key] = result + return result + + def namespace(self, namespace: str): + return Cache( + disable=self.disable, + prefix=self.prefix + namespace + ".", + cache=self.cache, + ) + + def get(self, key: str): + key = self.prefix + key + return self.cache[key] + +def repeat_concat( + vid: torch.FloatTensor, # (VL ... c) + txt: torch.FloatTensor, # (TL ... c) + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: List, # (n) +) -> torch.FloatTensor: # (L ... c) + vid = torch.split(vid, vid_len.tolist()) + txt = torch.split(txt, txt_len.tolist()) + txt = [[x] * n for x, n in zip(txt, txt_repeat)] + txt = list(chain(*txt)) + return torch.cat(list(chain(*zip(vid, txt)))) + +def concat( + vid: torch.FloatTensor, # (VL ... c) + txt: torch.FloatTensor, # (TL ... c) + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> torch.FloatTensor: # (L ... c) + vid = torch.split(vid, vid_len.tolist()) + txt = torch.split(txt, txt_len.tolist()) + return torch.cat(list(chain(*zip(vid, txt)))) + +def concat_idx( + vid_len: torch.LongTensor, # (b) + txt_len: torch.LongTensor, # (b) +) -> Tuple[ + Callable, + Callable, +]: + device = vid_len.device + vid_idx = torch.arange(vid_len.sum(), device=device) + txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) + tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) + src_idx = torch.argsort(tgt_idx) + return ( + lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), + lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), + ) + + +def repeat_concat_idx( + vid_len: torch.LongTensor, # (n*b) + txt_len: torch.LongTensor, # (b) + txt_repeat: torch.LongTensor, # (n) +) -> Tuple[ + Callable, + Callable, +]: + device = vid_len.device + vid_idx = torch.arange(vid_len.sum(), device=device) + txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) + txt_repeat_list = txt_repeat.tolist() + tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) + src_idx = torch.argsort(tgt_idx) + txt_idx_len = len(tgt_idx) - len(vid_idx) + repeat_txt_len = (txt_len * txt_repeat).tolist() + + def unconcat_coalesce(all): + vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) + txt_out_coalesced = [] + for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): + txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) + txt_out_coalesced.append(txt) + return vid_out, torch.cat(txt_out_coalesced) + + return ( + lambda vid, txt: torch.cat([vid, txt])[tgt_idx], + lambda all: unconcat_coalesce(all), + ) + + +@dataclass +class MMArg: + vid: Any + txt: Any + +def safe_pad_operation(x, padding, mode='constant', value=0.0): + """Safe padding operation that handles Half precision only for problematic modes""" + # Modes qui nécessitent le fix Half precision + problematic_modes = ['replicate', 'reflect', 'circular'] + + if mode in problematic_modes: + try: + return F.pad(x, padding, mode=mode, value=value) + except RuntimeError as e: + if "not implemented for 'Half'" in str(e): + original_dtype = x.dtype + return F.pad(x.float(), padding, mode=mode, value=value).to(original_dtype) + else: + raise e + else: + # Pour 'constant' et autres modes compatibles, pas de fix nécessaire + return F.pad(x, padding, mode=mode, value=value) + + +def get_args(key: str, args: List[Any]) -> List[Any]: + return [getattr(v, key) if isinstance(v, MMArg) else v for v in args] + + +def get_kwargs(key: str, kwargs: Dict[str, Any]) -> Dict[str, Any]: + return {k: getattr(v, key) if isinstance(v, MMArg) else v for k, v in kwargs.items()} + + +def get_window_op(name: str): + if name == "720pwin_by_size_bysize": + return make_720Pwindows_bysize + if name == "720pswin_by_size_bysize": + return make_shifted_720Pwindows_bysize + raise ValueError(f"Unknown windowing method: {name}") + + +# -------------------------------- Windowing -------------------------------- # +def make_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + #cal windows under 720p + scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. + nt, nh, nw = ceil(t / wt), ceil(h / wh), ceil(w / ww) # window size. + return [ + ( + slice(it * wt, min((it + 1) * wt, t)), + slice(ih * wh, min((ih + 1) * wh, h)), + slice(iw * ww, min((iw + 1) * ww, w)), + ) + for iw in range(nw) + if min((iw + 1) * ww, w) > iw * ww + for ih in range(nh) + if min((ih + 1) * wh, h) > ih * wh + for it in range(nt) + if min((it + 1) * wt, t) > it * wt + ] + +def make_shifted_720Pwindows_bysize(size: Tuple[int, int, int], num_windows: Tuple[int, int, int]): + t, h, w = size + resized_nt, resized_nh, resized_nw = num_windows + #cal windows under 720p + scale = math.sqrt(BYTEDANCE_720P_REF_AREA / (h * w)) + resized_h, resized_w = round(h * scale), round(w * scale) + wh, ww = ceil(resized_h / resized_nh), ceil(resized_w / resized_nw) # window size. + wt = ceil(min(t, BYTEDANCE_MAX_TEMPORAL_WINDOW) / resized_nt) # window size. + + st, sh, sw = ( # shift size. + 0.5 if wt < t else 0, + 0.5 if wh < h else 0, + 0.5 if ww < w else 0, + ) + nt, nh, nw = ceil((t - st) / wt), ceil((h - sh) / wh), ceil((w - sw) / ww) # window size. + nt, nh, nw = ( # number of window. + nt + 1 if st > 0 else 1, + nh + 1 if sh > 0 else 1, + nw + 1 if sw > 0 else 1, + ) + return [ + ( + slice(max(int((it - st) * wt), 0), min(int((it - st + 1) * wt), t)), + slice(max(int((ih - sh) * wh), 0), min(int((ih - sh + 1) * wh), h)), + slice(max(int((iw - sw) * ww), 0), min(int((iw - sw + 1) * ww), w)), + ) + for iw in range(nw) + if min(int((iw - sw + 1) * ww), w) > max(int((iw - sw) * ww), 0) + for ih in range(nh) + if min(int((ih - sh + 1) * wh), h) > max(int((ih - sh) * wh), 0) + for it in range(nt) + if min(int((it - st + 1) * wt), t) > max(int((it - st) * wt), 0) + ] + +class RotaryEmbedding(nn.Module): + def __init__( + self, + dim, + custom_freqs = None, + freqs_for = 'lang', + theta = 10000, + max_freq = 10, + num_freqs = 1, + learned_freq = False, + use_xpos = False, + xpos_scale_base = 512, + interpolate_factor = 1., + theta_rescale_factor = 1., + seq_before_head_dim = False, + cache_if_possible = True, + cache_max_seq_len = 8192 + ): + super().__init__() + + theta *= theta_rescale_factor ** (dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + + self.cache_if_possible = cache_if_possible + self.cache_max_seq_len = cache_max_seq_len + + self.register_buffer('cached_freqs', torch.zeros(cache_max_seq_len, dim), persistent = False) + self.cached_freqs_seq_len = 0 + + self.freqs = nn.Parameter(freqs, requires_grad = learned_freq) + + self.learned_freq = learned_freq + + # dummy for device + + self.register_buffer('dummy', torch.tensor(0), persistent = False) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1. + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + + if not use_xpos: + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + + self.register_buffer('scale', scale, persistent = False) + self.register_buffer('cached_scales', torch.zeros(cache_max_seq_len, dim), persistent = False) + self.cached_scales_seq_len = 0 + + # add apply_rotary_emb as static method + + self.apply_rotary_emb = staticmethod(apply_rotary_emb) + + @property + def device(self): + return self.dummy.device + + def get_axial_freqs( + self, + *dims, + offsets = None + ): + Colon = slice(None) + all_freqs = [] + + # handle offset + + if exists(offsets): + assert len(offsets) == len(dims) + + for ind, dim in enumerate(dims): + + offset = 0 + if exists(offsets): + offset = offsets[ind] + + if self.freqs_for == 'pixel': + pos = torch.linspace(-1, 1, steps = dim, device = self.device) + else: + pos = torch.arange(dim, device = self.device) + + pos = pos + offset + + freqs = self.forward(pos, seq_len = dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + # concat all freqs + + all_freqs = torch.broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim = -1) + + def forward( + self, + t, + seq_len: int | None = None, + offset = 0 + ): + should_cache = ( + self.cache_if_possible and + not self.learned_freq and + exists(seq_len) and + self.freqs_for != 'pixel' and + (offset + seq_len) <= self.cache_max_seq_len + ) + + if ( + should_cache and \ + exists(self.cached_freqs) and \ + (offset + seq_len) <= self.cached_freqs_seq_len + ): + return self.cached_freqs[offset:(offset + seq_len)].detach() + + freqs = self.freqs + + freqs = torch.einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = einops.repeat(freqs, '... n -> ... (n r)', r = 2) + + if should_cache and offset == 0: + self.cached_freqs[:seq_len] = freqs.detach() + self.cached_freqs_seq_len = seq_len + + return freqs + +class RotaryEmbeddingBase(nn.Module): + def __init__(self, dim: int, rope_dim: int): + super().__init__() + self.rope = RotaryEmbedding( + dim=dim // rope_dim, + freqs_for="pixel", + max_freq=BYTEDANCE_ROPE_MAX_FREQ, + ) + freqs = self.rope.freqs + del self.rope.freqs + self.rope.register_buffer("freqs", freqs.data) + + def get_axial_freqs(self, *dims): + return self.rope.get_axial_freqs(*dims) + + +class RotaryEmbedding3d(RotaryEmbeddingBase): + def __init__(self, dim: int): + super().__init__(dim, rope_dim=3) + self.mm = False + + def forward( + self, + q: torch.FloatTensor, # b h l d + k: torch.FloatTensor, # b h l d + size: Tuple[int, int, int], + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + T, H, W = size + freqs = self.get_axial_freqs(T, H, W) + q = rearrange(q, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) + k = rearrange(k, "b h (T H W) d -> b h T H W d", T=T, H=H, W=W) + q = apply_rotary_emb(freqs, q.float()).to(q.dtype) + k = apply_rotary_emb(freqs, k.float()).to(k.dtype) + q = rearrange(q, "b h T H W d -> b h (T H W) d") + k = rearrange(k, "b h T H W d -> b h (T H W) d") + return q, k + + +class NaRotaryEmbedding3d(RotaryEmbedding3d): + def forward( + self, + q: torch.FloatTensor, + k: torch.FloatTensor, + shape: torch.LongTensor, + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + freqs = cache("rope_freqs_3d", lambda: self.get_freqs(shape)) + freqs = freqs.to(device=q.device) + q = rearrange(q, "L h d -> h L d") + k = rearrange(k, "L h d -> h L d") + q = _apply_seedvr2_rotary_emb(freqs, q.float()).to(q.dtype) + k = _apply_seedvr2_rotary_emb(freqs, k.float()).to(k.dtype) + q = rearrange(q, "h L d -> L h d") + k = rearrange(k, "h L d -> L h d") + return q, k + + @torch._dynamo.disable + def get_freqs( + self, + shape: torch.LongTensor, + ) -> torch.Tensor: + # Primary provenance: ByteDance-Seed/SeedVR models/dit/rope.py builds + # 7B pixel RoPE with the interleaved-angle convention, not Comfy's + # Flux freqs_cis matrix. + plain_rope = RotaryEmbedding( + dim=self.rope.freqs.numel() * 2, + freqs_for="pixel", + max_freq=BYTEDANCE_ROPE_MAX_FREQ, + ) + plain_rope = plain_rope.to(self.rope.dummy.device) + freq_list = [] + for f, h, w in shape.tolist(): + freqs = plain_rope.get_axial_freqs(f, h, w) + freq_list.append(freqs.view(-1, freqs.size(-1))) + return torch.cat(freq_list, dim=0) + + +class MMRotaryEmbeddingBase(RotaryEmbeddingBase): + def __init__(self, dim: int, rope_dim: int): + super().__init__(dim, rope_dim) + self.rope = RotaryEmbedding( + dim=dim // rope_dim, + freqs_for="lang", + theta=ROPE_THETA, + cache_if_possible=False, + ) + freqs = self.rope.freqs + del self.rope.freqs + self.rope.register_buffer("freqs", freqs.data) + self.mm = True + +def slice_at_dim(t, dim_slice: slice, *, dim): + dim += (t.ndim if dim < 0 else 0) + colons = [slice(None)] * t.ndim + colons[dim] = dim_slice + return t[tuple(colons)] + +# rotary embedding helper functions + +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r = 2) + x1, x2 = x.unbind(dim = -1) + x = torch.stack((-x2, x1), dim = -1) + return rearrange(x, '... d r -> ... (d r)') +def exists(val): + return val is not None + +def apply_rotary_emb( + freqs, + t, + start_index = 0, + scale = 1., + seq_dim = -2, + freqs_seq_dim = None +): + dtype = t.dtype + if not exists(freqs_seq_dim): + if freqs.ndim == 2 or t.ndim == 3: + freqs_seq_dim = 0 + + if t.ndim == 3 or exists(freqs_seq_dim): + seq_len = t.shape[seq_dim] + freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim) + + rot_feats = freqs.shape[-1] + end_index = start_index + rot_feats + + t_left = t[..., :start_index] + t_middle = t[..., start_index:end_index] + t_right = t[..., end_index:] + + angles = freqs.to(t_middle.device)[..., ::2] + cos = torch.cos(angles) * scale + sin = torch.sin(angles) * scale + + col0 = torch.stack([cos, sin], dim=-1) + col1 = torch.stack([-sin, cos], dim=-1) + freqs_mat = torch.stack([col0, col1], dim=-1) + + t_middle_out = apply_rope1(t_middle, freqs_mat) + out = torch.cat((t_left, t_middle_out, t_right), dim=-1) + return out.type(dtype) + + +def _apply_seedvr2_rotary_emb( + freqs: torch.Tensor, + t: torch.Tensor, + start_index: int = 0, + scale: float = 1.0, + seq_dim: int = -2, + freqs_seq_dim: int | None = None, +) -> torch.Tensor: + dtype = t.dtype + if freqs_seq_dim is None and (freqs.ndim == 2 or t.ndim == 3): + freqs_seq_dim = 0 + + if t.ndim == 3 or freqs_seq_dim is not None: + seq_len = t.shape[seq_dim] + freqs = slice_at_dim(freqs, slice(-seq_len, None), dim=freqs_seq_dim) + + rot_feats = freqs.shape[-1] + end_index = start_index + rot_feats + + t_left = t[..., :start_index] + t_middle = t[..., start_index:end_index] + t_right = t[..., end_index:] + + freqs = freqs.to(device=t_middle.device, dtype=t_middle.dtype) + cos = freqs.cos() * scale + sin = freqs.sin() * scale + t_middle = (t_middle * cos) + (rotate_half(t_middle) * sin) + return torch.cat((t_left, t_middle, t_right), dim=-1).to(dtype) + +def _to_flux_freqs_cis(freqs_interleaved: torch.Tensor) -> torch.Tensor: + """Convert lucidrains-interleaved freqs to flux-canonical fp32 freqs_cis `[..., d/2, 2, 2]` (cos/-sin/sin/cos), per `comfy/ldm/flux/math.py:rope`.""" + angles = freqs_interleaved[..., ::2].float() + cos = torch.cos(angles) + sin = torch.sin(angles) + out = torch.stack([cos, -sin, sin, cos], dim=-1) + return rearrange(out, "... d (i j) -> ... d i j", i=2, j=2) + + +def _apply_rope1_partial(t: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """Rotate the leading ``rot_d = 2 * freqs_cis.shape[-3]`` dims of ``t`` and pass the rest + through; in-place for inference, cloned for training (autograd). Mirrors the legacy + ``apply_rotary_emb`` ``t_left``/``t_middle``/``t_right`` split: 3B ``rope_dim=128`` gives + ``42*3 = 126`` rotated of head_dim 128 (trailing 2 unrotated). Fast path skips the cat when + ``rot_d == t.shape[-1]``. + """ + out = t.clone() if t.requires_grad or comfy.model_management.in_training else t + rot_d = 2 * freqs_cis.shape[-3] + seq_len = out.shape[-2] + for start in range(0, seq_len, SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS): + end = min(start + SEEDVR2_ROPE_PARTIAL_CHUNK_TOKENS, seq_len) + freqs_chunk = freqs_cis[start:end] + if rot_d == out.shape[-1]: + out[..., start:end, :] = apply_rope1(out[..., start:end, :], freqs_chunk).to(out.dtype) + else: + out[..., start:end, :rot_d] = apply_rope1(out[..., start:end, :rot_d], freqs_chunk).to(out.dtype) + return out + + +class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase): + def __init__(self, dim: int): + super().__init__(dim, rope_dim=3) + + def forward( + self, + vid_q: torch.FloatTensor, # L h d + vid_k: torch.FloatTensor, # L h d + vid_shape: torch.LongTensor, # B 3 + txt_q: torch.FloatTensor, # L h d + txt_k: torch.FloatTensor, # L h d + txt_shape: torch.LongTensor, # B 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_freqs, txt_freqs = cache( + "mmrope_freqs_3d", + lambda: self.get_freqs(vid_shape, txt_shape), + ) + target_device = vid_q.device + if vid_freqs.device != target_device: + vid_freqs = vid_freqs.to(target_device) + if txt_freqs.device != target_device: + txt_freqs = txt_freqs.to(target_device) + vid_q = rearrange(vid_q, "L h d -> h L d") + vid_k = rearrange(vid_k, "L h d -> h L d") + vid_q = _apply_rope1_partial(vid_q, vid_freqs) + vid_k = _apply_rope1_partial(vid_k, vid_freqs) + vid_q = rearrange(vid_q, "h L d -> L h d") + vid_k = rearrange(vid_k, "h L d -> L h d") + + txt_q = rearrange(txt_q, "L h d -> h L d") + txt_k = rearrange(txt_k, "L h d -> h L d") + txt_q = _apply_rope1_partial(txt_q, txt_freqs) + txt_k = _apply_rope1_partial(txt_k, txt_freqs) + txt_q = rearrange(txt_q, "h L d -> L h d") + txt_k = rearrange(txt_k, "h L d -> L h d") + return vid_q, vid_k, txt_q, txt_k + + @torch._dynamo.disable # Disable compilation: .tolist() is data-dependent and causes graph breaks + def get_freqs( + self, + vid_shape: torch.LongTensor, + txt_shape: torch.LongTensor, + ) -> Tuple[ + torch.Tensor, + torch.Tensor, + ]: + + # Calculate actual max dimensions needed for this batch + max_temporal = 0 + max_height = 0 + max_width = 0 + max_txt_len = 0 + + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + max_temporal = max(max_temporal, l + f) # Need up to l+f for temporal + max_height = max(max_height, h) + max_width = max(max_width, w) + max_txt_len = max(max_txt_len, l) + + autocast_device = "cuda" if torch.cuda.is_available() else "cpu" + with torch.amp.autocast(autocast_device, enabled=False): + vid_freqs = self.get_axial_freqs( + max_temporal + 16, + max_height + 4, + max_width + 4, + ).float() + txt_freqs = self.get_axial_freqs(max_txt_len + 16) + + # Now slice as before + vid_freq_list, txt_freq_list = [], [] + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + vid_freq = vid_freqs[l : l + f, :h, :w].reshape(-1, vid_freqs.size(-1)) + txt_freq = txt_freqs[:l].repeat(1, 3).reshape(-1, vid_freqs.size(-1)) + vid_freq_list.append(vid_freq) + txt_freq_list.append(txt_freq) + vid_freqs_interleaved = torch.cat(vid_freq_list, dim=0) + txt_freqs_interleaved = torch.cat(txt_freq_list, dim=0) + + # Convert from lucidrains-interleaved layout `[θ0, θ0, θ1, θ1, ...]` + # (produced by `repeat(freqs, '... n -> ... (n r)', r=2)` in the + # upstream `RotaryEmbedding.forward`) to flux-canonical `freqs_cis` + # in shape `[..., d/2, 2, 2]` with `cos/-sin/sin/cos` baked in. + # Mirrors `comfy/ldm/flux/math.py:rope` (line 27) so the trailing + # 2x2 is the per-frequency rotation matrix that + # `comfy.ldm.flux.math.apply_rope1` expects. + return _to_flux_freqs_cis(vid_freqs_interleaved), _to_flux_freqs_cis(txt_freqs_interleaved) + +class MMModule(nn.Module): + def __init__( + self, + module: Callable[..., nn.Module], + *args, + shared_weights: bool = False, + vid_only: bool = False, + **kwargs, + ): + super().__init__() + self.shared_weights = shared_weights + self.vid_only = vid_only + if self.shared_weights: + assert get_args("vid", args) == get_args("txt", args) + assert get_kwargs("vid", kwargs) == get_kwargs("txt", kwargs) + self.all = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + else: + self.vid = module(*get_args("vid", args), **get_kwargs("vid", kwargs)) + self.txt = ( + module(*get_args("txt", args), **get_kwargs("txt", kwargs)) + if not vid_only + else None + ) + + def forward( + self, + vid: torch.FloatTensor, + txt: torch.FloatTensor, + *args, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_module = self.vid if not self.shared_weights else self.all + vid = vid_module(vid, *get_args("vid", args), **get_kwargs("vid", kwargs)) + if not self.vid_only: + txt_module = self.txt if not self.shared_weights else self.all + txt = txt.to(device=vid.device, dtype=vid.dtype) + txt = txt_module(txt, *get_args("txt", args), **get_kwargs("txt", kwargs)) + return vid, txt + +def get_na_rope(rope_type: Optional[str], dim: int): + if rope_type is None: + return None + if rope_type == "rope3d": + return NaRotaryEmbedding3d(dim=dim) + if rope_type == "mmrope3d": + return NaMMRotaryEmbedding3d(dim=dim) + +class NaMMAttention(nn.Module): + def __init__( + self, + vid_dim: int, + txt_dim: int, + heads: int, + head_dim: int, + qk_bias: bool, + qk_norm, + qk_norm_eps: float, + rope_type: Optional[str], + rope_dim: int, + shared_weights: bool, + device, dtype, operations, + **kwargs, + ): + super().__init__() + dim = MMArg(vid_dim, txt_dim) + self.heads = heads + inner_dim = heads * head_dim + qkv_dim = inner_dim * 3 + self.head_dim = head_dim + self.proj_qkv = MMModule( + operations.Linear, dim, qkv_dim, bias=qk_bias, shared_weights=shared_weights, device=device, dtype=dtype + ) + self.proj_out = MMModule(operations.Linear, inner_dim, dim, shared_weights=shared_weights, device=device, dtype=dtype) + self.norm_q = MMModule( + qk_norm, + normalized_shape=head_dim, + eps=qk_norm_eps, + elementwise_affine=True, + shared_weights=shared_weights, + device=device, dtype=dtype + ) + self.norm_k = MMModule( + qk_norm, + normalized_shape=head_dim, + eps=qk_norm_eps, + elementwise_affine=True, + shared_weights=shared_weights, + device=device, dtype=dtype + ) + + + self.rope = get_na_rope(rope_type=rope_type, dim=rope_dim) + + def forward(self): + pass + +def window( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid = unflatten(hid, hid_shape) + hid = list(map(window_fn, hid)) + hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) + hid, hid_shape = flatten(list(chain(*hid))) + return hid, hid_shape, hid_windows + +def window_idx( + hid_shape: torch.LongTensor, # (b n) + window_fn: Callable[[torch.Tensor], List[torch.Tensor]], +): + hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) + tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) + tgt_idx = tgt_idx.squeeze(-1) + src_idx = torch.argsort(tgt_idx) + return ( + lambda hid: torch.index_select(hid, 0, tgt_idx), + lambda hid: torch.index_select(hid, 0, src_idx), + tgt_shape, + tgt_windows, + ) + +class NaSwinAttention(NaMMAttention): + def __init__( + self, + *args, + window: Union[int, Tuple[int, int, int]], + window_method: bool, # shifted or not + **kwargs, + ): + super().__init__(*args, **kwargs) + self.version_7b = kwargs.get("version", False) + self.window = _triple(window) + self.window_method = window_method + assert all(map(lambda v: isinstance(v, int) and v >= 0, self.window)) + + self.window_op = get_window_op(window_method) + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + + vid_qkv, txt_qkv = self.proj_qkv(vid, txt) + + # re-org the input seq for window attn + cache_win = cache.namespace(f"{self.window_method}_{self.window}_sd3") + + def make_window(x: torch.Tensor): + t, h, w, _ = x.shape + window_slices = self.window_op((t, h, w), self.window) + return [x[st, sh, sw] for (st, sh, sw) in window_slices] + + window_partition, window_reverse, window_shape, window_count = cache_win( + "win_transform", + lambda: window_idx(vid_shape, make_window), + ) + vid_qkv_win = window_partition(vid_qkv) + + vid_qkv_win = rearrange(vid_qkv_win, "l (o h d) -> l o h d", o=3, d=self.head_dim) + txt_qkv = rearrange(txt_qkv, "l (o h d) -> l o h d", o=3, d=self.head_dim) + + vid_q, vid_k, vid_v = vid_qkv_win.unbind(1) + txt_q, txt_k, txt_v = txt_qkv.unbind(1) + + vid_q, txt_q = self.norm_q(vid_q, txt_q) + vid_k, txt_k = self.norm_k(vid_k, txt_k) + + txt_len = cache("txt_len", lambda: txt_shape.prod(-1)) + + vid_len_win = cache_win("vid_len", lambda: window_shape.prod(-1)) + txt_len = txt_len.to(window_count.device) + + # window rope + if self.rope: + if self.version_7b: + vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) + elif self.rope.mm: + # repeat text q and k for window mmrope + _, num_h, _ = txt_q.shape + txt_q_repeat = rearrange(txt_q, "l h d -> l (h d)") + txt_q_repeat = unflatten(txt_q_repeat, txt_shape) + txt_q_repeat = [[x] * n for x, n in zip(txt_q_repeat, window_count)] + txt_q_repeat = list(chain(*txt_q_repeat)) + txt_q_repeat, txt_shape_repeat = flatten(txt_q_repeat) + txt_q_repeat = rearrange(txt_q_repeat, "l (h d) -> l h d", h=num_h) + + txt_k_repeat = rearrange(txt_k, "l h d -> l (h d)") + txt_k_repeat = unflatten(txt_k_repeat, txt_shape) + txt_k_repeat = [[x] * n for x, n in zip(txt_k_repeat, window_count)] + txt_k_repeat = list(chain(*txt_k_repeat)) + txt_k_repeat, _ = flatten(txt_k_repeat) + txt_k_repeat = rearrange(txt_k_repeat, "l (h d) -> l h d", h=num_h) + + vid_q, vid_k, txt_q, txt_k = self.rope( + vid_q, vid_k, window_shape, txt_q_repeat, txt_k_repeat, txt_shape_repeat, cache_win + ) + else: + vid_q, vid_k = self.rope(vid_q, vid_k, window_shape, cache_win) + + txt_len_win = cache_win("txt_len", lambda: txt_len.repeat_interleave(window_count)) + all_len_win = cache_win("all_len", lambda: vid_len_win + txt_len_win) + concat_win, unconcat_win = cache_win( + "mm_pnp", lambda: repeat_concat_idx(vid_len_win, txt_len, window_count) + ) + out = optimized_var_attention( + q=concat_win(vid_q, txt_q), + k=concat_win(vid_k, txt_k), + v=concat_win(vid_v, txt_v), + heads=self.heads, skip_reshape=True, skip_output_reshape=True, + cu_seqlens_q=cache_win( + "vid_seqlens_q", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + cu_seqlens_k=cache_win( + "vid_seqlens_k", lambda: safe_pad_operation(all_len_win.cumsum(0), (1, 0)).int() + ), + ) + vid_out, txt_out = unconcat_win(out) + + vid_out = rearrange(vid_out, "l h d -> l (h d)") + txt_out = rearrange(txt_out, "l h d -> l (h d)") + vid_out = window_reverse(vid_out) + + vid_out, txt_out = self.proj_out(vid_out, txt_out) + + return vid_out, txt_out + +class MLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + device, dtype, operations + ): + super().__init__() + self.proj_in = operations.Linear(dim, dim * expand_ratio, device=device, dtype=dtype) + self.act = nn.GELU("tanh") + self.proj_out = operations.Linear(dim * expand_ratio, dim, device=device, dtype=dtype) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + x = self.proj_in(x) + x = self.act(x) + x = self.proj_out(x) + return x + + +class SwiGLUMLP(nn.Module): + def __init__( + self, + dim: int, + expand_ratio: int, + multiple_of: int = 256, + device=None, dtype=None, operations=None + ): + super().__init__() + hidden_dim = int(2 * dim * expand_ratio / 3) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + self.proj_in_gate = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) + self.proj_out = operations.Linear(hidden_dim, dim, bias=False, device=device, dtype=dtype) + self.proj_in = operations.Linear(dim, hidden_dim, bias=False, device=device, dtype=dtype) + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + return self.proj_out(F.silu(self.proj_in_gate(x)) * self.proj_in(x)) + +def get_mlp(mlp_type: Optional[str] = "normal"): + # 3b and 7b uses different mlp types + if mlp_type == "normal": + return MLP + elif mlp_type == "swiglu": + return SwiGLUMLP + +class NaMMSRTransformerBlock(nn.Module): + def __init__( + self, + *, + vid_dim: int, + txt_dim: int, + emb_dim: int, + heads: int, + head_dim: int, + expand_ratio: int, + norm, + norm_eps: float, + ada, + qk_bias: bool, + qk_norm, + mlp_type: str, + shared_weights: bool, + rope_type: str, + rope_dim: int, + is_last_layer: bool, + device, dtype, operations, + **kwargs, + ): + super().__init__() + version = kwargs.get("version", False) + dim = MMArg(vid_dim, txt_dim) + self.attn_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, device=device, dtype=dtype) + + self.attn = NaSwinAttention( + vid_dim=vid_dim, + txt_dim=txt_dim, + heads=heads, + head_dim=head_dim, + qk_bias=qk_bias, + qk_norm=qk_norm, + qk_norm_eps=norm_eps, + rope_type=rope_type, + rope_dim=rope_dim, + shared_weights=shared_weights, + window=kwargs.pop("window", None), + window_method=kwargs.pop("window_method", None), + version=version, + device=device, dtype=dtype, operations=operations + ) + + self.mlp_norm = MMModule(norm, normalized_shape=dim, eps=norm_eps, elementwise_affine=False, shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) + self.mlp = MMModule( + get_mlp(mlp_type), + dim=dim, + expand_ratio=expand_ratio, + shared_weights=shared_weights, + vid_only=is_last_layer, + device=device, dtype=dtype, operations=operations + ) + self.ada = MMModule(ada, dim=dim, emb_dim=emb_dim, layers=["attn", "mlp"], shared_weights=shared_weights, vid_only=is_last_layer, device=device, dtype=dtype) + self.is_last_layer = is_last_layer + self.version = version + + def _seedvr2_7b_mlp( + self, + vid: torch.FloatTensor, + txt: torch.FloatTensor, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + ]: + vid_module = self.mlp.vid if not self.mlp.shared_weights else self.mlp.all + if comfy.model_management.in_training or vid.requires_grad: + vid = torch.cat([vid_module(chunk) for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0)], dim=0) + else: + vid_out = None + offset = 0 + for chunk in vid.split(SEEDVR2_7B_MLP_CHUNK, dim=0): + chunk_out = vid_module(chunk) + if vid_out is None: + vid_out = chunk_out.new_empty((vid.shape[0], *chunk_out.shape[1:])) + vid_out[offset:offset + chunk_out.shape[0]] = chunk_out + offset += chunk_out.shape[0] + vid = vid_out + if not self.mlp.vid_only: + txt_module = self.mlp.txt if not self.mlp.shared_weights else self.mlp.all + txt = txt.to(device=vid.device, dtype=vid.dtype) + txt = txt_module(txt) + return vid, txt + + def forward( + self, + vid: torch.FloatTensor, # l c + txt: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, # b 3 + txt_shape: torch.LongTensor, # b 1 + emb: torch.FloatTensor, + cache: Cache, + ) -> Tuple[ + torch.FloatTensor, + torch.FloatTensor, + torch.LongTensor, + torch.LongTensor, + ]: + hid_len = MMArg( + cache("vid_len", lambda: vid_shape.prod(-1)), + cache("txt_len", lambda: txt_shape.prod(-1)), + ) + ada_kwargs = { + "emb": emb, + "hid_len": hid_len, + "cache": cache, + "branch_tag": MMArg("vid", "txt"), + } + + vid_attn, txt_attn = self.attn_norm(vid, txt) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="in", **ada_kwargs) + vid_attn, txt_attn = self.attn(vid_attn, txt_attn, vid_shape, txt_shape, cache) + vid_attn, txt_attn = self.ada(vid_attn, txt_attn, layer="attn", mode="out", **ada_kwargs) + vid_attn, txt_attn = (vid_attn + vid), (txt_attn + txt) + + vid_mlp, txt_mlp = self.mlp_norm(vid_attn, txt_attn) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="in", **ada_kwargs) + if self.version: + vid_mlp, txt_mlp = self._seedvr2_7b_mlp(vid_mlp, txt_mlp) + else: + vid_mlp, txt_mlp = self.mlp(vid_mlp, txt_mlp) + vid_mlp, txt_mlp = self.ada(vid_mlp, txt_mlp, layer="mlp", mode="out", **ada_kwargs) + vid_mlp, txt_mlp = (vid_mlp + vid_attn), (txt_mlp + txt_attn) + + return vid_mlp, txt_mlp, vid_shape, txt_shape + +class PatchOut(nn.Module): + def __init__( + self, + out_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + device, dtype, operations + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = operations.Linear(dim, out_channels * t * h * w, device=device, dtype=dtype) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + vid = self.proj(vid) + vid = rearrange(vid, "b T H W (t h w c) -> b c (T t) (H h) (W w)", t=t, h=h, w=w) + if t > 1: + vid = vid[:, :, (t - 1) :] + return vid + +class NaPatchOut(PatchOut): + def forward( + self, + vid: torch.FloatTensor, # l c + vid_shape: torch.LongTensor, + cache: Cache = Cache(disable=True), # for test + vid_shape_before_patchify = None + ) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, + ]: + + t, h, w = self.patch_size + vid = self.proj(vid) + + if not (t == h == w == 1): + vid = unflatten(vid, vid_shape) + for i in range(len(vid)): + vid[i] = rearrange(vid[i], "T H W (t h w c) -> (T t) (H h) (W w) c", t=t, h=h, w=w) + if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: + vid[i] = vid[i][(t - vid_shape_before_patchify[i, 0] % t) :] + vid, vid_shape = flatten(vid) + + return vid, vid_shape + +class PatchIn(nn.Module): + def __init__( + self, + in_channels: int, + patch_size: Union[int, Tuple[int, int, int]], + dim: int, + device, dtype, operations + ): + super().__init__() + t, h, w = _triple(patch_size) + self.patch_size = t, h, w + self.proj = operations.Linear(in_channels * t * h * w, dim, device=device, dtype=dtype) + + def forward( + self, + vid: torch.Tensor, + ) -> torch.Tensor: + t, h, w = self.patch_size + if t > 1: + assert vid.size(2) % t == 1 + vid = torch.cat([vid[:, :, :1]] * (t - 1) + [vid], dim=2) + vid = rearrange(vid, "b c (T t) (H h) (W w) -> b T H W (t h w c)", t=t, h=h, w=w) + vid = self.proj(vid) + return vid + +class NaPatchIn(PatchIn): + def forward( + self, + vid: torch.Tensor, # l c + vid_shape: torch.LongTensor, + cache: Cache = Cache(disable=True), # for test + ) -> torch.Tensor: + cache = cache.namespace("patch") + vid_shape_before_patchify = cache("vid_shape_before_patchify", lambda: vid_shape) + t, h, w = self.patch_size + if not (t == h == w == 1): + vid = unflatten(vid, vid_shape) + for i in range(len(vid)): + if t > 1 and vid_shape_before_patchify[i, 0] % t != 0: + vid[i] = torch.cat([vid[i][:1]] * (t - vid[i].size(0) % t) + [vid[i]], dim=0) + vid[i] = rearrange(vid[i], "(T t) (H h) (W w) c -> T H W (t h w c)", t=t, h=h, w=w) + vid, vid_shape = flatten(vid) + + vid = self.proj(vid) + return vid, vid_shape + +def expand_dims(x: torch.Tensor, dim: int, ndim: int): + shape = x.shape + shape = shape[:dim] + (1,) * (ndim - len(shape)) + shape[dim:] + return x.reshape(shape) + + +class AdaSingle(nn.Module): + def __init__( + self, + dim: int, + emb_dim: int, + layers: List[str], + modes: List[str] = ["in", "out"], + device = None, dtype = None, + ): + assert emb_dim == 6 * dim, "AdaSingle requires emb_dim == 6 * dim" + super().__init__() + self.dim = dim + self.emb_dim = emb_dim + self.layers = layers + + randn_kwargs = {"device": device} + fp8_types = _torch_float8_types() + if dtype is not None and dtype not in fp8_types: + randn_kwargs["dtype"] = dtype + + for l in layers: + if "in" in modes: + # Passing fp8 ``dtype=`` here would break CPU weight + # loads: CPU has no ``normal_kernel_cpu`` for fp8. + self.register_parameter(f"{l}_shift", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5)) + self.register_parameter( + f"{l}_scale", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5 + 1) + ) + if "out" in modes: + self.register_parameter(f"{l}_gate", nn.Parameter(torch.randn(dim, **randn_kwargs) / dim**0.5)) + + def forward( + self, + hid: torch.FloatTensor, # b ... c + emb: torch.FloatTensor, # b d + layer: str, + mode: str, + cache: Cache = Cache(disable=True), + branch_tag: str = "", + hid_len: Optional[torch.LongTensor] = None, # b + ) -> torch.FloatTensor: + idx = self.layers.index(layer) + emb = rearrange(emb, "b (d l g) -> b d l g", l=len(self.layers), g=3)[..., idx, :] + emb = expand_dims(emb, 1, hid.ndim + 1) + + if hid_len is not None: + slice_inputs = lambda x, dim: x + emb = cache( + f"emb_repeat_{idx}_{branch_tag}", + lambda: slice_inputs( + torch.repeat_interleave(emb, hid_len, dim=0), + dim=0, + ), + ) + + shiftA, scaleA, gateA = emb.unbind(-1) + shiftB, scaleB, gateB = ( + getattr(self, f"{layer}_shift", None), + getattr(self, f"{layer}_scale", None), + getattr(self, f"{layer}_gate", None), + ) + + fp8_types = _torch_float8_types() + if fp8_types: + target_dtype = hid.dtype + + if shiftB is not None and shiftB.dtype in fp8_types: + shiftB = shiftB.to(target_dtype) + if scaleB is not None and scaleB.dtype in fp8_types: + scaleB = scaleB.to(target_dtype) + if gateB is not None and gateB.dtype in fp8_types: + gateB = gateB.to(target_dtype) + + if mode == "in": + return hid.mul_(scaleA + scaleB).add_(shiftA + shiftB) + if mode == "out": + if gateB is not None: + return hid.mul_(gateA + gateB) + else: + return hid.mul_(gateA) + + raise NotImplementedError + + +def emb_add(emb1: torch.Tensor, emb2: Optional[torch.Tensor]): + return emb1 if emb2 is None else emb1 + emb2 + + +class TimeEmbedding(nn.Module): + def __init__( + self, + sinusoidal_dim: int, + hidden_dim: int, + output_dim: int, + device, dtype, operations + ): + super().__init__() + self.sinusoidal_dim = sinusoidal_dim + self.proj_in = operations.Linear(sinusoidal_dim, hidden_dim, device=device, dtype=dtype) + self.proj_hid = operations.Linear(hidden_dim, hidden_dim, device=device, dtype=dtype) + self.proj_out = operations.Linear(hidden_dim, output_dim, device=device, dtype=dtype) + self.act = nn.SiLU() + + def forward( + self, + timestep: Union[int, float, torch.IntTensor, torch.FloatTensor], + device: torch.device, + dtype: torch.dtype, + ) -> torch.FloatTensor: + if not torch.is_tensor(timestep): + timestep = torch.tensor([timestep], device=device, dtype=dtype) + if timestep.ndim == 0: + timestep = timestep[None] + + emb = get_timestep_embedding( + timesteps=timestep, + embedding_dim=self.sinusoidal_dim, + flip_sin_to_cos=False, + downscale_freq_shift=0, + ).to(dtype) + emb = self.proj_in(emb) + emb = self.act(emb) + emb = self.proj_hid(emb) + emb = self.act(emb) + emb = self.proj_out(emb) + return emb + +def flatten( + hid: List[torch.FloatTensor], # List of (*** c) +) -> Tuple[ + torch.FloatTensor, # (L c) + torch.LongTensor, # (b n) +]: + assert len(hid) > 0 + shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) + hid = torch.cat([x.flatten(0, -2) for x in hid]) + return hid, shape + + +def unflatten( + hid: torch.FloatTensor, # (L c) or (L ... c) + hid_shape: torch.LongTensor, # (b n) +) -> List[torch.Tensor]: # List of (*** c) or (*** ... c) + hid_len = hid_shape.prod(-1) + hid = hid.split(hid_len.tolist()) + hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] + return hid + +def repeat( + hid: torch.FloatTensor, # (L c) + hid_shape: torch.LongTensor, # (b n) + pattern: str, + **kwargs: Dict[str, torch.LongTensor], # (b) +) -> Tuple[ + torch.FloatTensor, + torch.LongTensor, +]: + hid = unflatten(hid, hid_shape) + kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] + return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) + +class NaDiT(nn.Module): + + def __init__( + self, + norm_eps, + qk_rope, + num_layers, + mlp_type, + vid_in_channels = 33, + vid_out_channels = 16, + vid_dim = 2560, + txt_in_dim = 5120, + heads = 20, + head_dim = 128, + mm_layers = 10, + expand_ratio = 4, + qk_bias = False, + patch_size = [ 1,2,2 ], + shared_qkv: bool = False, + shared_mlp: bool = False, + window_method: Optional[Tuple[str]] = None, + temporal_window_size: int = None, + temporal_shifted: bool = False, + rope_dim = 128, + rope_type = "mmrope3d", + vid_out_norm: Optional[str] = None, + device = None, + dtype = None, + operations = None, + **kwargs, + ): + self._7b_version = vid_dim == SEEDVR2_7B_VID_DIM + if self._7b_version: + rope_type = "rope3d" + self.dtype = dtype + factory_kwargs = {"device": device, "dtype": dtype} + window_method = num_layers // 2 * ["720pwin_by_size_bysize","720pswin_by_size_bysize"] + txt_dim = vid_dim + emb_dim = vid_dim * 6 + block_type = ["mmdit_sr"] * num_layers + window = num_layers * [(4,3,3)] + ada = AdaSingle + norm = CustomRMSNorm + qk_norm = CustomRMSNorm + if isinstance(block_type, str): + block_type = [block_type] * num_layers + elif len(block_type) != num_layers: + raise ValueError("The ``block_type`` list should equal to ``num_layers``.") + super().__init__() + # ``torch.empty`` returns uninitialized memory, not zeros. The + # SeedVR2Conditioning fail-loud guard at + # ``comfy_extras/nodes_seedvr.py`` distinguishes "buffer was loaded" + # from "buffer was never populated by the file" by checking + # ``positive_conditioning.abs().sum() == 0``. That sentinel is only + # reliable if the post-construction buffer state is deterministically + # zero, so explicitly zero-fill here rather than relying on the + # allocator's zero-on-alloc behavior (allocator-dependent and not + # contractual). When ``load_state_dict`` populates these buffers + # from a properly-baked SeedVR2 .safetensors, the in-place copy + # overwrites the zeros with the universal SeedVR2 conditioning + # tensors (shape (58, 5120) and (64, 5120) bf16). + self.register_buffer("positive_conditioning", torch.zeros((58, 5120), device=device, dtype=dtype)) + self.register_buffer("negative_conditioning", torch.zeros((64, 5120), device=device, dtype=dtype)) + self.vid_in = NaPatchIn( + in_channels=vid_in_channels, + patch_size=patch_size, + dim=vid_dim, + device=device, dtype=dtype, operations=operations + ) + self.txt_in = ( + operations.Linear(txt_in_dim, txt_dim, **factory_kwargs) + if txt_in_dim and txt_in_dim != txt_dim + else nn.Identity() + ) + self.emb_in = TimeEmbedding( + sinusoidal_dim=BYTEDANCE_SINUSOIDAL_DIM, + hidden_dim=max(vid_dim, txt_dim), + output_dim=emb_dim, + device=device, dtype=dtype, operations=operations + ) + + if window is None or isinstance(window[0], int): + window = [window] * num_layers + if window_method is None or isinstance(window_method, str): + window_method = [window_method] * num_layers + if temporal_window_size is None or isinstance(temporal_window_size, int): + temporal_window_size = [temporal_window_size] * num_layers + if temporal_shifted is None or isinstance(temporal_shifted, bool): + temporal_shifted = [temporal_shifted] * num_layers + + rope_dim = rope_dim if rope_dim is not None else head_dim // 2 + self.blocks = nn.ModuleList( + [ + NaMMSRTransformerBlock( + vid_dim=vid_dim, + txt_dim=txt_dim, + emb_dim=emb_dim, + heads=heads, + head_dim=head_dim, + expand_ratio=expand_ratio, + norm=norm, + norm_eps=norm_eps, + ada=ada, + qk_bias=qk_bias, + qk_rope=qk_rope, + qk_norm=qk_norm, + shared_qkv=shared_qkv, + shared_mlp=shared_mlp, + mlp_type=mlp_type, + rope_dim = rope_dim, + window=window[i], + window_method=window_method[i], + temporal_window_size=temporal_window_size[i], + temporal_shifted=temporal_shifted[i], + is_last_layer=(i == num_layers - 1) and not self._7b_version, + rope_type = rope_type, + shared_weights=not ( + (i < mm_layers) if isinstance(mm_layers, int) else mm_layers[i] + ), + version = self._7b_version, + operations = operations, + **kwargs, + **factory_kwargs + ) + for i in range(num_layers) + ] + ) + self.vid_out = NaPatchOut( + out_channels=vid_out_channels, + patch_size=patch_size, + dim=vid_dim, + device=device, dtype=dtype, operations=operations + ) + + self.need_txt_repeat = block_type[0] in [ + "mmdit_stwin", + "mmdit_stwin_spatial", + "mmdit_stwin_3d_spatial", + ] + + self.vid_out_norm = None + if vid_out_norm is not None: + self.vid_out_norm = CustomRMSNorm( + normalized_shape=vid_dim, + eps=norm_eps, + elementwise_affine=True, + device=device, dtype=dtype + ) + self.vid_out_ada = ada( + dim=vid_dim, + emb_dim=emb_dim, + layers=["out"], + modes=["in"], + device=device, dtype=dtype + ) + + def _resolve_text_conditioning(self, context, cond_or_uncond=None): + if context is None or getattr(context, "numel", lambda: None)() == 0: + context = self.positive_conditioning + return flatten([context]) + if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): + if context.shape[0] == 1: + context = context.squeeze(0) + return flatten([context]) + return flatten(context.unbind(0)) + if context.shape[0] % 2 != 0: + raise ValueError(f"SeedVR2 expected an even text-conditioning batch, got shape {tuple(context.shape)}") + neg_cond, pos_cond = context.chunk(2, dim=0) + if pos_cond.shape[0] == 1: + pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) + return flatten([pos_cond, neg_cond]) + return flatten((*pos_cond.unbind(0), *neg_cond.unbind(0))) + + @staticmethod + def _seedvr2_is_single_conditioning_branch(cond_or_uncond): + if cond_or_uncond is None or len(cond_or_uncond) == 0: + return False + first = cond_or_uncond[0] + return all(entry == first for entry in cond_or_uncond) + + def _swap_pos_neg_halves(self, out, cond_or_uncond=None): + if NaDiT._seedvr2_is_single_conditioning_branch(cond_or_uncond): + return out + # ``dim=0`` is explicit on both calls. The contract is "split + # the batch axis into two halves and swap them"; making the + # axis load-bearing in source guards against silent drift if a + # future refactor reorders tensor axes. + pos, neg = out.chunk(2, dim=0) + return torch.cat([neg, pos], dim=0) + + def forward( + self, + x, + timestep, + context, # l c + disable_cache: bool = False, # for test # TODO ? // gives an error when set to True + **kwargs + ): + transformer_options = kwargs.get("transformer_options", {}) + patches_replace = transformer_options.get("patches_replace", {}) + blocks_replace = patches_replace.get("dit", {}) + conditions = kwargs.get("condition") + b, tc, h, w = x.shape + x = x.view(b, 16, -1, h, w) + conditions = conditions.view(b, 17, -1, h, w) + x = x.movedim(1, -1) + conditions = conditions.movedim(1, -1) + cache = Cache(disable=disable_cache) + + txt, txt_shape = self._resolve_text_conditioning(context, transformer_options.get("cond_or_uncond")) + + vid, vid_shape = flatten(x) + cond_latent, _ = flatten(conditions) + + vid = torch.cat([vid, cond_latent], dim=-1) + if txt_shape.size(-1) == 1 and self.need_txt_repeat: + txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) + + txt = self.txt_in(txt) + + vid_shape_before_patchify = vid_shape + vid, vid_shape = self.vid_in(vid, vid_shape, cache=cache) + + emb = self.emb_in(timestep, device=vid.device, dtype=vid.dtype) + + for i, block in enumerate(self.blocks): + if ("block", i) in blocks_replace: + def block_wrap(args): + out = {} + out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] = block( + vid=args["vid"], + txt=args["txt"], + vid_shape=args["vid_shape"], + txt_shape=args["txt_shape"], + emb=args["emb"], + cache=args["cache"], + ) + return out + out = blocks_replace[("block", i)]({ + "vid":vid, + "txt":txt, + "vid_shape":vid_shape, + "txt_shape":txt_shape, + "emb":emb, + "cache":cache, + }, {"original_block": block_wrap}) + vid, txt, vid_shape, txt_shape = out["vid"], out["txt"], out["vid_shape"], out["txt_shape"] + else: + vid, txt, vid_shape, txt_shape = block( + vid=vid, + txt=txt, + vid_shape=vid_shape, + txt_shape=txt_shape, + emb=emb, + cache=cache, + ) + + if self.vid_out_norm: + vid = self.vid_out_norm(vid) + vid = self.vid_out_ada( + vid, + emb=emb, + layer="out", + mode="in", + hid_len=cache("vid_len", lambda: vid_shape.prod(-1)), + cache=cache, + branch_tag="vid", + ) + + vid, vid_shape = self.vid_out(vid, vid_shape, cache, vid_shape_before_patchify = vid_shape_before_patchify) + vid = unflatten(vid, vid_shape) + out = torch.stack(vid) + out = out.movedim(-1, 1) + out = rearrange(out, "b c t h w -> b (c t) h w") + return self._swap_pos_neg_halves(out, transformer_options.get("cond_or_uncond")) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py new file mode 100644 index 000000000..68b11c0ff --- /dev/null +++ b/comfy/ldm/seedvr/vae.py @@ -0,0 +1,2110 @@ +from contextlib import nullcontext +from typing import Literal, Optional, Tuple +import gc +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch import Tensor +from contextlib import contextmanager +from comfy.utils import ProgressBar + +from comfy.ldm.seedvr.model import safe_pad_operation +from comfy.ldm.seedvr.constants import ( + BYTEDANCE_BLOCK_OUT_CHANNELS, + BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD, + BYTEDANCE_GN_CHUNKS_FP16, + BYTEDANCE_GN_CHUNKS_FP32, + BYTEDANCE_LOGVAR_CLAMP_MAX, + BYTEDANCE_LOGVAR_CLAMP_MIN, + BYTEDANCE_SLICING_SAMPLE_MIN, + BYTEDANCE_VAE_CONV_MEM_GIB, + BYTEDANCE_VAE_NORM_MEM_GIB, + BYTEDANCE_VAE_SCALING_FACTOR, + BYTEDANCE_VAE_SHIFTING_FACTOR, + BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE, + BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE, + SEEDVR2_LATENT_CHANNELS, +) +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.modules.diffusionmodules.model import vae_attention + +import math +from enum import Enum +from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND + +import logging +import comfy.model_management +import comfy.ops +ops = comfy.ops.disable_weight_init + + +def _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap, temporal_scale=1): + if temporal_size is None: + return None + + temporal_size = int(temporal_size) + if temporal_size <= 0: + return 0 + + temporal_overlap = max(0, int(temporal_overlap or 0)) + temporal_overlap = min(temporal_overlap, temporal_size - 1) + temporal_step = temporal_size - temporal_overlap + temporal_scale = max(1, int(temporal_scale)) + return max(1, math.ceil(temporal_step / temporal_scale)) + + +def _seedvr2_clamped_spatial_overlap(overlap, tile_size): + overlap = max(0, int(overlap)) + tile_size = max(1, int(tile_size)) + return min(overlap, tile_size - 1) + + +def _seedvr2_clear_temporal_memory(model): + for module in model.modules(): + if hasattr(module, "memory"): + module.memory = None + + +@torch.inference_mode() +def tiled_vae( + x, + vae_model, + tile_size=(512, 512), + tile_overlap=(64, 64), + temporal_size=16, + temporal_overlap=0, + encode=True, + **kwargs, +): + gc.collect() + comfy.model_management.soft_empty_cache() + + x = x.to(next(vae_model.parameters()).dtype) + if x.ndim != 5: + x = x.unsqueeze(2) + + _, _, d, h, w = x.shape + + sf_s = getattr(vae_model, "spatial_downsample_factor", BYTEDANCE_VAE_SPATIAL_DOWNSAMPLE) + sf_t = getattr(vae_model, "temporal_downsample_factor", BYTEDANCE_VAE_TEMPORAL_DOWNSAMPLE) + if encode: + slicing_attr = "slicing_sample_min_size" + slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap) + else: + slicing_attr = "slicing_latent_min_size" + slicing_min_size = _seedvr2_temporal_slicing_min_size(temporal_size, temporal_overlap, sf_t) + if encode: + ti_h, ti_w = tile_size + ov_h = _seedvr2_clamped_spatial_overlap(tile_overlap[0], ti_h) + ov_w = _seedvr2_clamped_spatial_overlap(tile_overlap[1], ti_w) + blend_ov_h = max(0, ov_h // sf_s) + blend_ov_w = max(0, ov_w // sf_s) + target_d = (d + sf_t - 1) // sf_t + target_h = (h + sf_s - 1) // sf_s + target_w = (w + sf_s - 1) // sf_s + else: + ti_h = max(1, tile_size[0] // sf_s) + ti_w = max(1, tile_size[1] // sf_s) + ov_h = _seedvr2_clamped_spatial_overlap(tile_overlap[0] // sf_s, ti_h) + ov_w = _seedvr2_clamped_spatial_overlap(tile_overlap[1] // sf_s, ti_w) + blend_ov_h = ov_h * sf_s + blend_ov_w = ov_w * sf_s + + target_d = max(1, d * sf_t - (sf_t - 1)) + target_h = h * sf_s + target_w = w * sf_s + + stride_h = max(1, ti_h - ov_h) + stride_w = max(1, ti_w - ov_w) + + storage_device = vae_model.device + result = None + count = None + def run_temporal_chunks(spatial_tile, model=vae_model, device=storage_device): + device = torch.device(device) + _seedvr2_clear_temporal_memory(model) + t_chunk = spatial_tile.to(device=device, dtype=next(model.parameters()).dtype, non_blocking=True).contiguous() + old_device = getattr(model, "device", None) + model.device = device + old_slicing_min_size = getattr(model, slicing_attr, None) + if old_slicing_min_size is not None and slicing_min_size is not None: + if slicing_min_size <= 0: + setattr(model, slicing_attr, t_chunk.shape[2]) + else: + setattr(model, slicing_attr, slicing_min_size) + try: + if encode: + out = model.encode(t_chunk)[0] + else: + out = model.decode_(t_chunk) + finally: + if old_slicing_min_size is not None and slicing_min_size is not None: + setattr(model, slicing_attr, old_slicing_min_size) + if old_device is not None: + model.device = old_device + if isinstance(out, (tuple, list)): + out = out[0] + if out.ndim == 4: + out = out.unsqueeze(2) + return out.to(storage_device) + + ramp_cache = {} + def get_ramp(steps): + if steps not in ramp_cache: + t = torch.linspace(0, 1, steps=steps, device=storage_device, dtype=torch.float32) + ramp_cache[steps] = 0.5 - 0.5 * torch.cos(t * torch.pi) + return ramp_cache[steps] + + tile_ranges = [] + for y_idx in range(0, h, stride_h): + y_end = min(y_idx + ti_h, h) + if y_idx > 0 and (y_end - y_idx) <= ov_h: + continue + for x_idx in range(0, w, stride_w): + x_end = min(x_idx + ti_w, w) + if x_idx > 0 and (x_end - x_idx) <= ov_w: + continue + tile_ranges.append((y_idx, y_end, x_idx, x_end)) + + total_tiles = len(tile_ranges) + bar = ProgressBar(total_tiles) + single_spatial_tile = h <= ti_h and w <= ti_w + + _seedvr2_clear_temporal_memory(vae_model) + + def run_tile(tile_index, tile_range): + y_idx, y_end, x_idx, x_end = tile_range + tile_x = x[:, :, :, y_idx:y_end, x_idx:x_end] + tile_out = run_temporal_chunks(tile_x) + return tile_index, y_idx, y_end, x_idx, x_end, tile_out + + ordered_tile_outputs = ( + run_tile(tile_index, tile_range) + for tile_index, tile_range in enumerate(tile_ranges) + ) + + for _, y_idx, y_end, x_idx, x_end, tile_out in ordered_tile_outputs: + + if single_spatial_tile: + result = tile_out[:, :, :target_d, :target_h, :target_w] + if result.device != x.device: + result = result.to(x.device).to(x.dtype) + if x.shape[2] == 1 and sf_t == 1: + result = result.squeeze(2) + bar.update(1) + return result + + if result is None: + b_out, c_out = tile_out.shape[0], tile_out.shape[1] + result = torch.zeros((b_out, c_out, target_d, target_h, target_w), device=storage_device, dtype=torch.float32) + count = torch.zeros((1, 1, 1, target_h, target_w), device=storage_device, dtype=torch.float32) + + if encode: + ys, ye = y_idx // sf_s, (y_idx // sf_s) + tile_out.shape[3] + xs, xe = x_idx // sf_s, (x_idx // sf_s) + tile_out.shape[4] + cur_ov_h = max(0, min(blend_ov_h, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(blend_ov_w, tile_out.shape[4] // 2)) + else: + ys, ye = y_idx * sf_s, (y_idx * sf_s) + tile_out.shape[3] + xs, xe = x_idx * sf_s, (x_idx * sf_s) + tile_out.shape[4] + cur_ov_h = max(0, min(blend_ov_h, tile_out.shape[3] // 2)) + cur_ov_w = max(0, min(blend_ov_w, tile_out.shape[4] // 2)) + + w_h = torch.ones((tile_out.shape[3],), device=storage_device) + w_w = torch.ones((tile_out.shape[4],), device=storage_device) + + if cur_ov_h > 0: + r = get_ramp(cur_ov_h) + if y_idx > 0: + w_h[:cur_ov_h] = r + if y_end < h: + w_h[-cur_ov_h:] = 1.0 - r + + if cur_ov_w > 0: + r = get_ramp(cur_ov_w) + if x_idx > 0: + w_w[:cur_ov_w] = r + if x_end < w: + w_w[-cur_ov_w:] = 1.0 - r + + final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1) + + valid_d = min(tile_out.shape[2], result.shape[2]) + tile_out = tile_out[:, :, :valid_d, :, :] + + tile_out.mul_(final_weight) + + result[:, :, :valid_d, ys:ye, xs:xe] += tile_out + count[:, :, :, ys:ye, xs:xe] += final_weight + + del tile_out, final_weight, w_h, w_w + bar.update(1) + + result.div_(count.clamp(min=1e-6)) + _seedvr2_clear_temporal_memory(vae_model) + + if result.device != x.device: + result = result.to(x.device).to(x.dtype) + + if x.shape[2] == 1 and sf_t == 1: + result = result.squeeze(2) + + return result + +_NORM_LIMIT = float("inf") +def get_norm_limit(): + return _NORM_LIMIT + + +def set_norm_limit(value: Optional[float] = None): + global _NORM_LIMIT + if value is None: + value = float("inf") + _NORM_LIMIT = value + +@contextmanager +def ignore_padding(model): + orig_padding = model.padding + model.padding = (0, 0, 0) + try: + yield + finally: + model.padding = orig_padding + +class MemoryState(Enum): + DISABLED = 0 + INITIALIZING = 1 + ACTIVE = 2 + UNSET = 3 + +def get_cache_size(conv_module, input_len, pad_len, dim=0): + dilated_kernerl_size = conv_module.dilation[dim] * (conv_module.kernel_size[dim] - 1) + 1 + output_len = (input_len + pad_len - dilated_kernerl_size) // conv_module.stride[dim] + 1 + remain_len = ( + input_len + pad_len - ((output_len - 1) * conv_module.stride[dim] + dilated_kernerl_size) + ) + overlap_len = dilated_kernerl_size - conv_module.stride[dim] + cache_len = overlap_len + remain_len # >= 0 + + assert output_len > 0 + return cache_len + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, BYTEDANCE_LOGVAR_CLAMP_MIN, BYTEDANCE_LOGVAR_CLAMP_MAX) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like( + self.mean, device=self.parameters.device, dtype=self.parameters.dtype + ) + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + sample = torch.randn( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def mode(self): + return self.mean + +class SpatialNorm(nn.Module): + def __init__( + self, + f_channels: int, + zq_channels: int, + ): + super().__init__() + self.norm_layer = ops.GroupNorm(num_channels=f_channels, num_groups=32, eps=1e-6, affine=True) + self.conv_y = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + self.conv_b = ops.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor: + f_size = f.shape[-2:] + zq = F.interpolate(zq, size=f_size, mode="nearest") + norm_f = self.norm_layer(f) + new_f = norm_f * self.conv_y(zq) + self.conv_b(zq) + return new_f + +# partial implementation of diffusers's Attention for comfyui +class Attention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + kv_heads: Optional[int] = None, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + added_proj_bias: Optional[bool] = True, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + out_dim: int = None, + out_context_dim: int = None, + context_pre_only=None, + pre_only=False, + is_causal: bool = False, + ): + super().__init__() + + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim + self.context_pre_only = context_pre_only + self.pre_only = pre_only + self.is_causal = is_causal + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + self.heads = out_dim // dim_head if out_dim is not None else heads + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if norm_num_groups is not None: + self.group_norm = ops.GroupNorm(num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm(f_channels=query_dim, zq_channels=spatial_norm_dim) + else: + self.spatial_norm = None + + self.norm_q = None + self.norm_k = None + + self.norm_cross = None + self.to_q = ops.Linear(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + self.to_v = ops.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + self.added_proj_bias = added_proj_bias + if self.added_kv_proj_dim is not None: + self.add_k_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + self.add_v_proj = ops.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias) + if self.context_pre_only is not None: + self.add_q_proj = ops.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias) + else: + self.add_q_proj = None + self.add_k_proj = None + self.add_v_proj = None + + if not self.pre_only: + self.to_out = nn.ModuleList([]) + self.to_out.append(ops.Linear(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + else: + self.to_out = None + + if self.context_pre_only is not None and not self.context_pre_only: + self.to_add_out = ops.Linear(self.inner_dim, self.out_context_dim, bias=out_bias) + else: + self.to_add_out = None + + self.norm_added_q = None + self.norm_added_k = None + self.optimized_vae_attention = vae_attention() + + def __call__( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + temb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + + residual = hidden_states + if self.spatial_norm is not None: + hidden_states = self.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = self.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attention_mask.view(batch_size, self.heads, -1, attention_mask.shape[-1]) + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif self.norm_cross: + encoder_hidden_states = self.norm_encoder_hidden_states(encoder_hidden_states) + + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // self.heads + + query = query.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, self.heads, head_dim).transpose(1, 2) + + if self.norm_q is not None: + query = self.norm_q(query) + if self.norm_k is not None: + key = self.norm_k(key) + + if input_ndim == 4 and encoder_hidden_states is hidden_states and attention_mask is None and self.heads == 1: + query = query.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) + key = key.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) + value = value.squeeze(1).transpose(1, 2).reshape(batch_size, head_dim, height, width) + hidden_states = self.optimized_vae_attention(query, key, value).reshape(batch_size, self.heads, head_dim, height * width).transpose(2, 3) + else: + hidden_states = optimized_attention(query, key, value, heads = self.heads, mask = attention_mask, skip_reshape=True, skip_output_reshape=True) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if self.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / self.rescale_output_factor + + return hidden_states + + +def inflate_weight(weight_2d: torch.Tensor, weight_3d: torch.Tensor): + with torch.no_grad(): + depth = weight_3d.size(2) + weight_3d.copy_(weight_2d.unsqueeze(2).repeat(1, 1, depth, 1, 1) / depth) + return weight_3d + +def inflate_bias(bias_2d: torch.Tensor, bias_3d: torch.Tensor): + with torch.no_grad(): + bias_3d.copy_(bias_2d) + return bias_3d + + +def modify_state_dict(layer, state_dict, prefix, inflate_weight_fn, inflate_bias_fn): + weight_name = prefix + "weight" + bias_name = prefix + "bias" + if weight_name in state_dict: + weight_2d = state_dict[weight_name] + if weight_2d.dim() == 4: + weight_3d = inflate_weight_fn( + weight_2d=weight_2d, + weight_3d=layer.weight, + ) + state_dict[weight_name] = weight_3d + else: + return state_dict + if bias_name in state_dict: + bias_2d = state_dict[bias_name] + if bias_2d.dim() == 1: + bias_3d = inflate_bias_fn( + bias_2d=bias_2d, + bias_3d=layer.bias, + ) + state_dict[bias_name] = bias_3d + return state_dict + +def causal_norm_wrapper(norm_layer: nn.Module, x: torch.Tensor) -> torch.Tensor: + input_dtype = x.dtype + if isinstance(norm_layer, (ops.LayerNorm, ops.RMSNorm)): + if x.ndim == 4: + x = rearrange(x, "b c h w -> b h w c") + x = norm_layer(x) + x = rearrange(x, "b h w c -> b c h w") + return x.to(input_dtype) + if x.ndim == 5: + x = rearrange(x, "b c t h w -> b t h w c") + x = norm_layer(x) + x = rearrange(x, "b t h w c -> b c t h w") + return x.to(input_dtype) + if isinstance(norm_layer, (ops.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)): + if x.ndim <= 4: + return norm_layer(x).to(input_dtype) + if x.ndim == 5: + t = x.size(2) + x = rearrange(x, "b c t h w -> (b t) c h w") + memory_occupy = x.numel() * x.element_size() / 1024**3 + if isinstance(norm_layer, ops.GroupNorm) and memory_occupy > get_norm_limit(): + num_chunks = min(BYTEDANCE_GN_CHUNKS_FP16 if x.element_size() == 2 else BYTEDANCE_GN_CHUNKS_FP32, norm_layer.num_groups) + assert norm_layer.num_groups % num_chunks == 0 + num_groups_per_chunk = norm_layer.num_groups // num_chunks + + x = list(x.chunk(num_chunks, dim=1)) + weights = norm_layer.weight.chunk(num_chunks, dim=0) + biases = norm_layer.bias.chunk(num_chunks, dim=0) + for i, (w, b) in enumerate(zip(weights, biases)): + x[i] = F.group_norm(x[i], num_groups_per_chunk, w, b, norm_layer.eps) + x[i] = x[i].to(input_dtype) + x = torch.cat(x, dim=1) + else: + x = norm_layer(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + return x.to(input_dtype) + raise NotImplementedError + +def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', align_corners=None, recompute_scale_factor=None): + problematic_modes = ['bilinear', 'bicubic', 'trilinear'] + + if mode in problematic_modes: + try: + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ) + except RuntimeError as e: + if ("not implemented for 'Half'" in str(e) or + "compute_indices_weights" in str(e)): + original_dtype = x.dtype + return F.interpolate( + x.float(), + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ).to(original_dtype) + else: + raise e + else: + # Pour 'nearest' et autres modes compatibles, pas de fix nécessaire + return F.interpolate( + x, + size=size, + scale_factor=scale_factor, + mode=mode, + align_corners=align_corners, + recompute_scale_factor=recompute_scale_factor + ) + +_receptive_field_t = Literal["half", "full"] + +def extend_head(tensor, times: int = 2, memory = None): + if memory is not None: + return torch.cat((memory.to(tensor), tensor), dim=2) + assert times >= 0, "Invalid input for function 'extend_head'!" + if times == 0: + return tensor + else: + tile_repeat = [1] * tensor.ndim + tile_repeat[2] = times + return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2) + +def cache_send_recv(tensor, cache_size, times, memory=None): + recv_buffer = None + + if memory is not None: + recv_buffer = memory.to(tensor[0]) + elif times > 0: + tile_repeat = [1] * tensor[0].ndim + tile_repeat[2] = times + recv_buffer = torch.tile(tensor[0][:, :, :1], tile_repeat) + + return recv_buffer + +class InflatedCausalConv3d(ops.Conv3d): + def __init__( + self, + *args, + inflation_mode, + memory_device = "same", + **kwargs, + ): + self.inflation_mode = inflation_mode + self.memory = None + super().__init__(*args, **kwargs) + self.temporal_padding = self.padding[0] + self.memory_device = memory_device + self.padding = (0, *self.padding[1:]) + self.memory_limit = float("inf") + self.logged_once = False + + def set_memory_limit(self, value: float): + self.memory_limit = value + + def set_memory_device(self, memory_device): + self.memory_device = memory_device + + def _conv_forward(self, input, weight, bias, *args, **kwargs): + if (NVIDIA_MEMORY_CONV_BUG_WORKAROUND and + weight.dtype in (torch.float16, torch.bfloat16) and + hasattr(torch.backends.cudnn, 'is_available') and + torch.backends.cudnn.is_available() and + getattr(torch.backends.cudnn, 'enabled', True)): + try: + out = torch.cudnn_convolution( + input, weight, self.padding, self.stride, self.dilation, self.groups, + benchmark=False, deterministic=False, allow_tf32=True + ) + if bias is not None: + out += bias.reshape((1, -1) + (1,) * (out.ndim - 2)) + return out + except RuntimeError: + pass + except NotImplementedError: + pass + try: + return super()._conv_forward(input, weight, bias, *args, **kwargs) + except NotImplementedError: + # for: Could not run 'aten::cudnn_convolution' with arguments from the 'CPU' backend + if not self.logged_once: + logging.warning("VAE is on CPU for decoding. This is most likely due to not enough memory") + self.logged_once = True + return F.conv3d(input, weight, bias, *args, **kwargs) + + def memory_limit_conv( + self, + x, + *, + split_dim=3, + padding=(0, 0, 0, 0, 0, 0), + prev_cache=None, + ): + # Compatible with no limit. + if math.isinf(self.memory_limit): + if prev_cache is not None: + x = torch.cat([prev_cache, x], dim=split_dim - 1) + return super().forward(x) + + # Compute tensor shape after concat & padding. + shape = torch.tensor(x.size()) + if prev_cache is not None: + shape[split_dim - 1] += prev_cache.size(split_dim - 1) + shape[-3:] += torch.tensor(padding).view(3, 2).sum(-1).flip(0) + memory_occupy = shape.prod() * x.element_size() / 1024**3 # GiB + if memory_occupy < self.memory_limit or split_dim == x.ndim: + x_concat = x + if prev_cache is not None: + x_concat = torch.cat([prev_cache, x], dim=split_dim - 1) + + def pad_and_forward(): + padded = safe_pad_operation(x_concat, padding, mode='constant', value=0.0) + if not padded.is_contiguous(): + padded = padded.contiguous() + with ignore_padding(self): + return torch.nn.Conv3d.forward(self, padded) + + return pad_and_forward() + + num_splits = math.ceil(memory_occupy / self.memory_limit) + size_per_split = x.size(split_dim) // num_splits + split_sizes = [size_per_split] * (num_splits - 1) + split_sizes += [x.size(split_dim) - sum(split_sizes)] + + x = list(x.split(split_sizes, dim=split_dim)) + if prev_cache is not None: + prev_cache = list(prev_cache.split(split_sizes, dim=split_dim)) + cache = None + for idx in range(len(x)): + if prev_cache is not None: + x[idx] = torch.cat([prev_cache[idx], x[idx]], dim=split_dim - 1) + + lpad_dim = (x[idx].ndim - split_dim - 1) * 2 + rpad_dim = lpad_dim + 1 + padding = list(padding) + padding[lpad_dim] = self.padding[split_dim - 2] if idx == 0 else 0 + padding[rpad_dim] = self.padding[split_dim - 2] if idx == len(x) - 1 else 0 + pad_len = padding[lpad_dim] + padding[rpad_dim] + padding = tuple(padding) + + next_cache = None + cache_len = cache.size(split_dim) if cache is not None else 0 + next_catch_size = get_cache_size( + conv_module=self, + input_len=x[idx].size(split_dim) + cache_len, + pad_len=pad_len, + dim=split_dim - 2, + ) + if next_catch_size != 0: + assert next_catch_size <= x[idx].size(split_dim) + next_cache = ( + x[idx].transpose(0, split_dim)[-next_catch_size:].transpose(0, split_dim) + ) + + x[idx] = self.memory_limit_conv( + x[idx], + split_dim=split_dim + 1, + padding=padding, + prev_cache=cache + ) + + cache = next_cache + + output = torch.cat(x, dim=split_dim) + return output + + def forward( + self, + input, + memory_state: MemoryState = MemoryState.UNSET + ) -> Tensor: + assert memory_state != MemoryState.UNSET + if memory_state != MemoryState.ACTIVE: + self.memory = None + if ( + math.isinf(self.memory_limit) + and torch.is_tensor(input) + ): + return self.basic_forward(input, memory_state) + return self.slicing_forward(input, memory_state) + + def basic_forward(self, input: Tensor, memory_state: MemoryState = MemoryState.UNSET): + mem_size = self.stride[0] - self.kernel_size[0] + if (self.memory is not None) and (memory_state == MemoryState.ACTIVE): + input = extend_head(input, memory=self.memory, times=-1) + else: + input = extend_head(input, times=self.temporal_padding * 2) + memory = ( + input[:, :, mem_size:].detach() + if (mem_size != 0 and memory_state != MemoryState.DISABLED) + else None + ) + if ( + memory_state != MemoryState.DISABLED + and not self.training + and (self.memory_device is not None) + ): + self.memory = memory + if self.memory_device == "cpu" and self.memory is not None: + self.memory = self.memory.to("cpu") + return super().forward(input) + + def slicing_forward( + self, + input, + memory_state: MemoryState = MemoryState.UNSET, + ) -> Tensor: + squeeze_out = False + if torch.is_tensor(input): + input = [input] + squeeze_out = True + + cache_size = self.kernel_size[0] - self.stride[0] + cache = cache_send_recv( + input, cache_size=cache_size, memory=self.memory, times=self.temporal_padding * 2 + ) + + # Single GPU inference - simplified memory management + if ( + memory_state in [MemoryState.INITIALIZING, MemoryState.ACTIVE] # use_slicing + and not self.training + and (self.memory_device is not None) + and cache_size != 0 + ): + if cache_size > input[-1].size(2) and cache is not None and len(input) == 1: + input[0] = torch.cat([cache, input[0]], dim=2) + cache = None + if cache_size <= input[-1].size(2): + self.memory = input[-1][:, :, -cache_size:].detach().contiguous() + if self.memory_device == "cpu" and self.memory is not None: + self.memory = self.memory.to("cpu") + + padding = tuple(x for x in reversed(self.padding) for _ in range(2)) + for i in range(len(input)): + # Prepare cache for next input slice. + next_cache = None + cache_size = 0 + if i < len(input) - 1: + cache_len = cache.size(2) if cache is not None else 0 + cache_size = get_cache_size(self, input[i].size(2) + cache_len, pad_len=0) + if cache_size != 0: + if cache_size > input[i].size(2) and cache is not None: + input[i] = torch.cat([cache, input[i]], dim=2) + cache = None + assert cache_size <= input[i].size(2), f"{cache_size} > {input[i].size(2)}" + next_cache = input[i][:, :, -cache_size:] + + # Conv forward for this input slice. + input[i] = self.memory_limit_conv( + input[i], + padding=padding, + prev_cache=cache + ) + + # Update cache. + cache = next_cache + + return input[0] if squeeze_out else input + +def remove_head(tensor: Tensor, times: int = 1) -> Tensor: + if times == 0: + return tensor + return torch.cat(tensors=(tensor[:, :, :1], tensor[:, :, times + 1 :]), dim=2) + +class Upsample3D(nn.Module): + + def __init__( + self, + channels, + out_channels = None, + inflation_mode = "tail", + temporal_up: bool = False, + spatial_up: bool = True, + slicing: bool = False, + interpolate = True, + name: str = "conv", + use_conv_transpose = False, + use_conv: bool = False, + padding = 1, + bias = True, + kernel_size = None, + **kwargs, + ): + super().__init__() + self.interpolate = interpolate + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv_transpose = use_conv_transpose + self.use_conv = use_conv + self.name = name + + self.conv = None + if use_conv_transpose: + if kernel_size is None: + kernel_size = 4 + self.conv = ops.ConvTranspose2d( + channels, self.out_channels, kernel_size=kernel_size, stride=2, padding=padding, bias=bias + ) + elif use_conv: + if kernel_size is None: + kernel_size = 3 + self.conv = ops.Conv2d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + + conv = self.conv if self.name == "conv" else self.Conv2d_0 + + # Note: lora_layer is not passed into constructor in the original implementation. + # So we make a simplification. + conv = InflatedCausalConv3d( + self.channels, + self.out_channels, + 3, + padding=1, + inflation_mode=inflation_mode, + ) + + self.temporal_up = temporal_up + self.spatial_up = spatial_up + self.temporal_ratio = 2 if temporal_up else 1 + self.spatial_ratio = 2 if spatial_up else 1 + self.slicing = slicing + + assert not self.interpolate + # [Override] MAGViT v2 implementation + if not self.interpolate: + upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio + self.upscale_conv = ops.Conv3d( + self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 + ) + identity = ( + torch.eye(self.channels) + .repeat(upscale_ratio, 1) + .reshape_as(self.upscale_conv.weight) + ) + self.upscale_conv.weight.data.copy_(identity) + + if self.name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + self.norm = None + + def forward( + self, + hidden_states: torch.FloatTensor, + memory_state=None, + **kwargs, + ) -> torch.FloatTensor: + assert hidden_states.shape[1] == self.channels + + if hasattr(self, "norm") and self.norm is not None: + # [Overridden] change to causal norm. + hidden_states = causal_norm_wrapper(self.norm, hidden_states) + + if self.use_conv_transpose: + return self.conv(hidden_states) + + if self.slicing: + split_size = hidden_states.size(2) // 2 + hidden_states = list( + hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) + ) + else: + hidden_states = [hidden_states] + + for i in range(len(hidden_states)): + hidden_states[i] = self.upscale_conv(hidden_states[i]) + hidden_states[i] = rearrange( + hidden_states[i], + "b (x y z c) f h w -> b c (f z) (h x) (w y)", + x=self.spatial_ratio, + y=self.spatial_ratio, + z=self.temporal_ratio, + ) + + if self.temporal_up and memory_state != MemoryState.ACTIVE: + hidden_states[0] = remove_head(hidden_states[0]) + + if not self.slicing: + hidden_states = hidden_states[0] + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states, memory_state=memory_state) + else: + hidden_states = self.Conv2d_0(hidden_states, memory_state=memory_state) + + if not self.slicing: + return hidden_states + else: + return torch.cat(hidden_states, dim=2) + + +class Downsample3D(nn.Module): + """A 3D downsampling layer with an optional convolution.""" + + def __init__( + self, + channels, + out_channels = None, + inflation_mode = "tail", + spatial_down: bool = False, + temporal_down: bool = False, + name: str = "conv", + kernel_size=3, + use_conv: bool = False, + padding = 1, + bias=True, + **kwargs, + ): + super().__init__() + self.padding = padding + self.name = name + self.channels = channels + self.out_channels = out_channels or channels + self.temporal_down = temporal_down + self.spatial_down = spatial_down + self.use_conv = use_conv + self.padding = padding + + self.temporal_ratio = 2 if temporal_down else 1 + self.spatial_ratio = 2 if spatial_down else 1 + + self.temporal_kernel = 3 if temporal_down else 1 + self.spatial_kernel = 3 if spatial_down else 1 + + if use_conv: + conv = InflatedCausalConv3d( + self.channels, + self.out_channels, + kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), + stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + padding=( + 1 if self.temporal_down else 0, + self.padding if self.spatial_down else 0, + self.padding if self.spatial_down else 0, + ), + inflation_mode=inflation_mode, + ) + else: + assert self.channels == self.out_channels + conv = nn.AvgPool3d( + kernel_size=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), + ) + + self.conv = conv + + + def forward( + self, + hidden_states: torch.FloatTensor, + memory_state = None, + **kwargs, + ) -> torch.FloatTensor: + + assert hidden_states.shape[1] == self.channels + + if hasattr(self, "norm") and self.norm is not None: + # [Overridden] change to causal norm. + hidden_states = causal_norm_wrapper(self.norm, hidden_states) + + if self.use_conv and self.padding == 0 and self.spatial_down: + pad = (0, 1, 0, 1) + hidden_states = safe_pad_operation(hidden_states, pad, mode="constant", value=0) + + assert hidden_states.shape[1] == self.channels + + hidden_states = self.conv(hidden_states, memory_state=memory_state) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + temb_channels: int = 512, + groups: int = 32, + groups_out: Optional[int] = None, + eps: float = 1e-6, + non_linearity: str = "swish", + time_embedding_norm: str = "default", + output_scale_factor: float = 1.0, + skip_time_act: bool = False, + use_in_shortcut: Optional[bool] = None, + up: bool = False, + down: bool = False, + conv_shortcut_bias: bool = True, + conv_2d_out_channels: Optional[int] = None, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + slicing: bool = False, + **kwargs, + ): + super().__init__() + self.up = up + self.down = down + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + conv_2d_out_channels = conv_2d_out_channels or out_channels + self.use_in_shortcut = use_in_shortcut + self.output_scale_factor = output_scale_factor + self.skip_time_act = skip_time_act + self.nonlinearity = nn.SiLU() + if temb_channels is not None: + self.time_emb_proj = ops.Linear(temb_channels, out_channels) + else: + self.time_emb_proj = None + self.norm1 = ops.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + if groups_out is None: + groups_out = groups + self.norm2 = ops.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.use_in_shortcut = self.in_channels != out_channels + self.dropout = torch.nn.Dropout(dropout) + self.conv1 = InflatedCausalConv3d( + self.in_channels, + self.out_channels, + kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), + stride=1, + padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), + inflation_mode=inflation_mode, + ) + + self.conv2 = InflatedCausalConv3d( + self.out_channels, + conv_2d_out_channels, + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.upsample = self.downsample = None + if self.up: + self.upsample = Upsample3D( + self.in_channels, + use_conv=False, + inflation_mode=inflation_mode, + slicing=slicing, + ) + elif self.down: + self.downsample = Downsample3D( + self.in_channels, + use_conv=False, + padding=1, + name="op", + inflation_mode=inflation_mode, + ) + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedCausalConv3d( + self.in_channels, + conv_2d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=True, + inflation_mode=inflation_mode, + ) + + def forward( + self, input_tensor, temb, memory_state = None, **kwargs + ): + hidden_states = input_tensor + + hidden_states = causal_norm_wrapper(self.norm1, hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + if hidden_states.shape[0] >= BYTEDANCE_CONTIGUOUS_BATCH_THRESHOLD: + input_tensor = input_tensor.contiguous() + hidden_states = hidden_states.contiguous() + input_tensor = self.upsample(input_tensor, memory_state=memory_state) + hidden_states = self.upsample(hidden_states, memory_state=memory_state) + elif self.downsample is not None: + input_tensor = self.downsample(input_tensor, memory_state=memory_state) + hidden_states = self.downsample(hidden_states, memory_state=memory_state) + + hidden_states = self.conv1(hidden_states, memory_state=memory_state) + + if self.time_emb_proj is not None: + if not self.skip_time_act: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + + if temb is not None: + hidden_states = hidden_states + temb + + hidden_states = causal_norm_wrapper(self.norm2, hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states, memory_state=memory_state) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_down: bool = True, + spatial_down: bool = True, + ): + super().__init__() + resnets = [] + temporal_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + # [Override] Replace module. + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + temporal_modules.append(nn.Identity()) + + self.resnets = nn.ModuleList(resnets) + self.temporal_modules = nn.ModuleList(temporal_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + temporal_down=temporal_down, + spatial_down=spatial_down, + inflation_mode=inflation_mode, + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + memory_state = None, + **kwargs, + ) -> torch.FloatTensor: + for resnet, temporal in zip(self.resnets, self.temporal_modules): + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) + hidden_states = temporal(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, memory_state=memory_state) + + return hidden_states + + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up: bool = True, + spatial_up: bool = True, + slicing: bool = False, + ): + super().__init__() + resnets = [] + temporal_modules = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + # [Override] Replace module. + ResnetBlock3D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + slicing=slicing, + ) + ) + + temporal_modules.append(nn.Identity()) + + self.resnets = nn.ModuleList(resnets) + self.temporal_modules = nn.ModuleList(temporal_modules) + + if add_upsample: + # [Override] Replace module & use learnable upsample + self.upsamplers = nn.ModuleList( + [ + Upsample3D( + out_channels, + use_conv=True, + out_channels=out_channels, + temporal_up=temporal_up, + spatial_up=spatial_up, + interpolate=False, + inflation_mode=inflation_mode, + slicing=slicing, + ) + ] + ) + else: + self.upsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + memory_state=None + ) -> torch.FloatTensor: + for resnet, temporal in zip(self.resnets, self.temporal_modules): + hidden_states = resnet(hidden_states, temb=None, memory_state=memory_state) + hidden_states = temporal(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, memory_state=memory_state) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + # there is always at least one resnet + resnets = [ + # [Override] Replace module. + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ] + attentions = [] + + if attention_head_dim is None: + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=( + resnet_groups if resnet_time_scale_shift == "default" else None + ), + spatial_norm_dim=( + temb_channels if resnet_time_scale_shift == "spatial" else None + ), + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, memory_state=None): + video_length, frame_height, frame_width = hidden_states.size()[-3:] + hidden_states = self.resnets[0](hidden_states, temb, memory_state=memory_state) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = attn(hidden_states, temb=temb) + hidden_states = rearrange( + hidden_states, "(b f) c h w -> b c f h w", f=video_length + ) + hidden_states = resnet(hidden_states, temb, memory_state=memory_state) + + return hidden_states + + +class Encoder3D(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str, ...] = ("DownEncoderBlock3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + mid_block_add_attention=True, + # [Override] add extra_cond_dim, temporal down num + temporal_down_num: int = 2, + extra_cond_dim: int = None, + gradient_checkpoint: bool = False, + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + ): + super().__init__() + self.layers_per_block = layers_per_block + self.temporal_down_num = temporal_down_num + + self.conv_in = InflatedCausalConv3d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.mid_block = None + self.down_blocks = nn.ModuleList([]) + self.extra_cond_dim = extra_cond_dim + + self.conv_extra_cond = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + # [Override] to support temporal down block design + is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 + # Note: take the last ones + + assert down_block_type == "DownEncoderBlock3D" + + down_block = DownEncoderBlock3D( + num_layers=self.layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + add_downsample=not is_final_block, + resnet_eps=1e-6, + downsample_padding=0, + # Note: Don't know why set it as 0 + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + temporal_down=is_temporal_down_block, + spatial_down=True, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.down_blocks.append(down_block) + + def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + self.conv_extra_cond.append( + zero_module( + ops.Conv3d(extra_cond_dim, output_channel, kernel_size=1, stride=1, padding=0) + ) + if self.extra_cond_dim is not None and self.extra_cond_dim > 0 + else None + ) + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default", + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=None, + add_attention=mid_block_add_attention, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # out + self.conv_norm_out = ops.GroupNorm( + num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = InflatedCausalConv3d( + block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + self.gradient_checkpointing = gradient_checkpoint + + def forward( + self, + sample: torch.FloatTensor, + extra_cond=None, + memory_state = None + ) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + sample = sample.to(next(self.parameters()).device) + sample = self.conv_in(sample, memory_state = memory_state) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + # down + # [Override] add extra block and extra cond + for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): + sample = torch.utils.checkpoint.checkpoint( + create_custom_forward(down_block), sample, use_reentrant=False + ) + if extra_block is not None: + sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) + + # middle + sample = self.mid_block(sample) + + else: + # down + # [Override] add extra block and extra cond + for down_block, extra_block in zip(self.down_blocks, self.conv_extra_cond): + sample = down_block(sample, memory_state=memory_state) + if extra_block is not None: + sample = sample + safe_interpolate_operation(extra_block(extra_cond), size=sample.shape[2:]) + + # middle + sample = self.mid_block(sample, memory_state=memory_state) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample, memory_state = memory_state) + + return sample + + +class Decoder3D(nn.Module): + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + up_block_types: Tuple[str, ...] = ("UpDecoderBlock3D",), + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + mid_block_add_attention=True, + # [Override] add temporal up block + inflation_mode = "tail", + time_receptive_field: _receptive_field_t = "half", + temporal_up_num: int = 2, + slicing_up_num: int = 0, + gradient_checkpoint: bool = False, + ): + super().__init__() + self.layers_per_block = layers_per_block + self.temporal_up_num = temporal_up_num + + self.conv_in = InflatedCausalConv3d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + inflation_mode=inflation_mode, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_time_scale_shift="default" if norm_type == "group" else norm_type, + attention_head_dim=block_out_channels[-1], + resnet_groups=norm_num_groups, + temb_channels=temb_channels, + add_attention=mid_block_add_attention, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + is_temporal_up_block = i < self.temporal_up_num + is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num + # Note: Keep symmetric + + assert up_block_type == "UpDecoderBlock3D" + up_block = UpDecoderBlock3D( + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=norm_type, + temb_channels=temb_channels, + temporal_up=is_temporal_up_block, + slicing=is_slicing_up_block, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = ops.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + self.conv_act = nn.SiLU() + self.conv_out = InflatedCausalConv3d( + block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode + ) + + self.gradient_checkpointing = gradient_checkpoint + + # Note: Just copy from Decoder. + def forward( + self, + sample: torch.FloatTensor, + latent_embeds: Optional[torch.FloatTensor] = None, + memory_state = None, + ) -> torch.FloatTensor: + + sample = sample.to(next(self.parameters()).device) + sample = self.conv_in(sample, memory_state=memory_state) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + # middle + sample = self.mid_block(sample, latent_embeds, memory_state=memory_state) + sample = sample.to(upscale_dtype) + + # up + for up_block in self.up_blocks: + sample = up_block(sample, latent_embeds, memory_state=memory_state) + + # post-process + sample = causal_norm_wrapper(self.conv_norm_out, sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample, memory_state=memory_state) + + return sample + +class VideoAutoencoderKL(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + layers_per_block: int = 2, + act_fn: str = "silu", + latent_channels: int = SEEDVR2_LATENT_CHANNELS, + norm_num_groups: int = 32, + attention: bool = True, + temporal_scale_num: int = 2, + slicing_up_num: int = 0, + gradient_checkpoint: bool = False, + inflation_mode = "pad", + time_receptive_field: _receptive_field_t = "full", + use_quant_conv: bool = False, + use_post_quant_conv: bool = False, + slicing_sample_min_size = BYTEDANCE_SLICING_SAMPLE_MIN, + *args, + **kwargs, + ): + self.slicing_sample_min_size = slicing_sample_min_size + self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) + extra_cond_dim = kwargs.pop("extra_cond_dim") if "extra_cond_dim" in kwargs else None + block_out_channels = BYTEDANCE_BLOCK_OUT_CHANNELS + down_block_types = ("DownEncoderBlock3D",) * 4 + up_block_types = ("UpDecoderBlock3D",) * 4 + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder3D( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + extra_cond_dim=extra_cond_dim, + # [Override] add temporal_down_num parameter + temporal_down_num=temporal_scale_num, + gradient_checkpoint=gradient_checkpoint, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + # pass init params to Decoder + self.decoder = Decoder3D( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + # [Override] add temporal_up_num parameter + temporal_up_num=temporal_scale_num, + slicing_up_num=slicing_up_num, + gradient_checkpoint=gradient_checkpoint, + inflation_mode=inflation_mode, + time_receptive_field=time_receptive_field, + ) + + self.quant_conv = ( + InflatedCausalConv3d( + in_channels=2 * latent_channels, + out_channels=2 * latent_channels, + kernel_size=1, + inflation_mode=inflation_mode, + ) + if use_quant_conv + else None + ) + self.post_quant_conv = ( + InflatedCausalConv3d( + in_channels=latent_channels, + out_channels=latent_channels, + kernel_size=1, + inflation_mode=inflation_mode, + ) + if use_post_quant_conv + else None + ) + + # A hacky way to remove attention. + if not attention: + self.encoder.mid_block.attentions = torch.nn.ModuleList([None]) + self.decoder.mid_block.attentions = torch.nn.ModuleList([None]) + + self.use_slicing = True + + def encode(self, x: torch.FloatTensor, return_dict: bool = True): + h = self.slicing_encode(x) + posterior = DiagonalGaussianDistribution(h).mode() + + if not return_dict: + return (posterior,) + + return posterior + + def decode_( + self, z: torch.Tensor, return_dict: bool = True + ): + decoded = self.slicing_decode(z) + + if not return_dict: + return (decoded,) + + return decoded + + def _encode( + self, x, memory_state = MemoryState.DISABLED + ) -> torch.Tensor: + _x = x.to(self.device) + h = self.encoder(_x, memory_state=memory_state) + if self.quant_conv is not None: + output = self.quant_conv(h, memory_state=memory_state) + else: + output = h + return output.to(x.device) + + def _decode( + self, z, memory_state = MemoryState.DISABLED + ) -> torch.Tensor: + _z = z.to(self.device) + + if self.post_quant_conv is not None: + _z = self.post_quant_conv(_z, memory_state=memory_state) + + output = self.decoder(_z, memory_state=memory_state) + return output.to(z.device) + + def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: + sp_size =1 + if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: + split_size = max( + self.slicing_sample_min_size * sp_size, + getattr(self, "temporal_downsample_factor", 1), + ) + x_slices = list(x[:, :, 1:].split(split_size=split_size, dim=2)) + min_active_len = getattr(self, "temporal_downsample_factor", 1) + if len(x_slices) > 1 and x_slices[-1].shape[2] < min_active_len: + x_slices[-2] = torch.cat((x_slices[-2], x_slices[-1]), dim=2) + x_slices.pop() + encoded_slices = [ + self._encode( + torch.cat((x[:, :, :1], x_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING, + ) + ] + for x_idx in range(1, len(x_slices)): + encoded_slices.append( + self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) + ) + out = torch.cat(encoded_slices, dim=2) + modules_with_memory = [m for m in self.modules() + if isinstance(m, InflatedCausalConv3d) and m.memory is not None] + for m in modules_with_memory: + m.memory = None + return out + else: + return self._encode(x) + + def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: + sp_size = 1 + if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: + z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) + decoded_slices = [ + self._decode( + torch.cat((z[:, :, :1], z_slices[0]), dim=2), + memory_state=MemoryState.INITIALIZING + ) + ] + for z_idx in range(1, len(z_slices)): + decoded_slices.append( + self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) + ) + out = torch.cat(decoded_slices, dim=2) + modules_with_memory = [m for m in self.modules() + if isinstance(m, InflatedCausalConv3d) and m.memory is not None] + for m in modules_with_memory: + m.memory = None + return out + else: + return self._decode(z) + + def tiled_encode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + raise NotImplementedError + + def tiled_decode(self, z: torch.Tensor, **kwargs) -> torch.Tensor: + raise NotImplementedError + + def forward( + self, x: torch.FloatTensor, mode: Literal["encode", "decode", "all"] = "all", **kwargs + ): + # x: [b c t h w] + def _unwrap(value): + return value[0] if isinstance(value, tuple) else value + + if mode == "encode": + return _unwrap(self.encode(x)) + elif mode == "decode": + return _unwrap(self.decode_(x)) + else: + latent = _unwrap(self.encode(x)) + return _unwrap(self.decode_(latent)) + +class VideoAutoencoderKLWrapper(VideoAutoencoderKL): + def __init__( + self, + *args, + spatial_downsample_factor = 8, + temporal_downsample_factor = 4, + freeze_encoder = True, + **kwargs, + ): + self.spatial_downsample_factor = spatial_downsample_factor + self.temporal_downsample_factor = temporal_downsample_factor + self.freeze_encoder = freeze_encoder + self.enable_tiling = False + super().__init__(*args, **kwargs) + self.set_memory_limit(BYTEDANCE_VAE_CONV_MEM_GIB, BYTEDANCE_VAE_NORM_MEM_GIB) + + def forward(self, x: torch.FloatTensor): + with torch.no_grad() if self.freeze_encoder else nullcontext(): + z, p = self.encode(x) + x = self.decode(z) + return x, z, p + + def encode(self, x, orig_dims=None): + if x.ndim == 4: + x = x.unsqueeze(2) + x = x.to(dtype=next(self.parameters()).dtype) + self.device = x.device + p = super().encode(x) + z = p.squeeze(2) + return z, p + + def decode(self, z, seedvr2_tiling=None): + seedvr2_tiling = {} if seedvr2_tiling is None else seedvr2_tiling + if not isinstance(seedvr2_tiling, dict): + raise RuntimeError( + "SeedVR2 VideoAutoencoderKLWrapper.decode: `seedvr2_tiling` must be a dict; " + f"got {type(seedvr2_tiling).__name__} with value {seedvr2_tiling!r}." + ) + + if z.ndim == 5: + b, c, t_latent, h, w = z.shape + if c != 16: + raise RuntimeError( + "SeedVR2 VideoAutoencoderKLWrapper.decode: 5-D latent input must " + f"have 16 channels; got shape {tuple(z.shape)}." + ) + latent = z + elif z.ndim == 4: + b, tc, h, w = z.shape + if tc % 16 != 0: + raise RuntimeError( + "SeedVR2 VideoAutoencoderKLWrapper.decode: 4-D latent input must " + "use collapsed channel layout (B, 16*T, H, W); " + f"got shape {tuple(z.shape)}." + ) + latent = z.reshape(b, 16, -1, h, w) + else: + raise RuntimeError( + "SeedVR2 VideoAutoencoderKLWrapper.decode: latent input must be " + "4-D collapsed (B, 16*T, H, W) or 5-D (B, 16, T, H, W); " + f"got shape {tuple(z.shape)}." + ) + scale = BYTEDANCE_VAE_SCALING_FACTOR + shift = BYTEDANCE_VAE_SHIFTING_FACTOR + latent = latent / scale + shift + + self.device = latent.device + self.enable_tiling = seedvr2_tiling.get("enable_tiling", False) + + if self.enable_tiling: + decode_seedvr2_args = dict(seedvr2_tiling) + tile_h, tile_w = decode_seedvr2_args.get("tile_size", (512, 512)) + ov_h, ov_w = decode_seedvr2_args.get("tile_overlap", (64, 64)) + decode_seedvr2_args["tile_overlap"] = ( + min(ov_h, max(0, tile_h - 8)), + min(ov_w, max(0, tile_w - 8)), + ) + x = tiled_vae(latent, self, **decode_seedvr2_args, encode=False) + if x.ndim == 4: + # tiled_vae squeezes the temporal axis when + # temporal_downsample_factor == 1 AND latent T == 1 + # (see tiled_vae line 179-180); re-add it so the post-decode + # pipeline can keep batch and time distinct on the tiled path. + x = x.unsqueeze(2) + else: + x = super().decode_(latent) + + # ensure even dims for save video + h, w = x.shape[-2:] + w2 = w - (w % 2) + h2 = h - (h % 2) + x = x[..., :h2, :w2] + + return x + + def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float], memory_device = "same"): + set_norm_limit(norm_max_mem) + for m in self.modules(): + if isinstance(m, InflatedCausalConv3d): + m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) + + for module in self.modules(): + if isinstance(module, InflatedCausalConv3d): + module.set_memory_device(memory_device) diff --git a/comfy/ldm/triposplat/gaussian.py b/comfy/ldm/triposplat/gaussian.py new file mode 100644 index 000000000..a4cd2f62f --- /dev/null +++ b/comfy/ldm/triposplat/gaussian.py @@ -0,0 +1,199 @@ +# TripoSplat 3D gaussian container. Operates on already-decoded +# tensors and exposes them as render-ready tensors (render_tensors) for the generic SPLAT type. +import torch +import torch.nn.functional as F + +import comfy.model_management + + +class GaussianModel: + def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0, + scaling_bias: float = 0.01, opacity_bias: float = 0.1, + scaling_activation: str = "exp", device=None): + self.sh_degree = sh_degree + self.mininum_kernel_size = mininum_kernel_size + self.scaling_bias = scaling_bias + self.opacity_bias = opacity_bias + self.device = device + self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + + if scaling_activation == "exp": + self._scaling_activation = torch.exp + self._inverse_scaling_activation = torch.log + elif scaling_activation == "softplus": + self._scaling_activation = F.softplus + self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x)) + + self._opacity_activation = torch.sigmoid + self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x)) + + self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device) + self.rots_bias = torch.zeros(4, device=self.device) + self.rots_bias[0] = 1 + self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device) + + self._storage = {} + + def _get_store(self, name): + return self._storage.get(name) + + def _set_store(self, name, value): + self._storage[name] = value + + @property + def _xyz(self): + return self._get_store("_xyz") + @_xyz.setter + def _xyz(self, value): + if value is None: + self._set_store("_xyz", None) + self._set_store("xyz", None) + return + self._set_store("_xyz", value) + self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3]) + + @property + def get_xyz(self): + return self._get_store("xyz") + + @property + def _features_dc(self): + return self._get_store("_features_dc") + @_features_dc.setter + def _features_dc(self, value): + self._set_store("_features_dc", value) + + @property + def _opacity(self): + return self._get_store("_opacity") + @_opacity.setter + def _opacity(self, value): + if value is None: + self._set_store("_opacity", None) + self._set_store("opacity", None) + return + self._set_store("_opacity", value) + self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val)) + + @property + def get_opacity(self): + return self._get_store("opacity") + + @property + def _scaling(self): + return self._get_store("_scaling") + @_scaling.setter + def _scaling(self, value): + if value is None: + self._set_store("_scaling", None) + self._set_store("scaling", None) + return + self._set_store("_scaling", value) + s = self._scaling_activation(value + self.scale_bias) + s = torch.square(s) + self.mininum_kernel_size ** 2 + self._set_store("scaling", torch.sqrt(s)) + + @property + def get_scaling(self): + return self._get_store("scaling") + + @property + def _rotation(self): + return self._get_store("_rotation") + @_rotation.setter + def _rotation(self, value): + self._set_store("_rotation", value) + + _DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]] + + def render_tensors(self): + # Render-ready (activated, world-space) tensors for the generic SPLAT type. The axis transform + # (a 3x3 rotation, object frame -> viewer Y-up) is baked into positions and rotations. + # Returns float tensors on the intermediate device: positions (N,3), scales (N,3) linear, + # rotations (N,4) wxyz, opacities (N,1) in [0,1], sh (N,K,3) coefficients. + xyz = self.get_xyz.float() + scaling = self.get_scaling.float() + opacity = self.get_opacity.float() + rotation = (self._rotation + self.rots_bias[None, :]).float() + sh = self._features_dc.float() # (N, K, 3) + T = torch.as_tensor(self._DEFAULT_TRANSFORM, dtype=torch.float32, device=xyz.device) + xyz = xyz @ T.T + rotation = _matrix_to_quat(torch.matmul(T, _quat_to_matrix(rotation))) + rotation = rotation / torch.linalg.norm(rotation, dim=-1, keepdim=True) + out_device = comfy.model_management.intermediate_device() + return ( + xyz.to(out_device).contiguous(), scaling.to(out_device).contiguous(), + rotation.to(out_device).contiguous(), opacity.to(out_device).contiguous(), + sh.to(out_device).contiguous(), + ) + + +def _quat_to_matrix(q): + q = q / torch.linalg.norm(q, dim=-1, keepdim=True) + w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3] + R = torch.stack([ + 1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y), + 2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x), + 2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y), + ], dim=-1).reshape(-1, 3, 3) + return R + + +def _matrix_to_quat(R): + trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2] + q = torch.zeros((R.shape[0], 4), dtype=R.dtype, device=R.device) + s = torch.sqrt(torch.clamp(trace + 1, min=0)) * 2 + q[:, 0] = 0.25 * s + denom = torch.where(s != 0, s, torch.ones_like(s)) + q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / denom + q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / denom + q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / denom + m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0) + s1 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=0)) * 2 + q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01] + q[m01, 1] = 0.25 * s1[m01] + q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01] + q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01] + m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0) + s2 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=0)) * 2 + q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11] + q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11] + q[m11, 2] = 0.25 * s2[m11] + q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11] + m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0) + s3 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=0)) * 2 + q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21] + q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21] + q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21] + q[m21, 3] = 0.25 * s3[m21] + return q / torch.linalg.norm(q, dim=-1, keepdim=True) + + +def build_gaussian_models(decoder, points_pred: dict, pred: dict): + # Assemble GaussianModels from the elastic decoder layout. decoder is the ElasticGaussianFixedlenDecoder + # (carries layout / rep_config / _get_offset) + x = points_pred + offset = decoder._get_offset(pred['features']) + h = pred["features"] + ret = [] + for i in range(h.shape[0]): + g = GaussianModel( + sh_degree=0, + aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0], + mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'], + scaling_bias=decoder.rep_config['scaling_bias'], + opacity_bias=decoder.rep_config['opacity_bias'], + scaling_activation=decoder.rep_config['scaling_activation'], + device=h.device, + ) + _x = x["points"][i, :, None, :] + for k, v in decoder.layout.items(): + if k == '_xyz': + setattr(g, k, (offset[i] + _x).flatten(0, 1)) + elif k in ('_xyz_center', '_offset_scale'): + continue + else: + feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1) + setattr(g, k, feats * decoder.rep_config['lr'][k]) + ret.append(g) + return ret diff --git a/comfy/ldm/triposplat/model.py b/comfy/ldm/triposplat/model.py new file mode 100644 index 000000000..d8a531772 --- /dev/null +++ b/comfy/ldm/triposplat/model.py @@ -0,0 +1,326 @@ +# TripoSplat flow-matching denoiser (LatentSeqMMFlowModel). Registered as a ModelType.FLOW arch and +# driven by the standard KSampler; jointly denoises the (B, 8192, 16) latent and a (B, 1, 5) camera token +# carried as a 2-element nested latent. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.model_management +import comfy.patcher_extension +import comfy.rmsnorm +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope + + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim, heads, dtype=None, device=None): + super().__init__() + self.gamma = nn.Parameter(torch.empty(heads, dim, dtype=dtype, device=device)) + + def forward(self, x): + x = comfy.rmsnorm.rms_norm(x) + return x * comfy.model_management.cast_to(self.gamma, x.dtype, x.device) + + +# Positional embeddings + +class RePo3DRotaryEmbedding(nn.Module): + def __init__(self, model_channels, num_heads, head_dim, repo_hidden_ratio=0.125, max_freq=16.0, + dtype=None, device=None, operations=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + repo_hidden_size = int(model_channels * repo_hidden_ratio) + self.norm = operations.LayerNorm(model_channels, dtype=dtype, device=device) + self.gate_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device) + self.content_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device) + self.act = nn.SiLU() + self.final_map = operations.Linear(repo_hidden_size, 3 * num_heads, bias=False, dtype=dtype, device=device) + self.dim_0 = 2 * (head_dim // 6) + self.dim_1 = 2 * (head_dim // 6) + self.dim_2 = head_dim - self.dim_0 - self.dim_1 + dims = [self.dim_0, self.dim_1, self.dim_2] + freqs_list = [] + for d in dims: + freq_dim = d // 2 + freqs_list.append(torch.linspace(1.0, float(max_freq), steps=freq_dim, dtype=torch.float32)) + self.freqs_0 = nn.Parameter(freqs_list[0]) + self.freqs_1 = nn.Parameter(freqs_list[1]) + self.freqs_2 = nn.Parameter(freqs_list[2]) + + def forward(self, hidden_states): + h = self.norm(hidden_states) + feat = self.act(self.gate_map(h)) * self.content_map(h) + out = self.final_map(feat) + B, L, _ = out.shape + delta_pos = out.reshape(B, L, self.num_heads, 3) + f0 = comfy.model_management.cast_to(self.freqs_0, torch.float32, out.device) + f1 = comfy.model_management.cast_to(self.freqs_1, torch.float32, out.device) + f2 = comfy.model_management.cast_to(self.freqs_2, torch.float32, out.device) + ang_0 = delta_pos[..., 0].unsqueeze(-1) * f0 * torch.pi + ang_1 = delta_pos[..., 1].unsqueeze(-1) * f1 * torch.pi + ang_2 = delta_pos[..., 2].unsqueeze(-1) * f2 * torch.pi + ang = torch.cat([ang_0, ang_1, ang_2], dim=-1).float() # (B, L, heads, head_dim/2) + cos, sin = ang.cos(), ang.sin() + return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*ang.shape, 2, 2) + + +class PcdAbsolutePositionEmbedder(nn.Module): + # Sinusoidal absolute position embedding. Two fixed schedules are used in TripoSplat: + # "pow2" (flow-model latent anchors) and "log2" (octree / gaussian decoders). + def __init__(self, channels: int, in_channels: int = 3, max_res: int = 16, schedule: str = "pow2"): + super().__init__() + self.channels = channels + self.in_channels = in_channels + self.max_res = max_res + self.schedule = schedule + self.freq_dim = channels // in_channels // 2 + + def _freqs(self, device): + if self.schedule == "pow2": + freqs_2exp = torch.arange(self.max_res, dtype=torch.float32, device=device) + res_dim = max(0, self.freq_dim - self.max_res) + freqs_res = (torch.arange(res_dim, dtype=torch.float32, device=device) / max(res_dim, 1) * self.max_res + if res_dim > 0 else torch.empty(0, device=device)) + freqs = torch.cat([freqs_2exp, freqs_res], dim=0)[:self.freq_dim] + return torch.pow(2.0, freqs) * 2.0 # *2 folds this schedule's 2*pi into the shared *pi below + logs = torch.linspace(0.0, float(self.max_res), steps=self.freq_dim, dtype=torch.float32, device=device) + return torch.pow(2.0, logs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + orig_dtype = x.dtype + x = x.float() + *dims, D = x.shape + out = torch.outer(x.reshape(-1), self._freqs(x.device)) * torch.pi + out = torch.cat([out.sin(), out.cos()], dim=-1).reshape(*dims, -1) + if out.shape[-1] < self.channels: + out = torch.cat([out, torch.zeros(*dims, self.channels - out.shape[-1], + device=out.device, dtype=out.dtype)], dim=-1) + return out.to(orig_dtype) + + +def attention(q, k, v, transformer_options=None): + # q, k, v: (B, L, heads, dim) -> (B, L, heads, dim). Shared optimized_attention call convention. + out = optimized_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), heads=q.shape[2], + skip_reshape=True, skip_output_reshape=True, low_precision_attention=False, + transformer_options=transformer_options) + return out.transpose(1, 2) + + +# Transformer building blocks + +class MLP(nn.Module): + def __init__(self, in_channels, hidden_channels, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(in_channels, hidden_channels, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(hidden_channels, out_channels, dtype=dtype, device=device), + ) + + def forward(self, x): + return self.mlp(x) + + +class RopeMultiHeadAttention(nn.Module): + def __init__(self, channels, num_heads, qkv_bias=True, qk_rms_norm=False, use_rope=False, + dtype=None, device=None, operations=None): + super().__init__() + self.channels = channels + self.num_heads = num_heads + self.head_dim = channels // num_heads + self.qk_rms_norm = qk_rms_norm + self.use_rope = use_rope + self.qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device) + if self.qk_rms_norm: + self.q_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.k_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.out = operations.Linear(channels, channels, dtype=dtype, device=device) + + def forward(self, x, rope_emb=None, transformer_options=None): + B, L, C = x.shape + qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim) + q, k, v = qkv.unbind(2) + if self.use_rope: + q, k = apply_rope(q, k, rope_emb) + if self.qk_rms_norm: + q = self.q_norm(q) + k = self.k_norm(k) + h = attention(q, k, v, transformer_options) # (B, L, heads, dim) + return self.out(h.reshape(B, L, C)) + + +class UnifiedTransformerBlock(nn.Module): + def __init__(self, channels, num_heads, mlp_ratio=4.0, + use_rope=False, qk_rms_norm=False, qkv_bias=True, + modulation=True, share_mod=False, + dtype=None, device=None, operations=None): + super().__init__() + self.modulation = modulation + self.share_mod = share_mod + self.norm1 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device) + self.attn = RopeMultiHeadAttention(channels, num_heads=num_heads, + qkv_bias=qkv_bias, use_rope=use_rope, qk_rms_norm=qk_rms_norm, + dtype=dtype, device=device, operations=operations) + self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations) + if modulation: + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)) + self.shift_table = nn.Parameter(torch.empty(1, 6 * channels, dtype=dtype, device=device)) + + def forward(self, x, mod=None, rotary_emb=None, transformer_options=None): + if self.modulation: + if not self.share_mod: + mod = self.adaLN_modulation(mod) + mod = mod + comfy.model_management.cast_to(self.shift_table, mod.dtype, mod.device) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1)) + x = torch.addcmul(x, self.attn(h, rope_emb=rotary_emb, transformer_options=transformer_options), gate_msa.unsqueeze(1)) + h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1)) + x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1)) + else: + x = x + self.attn(self.norm1(x), rope_emb=rotary_emb, transformer_options=transformer_options) + x = x + self.mlp(self.norm2(x)) + return x + + +class TimestepEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + half = dim // 2 + freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + emb = self.timestep_embedding(t, self.frequency_embedding_size) + return self.mlp(emb.to(self.mlp[0].weight.dtype)) + + +class LatentSeqMMFlowModel(nn.Module): + def __init__(self, image_model=None, q_token_length=8192, in_channels=16, model_channels=1024, + cond_channels=1280, out_channels=16, num_blocks=24, num_refiner_blocks=2, + num_heads=None, num_head_channels=64, cam_channels=5, cond2_channels=128, + mlp_ratio=4, share_mod=True, qk_rms_norm=True, + dtype=None, device=None, operations=None, **kwargs): + super().__init__() + self.dtype = dtype + self.q_token_length = q_token_length + self.in_channels = in_channels + self.cam_channels = cam_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.cond2_channels = cond2_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_refiner_blocks = num_refiner_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.share_mod = share_mod + self.qk_rms_norm = qk_rms_norm + + factory_kwargs = dict(dtype=dtype, device=device) + op_kwargs = dict(operations=operations, **factory_kwargs) + + self.t_embedder = TimestepEmbedder(model_channels, **op_kwargs) + if share_mod: + self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, **factory_kwargs)) + + self.input_layer = operations.Linear(in_channels, model_channels, **factory_kwargs) + self.cond_embedder = operations.Linear(cond_channels, model_channels, **factory_kwargs) + self.cond_embedder2 = operations.Linear(cond2_channels, model_channels, **factory_kwargs) if cond2_channels is not None else None + + # Fixed Sobol (low-discrepancy) 3D anchor positions for the latent tokens, used as positional encoding. + # The embedder is parameter-free and the anchors are fixed, precompute once. + sobol_seq = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123).draw(q_token_length) + pos_emb = PcdAbsolutePositionEmbedder(model_channels)(sobol_seq.unsqueeze(0)) + self.register_buffer("pos_emb", pos_emb, persistent=False) + + # RePo3DRotaryEmbedding layers for the refiner and main blocks + repo_kwargs = dict(num_heads=self.num_heads, head_dim=num_head_channels, **op_kwargs) + self.noise_repo_layers = nn.ModuleList( + [RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)]) + self.context_repo_layers = nn.ModuleList( + [RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)]) + self.repo_layers = nn.ModuleList( + [RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_blocks)]) + + # Refiner blocks + block_kwargs = dict(num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, use_rope=True, qk_rms_norm=self.qk_rms_norm, **op_kwargs) + self.noise_refiner = nn.ModuleList( + [UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_refiner_blocks)]) + self.context_refiner = nn.ModuleList( + [UnifiedTransformerBlock(model_channels, modulation=False, **block_kwargs) for _ in range(num_refiner_blocks)]) + + self.cam_refiner = MLP(self.cam_channels, model_channels, model_channels, **op_kwargs) + + self.blocks = nn.ModuleList( + [UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_blocks)]) + + self.shift_table = nn.Parameter(torch.empty(1, 2, model_channels, **factory_kwargs)) + self.out_layer = operations.Linear(model_channels, out_channels, **factory_kwargs) + self.cam_out_layer = operations.Linear(model_channels, cam_channels, **factory_kwargs) + + def forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, t, context, ref_latents, transformer_options, **kwargs) + + def _forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs): + # x is the unpacked nested latent: [latent (B,8192,in_channels), camera (B,1,cam_channels)]. + # context == feature1. + z, camera = x[0], x[1] + feat1 = context + + h_x = self.input_layer(z) + h_cond = self.cond_embedder(feat1) + if ref_latents is not None and self.cond_embedder2 is not None: + # Flatten the Flux2 VAE latent (B,128,h,w) to a token sequence and front-pad to feat1's length + # (the pad count = feat1's prefix tokens: DINOv3 cls + registers), then add to the context. + feat2 = ref_latents[0].flatten(2).transpose(1, 2) + feat2 = F.pad(feat2, (0, 0, feat1.shape[1] - feat2.shape[1], 0)) + h_cond = h_cond + self.cond_embedder2(feat2.to(h_cond.dtype)) + t_emb = self.t_embedder(t) + t_mod = self.adaLN_modulation(t_emb) if self.share_mod else t_emb + + h_x = h_x + self.pos_emb.to(z) + + for i, block in enumerate(self.noise_refiner): + h_x = block(h_x, mod=t_mod, rotary_emb=self.noise_repo_layers[i](h_x), transformer_options=transformer_options) + + for i, block in enumerate(self.context_refiner): + h_cond = block(h_cond, mod=None, rotary_emb=self.context_repo_layers[i](h_cond), transformer_options=transformer_options) + + cam = camera.to(z) + h_cam = self.cam_refiner(cam) + h = torch.cat([h_x, h_cond, h_cam], dim=1) + + for i, block in enumerate(self.blocks): + h = block(h, mod=t_mod, rotary_emb=self.repo_layers[i](h), transformer_options=transformer_options) + + h_x = F.layer_norm(h[:, :z.shape[1]].float(), h.shape[-1:]).to(z) + h_cam = F.layer_norm(h[:, -cam.shape[1]:].float(), h.shape[-1:]).to(z) + + shift, scale = (comfy.model_management.cast_to(self.shift_table, t_emb.dtype, t_emb.device) + t_emb.unsqueeze(1)).chunk(2, dim=1) + scale = 1 + scale + h_x = torch.addcmul(shift, h_x, scale) + h_cam = torch.addcmul(shift, h_cam, scale) + + return self.out_layer(h_x), self.cam_out_layer(h_cam) diff --git a/comfy/ldm/triposplat/preview.py b/comfy/ldm/triposplat/preview.py new file mode 100644 index 000000000..6a942bb53 --- /dev/null +++ b/comfy/ldm/triposplat/preview.py @@ -0,0 +1,91 @@ +# Live preview for TripoSplat: decode an x0 estimate into a coarse gaussian splat and render it with a perspective orbit camera. +import numpy as np +from PIL import Image + +_C0 = 0.28209479177387814 +_LATENT_TOKENS = 8192 # q_token_length +_LATENT_CH = 16 # in_channels +_OBJECT_TO_VIEWER = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], np.float32) # object frame -> viewer Y-up frame + + +def _view_matrix(yaw_deg, pitch_deg): + y, p = np.radians(yaw_deg), np.radians(pitch_deg) + Ry = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0], [-np.sin(y), 0, np.cos(y)]], np.float32) + Rx = np.array([[1, 0, 0], [0, np.cos(p), -np.sin(p)], [0, np.sin(p), np.cos(p)]], np.float32) + return Rx @ Ry + + +def render_splat(xyz, rgb, scale, opacity=None, yaw=35.0, pitch=30.0, size=320, min_px=2, gain=1.0, + max_px=9, min_opacity=0.0, fov=35.0, dist=2.2): + # Project gaussian centers with a perspective camera and paint each as a filled disk whose screen + # radius follows the gaussian's world-space scale, composited with a nearest-wins z-buffer. + # gain scales the footprint (≈ std spanned), `min_px`/`max_px` clamp the on-screen radius. + + pts = xyz.astype(np.float32) @ _OBJECT_TO_VIEWER.T + v = pts @ _view_matrix(yaw, pitch).T + zc = v[:, 2] + dist + keep = zc > 1e-2 + if opacity is not None and min_opacity > 0.0: # culls gaussians with very low opacity + keep = keep & (opacity > min_opacity) + v, zc, scale = v[keep], zc[keep], scale[keep] + col = (np.clip(rgb, 0, 1)[:, :3] * 255).astype(np.uint8)[keep] + if v.shape[0] == 0: + return Image.fromarray(np.zeros((size, size, 3), np.uint8)) + f = (size / 2) / np.tan(np.radians(fov) / 2) + cx = size / 2 + f * v[:, 0] / zc + cy = size / 2 + f * v[:, 1] / zc + radius = np.clip(np.round(f * scale / zc * gain), min_px, max_px).astype(np.int32) + + # Expand each splat to its disk pixels, bucketed by integer radius so it stays vectorized. + px, py, pz, pc = [], [], [], [] + for r in range(int(radius.min()), int(radius.max()) + 1): + m = radius == r + if not m.any(): + continue + dy, dx = np.mgrid[-r:r + 1, -r:r + 1] + disk = (dx * dx + dy * dy) <= r * r + ox, oy = dx[disk], dy[disk] + px.append((cx[m, None] + ox).ravel()) + py.append((cy[m, None] + oy).ravel()) + pz.append(np.repeat(zc[m], ox.size)) + pc.append(np.repeat(col[m], ox.size, axis=0)) + px, py = np.concatenate(px), np.concatenate(py) + pz, pc = np.concatenate(pz), np.concatenate(pc) + xi = np.clip(px, 0, size - 1).astype(np.int64) + yi = np.clip(py, 0, size - 1).astype(np.int64) + + # Nearest-wins z-buffer: pack (quantized depth, source index), per-pixel min picks the closest + # splat, then decode the winning index back to its color. + pid = yi * size + xi + q = np.clip((pz * 1024.0).astype(np.int64), 0, (1 << 20) - 1) # near = small + key = (q << 32) | np.arange(pid.size, dtype=np.int64) + buf = np.full(size * size, 1 << 62, np.int64) + np.minimum.at(buf, pid, key) + img = np.zeros((size * size, 3), np.uint8) + hit = buf < (1 << 62) + img[hit] = pc[buf[hit] & 0xFFFFFFFF] + return Image.fromarray(img.reshape(size, size, 3)) + + +def _extract_latent(x0): + # x0 from the sampler callback is the nested latent packed to (B, 1, TOKENS*CH + 1*5); + # the plain single-latent case is (B, TOKENS, CH). Return the (B, TOKENS, CH) latent stream. + if x0.ndim == 3 and x0.shape[1] == _LATENT_TOKENS and x0.shape[2] == _LATENT_CH: + return x0 + flat = x0.reshape(x0.shape[0], -1) + return flat[:, :_LATENT_TOKENS * _LATENT_CH].reshape(x0.shape[0], _LATENT_TOKENS, _LATENT_CH) + + +def decode_x0_to_image(decoder, x0, cfg): + # Decode x0 at a coarse octree level / few gaussians and render a preview image. + latent = _extract_latent(x0) + fsm = decoder.first_stage_model + gaussian = fsm.decode(latent.to(decoder.device, decoder.vae_dtype), + num_gaussians=cfg.get("gaussians", 16384), level=cfg.get("level", 5))[0] + xyz = gaussian.get_xyz.float().cpu().numpy() + rgb = gaussian._features_dc.float().cpu().numpy()[:, 0, :] * _C0 + 0.5 + scale = gaussian.get_scaling.float().cpu().numpy().max(axis=1) # per-splat world radius (largest axis) + opacity = gaussian.get_opacity.float().cpu().numpy()[:, 0] + return render_splat(xyz, rgb, scale, opacity=opacity, yaw=cfg.get("yaw", 35.0), pitch=cfg.get("pitch", 30.0), + size=cfg.get("size", 320), min_px=1, gain=1.0, max_px=cfg.get("point_size", 3), + min_opacity=0.01) diff --git a/comfy/ldm/triposplat/vae.py b/comfy/ldm/triposplat/vae.py new file mode 100644 index 000000000..e5ed9fd36 --- /dev/null +++ b/comfy/ldm/triposplat/vae.py @@ -0,0 +1,382 @@ +# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an +# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns +# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.model_management +import comfy.ops +from .gaussian import build_gaussian_models +from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention + + +# Quasi-random sampling utilities (pure functions, dtype/device-agnostic) + +PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + + +def radical_inverse(base, n): + val = 0 + inv_base = 1.0 / base + inv_base_n = inv_base + while n > 0: + digit = n % base + val += digit * inv_base_n + n //= base + inv_base_n *= inv_base + return val + + +def halton_sequence(dim, n): + return [radical_inverse(PRIMES[i], n) for i in range(dim)] + + +def hammersley_sequence(dim, n, num_samples): + return [n / num_samples] + halton_sequence(dim - 1, n) + + +def sample_probs(probs, counts, generator=None): + # Systematic resampling: distribute counts[r] draws across the P bins of row r + batch_shape = counts.shape + R = counts.numel() + P = probs.size(-1) + device = probs.device + probs = probs.reshape(R, P).to(torch.float32).clamp_min(0) + counts = counts.reshape(R).to(device=device, dtype=torch.long) + + row_sums = probs.sum(1, keepdim=True) + probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1)) + cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12) + + Nmax = int(counts.max()) + if Nmax == 0: + return counts.new_zeros(*batch_shape, P) + cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1) + grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax) + u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded) + idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1) + weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r] + out = torch.zeros(R, P, dtype=torch.float32, device=device) + out.scatter_add_(1, idx, weight) + return out.to(torch.long).view(*batch_shape, P) + + +class MultiHeadAttention(nn.Module): + def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False, + dtype=None, device=None, operations=None): + super().__init__() + assert channels % num_heads == 0 + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.qk_rms_norm = qk_rms_norm + if self._type == "self": + self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device) + else: + self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device) + self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device) + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device) + self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device) + + def forward(self, x, context=None): + B, L, C = x.shape + if self._type == "self": + q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2) + else: + Lkv = context.shape[1] + q = self.to_q(x).reshape(B, L, self.num_heads, -1) + k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + h = attention(q, k, v) + return self.to_out(h.reshape(B, L, -1)) + + +# Octree probability decoder + +class LevelEmbedder(nn.Module): + def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024, + dtype=None, device=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + self.max_period = max_period + + @staticmethod + def level_embedding(t, dim, max_period=1024): + half = dim // 2 + freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device) + args = t[:, None].float() * freqs[None] * 2 * torch.pi + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period) + return self.mlp(emb.to(self.mlp[0].weight.dtype)) + + +class ModulatedTransformerCrossOnlyBlock(nn.Module): + def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False, + qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None): + super().__init__() + self.share_mod = share_mod + self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, + type="cross", qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations) + self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)) + + def forward(self, x, mod, context): + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1)) + x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1)) + h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1)) + x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1)) + return x + + +class OctreeProbabilityFixedlenDecoder(nn.Module): + # Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits. + def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16, + num_head_channels=64, mlp_ratio=4.0, share_mod=True, + qk_rms_norm_cross=True, dtype=None, device=None, operations=None): + super().__init__() + self.model_channels = model_channels + self.cond_channels = cond_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.share_mod = share_mod + self.qk_rms_norm_cross = qk_rms_norm_cross + self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device) + self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device)) + if cond_channels is not None: + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossOnlyBlock( + model_channels, ctx_channels=cond_channels, num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross, + share_mod=self.share_mod, dtype=dtype, device=device, operations=operations) + for _ in range(num_blocks) + ]) + self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device) + self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device) + self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2") + + def forward(self, x, l, cond): + d = next(self.parameters()).dtype + B, L, _ = x.shape + h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d) + h = self.input_layer(h) + l_emb = self.l_embedder(l) + if self.share_mod: + l_emb = self.adaLN_modulation(l_emb) + cond = cond.to(d) + for block in self.blocks: + h = block(h, l_emb, cond) + h = F.layer_norm(h.float(), h.shape[-1:]).to(d) + logits = self.out_proj(h) + return {"logits": logits, "probs": torch.softmax(logits, dim=-1)} + + @staticmethod + def sample(model, cond, num_points, level, temperature=1.0, generator=None): + B = cond.shape[0] + device = cond.device + child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]], + dtype=torch.long, device=device) + prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device) + prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device) + prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device) + batch_indices_range = torch.arange(B, device=device).unsqueeze(1) + + for lv in range(1, level + 1): + res_p = 1 << (lv - 1) + res = 1 << lv + parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p + res_tensor = torch.full((B,), res, dtype=torch.long, device=device) + pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature + pred_probs = torch.softmax(pred_logits, dim=-1) + pred_log_probs = torch.log_softmax(pred_logits, dim=-1) + sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2) + pred_log_probs = pred_log_probs.flatten(1, 2) + prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1) + child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2) + mask = sampled > 0 + max_valid = mask.sum(dim=1).max().item() + scatter_indices = mask.cumsum(dim=1) - 1 + valid_scatter_indices = scatter_indices[mask] + valid_batch_indices = batch_indices_range.expand_as(mask)[mask] + next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device) + next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask] + next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device) + next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask] + next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device) + next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask] + prev_coords_int = next_prev_coords_int + prev_counts = next_prev_counts + prev_log_probs = next_prev_log_probs + + res = 1 << level + prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points) + coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1) + rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device) + coords_norm = (coords_int.to(torch.float32) + rand) / res + return {"points": coords_norm, "log_probs": prev_log_probs} + + +# Elastic gaussian decoder + +class TransformerCrossBlock(nn.Module): + def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, + qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True, + dtype=None, device=None, operations=None): + super().__init__() + self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device) + self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations) + self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross", + qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations) + self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations) + + def forward(self, x, context): + x = x + self.self_attn(self.norm1(x)) + x = x + self.cross_attn(self.norm2(x), context) + x = x + self.mlp(self.norm3(x)) + return x + + +class ElasticGaussianFixedlenDecoder(nn.Module): + # Cross-attention transformer over sampled octree points -> per-point gaussian params. + def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16, + num_head_channels=64, mlp_ratio=4.0, *, representation_config=None, + qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None): + super().__init__() + self.rep_config = representation_config or dict( + lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1), + perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32, + filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1, + scaling_activation="softplus", + ) + self.out_channels = self._calc_layout() + self.model_channels = model_channels + self.cond_channels = cond_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device) + if cond_channels is not None: + self.blocks = nn.ModuleList([ + TransformerCrossBlock(model_channels, ctx_channels=cond_channels, + num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, + qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross, + dtype=dtype, device=device, operations=operations) + for _ in range(num_blocks) + ]) + self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device) + self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2") + self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device) + self._build_perturbation() + + def _calc_layout(self): + ng = self.rep_config['num_gaussians'] + self.layout = { + '_xyz': {'shape': (ng, 3), 'size': ng * 3}, + '_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3}, + '_scaling': {'shape': (ng, 3), 'size': ng * 3}, + '_rotation': {'shape': (ng, 4), 'size': ng * 4}, + '_opacity': {'shape': (ng, 1), 'size': ng}, + } + self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng} + start = 0 + for k, v in self.layout.items(): + v['range'] = (start, start + v['size']) + start += v['size'] + return start + + def _build_perturbation(self): + ng = self.rep_config['num_gaussians'] + perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float() + perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size']) + self.register_buffer('points_offset_perturbation', perturbation) + base = torch.tensor(self.rep_config['offset_scale']) + self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0)) + + def _get_offset(self, h): + B = h.shape[0] + r = self.layout['_offset_scale']['range'] + _offset_scale = F.softplus( + h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape']) + + comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device)) + + r = self.layout['_xyz']['range'] + offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape']) + offset = offset * self.rep_config['lr']['_xyz'] + if self.rep_config['perturb_offset']: + offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device) + offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size'] + offset = offset * _offset_scale + return offset + + def forward(self, x=None, cond=None): + pcd = x["points"] + d = next(self.parameters()).dtype + B, L, _ = pcd.shape + h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d) + h = self.input_layer(h) + cond = cond.to(d) + for block in self.blocks: + h = block(h, cond) + h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype) + return {"features": self.out_proj(h)} + + +# Combined octree gaussian decoder (comfy first-stage model) + +class OctreeGaussianDecoder(nn.Module): + _MAX_VOXEL_LEVEL = 8 + + def __init__(self, dtype=None, device=None, operations=None): + super().__init__() + if operations is None: + operations = comfy.ops.disable_weight_init + self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations) + self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations) + + @property + def gaussians_per_point(self) -> int: + return self.gs.rep_config['num_gaussians'] + + def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None): + # level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews. + # generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG. + level = self._MAX_VOXEL_LEVEL if level is None else level + num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point) + points_pred = OctreeProbabilityFixedlenDecoder.sample( + self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator, + ) + pred = self.gs(x=points_pred, cond=latent) + return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 962addb27..e032b7dcd 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -4,6 +4,7 @@ import dataclasses import torch from typing import NamedTuple +import comfy_aimdo.host_buffer from comfy.quant_ops import QuantizedTensor @@ -17,21 +18,18 @@ class TensorFileSlice(NamedTuple): def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=None): if isinstance(tensor, QuantizedTensor): - if not isinstance(destination, QuantizedTensor): - return False - if tensor._layout_cls != destination._layout_cls: - return False - - if not read_tensor_file_slice_into(tensor._qdata, destination._qdata, stream=stream, + if not read_tensor_file_slice_into(tensor._qdata, + destination._qdata if destination is not None else None, stream=stream, destination2=(destination2._qdata if destination2 is not None else None)): return False - dst_orig_dtype = destination._params.orig_dtype - destination._params.copy_from(tensor._params, non_blocking=False) - destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) + if destination is not None: + dst_orig_dtype = destination._params.orig_dtype + destination._params.copy_from(tensor._params, non_blocking=False) + destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) if destination2 is not None: dst_orig_dtype = destination2._params.orig_dtype - destination2._params.copy_from(destination._params, non_blocking=True) + destination2._params.copy_from(destination._params if destination is not None else tensor._params, non_blocking=True) destination2._params = dataclasses.replace(destination2._params, orig_dtype=dst_orig_dtype) return True @@ -39,10 +37,15 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N if info is None: return False + if destination is not None and destination.device.type != "cpu" and destination2 is None: + destination2 = destination + destination = None + file_obj = info.file_ref - if (destination.device.type != "cpu" - or file_obj is None - or destination.numel() * destination.element_size() < info.size + if (file_obj is None + or (destination is None and destination2 is None) + or (destination is not None and (destination.device.type != "cpu" or destination.numel() * destination.element_size() < info.size)) + or (destination2 is not None and (destination2.device.type == "cpu" or destination2.numel() * destination2.element_size() < info.size)) or tensor.numel() * tensor.element_size() != info.size or tensor.storage_offset() != 0 or not tensor.is_contiguous()): @@ -51,6 +54,14 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N if info.size == 0: return True + if destination is None: + stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 + comfy_aimdo.host_buffer.read_file_to_device(file_obj, info.offset, info.size, + stream_ptr, destination2.data_ptr(), + destination2.device.index, + mark_cold=False) + return True + hostbuf = getattr(destination.untyped_storage(), "_comfy_hostbuf", None) if hostbuf is not None: stream_ptr = getattr(stream, "cuda_stream", 0) if stream is not None else 0 @@ -63,6 +74,9 @@ def read_tensor_file_slice_into(tensor, destination, stream=None, destination2=N device=None if destination2 is None else destination2.device.index) return True + if not hasattr(file_obj, "seek") or not hasattr(file_obj, "readinto"): + return False + buf_type = ctypes.c_ubyte * info.size view = memoryview(buf_type.from_address(destination.data_ptr())) diff --git a/comfy/model_base.py b/comfy/model_base.py index e55808633..c084e23bb 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -46,6 +46,7 @@ import comfy.ldm.wan.model_animate import comfy.ldm.wan.ar_model import comfy.ldm.wan.model_wandancer import comfy.ldm.hunyuan3d.model +import comfy.ldm.triposplat.model import comfy.ldm.hidream.model import comfy.ldm.chroma.model import comfy.ldm.chroma_radiance.model @@ -53,7 +54,10 @@ import comfy.ldm.pixeldit.model import comfy.ldm.pixeldit.pid import comfy.ldm.ace.model import comfy.ldm.omnigen.omnigen2 +import comfy.ldm.seedvr.model + import comfy.ldm.qwen_image.model +import comfy.ldm.ideogram4.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model import comfy.ldm.ace.ace_step15 @@ -926,6 +930,16 @@ class HunyuanDiT(BaseModel): out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) return out +class SeedVR2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + condition = kwargs.get("condition", None) + if condition is not None: + out["condition"] = comfy.conds.CONDRegular(condition) + return out + class PixArt(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS) @@ -1428,6 +1442,23 @@ class PiD(PixelDiTT2I): out["degrade_sigma"] = comfy.conds.CONDRegular(degrade_sigma) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "lq_latent" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + lq = cond_value.cond + dim = window.dim + if dim >= lq.ndim: + return None + lq_proj = self.diffusion_model.lq_proj + ratio = lq_proj.sr_scale * lq_proj.latent_spatial_down_factor + # Map x window indices -> lq indices (deduplicated, sorted, in-bounds). + lq_size = lq.size(dim) + lq_indices = sorted({i // ratio for i in window.index_list if 0 <= i // ratio < lq_size}) + if not lq_indices: + return None + idx = tuple([slice(None)] * dim + [lq_indices]) + return cond_value._copy_with(lq[idx].to(device)) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): @@ -1789,6 +1820,24 @@ class Hunyuan3Dv2_1(BaseModel): out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance])) return out +class TripoSplat(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.triposplat.model.LatentSeqMMFlowModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) # DINOv3 token sequence -> cross-attention context. + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + ref_latents = kwargs.get("reference_latents", None) # Flux2 VAE image latent -> additive second conditioning. + if ref_latents is not None: + out['ref_latents'] = comfy.conds.CONDList(list(ref_latents)) + latent_shapes = kwargs.get("latent_shapes", None) # {latent, camera} nested latent + if latent_shapes is not None: + out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes) + return out + + class HiDream(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel) @@ -1982,6 +2031,21 @@ class QwenImage(BaseModel): out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) return out +class Ideogram4(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ideogram4.model.Ideogram4Transformer2DModel) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + if torch.numel(attention_mask) != attention_mask.sum(): + out['attention_mask'] = comfy.conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn) + return out + class HunyuanImage21(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index f0db7d388..955581006 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -313,6 +313,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["use_x0"] = True else: dit_config["use_x0"] = False + if "{}__sequential__".format(key_prefix) in state_dict_keys: # sequential txt_ids + dit_config["use_sequential_txt_ids"] = True + else: + dit_config["use_sequential_txt_ids"] = False else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys @@ -594,6 +598,56 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): return dit_config + if "{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix) in state_dict_keys and state_dict["{}blocks.35.mlp.vid.proj_in.weight".format(key_prefix)].shape[1] == 3072: # seedvr2 7b + dit_config = {} + dit_config["image_model"] = "seedvr2" + dit_config["vid_dim"] = 3072 + dit_config["heads"] = 24 + dit_config["num_layers"] = 36 + # 7B uses non-shared MMModule layout (separate ``vid.`` / ``txt.`` + # submodules) at EVERY block — verified by inspecting the 7B + # state_dict at ``blocks.31.ada.txt.attn_gate`` (txt. prefix means + # ``MMModule.shared_weights=False``). Native NaDiT computes + # per-block ``shared_weights = not (i < mm_layers)``, so to keep + # every block non-shared we set ``mm_layers = num_layers``. + # Without this, blocks at index >= mm_layers (default 10) try to + # load ``blocks.N.*.all.*`` keys that don't exist in the file, + # silently miss-load → all-black output. + dit_config["mm_layers"] = 36 + dit_config["norm_eps"] = 1e-5 + dit_config["qk_rope"] = True + dit_config["rope_type"] = "rope3d" + dit_config["rope_dim"] = 64 + dit_config["mlp_type"] = "normal" + return dit_config + elif "{}blocks.35.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 7b + dit_config = {} + dit_config["image_model"] = "seedvr2" + dit_config["vid_dim"] = 3072 + dit_config["heads"] = 24 + dit_config["num_layers"] = 36 + # This checkpoint layout carries shared ``all.`` MMModule keys. + # Preserve the historical split: the initial blocks use separate + # vid/txt modules, later blocks use shared modules. + dit_config["mm_layers"] = 10 + dit_config["norm_eps"] = 1e-5 + dit_config["qk_rope"] = True + dit_config["rope_type"] = "rope3d" + dit_config["rope_dim"] = 64 + dit_config["mlp_type"] = "swiglu" + return dit_config + elif "{}blocks.31.mlp.all.proj_in_gate.weight".format(key_prefix) in state_dict_keys: # seedvr2 3b + dit_config = {} + dit_config["image_model"] = "seedvr2" + dit_config["vid_dim"] = 2560 + dit_config["heads"] = 20 + dit_config["num_layers"] = 32 + dit_config["norm_eps"] = 1.0e-05 + dit_config["qk_rope"] = None + dit_config["mlp_type"] = "swiglu" + dit_config["vid_out_norm"] = True + return dit_config + if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 dit_config = {} dit_config["image_model"] = "wan2.1" @@ -676,6 +730,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys return dit_config + if '{}cam_out_layer.weight'.format(key_prefix) in state_dict_keys and '{}repo_layers.0.final_map.weight'.format(key_prefix) in state_dict_keys: # TripoSplat + return {"image_model": "triposplat"} + if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1 return {"image_model": "hidream_o1"} @@ -808,6 +865,13 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["default_ref_method"] = "negative_index" return dit_config + if '{}embed_image_indicator.weight'.format(key_prefix) in state_dict_keys: # Ideogram 4 + dit_config = {} + dit_config["image_model"] = "ideogram4" + dit_config["in_channels"] = state_dict['{}input_proj.weight'.format(key_prefix)].shape[1] + dit_config["num_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') + return dit_config + if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 dit_config = {} model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0] diff --git a/comfy/model_management.py b/comfy/model_management.py index b01c4d7fa..8e786c0a5 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -641,15 +641,17 @@ def free_pins(size, evict_active=False): return freed_total def ensure_pin_budget(size, evict_active=False): - shortfall = size + comfy.memory_management.RAM_CACHE_HEADROOM / 2 - psutil.virtual_memory().available + if args.fast_disk: + shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY + else: + shortfall = size + max(comfy.memory_management.RAM_CACHE_HEADROOM / 2, 2048 * 1024 ** 2) - psutil.virtual_memory().available if shortfall <= 0: return True to_free = shortfall + PIN_PRESSURE_HYSTERESIS return free_pins(to_free, evict_active=evict_active) >= shortfall -def ensure_pin_registerable(size, evict_active=False): - shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY +def free_registrations(shortfall, evict_active=True): if MAX_PINNED_MEMORY <= 0: return False if shortfall <= 0: @@ -658,12 +660,22 @@ def ensure_pin_registerable(size, evict_active=False): shortfall += REGISTERABLE_PIN_HYSTERESIS for loaded_model in reversed(current_loaded_models): model = loaded_model.model - if model is not None and model.is_dynamic() and (evict_active or not model.model.dynamic_pins[model.load_device]["active"]): + if model is not None and model.is_dynamic() and not model.model.dynamic_pins[model.load_device]["active"]: shortfall -= model.unregister_inactive_pins(shortfall) if shortfall <= 0: return True + if evict_active: + for loaded_model in current_loaded_models: + model = loaded_model.model + if model is not None and model.is_dynamic() and model.model.dynamic_pins[model.load_device]["active"]: + shortfall -= model.unregister_inactive_pins(shortfall) + if shortfall <= 0: + return True return shortfall <= REGISTERABLE_PIN_HYSTERESIS +def ensure_pin_registerable(size, evict_active=True): + return free_registrations(TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY, evict_active=evict_active) + class LoadedModel: def __init__(self, model: ModelPatcher): self._set_model(model) @@ -803,9 +815,9 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins for x in can_unload_sorted: i = x[-1] memory_to_free = 1e32 - if current_loaded_models[i].model.is_dynamic() and (not DISABLE_SMART_MEMORY or device is None): + if not DISABLE_SMART_MEMORY or device is None: memory_to_free = 0 if device is None else memory_required - get_free_memory(device) - if for_dynamic: + if current_loaded_models[i].model.is_dynamic() and for_dynamic: #don't actually unload dynamic models for the sake of other dynamic models #as that works on-demand. memory_required -= current_loaded_models[i].model.loaded_size() @@ -817,6 +829,10 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) + if not for_dynamic and pins_required > 0: + ensure_pin_budget(pins_required) + ensure_pin_registerable(pins_required) + if len(unloaded_model) > 0: soft_empty_cache() elif device is not None: @@ -879,15 +895,19 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu model_to_unload.model_finalizer.detach() total_memory_required = {} + total_pins_required = {} for loaded_model in models_to_load: device = loaded_model.device total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device) + if not loaded_model.model.is_dynamic(): + total_pins_required[device] = total_pins_required.get(device, 0) + loaded_model.model_memory() for device in total_memory_required: if device != torch.device("cpu"): free_memory(total_memory_required[device] * 1.1 + extra_mem, device, - for_dynamic=free_for_dynamic) + for_dynamic=free_for_dynamic, + pins_required=total_pins_required.get(device, 0)) for device in total_memory_required: if device != torch.device("cpu"): @@ -1283,7 +1303,6 @@ STREAM_CAST_BUFFERS = {} LARGEST_CASTED_WEIGHT = (None, 0) STREAM_AIMDO_CAST_BUFFERS = {} LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) -STREAM_PIN_BUFFERS = {} DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3 @@ -1326,42 +1345,13 @@ def get_aimdo_cast_buffer(offload_stream, device): STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer return cast_buffer -def get_pin_buffer(offload_stream): - pin_buffer = STREAM_PIN_BUFFERS.get(offload_stream, None) - if pin_buffer is None: - pin_buffer = comfy_aimdo.host_buffer.HostBuffer(0, 0, pinned_hostbuf_size(8 * 1024**3), mark_cold=False) - STREAM_PIN_BUFFERS[offload_stream] = pin_buffer - elif offload_stream is not None: - event = getattr(pin_buffer, "_comfy_event", None) - if event is not None: - event.synchronize() - delattr(pin_buffer, "_comfy_event") - return pin_buffer - -def resize_pin_buffer(pin_buffer, size): - global TOTAL_PINNED_MEMORY - old_size = pin_buffer.size - if size <= old_size: - return True - growth = size - old_size - comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) - ensure_pin_budget(growth, evict_active=True) - ensure_pin_registerable(growth, evict_active=True) - try: - pin_buffer.extend(size=size, reallocate=True) - except RuntimeError: - return False - TOTAL_PINNED_MEMORY += pin_buffer.size - old_size - return True - def reset_cast_buffers(): - global TOTAL_PINNED_MEMORY global LARGEST_CASTED_WEIGHT global LARGEST_AIMDO_CASTED_WEIGHT LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_AIMDO_CASTED_WEIGHT = (None, 0) - for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS) | set(STREAM_PIN_BUFFERS): + for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS): if offload_stream is not None: offload_stream.synchronize() synchronize() @@ -1370,20 +1360,24 @@ def reset_cast_buffers(): mmap_obj.bounce() DIRTY_MMAPS.clear() - for pin_buffer in STREAM_PIN_BUFFERS.values(): - TOTAL_PINNED_MEMORY -= pin_buffer.size - TOTAL_PINNED_MEMORY = max(0, TOTAL_PINNED_MEMORY) - for loaded_model in current_loaded_models: model = loaded_model.model if model is not None and model.is_dynamic(): - model.model.dynamic_pins[model.load_device]["active"] = False + pin_state = model.model.dynamic_pins[model.load_device] + + if pin_state["active"]: + *_, buckets = pin_state["weights"] + for size, bucket in list(buckets.items()): + bucket[:] = [ entry for entry in bucket if entry[-1] is not None ] + if not bucket: + del buckets[size] + + pin_state["active"] = False model.partially_unload_ram(1e30, subsets=[ "patches" ]) - model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0]) + model.model.dynamic_pins[model.load_device]["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, pinned_hostbuf_size(model.model_size())), [], [-1], [0], [0], {}) STREAM_CAST_BUFFERS.clear() STREAM_AIMDO_CAST_BUFFERS.clear() - STREAM_PIN_BUFFERS.clear() soft_empty_cache() def get_offload_stream(device): @@ -1436,7 +1430,7 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None): if hasattr(wf_context, "as_context"): wf_context = wf_context.as_context(stream) - dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) + dest_views = comfy.memory_management.interpret_gathered_like(tensors, r) if r is not None else [None] * len(tensors) dest2_views = comfy.memory_management.interpret_gathered_like(tensors, r2) if r2 is not None else None with wf_context: for tensor in tensors: @@ -1448,9 +1442,10 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None, r2=None): continue storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() mark_mmap_dirty(storage) - dest_view.copy_(tensor, non_blocking=non_blocking) + if dest_view is not None: + dest_view.copy_(tensor, non_blocking=non_blocking) if dest2_view is not None: - dest2_view.copy_(dest_view, non_blocking=non_blocking) + dest2_view.copy_(tensor if dest_view is None else dest_view, non_blocking=non_blocking) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None): @@ -1723,6 +1718,13 @@ def is_device_xpu(device): def is_device_cuda(device): return is_device_type(device, 'cuda') +def set_torch_device(device): + """Set the current device for the given torch device. Supports CUDA and XPU.""" + if is_device_cuda(device): + torch.cuda.set_device(device) + elif is_device_xpu(device): + torch.xpu.set_device(device) + def is_directml_enabled(): global directml_enabled if directml_enabled: diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 00a15fa63..b716a69e2 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -1721,8 +1721,8 @@ class ModelPatcherDynamic(ModelPatcher): """ if device not in self.model.dynamic_pins: self.model.dynamic_pins[device] = { - "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), - "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0]), + "weights": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}), + "patches": (comfy_aimdo.host_buffer.HostBuffer(0, 0, 0), [], [-1], [0], [0], {}), "hostbufs_initialized": False, "failed": False, "active": False, @@ -1799,8 +1799,8 @@ class ModelPatcherDynamic(ModelPatcher): pin_state = self.model.dynamic_pins[self.load_device] if not pin_state["hostbufs_initialized"]: hostbuf_size = comfy.model_management.pinned_hostbuf_size(self.model_size()) - pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0]) - pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0]) + pin_state["weights"] = (comfy_aimdo.host_buffer.HostBuffer(0, 64 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {}) + pin_state["patches"] = (comfy_aimdo.host_buffer.HostBuffer(0, 8 * 1024 * 1024, hostbuf_size), [], [-1], [0], [0], {}) pin_state["hostbufs_initialized"] = True pin_state["failed"] = False pin_state["active"] = True @@ -1942,18 +1942,16 @@ class ModelPatcherDynamic(ModelPatcher): return freed def loaded_ram_size(self): - return (self.model.dynamic_pins[self.load_device]["weights"][0].size + - self.model.dynamic_pins[self.load_device]["patches"][0].size) + return (self.model.dynamic_pins[self.load_device]["weights"][0].size) def pinned_memory_size(self): - return (self.model.dynamic_pins[self.load_device]["weights"][3][0] + - self.model.dynamic_pins[self.load_device]["patches"][3][0]) + return (self.model.dynamic_pins[self.load_device]["weights"][3][0]) def unregister_inactive_pins(self, ram_to_unload, subsets=[ "weights", "patches" ]): freed = 0 pin_state = self.model.dynamic_pins[self.load_device] for subset in subsets: - hostbuf, stack, stack_split, pinned_size = pin_state[subset] + hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset] split = stack_split[0] while split >= 0: module, offset = stack[split] @@ -1978,10 +1976,12 @@ class ModelPatcherDynamic(ModelPatcher): freed = 0 pin_state = self.model.dynamic_pins[self.load_device] for subset in subsets: - hostbuf, stack, stack_split, pinned_size = pin_state[subset] + hostbuf, stack, stack_split, pinned_size, *_ = pin_state[subset] while len(stack) > 0: module, offset = stack.pop() size = module._pin.numel() * module._pin.element_size() + module._pin_balancer_entry[-1] = None + del module._pin_balancer_entry del module._pin hostbuf.truncate(offset, do_unregister=module._pin_registered) stack_split[0] = min(stack_split[0], len(stack) - 1) diff --git a/comfy/model_prefetch.py b/comfy/model_prefetch.py index 72e11dec6..aa6d22d77 100644 --- a/comfy/model_prefetch.py +++ b/comfy/model_prefetch.py @@ -1,4 +1,5 @@ import comfy_aimdo.model_vbar +import comfy.memory_management import comfy.model_management import comfy.ops @@ -50,7 +51,17 @@ def prefetch_queue_pop(queue, device, module): if hasattr(s, "_v"): comfy_modules.append(s) + registerable_size = 0 + for s in comfy_modules: + registerable_size += comfy.memory_management.vram_aligned_size([s.weight, s.bias]) + for param_key in ("weight", "bias"): + lowvram_fn = getattr(s, param_key + "_lowvram_function", None) + if lowvram_fn is not None: + registerable_size += lowvram_fn.memory_required() + offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True) + if not comfy.model_management.args.fast_disk: + comfy.model_management.ensure_pin_registerable(registerable_size) comfy.model_management.sync_stream(device, offload_stream) queue[0] = (offload_stream, (prefetch, comfy_modules)) diff --git a/comfy/multigpu.py b/comfy/multigpu.py index e7f5b3d6f..2b6d8260d 100644 --- a/comfy/multigpu.py +++ b/comfy/multigpu.py @@ -17,7 +17,7 @@ class MultiGPUThreadPool: """Persistent thread pool for multi-GPU work distribution. Maintains one worker thread per extra GPU device. Each thread calls - torch.cuda.set_device() once at startup so that compiled kernel caches + set_torch_device() once at startup so that compiled kernel caches (inductor/triton) stay warm across diffusion steps. """ @@ -37,7 +37,7 @@ class MultiGPUThreadPool: def _worker_loop(self, device: torch.device, work_q: queue.Queue, result_q: queue.Queue): try: - torch.cuda.set_device(device) + comfy.model_management.set_torch_device(device) except Exception as e: logging.error(f"MultiGPUThreadPool: failed to set device {device}: {e}") while True: @@ -54,6 +54,8 @@ class MultiGPUThreadPool: try: result = fn(*args, **kwargs) result_q.put((result, None)) + except comfy.model_management.InterruptProcessingException as e: + result_q.put((None, e)) except Exception as e: result_q.put((None, e)) diff --git a/comfy/ops.py b/comfy/ops.py index 56445be8d..119177c37 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -76,8 +76,6 @@ except: cast_to = comfy.model_management.cast_to #TODO: remove once no more references -STREAM_PIN_BUFFER_HEADROOM = 8 * 1024 * 1024 - def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) @@ -94,9 +92,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin offload_stream = None cast_buffer = None cast_buffer_offset = 0 - stream_pin_hostbuf = None - stream_pin_offset = 0 - stream_pin_queue = [] def ensure_offload_stream(module, required_size, check_largest): nonlocal offload_stream @@ -130,22 +125,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin cast_buffer_offset += buffer_size return buffer - def get_stream_pin_buffer_offset(buffer_size): - nonlocal stream_pin_hostbuf - nonlocal stream_pin_offset - - if buffer_size == 0 or offload_stream is None: - return None - - if stream_pin_hostbuf is None: - stream_pin_hostbuf = comfy.model_management.get_pin_buffer(offload_stream) - if stream_pin_hostbuf is None: - return None - - offset = stream_pin_offset - stream_pin_offset += buffer_size - return offset - for s in comfy_modules: signature = comfy_aimdo.model_vbar.vbar_fault(s._v) resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) @@ -184,12 +163,18 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin if xfer_dest is None: xfer_dest = get_cast_buffer(dest_size) - def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream): + def cast_maybe_lowvram_patch(xfer_source, xfer_dest, stream, xfer_dest2=None): if xfer_source is not None: if getattr(xfer_source, "is_lowvram_patch", False): - xfer_source.prepare(xfer_dest, stream, copy=True, commit=False) - else: - comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream) + if xfer_dest is not None: + xfer_source.prepare(xfer_dest, stream, copy=True, commit=False) + xfer_source = [ xfer_dest ] + xfer_dest = xfer_dest2 + xfer_dest2 = None + elif xfer_dest2 is not None: + xfer_source.prepare(xfer_dest2, stream, copy=True, commit=False) + return + comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=stream, r2=xfer_dest2) def handle_pin(m, pin, source, dest, subset="weights", size=None): if pin is not None: @@ -198,19 +183,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin if signature is None: comfy.pinned_memory.pin_memory(m, subset=subset, size=size) pin = comfy.pinned_memory.get_pin(m, subset=subset) - if pin is not None: - if isinstance(source, list): - comfy.model_management.cast_to_gathered(source, pin, non_blocking=non_blocking, stream=offload_stream, r2=dest) - else: - cast_maybe_lowvram_patch(source, pin, None) - cast_maybe_lowvram_patch([ pin ], dest, offload_stream) - return - if pin is None: - pin_offset = get_stream_pin_buffer_offset(size) - if pin_offset is not None: - stream_pin_queue.append((source, pin_offset, size, dest)) - return - cast_maybe_lowvram_patch(source, dest, offload_stream) + cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest) handle_pin(s, pin, xfer_source, xfer_dest, size=dest_size) @@ -232,23 +205,6 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin prefetch["needs_cast"] = needs_cast s._prefetch = prefetch - if stream_pin_offset > 0: - if stream_pin_hostbuf.size < stream_pin_offset: - if not comfy.model_management.resize_pin_buffer(stream_pin_hostbuf, stream_pin_offset + STREAM_PIN_BUFFER_HEADROOM): - for xfer_source, _, _, xfer_dest in stream_pin_queue: - cast_maybe_lowvram_patch(xfer_source, xfer_dest, offload_stream) - return offload_stream - stream_pin_tensor = comfy_aimdo.torch.hostbuf_to_tensor(stream_pin_hostbuf) - stream_pin_tensor.untyped_storage()._comfy_hostbuf = stream_pin_hostbuf - for xfer_source, pin_offset, pin_size, xfer_dest in stream_pin_queue: - pin = stream_pin_tensor[pin_offset:pin_offset + pin_size] - if isinstance(xfer_source, list): - comfy.model_management.cast_to_gathered(xfer_source, pin, non_blocking=non_blocking, stream=offload_stream, r2=xfer_dest) - else: - cast_maybe_lowvram_patch(xfer_source, pin, None) - comfy.model_management.cast_to_gathered([ pin ], xfer_dest, non_blocking=non_blocking, stream=offload_stream) - stream_pin_hostbuf._comfy_event = offload_stream.record_event() - return offload_stream diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 0e8f573ba..cb77c517a 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -1,17 +1,55 @@ +import bisect + import comfy.model_management import comfy.memory_management +import comfy.utils import comfy_aimdo.host_buffer import comfy_aimdo.torch import torch from comfy.cli_args import args +def _add_to_bucket(module, buckets, size, priority): + bucket = buckets.setdefault(size, []) + entry = [-priority, 0, module] + entry[1] = id(entry) + bisect.insort(bucket, entry) + module._pin_balancer_entry = entry + +def _steal_pin(module, stack, buckets, size, priority): + bucket = buckets.get(size) + if bucket is None: + return False + + while bucket and bucket[-1][-1] is None: + bucket.pop() + if not bucket: + del buckets[size] + return False + + if priority <= -bucket[-1][0]: + return False + + *_, victim = bucket.pop() + module._pin = victim._pin + module._pin_registered = victim._pin_registered + module._pin_stack_index = victim._pin_stack_index + stack[module._pin_stack_index] = (module, stack[module._pin_stack_index][1]) + + victim._pin_registered = False + del victim._pin + del victim._pin_stack_index + del victim._pin_balancer_entry + + _add_to_bucket(module, buckets, size, priority) + return True + def get_pin(module, subset="weights"): pin = getattr(module, "_pin", None) if pin is None or module._pin_registered or args.disable_pinned_memory: return pin - _, _, stack_split, pinned_size = module._pin_state[subset] + _, _, stack_split, pinned_size, *_ = module._pin_state[subset] size = pin.nbytes comfy.model_management.ensure_pin_registerable(size) @@ -31,33 +69,51 @@ def pin_memory(module, subset="weights", size=None): return pin = get_pin(module, subset) - if pin is not None or pin_state["failed"]: + if pin is not None: return - hostbuf, stack, stack_split, pinned_size = pin_state[subset] + hostbuf, stack, stack_split, pinned_size, counter, buckets = pin_state[subset] if size is None: size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) offset = hostbuf.size - registerable_size = size + max(0, hostbuf.size - pinned_size[0]) + registerable_size = size + priority = getattr(module, "_pin_balancer_priority", None) + + if priority is None: + priority = comfy.utils.bit_reverse_range(counter[0], 16) + counter[0] += 1 + module._pin_balancer_priority = priority comfy.memory_management.extra_ram_release(comfy.memory_management.RAM_CACHE_HEADROOM) if (not comfy.model_management.ensure_pin_budget(size) or not comfy.model_management.ensure_pin_registerable(registerable_size)): - pin_state["failed"] = True - return False + return _steal_pin(module, stack, buckets, size, priority) + extended = False try: - hostbuf.extend(size=size) + hostbuf.extend(size=size, register=False) + extended = True + pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] + pin.untyped_storage()._comfy_hostbuf = hostbuf + if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0: + comfy.model_management.discard_cuda_async_error() + comfy.model_management.free_registrations(size) + if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0: + comfy.model_management.discard_cuda_async_error() + del pin + hostbuf.truncate(offset, do_unregister=False) + return _steal_pin(module, stack, buckets, size, priority) except RuntimeError: - pin_state["failed"] = True - return False + if extended: + hostbuf.truncate(offset, do_unregister=False) + return _steal_pin(module, stack, buckets, size, priority) - module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size] - module._pin.untyped_storage()._comfy_hostbuf = hostbuf + module._pin = pin stack.append((module, offset)) module._pin_registered = True module._pin_stack_index = len(stack) - 1 stack_split[0] = max(stack_split[0], module._pin_stack_index) comfy.model_management.TOTAL_PINNED_MEMORY += size pinned_size[0] += size + _add_to_bucket(module, buckets, size, priority) return True diff --git a/comfy/sample.py b/comfy/sample.py index 2be0cae5f..de71596b3 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -44,7 +44,13 @@ def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None, is_empty = torch.count_nonzero(latent_image) == 0 if is_empty: if latent_format.latent_channels != latent_image.shape[1]: - latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) + preserves_collapsed_channels = ( + getattr(latent_format, "preserve_empty_channel_multiples", False) + and latent_image.ndim == 4 + and latent_image.shape[1] % latent_format.latent_channels == 0 + ) + if not preserves_collapsed_channels: + latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) if downscale_ratio_spacial is not None: if downscale_ratio_spacial != latent_format.spacial_downscale_ratio: ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio diff --git a/comfy/samplers.py b/comfy/samplers.py index e31277f7b..25c5a855f 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -464,10 +464,7 @@ def _calc_cond_batch_multigpu(model: BaseModel, conds: list[list[dict]], x_in: t def _handle_batch(device: torch.device, batch_tuple: tuple[comfy.hooks.HookGroup, tuple], results: list[thread_result]): try: - # TODO: non-NVIDIA support -- guard with `if device.type == "cuda":` once - # we extend multigpu QA beyond CUDA. Unconditional call crashes on - # XPU/NPU/MPS/CPU/DirectML backends. - torch.cuda.set_device(device) + comfy.model_management.set_torch_device(device) model_current: BaseModel = model_options["multigpu_clones"][device].model # run every hooked_to_run separately with torch.no_grad(): diff --git a/comfy/sd.py b/comfy/sd.py index 30b877b85..8ac08ac42 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -1,3 +1,4 @@ +import inspect import json import torch from enum import Enum @@ -16,6 +17,8 @@ import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae +import comfy.ldm.seedvr.vae +import comfy.ldm.triposplat.vae import comfy.ldm.ace.vae.music_dcae_pipeline import comfy.ldm.cogvideo.vae import comfy.ldm.hunyuan_video.vae @@ -57,6 +60,7 @@ import comfy.text_encoders.omnigen2 import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.z_image +import comfy.text_encoders.ideogram4 import comfy.text_encoders.ovis import comfy.text_encoders.kandinsky5 import comfy.text_encoders.jina_clip_2 @@ -82,6 +86,36 @@ import comfy.latent_formats import comfy.ldm.flux.redux +SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL = 160 + + +def _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w): + output_t = max(1, (latent_t - 1) * 4 + 1) + return output_t * latent_h * 8 * latent_w * 8 + + +def _seedvr2_vae_decode_memory_used(shape): + if len(shape) == 5: + candidates = [] + if shape[1] == 16: + candidates.append((shape[2], shape[3], shape[4])) + if shape[-1] == 16: + candidates.append((shape[1], shape[2], shape[3])) + if len(candidates) == 0: + candidates.append((shape[2], shape[3], shape[4])) + output_pixels = max(_seedvr2_vae_decode_output_pixels(*candidate) for candidate in candidates) + elif len(shape) == 4: + latent_t = max(1, (shape[1] + 15) // 16) + latent_h, latent_w = shape[2], shape[3] + output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w) + else: + latent_t, latent_h, latent_w = 1, shape[-2], shape[-1] + output_pixels = _seedvr2_vae_decode_output_pixels(latent_t, latent_h, latent_w) + # SeedVR2 decode performs full-frame LAB histogram matching: fp32 channels + # plus int64 sort indices dominate peak memory, not the VAE weight dtype. + return output_pixels * SEEDVR2_VAE_DECODE_BYTES_PER_OUTPUT_PIXEL + + def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None): key_map = {} if model is not None: @@ -465,8 +499,10 @@ class CLIP: class VAE: def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): - if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format - sd = diffusers_convert.convert_vae_state_dict(sd) + is_seedvr2_vae = "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd + if not is_seedvr2_vae and 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format + if metadata is None or metadata.get("keep_diffusers_format") != "true": + sd = diffusers_convert.convert_vae_state_dict(sd) if model_management.is_amd(): VAE_KL_MEM_RATIO = 2.73 @@ -538,6 +574,20 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 + elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2 + self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper() + self.latent_channels = 16 + self.latent_dim = 3 + self.disable_offload = True + self.memory_used_decode = lambda shape, dtype: _seedvr2_vae_decode_memory_used(shape) + self.memory_used_encode = lambda shape, dtype: (max(shape[2], 5) * shape[3] * shape[4] * 64) * model_management.dtype_size(dtype) + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.process_input = lambda image: image * 2.0 - 1.0 + self.crop_input = False elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} @@ -665,6 +715,7 @@ class VAE: self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] + elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] @@ -894,6 +945,16 @@ class VAE: #Force cast it for --disable-dynamic-vram users until there is a true core fix. if not comfy.memory_management.aimdo_enabled: self.disable_offload = True + elif "gs.base_offset_scale" in sd and "octree.out_proj.weight" in sd: # TripoSplat octree gaussian decoder + self.first_stage_model = comfy.ldm.triposplat.vae.OctreeGaussianDecoder() + self.latent_channels = 16 + self.latent_dim = 1 + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + # The generic VAE.encode/decode path isn't used: VAEDecodeTripoSplat calls the gaussian + # decoder directly (structured GaussianSplat objects, not a tensor and reserves VRAM itself from num_gaussians. + def _no_generic_io(*args, **kwargs): + raise RuntimeError("TripoSplat gaussian decoder: use the 'TripoSplat Decode' (VAEDecodeTripoSplat)") + self.memory_used_encode = self.memory_used_decode = _no_generic_io else: logging.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -994,6 +1055,40 @@ class VAE: decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) + def decode_tiled_seedvr2(self, samples, tile_x=32, tile_y=32, overlap=8, tile_t=16, overlap_t=4): + sf_s = getattr(self.first_stage_model, "spatial_downsample_factor", 8) + sf_t = getattr(self.first_stage_model, "temporal_downsample_factor", 4) + if tile_t is None: + tile_t = 16 + if overlap_t is None: + overlap_t = 4 + if tile_t > 0: + temporal_size = tile_t * sf_t + temporal_overlap = max(0, overlap_t) * sf_t + else: + temporal_size = 0 + temporal_overlap = 0 + args = { + "enable_tiling": True, + "tile_size": (tile_y * sf_s, tile_x * sf_s), + "tile_overlap": (overlap * sf_s, overlap * sf_s), + "temporal_size": temporal_size, + "temporal_overlap": temporal_overlap, + } + output = self.first_stage_model.decode( + samples.to(self.vae_dtype).to(self.device), + seedvr2_tiling=args, + ) + return self.process_output(output.to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)) + + def _format_seedvr2_encoded_samples(self, samples): + if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): + if samples.ndim == 4: + samples = samples.unsqueeze(2) + samples = samples.contiguous() + samples = samples * 0.9152 + return samples + def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): steps = pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x // 2, tile_y * 2, overlap) @@ -1030,6 +1125,36 @@ class VAE: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) + def encode_tiled_seedvr2(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + if tile_y is None: + tile_y = 512 + if tile_x is None: + tile_x = 512 + if overlap is None: + overlap_y = 64 + overlap_x = 64 + else: + overlap_y = overlap + overlap_x = overlap + if tile_t is None: + tile_t = 9999 + if overlap_t is None: + overlap_t = 0 + overlap_y = min(overlap_y, max(0, tile_y - 8)) + overlap_x = min(overlap_x, max(0, tile_x - 8)) + self.first_stage_model.device = self.device + x = self.process_input(pixel_samples).to(self.vae_dtype).to(self.device) + output = comfy.ldm.seedvr.vae.tiled_vae( + x, + self.first_stage_model, + tile_size=(tile_y, tile_x), + tile_overlap=(overlap_y, overlap_x), + temporal_size=tile_t, + temporal_overlap=overlap_t, + encode=True, + ) + return output.to(device=self.output_device, dtype=self.vae_output_dtype()) + def decode(self, samples_in, vae_options={}): self.throw_exception_if_invalid() pixel_samples = None @@ -1077,16 +1202,40 @@ class VAE: if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) elif dims == 2: - pixel_samples = self.decode_tiled_(samples_in) + # SeedVR2 latents arrive in 4D collapsed form ``(B, 16*T, H, W)`` + # downstream of ``SeedVR2Conditioning`` (which performs the + # ``rearrange(b c t h w -> b (c t) h w)`` collapse). The + # generic ``decode_tiled_`` would treat the channel dim as + # spatial-only and crash on the collapsed (16, T) layout + # under ``tiled_scale``'s mask broadcast; route SeedVR2 4D + # latents to ``decode_tiled_seedvr2`` instead, whose wrapper + # dispatch handles both 4D and 5D inputs. + if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): + tile = 256 // self.spacial_compression_decode() + overlap = tile // 4 + pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) + else: + pixel_samples = self.decode_tiled_(samples_in) elif dims == 3: tile = 256 // self.spacial_compression_decode() overlap = tile // 4 - pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): + pixel_samples = self.decode_tiled_seedvr2(samples_in, tile_x=tile, tile_y=tile, overlap=overlap) + else: + pixel_samples = self.decode_tiled_3d(samples_in, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) pixel_samples = pixel_samples.to(self.output_device).movedim(1,-1) return pixel_samples - def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): + def decode_tiled( + self, + samples, + tile_x=None, + tile_y=None, + overlap=None, + tile_t=None, + overlap_t=None, + ): self.throw_exception_if_invalid() memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -1100,7 +1249,20 @@ class VAE: args["overlap"] = overlap with model_management.cuda_device_context(self.device): - if dims == 1 or self.extra_1d_channel is not None: + if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper) and dims in (2, 3): + seedvr2_args = {} + if tile_x is not None: + seedvr2_args["tile_x"] = tile_x + if tile_y is not None: + seedvr2_args["tile_y"] = tile_y + if overlap is not None: + seedvr2_args["overlap"] = overlap + if tile_t is not None: + seedvr2_args["tile_t"] = tile_t + if overlap_t is not None: + seedvr2_args["overlap_t"] = overlap_t + output = self.decode_tiled_seedvr2(samples, **seedvr2_args) + elif dims == 1 or self.extra_1d_channel is not None: args.pop("tile_y") output = self.decode_tiled_1d(samples, **args) elif dims == 2: @@ -1142,6 +1304,8 @@ class VAE: else: pixels_in = pixels_in.to(self.device) out = self.first_stage_model.encode(pixels_in) + if isinstance(out, tuple): + out = out[0] out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) @@ -1161,20 +1325,23 @@ class VAE: if self.latent_dim == 3: tile = 256 overlap = tile // 4 - samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) + if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): + samples = self.encode_tiled_seedvr2(pixel_samples, tile_x=tile, tile_y=tile, overlap=overlap) + else: + samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap)) elif self.latent_dim == 1 or self.extra_1d_channel is not None: samples = self.encode_tiled_1d(pixel_samples) else: samples = self.encode_tiled_(pixel_samples) - return samples + return self._format_seedvr2_encoded_samples(samples) def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): self.throw_exception_if_invalid() pixel_samples = self.vae_encode_crop_pixels(pixel_samples) dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) - if dims == 3: + if dims == 3 and pixel_samples.ndim < 5: if not self.not_video: pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0) else: @@ -1198,22 +1365,47 @@ class VAE: elif dims == 2: samples = self.encode_tiled_(pixel_samples, **args) elif dims == 3: - if tile_t is not None: - tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) + if isinstance(self.first_stage_model, comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper): + seedvr2_args = {} + if tile_x is not None: + seedvr2_args["tile_x"] = tile_x + else: + seedvr2_args["tile_x"] = 512 + if tile_y is not None: + seedvr2_args["tile_y"] = tile_y + else: + seedvr2_args["tile_y"] = 512 + if overlap is not None: + seedvr2_args["overlap"] = overlap + else: + seedvr2_args["overlap"] = 64 + if tile_t is not None: + seedvr2_args["tile_t"] = tile_t + else: + seedvr2_args["tile_t"] = 9999 + if overlap_t is not None: + seedvr2_args["overlap_t"] = overlap_t + else: + seedvr2_args["overlap_t"] = 0 + samples = self.encode_tiled_seedvr2(pixel_samples, **seedvr2_args) else: - tile_t_latent = 9999 - args["tile_t"] = self.upscale_ratio[0](tile_t_latent) + if tile_t is not None: + tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) + else: + tile_t_latent = 9999 + args["tile_t"] = self.upscale_ratio[0](tile_t_latent) - if overlap_t is None: - args["overlap"] = (1, overlap, overlap) - else: - args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) - maximum = pixel_samples.shape[2] - maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) + spatial_overlap = overlap if overlap is not None else 64 + if overlap_t is None: + args["overlap"] = (1, spatial_overlap, spatial_overlap) + else: + args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), spatial_overlap, spatial_overlap) + maximum = pixel_samples.shape[2] + maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) - samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) + samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) - return samples + return self._format_seedvr2_encoded_samples(samples) def get_sd(self): return self.first_stage_model.state_dict() @@ -1287,6 +1479,7 @@ class CLIPType(Enum): COGVIDEOX = 27 LENS = 28 PIXELDIT = 29 + IDEOGRAM4 = 30 @@ -1585,8 +1778,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer elif te_model == TEModel.QWEN3_8B: - clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b") - clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B + if clip_type == CLIPType.IDEOGRAM4: + clip_target.clip = comfy.text_encoders.ideogram4.te(**llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.ideogram4.Ideogram4Tokenizer + else: + clip_target.clip = comfy.text_encoders.flux.klein_te(**llama_detect(clip_data), model_type="qwen3_8b") + clip_target.tokenizer = comfy.text_encoders.flux.KleinTokenizer8B elif te_model == TEModel.JINA_CLIP_2: clip_target.clip = comfy.text_encoders.jina_clip_2.JinaClip2TextModelWrapper clip_target.tokenizer = comfy.text_encoders.jina_clip_2.JinaClip2TokenizerWrapper @@ -1735,6 +1932,17 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl return (model, clip, vae) + +def _set_model_config_inference_dtype(model_config, dtype, manual_cast_dtype, device): + set_dtype = model_config.set_inference_dtype + parameters = inspect.signature(set_dtype).parameters + supports_device = "device" in parameters or any(p.kind == inspect.Parameter.VAR_KEYWORD for p in parameters.values()) + if supports_device: + set_dtype(dtype, manual_cast_dtype, device=device) + else: + set_dtype(dtype, manual_cast_dtype) + + def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, disable_dynamic=False): sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata, disable_dynamic=disable_dynamic) @@ -1842,7 +2050,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) + _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) if model_config.clip_vision_prefix is not None: if output_clipvision: @@ -1983,7 +2191,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None, disable manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) - model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) + _set_model_config_inference_dtype(model_config, unet_dtype, manual_cast_dtype, load_device) if custom_operations is not None: model_config.custom_operations = custom_operations diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 00941da53..fa95003cc 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -24,6 +24,7 @@ import comfy.text_encoders.qwen_image import comfy.text_encoders.hunyuan_image import comfy.text_encoders.kandinsky5 import comfy.text_encoders.z_image +import comfy.text_encoders.ideogram4 import comfy.text_encoders.anima import comfy.text_encoders.ace15 import comfy.text_encoders.longcat_image @@ -1538,6 +1539,30 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2): latent_format = latent_formats.Hunyuan3Dv2mini +class TripoSplat(supported_models_base.BASE): + # Image -> 3D gaussian splat flow denoiser + unet_config = { + "image_model": "triposplat", + } + + unet_extra_config = {} + + sampling_settings = { + "shift": 3.0, + } + + memory_usage_factor = 0.6 + + latent_format = latent_formats.TripoSplat + + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.TripoSplat(self, device=device) + + def clip_target(self, state_dict={}): + return None + class HiDream(supported_models_base.BASE): unet_config = { "image_model": "hidream", @@ -1647,6 +1672,35 @@ class Chroma(supported_models_base.BASE): t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect)) +class SeedVR2(supported_models_base.BASE): + unet_config = { + "image_model": "seedvr2" + } + latent_format = comfy.latent_formats.SeedVR2 + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + sampling_settings = { + "shift": 1.0, + } + + def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): + if ( + dtype == torch.float16 + and manual_cast_dtype is None + and comfy.model_management.should_use_bf16(device) + ): + manual_cast_dtype = torch.bfloat16 + super().set_inference_dtype(dtype, manual_cast_dtype, device=device) + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.SeedVR2(self, device=device) + return out + + def clip_target(self, state_dict={}): + return None + class ChromaRadiance(Chroma): unet_config = { "image_model": "chroma_radiance", @@ -1722,6 +1776,44 @@ class Omnigen2(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_3b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.omnigen2.Omnigen2Tokenizer, comfy.text_encoders.omnigen2.te(**hunyuan_detect)) +class Ideogram4(supported_models_base.BASE): + unet_config = { + "image_model": "ideogram4", + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 1.0, + } + + memory_usage_factor = 11.6 + + unet_extra_config = { + "num_attention_heads": 18, + "attention_head_dim": 256, + "intermediate_size": 12288, + "adaln_dim": 512, + "llm_features_dim": 53248, + "rope_theta": 5000000, + "mrope_section": [24, 20, 20], + "norm_eps": 1e-5, + } + latent_format = latent_formats.Flux2 + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Ideogram4(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3vl_8b.transformer.".format(pref)) + return supported_models_base.ClipTarget(comfy.text_encoders.ideogram4.Ideogram4Tokenizer, comfy.text_encoders.ideogram4.te(**hunyuan_detect)) + class QwenImage(supported_models_base.BASE): unet_config = { "image_model": "qwen_image", @@ -1966,7 +2058,6 @@ class LongCatImage(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect)) - class RT_DETR_v4(supported_models_base.BASE): unet_config = { "image_model": "RT_DETR_v4", @@ -2200,14 +2291,17 @@ models = [ Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, + TripoSplat, HiDream, HiDreamO1, Chroma, + SeedVR2, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, + Ideogram4, Flux2, Lens, Kandinsky5Image, diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index 0e7a829ba..572f9984e 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -115,7 +115,7 @@ class BASE: replace_prefix = {"": self.vae_key_prefix[0]} return utils.state_dict_prefix_replace(state_dict, replace_prefix) - def set_inference_dtype(self, dtype, manual_cast_dtype): + def set_inference_dtype(self, dtype, manual_cast_dtype, device=None): self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype diff --git a/comfy/text_encoders/ideogram4.py b/comfy/text_encoders/ideogram4.py new file mode 100644 index 000000000..55e655d67 --- /dev/null +++ b/comfy/text_encoders/ideogram4.py @@ -0,0 +1,77 @@ +"""Ideogram 4 text encoder: Qwen3-VL-8B language model, 13-layer tap. + +Ideogram 4 conditions on the concatenation of hidden states from 13 layers of +Qwen3-VL (layers 0,3,...,33,35), giving a 4096*13 = 53248-dim feature per token. +""" + +import os + +from transformers import Qwen2Tokenizer + +import comfy.text_encoders.llama +from comfy import sd1_clip + +# Reference taps outputs of layers (0,3,...,35); comfy captures layer inputs, offset by +1. +IDEOGRAM4_TAP_LAYERS = [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31, 34, 36] + + +class Qwen3VLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, + embedding_size=4096, embedding_key='qwen3vl_8b', tokenizer_class=Qwen2Tokenizer, + has_start_token=False, has_end_token=False, pad_to_max_length=False, + max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + + +class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, + name="qwen3vl_8b", tokenizer=Qwen3VLTokenizer) + + self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + + +# Qwen3-VL-8B = 5e6 (vs plain Qwen3-8B's 1e6) +# final_norm/lm_head off -> Ideogram only reads raw tapped hidden states +QWEN3VL_8B_CONFIG = {"rope_theta": 5000000.0, "final_norm": False, "lm_head": False} + + +class Qwen3VL8BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=IDEOGRAM4_TAP_LAYERS, layer_idx=None, + textmodel_json_config=dict(QWEN3VL_8B_CONFIG), + dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, + model_class=comfy.text_encoders.llama.Qwen3_8B, + enable_attention_masks=attention_mask, return_attention_masks=attention_mask, + model_options=model_options) + + +class Ideogram4TEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3vl_8b", clip_model=Qwen3VL8BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + b, n, seq, h = out.shape # (B, n_taps=13, seq, 4096) stacked in ascending layer order. + out = out.permute(0, 2, 3, 1).reshape(b, seq, h * n) # (B, seq, 4096*13). permute -> (B, seq, H, taps). + return out, pooled, extra + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class Ideogram4TEModel_(Ideogram4TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Ideogram4TEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index 49ae12b06..09d783fff 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -85,9 +85,9 @@ _TYPES = { def load_safetensors(ckpt): import comfy_aimdo.model_mmap - f = open(ckpt, "rb", buffering=0) file_lock = threading.Lock() model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt) + f = model_mmap.get_file_handle() file_size = os.path.getsize(ckpt) mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get())) @@ -1452,3 +1452,10 @@ def deepcopy_list_dict(obj, memo=None): memo[obj_id] = res return res + +def bit_reverse_range(index, bits): + result = 0 + for _ in range(bits): + result = (result << 1) | (index & 1) + index >>= 1 + return result diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index e0a585b10..294ad425e 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -5,7 +5,7 @@ from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput from ._input_impl import VideoFromFile, VideoFromComponents -from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, File3D +from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL, SPLAT, File3D from . import _io_public as io from . import _ui_public as ui from comfy_execution.utils import get_executing_context @@ -143,6 +143,7 @@ class Types: VideoComponents = VideoComponents MESH = MESH VOXEL = VOXEL + SPLAT = SPLAT File3D = File3D diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index 451e9526e..8fff52c16 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -65,6 +65,12 @@ class VideoInput(ABC): buffer.seek(0) return buffer + def get_active_trim_window(self) -> tuple[float, float]: + """Return the active trim as ``(start_time, duration)`` in seconds (start_time normalized + to ``>= 0``; ``duration == 0`` means "until the end"). Default: no trim; trimmable subclasses override. + """ + return 0.0, 0.0 + # Provide a default implementation, but subclasses can provide optimized versions # if possible. def get_dimensions(self) -> tuple[int, int]: diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index 99e67d363..4a12ff9c1 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -75,6 +75,12 @@ class VideoFromFile(VideoInput): self.__file.seek(0) return self.__file + def get_active_trim_window(self) -> tuple[float, float]: + start_time = self.__start_time + if start_time < 0: + start_time = max(self._get_raw_duration() + start_time, 0.0) + return float(start_time), float(self.__duration) + def get_dimensions(self) -> tuple[int, int]: """ Returns the dimensions of the video input. diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index 5ed968960..37614a4c3 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) from comfy_execution.graph_utils import ExecutionBlocker -from ._util import MESH, VOXEL, SVG as _SVG, File3D +from ._util import MESH, VOXEL, SPLAT, SVG as _SVG, File3D class FolderType(str, Enum): @@ -684,6 +684,10 @@ class Voxel(ComfyTypeIO): class Mesh(ComfyTypeIO): Type = MESH +@comfytype(io_type="SPLAT") +class Splat(ComfyTypeIO): + Type = SPLAT + @comfytype(io_type="FILE_3D") class File3DAny(ComfyTypeIO): @@ -727,6 +731,42 @@ class File3DUSDZ(ComfyTypeIO): Type = File3D +@comfytype(io_type="FILE_3D_PLY") +class File3DPLY(ComfyTypeIO): + """PLY format 3D file - point cloud or Gaussian splat.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_SPLAT") +class File3DSPLAT(ComfyTypeIO): + """SPLAT format 3D file - 3D Gaussian splat.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_SPZ") +class File3DSPZ(ComfyTypeIO): + """SPZ format 3D file - compressed 3D Gaussian splat.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_KSPLAT") +class File3DKSPLAT(ComfyTypeIO): + """KSPLAT format 3D file - 3D Gaussian splat.""" + Type = File3D + + +@comfytype(io_type="FILE_3D_SPLAT_ANY") +class File3DSplatAny(ComfyTypeIO): + """General 3D Gaussian splat file type - accepts any supported splat container (.ply / .spz / .splat / .ksplat).""" + Type = File3D + + +@comfytype(io_type="FILE_3D_POINT_CLOUD_ANY") +class File3DPointCloudAny(ComfyTypeIO): + """General point cloud file type - accepts any supported point cloud container (currently .ply).""" + Type = File3D + + @comfytype(io_type="HOOKS") class Hooks(ComfyTypeIO): if TYPE_CHECKING: @@ -762,14 +802,32 @@ class Accumulation(ComfyTypeIO): @comfytype(io_type="LOAD3D_CAMERA") class Load3DCamera(ComfyTypeIO): class CameraInfo(TypedDict): - position: dict[str, float | int] - target: dict[str, float | int] - zoom: int - cameraType: str + # Coordinate system: right-handed, Y-up, camera looks down -Z + position: dict[str, float | int] # scene units + target: dict[str, float | int] # scene units; OrbitControls focus point + zoom: float | int # dimensionless, 1 = 100% + cameraType: str # 'perspective' | 'orthographic' + quaternion: NotRequired[dict[str, float | int]] # normalized, dimensionless; camera world rotation + fov: NotRequired[float | int] # degrees, vertical FOV (perspective only) + aspect: NotRequired[float | int] # width / height (perspective only) + near: NotRequired[float | int] # scene units + far: NotRequired[float | int] # scene units + frustum: NotRequired[dict[str, float | int]] # orthographic only: {left, right, top, bottom} in scene units Type = CameraInfo +@comfytype(io_type="LOAD3D_MODEL_INFO") +class Load3DModelInfo(ComfyTypeIO): + class Model3DTransform(TypedDict): + # Coordinate system: right-handed, Y-up, world space + position: dict[str, float | int] # scene units + quaternion: dict[str, float | int] # normalized, dimensionless; world rotation + scale: dict[str, float | int] # dimensionless multiplier + + Type = list[Model3DTransform] + + @comfytype(io_type="LOAD_3D") class Load3D(ComfyTypeIO): """3D models are stored as a dictionary.""" @@ -779,6 +837,7 @@ class Load3D(ComfyTypeIO): normal: str camera_info: Load3DCamera.CameraInfo recording: NotRequired[str] + model_3d_info: NotRequired[list[Load3DModelInfo.Model3DTransform]] Type = Model3DDict @@ -2277,6 +2336,7 @@ __all__ = [ "LossMap", "Voxel", "Mesh", + "Splat", "File3DAny", "File3DGLB", "File3DGLTF", @@ -2284,6 +2344,12 @@ __all__ = [ "File3DOBJ", "File3DSTL", "File3DUSDZ", + "File3DPLY", + "File3DSPLAT", + "File3DSPZ", + "File3DKSPLAT", + "File3DSplatAny", + "File3DPointCloudAny", "Hooks", "HookKeyframes", "TimestepsRange", @@ -2291,6 +2357,7 @@ __all__ = [ "FlowControl", "Accumulation", "Load3DCamera", + "Load3DModelInfo", "Load3D", "Load3DAnimation", "Photomaker", diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index e238cdf3c..b48713d41 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -285,7 +285,7 @@ class AudioSaveHelper: results = [] for batch_number, waveform in enumerate(audio["waveform"].cpu()): filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) - file = f"{filename_with_batch_num}_{counter:05}_.{format}" + file = f"{filename_with_batch_num}_{counter:05}.{format}" output_path = os.path.join(full_output_folder, file) # Use original sample rate initially @@ -452,6 +452,16 @@ class PreviewUI3D(_UIOutput): return {"result": [self.model_file, self.camera_info, self.bg_image_path]} +class PreviewUI3DAdvanced(_UIOutput): + def __init__(self, model_file, camera_info, model_3d_info): + self.model_file = model_file + self.camera_info = camera_info + self.model_3d_info = model_3d_info + + def as_dict(self): + return {"result": [self.model_file, self.camera_info, self.model_3d_info]} + + class PreviewText(_UIOutput): def __init__(self, value: str, **kwargs): self.value = value @@ -471,5 +481,6 @@ __all__ = [ "PreviewAudio", "PreviewVideo", "PreviewUI3D", + "PreviewUI3DAdvanced", "PreviewText", ] diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index 115baf392..b27f5a97e 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,5 +1,5 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents -from .geometry_types import VOXEL, MESH, File3D +from .geometry_types import VOXEL, MESH, SPLAT, File3D from .image_types import SVG __all__ = [ @@ -9,6 +9,7 @@ __all__ = [ "VideoComponents", "VOXEL", "MESH", + "SPLAT", "File3D", "SVG", ] diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py index cdde60b10..84a18d69a 100644 --- a/comfy_api/latest/_util/geometry_types.py +++ b/comfy_api/latest/_util/geometry_types.py @@ -11,13 +11,32 @@ class VOXEL: self.data = data +class SPLAT: + """A batch of 3D Gaussian splats in render-ready (activated, world-space) form. + + Tensors are (B, N, ...) and zero-padded to a common N across the batch; `counts` (B,) holds the + real per-item lengths (None when rows are uniform and no slicing is needed). SH coefficients are + stored as (B, N, K, 3) with K = (sh_degree + 1)**2; the DC (diffuse) term is sh[..., 0, :]. + """ + + def __init__(self, positions: torch.Tensor, scales: torch.Tensor, rotations: torch.Tensor, + opacities: torch.Tensor, sh: torch.Tensor, counts: torch.Tensor | None = None): + self.positions = positions # (B, N, 3) world-space centers + self.scales = scales # (B, N, 3) linear (positive) per-axis std + self.rotations = rotations # (B, N, 4) quaternion wxyz (normalized) + self.opacities = opacities # (B, N, 1) in [0, 1] + self.sh = sh # (B, N, K, 3) spherical-harmonic color coefficients + self.counts = counts # (B,) real lengths, or None + + class MESH: def __init__(self, vertices: torch.Tensor, faces: torch.Tensor, uvs: torch.Tensor | None = None, vertex_colors: torch.Tensor | None = None, texture: torch.Tensor | None = None, vertex_counts: torch.Tensor | None = None, - face_counts: torch.Tensor | None = None): + face_counts: torch.Tensor | None = None, + unlit: bool = False): assert (vertex_counts is None) == (face_counts is None), \ "vertex_counts and face_counts must be provided together (both or neither)" @@ -30,6 +49,8 @@ class MESH: # these hold the real per-item lengths (B,). None means rows are uniform and no slicing is needed. self.vertex_counts = vertex_counts self.face_counts = face_counts + # Render flat / emissive (no scene lighting) when saved, e.g. for gaussian-splat-derived meshes. + self.unlit = unlit class File3D: diff --git a/comfy_api_nodes/apis/beeble.py b/comfy_api_nodes/apis/beeble.py new file mode 100644 index 000000000..90175b214 --- /dev/null +++ b/comfy_api_nodes/apis/beeble.py @@ -0,0 +1,32 @@ +from pydantic import BaseModel, Field + + +class CreateSwitchXRequest(BaseModel): + generation_type: str = Field(...) + source_uri: str = Field(...) + alpha_mode: str = Field(...) + prompt: str | None = Field(None, max_length=2000) + reference_image_uri: str | None = Field(None) + alpha_uri: str | None = Field(None) + max_resolution: int = Field(1080) + callback_url: str | None = Field(None) + idempotency_key: str | None = Field(None, max_length=256, min_length=1) + + +class SwitchXOutputUrls(BaseModel): + render: str | None = Field(None) + source: str | None = Field(None) + alpha: str | None = Field(None) + + +class SwitchXStatusResponse(BaseModel): + id: str = Field(...) + status: str = Field(...) + progress: int | None = Field(None) + generation_type: str | None = Field(None) + alpha_mode: str | None = Field(None) + output: SwitchXOutputUrls | None = Field(None) + error: str | None = Field(None) + created_at: str | None = Field(None) + modified_at: str | None = Field(None) + completed_at: str | None = Field(None) diff --git a/comfy_api_nodes/apis/bfl.py b/comfy_api_nodes/apis/bfl.py index f0665fa09..4c950da84 100644 --- a/comfy_api_nodes/apis/bfl.py +++ b/comfy_api_nodes/apis/bfl.py @@ -1,71 +1,72 @@ from enum import Enum -from typing import Any, Dict, Optional +from typing import Any -from pydantic import BaseModel, Field, confloat, conint - - -class BFLOutputFormat(str, Enum): - png = 'png' - jpeg = 'jpeg' +from pydantic import BaseModel, Field class BFLFluxExpandImageRequest(BaseModel): - prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image') - bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image') - left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image') - right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand') + prompt: str = Field(...) + prompt_upsampling: bool | None = Field(None) + seed: int | None = Field(None) + top: int = Field(...) + bottom: int = Field(...) + left: int = Field(...) + right: int = Field(...) + steps: int = Field(...) + guidance: float = Field(...) + safety_tolerance: int = Field(6) + output_format: str = Field("png") + image: str = Field(None, description="A Base64-encoded string representing the image you wish to expand") class BFLFluxFillImageRequest(BaseModel): - prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' + prompt: str = Field(...) + prompt_upsampling: bool | None = Field(None) + seed: int | None = Field(None) + steps: int = Field(...) + guidance: float = Field(...) + safety_tolerance: int = Field(6) + output_format: str = Field("png") + image: str = Field( + None, description="Base64-encoded string representing the image to modify. Can contain alpha mask if desired.", ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' + mask: str = Field( + None, description="Base64-encoded string representing the mask of the areas you wish to modify." ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] + + +class BFLFluxEraseRequest(BaseModel): + image: str = Field(..., description="A Base64-encoded string representing the image to erase from.") + mask: str = Field( + ..., + description="A Base64-encoded black/white mask matching the input dimensions; " + "white (255) marks areas to remove, black (0) marks areas to preserve.", ) - image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.') - mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.') + dilate_pixels: int = Field(10) + seed: int | None = Field(None) + output_format: str = Field("png") + + +class BFLFluxVTORequest(BaseModel): + prompt: str = Field( + ..., description="Natural-language styling instruction. Required field, but may be an empty string." + ) + person: str = Field(..., description="A Base64-encoded string representing the person image.") + garment: str = Field(..., description="A Base64-encoded string representing the garment reference image.") + seed: int | None = Field(None) + safety_tolerance: int = Field(5) + output_format: str = Field("png") class BFLFluxProGenerateRequest(BaseModel): - prompt: str = Field(..., description='The text prompt for image generation.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.') - height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format') - # image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field( - # None, description='Blend between the prompt and the image prompt.' - # ) + prompt: str = Field(...) + prompt_upsampling: bool | None = Field(None) + seed: int | None = Field(None) + width: int = Field(1024, description="Must be a multiple of 32.") + height: int = Field(768, description="Must be a multiple of 32.") + safety_tolerance: int = Field(6) + output_format: str = Field("png") + image_prompt: str | None = Field(None, description="Optional image to remix in base64 format") class Flux2ProGenerateRequest(BaseModel): @@ -83,55 +84,37 @@ class Flux2ProGenerateRequest(BaseModel): input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation") input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation") input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation") - safety_tolerance: int | None = Field( - 5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5 - ) - output_format: str | None = Field( - "png", description="Output format for the generated image. Can be 'jpeg' or 'png'." - ) + safety_tolerance: int = Field(5) + output_format: str = Field("png") class BFLFluxKontextProGenerateRequest(BaseModel): - prompt: str = Field(..., description='The text prompt for what you wannt to edit.') - input_image: Optional[str] = Field(None, description='Image to edit in base64 format') - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process') - steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=2)] = Field( - 2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) + prompt: str = Field(...) + input_image: str | None = Field(None, description="Image to edit in base64 format") + seed: int | None = Field(None) + guidance: float = Field(...) + steps: int = Field(...) + safety_tolerance: int = Field(2) + output_format: str = Field("png") + aspect_ratio: str | None = Field(None) + prompt_upsampling: bool | None = Field(None) class BFLFluxProUltraGenerateRequest(BaseModel): - prompt: str = Field(..., description='The text prompt for image generation.') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.') - image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format') - image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field( - None, description='Blend between the prompt and the image prompt.' - ) + prompt: str = Field(...) + prompt_upsampling: bool | None = Field(None) + seed: int | None = Field(None) + aspect_ratio: str | None = Field(None) + safety_tolerance: int = Field(6) + output_format: str = Field("png") + raw: bool | None = Field(None) + image_prompt: str | None = Field(None, description="Optional image to remix in base64 format") + image_prompt_strength: float | None = Field(None) class BFLFluxProGenerateResponse(BaseModel): - id: str = Field(..., description="The unique identifier for the generation task.") - polling_url: str = Field(..., description="URL to poll for the generation result.") + id: str = Field(...) + polling_url: str = Field(...) cost: float | None = Field(None, description="Price in cents") @@ -145,7 +128,7 @@ class BFLStatus(str, Enum): class BFLFluxStatusResponse(BaseModel): - id: str = Field(..., description="The unique identifier for the generation task.") - status: BFLStatus = Field(..., description="The status of the task.") - result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).") - progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0) + id: str = Field(...) + status: BFLStatus = Field(...) + result: dict[str, Any] | None = Field(None) + progress: float | None = Field(None, ge=0.0, le=1.0) diff --git a/comfy_api_nodes/apis/bria.py b/comfy_api_nodes/apis/bria.py index e08a519a8..7a98428c3 100644 --- a/comfy_api_nodes/apis/bria.py +++ b/comfy_api_nodes/apis/bria.py @@ -97,3 +97,28 @@ class BriaRemoveVideoBackgroundResult(BaseModel): class BriaRemoveVideoBackgroundResponse(BaseModel): status: str = Field(...) result: BriaRemoveVideoBackgroundResult | None = Field(None) + + +class BriaVideoGreenScreenRequest(BaseModel): + video: str = Field(..., description="Publicly accessible URL of the input video.") + green_shade: str = Field( + default="broadcast_green", + description="Solid chroma-key shade applied behind the foreground " + "(broadcast_green, chroma_green, or blue_screen).", + ) + output_container_and_codec: str = Field(...) + preserve_audio: bool = Field(True) + seed: int = Field(...) + + +class BriaVideoReplaceBackgroundRequest(BaseModel): + video: str = Field(..., description="Publicly accessible URL of the input (foreground) video.") + background_url: str = Field( + ..., + description="Publicly accessible URL of the background image or video to composite behind " + "the foreground. Stretched to the foreground frame; match its aspect ratio for " + "undistorted results.", + ) + output_container_and_codec: str = Field(...) + preserve_audio: bool = Field(True) + seed: int = Field(...) diff --git a/comfy_api_nodes/apis/bytedance.py b/comfy_api_nodes/apis/bytedance.py index 03f4c445b..47f24586c 100644 --- a/comfy_api_nodes/apis/bytedance.py +++ b/comfy_api_nodes/apis/bytedance.py @@ -158,8 +158,9 @@ class SeedanceCreateAssetResponse(BaseModel): class SeedanceVirtualLibraryCreateAssetRequest(BaseModel): - url: str = Field(..., description="Publicly accessible URL of the image asset to upload.") + url: str = Field(..., description="Publicly accessible URL of the asset to upload.") hash: str = Field(..., description="Dedup key. Re-submitting the same hash returns the existing asset id.") + asset_type: str | None = Field(None, description="BytePlus asset type. Defaults to Image server-side when omitted.") # Dollars per 1K tokens, keyed by (model_id, has_video_input). diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py index 22879fe18..caaba8f36 100644 --- a/comfy_api_nodes/apis/gemini.py +++ b/comfy_api_nodes/apis/gemini.py @@ -108,13 +108,19 @@ class GeminiVideoMetadata(BaseModel): startOffset: GeminiOffset | None = Field(None) +class GeminiThinkingConfig(BaseModel): + includeThoughts: bool | None = Field(None) + thinkingLevel: str = Field(...) + + class GeminiGenerationConfig(BaseModel): - maxOutputTokens: int | None = Field(None, ge=16, le=8192) + maxOutputTokens: int | None = Field(None, ge=16, le=65536) seed: int | None = Field(None) stopSequences: list[str] | None = Field(None) temperature: float | None = Field(None, ge=0.0, le=2.0) topK: int | None = Field(None, ge=1) topP: float | None = Field(None, ge=0.0, le=1.0) + thinkingConfig: GeminiThinkingConfig | None = Field(None) class GeminiImageOutputOptions(BaseModel): @@ -128,11 +134,6 @@ class GeminiImageConfig(BaseModel): imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions) -class GeminiThinkingConfig(BaseModel): - includeThoughts: bool | None = Field(None) - thinkingLevel: str = Field(...) - - class GeminiImageGenerationConfig(GeminiGenerationConfig): responseModalities: list[str] | None = Field(None) imageConfig: GeminiImageConfig | None = Field(None) diff --git a/comfy_api_nodes/apis/ideogram.py b/comfy_api_nodes/apis/ideogram.py index 737e18e3b..c5ad9559f 100644 --- a/comfy_api_nodes/apis/ideogram.py +++ b/comfy_api_nodes/apis/ideogram.py @@ -290,3 +290,19 @@ class IdeogramV3Request(BaseModel): None, description='Optional masks for character reference images. When provided, must match the number of character_reference_images. Each mask should be a grayscale image of the same dimensions as the corresponding character reference image. The images should be in JPEG, PNG or WebP format.' ) + + +class IdeogramV4Request(BaseModel): + text_prompt: str | None = Field( + None, + description="Natural-language prompt; Magic Prompt is applied automatically. " + "Supply exactly one of text_prompt or json_prompt.", + ) + json_prompt: dict[str, Any] | None = Field( + None, + description="Structured V4 prompt object consumed directly (disables Magic Prompt). " + "Supply exactly one of text_prompt or json_prompt.", + ) + resolution: str | None = Field(None, description="Output resolution in WIDTHxHEIGHT (e.g. '2048x2048').") + rendering_speed: str | None = Field(None, description="Rendering speed: 'TURBO', 'DEFAULT', or 'QUALITY'.") + enable_copyright_detection: bool | None = Field(None, description="Opt into post-generation copyright detection.") diff --git a/comfy_api_nodes/apis/krea.py b/comfy_api_nodes/apis/krea.py new file mode 100644 index 000000000..6e294a3b7 --- /dev/null +++ b/comfy_api_nodes/apis/krea.py @@ -0,0 +1,46 @@ +"""Pydantic models for the Krea image-generation API.""" + +from pydantic import BaseModel, Field + + +class KreaMoodboard(BaseModel): + id: str = Field(...) + strength: float = Field(default=0.35, ge=-0.5, le=1.5) + + +class KreaImageStyleReference(BaseModel): + strength: float = Field(..., ge=-2.0, le=2.0) + url: str | None = Field(default=None) + + +class KreaGenerateImageRequest(BaseModel): + prompt: str = Field(...) + aspect_ratio: str = Field(...) + resolution: str = Field(...) + seed: int | None = Field(default=None) + creativity: str = Field(default="medium") + moodboards: list[KreaMoodboard] | None = Field(default=None) + image_style_references: list[KreaImageStyleReference] | None = Field(default=None) + + +class KreaJobResult(BaseModel): + urls: list[str] | None = Field(default=None) + style_id: str | None = Field(default=None) + + +class KreaJob(BaseModel): + job_id: str = Field(...) + status: str = Field(...) + created_at: str = Field(...) + completed_at: str | None = Field(default=None) + result: KreaJobResult | None = Field(default=None) + + +class KreaAssetResponse(BaseModel): + id: str = Field(...) + image_url: str = Field(...) + uploaded_at: str = Field(...) + width: float | None = Field(default=None) + height: float | None = Field(default=None) + size_bytes: float | None = Field(default=None) + mime_type: str | None = Field(default=None) diff --git a/comfy_api_nodes/apis/tripo.py b/comfy_api_nodes/apis/tripo.py index bce6b0e89..7ac81d42c 100644 --- a/comfy_api_nodes/apis/tripo.py +++ b/comfy_api_nodes/apis/tripo.py @@ -1,25 +1,25 @@ from enum import Enum -from typing import Optional, Any +from typing import Any from pydantic import BaseModel, Field, RootModel class TripoModelVersion(str, Enum): - v3_1_20260211 = 'v3.1-20260211' - v3_0_20250812 = 'v3.0-20250812' - v2_5_20250123 = 'v2.5-20250123' - v2_0_20240919 = 'v2.0-20240919' - v1_4_20240625 = 'v1.4-20240625' + v3_1_20260211 = "v3.1-20260211" + v3_0_20250812 = "v3.0-20250812" + v2_5_20250123 = "v2.5-20250123" + v2_0_20240919 = "v2.0-20240919" + v1_4_20240625 = "v1.4-20240625" class TripoGeometryQuality(str, Enum): - standard = 'standard' - detailed = 'detailed' + standard = "standard" + detailed = "detailed" class TripoTextureQuality(str, Enum): - standard = 'standard' - detailed = 'detailed' + standard = "standard" + detailed = "detailed" class TripoStyle(str, Enum): @@ -33,6 +33,7 @@ class TripoStyle(str, Enum): ANCIENT_BRONZE = "ancient_bronze" NONE = "None" + class TripoTaskType(str, Enum): TEXT_TO_MODEL = "text_to_model" IMAGE_TO_MODEL = "image_to_model" @@ -45,26 +46,27 @@ class TripoTaskType(str, Enum): STYLIZE_MODEL = "stylize_model" CONVERT_MODEL = "convert_model" + class TripoTextureAlignment(str, Enum): ORIGINAL_IMAGE = "original_image" GEOMETRY = "geometry" + class TripoOrientation(str, Enum): ALIGN_IMAGE = "align_image" DEFAULT = "default" + class TripoOutFormat(str, Enum): GLB = "glb" FBX = "fbx" -class TripoTopology(str, Enum): - BIP = "bip" - QUAD = "quad" class TripoSpec(str, Enum): MIXAMO = "mixamo" TRIPO = "tripo" + class TripoAnimation(str, Enum): IDLE = "preset:idle" WALK = "preset:walk" @@ -83,11 +85,6 @@ class TripoAnimation(str, Enum): SERPENTINE_MARCH = "preset:serpentine:march" AQUATIC_MARCH = "preset:aquatic:march" -class TripoStylizeStyle(str, Enum): - LEGO = "lego" - VOXEL = "voxel" - VORONOI = "voronoi" - MINECRAFT = "minecraft" class TripoConvertFormat(str, Enum): GLTF = "GLTF" @@ -97,6 +94,7 @@ class TripoConvertFormat(str, Enum): STL = "STL" _3MF = "3MF" + class TripoTextureFormat(str, Enum): BMP = "BMP" DPX = "DPX" @@ -108,6 +106,7 @@ class TripoTextureFormat(str, Enum): TIFF = "TIFF" WEBP = "WEBP" + class TripoTaskStatus(str, Enum): QUEUED = "queued" RUNNING = "running" @@ -118,183 +117,223 @@ class TripoTaskStatus(str, Enum): BANNED = "banned" EXPIRED = "expired" + class TripoFbxPreset(str, Enum): BLENDER = "blender" MIXAMO = "mixamo" _3DSMAX = "3dsmax" + class TripoFileTokenReference(BaseModel): - type: Optional[str] = Field(None, description='The type of the reference') + type: str | None = Field(None, description="The type of the reference") file_token: str + class TripoUrlReference(BaseModel): - type: Optional[str] = Field(None, description='The type of the reference') + type: str | None = Field(None, description="The type of the reference") url: str + class TripoObjectStorage(BaseModel): bucket: str key: str + class TripoObjectReference(BaseModel): type: str object: TripoObjectStorage + class TripoFileEmptyReference(BaseModel): pass + class TripoFileReference(RootModel): root: TripoFileTokenReference | TripoUrlReference | TripoObjectReference | TripoFileEmptyReference -class TripoGetStsTokenRequest(BaseModel): - format: str = Field(..., description='The format of the image') class TripoTextToModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task') - prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024) - negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024) - model_version: Optional[TripoModelVersion] = TripoModelVersion.v2_5_20250123 - face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') - texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') - pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') - image_seed: Optional[int] = Field(None, description='The seed for the text') - model_seed: Optional[int] = Field(None, description='The seed for the model') - texture_seed: Optional[int] = Field(None, description='The seed for the texture') - texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard - geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard - style: Optional[TripoStyle] = None - auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') - quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') + type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description="Type of task") + prompt: str = Field(..., description="The text prompt describing the model to generate", max_length=1024) + negative_prompt: str | None = Field(None, description="The negative text prompt", max_length=1024) + model_version: TripoModelVersion | None = TripoModelVersion.v2_5_20250123 + face_limit: int | None = Field(None, description="The number of faces to limit the generation to") + texture: bool | None = Field(True, description="Whether to apply texture to the generated model") + pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model") + image_seed: int | None = Field(None, description="The seed for the text") + model_seed: int | None = Field(None, description="The seed for the model") + texture_seed: int | None = Field(None, description="The seed for the texture") + texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard + geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard + style: TripoStyle | None = None + auto_size: bool | None = Field(False, description="Whether to auto-size the model") + quad: bool | None = Field(False, description="Whether to apply quad to the generated model") + class TripoImageToModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task') - file: TripoFileReference = Field(..., description='The file reference to convert to a model') - model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation') - face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') - texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') - pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') - model_seed: Optional[int] = Field(None, description='The seed for the model') - texture_seed: Optional[int] = Field(None, description='The seed for the texture') - texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard - geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard - texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') - style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model') - auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') - orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT - quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') + type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description="Type of task") + file: TripoFileReference = Field(..., description="The file reference to convert to a model") + model_version: TripoModelVersion | None = Field(None, description="The model version to use for generation") + face_limit: int | None = Field(None, description="The number of faces to limit the generation to") + texture: bool | None = Field(True, description="Whether to apply texture to the generated model") + pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model") + model_seed: int | None = Field(None, description="The seed for the model") + texture_seed: int | None = Field(None, description="The seed for the texture") + texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard + geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard + texture_alignment: TripoTextureAlignment | None = Field( + TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method" + ) + style: TripoStyle | None = Field(None, description="The style to apply to the generated model") + auto_size: bool | None = Field(False, description="Whether to auto-size the model") + orientation: TripoOrientation | None = TripoOrientation.DEFAULT + quad: bool | None = Field(False, description="Whether to apply quad to the generated model") + class TripoMultiviewToModelRequest(BaseModel): type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL - files: list[TripoFileReference] = Field(..., description='The file references to convert to a model') - model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation') - orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection') - face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to') - texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model') - pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model') - model_seed: Optional[int] = Field(None, description='The seed for the model') - texture_seed: Optional[int] = Field(None, description='The seed for the texture') - texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard - geometry_quality: Optional[TripoGeometryQuality] = TripoGeometryQuality.standard - texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE - auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model') - orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model') - quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model') + files: list[TripoFileReference] = Field(..., description="The file references to convert to a model") + model_version: TripoModelVersion | None = Field(None, description="The model version to use for generation") + orthographic_projection: bool | None = Field(False, description="Whether to use orthographic projection") + face_limit: int | None = Field(None, description="The number of faces to limit the generation to") + texture: bool | None = Field(True, description="Whether to apply texture to the generated model") + pbr: bool | None = Field(True, description="Whether to apply PBR to the generated model") + model_seed: int | None = Field(None, description="The seed for the model") + texture_seed: int | None = Field(None, description="The seed for the texture") + texture_quality: TripoTextureQuality | None = TripoTextureQuality.standard + geometry_quality: TripoGeometryQuality | None = TripoGeometryQuality.standard + texture_alignment: TripoTextureAlignment | None = TripoTextureAlignment.ORIGINAL_IMAGE + auto_size: bool | None = Field(False, description="Whether to auto-size the model") + orientation: TripoOrientation | None = Field(TripoOrientation.DEFAULT, description="The orientation for the model") + quad: bool | None = Field(False, description="Whether to apply quad to the generated model") + class TripoTextureModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task') - original_model_task_id: str = Field(..., description='The task ID of the original model') - texture: Optional[bool] = Field(True, description='Whether to apply texture to the model') - pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model') - model_seed: Optional[int] = Field(None, description='The seed for the model') - texture_seed: Optional[int] = Field(None, description='The seed for the texture') - texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture') - texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method') + type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description="Type of task") + original_model_task_id: str = Field(..., description="The task ID of the original model") + texture: bool | None = Field(True, description="Whether to apply texture to the model") + pbr: bool | None = Field(True, description="Whether to apply PBR to the model") + model_seed: int | None = Field(None, description="The seed for the model") + texture_seed: int | None = Field(None, description="The seed for the texture") + texture_quality: TripoTextureQuality | None = Field(None, description="The quality of the texture") + texture_alignment: TripoTextureAlignment | None = Field( + TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method" + ) + class TripoRefineModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task') - draft_model_task_id: str = Field(..., description='The task ID of the draft model') + type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description="Type of task") + draft_model_task_id: str = Field(..., description="The task ID of the draft model") -class TripoAnimatePrerigcheckRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task') - original_model_task_id: str = Field(..., description='The task ID of the original model') class TripoAnimateRigRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task') - original_model_task_id: str = Field(..., description='The task ID of the original model') - out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format') - spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging') + type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description="Type of task") + original_model_task_id: str = Field(..., description="The task ID of the original model") + out_format: TripoOutFormat | None = Field(TripoOutFormat.GLB, description="The output format") + spec: TripoSpec | None = Field(TripoSpec.TRIPO, description="The specification for rigging") + class TripoAnimateRetargetRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task') - original_model_task_id: str = Field(..., description='The task ID of the original model') - animation: TripoAnimation = Field(..., description='The animation to apply') - out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format') - bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation') + type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description="Type of task") + original_model_task_id: str = Field(..., description="The task ID of the original model") + animation: TripoAnimation = Field(..., description="The animation to apply") + out_format: TripoOutFormat | None = Field(TripoOutFormat.GLB, description="The output format") + bake_animation: bool | None = Field(True, description="Whether to bake the animation") -class TripoStylizeModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task') - style: TripoStylizeStyle = Field(..., description='The style to apply to the model') - original_model_task_id: str = Field(..., description='The task ID of the original model') - block_size: Optional[int] = Field(80, description='The block size for stylization') class TripoConvertModelRequest(BaseModel): - type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task') - format: TripoConvertFormat = Field(..., description='The format to convert to') - original_model_task_id: str = Field(..., description='The task ID of the original model') - quad: Optional[bool] = Field(None, description='Whether to apply quad to the model') - force_symmetry: Optional[bool] = Field(None, description='Whether to force symmetry') - face_limit: Optional[int] = Field(None, description='The number of faces to limit the conversion to') - flatten_bottom: Optional[bool] = Field(None, description='Whether to flatten the bottom of the model') - flatten_bottom_threshold: Optional[float] = Field(None, description='The threshold for flattening the bottom') - texture_size: Optional[int] = Field(None, description='The size of the texture') - texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture') - pivot_to_center_bottom: Optional[bool] = Field(None, description='Whether to pivot to the center bottom') - scale_factor: Optional[float] = Field(None, description='The scale factor for the model') - with_animation: Optional[bool] = Field(None, description='Whether to include animations') - pack_uv: Optional[bool] = Field(None, description='Whether to pack the UVs') - bake: Optional[bool] = Field(None, description='Whether to bake the model') - part_names: Optional[list[str]] = Field(None, description='The names of the parts to include') - fbx_preset: Optional[TripoFbxPreset] = Field(None, description='The preset for the FBX export') - export_vertex_colors: Optional[bool] = Field(None, description='Whether to export the vertex colors') - export_orientation: Optional[TripoOrientation] = Field(None, description='The orientation for the export') - animate_in_place: Optional[bool] = Field(None, description='Whether to animate in place') + type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description="Type of task") + format: TripoConvertFormat = Field(..., description="The format to convert to") + original_model_task_id: str = Field(..., description="The task ID of the original model") + quad: bool | None = Field(None, description="Whether to apply quad to the model") + force_symmetry: bool | None = Field(None, description="Whether to force symmetry") + face_limit: int | None = Field(None, description="The number of faces to limit the conversion to") + flatten_bottom: bool | None = Field(None, description="Whether to flatten the bottom of the model") + flatten_bottom_threshold: float | None = Field(None, description="The threshold for flattening the bottom") + texture_size: int | None = Field(None, description="The size of the texture") + texture_format: TripoTextureFormat | None = Field(TripoTextureFormat.JPEG, description="The format of the texture") + pivot_to_center_bottom: bool | None = Field(None, description="Whether to pivot to the center bottom") + scale_factor: float | None = Field(None, description="The scale factor for the model") + with_animation: bool | None = Field(None, description="Whether to include animations") + pack_uv: bool | None = Field(None, description="Whether to pack the UVs") + bake: bool | None = Field(None, description="Whether to bake the model") + part_names: list[str] | None = Field(None, description="The names of the parts to include") + fbx_preset: TripoFbxPreset | None = Field(None, description="The preset for the FBX export") + export_vertex_colors: bool | None = Field(None, description="Whether to export the vertex colors") + export_orientation: TripoOrientation | None = Field(None, description="The orientation for the export") + animate_in_place: bool | None = Field(None, description="Whether to animate in place") + + +class TripoP1CommonRequest(BaseModel): + """Fields supported by Tripo P1 across all input types.""" + + model_version: str = Field("P1-20260311") + model_seed: int | None = Field(None, description="Random seed for geometry generation") + face_limit: int | None = Field(None, ge=48, le=20000, description="Target face count (48-20000)") + texture: bool | None = Field(None, description="Enable texturing; pbr=True forces this true") + pbr: bool | None = Field(None, description="Enable PBR maps; when true, texture is also enabled") + texture_seed: int | None = Field(None, description="Random seed for texture generation") + texture_quality: str | None = Field(None, description='"standard" or "detailed"') + auto_size: bool | None = Field(None, description="Scale to real-world meters") + compress: str | None = Field(None, description='Only "geometry" is supported') + export_uv: bool | None = Field(None, description="Perform UV unwrapping during generation") + + +class TripoP1TextToModelRequest(TripoP1CommonRequest): + type: str = "text_to_model" + prompt: str = Field(..., max_length=1024) + negative_prompt: str | None = Field(None, max_length=255) + image_seed: int | None = None + + +class TripoP1ImageToModelRequest(TripoP1CommonRequest): + type: str = "image_to_model" + file: TripoFileReference + enable_image_autofix: bool | None = None + texture_alignment: str | None = Field(None, description='"original_image" or "geometry"') + orientation: str | None = Field(None, description='"default" or "align_image"; needs texture=true') + + +class TripoP1MultiviewToModelRequest(TripoP1CommonRequest): + """P1 multiview generation. + + Tripo requires `files` to be exactly four entries in [front, left, back, right] order with `{}` + (TripoFileEmptyReference) for omitted slots; front is required and at least two images total must be provided. + """ + + type: str = "multiview_to_model" + files: list[TripoFileReference] + texture_alignment: str | None = None + orientation: str | None = None class TripoTaskOutput(BaseModel): - model: Optional[str] = Field(None, description='URL to the model') - base_model: Optional[str] = Field(None, description='URL to the base model') - pbr_model: Optional[str] = Field(None, description='URL to the PBR model') - rendered_image: Optional[str] = Field(None, description='URL to the rendered image') - riggable: Optional[bool] = Field(None, description='Whether the model is riggable') + model: str | None = Field(None, description="URL to the model") + base_model: str | None = Field(None, description="URL to the base model") + pbr_model: str | None = Field(None, description="URL to the PBR model") + rendered_image: str | None = Field(None, description="URL to the rendered image") + riggable: bool | None = Field(None, description="Whether the model is riggable") + class TripoTask(BaseModel): - task_id: str = Field(..., description='The task ID') - type: Optional[str] = Field(None, description='The type of task') - status: Optional[TripoTaskStatus] = Field(None, description='The status of the task') - input: Optional[dict[str, Any]] = Field(None, description='The input parameters for the task') - output: Optional[TripoTaskOutput] = Field(None, description='The output of the task') - progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100) - create_time: Optional[int] = Field(None, description='The creation time of the task') - running_left_time: Optional[int] = Field(None, description='The estimated time left for the task') - queue_position: Optional[int] = Field(None, description='The position in the queue') + task_id: str = Field(..., description="The task ID") + type: str | None = Field(None, description="The type of task") + status: TripoTaskStatus | None = Field(None, description="The status of the task") + input: dict[str, Any] | None = Field(None, description="The input parameters for the task") + output: TripoTaskOutput | None = Field(None, description="The output of the task") + progress: int | None = Field(None, description="The progress of the task", ge=0, le=100) + create_time: int | None = Field(None, description="The creation time of the task") + running_left_time: int | None = Field(None, description="The estimated time left for the task") + queue_position: int | None = Field(None, description="The position in the queue") consumed_credit: int | None = Field(None) + class TripoTaskResponse(BaseModel): - code: int = Field(0, description='The response code') - data: TripoTask = Field(..., description='The task data') + code: int = Field(0, description="The response code") + data: TripoTask = Field(..., description="The task data") -class TripoGeneralResponse(BaseModel): - code: int = Field(0, description='The response code') - data: dict[str, str] = Field(..., description='The task ID data') - -class TripoBalanceData(BaseModel): - balance: float = Field(..., description='The account balance') - frozen: float = Field(..., description='The frozen balance') - -class TripoBalanceResponse(BaseModel): - code: int = Field(0, description='The response code') - data: TripoBalanceData = Field(..., description='The balance data') class TripoErrorResponse(BaseModel): - code: int = Field(..., description='The error code') - message: str = Field(..., description='The error message') - suggestion: str = Field(..., description='The suggestion for fixing the error') + code: int = Field(..., description="The error code") + message: str = Field(..., description="The error message") + suggestion: str = Field(..., description="The suggestion for fixing the error") diff --git a/comfy_api_nodes/nodes_anthropic.py b/comfy_api_nodes/nodes_anthropic.py index 42ec5708f..87a870553 100644 --- a/comfy_api_nodes/nodes_anthropic.py +++ b/comfy_api_nodes/nodes_anthropic.py @@ -155,7 +155,7 @@ class ClaudeNode(IO.ComfyNode): return IO.Schema( node_id="ClaudeNode", display_name="Anthropic Claude", - category="api node/text/Anthropic", + category="partner/text/Anthropic", essentials_category="Text Generation", description="Generate text responses with Anthropic's Claude models. " "Provide a text prompt and optionally one or more images for multimodal context.", diff --git a/comfy_api_nodes/nodes_beeble.py b/comfy_api_nodes/nodes_beeble.py new file mode 100644 index 000000000..d863c2130 --- /dev/null +++ b/comfy_api_nodes/nodes_beeble.py @@ -0,0 +1,404 @@ +from fractions import Fraction + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl, Types +from comfy_api_nodes.apis.beeble import ( + CreateSwitchXRequest, + SwitchXStatusResponse, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + bytesio_to_image_tensor, + convert_mask_to_image, + download_url_as_bytesio, + download_url_to_image_tensor, + download_url_to_video_output, + downscale_image_tensor, + downscale_video_to_max_pixels, + poll_op, + sync_op, + upload_image_to_comfyapi, + upload_video_to_comfyapi, + validate_string, + validate_video_frame_count, +) + +_MAX_PIXELS = 2_770_000 +_MAX_FRAMES = 240 +_MAX_PROMPT_LEN = 2000 + + +def _validate_inputs(prompt: str | None, reference_image: Input.Image | None) -> str | None: + """Beeble requires at least one of prompt or reference_image. Returns the cleaned prompt.""" + cleaned = prompt.strip() if prompt else "" + if not cleaned and reference_image is None: + raise ValueError("At least one of 'prompt' or 'reference_image' must be provided.") + if cleaned: + validate_string(cleaned, strip_whitespace=False, max_length=_MAX_PROMPT_LEN) + return cleaned or None + + +async def _upload_mask_as_image( + cls: type[IO.ComfyNode], + mask: Input.Image, + *, + wait_label: str, +) -> str: + """Encode a single-frame MASK (H, W) or (1, H, W) as a PNG and upload.""" + if mask.dim() == 2: + mask = mask.unsqueeze(0) + image = convert_mask_to_image(mask[:1]) + return await upload_image_to_comfyapi( + cls, + image, + mime_type="image/png", + wait_label=wait_label, + total_pixels=_MAX_PIXELS, + ) + + +async def _upload_mask_batch_as_video( + cls: type[IO.ComfyNode], + mask: Input.Image, + *, + frame_rate: Fraction, + source_frame_count: int, + wait_label: str, +) -> str: + """Encode a MASK batch (N, H, W) as a grayscale H.264 MP4 at frame_rate and upload. + + The matte is always downscaled to the pixel budget so it stays within Beeble's limit and + keeps the same dimensions as the (similarly downscaled) source — both use the same algorithm + from the same starting dimensions, and downscaling is a no-op when already within budget. + """ + if mask.dim() == 2: + mask = mask.unsqueeze(0) + if mask.shape[0] != source_frame_count: + raise ValueError( + f"Custom alpha video frame count ({mask.shape[0]}) does not match the " + f"source video frame count ({source_frame_count}). The Beeble API requires " + "one mask per source frame." + ) + images = downscale_image_tensor(convert_mask_to_image(mask), _MAX_PIXELS) + alpha_video = InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=None, frame_rate=frame_rate)) + return await upload_video_to_comfyapi(cls, alpha_video, wait_label=wait_label) + + +def _alpha_mode_input(*, video: bool) -> IO.DynamicCombo.Input: + """Build the alpha_mode DynamicCombo with mode-specific extra inputs.""" + select_keyframe_tooltip = ( + "First-frame keyframe mask. Beeble propagates this across the video." if video else "Grayscale keyframe mask." + ) + custom_tooltip = ( + "Per-frame grayscale mask covering the entire video. " + "Must have the same frame count as the source. " + "Connect a MASK output from SAM3_TrackToMask or similar." + if video + else "Grayscale mask to apply." + ) + return IO.DynamicCombo.Input( + "alpha_mode", + tooltip=( + "Controls how SwitchX decides what to keep vs. regenerate. " + "'auto' isolates the main subject automatically. " + "'fill' regenerates the entire frame while preserving geometry. " + "'select' propagates a first-frame keyframe across the clip. " + "'custom' uses a per-frame alpha matte you provide." + ), + options=[ + IO.DynamicCombo.Option("auto", []), + IO.DynamicCombo.Option("fill", []), + IO.DynamicCombo.Option( + "select", + [IO.Mask.Input("alpha_keyframe", tooltip=select_keyframe_tooltip)], + ), + IO.DynamicCombo.Option( + "custom", + [IO.Mask.Input("alpha_mask", tooltip=custom_tooltip)], + ), + ], + ) + + +def _common_inputs(*, source: IO.Input, video: bool) -> list[IO.Input]: + return [ + source, + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip=( + "Text description of the desired output (max 2000 chars). " + "At least one of 'prompt' or 'reference_image' is required." + ), + ), + IO.Image.Input( + "reference_image", + optional=True, + tooltip=( + "Reference image whose look (background, lighting, costume) the result " + "should adopt. At least one of 'reference_image' or 'prompt' is required." + ), + ), + _alpha_mode_input(video=video), + IO.Combo.Input( + "max_resolution", + options=["1080p", "720p"], + default="1080p", + tooltip="Maximum output resolution.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip=( + "Seed controls whether the node should re-run; " "results are non-deterministic regardless of seed." + ), + ), + ] + + +async def _submit_and_poll( + cls: type[IO.ComfyNode], + request: CreateSwitchXRequest, +) -> SwitchXStatusResponse: + initial = await sync_op( + cls, + ApiEndpoint(path="/proxy/beeble/v1/switchx/generations", method="POST"), + response_model=SwitchXStatusResponse, + data=request, + ) + return await poll_op( + cls, + ApiEndpoint(path=f"/proxy/beeble/v1/switchx/generations/{initial.id}"), + response_model=SwitchXStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + ) + + +def _require_output_url(response: SwitchXStatusResponse, name: str) -> str: + if response.output is None or getattr(response.output, name) is None: + raise RuntimeError(f"Beeble job {response.id} completed without a {name!r} output URL.") + return getattr(response.output, name) + + +def _alpha_url(response: SwitchXStatusResponse, mode: str) -> str | None: + """URL of the alpha matte, or None when the mode produces no separate matte. + + 'fill' selects the whole frame, so Beeble writes no alpha asset even though the status + response still returns a (dangling) signed URL for it — fetching it 403s with S3 + AccessDenied. The other three modes ('auto', 'custom', 'select') all produce a real, + downloadable matte. + """ + if mode == "fill" or response.output is None: + return None + return response.output.alpha + + +class BeebleSwitchXVideoEdit(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="BeebleSwitchXVideoEdit", + display_name="Beeble SwitchX Video Edit", + category="partner/video/Beeble", + description=( + "Edit a video with Beeble SwitchX. Switches anything in the scene (background, " + "lighting, costume) while preserving the original subject's pixels and motion. " + "Provide a reference image and/or text prompt to describe the new look. " + "Max 240 frames, max ~2.77MP per frame." + ), + inputs=_common_inputs(source=IO.Video.Input("video"), video=True), + outputs=[ + IO.Video.Output(display_name="video"), + IO.Video.Output( + display_name="alpha", + tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.", + ), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]), + expr=""" + ( + $rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143; + {"type":"usd","usd": $rate, "format":{"suffix":"/30 frames"}} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + prompt: str, + alpha_mode: dict, + max_resolution: str, + seed: int, + reference_image: Input.Image | None = None, + ) -> IO.NodeOutput: + cleaned_prompt = _validate_inputs(prompt, reference_image) + + validate_video_frame_count(video, max_frame_count=_MAX_FRAMES) + video = downscale_video_to_max_pixels(video, _MAX_PIXELS) + + mode = alpha_mode["alpha_mode"] + alpha_uri: str | None = None + if mode == "select": + alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe") + elif mode == "custom": + alpha_uri = await _upload_mask_batch_as_video( + cls, + alpha_mode["alpha_mask"], + frame_rate=video.get_frame_rate(), + source_frame_count=video.get_frame_count(), + wait_label="Uploading alpha video", + ) + + source_uri = await upload_video_to_comfyapi(cls, video, wait_label="Uploading source") + reference_uri: str | None = None + if reference_image is not None: + reference_uri = await upload_image_to_comfyapi( + cls, + reference_image, + mime_type="image/png", + wait_label="Uploading reference", + total_pixels=_MAX_PIXELS, + ) + + request = CreateSwitchXRequest( + generation_type="video", + source_uri=source_uri, + alpha_mode=mode, + prompt=cleaned_prompt, + reference_image_uri=reference_uri, + alpha_uri=alpha_uri, + max_resolution=1080 if max_resolution == "1080p" else 720, + ) + response = await _submit_and_poll(cls, request) + + render = await download_url_to_video_output(_require_output_url(response, "render")) + alpha = None + if (alpha_url := _alpha_url(response, mode)) is not None: + alpha = await download_url_to_video_output(alpha_url) + return IO.NodeOutput(render, alpha) + + +class BeebleSwitchXImageEdit(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="BeebleSwitchXImageEdit", + display_name="Beeble SwitchX Image Edit", + category="partner/image/Beeble", + description=( + "Edit a single image with Beeble SwitchX. Switches anything in the scene " + "(background, lighting, costume) while preserving the original subject's pixels. " + "Provide a reference image and/or text prompt to describe the new look. " + "Max ~2.77MP." + ), + inputs=_common_inputs(source=IO.Image.Input("image"), video=False), + outputs=[ + IO.Image.Output(display_name="image"), + IO.Mask.Output( + display_name="alpha", + tooltip="The alpha matte Beeble used. Empty for 'fill' mode, which has no separate matte.", + ), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["max_resolution"]), + expr=""" + ( + $rate := widgets.max_resolution = "1080p" ? 0.429 : 0.143; + {"type":"usd","usd": $rate} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + prompt: str, + alpha_mode: dict, + max_resolution: str, + seed: int, + reference_image: Input.Image | None = None, + ) -> IO.NodeOutput: + cleaned_prompt = _validate_inputs(prompt, reference_image) + + image = downscale_image_tensor(image, _MAX_PIXELS) + + mode = alpha_mode["alpha_mode"] + alpha_uri: str | None = None + if mode == "select": + alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_keyframe"], wait_label="Uploading keyframe") + elif mode == "custom": + alpha_uri = await _upload_mask_as_image(cls, alpha_mode["alpha_mask"], wait_label="Uploading alpha") + + source_uri = await upload_image_to_comfyapi( + cls, + image, + mime_type="image/png", + wait_label="Uploading source", + total_pixels=None, + ) + reference_uri: str | None = None + if reference_image is not None: + reference_uri = await upload_image_to_comfyapi( + cls, + reference_image, + mime_type="image/png", + wait_label="Uploading reference", + total_pixels=_MAX_PIXELS, + ) + + request = CreateSwitchXRequest( + generation_type="image", + source_uri=source_uri, + alpha_mode=mode, + prompt=cleaned_prompt, + reference_image_uri=reference_uri, + alpha_uri=alpha_uri, + max_resolution=1080 if max_resolution == "1080p" else 720, + ) + response = await _submit_and_poll(cls, request) + + render = await download_url_to_image_tensor(_require_output_url(response, "render")) + alpha_mask = None + if (alpha_url := _alpha_url(response, mode)) is not None: + alpha_image = bytesio_to_image_tensor(await download_url_as_bytesio(alpha_url), mode="L") + alpha_mask = alpha_image.squeeze(-1) if alpha_image.dim() == 4 else alpha_image + return IO.NodeOutput(render, alpha_mask) + + +class BeebleExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + BeebleSwitchXVideoEdit, + BeebleSwitchXImageEdit, + ] + + +async def comfy_entrypoint() -> BeebleExtension: + return BeebleExtension() diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 3f0ce29d8..259c54ef9 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -4,17 +4,20 @@ from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.bfl import ( + BFLFluxEraseRequest, BFLFluxExpandImageRequest, BFLFluxFillImageRequest, BFLFluxKontextProGenerateRequest, BFLFluxProGenerateResponse, BFLFluxProUltraGenerateRequest, BFLFluxStatusResponse, + BFLFluxVTORequest, BFLStatus, Flux2ProGenerateRequest, ) from comfy_api_nodes.util import ( ApiEndpoint, + convert_mask_to_image, download_url_to_image_tensor, get_number_of_images, poll_op, @@ -22,19 +25,11 @@ from comfy_api_nodes.util import ( sync_op, tensor_to_base64_string, validate_aspect_ratio_string, + validate_image_dimensions, validate_string, ) -def convert_mask_to_image(mask: Input.Image): - """ - Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. - """ - mask = mask.unsqueeze(-1) - mask = torch.cat([mask] * 3, dim=-1) - return mask - - class FluxProUltraImageNode(IO.ComfyNode): @classmethod @@ -42,7 +37,7 @@ class FluxProUltraImageNode(IO.ComfyNode): return IO.Schema( node_id="FluxProUltraImageNode", display_name="Flux 1.1 [pro] Ultra Image", - category="api node/image/BFL", + category="partner/image/BFL", description="Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.", inputs=[ IO.String.Input( @@ -160,7 +155,7 @@ class FluxKontextProImageNode(IO.ComfyNode): return IO.Schema( node_id=cls.NODE_ID, display_name=cls.DISPLAY_NAME, - category="api node/image/BFL", + category="partner/image/BFL", description="Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio.", inputs=[ IO.String.Input( @@ -282,7 +277,7 @@ class FluxProExpandNode(IO.ComfyNode): return IO.Schema( node_id="FluxProExpandNode", display_name="Flux.1 Expand Image", - category="api node/image/BFL", + category="partner/image/BFL", description="Outpaints image based on prompt.", inputs=[ IO.Image.Input("image"), @@ -419,7 +414,7 @@ class FluxProFillNode(IO.ComfyNode): return IO.Schema( node_id="FluxProFillNode", display_name="Flux.1 Fill Image", - category="api node/image/BFL", + category="partner/image/BFL", description="Inpaints image based on mask and prompt.", inputs=[ IO.Image.Input("image"), @@ -519,6 +514,174 @@ class FluxProFillNode(IO.ComfyNode): return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) +class FluxEraseNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="FluxEraseNode", + display_name="Flux Erase Image", + category="partner/image/BFL", + description="Removes the masked object from an image and reconstructs the background. " + "Paint the mask over what you want to erase.", + inputs=[ + IO.Image.Input("image"), + IO.Mask.Input("mask", tooltip="White areas are removed; black areas are preserved."), + IO.Int.Input( + "dilate_pixels", + default=10, + min=0, + max=25, + tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + optional=True, + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"range_usd","min_usd":0.03,"max_usd":0.06,"format":{"approximate":true}}""", + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + mask: Input.Image, + dilate_pixels: int = 10, + seed: int = 0, + ) -> IO.NodeOutput: + validate_image_dimensions(image, min_width=256, min_height=256) + mask = resize_mask_to_image(mask, image) + mask = tensor_to_base64_string(convert_mask_to_image(mask)) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/v1/flux-tools/erase-v1", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxEraseRequest( + image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed + mask=mask, + dilate_pixels=dilate_pixels, + seed=seed, + ), + ) + + def price_extractor(_r: BaseModel) -> float | None: + return None if initial_response.cost is None else initial_response.cost / 100 + + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + price_extractor=price_extractor, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + + +class FluxVTONode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="FluxVTONode", + display_name="Flux Virtual Try-On", + category="partner/image/BFL", + description="Virtual try-on: dresses the person in the provided garment.", + inputs=[ + IO.Image.Input("person", tooltip="Image of the person to dress."), + IO.Image.Input("garment", tooltip="Image of the garment to apply."), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Optional natural-language styling instruction (e.g. how the garment should fit).", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="The random seed used for creating the noise.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"range_usd","min_usd":0.0375,"max_usd":0.075,"format":{"approximate":true}}""", + ), + ) + + @classmethod + async def execute( + cls, + person: Input.Image, + garment: Input.Image, + prompt: str = "", + seed: int = 0, + ) -> IO.NodeOutput: + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/v1/flux-tools/vto-v1", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxVTORequest( + prompt=prompt, + person=tensor_to_base64_string(person[:, :, :, :3]), + garment=tensor_to_base64_string(garment[:, :, :, :3]), + seed=seed, + ), + ) + + def price_extractor(_r: BaseModel) -> float | None: + return None if initial_response.cost is None else initial_response.cost / 100 + + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + price_extractor=price_extractor, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) + + class Flux2ProImageNode(IO.ComfyNode): NODE_ID = "Flux2ProImageNode" @@ -545,7 +708,7 @@ class Flux2ProImageNode(IO.ComfyNode): return IO.Schema( node_id=cls.NODE_ID, display_name=cls.DISPLAY_NAME, - category="api node/image/BFL", + category="partner/image/BFL", description="Generates images synchronously based on prompt and resolution.", inputs=[ IO.String.Input( @@ -716,7 +879,7 @@ class Flux2ImageNode(IO.ComfyNode): return IO.Schema( node_id="Flux2ImageNode", display_name="Flux.2 Image", - category="api node/image/BFL", + category="partner/image/BFL", description="Generate images via Flux.2 [pro] or Flux.2 [max] from a prompt and optional reference images.", inputs=[ IO.String.Input( @@ -853,6 +1016,8 @@ class BFLExtension(ComfyExtension): FluxKontextMaxImageNode, FluxProExpandNode, FluxProFillNode, + FluxEraseNode, + FluxVTONode, Flux2ProImageNode, Flux2MaxImageNode, Flux2ImageNode, diff --git a/comfy_api_nodes/nodes_bria.py b/comfy_api_nodes/nodes_bria.py index 4044ee3ea..e138fafa9 100644 --- a/comfy_api_nodes/nodes_bria.py +++ b/comfy_api_nodes/nodes_bria.py @@ -1,14 +1,19 @@ +import av +import torch +from av.codec import CodecContext from typing_extensions import override from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis.bria import ( BriaEditImageRequest, + BriaImageEditResponse, BriaRemoveBackgroundRequest, BriaRemoveBackgroundResponse, BriaRemoveVideoBackgroundRequest, BriaRemoveVideoBackgroundResponse, - BriaImageEditResponse, BriaStatusResponse, + BriaVideoGreenScreenRequest, + BriaVideoReplaceBackgroundRequest, InputModerationSettings, ) from comfy_api_nodes.util import ( @@ -31,7 +36,7 @@ class BriaImageEditNode(IO.ComfyNode): return IO.Schema( node_id="BriaImageEditNode", display_name="Bria FIBO Image Edit", - category="api node/image/Bria", + category="partner/image/Bria", description="Edit images using Bria latest model", inputs=[ IO.Combo.Input("model", options=["FIBO"]), @@ -169,7 +174,7 @@ class BriaRemoveImageBackground(IO.ComfyNode): return IO.Schema( node_id="BriaRemoveImageBackground", display_name="Bria Remove Image Background", - category="api node/image/Bria", + category="partner/image/Bria", description="Remove the background from an image using Bria RMBG 2.0.", inputs=[ IO.Image.Input("image"), @@ -245,7 +250,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode): return IO.Schema( node_id="BriaRemoveVideoBackground", display_name="Bria Remove Video Background", - category="api node/video/Bria", + category="partner/video/Bria", description="Remove the background from a video using Bria. ", inputs=[ IO.Video.Input("video"), @@ -316,6 +321,248 @@ class BriaRemoveVideoBackground(IO.ComfyNode): return IO.NodeOutput(await download_url_to_video_output(response.result.video_url)) +class BriaVideoGreenScreen(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaVideoGreenScreen", + display_name="Bria Video Green Screen", + category="partner/video/Bria", + description="Replace a video's background with a solid chroma-key screen using Bria.", + inputs=[ + IO.Video.Input("video"), + IO.Combo.Input( + "green_shade", + options=["broadcast_green", "chroma_green", "blue_screen"], + tooltip="Solid chroma-key shade applied behind the foreground: " + "broadcast_green (#00B140), chroma_green (#00FF00), or blue_screen (#0000FF).", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + green_shade: str, + seed: int, + ) -> IO.NodeOutput: + validate_video_duration(video, max_duration=60.0) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bria/v2/video/edit/green_screen", method="POST"), + data=BriaVideoGreenScreenRequest( + video=await upload_video_to_comfyapi(cls, video), + green_shade=green_shade, + output_container_and_codec="mp4_h264", + seed=seed, + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaRemoveVideoBackgroundResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.result.video_url)) + + +class BriaVideoReplaceBackground(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaVideoReplaceBackground", + display_name="Bria Video Replace Background", + category="partner/video/Bria", + description="Replace a video's background with a supplied image or video using Bria. " + "The output keeps the foreground's resolution and frame rate; a background with a " + "different aspect ratio is stretched to fit, so match it for undistorted results.", + inputs=[ + IO.Video.Input("video", tooltip="Foreground video whose background is replaced."), + IO.Image.Input( + "background_image", + optional=True, + tooltip="Background image to composite behind the foreground. " + "Provide either a background image or a background video, not both.", + ), + IO.Video.Input( + "background_video", + optional=True, + tooltip="Background video to composite behind the foreground. " + "Provide either a background image or a background video, not both.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[IO.Video.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + seed: int, + background_image: Input.Image | None = None, + background_video: Input.Video | None = None, + ) -> IO.NodeOutput: + if (background_image is None) == (background_video is None): + raise ValueError("Provide either a background image or a background video, not both.") + validate_video_duration(video, max_duration=60.0) + if background_video is not None: + validate_video_duration(background_video, max_duration=60.0) + background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background") + else: + background_url = await upload_image_to_comfyapi(cls, background_image, wait_label="Uploading background") + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"), + data=BriaVideoReplaceBackgroundRequest( + video=await upload_video_to_comfyapi(cls, video), + background_url=background_url, + output_container_and_codec="mp4_h264", + seed=seed, + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaRemoveVideoBackgroundResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.result.video_url)) + + +def _video_to_images_and_mask(video: Input.Video) -> tuple[Input.Image, Input.Mask]: + """Decode a transparent webm (VP9 + alpha) into image frames and an alpha mask. + + VP9 keeps its alpha in a side layer that PyAV's default vp9 decoder drops, so the frames + are decoded with libvpx-vp9. Returns RGB images [B,H,W,3] in 0..1 and a mask [B,H,W] + following the Load Image convention (1 = transparent) for compositing or Save WEBM. + """ + rgb_frames: list[torch.Tensor] = [] + alpha_frames: list[torch.Tensor] = [] + with av.open(video.get_stream_source(), mode="r") as container: + stream = container.streams.video[0] + decoder = CodecContext.create("libvpx-vp9", "r") if stream.codec_context.name == "vp9" else None + for packet in container.demux(stream): + for frame in (decoder.decode(packet) if decoder is not None else packet.decode()): + rgba = torch.from_numpy(frame.to_ndarray(format="rgba")).float() / 255.0 + rgb_frames.append(rgba[..., :3]) + alpha_frames.append(rgba[..., 3]) + images = torch.stack(rgb_frames) if rgb_frames else torch.zeros(0, 0, 0, 3) + mask = (1.0 - torch.stack(alpha_frames)) if alpha_frames else torch.zeros((images.shape[0], 64, 64)) + return images, mask + + +class BriaTransparentVideoBackground(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="BriaTransparentVideoBackground", + display_name="Bria Remove Video Background (Transparent)", + category="partner/video/Bria", + description="Remove the background from a video using Bria and return the cut-out frames " + "plus an alpha mask. Connect both to a compositing node, or feed them to Save WEBM to " + "write a transparent video.", + inputs=[ + IO.Video.Input("video"), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed controls whether the node should re-run; " + "results are non-deterministic regardless of seed.", + ), + ], + outputs=[ + IO.Image.Output(display_name="images"), + IO.Mask.Output(display_name="mask"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""", + ), + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + seed: int, + ) -> IO.NodeOutput: + validate_video_duration(video, max_duration=60.0) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"), + data=BriaRemoveVideoBackgroundRequest( + video=await upload_video_to_comfyapi(cls, video), + background_color="Transparent", + output_container_and_codec="webm_vp9", + seed=seed, + ), + response_model=BriaStatusResponse, + ) + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"), + status_extractor=lambda r: r.status, + response_model=BriaRemoveVideoBackgroundResponse, + ) + video_out = await download_url_to_video_output(response.result.video_url) + images, mask = _video_to_images_and_mask(video_out) + return IO.NodeOutput(images, mask) + + class BriaExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -323,6 +570,9 @@ class BriaExtension(ComfyExtension): BriaImageEditNode, BriaRemoveImageBackground, BriaRemoveVideoBackground, + BriaVideoGreenScreen, + # BriaVideoReplaceBackground, # server returns Status 500 when we pass background video + BriaTransparentVideoBackground, ] diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index e08fc0b01..c30ddc446 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -2,11 +2,13 @@ import hashlib import logging import math import re +from io import BytesIO import torch from typing_extensions import override -from comfy_api.latest import IO, ComfyExtension, Input +from comfy.utils import common_upscale +from comfy_api.latest import IO, ComfyExtension, Input, Types from comfy_api_nodes.apis.bytedance import ( RECOMMENDED_PRESETS, RECOMMENDED_PRESETS_SEEDREAM_4, @@ -43,6 +45,7 @@ from comfy_api_nodes.util import ( ApiEndpoint, download_url_to_image_tensor, download_url_to_video_output, + downscale_image_tensor_by_max_side, downscale_video_to_max_pixels, get_number_of_images, image_tensor_pair_to_batch, @@ -121,6 +124,52 @@ def _validate_ref_video_pixels(video: Input.Video, model_id: str, resolution: st ) +def _prepare_seedance_image(image: Input.Image) -> Input.Image: + """Auto-downscale a Seedance image input to the per-side limits, then validate it.""" + validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + image = downscale_image_tensor_by_max_side(image, max_side=6000) + validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) + return image + + +# Supported output aspect ratios, used to pre-size FLF frames to matching pixel pair to avoid the 1080p stretch jump. +SEEDANCE2_RATIO_WH = { + "16:9": (16, 9), + "4:3": (4, 3), + "1:1": (1, 1), + "3:4": (3, 4), + "9:16": (9, 16), + "21:9": (21, 9), +} +SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080} + + +def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]: + """Exact supported output (width, height) for (resolution, ratio). + + The shorter side equals the resolution number (e.g. 1080p 16:9 -> 1920x1080). For ratio + "adaptive" (or any unexpected value) the ratio is derived from the image's own aspect, snapped + to the nearest supported ratio, so the output keeps the frame's orientation. + """ + short = SEEDANCE2_RES_SHORT_SIDE[resolution] + if ratio not in SEEDANCE2_RATIO_WH: + aspect = image.shape[-2] / image.shape[-3] # W / H; tensor is (B, H, W, C) + ratio = min(SEEDANCE2_RATIO_WH, key=lambda k: abs(SEEDANCE2_RATIO_WH[k][0] / SEEDANCE2_RATIO_WH[k][1] - aspect)) + rw, rh = SEEDANCE2_RATIO_WH[ratio] + if rw >= rh: # landscape or square: shorter side is the height + out_w, out_h = round(short * rw / rh), short + else: # portrait: shorter side is the width + out_w, out_h = short, round(short * rh / rw) + return out_w - out_w % 2, out_h - out_h % 2 + + +def _resize_to_exact(image: torch.Tensor, width: int, height: int) -> torch.Tensor: + """Center-crop to the target aspect and resize to exactly width x height (lanczos).""" + samples = image.movedim(-1, 1) # (B, H, W, C) -> (B, C, H, W) + resized = common_upscale(samples, width, height, "lanczos", "center") + return resized.movedim(1, -1) + + async def _resolve_reference_assets( cls: type[IO.ComfyNode], asset_ids: list[str], @@ -308,6 +357,26 @@ async def _seedance_virtual_library_upload_image_asset( return f"asset://{create_resp.asset_id}" +async def _seedance_virtual_library_upload_video_asset( + cls: type[IO.ComfyNode], + video: Input.Video, + *, + wait_label: str = "Uploading video", +) -> str: + buf = BytesIO() + video.save_to(buf, format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264) + video_hash = hashlib.sha256(buf.getbuffer()).hexdigest() + public_url = await upload_video_to_comfyapi(cls, video, wait_label=wait_label) + create_resp = await sync_op( + cls, + ApiEndpoint(path="/proxy/seedance/virtual-library/assets", method="POST"), + response_model=SeedanceCreateAssetResponse, + data=SeedanceVirtualLibraryCreateAssetRequest(url=public_url, hash=video_hash, asset_type="Video"), + ) + await _wait_for_asset_active(cls, create_resp.asset_id, group_id="virtual-library") + return f"asset://{create_resp.asset_id}" + + def _seedance2_price_extractor(model_id: str, has_video_input: bool): """Returns a price_extractor closure for Seedance 2.0 poll_op.""" rate = SEEDANCE2_PRICE_PER_1K_TOKENS.get((model_id, has_video_input)) @@ -338,7 +407,7 @@ class ByteDanceImageNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceImageNode", display_name="ByteDance Image", - category="api node/image/ByteDance", + category="partner/image/ByteDance", description="Generate images using ByteDance models via api based on prompt", inputs=[ IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]), @@ -462,7 +531,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceSeedreamNode", display_name="ByteDance Seedream 4.5 & 5.0", - category="api node/image/ByteDance", + category="partner/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ IO.Combo.Input( @@ -724,7 +793,7 @@ class ByteDanceSeedreamNodeV2(IO.ComfyNode): return IO.Schema( node_id="ByteDanceSeedreamNodeV2", display_name="ByteDance Seedream 4.5 & 5.0", - category="api node/image/ByteDance", + category="partner/image/ByteDance", description="Unified text-to-image generation and precise single-sentence editing at up to 4K resolution.", inputs=[ IO.String.Input( @@ -890,7 +959,7 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceTextToVideoNode", display_name="ByteDance Text to Video", - category="api node/video/ByteDance", + category="partner/video/ByteDance", description="Generate video using ByteDance models via api based on prompt", inputs=[ IO.Combo.Input( @@ -1018,7 +1087,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceImageToVideoNode", display_name="ByteDance Image to Video", - category="api node/video/ByteDance", + category="partner/video/ByteDance", description="Generate video using ByteDance models via api based on image and prompt", inputs=[ IO.Combo.Input( @@ -1155,7 +1224,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceFirstLastFrameNode", display_name="ByteDance First-Last-Frame to Video", - category="api node/video/ByteDance", + category="partner/video/ByteDance", description="Generate video using prompt and first and last frames.", inputs=[ IO.Combo.Input( @@ -1303,7 +1372,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceImageReferenceNode", display_name="ByteDance Reference Images to Video", - category="api node/video/ByteDance", + category="partner/video/ByteDance", description="Generate video using prompt and reference images.", inputs=[ IO.Combo.Input( @@ -1546,7 +1615,7 @@ class ByteDance2TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ByteDance2TextToVideoNode", display_name="ByteDance Seedance 2.0 Text to Video", - category="api node/video/ByteDance", + category="partner/video/ByteDance", description="Generate video using Seedance 2.0 models based on a text prompt.", inputs=[ IO.DynamicCombo.Input( @@ -1647,7 +1716,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="ByteDance2FirstLastFrameNode", display_name="ByteDance Seedance 2.0 First-Last-Frame to Video", - category="api node/video/ByteDance", + category="partner/video/ByteDance", description="Generate video using Seedance 2.0 from a first frame image and optional last frame image.", inputs=[ IO.DynamicCombo.Input( @@ -1760,6 +1829,29 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): if last_frame is not None and last_frame_asset_id: raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.") + request_ratio = model["ratio"] + if first_frame_asset_id or last_frame_asset_id: + if first_frame is not None: + first_frame = _prepare_seedance_image(first_frame) + if last_frame is not None: + last_frame = _prepare_seedance_image(last_frame) + else: + # The 1080p FLF stretch fix (pre-size frames to a supported pixel pair + submit ratio="adaptive") + # only applies to local image inputs we can resize. + request_ratio = "adaptive" + target_dims: tuple[int, int] | None = None + if first_frame is not None: + validate_image_aspect_ratio(first_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + validate_image_dimensions(first_frame, min_width=300, min_height=300) + target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], first_frame) + first_frame = _resize_to_exact(first_frame, *target_dims) + if last_frame is not None: + validate_image_aspect_ratio(last_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + validate_image_dimensions(last_frame, min_width=300, min_height=300) + if target_dims is None: + target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], last_frame) + last_frame = _resize_to_exact(last_frame, *target_dims) + asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a] image_assets: dict[str, str] = {} if asset_ids_to_resolve: @@ -1809,7 +1901,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode): content=content, generate_audio=model["generate_audio"], resolution=model["resolution"], - ratio=model["ratio"], + ratio=request_ratio, duration=model["duration"], seed=seed, watermark=watermark, @@ -1866,7 +1958,7 @@ def _seedance2_reference_inputs(resolutions: list[str], default_ratio: str = "16 ), IO.Boolean.Input( "auto_downscale", - default=False, + default=True, optional=True, tooltip="Automatically downscale reference videos that exceed the model's pixel budget " "for the selected resolution. Aspect ratio is preserved; videos already within limits are untouched.", @@ -1909,7 +2001,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode): return IO.Schema( node_id="ByteDance2ReferenceNode", display_name="ByteDance Seedance 2.0 Reference to Video", - category="api node/video/ByteDance", + category="partner/video/ByteDance", description="Generate, edit, or extend video using Seedance 2.0 with reference images, " "videos, and audio. Supports multimodal reference, video editing, and video extension.", inputs=[ @@ -2034,6 +2126,9 @@ class ByteDance2ReferenceNode(IO.ComfyNode): f"(audios={len(reference_audios)}, audio assets={len(reference_audio_assets)}). Maximum is 3." ) + for key in reference_images: + reference_images[key] = _prepare_seedance_image(reference_images[key]) + model_id = SEEDANCE_MODELS[model["model"]] has_video_input = total_videos > 0 @@ -2106,7 +2201,7 @@ class ByteDance2ReferenceNode(IO.ComfyNode): content.append( TaskVideoContent( video_url=TaskVideoContentUrl( - url=await upload_video_to_comfyapi( + url=await _seedance_virtual_library_upload_video_asset( cls, reference_videos[key], wait_label=f"Uploading video {i}", @@ -2203,7 +2298,7 @@ class ByteDanceCreateImageAsset(IO.ComfyNode): return IO.Schema( node_id="ByteDanceCreateImageAsset", display_name="ByteDance Create Image Asset", - category="api node/image/ByteDance", + category="partner/image/ByteDance", description=( "Create a Seedance 2.0 personal image asset. Uploads the input image and " "registers it in the given asset group. If group_id is empty, runs a real-person " @@ -2270,7 +2365,7 @@ class ByteDanceCreateVideoAsset(IO.ComfyNode): return IO.Schema( node_id="ByteDanceCreateVideoAsset", display_name="ByteDance Create Video Asset", - category="api node/video/ByteDance", + category="partner/video/ByteDance", description=( "Create a Seedance 2.0 personal video asset. Uploads the input video and " "registers it in the given asset group. If group_id is empty, runs a real-person " diff --git a/comfy_api_nodes/nodes_bytedance_llm.py b/comfy_api_nodes/nodes_bytedance_llm.py index fa7fe370a..cb41defa0 100644 --- a/comfy_api_nodes/nodes_bytedance_llm.py +++ b/comfy_api_nodes/nodes_bytedance_llm.py @@ -144,7 +144,7 @@ class ByteDanceSeedNode(IO.ComfyNode): return IO.Schema( node_id="ByteDanceSeedNode", display_name="ByteDance Seed", - category="api node/text/ByteDance", + category="partner/text/ByteDance", essentials_category="Text Generation", description="Generate text responses with ByteDance's Seed 2.0 models. " "Provide a text prompt and optionally one or more images or videos for multimodal context.", diff --git a/comfy_api_nodes/nodes_elevenlabs.py b/comfy_api_nodes/nodes_elevenlabs.py index e452daf77..eba578a45 100644 --- a/comfy_api_nodes/nodes_elevenlabs.py +++ b/comfy_api_nodes/nodes_elevenlabs.py @@ -69,7 +69,7 @@ class ElevenLabsSpeechToText(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsSpeechToText", display_name="ElevenLabs Speech to Text", - category="api node/audio/ElevenLabs", + category="partner/audio/ElevenLabs", description="Transcribe audio to text. " "Supports automatic language detection, speaker diarization, and audio event tagging.", inputs=[ @@ -210,7 +210,7 @@ class ElevenLabsVoiceSelector(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsVoiceSelector", display_name="ElevenLabs Voice Selector", - category="api node/audio/ElevenLabs", + category="partner/audio/ElevenLabs", description="Select a predefined ElevenLabs voice for text-to-speech generation.", inputs=[ IO.Combo.Input( @@ -239,7 +239,7 @@ class ElevenLabsTextToSpeech(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsTextToSpeech", display_name="ElevenLabs Text to Speech", - category="api node/audio/ElevenLabs", + category="partner/audio/ElevenLabs", description="Convert text to speech.", inputs=[ IO.Custom(ELEVENLABS_VOICE).Input( @@ -414,7 +414,7 @@ class ElevenLabsAudioIsolation(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsAudioIsolation", display_name="ElevenLabs Voice Isolation", - category="api node/audio/ElevenLabs", + category="partner/audio/ElevenLabs", description="Remove background noise from audio, isolating vocals or speech.", inputs=[ IO.Audio.Input( @@ -459,7 +459,7 @@ class ElevenLabsTextToSoundEffects(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsTextToSoundEffects", display_name="ElevenLabs Text to Sound Effects", - category="api node/audio/ElevenLabs", + category="partner/audio/ElevenLabs", description="Generate sound effects from text descriptions.", inputs=[ IO.String.Input( @@ -555,7 +555,7 @@ class ElevenLabsInstantVoiceClone(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsInstantVoiceClone", display_name="ElevenLabs Instant Voice Clone", - category="api node/audio/ElevenLabs", + category="partner/audio/ElevenLabs", description="Create a cloned voice from audio samples. " "Provide 1-8 audio recordings of the voice to clone.", inputs=[ @@ -658,7 +658,7 @@ class ElevenLabsSpeechToSpeech(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsSpeechToSpeech", display_name="ElevenLabs Speech to Speech", - category="api node/audio/ElevenLabs", + category="partner/audio/ElevenLabs", description="Transform speech from one voice to another while preserving the original content and emotion.", inputs=[ IO.Custom(ELEVENLABS_VOICE).Input( @@ -793,7 +793,7 @@ class ElevenLabsTextToDialogue(IO.ComfyNode): return IO.Schema( node_id="ElevenLabsTextToDialogue", display_name="ElevenLabs Text to Dialogue", - category="api node/audio/ElevenLabs", + category="partner/audio/ElevenLabs", description="Generate multi-speaker dialogue from text. Each dialogue entry has its own text and voice.", inputs=[ IO.Float.Input( diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index d18c958a8..3d4be6065 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -8,7 +8,7 @@ import os from enum import Enum from fnmatch import fnmatch from io import BytesIO -from typing import Literal +from typing import Any, Literal import torch from typing_extensions import override @@ -19,6 +19,7 @@ from comfy_api_nodes.apis.gemini import ( GeminiContent, GeminiFileData, GeminiGenerateContentRequest, + GeminiGenerationConfig, GeminiGenerateContentResponse, GeminiImageConfig, GeminiImageGenerateContentRequest, @@ -40,13 +41,18 @@ from comfy_api_nodes.util import ( get_number_of_images, sync_op, tensor_to_base64_string, + upload_audio_to_comfyapi, + upload_image_to_comfyapi, upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_string, video_to_base64_string, ) GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB +GEMINI_URL_INPUT_BUDGET = 10 +GEMINI_MAX_INLINE_BYTES = 18 * 1024 * 1024 GEMINI_IMAGE_SYS_PROMPT = ( "You are an expert image-generation engine. You must ALWAYS produce an image.\n" "Interpret all user input—regardless of " @@ -285,6 +291,140 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N return final_price / 1_000_000.0 +def create_video_parts(video_input: Input.Video) -> list[GeminiPart]: + """Convert a single video input to Gemini API compatible parts (inline MP4/H.264).""" + base_64_string = video_to_base64_string( + video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 + ) + return [ + GeminiPart( + inlineData=GeminiInlineData( + mimeType=GeminiMimeType.video_mp4, + data=base_64_string, + ) + ) + ] + + +def create_audio_parts(audio_input: Input.Audio) -> list[GeminiPart]: + """Convert an audio input to Gemini API compatible parts (one inline MP3 part per batch item).""" + audio_parts: list[GeminiPart] = [] + for batch_index in range(audio_input["waveform"].shape[0]): + # Recreate an IO.AUDIO object for the given batch dimension index + audio_at_index = Input.Audio( + waveform=audio_input["waveform"][batch_index].unsqueeze(0), + sample_rate=audio_input["sample_rate"], + ) + # Convert to MP3 format for compatibility with Gemini API + audio_bytes = audio_to_base64_string( + audio_at_index, + container_format="mp3", + codec_name="libmp3lame", + ) + audio_parts.append( + GeminiPart( + inlineData=GeminiInlineData( + mimeType=GeminiMimeType.audio_mp3, + data=audio_bytes, + ) + ) + ) + return audio_parts + + +def _flatten_images(images: list[Input.Image]) -> list[torch.Tensor]: + """Expand any batched image tensors into individual (H, W, C) frames, preserving order.""" + frames: list[torch.Tensor] = [] + for img in images: + if len(img.shape) == 4: + frames.extend(img[i] for i in range(img.shape[0])) + else: + frames.append(img) + return frames + + +def _flatten_audio(audios: list[Input.Audio]) -> list[Input.Audio]: + """Expand any batched audio inputs into individual single-clip audio inputs, preserving order.""" + clips: list[Input.Audio] = [] + for audio in audios: + waveform = audio["waveform"] + for i in range(waveform.shape[0]): + clips.append(Input.Audio(waveform=waveform[i].unsqueeze(0), sample_rate=audio["sample_rate"])) + return clips + + +async def _media_url_part(cls: type[IO.ComfyNode], kind: str, payload: Any) -> GeminiPart: + """Upload a single media unit to ComfyAPI storage and return a fileData (URL) part.""" + if kind == "image": + url = await upload_image_to_comfyapi(cls, payload, mime_type="image/png", wait_label="Uploading image") + return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.image_png, fileUri=url)) + if kind == "audio": + url = await upload_audio_to_comfyapi( + cls, payload, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mp3" + ) + return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.audio_mp3, fileUri=url)) + url = await upload_video_to_comfyapi(cls, payload, wait_label="Uploading video") + return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.video_mp4, fileUri=url)) + + +def _media_inline_part(kind: str, payload: Any) -> tuple[GeminiPart, int]: + """Encode a single media unit as an inline base64 part; returns (part, base64_length).""" + if kind == "image": + data = tensor_to_base64_string(payload, mime_type="image/webp") + mime = GeminiMimeType.image_webp + elif kind == "audio": + data = audio_to_base64_string(payload, container_format="mp3", codec_name="libmp3lame") + mime = GeminiMimeType.audio_mp3 + else: + data = video_to_base64_string( + payload, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 + ) + mime = GeminiMimeType.video_mp4 + return GeminiPart(inlineData=GeminiInlineData(mimeType=mime, data=data)), len(data) + + +async def build_gemini_media_parts( + cls: type[IO.ComfyNode], + images: list[Input.Image], + audios: list[Input.Audio], + videos: list[Input.Video], + *, + url_budget: int = GEMINI_URL_INPUT_BUDGET, + max_inline_bytes: int = GEMINI_MAX_INLINE_BYTES, +) -> list[GeminiPart]: + """Build Gemini parts for multimodal inputs (images, audio, video). + + fileData URLs are preferred for every media type: the upload is fetched directly by the + model, keeping the request body tiny regardless of media size. The URL budget is shared + across all media and assigned largest-first (video, then audio, then images), so that if it + is ever exhausted the inline-base64 overflow is limited to the smallest items. Total inline + payload is capped by `max_inline_bytes`. + """ + units: list[tuple[str, Any]] = ( + [("video", v) for v in videos] + + [("audio", a) for a in _flatten_audio(audios)] + + [("image", f) for f in _flatten_images(images)] + ) + + parts: list[GeminiPart] = [] + url_used = 0 + inline_bytes = 0 + for kind, payload in units: + if url_used < url_budget: + parts.append(await _media_url_part(cls, kind, payload)) + url_used += 1 + continue + part, nbytes = _media_inline_part(kind, payload) + inline_bytes += nbytes + if inline_bytes > max_inline_bytes: + raise ValueError( + f"Too much media to send inline (over {max_inline_bytes // (1024 * 1024)}MB after the first " + f"{url_budget} inputs are uploaded as URLs). Reduce the number or size of attached media." + ) + parts.append(part) + return parts + + class GeminiNode(IO.ComfyNode): """ Node to generate text responses from a Gemini model. @@ -300,7 +440,7 @@ class GeminiNode(IO.ComfyNode): return IO.Schema( node_id="GeminiNode", display_name="Google Gemini", - category="api node/text/Gemini", + category="partner/text/Gemini", description="Generate text responses with Google's Gemini AI model. " "You can provide multiple types of inputs (text, images, audio, video) " "as context for generating more relevant and meaningful responses.", @@ -407,58 +547,9 @@ class GeminiNode(IO.ComfyNode): ) """, ), + is_deprecated=True, ) - @classmethod - def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]: - """Convert video input to Gemini API compatible parts.""" - - base_64_string = video_to_base64_string( - video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 - ) - return [ - GeminiPart( - inlineData=GeminiInlineData( - mimeType=GeminiMimeType.video_mp4, - data=base_64_string, - ) - ) - ] - - @classmethod - def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]: - """ - Convert audio input to Gemini API compatible parts. - - Args: - audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate. - - Returns: - List of GeminiPart objects containing the encoded audio. - """ - audio_parts: list[GeminiPart] = [] - for batch_index in range(audio_input["waveform"].shape[0]): - # Recreate an IO.AUDIO object for the given batch dimension index - audio_at_index = Input.Audio( - waveform=audio_input["waveform"][batch_index].unsqueeze(0), - sample_rate=audio_input["sample_rate"], - ) - # Convert to MP3 format for compatibility with Gemini API - audio_bytes = audio_to_base64_string( - audio_at_index, - container_format="mp3", - codec_name="libmp3lame", - ) - audio_parts.append( - GeminiPart( - inlineData=GeminiInlineData( - mimeType=GeminiMimeType.audio_mp3, - data=audio_bytes, - ) - ) - ) - return audio_parts - @classmethod async def execute( cls, @@ -482,9 +573,9 @@ class GeminiNode(IO.ComfyNode): if images is not None: parts.extend(await create_image_parts(cls, images)) if audio is not None: - parts.extend(cls.create_audio_parts(audio)) + parts.extend(create_audio_parts(audio)) if video is not None: - parts.extend(cls.create_video_parts(video)) + parts.extend(create_video_parts(video)) if files is not None: parts.extend(files) @@ -512,6 +603,210 @@ class GeminiNode(IO.ComfyNode): return IO.NodeOutput(output_text or "Empty response from Gemini model...") +GEMINI_V2_MODELS: dict[str, str] = { + "Gemini 3.1 Pro": "gemini-3.1-pro-preview", + "Gemini 3.1 Flash-Lite": "gemini-3.1-flash-lite-preview", +} + + +def _gemini_text_model_inputs(thinking_default: str) -> list[Input]: + """Per-model inputs revealed by the model DynamicCombo (shared media + sampling controls).""" + return [ + IO.Autogrow.Input( + "images", + template=IO.Autogrow.TemplateNames( + IO.Image.Input("image"), + names=[f"image_{i}" for i in range(1, 17)], + min=0, + ), + tooltip="Optional image(s) to use as context for the model. Up to 16 images.", + ), + IO.Autogrow.Input( + "audio", + template=IO.Autogrow.TemplateNames( + IO.Audio.Input("audio"), + names=["audio_1"], + min=0, + ), + tooltip="Optional audio clip to use as context for the model.", + ), + IO.Autogrow.Input( + "video", + template=IO.Autogrow.TemplateNames( + IO.Video.Input("video"), + names=["video_1"], + min=0, + ), + tooltip="Optional video clip to use as context for the model.", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Input Files node.", + ), + IO.Combo.Input( + "thinking_level", + options=["LOW", "HIGH"], + default=thinking_default, + tooltip="How hard the model reasons internally before answering. " + "HIGH improves quality on difficult tasks but costs more (thinking) tokens and is slower.", + ), + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.01, + tooltip="Controls randomness. Lower is more focused/deterministic, higher is more creative.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=0.95, + min=0.0, + max=1.0, + step=0.01, + tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.", + advanced=True, + ), + IO.Int.Input( + "max_output_tokens", + default=32768, + min=16, + max=65536, + tooltip="Maximum tokens to generate, including the model's internal thinking. " + "With thinking_level HIGH, a low value can leave no room for the answer; raise this if " + "responses come back empty or truncated. The model stops early when finished, so a higher " + "cap costs nothing extra for short replies.", + advanced=True, + ), + ] + + +class GeminiNodeV2(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiNodeV2", + display_name="Google Gemini", + category="partner/text/Gemini", + essentials_category="Text Generation", + description="Generate text responses with Google's Gemini models. Provide a text prompt and, " + "optionally, one or more images, audio clips, videos, or files as multimodal context.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text input to the model. Include detailed instructions, questions, or context.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option("Gemini 3.1 Pro", _gemini_text_model_inputs("HIGH")), + IO.DynamicCombo.Option("Gemini 3.1 Flash-Lite", _gemini_text_model_inputs("LOW")), + ], + tooltip="The Gemini model used to generate the response.", + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed for sampling. Set to 0 for a random seed. Deterministic output isn't guaranteed.", + ), + IO.String.Input( + "system_prompt", + multiline=True, + default="", + optional=True, + advanced=True, + tooltip="Foundational instructions that dictate the model's behavior.", + ), + ], + outputs=[ + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["model"]), + expr=""" + ( + $m := widgets.model; + $contains($m, "lite") ? { + "type": "list_usd", + "usd": [0.00025, 0.0015], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } : { + "type": "list_usd", + "usd": [0.002, 0.012], + "format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" } + } + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + model_id = GEMINI_V2_MODELS[model["model"]] + + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + images = [t for t in (model.get("images") or {}).values() if t is not None] + audios = [a for a in (model.get("audio") or {}).values() if a is not None] + videos = [v for v in (model.get("video") or {}).values() if v is not None] + if images or audios or videos: + parts.extend(await build_gemini_media_parts(cls, images, audios, videos)) + files = model.get("files") + if files is not None: + parts.extend(files) + + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"), + data=GeminiGenerateContentRequest( + contents=[ + GeminiContent( + role=GeminiRole.user, + parts=parts, + ) + ], + generationConfig=GeminiGenerationConfig( + temperature=model["temperature"], + topP=model["top_p"], + maxOutputTokens=model["max_output_tokens"], + seed=seed if seed > 0 else None, + thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]), + ), + systemInstruction=gemini_system_prompt, + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + + output_text = get_text_from_response(response) + return IO.NodeOutput(output_text or "Empty response from Gemini model...") + + class GeminiInputFiles(IO.ComfyNode): """ Loads and formats input files for use with the Gemini API. @@ -541,7 +836,7 @@ class GeminiInputFiles(IO.ComfyNode): return IO.Schema( node_id="GeminiInputFiles", display_name="Gemini Input Files", - category="api node/text/Gemini", + category="partner/text/Gemini", description="Loads and prepares input files to include as inputs for Gemini LLM nodes. " "The files will be read by the Gemini model when generating a response. " "The contents of the text file count toward the token limit. " @@ -598,7 +893,7 @@ class GeminiImage(IO.ComfyNode): return IO.Schema( node_id="GeminiImageNode", display_name="Nano Banana (Google Gemini Image)", - category="api node/image/Gemini", + category="partner/image/Gemini", description="Edit images synchronously via Google API.", inputs=[ IO.String.Input( @@ -731,7 +1026,7 @@ class GeminiImage2(IO.ComfyNode): return IO.Schema( node_id="GeminiImage2Node", display_name="Nano Banana Pro (Google Gemini Image)", - category="api node/image/Gemini", + category="partner/image/Gemini", description="Generate or edit images synchronously via Google Vertex API.", inputs=[ IO.String.Input( @@ -869,7 +1164,7 @@ class GeminiNanoBanana2(IO.ComfyNode): return IO.Schema( node_id="GeminiNanoBanana2", display_name="Nano Banana 2", - category="api node/image/Gemini", + category="partner/image/Gemini", description="Generate or edit images synchronously via Google Vertex API.", inputs=[ IO.String.Input( @@ -1085,7 +1380,7 @@ class GeminiNanoBanana2V2(IO.ComfyNode): return IO.Schema( node_id="GeminiNanoBanana2V2", display_name="Nano Banana 2", - category="api node/image/Gemini", + category="partner/image/Gemini", description="Generate or edit images synchronously via Google Vertex API.", inputs=[ IO.String.Input( @@ -1129,6 +1424,26 @@ class GeminiNanoBanana2V2(IO.ComfyNode): tooltip="Foundational instructions that dictate an AI's behavior.", advanced=True, ), + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.01, + optional=True, + tooltip="Controls randomness in generation. Lower is more focused/deterministic.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=0.95, + min=0.0, + max=1.0, + step=0.01, + optional=True, + tooltip="Nucleus sampling threshold. Lower is more focused, higher more diverse.", + advanced=True, + ), ], outputs=[ IO.Image.Output(), @@ -1165,6 +1480,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode): seed: int, response_modalities: str, system_prompt: str = "", + temperature: float = 1.0, + top_p: float = 0.95, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) model_choice = model["model"] @@ -1204,6 +1521,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode): responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), imageConfig=image_config, thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]), + temperature=temperature, + topP=top_p, ), systemInstruction=gemini_system_prompt, ), @@ -1222,6 +1541,7 @@ class GeminiExtension(ComfyExtension): async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ GeminiNode, + GeminiNodeV2, GeminiImage, GeminiImage2, GeminiNanoBanana2, diff --git a/comfy_api_nodes/nodes_grok.py b/comfy_api_nodes/nodes_grok.py index a103f24ee..2ae529813 100644 --- a/comfy_api_nodes/nodes_grok.py +++ b/comfy_api_nodes/nodes_grok.py @@ -29,6 +29,11 @@ from comfy_api_nodes.util import ( ) +_GROK_VIDEO_MODEL_API_IDS = { + "grok-imagine-video-1.5": "grok-imagine-video-1.5-preview", +} + + def _extract_grok_price(response) -> float | None: if response.usage and response.usage.cost_in_usd_ticks is not None: return response.usage.cost_in_usd_ticks / 10_000_000_000 @@ -49,7 +54,7 @@ class GrokImageNode(IO.ComfyNode): return IO.Schema( node_id="GrokImageNode", display_name="Grok Image", - category="api node/image/Grok", + category="partner/image/Grok", description="Generate images using Grok based on a text prompt", inputs=[ IO.Combo.Input( @@ -58,7 +63,6 @@ class GrokImageNode(IO.ComfyNode): "grok-imagine-image-quality", "grok-imagine-image-pro", "grok-imagine-image", - "grok-imagine-image-beta", ], ), IO.String.Input( @@ -224,7 +228,7 @@ class GrokImageEditNode(IO.ComfyNode): return IO.Schema( node_id="GrokImageEditNode", display_name="Grok Image Edit", - category="api node/image/Grok", + category="partner/image/Grok", description="Modify an existing image based on a text prompt", inputs=[ IO.Combo.Input( @@ -233,7 +237,6 @@ class GrokImageEditNode(IO.ComfyNode): "grok-imagine-image-quality", "grok-imagine-image-pro", "grok-imagine-image", - "grok-imagine-image-beta", ], ), IO.Image.Input("image", display_name="images"), @@ -366,7 +369,7 @@ class GrokImageEditNodeV2(IO.ComfyNode): return IO.Schema( node_id="GrokImageEditNodeV2", display_name="Grok Image Edit", - category="api node/image/Grok", + category="partner/image/Grok", description="Modify an existing image based on a text prompt", inputs=[ IO.String.Input( @@ -503,10 +506,14 @@ class GrokVideoNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoNode", display_name="Grok Video", - category="api node/video/Grok", + category="partner/video/Grok", description="Generate video from a prompt or an image", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), + IO.Combo.Input( + "model", + options=["grok-imagine-video", "grok-imagine-video-1.5"], + tooltip="grok-imagine-video-1.5 currently always requires an input image.", + ), IO.String.Input( "prompt", multiline=True, @@ -542,7 +549,11 @@ class GrokVideoNode(IO.ComfyNode): tooltip="Seed to determine if node should re-run; " "actual results are nondeterministic regardless of seed.", ), - IO.Image.Input("image", optional=True), + IO.Image.Input( + "image", + optional=True, + tooltip="Optional starting image for grok-imagine-video. Required for grok-imagine-video-1.5.", + ), ], outputs=[ IO.Video.Output(), @@ -554,12 +565,16 @@ class GrokVideoNode(IO.ComfyNode): ], is_api_node=True, price_badge=IO.PriceBadge( - depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]), + depends_on=IO.PriceBadgeDepends(widgets=["model", "duration", "resolution"], inputs=["image"]), expr=""" ( - $rate := widgets.resolution = "720p" ? 0.07 : 0.05; + $is15 := $contains(widgets.model, "1.5"); + $rate := $is15 + ? (widgets.resolution = "720p" ? 0.2002 : 0.1144) + : (widgets.resolution = "720p" ? 0.07 : 0.05); + $imgCost := $is15 ? 0.0143 : 0.002; $base := $rate * widgets.duration; - {"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base} + {"type":"usd","usd": inputs.image.connected ? $base + $imgCost : $base} ) """, ), @@ -576,8 +591,8 @@ class GrokVideoNode(IO.ComfyNode): seed: int, image: Input.Image | None = None, ) -> IO.NodeOutput: - if model == "grok-imagine-video-beta": - model = "grok-imagine-video" + if image is None and model == "grok-imagine-video-1.5": + raise ValueError(f"The '{model}' model requires an input image; connect one to the 'image' input.") image_url = None if image is not None: if get_number_of_images(image) != 1: @@ -588,7 +603,7 @@ class GrokVideoNode(IO.ComfyNode): cls, ApiEndpoint(path="/proxy/xai/v1/videos/generations", method="POST"), data=VideoGenerationRequest( - model=model, + model=_GROK_VIDEO_MODEL_API_IDS.get(model, model), image=image_url, prompt=prompt, resolution=resolution, @@ -603,7 +618,7 @@ class GrokVideoNode(IO.ComfyNode): ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"), status_extractor=lambda r: r.status if r.status is not None else "complete", response_model=VideoStatusResponse, - price_extractor=_extract_grok_price, + price_extractor=_extract_grok_video_price if model == "grok-imagine-video-1.5" else _extract_grok_price, ) return IO.NodeOutput(await download_url_to_video_output(response.video.url)) @@ -615,10 +630,10 @@ class GrokVideoEditNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoEditNode", display_name="Grok Video Edit", - category="api node/video/Grok", + category="partner/video/Grok", description="Edit an existing video based on a text prompt.", inputs=[ - IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]), + IO.Combo.Input("model", options=["grok-imagine-video"]), IO.String.Input( "prompt", multiline=True, @@ -693,7 +708,7 @@ class GrokVideoReferenceNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoReferenceNode", display_name="Grok Reference-to-Video", - category="api node/video/Grok", + category="partner/video/Grok", description="Generate video guided by reference images as style and content references.", inputs=[ IO.String.Input( @@ -826,7 +841,7 @@ class GrokVideoExtendNode(IO.ComfyNode): return IO.Schema( node_id="GrokVideoExtendNode", display_name="Grok Video Extend", - category="api node/video/Grok", + category="partner/video/Grok", description="Extend an existing video with a seamless continuation based on a text prompt.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_hitpaw.py b/comfy_api_nodes/nodes_hitpaw.py index bca5170e4..062d3cf1d 100644 --- a/comfy_api_nodes/nodes_hitpaw.py +++ b/comfy_api_nodes/nodes_hitpaw.py @@ -71,7 +71,7 @@ class HitPawGeneralImageEnhance(IO.ComfyNode): return IO.Schema( node_id="HitPawGeneralImageEnhance", display_name="HitPaw General Image Enhance", - category="api node/image/HitPaw", + category="partner/image/HitPaw", description="Upscale low-resolution images to super-resolution, eliminate artifacts and noise. " f"Maximum output: {MAX_MP_GENERATIVE} megapixels.", inputs=[ @@ -201,7 +201,7 @@ class HitPawVideoEnhance(IO.ComfyNode): return IO.Schema( node_id="HitPawVideoEnhance", display_name="HitPaw Video Enhance", - category="api node/video/HitPaw", + category="partner/video/HitPaw", description="Upscale low-resolution videos to high resolution, eliminate artifacts and noise. " "Prices shown are per second of video.", inputs=[ diff --git a/comfy_api_nodes/nodes_hunyuan3d.py b/comfy_api_nodes/nodes_hunyuan3d.py index 5fc31bccd..fcd27b7fb 100644 --- a/comfy_api_nodes/nodes_hunyuan3d.py +++ b/comfy_api_nodes/nodes_hunyuan3d.py @@ -123,7 +123,7 @@ class TencentTextToModelNode(IO.ComfyNode): return IO.Schema( node_id="TencentTextToModelNode", display_name="Hunyuan3D: Text to Model", - category="api node/3d/Tencent", + category="partner/3d/Tencent", essentials_category="3D", inputs=[ IO.Combo.Input( @@ -242,7 +242,7 @@ class TencentImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="TencentImageToModelNode", display_name="Hunyuan3D: Image(s) to Model", - category="api node/3d/Tencent", + category="partner/3d/Tencent", essentials_category="3D", inputs=[ IO.Combo.Input( @@ -415,7 +415,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode): return IO.Schema( node_id="TencentModelTo3DUVNode", display_name="Hunyuan3D: Model to UV", - category="api node/3d/Tencent", + category="partner/3d/Tencent", description="Perform UV unfolding on a 3D model to generate UV texture. " "Input model must have less than 30000 faces.", inputs=[ @@ -505,7 +505,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode): return IO.Schema( node_id="Tencent3DTextureEditNode", display_name="Hunyuan3D: 3D Texture Edit", - category="api node/3d/Tencent", + category="partner/3d/Tencent", description="After inputting the 3D model, perform 3D model texture redrawing.", inputs=[ IO.MultiType.Input( @@ -594,7 +594,7 @@ class Tencent3DPartNode(IO.ComfyNode): return IO.Schema( node_id="Tencent3DPartNode", display_name="Hunyuan3D: 3D Part", - category="api node/3d/Tencent", + category="partner/3d/Tencent", description="Automatically perform component identification and generation based on the model structure.", inputs=[ IO.MultiType.Input( @@ -666,7 +666,7 @@ class TencentSmartTopologyNode(IO.ComfyNode): return IO.Schema( node_id="TencentSmartTopologyNode", display_name="Hunyuan3D: Smart Topology", - category="api node/3d/Tencent", + category="partner/3d/Tencent", description="Perform smart retopology on a 3D model. " "Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.", inputs=[ diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 97c3609bd..3b914a850 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -10,6 +10,7 @@ from comfy_api_nodes.apis.ideogram import ( ImageRequest, IdeogramV3Request, IdeogramV3EditRequest, + IdeogramV4Request, ) from comfy_api_nodes.util import ( ApiEndpoint, @@ -17,6 +18,7 @@ from comfy_api_nodes.util import ( download_url_as_bytesio, resize_mask_to_image, sync_op, + validate_string, ) V1_V1_RES_MAP = { @@ -234,7 +236,7 @@ class IdeogramV1(IO.ComfyNode): return IO.Schema( node_id="IdeogramV1", display_name="Ideogram V1", - category="api node/image/Ideogram", + category="partner/image/Ideogram", description="Generates images using the Ideogram V1 model.", inputs=[ IO.String.Input( @@ -360,7 +362,7 @@ class IdeogramV2(IO.ComfyNode): return IO.Schema( node_id="IdeogramV2", display_name="Ideogram V2", - category="api node/image/Ideogram", + category="partner/image/Ideogram", description="Generates images using the Ideogram V2 model.", inputs=[ IO.String.Input( @@ -526,7 +528,7 @@ class IdeogramV3(IO.ComfyNode): return IO.Schema( node_id="IdeogramV3", display_name="Ideogram V3", - category="api node/image/Ideogram", + category="partner/image/Ideogram", description="Generates images using the Ideogram V3 model. " "Supports both regular image generation from text prompts and image editing with mask.", inputs=[ @@ -798,6 +800,119 @@ class IdeogramV3(IO.ComfyNode): return IO.NodeOutput(await download_and_process_images(image_urls)) +class IdeogramV4(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="IdeogramV4", + display_name="Ideogram V4", + category="partner/image/Ideogram", + description="Generates images using the Ideogram 4.0 model from a text prompt.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the image generation.", + ), + IO.Combo.Input( + "resolution", + options=[ + "Auto", + "2048x2048 (1:1)", + "1440x2880 (1:2)", + "2880x1440 (2:1)", + "1664x2496 (2:3)", + "2496x1664 (3:2)", + "1792x2240 (4:5)", + "2240x1792 (5:4)", + "1440x2560 (9:16)", + "2560x1440 (16:9)", + "1600x2560 (5:8)", + "2560x1600 (8:5)", + "1728x2304 (3:4)", + "2304x1728 (4:3)", + "1296x3168 (9:22)", + "3168x1296 (22:9)", + "1152x2944 (9:23)", + "2944x1152 (23:9)", + "1248x3328 (3:8)", + "3328x1248 (8:3)", + "1280x3072 (5:12)", + "3072x1280 (12:5)", + ], + default="Auto", + ), + IO.Combo.Input( + "rendering_speed", + options=["DEFAULT", "TURBO", "QUALITY"], + default="DEFAULT", + tooltip="Controls the trade-off between generation speed and quality.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + step=1, + control_after_generate=True, + display_mode=IO.NumberDisplay.number, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["rendering_speed"]), + expr=""" + ( + $speed := widgets.rendering_speed; + $price := + $contains($speed,"turbo") ? 0.0429 : + $contains($speed,"quality") ? 0.143 : + 0.0858; + {"type":"usd","usd": $price} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + resolution: str, + rendering_speed: str, + seed: int, + ): + validate_string(prompt, strip_whitespace=True, min_length=1) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/ideogram/ideogram-v4/generate", method="POST"), + response_model=IdeogramGenerateResponse, + data=IdeogramV4Request( + text_prompt=prompt, + resolution=resolution.split(" ")[0] if resolution != "Auto" else None, + rendering_speed=rendering_speed, + ), + max_retries=1, + ) + + if not response.data or len(response.data) == 0: + raise Exception("No images were generated in the response") + image_urls = [image_data.url for image_data in response.data if image_data.url] + if not image_urls: + raise Exception("No image URLs were generated in the response") + return IO.NodeOutput(await download_and_process_images(image_urls)) + + class IdeogramExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -805,6 +920,7 @@ class IdeogramExtension(ComfyExtension): IdeogramV1, IdeogramV2, IdeogramV3, + IdeogramV4, ] diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 7586f1816..d11e42540 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -642,7 +642,7 @@ class KlingCameraControls(IO.ComfyNode): return IO.Schema( node_id="KlingCameraControls", display_name="Kling Camera Controls", - category="api node/video/Kling", + category="partner/video/Kling", description="Allows specifying configuration options for Kling Camera Controls and motion control effects.", inputs=[ IO.Combo.Input("camera_control_type", options=KlingCameraControlType), @@ -762,7 +762,7 @@ class KlingTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingTextToVideoNode", display_name="Kling Text to Video", - category="api node/video/Kling", + category="partner/video/Kling", description="Kling Text to Video Node", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -849,7 +849,7 @@ class OmniProTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProTextToVideoNode", display_name="Kling 3.0 Omni Text to Video", - category="api node/video/Kling", + category="partner/video/Kling", description="Use text prompts to generate videos with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -998,7 +998,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProFirstLastFrameNode", display_name="Kling 3.0 Omni First-Last-Frame to Video", - category="api node/video/Kling", + category="partner/video/Kling", description="Use a start frame, an optional end frame, or reference images with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -1205,7 +1205,7 @@ class OmniProImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProImageToVideoNode", display_name="Kling 3.0 Omni Image to Video", - category="api node/video/Kling", + category="partner/video/Kling", description="Use up to 7 reference images to generate a video with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -1374,7 +1374,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProVideoToVideoNode", display_name="Kling 3.0 Omni Video to Video", - category="api node/video/Kling", + category="partner/video/Kling", description="Use a video and up to 4 reference images to generate a video with the latest Kling model.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), @@ -1485,7 +1485,7 @@ class OmniProEditVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProEditVideoNode", display_name="Kling 3.0 Omni Edit Video", - category="api node/video/Kling", + category="partner/video/Kling", essentials_category="Video Generation", description="Edit an existing video with the latest model from Kling.", inputs=[ @@ -1593,7 +1593,7 @@ class OmniProImageNode(IO.ComfyNode): return IO.Schema( node_id="KlingOmniProImageNode", display_name="Kling 3.0 Omni Image", - category="api node/image/Kling", + category="partner/image/Kling", description="Create or edit images with the latest model from Kling.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-image-o1"]), @@ -1721,7 +1721,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode): return IO.Schema( node_id="KlingCameraControlT2VNode", display_name="Kling Text to Video (Camera Control)", - category="api node/video/Kling", + category="partner/video/Kling", description="Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text.", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -1783,7 +1783,7 @@ class KlingImage2VideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingImage2VideoNode", display_name="Kling Image(First Frame) to Video", - category="api node/video/Kling", + category="partner/video/Kling", inputs=[ IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -1882,7 +1882,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode): return IO.Schema( node_id="KlingCameraControlI2VNode", display_name="Kling Image to Video (Camera Control)", - category="api node/video/Kling", + category="partner/video/Kling", description="Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image.", inputs=[ IO.Image.Input( @@ -1953,7 +1953,7 @@ class KlingStartEndFrameNode(IO.ComfyNode): return IO.Schema( node_id="KlingStartEndFrameNode", display_name="Kling Start-End Frame to Video", - category="api node/video/Kling", + category="partner/video/Kling", description="Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last.", inputs=[ IO.Image.Input( @@ -2047,7 +2047,7 @@ class KlingVideoExtendNode(IO.ComfyNode): return IO.Schema( node_id="KlingVideoExtendNode", display_name="Kling Video Extend", - category="api node/video/Kling", + category="partner/video/Kling", description="Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes.", inputs=[ IO.String.Input( @@ -2128,7 +2128,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode): return IO.Schema( node_id="KlingDualCharacterVideoEffectNode", display_name="Kling Dual Character Video Effects", - category="api node/video/Kling", + category="partner/video/Kling", description="Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite.", inputs=[ IO.Image.Input("image_left", tooltip="Left side image"), @@ -2218,7 +2218,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): return IO.Schema( node_id="KlingSingleImageVideoEffectNode", display_name="Kling Video Effects", - category="api node/video/Kling", + category="partner/video/Kling", description="Achieve different special effects when generating a video based on the effect_scene.", inputs=[ IO.Image.Input( @@ -2291,7 +2291,7 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingLipSyncAudioToVideoNode", display_name="Kling Lip Sync Video with Audio", - category="api node/video/Kling", + category="partner/video/Kling", essentials_category="Video Generation", description="Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", inputs=[ @@ -2343,7 +2343,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingLipSyncTextToVideoNode", display_name="Kling Lip Sync Video with Text", - category="api node/video/Kling", + category="partner/video/Kling", description="Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length.", inputs=[ IO.Video.Input("video"), @@ -2411,7 +2411,7 @@ class KlingVirtualTryOnNode(IO.ComfyNode): return IO.Schema( node_id="KlingVirtualTryOnNode", display_name="Kling Virtual Try On", - category="api node/image/Kling", + category="partner/image/Kling", description="Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background.", inputs=[ IO.Image.Input("human_image"), @@ -2478,7 +2478,7 @@ class KlingImageGenerationNode(IO.ComfyNode): return IO.Schema( node_id="KlingImageGenerationNode", display_name="Kling 3.0 Image", - category="api node/image/Kling", + category="partner/image/Kling", description="Kling Image Generation Node. Generate an image from a text prompt with an optional reference image.", inputs=[ IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), @@ -2615,7 +2615,7 @@ class TextToVideoWithAudio(IO.ComfyNode): return IO.Schema( node_id="KlingTextToVideoWithAudio", display_name="Kling 2.6 Text to Video with Audio", - category="api node/video/Kling", + category="partner/video/Kling", inputs=[ IO.Combo.Input("model_name", options=["kling-v2-6"]), IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."), @@ -2683,7 +2683,7 @@ class ImageToVideoWithAudio(IO.ComfyNode): return IO.Schema( node_id="KlingImageToVideoWithAudio", display_name="Kling 2.6 Image(First Frame) to Video with Audio", - category="api node/video/Kling", + category="partner/video/Kling", inputs=[ IO.Combo.Input("model_name", options=["kling-v2-6"]), IO.Image.Input("start_frame"), @@ -2753,7 +2753,7 @@ class MotionControl(IO.ComfyNode): return IO.Schema( node_id="KlingMotionControl", display_name="Kling Motion Control", - category="api node/video/Kling", + category="partner/video/Kling", inputs=[ IO.String.Input("prompt", multiline=True), IO.Image.Input("reference_image"), @@ -2854,7 +2854,7 @@ class KlingVideoNode(IO.ComfyNode): return IO.Schema( node_id="KlingVideoNode", display_name="Kling 3.0 Video", - category="api node/video/Kling", + category="partner/video/Kling", description="Generate videos with Kling V3. " "Supports text-to-video and image-to-video with optional storyboard multi-prompt and audio generation.", inputs=[ @@ -3077,7 +3077,7 @@ class KlingFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="KlingFirstLastFrameNode", display_name="Kling 3.0 First-Last-Frame to Video", - category="api node/video/Kling", + category="partner/video/Kling", description="Generate videos with Kling V3 using first and last frames.", inputs=[ IO.String.Input("prompt", multiline=True, default=""), @@ -3202,7 +3202,7 @@ class KlingAvatarNode(IO.ComfyNode): return IO.Schema( node_id="KlingAvatarNode", display_name="Kling Avatar 2.0", - category="api node/video/Kling", + category="partner/video/Kling", description="Generate broadcast-style digital human videos from a single photo and an audio file.", inputs=[ IO.Image.Input( diff --git a/comfy_api_nodes/nodes_krea.py b/comfy_api_nodes/nodes_krea.py new file mode 100644 index 000000000..b9e6268f2 --- /dev/null +++ b/comfy_api_nodes/nodes_krea.py @@ -0,0 +1,294 @@ +"""Krea image-generation nodes.""" + +import re + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.krea import ( + KreaAssetResponse, + KreaGenerateImageRequest, + KreaImageStyleReference, + KreaJob, + KreaMoodboard, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + poll_op, + sync_op, + tensor_to_bytesio, + validate_string, +) + + +class KreaIO: + STYLE_REF = "KREA_STYLE_REF" + + +async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Image) -> str: + """Upload an image to Krea's /assets endpoint and return the Krea-hosted image URL.""" + img_io = tensor_to_bytesio(image, total_pixels=2048 * 2048, mime_type="image/png") + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/krea/assets", method="POST"), + response_model=KreaAssetResponse, + files=[("file", (img_io.name, img_io, "image/png"))], + content_type="multipart/form-data", + max_retries=1, + wait_label="Uploading reference", + ) + return response.image_url + + +_MODEL_MEDIUM = "Krea 2 Medium" +_MODEL_MEDIUM_TURBO = "Krea 2 Medium Turbo" +_MODEL_LARGE = "Krea 2 Large" +_MODEL_ENDPOINTS: dict[str, str] = { + _MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium", + _MODEL_MEDIUM_TURBO: "/proxy/krea/generate/image/krea/krea-2/medium-turbo", + _MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large", +} + +_ASPECT_RATIOS = ["1:1", "4:3", "3:2", "16:9", "2.35:1", "4:5", "2:3", "9:16"] +_RESOLUTIONS = ["1K"] +_CREATIVITY_LEVELS = ["raw", "low", "medium", "high"] +_KREA_QUEUED_STATUSES = ["backlogged", "queued", "scheduled"] + +_UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}$") + + +def _krea_model_inputs() -> list: + """Nested inputs shared by Krea 2 Medium, Medium Turbo and Large under the DynamicCombo.""" + return [ + IO.Combo.Input( + "aspect_ratio", + options=_ASPECT_RATIOS, + tooltip="Output aspect ratio.", + ), + IO.Combo.Input( + "resolution", + options=_RESOLUTIONS, + tooltip="Resolution scale.", + ), + IO.Combo.Input( + "creativity", + options=_CREATIVITY_LEVELS, + default="medium", + tooltip="Prompt interpretation strength: raw stays closest to the prompt; high is most creative.", + ), + IO.String.Input( + "moodboard_id", + default="", + tooltip="Optional Krea moodboard UUID (e.g. from the Krea website). " + "Leave empty to disable. Only one moodboard is supported per request.", + optional=True, + ), + IO.Float.Input( + "moodboard_strength", + default=0.35, + min=-0.5, + max=1.5, + step=0.05, + tooltip="Moodboard influence; ignored when moodboard_id is empty.", + optional=True, + ), + IO.Custom(KreaIO.STYLE_REF).Input( + "style_reference", + optional=True, + tooltip="Optional chain of style references (max 10) from Krea 2 Style Reference nodes.", + ), + ] + + +class Krea2ImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Krea2ImageNode", + display_name="Krea 2 Image", + category="partner/image/Krea", + description=( + "Generate images via Krea 2 — pick Medium (expressive illustrations) or " + "Large (expressive photorealism). Supports an optional moodboard and up " + "to 10 chained image style references." + ), + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text prompt for the image.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()), + IO.DynamicCombo.Option(_MODEL_MEDIUM_TURBO, _krea_model_inputs()), + IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()), + ], + tooltip="Krea 2 Medium is best for expressive illustrations; " + "Krea 2 Large is best for expressive photorealism.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Random seed for reproducibility.", + ), + ], + outputs=[IO.Image.Output()], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends( + widgets=["model", "model.moodboard_id"], + inputs=["model.style_reference"], + ), + expr=""" + ( + $rates := { + "krea 2 medium turbo": {"text": 0.015, "style": 0.0175, "moodboard": 0.02}, + "krea 2 medium": {"text": 0.03, "style": 0.035, "moodboard": 0.04}, + "krea 2 large": {"text": 0.06, "style": 0.065, "moodboard": 0.07} + }; + $r := $lookup($rates, widgets.model); + $hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0; + $hasStyle := $lookup(inputs, "model.style_reference").connected; + $usd := $hasMoodboard ? $r.moodboard : ($hasStyle ? $r.style : $r.text); + {"type":"usd","usd": $usd} + ) + """, + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1) + + model_choice = model["model"] + endpoint_path = _MODEL_ENDPOINTS.get(model_choice) + if endpoint_path is None: + raise ValueError(f"Unknown Krea 2 model: {model_choice!r}") + + moodboards: list[KreaMoodboard] | None = None + mb_id = (model.get("moodboard_id") or "").strip() + if mb_id: + if not _UUID_RE.match(mb_id): + raise ValueError(f"moodboard_id must be a UUID (received {mb_id!r}); copy it from the Krea website.") + mb_strength = model.get("moodboard_strength") + moodboards = [KreaMoodboard(id=mb_id, strength=0.35 if mb_strength is None else float(mb_strength))] + + style_reference = model.get("style_reference") + image_style_references: list[KreaImageStyleReference] | None = None + if style_reference: + if len(style_reference) > 10: + raise ValueError(f"Krea 2 accepts at most 10 image_style_references; received {len(style_reference)}.") + image_style_references = [ + KreaImageStyleReference(url=ref["url"], strength=float(ref["strength"])) for ref in style_reference + ] + initial = await sync_op( + cls, + ApiEndpoint(path=endpoint_path, method="POST"), + response_model=KreaJob, + data=KreaGenerateImageRequest( + prompt=prompt, + aspect_ratio=model["aspect_ratio"], + resolution=model["resolution"], + seed=seed, + creativity=model["creativity"], + moodboards=moodboards, + image_style_references=image_style_references, + ), + ) + job = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/krea/jobs/{initial.job_id}", method="GET"), + response_model=KreaJob, + status_extractor=lambda r: r.status, + queued_statuses=_KREA_QUEUED_STATUSES, + ) + if not job.result or not job.result.urls: + raise RuntimeError(f"Krea 2 job {job.job_id} completed without any image URLs.") + image = await download_url_to_image_tensor(job.result.urls[0]) + return IO.NodeOutput(image) + + +class Krea2StyleReferenceNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="Krea2StyleReferenceNode", + display_name="Krea 2 Style Reference", + category="partner/image/Krea", + description=( + "Add an image style reference to a Krea 2 generation. Chain multiple Krea 2 " + "Style Reference nodes (max 10) and feed the final `style_reference` output " + "into Krea 2 Image. Each image is uploaded to ComfyAPI storage and passed as URL." + ), + inputs=[ + IO.Image.Input( + "image", + tooltip="Reference image whose style influences the generation.", + ), + IO.Float.Input( + "strength", + default=1.0, + min=-2.0, + max=2.0, + step=0.05, + tooltip="Reference strength; negative values invert the style influence.", + ), + IO.Custom(KreaIO.STYLE_REF).Input( + "style_reference", + optional=True, + tooltip="Optional incoming chain of style references; this node appends one more.", + ), + ], + outputs=[IO.Custom(KreaIO.STYLE_REF).Output(display_name="style_reference")], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + strength: float, + style_reference: list[dict] | None = None, + ) -> IO.NodeOutput: + chain: list[dict] = list(style_reference) if style_reference else [] + if len(chain) >= 10: + raise ValueError("Krea 2 accepts at most 10 image_style_references in one generation.") + url = await _upload_image_to_krea_assets(cls, image) + chain.append({"url": url, "strength": float(strength)}) + return IO.NodeOutput(chain) + + +class KreaExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + Krea2ImageNode, + Krea2StyleReferenceNode, + ] + + +async def comfy_entrypoint() -> KreaExtension: + return KreaExtension() diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py index 0a219af96..878e04b4e 100644 --- a/comfy_api_nodes/nodes_ltxv.py +++ b/comfy_api_nodes/nodes_ltxv.py @@ -50,7 +50,7 @@ class TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="LtxvApiTextToVideo", display_name="LTXV Text To Video", - category="api node/video/LTXV", + category="partner/video/LTXV", description="Professional-quality videos with customizable duration and resolution.", inputs=[ IO.Combo.Input("model", options=list(MODELS_MAP.keys())), @@ -127,7 +127,7 @@ class ImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="LtxvApiImageToVideo", display_name="LTXV Image To Video", - category="api node/video/LTXV", + category="partner/video/LTXV", description="Professional-quality videos with customizable duration and resolution based on start image.", inputs=[ IO.Image.Input("image", tooltip="First frame to be used for the video."), diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index d92a7c382..0d31ac77e 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -46,7 +46,7 @@ class LumaReferenceNode(IO.ComfyNode): return IO.Schema( node_id="LumaReferenceNode", display_name="Luma Reference", - category="api node/image/Luma", + category="partner/image/Luma", description="Holds an image and weight for use with Luma Generate Image node.", inputs=[ IO.Image.Input( @@ -85,7 +85,7 @@ class LumaConceptsNode(IO.ComfyNode): return IO.Schema( node_id="LumaConceptsNode", display_name="Luma Concepts", - category="api node/video/Luma", + category="partner/video/Luma", description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.", inputs=[ IO.Combo.Input( @@ -134,7 +134,7 @@ class LumaImageGenerationNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageNode", display_name="Luma Text to Image", - category="api node/image/Luma", + category="partner/image/Luma", description="Generates images synchronously based on prompt and aspect ratio.", inputs=[ IO.String.Input( @@ -278,7 +278,7 @@ class LumaImageModifyNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageModifyNode", display_name="Luma Image to Image", - category="api node/image/Luma", + category="partner/image/Luma", description="Modifies images synchronously based on prompt and aspect ratio.", inputs=[ IO.Image.Input( @@ -371,7 +371,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="LumaVideoNode", display_name="Luma Text to Video", - category="api node/video/Luma", + category="partner/video/Luma", description="Generates videos synchronously based on prompt and output_size.", inputs=[ IO.String.Input( @@ -472,7 +472,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageToVideoNode", display_name="Luma Image to Video", - category="api node/video/Luma", + category="partner/video/Luma", description="Generates videos synchronously based on prompt, input images, and output_size.", inputs=[ IO.String.Input( @@ -724,7 +724,7 @@ class LumaImageNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageNode2", display_name="Luma UNI-1 Image", - category="api node/image/Luma", + category="partner/image/Luma", description="Generate images from text using the Luma UNI-1 model.", inputs=[ IO.String.Input( @@ -853,7 +853,7 @@ class LumaImageEditNode(IO.ComfyNode): return IO.Schema( node_id="LumaImageEditNode2", display_name="Luma UNI-1 Image Edit", - category="api node/image/Luma", + category="partner/image/Luma", description="Edit an existing image with a text prompt using the Luma UNI-1 model.", inputs=[ IO.Image.Input( diff --git a/comfy_api_nodes/nodes_magnific.py b/comfy_api_nodes/nodes_magnific.py index 38b881fea..4ce4735df 100644 --- a/comfy_api_nodes/nodes_magnific.py +++ b/comfy_api_nodes/nodes_magnific.py @@ -61,7 +61,7 @@ class MagnificImageUpscalerCreativeNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageUpscalerCreativeNode", display_name="Magnific Image Upscale (Creative)", - category="api node/image/Magnific", + category="partner/image/Magnific", description="Prompt‑guided enhancement, stylization, and 2x/4x/8x/16x upscaling. " "Maximum output: 25.3 megapixels.", inputs=[ @@ -240,7 +240,7 @@ class MagnificImageUpscalerPreciseV2Node(IO.ComfyNode): return IO.Schema( node_id="MagnificImageUpscalerPreciseV2Node", display_name="Magnific Image Upscale (Precise V2)", - category="api node/image/Magnific", + category="partner/image/Magnific", description="High-fidelity upscaling with fine control over sharpness, grain, and detail. " "Maximum output: 10060×10060 pixels.", inputs=[ @@ -400,7 +400,7 @@ class MagnificImageStyleTransferNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageStyleTransferNode", display_name="Magnific Image Style Transfer", - category="api node/image/Magnific", + category="partner/image/Magnific", description="Transfer the style from a reference image to your input image.", inputs=[ IO.Image.Input("image", tooltip="The image to apply style transfer to."), @@ -549,7 +549,7 @@ class MagnificImageRelightNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageRelightNode", display_name="Magnific Image Relight", - category="api node/image/Magnific", + category="partner/image/Magnific", description="Relight an image with lighting adjustments and optional reference-based light transfer.", inputs=[ IO.Image.Input("image", tooltip="The image to relight."), @@ -789,7 +789,7 @@ class MagnificImageSkinEnhancerNode(IO.ComfyNode): return IO.Schema( node_id="MagnificImageSkinEnhancerNode", display_name="Magnific Image Skin Enhancer", - category="api node/image/Magnific", + category="partner/image/Magnific", description="Skin enhancement for portraits with multiple processing modes.", inputs=[ IO.Image.Input("image", tooltip="The portrait image to enhance."), diff --git a/comfy_api_nodes/nodes_meshy.py b/comfy_api_nodes/nodes_meshy.py index 3cf577f4a..3a24f1095 100644 --- a/comfy_api_nodes/nodes_meshy.py +++ b/comfy_api_nodes/nodes_meshy.py @@ -33,7 +33,7 @@ class MeshyTextToModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyTextToModelNode", display_name="Meshy: Text to Model", - category="api node/3d/Meshy", + category="partner/3d/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.String.Input("prompt", multiline=True, default=""), @@ -145,7 +145,7 @@ class MeshyRefineNode(IO.ComfyNode): return IO.Schema( node_id="MeshyRefineNode", display_name="Meshy: Refine Draft Model", - category="api node/3d/Meshy", + category="partner/3d/Meshy", description="Refine a previously created draft model.", inputs=[ IO.Combo.Input("model", options=["latest"]), @@ -240,7 +240,7 @@ class MeshyImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyImageToModelNode", display_name="Meshy: Image to Model", - category="api node/3d/Meshy", + category="partner/3d/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.Image.Input("image"), @@ -405,7 +405,7 @@ class MeshyMultiImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyMultiImageToModelNode", display_name="Meshy: Multi-Image to Model", - category="api node/3d/Meshy", + category="partner/3d/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.Autogrow.Input( @@ -575,7 +575,7 @@ class MeshyRigModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyRigModelNode", display_name="Meshy: Rig Model", - category="api node/3d/Meshy", + category="partner/3d/Meshy", description="Provides a rigged character in standard formats. " "Auto-rigging is currently not suitable for untextured meshes, non-humanoid assets, " "or humanoid assets with unclear limb and body structure.", @@ -656,7 +656,7 @@ class MeshyAnimateModelNode(IO.ComfyNode): return IO.Schema( node_id="MeshyAnimateModelNode", display_name="Meshy: Animate Model", - category="api node/3d/Meshy", + category="partner/3d/Meshy", description="Apply a specific animation action to a previously rigged character.", inputs=[ IO.Custom("MESHY_RIGGED_TASK_ID").Input("rig_task_id"), @@ -722,7 +722,7 @@ class MeshyTextureNode(IO.ComfyNode): return IO.Schema( node_id="MeshyTextureNode", display_name="Meshy: Texture Model", - category="api node/3d/Meshy", + category="partner/3d/Meshy", inputs=[ IO.Combo.Input("model", options=["latest"]), IO.Custom("MESHY_TASK_ID").Input("meshy_task_id"), diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index b5d0b461f..6250af146 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -101,7 +101,7 @@ class MinimaxTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxTextToVideoNode", display_name="MiniMax Text to Video", - category="api node/video/MiniMax", + category="partner/video/MiniMax", description="Generates videos synchronously based on a prompt, and optional parameters.", inputs=[ IO.String.Input( @@ -163,7 +163,7 @@ class MinimaxImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxImageToVideoNode", display_name="MiniMax Image to Video", - category="api node/video/MiniMax", + category="partner/video/MiniMax", description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( @@ -230,7 +230,7 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxSubjectToVideoNode", display_name="MiniMax Subject to Video", - category="api node/video/MiniMax", + category="partner/video/MiniMax", description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( @@ -294,7 +294,7 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): return IO.Schema( node_id="MinimaxHailuoVideoNode", display_name="MiniMax Hailuo Video", - category="api node/video/MiniMax", + category="partner/video/MiniMax", description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index a5a188634..0fe5fb9d0 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -99,7 +99,7 @@ class OpenAIDalle2(IO.ComfyNode): return IO.Schema( node_id="OpenAIDalle2", display_name="OpenAI DALL·E 2", - category="api node/image/OpenAI", + category="partner/image/OpenAI", description="Generates images synchronously via OpenAI's DALL·E 2 endpoint.", inputs=[ IO.String.Input( @@ -249,7 +249,7 @@ class OpenAIDalle3(IO.ComfyNode): return IO.Schema( node_id="OpenAIDalle3", display_name="OpenAI DALL·E 3", - category="api node/image/OpenAI", + category="partner/image/OpenAI", description="Generates images synchronously via OpenAI's DALL·E 3 endpoint.", inputs=[ IO.String.Input( @@ -371,7 +371,7 @@ class OpenAIGPTImage1(IO.ComfyNode): return IO.Schema( node_id="OpenAIGPTImage1", display_name="OpenAI GPT Image 2", - category="api node/image/OpenAI", + category="partner/image/OpenAI", description="Generates images synchronously via OpenAI's GPT Image endpoint.", is_deprecated=True, inputs=[ @@ -695,7 +695,7 @@ class OpenAIGPTImageNodeV2(IO.ComfyNode): return IO.Schema( node_id="OpenAIGPTImageNodeV2", display_name="OpenAI GPT Image 2", - category="api node/image/OpenAI", + category="partner/image/OpenAI", description="Generates images via OpenAI's GPT Image endpoint.", inputs=[ IO.String.Input( @@ -962,7 +962,7 @@ class OpenAIChatNode(IO.ComfyNode): return IO.Schema( node_id="OpenAIChatNode", display_name="OpenAI ChatGPT", - category="api node/text/OpenAI", + category="partner/text/OpenAI", essentials_category="Text Generation", description="Generate text responses from an OpenAI model.", inputs=[ @@ -1201,7 +1201,7 @@ class OpenAIInputFiles(IO.ComfyNode): return IO.Schema( node_id="OpenAIInputFiles", display_name="OpenAI ChatGPT Input Files", - category="api node/text/OpenAI", + category="partner/text/OpenAI", description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.", inputs=[ IO.Combo.Input( @@ -1248,7 +1248,7 @@ class OpenAIChatConfig(IO.ComfyNode): return IO.Schema( node_id="OpenAIChatConfig", display_name="OpenAI ChatGPT Advanced Options", - category="api node/text/OpenAI", + category="partner/text/OpenAI", description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.", inputs=[ IO.Combo.Input( diff --git a/comfy_api_nodes/nodes_openrouter.py b/comfy_api_nodes/nodes_openrouter.py index 031301870..ba98133f0 100644 --- a/comfy_api_nodes/nodes_openrouter.py +++ b/comfy_api_nodes/nodes_openrouter.py @@ -265,7 +265,7 @@ class OpenRouterLLMNode(IO.ComfyNode): return IO.Schema( node_id="OpenRouterLLMNode", display_name="OpenRouter LLM", - category="api node/text/OpenRouter", + category="partner/text/OpenRouter", essentials_category="Text Generation", description=( "Generate text responses through OpenRouter. Routes to a curated set of popular " diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index e17a24ae7..4c8b723b9 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -53,7 +53,7 @@ class PixverseTemplateNode(IO.ComfyNode): return IO.Schema( node_id="PixverseTemplateNode", display_name="PixVerse Template", - category="api node/video/PixVerse", + category="partner/video/PixVerse", inputs=[ IO.Combo.Input("template", options=list(pixverse_templates.keys())), ], @@ -74,7 +74,7 @@ class PixverseTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="PixverseTextToVideoNode", display_name="PixVerse Text to Video", - category="api node/video/PixVerse", + category="partner/video/PixVerse", description="Generates videos based on prompt and output_size.", inputs=[ IO.String.Input( @@ -192,7 +192,7 @@ class PixverseImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="PixverseImageToVideoNode", display_name="PixVerse Image to Video", - category="api node/video/PixVerse", + category="partner/video/PixVerse", description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("image"), @@ -310,7 +310,7 @@ class PixverseTransitionVideoNode(IO.ComfyNode): return IO.Schema( node_id="PixverseTransitionVideoNode", display_name="PixVerse Transition Video", - category="api node/video/PixVerse", + category="partner/video/PixVerse", description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("first_frame"), diff --git a/comfy_api_nodes/nodes_quiver.py b/comfy_api_nodes/nodes_quiver.py index 3269c0afe..34929fa0c 100644 --- a/comfy_api_nodes/nodes_quiver.py +++ b/comfy_api_nodes/nodes_quiver.py @@ -62,7 +62,7 @@ class QuiverTextToSVGNode(IO.ComfyNode): return IO.Schema( node_id="QuiverTextToSVGNode", display_name="Quiver Text to SVG", - category="api node/image/Quiver", + category="partner/image/Quiver", description="Generate an SVG from a text prompt using Quiver AI.", inputs=[ IO.String.Input( @@ -177,7 +177,7 @@ class QuiverImageToSVGNode(IO.ComfyNode): return IO.Schema( node_id="QuiverImageToSVGNode", display_name="Quiver Image to SVG", - category="api node/image/Quiver", + category="partner/image/Quiver", description="Vectorize a raster image into SVG using Quiver AI.", inputs=[ IO.Image.Input( diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index c60cfbc4a..c44942f50 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -178,7 +178,7 @@ class RecraftColorRGBNode(IO.ComfyNode): return IO.Schema( node_id="RecraftColorRGB", display_name="Recraft Color RGB", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Create Recraft Color by choosing specific RGB values.", inputs=[ IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."), @@ -204,7 +204,7 @@ class RecraftControlsNode(IO.ComfyNode): return IO.Schema( node_id="RecraftControls", display_name="Recraft Controls", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Create Recraft Controls for customizing Recraft generation.", inputs=[ IO.Custom(RecraftIO.COLOR).Input("colors", optional=True), @@ -228,7 +228,7 @@ class RecraftStyleV3RealisticImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftStyleV3RealisticImage", display_name="Recraft Style - Realistic Image", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), @@ -253,7 +253,7 @@ class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode): return IO.Schema( node_id="RecraftStyleV3DigitalIllustration", display_name="Recraft Style - Digital Illustration", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), @@ -272,7 +272,7 @@ class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode): return IO.Schema( node_id="RecraftStyleV3VectorIllustrationNode", display_name="Recraft Style - Realistic Image", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), @@ -291,7 +291,7 @@ class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode): return IO.Schema( node_id="RecraftStyleV3LogoRaster", display_name="Recraft Style - Logo Raster", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Select realistic_image style and optional substyle.", inputs=[ IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)), @@ -308,7 +308,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): return IO.Schema( node_id="RecraftStyleV3InfiniteStyleLibrary", display_name="Recraft Style - Infinite Style Library", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.", inputs=[ IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."), @@ -331,7 +331,7 @@ class RecraftCreateStyleNode(IO.ComfyNode): return IO.Schema( node_id="RecraftCreateStyleNode", display_name="Recraft Create Style", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Create a custom style from reference images. " "Upload 1-5 images to use as style references. " "Total size of all images is limited to 5 MB.", @@ -400,7 +400,7 @@ class RecraftTextToImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftTextToImageNode", display_name="Recraft Text to Image", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Generates images synchronously based on prompt and resolution.", inputs=[ IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), @@ -512,7 +512,7 @@ class RecraftImageToImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftImageToImageNode", display_name="Recraft Image to Image", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Modify image based on prompt and strength.", inputs=[ IO.Image.Input("image"), @@ -630,7 +630,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode): return IO.Schema( node_id="RecraftImageInpaintingNode", display_name="Recraft Image Inpainting", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Modify image based on prompt and mask.", inputs=[ IO.Image.Input("image"), @@ -732,7 +732,7 @@ class RecraftTextToVectorNode(IO.ComfyNode): return IO.Schema( node_id="RecraftTextToVectorNode", display_name="Recraft Text to Vector", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Generates SVG synchronously based on prompt and resolution.", inputs=[ IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True), @@ -832,7 +832,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftVectorizeImageNode", display_name="Recraft Vectorize Image", - category="api node/image/Recraft", + category="partner/image/Recraft", essentials_category="Image Tools", description="Generates SVG synchronously from an input image.", inputs=[ @@ -876,7 +876,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode): return IO.Schema( node_id="RecraftReplaceBackgroundNode", display_name="Recraft Replace Background", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Replace background on image, based on provided prompt.", inputs=[ IO.Image.Input("image"), @@ -963,7 +963,7 @@ class RecraftRemoveBackgroundNode(IO.ComfyNode): return IO.Schema( node_id="RecraftRemoveBackgroundNode", display_name="Recraft Remove Background", - category="api node/image/Recraft", + category="partner/image/Recraft", essentials_category="Image Tools", description="Remove background from image, and return processed image and mask.", inputs=[ @@ -1012,7 +1012,7 @@ class RecraftCrispUpscaleNode(IO.ComfyNode): return IO.Schema( node_id="RecraftCrispUpscaleNode", display_name="Recraft Crisp Upscale Image", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Upscale image synchronously.\n" "Enhances a given raster image using ‘crisp upscale’ tool, " "increasing image resolution, making the image sharper and cleaner.", @@ -1058,7 +1058,7 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): return IO.Schema( node_id="RecraftCreativeUpscaleNode", display_name="Recraft Creative Upscale Image", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Upscale image synchronously.\n" "Enhances a given raster image using ‘creative upscale’ tool, " "boosting resolution with a focus on refining small details and faces.", @@ -1086,7 +1086,7 @@ class RecraftV4TextToImageNode(IO.ComfyNode): return IO.Schema( node_id="RecraftV4TextToImageNode", display_name="Recraft V4 Text to Image", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Generates images using Recraft V4 or V4 Pro models.", inputs=[ IO.String.Input( @@ -1210,7 +1210,7 @@ class RecraftV4TextToVectorNode(IO.ComfyNode): return IO.Schema( node_id="RecraftV4TextToVectorNode", display_name="Recraft V4 Text to Vector", - category="api node/image/Recraft", + category="partner/image/Recraft", description="Generates SVG using Recraft V4 or V4 Pro models.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_reve.py b/comfy_api_nodes/nodes_reve.py index a87395394..177349a8b 100644 --- a/comfy_api_nodes/nodes_reve.py +++ b/comfy_api_nodes/nodes_reve.py @@ -109,7 +109,7 @@ class ReveImageCreateNode(IO.ComfyNode): return IO.Schema( node_id="ReveImageCreateNode", display_name="Reve Image Create", - category="api node/image/Reve", + category="partner/image/Reve", description="Generate images from text descriptions using Reve.", inputs=[ IO.String.Input( @@ -200,7 +200,7 @@ class ReveImageEditNode(IO.ComfyNode): return IO.Schema( node_id="ReveImageEditNode", display_name="Reve Image Edit", - category="api node/image/Reve", + category="partner/image/Reve", description="Edit images using natural language instructions with Reve.", inputs=[ IO.Image.Input("image", tooltip="The image to edit."), @@ -300,7 +300,7 @@ class ReveImageRemixNode(IO.ComfyNode): return IO.Schema( node_id="ReveImageRemixNode", display_name="Reve Image Remix", - category="api node/image/Reve", + category="partner/image/Reve", description="Combine reference images with text prompts to create new images using Reve.", inputs=[ IO.Autogrow.Input( diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 2df5a3e13..0375a2123 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -230,7 +230,7 @@ class Rodin3D_Regular(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Regular", display_name="Rodin 3D Generate - Regular Generate", - category="api node/3d/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -289,7 +289,7 @@ class Rodin3D_Detail(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Detail", display_name="Rodin 3D Generate - Detail Generate", - category="api node/3d/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -348,7 +348,7 @@ class Rodin3D_Smooth(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Smooth", display_name="Rodin 3D Generate - Smooth Generate", - category="api node/3d/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -406,7 +406,7 @@ class Rodin3D_Sketch(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Sketch", display_name="Rodin 3D Generate - Sketch Generate", - category="api node/3d/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -468,7 +468,7 @@ class Rodin3D_Gen2(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Gen2", display_name="Rodin 3D Generate - Gen-2 Generate", - category="api node/3d/Rodin", + category="partner/3d/Rodin", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("Images"), @@ -941,7 +941,7 @@ class Rodin3D_Gen25_Image(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Gen25_Image", display_name="Rodin 3D Gen-2.5 - Image to 3D", - category="api node/3d/Rodin", + category="partner/3d/Rodin", description=( "Generate a 3D model from 1-5 reference images via Rodin Gen-2.5. " "Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost." @@ -1035,7 +1035,7 @@ class Rodin3D_Gen25_Text(IO.ComfyNode): return IO.Schema( node_id="Rodin3D_Gen25_Text", display_name="Rodin 3D Gen-2.5 - Text to 3D", - category="api node/3d/Rodin", + category="partner/3d/Rodin", description=( "Generate a 3D model from a text prompt via Rodin Gen-2.5. " "Pick a mode (Fast / Regular / Extreme-High) to tune quality vs. cost." diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index 573170ba2..b9c5c81a1 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -140,7 +140,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): return IO.Schema( node_id="RunwayImageToVideoNodeGen3a", display_name="Runway Image to Video (Gen3a Turbo)", - category="api node/video/Runway", + category="partner/video/Runway", description="Generate a video from a single starting frame using Gen3a Turbo model. " "Before diving in, review these best practices to ensure that " "your input selections will set your generation up for success: " @@ -234,7 +234,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): return IO.Schema( node_id="RunwayImageToVideoNodeGen4", display_name="Runway Image to Video (Gen4 Turbo)", - category="api node/video/Runway", + category="partner/video/Runway", description="Generate a video from a single starting frame using Gen4 Turbo model. " "Before diving in, review these best practices to ensure that " "your input selections will set your generation up for success: " @@ -329,7 +329,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="RunwayFirstLastFrameNode", display_name="Runway First-Last-Frame to Video", - category="api node/video/Runway", + category="partner/video/Runway", description="Upload first and last keyframes, draft a prompt, and generate a video. " "More complex transitions, such as cases where the Last frame is completely different " "from the First frame, may benefit from the longer 10s duration. " @@ -440,7 +440,7 @@ class RunwayTextToImageNode(IO.ComfyNode): return IO.Schema( node_id="RunwayTextToImageNode", display_name="Runway Text to Image", - category="api node/image/Runway", + category="partner/image/Runway", description="Generate an image from a text prompt using Runway's Gen 4 model. " "You can also include reference image to guide the generation.", inputs=[ diff --git a/comfy_api_nodes/nodes_sonilo.py b/comfy_api_nodes/nodes_sonilo.py index 5518f5902..9ce896ed0 100644 --- a/comfy_api_nodes/nodes_sonilo.py +++ b/comfy_api_nodes/nodes_sonilo.py @@ -34,7 +34,7 @@ class SoniloVideoToMusic(IO.ComfyNode): return IO.Schema( node_id="SoniloVideoToMusic", display_name="Sonilo Video to Music", - category="api node/audio/Sonilo", + category="partner/audio/Sonilo", description="Generate music from video content using Sonilo's AI model. " "Analyzes the video and creates matching music.", inputs=[ @@ -99,7 +99,7 @@ class SoniloTextToMusic(IO.ComfyNode): return IO.Schema( node_id="SoniloTextToMusic", display_name="Sonilo Text to Music", - category="api node/audio/Sonilo", + category="partner/audio/Sonilo", description="Generate music from a text prompt using Sonilo's AI model. " "Leave duration at 0 to let the model infer it from the prompt.", inputs=[ diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index c1d485188..4ff1d649f 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -34,7 +34,7 @@ class OpenAIVideoSora2(IO.ComfyNode): return IO.Schema( node_id="OpenAIVideoSora2", display_name="OpenAI Sora - Video (DEPRECATED)", - category="api node/video/Sora", + category="partner/video/Sora", description=( "OpenAI video and audio generation.\n\n" "DEPRECATION NOTICE: OpenAI will stop serving the Sora v2 API in September 2026. " diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 906d8ff35..9eaba173b 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -62,7 +62,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode): return IO.Schema( node_id="StabilityStableImageUltraNode", display_name="Stability AI Stable Image Ultra", - category="api node/image/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.String.Input( @@ -197,7 +197,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): return IO.Schema( node_id="StabilityStableImageSD_3_5Node", display_name="Stability AI Stable Diffusion 3.5 Image", - category="api node/image/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.String.Input( @@ -354,7 +354,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): return IO.Schema( node_id="StabilityUpscaleConservativeNode", display_name="Stability AI Upscale Conservative", - category="api node/image/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("image"), @@ -457,7 +457,7 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): return IO.Schema( node_id="StabilityUpscaleCreativeNode", display_name="Stability AI Upscale Creative", - category="api node/image/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("image"), @@ -578,7 +578,7 @@ class StabilityUpscaleFastNode(IO.ComfyNode): return IO.Schema( node_id="StabilityUpscaleFastNode", display_name="Stability AI Upscale Fast", - category="api node/image/Stability AI", + category="partner/image/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Image.Input("image"), @@ -630,7 +630,7 @@ class StabilityTextToAudio(IO.ComfyNode): return IO.Schema( node_id="StabilityTextToAudio", display_name="Stability AI Text To Audio", - category="api node/audio/Stability AI", + category="partner/audio/Stability AI", essentials_category="Audio", description=cleandoc(cls.__doc__ or ""), inputs=[ @@ -708,7 +708,7 @@ class StabilityAudioToAudio(IO.ComfyNode): return IO.Schema( node_id="StabilityAudioToAudio", display_name="Stability AI Audio To Audio", - category="api node/audio/Stability AI", + category="partner/audio/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Combo.Input( @@ -802,7 +802,7 @@ class StabilityAudioInpaint(IO.ComfyNode): return IO.Schema( node_id="StabilityAudioInpaint", display_name="Stability AI Audio Inpaint", - category="api node/audio/Stability AI", + category="partner/audio/Stability AI", description=cleandoc(cls.__doc__ or ""), inputs=[ IO.Combo.Input( diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py index e79c16d3c..f7ef4cbf6 100644 --- a/comfy_api_nodes/nodes_topaz.py +++ b/comfy_api_nodes/nodes_topaz.py @@ -52,7 +52,7 @@ class TopazImageEnhance(IO.ComfyNode): return IO.Schema( node_id="TopazImageEnhance", display_name="Topaz Image Enhance", - category="api node/image/Topaz", + category="partner/image/Topaz", description="Industry-standard upscaling and image enhancement.", inputs=[ IO.Combo.Input("model", options=["Reimagine"]), @@ -235,7 +235,7 @@ class TopazVideoEnhance(IO.ComfyNode): return IO.Schema( node_id="TopazVideoEnhance", display_name="Topaz Video Enhance (Legacy)", - category="api node/video/Topaz", + category="partner/video/Topaz", description="Breathe new life into video with powerful upscaling and recovery technology.", inputs=[ IO.Video.Input("video"), @@ -475,7 +475,7 @@ class TopazVideoEnhanceV2(IO.ComfyNode): return IO.Schema( node_id="TopazVideoEnhanceV2", display_name="Topaz Video Enhance", - category="api node/video/Topaz", + category="partner/video/Topaz", description="Breathe new life into video with powerful upscaling and recovery technology.", inputs=[ IO.Video.Input("video"), diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index d6501dee4..a3f2cb053 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -11,6 +11,9 @@ from comfy_api_nodes.apis.tripo import ( TripoModelVersion, TripoMultiviewToModelRequest, TripoOrientation, + TripoP1ImageToModelRequest, + TripoP1MultiviewToModelRequest, + TripoP1TextToModelRequest, TripoRefineModelRequest, TripoStyle, TripoTaskResponse, @@ -80,7 +83,7 @@ class TripoTextToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoTextToModelNode", display_name="Tripo: Text to Model", - category="api node/3d/Tripo", + category="partner/3d/Tripo", inputs=[ IO.String.Input("prompt", multiline=True), IO.String.Input("negative_prompt", multiline=True, optional=True), @@ -93,10 +96,22 @@ class TripoTextToModelNode(IO.ComfyNode): IO.Int.Input("image_seed", default=42, optional=True, advanced=True), IO.Int.Input("model_seed", default=42, optional=True, advanced=True), IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), + IO.Combo.Input( + "texture_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), IO.Int.Input("face_limit", default=-1, min=-1, max=2000000, optional=True, advanced=True), IO.Boolean.Input("quad", default=False, optional=True, advanced=True), - IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), + IO.Combo.Input( + "geometry_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), ], outputs=[ IO.String.Output(display_name="model_file"), # for backward compatibility only @@ -195,7 +210,7 @@ class TripoImageToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoImageToModelNode", display_name="Tripo: Image to Model", - category="api node/3d/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Image.Input("image"), IO.Combo.Input( @@ -209,16 +224,36 @@ class TripoImageToModelNode(IO.ComfyNode): IO.Boolean.Input("pbr", default=True, optional=True), IO.Int.Input("model_seed", default=42, optional=True, advanced=True), IO.Combo.Input( - "orientation", options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True, advanced=True + "orientation", + options=TripoOrientation, + default=TripoOrientation.DEFAULT, + optional=True, + advanced=True, ), IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), IO.Combo.Input( - "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True + "texture_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), + IO.Combo.Input( + "texture_alignment", + default="original_image", + options=["original_image", "geometry"], + optional=True, + advanced=True, ), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True), IO.Boolean.Input("quad", default=False, optional=True, advanced=True), - IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), + IO.Combo.Input( + "geometry_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), ], outputs=[ IO.String.Output(display_name="model_file"), # for backward compatibility only @@ -323,7 +358,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode): return IO.Schema( node_id="TripoMultiviewToModelNode", display_name="Tripo: Multiview to Model", - category="api node/3d/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Image.Input("image"), IO.Image.Input("image_left", optional=True), @@ -346,13 +381,35 @@ class TripoMultiviewToModelNode(IO.ComfyNode): IO.Boolean.Input("pbr", default=True, optional=True), IO.Int.Input("model_seed", default=42, optional=True, advanced=True), IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), IO.Combo.Input( - "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True + "texture_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), + IO.Combo.Input( + "texture_alignment", + default="original_image", + options=["original_image", "geometry"], + optional=True, + advanced=True, ), IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True, advanced=True), - IO.Boolean.Input("quad", default=False, optional=True, advanced=True, tooltip="This parameter is deprecated and does nothing."), - IO.Combo.Input("geometry_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), + IO.Boolean.Input( + "quad", + default=False, + optional=True, + advanced=True, + tooltip="This parameter is deprecated and does nothing.", + ), + IO.Combo.Input( + "geometry_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), ], outputs=[ IO.String.Output(display_name="model_file"), # for backward compatibility only @@ -461,15 +518,25 @@ class TripoTextureNode(IO.ComfyNode): return IO.Schema( node_id="TripoTextureNode", display_name="Tripo: Texture model", - category="api node/3d/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Custom("MODEL_TASK_ID").Input("model_task_id"), IO.Boolean.Input("texture", default=True, optional=True), IO.Boolean.Input("pbr", default=True, optional=True), IO.Int.Input("texture_seed", default=42, optional=True, advanced=True), - IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True, advanced=True), IO.Combo.Input( - "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True, advanced=True + "texture_quality", + default="standard", + options=["standard", "detailed"], + optional=True, + advanced=True, + ), + IO.Combo.Input( + "texture_alignment", + default="original_image", + options=["original_image", "geometry"], + optional=True, + advanced=True, ), ], outputs=[ @@ -528,7 +595,7 @@ class TripoRefineNode(IO.ComfyNode): return IO.Schema( node_id="TripoRefineNode", display_name="Tripo: Refine Draft model", - category="api node/3d/Tripo", + category="partner/3d/Tripo", description="Refine a draft model created by v1.4 Tripo models only.", inputs=[ IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"), @@ -568,7 +635,7 @@ class TripoRigNode(IO.ComfyNode): return IO.Schema( node_id="TripoRigNode", display_name="Tripo: Rig model", - category="api node/3d/Tripo", + category="partner/3d/Tripo", inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")], outputs=[ IO.String.Output(display_name="model_file"), # for backward compatibility only @@ -605,7 +672,7 @@ class TripoRetargetNode(IO.ComfyNode): return IO.Schema( node_id="TripoRetargetNode", display_name="Tripo: Retarget rigged model", - category="api node/3d/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Custom("RIG_TASK_ID").Input("original_model_task_id"), IO.Combo.Input( @@ -626,7 +693,7 @@ class TripoRetargetNode(IO.ComfyNode): "preset:hexapod:walk", "preset:octopod:walk", "preset:serpentine:march", - "preset:aquatic:march" + "preset:aquatic:march", ], ), ], @@ -670,7 +737,7 @@ class TripoConversionNode(IO.ComfyNode): return IO.Schema( node_id="TripoConversionNode", display_name="Tripo: Convert model", - category="api node/3d/Tripo", + category="partner/3d/Tripo", inputs=[ IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"), IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]), @@ -817,7 +884,7 @@ class TripoConversionNode(IO.ComfyNode): # Parse part_names from comma-separated string to list part_names_list = None if part_names and part_names.strip(): - part_names_list = [name.strip() for name in part_names.split(',') if name.strip()] + part_names_list = [name.strip() for name in part_names.split(",") if name.strip()] response = await sync_op( cls, @@ -848,6 +915,373 @@ class TripoConversionNode(IO.ComfyNode): return await poll_until_finished(cls, response, average_duration=30) +def _p1_price_expr(*, geometry_credits: int, textured_credits: int, detailed_credits: int) -> str: + return ( + "(" + " $mode := widgets.output_mode;" + ' $detailed := $lookup(widgets, "output_mode.texture_quality") = "detailed";' + f' $credits := $mode = "geometry only" ? {geometry_credits} : ($detailed ? {detailed_credits} : {textured_credits});' + ' {"type":"usd","usd": $credits * 0.01, "format": {"approximate": true}}' + ")" + ) + + +def _p1_textured_inputs(*, include_image_alignment: bool) -> list: + """Inputs shown inside the 'Textured' branch of the P1 output_mode DynamicCombo.""" + inputs: list = [ + IO.Boolean.Input("pbr", default=True, tooltip="Include PBR maps. When on, base texture is forced on too."), + IO.Combo.Input("texture_quality", options=["standard", "detailed"], default="standard"), + ] + if include_image_alignment: + inputs.extend( + [ + IO.Combo.Input( + "texture_alignment", + options=["original_image", "geometry"], + default="original_image", + tooltip="Prioritize visual fidelity to the source image, or alignment to the mesh geometry.", + ), + IO.Combo.Input( + "orientation", + options=["default", "align_image"], + default="default", + tooltip="Rotate the output to match the source image. Only applies when textured.", + ), + ] + ) + inputs.append(IO.Int.Input("texture_seed", default=42, advanced=True)) + return inputs + + +def _build_p1_output_mode(*, include_image_alignment: bool) -> IO.DynamicCombo.Input: + return IO.DynamicCombo.Input( + "output_mode", + options=[ + IO.DynamicCombo.Option("Geometry only", []), + IO.DynamicCombo.Option("Textured", _p1_textured_inputs(include_image_alignment=include_image_alignment)), + ], + tooltip='"Geometry only" returns an untextured mesh. "Textured" adds color/PBR maps.', + ) + + +def _resolve_p1_texture_fields(output_mode: dict) -> dict: + """Translate the output_mode DynamicCombo payload into P1 request fields. + + pbr=true forces texture=true server-side, but we send both explicitly so the + intent is visible in the request body and logs. + """ + mode = output_mode["output_mode"] + if mode == "Geometry only": + return {"texture": False, "pbr": False} + out = { + "texture": True, + "pbr": bool(output_mode.get("pbr", True)), + "texture_quality": output_mode.get("texture_quality", "standard"), + "texture_seed": output_mode.get("texture_seed"), + } + if "texture_alignment" in output_mode: + out["texture_alignment"] = output_mode["texture_alignment"] + if "orientation" in output_mode: + out["orientation"] = output_mode["orientation"] + return out + + +def _p1_common_inputs() -> list: + """Inputs shared by all P1 nodes (placed after output_mode).""" + return [ + IO.Int.Input( + "face_limit", + default=-1, + min=-1, + max=20000, + optional=True, + advanced=True, + tooltip="Target face count, 48-20000. -1 lets Tripo pick adaptively.", + ), + IO.Int.Input("model_seed", default=42, optional=True, advanced=True), + IO.Boolean.Input( + "auto_size", + default=False, + optional=True, + advanced=True, + tooltip="Scale the output to approximate real-world meters.", + ), + IO.Boolean.Input( + "export_uv", + default=True, + optional=True, + advanced=True, + tooltip="UV unwrap during generation. Turn off for faster geometry-only runs.", + ), + IO.Boolean.Input( + "compress_geometry", + default=False, + optional=True, + advanced=True, + tooltip="Apply geometry-based compression. Decompress before editing.", + ), + ] + + +def _build_p1_request_kwargs( + *, + output_mode: dict, + face_limit: int, + model_seed: int, + auto_size: bool, + export_uv: bool, + compress_geometry: bool, +) -> dict: + """Common P1 request fields shared by all three node types.""" + kwargs: dict = { + "model_seed": model_seed, + "face_limit": face_limit if face_limit != -1 else None, + "auto_size": auto_size, + "export_uv": export_uv, + "compress": "geometry" if compress_geometry else None, + } + kwargs.update(_resolve_p1_texture_fields(output_mode)) + return kwargs + + +class TripoP1TextToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoP1TextToModelNode", + display_name="Tripo P1: Text to Model", + category="partner/3d/Tripo", + description="Tripo P1 text-to-3D. Optimized for low-poly, game-ready meshes with stable topology.", + inputs=[ + IO.String.Input("prompt", multiline=True, tooltip="Up to 1024 characters."), + IO.String.Input("negative_prompt", multiline=True, optional=True, tooltip="Up to 255 characters."), + _build_p1_output_mode(include_image_alignment=False), + IO.Int.Input("image_seed", default=42, optional=True, advanced=True), + *_p1_common_inputs(), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]), + expr=_p1_price_expr(geometry_credits=30, textured_credits=40, detailed_credits=50), + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + output_mode: dict, + negative_prompt: str | None = None, + image_seed: int | None = None, + face_limit: int = -1, + model_seed: int | None = None, + auto_size: bool = False, + export_uv: bool = True, + compress_geometry: bool = False, + ) -> IO.NodeOutput: + if not prompt: + raise RuntimeError("Prompt is required") + common = _build_p1_request_kwargs( + output_mode=output_mode, + face_limit=face_limit, + model_seed=model_seed, + auto_size=auto_size, + export_uv=export_uv, + compress_geometry=compress_geometry, + ) + request = TripoP1TextToModelRequest( + prompt=prompt, + negative_prompt=negative_prompt or None, + image_seed=image_seed, + **common, + ) + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=request, + ) + return await poll_until_finished(cls, response, average_duration=60) + + +class TripoP1ImageToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoP1ImageToModelNode", + display_name="Tripo P1: Image to Model", + category="partner/3d/Tripo", + description="Tripo P1 image-to-3D. Optimized for low-poly, game-ready meshes.", + inputs=[ + IO.Image.Input("image"), + _build_p1_output_mode(include_image_alignment=True), + IO.Boolean.Input( + "enable_image_autofix", + default=False, + optional=True, + advanced=True, + tooltip="Pre-process the input image for better generation quality.", + ), + *_p1_common_inputs(), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]), + expr=_p1_price_expr(geometry_credits=40, textured_credits=50, detailed_credits=60), + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + output_mode: dict, + enable_image_autofix: bool = False, + face_limit: int = -1, + model_seed: int | None = None, + auto_size: bool = False, + export_uv: bool = True, + compress_geometry: bool = False, + ) -> IO.NodeOutput: + if image is None: + raise RuntimeError("Image is required") + tripo_file = TripoFileReference( + root=TripoUrlReference( + url=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], + type="jpeg", + ) + ) + common = _build_p1_request_kwargs( + output_mode=output_mode, + face_limit=face_limit, + model_seed=model_seed, + auto_size=auto_size, + export_uv=export_uv, + compress_geometry=compress_geometry, + ) + request = TripoP1ImageToModelRequest( + file=tripo_file, + enable_image_autofix=enable_image_autofix, + **common, + ) + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=request, + ) + return await poll_until_finished(cls, response, average_duration=60) + + +class TripoP1MultiviewToModelNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoP1MultiviewToModelNode", + display_name="Tripo P1: Multiview to Model", + category="partner/3d/Tripo", + description="Tripo P1 multiview-to-3D from 2-4 reference images in [front, left, back, right] order. " + "Front is required; any combination of the other three may be omitted.", + inputs=[ + IO.Image.Input("image", tooltip="Front view (0°). Required."), + IO.Image.Input( + "image_left", + optional=True, + tooltip="Left view (90°), i.e. the subject's left side.", + ), + IO.Image.Input("image_back", optional=True, tooltip="Back view (180°)."), + IO.Image.Input( + "image_right", + optional=True, + tooltip="Right view (270°), i.e. the subject's right side.", + ), + _build_p1_output_mode(include_image_alignment=True), + *_p1_common_inputs(), + ], + outputs=[ + IO.String.Output(display_name="model_file"), # for backward compatibility only + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + IO.File3DGLB.Output(display_name="GLB"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + depends_on=IO.PriceBadgeDepends(widgets=["output_mode", "output_mode.texture_quality"]), + expr=_p1_price_expr(geometry_credits=40, textured_credits=50, detailed_credits=60), + ), + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + output_mode: dict, + image_left: Input.Image | None = None, + image_back: Input.Image | None = None, + image_right: Input.Image | None = None, + face_limit: int = -1, + model_seed: int | None = None, + auto_size: bool = False, + export_uv: bool = True, + compress_geometry: bool = False, + ) -> IO.NodeOutput: + views = [image, image_left, image_back, image_right] + if sum(1 for v in views if v is not None) < 2: + raise RuntimeError("Tripo P1 multiview requires at least 2 images (front plus one of left/back/right).") + + files: list[TripoFileReference] = [] + for view in views: + if view is None: + files.append(TripoFileReference(root=TripoFileEmptyReference())) + continue + url = (await upload_images_to_comfyapi(cls, view, max_images=1))[0] + files.append(TripoFileReference(root=TripoUrlReference(url=url, type="jpeg"))) + + common = _build_p1_request_kwargs( + output_mode=output_mode, + face_limit=face_limit, + model_seed=model_seed, + auto_size=auto_size, + export_uv=export_uv, + compress_geometry=compress_geometry, + ) + request = TripoP1MultiviewToModelRequest(files=files, **common) + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=request, + ) + return await poll_until_finished(cls, response, average_duration=80) + + class TripoExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -855,6 +1289,9 @@ class TripoExtension(ComfyExtension): TripoTextToModelNode, TripoImageToModelNode, TripoMultiviewToModelNode, + TripoP1TextToModelNode, + TripoP1ImageToModelNode, + TripoP1MultiviewToModelNode, TripoTextureNode, TripoRefineNode, TripoRigNode, diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 2ff75d9b2..ed34e928b 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -45,7 +45,7 @@ class VeoVideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="VeoVideoGenerationNode", display_name="Google Veo 2 Video Generation", - category="api node/video/Veo", + category="partner/video/Veo", description="Generates videos from text prompts using Google's Veo 2 API", inputs=[ IO.String.Input( @@ -256,7 +256,7 @@ class Veo3VideoGenerationNode(IO.ComfyNode): return IO.Schema( node_id="Veo3VideoGenerationNode", display_name="Google Veo 3 Video Generation", - category="api node/video/Veo", + category="partner/video/Veo", description="Generates videos from text prompts using Google's Veo 3 API", inputs=[ IO.String.Input( @@ -468,7 +468,7 @@ class Veo3FirstLastFrameNode(IO.ComfyNode): return IO.Schema( node_id="Veo3FirstLastFrameNode", display_name="Google Veo 3 First-Last-Frame to Video", - category="api node/video/Veo", + category="partner/video/Veo", description="Generate video using prompt and first and last frames.", inputs=[ IO.String.Input( diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 8d90cefeb..8c5a43f5b 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -71,7 +71,7 @@ class ViduTextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduTextToVideoNode", display_name="Vidu Text To Video Generation", - category="api node/video/Vidu", + category="partner/video/Vidu", description="Generate video from a text prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -169,7 +169,7 @@ class ViduImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduImageToVideoNode", display_name="Vidu Image To Video Generation", - category="api node/video/Vidu", + category="partner/video/Vidu", description="Generate video from image and optional prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -273,7 +273,7 @@ class ViduReferenceVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduReferenceVideoNode", display_name="Vidu Reference To Video Generation", - category="api node/video/Vidu", + category="partner/video/Vidu", description="Generate video from multiple images and a prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -388,7 +388,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduStartEndToVideoNode", display_name="Vidu Start End To Video Generation", - category="api node/video/Vidu", + category="partner/video/Vidu", description="Generate a video from start and end frames and a prompt", inputs=[ IO.Combo.Input("model", options=["viduq1"], tooltip="Model name"), @@ -492,7 +492,7 @@ class Vidu2TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2TextToVideoNode", display_name="Vidu2 Text-to-Video Generation", - category="api node/video/Vidu", + category="partner/video/Vidu", description="Generate video from a text prompt", inputs=[ IO.Combo.Input("model", options=["viduq2"]), @@ -584,7 +584,7 @@ class Vidu2ImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2ImageToVideoNode", display_name="Vidu2 Image-to-Video Generation", - category="api node/video/Vidu", + category="partner/video/Vidu", description="Generate a video from an image and an optional prompt.", inputs=[ IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), @@ -714,7 +714,7 @@ class Vidu2ReferenceVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2ReferenceVideoNode", display_name="Vidu2 Reference-to-Video Generation", - category="api node/video/Vidu", + category="partner/video/Vidu", description="Generate a video from multiple reference images and a prompt.", inputs=[ IO.Combo.Input("model", options=["viduq2"]), @@ -849,7 +849,7 @@ class Vidu2StartEndToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu2StartEndToVideoNode", display_name="Vidu2 Start/End Frame-to-Video Generation", - category="api node/video/Vidu", + category="partner/video/Vidu", description="Generate a video from a start frame, an end frame, and a prompt.", inputs=[ IO.Combo.Input("model", options=["viduq2-pro-fast", "viduq2-pro", "viduq2-turbo"]), @@ -969,7 +969,7 @@ class ViduExtendVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduExtendVideoNode", display_name="Vidu Video Extension", - category="api node/video/Vidu", + category="partner/video/Vidu", description="Extend an existing video by generating additional frames.", inputs=[ IO.DynamicCombo.Input( @@ -1138,7 +1138,7 @@ class ViduMultiFrameVideoNode(IO.ComfyNode): return IO.Schema( node_id="ViduMultiFrameVideoNode", display_name="Vidu Multi-Frame Video Generation", - category="api node/video/Vidu", + category="partner/video/Vidu", description="Generate a video with multiple keyframe transitions.", inputs=[ IO.Combo.Input("model", options=["viduq2-pro", "viduq2-turbo"]), @@ -1284,7 +1284,7 @@ class Vidu3TextToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu3TextToVideoNode", display_name="Vidu Q3 Text-to-Video Generation", - category="api node/video/Vidu", + category="partner/video/Vidu", description="Generate video from a text prompt.", inputs=[ IO.DynamicCombo.Input( @@ -1429,7 +1429,7 @@ class Vidu3ImageToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu3ImageToVideoNode", display_name="Vidu Q3 Image-to-Video Generation", - category="api node/video/Vidu", + category="partner/video/Vidu", description="Generate a video from an image and an optional prompt.", inputs=[ IO.DynamicCombo.Input( @@ -1571,7 +1571,7 @@ class Vidu3StartEndToVideoNode(IO.ComfyNode): return IO.Schema( node_id="Vidu3StartEndToVideoNode", display_name="Vidu Q3 Start/End Frame-to-Video Generation", - category="api node/video/Vidu", + category="partner/video/Vidu", description="Generate a video from a start frame, an end frame, and a prompt.", inputs=[ IO.DynamicCombo.Input( diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index 68061bb5c..b7b97d70f 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -61,7 +61,7 @@ class WanTextToImageApi(IO.ComfyNode): return IO.Schema( node_id="WanTextToImageApi", display_name="Wan Text to Image", - category="api node/image/Wan", + category="partner/image/Wan", description="Generates an image based on a text prompt.", inputs=[ IO.Combo.Input( @@ -184,7 +184,7 @@ class WanImageToImageApi(IO.ComfyNode): return IO.Schema( node_id="WanImageToImageApi", display_name="Wan Image to Image", - category="api node/image/Wan", + category="partner/image/Wan", description="Generates an image from one or two input images and a text prompt. " "The output image is currently fixed at 1.6 MP, and its aspect ratio matches the input image(s).", inputs=[ @@ -312,7 +312,7 @@ class WanTextToVideoApi(IO.ComfyNode): return IO.Schema( node_id="WanTextToVideoApi", display_name="Wan Text to Video", - category="api node/video/Wan", + category="partner/video/Wan", description="Generates a video based on a text prompt.", inputs=[ IO.Combo.Input( @@ -495,7 +495,7 @@ class WanImageToVideoApi(IO.ComfyNode): return IO.Schema( node_id="WanImageToVideoApi", display_name="Wan Image to Video", - category="api node/video/Wan", + category="partner/video/Wan", description="Generates a video from the first frame and a text prompt.", inputs=[ IO.Combo.Input( @@ -674,7 +674,7 @@ class WanReferenceVideoApi(IO.ComfyNode): return IO.Schema( node_id="WanReferenceVideoApi", display_name="Wan Reference to Video", - category="api node/video/Wan", + category="partner/video/Wan", description="Use the character and voice from input videos, combined with a prompt, " "to generate a new video that maintains character consistency.", inputs=[ @@ -828,7 +828,7 @@ class Wan2TextToVideoApi(IO.ComfyNode): return IO.Schema( node_id="Wan2TextToVideoApi", display_name="Wan 2.7 Text to Video", - category="api node/video/Wan", + category="partner/video/Wan", description="Generates a video based on a text prompt using the Wan 2.7 model.", inputs=[ IO.DynamicCombo.Input( @@ -981,7 +981,7 @@ class Wan2ImageToVideoApi(IO.ComfyNode): return IO.Schema( node_id="Wan2ImageToVideoApi", display_name="Wan 2.7 Image to Video", - category="api node/video/Wan", + category="partner/video/Wan", description="Generate a video from a first-frame image, with optional last-frame image and audio.", inputs=[ IO.DynamicCombo.Input( @@ -1152,7 +1152,7 @@ class Wan2VideoContinuationApi(IO.ComfyNode): return IO.Schema( node_id="Wan2VideoContinuationApi", display_name="Wan 2.7 Video Continuation", - category="api node/video/Wan", + category="partner/video/Wan", description="Continue a video from where it left off, with optional last-frame control.", inputs=[ IO.DynamicCombo.Input( @@ -1319,7 +1319,7 @@ class Wan2VideoEditApi(IO.ComfyNode): return IO.Schema( node_id="Wan2VideoEditApi", display_name="Wan 2.7 Video Edit", - category="api node/video/Wan", + category="partner/video/Wan", description="Edit a video using text instructions, reference images, or style transfer.", inputs=[ IO.DynamicCombo.Input( @@ -1477,7 +1477,7 @@ class Wan2ReferenceVideoApi(IO.ComfyNode): return IO.Schema( node_id="Wan2ReferenceVideoApi", display_name="Wan 2.7 Reference to Video", - category="api node/video/Wan", + category="partner/video/Wan", description="Generate a video featuring a person or object from reference materials. " "Supports single-character performances and multi-character interactions.", inputs=[ @@ -1651,7 +1651,7 @@ class HappyHorseTextToVideoApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseTextToVideoApi", display_name="HappyHorse Text to Video", - category="api node/video/Wan", + category="partner/video/Wan", description="Generates a video based on a text prompt using the HappyHorse model.", inputs=[ IO.DynamicCombo.Input( @@ -1775,7 +1775,7 @@ class HappyHorseImageToVideoApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseImageToVideoApi", display_name="HappyHorse Image to Video", - category="api node/video/Wan", + category="partner/video/Wan", description="Generate a video from a first-frame image using the HappyHorse model.", inputs=[ IO.DynamicCombo.Input( @@ -1905,7 +1905,7 @@ class HappyHorseVideoEditApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseVideoEditApi", display_name="HappyHorse Video Edit", - category="api node/video/Wan", + category="partner/video/Wan", description="Edit a video using text instructions or reference images with the HappyHorse model. " "Output duration is 3-15s and matches the input video; inputs longer than 15s are truncated.", inputs=[ @@ -2046,7 +2046,7 @@ class HappyHorseReferenceVideoApi(IO.ComfyNode): return IO.Schema( node_id="HappyHorseReferenceVideoApi", display_name="HappyHorse Reference to Video", - category="api node/video/Wan", + category="partner/video/Wan", description="Generate a video featuring a person or object from reference materials with the HappyHorse " "model. Supports single-character performances and multi-character interactions.", inputs=[ diff --git a/comfy_api_nodes/nodes_wavespeed.py b/comfy_api_nodes/nodes_wavespeed.py index 65e45f60a..5839f9d37 100644 --- a/comfy_api_nodes/nodes_wavespeed.py +++ b/comfy_api_nodes/nodes_wavespeed.py @@ -27,7 +27,7 @@ class WavespeedFlashVSRNode(IO.ComfyNode): return IO.Schema( node_id="WavespeedFlashVSRNode", display_name="FlashVSR Video Upscale", - category="api node/video/WaveSpeed", + category="partner/video/WaveSpeed", description="Fast, high-quality video upscaler that " "boosts resolution and restores clarity for low-resolution or blurry footage.", inputs=[ @@ -98,7 +98,7 @@ class WavespeedImageUpscaleNode(IO.ComfyNode): return IO.Schema( node_id="WavespeedImageUpscaleNode", display_name="WaveSpeed Image Upscale", - category="api node/image/WaveSpeed", + category="partner/image/WaveSpeed", description="Boost image resolution and quality, upscaling photos to 4K or 8K for sharp, detailed results.", inputs=[ IO.Combo.Input("model", options=["SeedVR2", "Ultimate"]), diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py index 052301c33..57c501724 100644 --- a/comfy_api_nodes/util/client.py +++ b/comfy_api_nodes/util/client.py @@ -86,7 +86,7 @@ class _PollUIState: _RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] -QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"] +QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait", "in_queue"] async def sync_op( diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py index 5738df57f..a1b5d599c 100644 --- a/comfy_api_nodes/util/conversions.py +++ b/comfy_api_nodes/util/conversions.py @@ -469,6 +469,11 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input input_container = None output_container = None + # get_stream_source() is untrimmed, so apply the trim window in this same pass. + # start_time is normalized (>= 0); duration == 0 means "until the end". + start_time, duration = video.get_active_trim_window() + trimming = bool(start_time or duration) + try: input_source = video.get_stream_source() input_container = av.open(input_source, mode="r") @@ -487,16 +492,45 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input audio_stream.layout = stream.layout break + in_video = input_container.streams.video[0] + start_pts = int(start_time / in_video.time_base) if trimming else 0 + end_pts = int((start_time + duration) / in_video.time_base) if duration else None + if start_pts: + input_container.seek(start_pts, stream=in_video) + + encoded = 0 for frame in input_container.decode(video=0): + if trimming: + if frame.pts is None or frame.pts < start_pts: + continue + if end_pts is not None and frame.pts >= end_pts: + break frame = frame.reformat(width=out_w, height=out_h, format="yuv420p") + # Re-wrap as a fresh frame: dropping irregular source timestamps (VFR/AVI/GIF/...) + # lets the encoder assign clean ones and avoids mp4 muxer errors. + frame = av.VideoFrame.from_ndarray(frame.to_ndarray(format="yuv420p"), format="yuv420p") for packet in video_stream.encode(frame): output_container.mux(packet) + encoded += 1 for packet in video_stream.encode(): output_container.mux(packet) + if encoded == 0: + raise ValueError( + f"resize produced no frames (start_time={start_time}, duration={duration} " + "selected nothing from the source)" + ) + if audio_stream is not None: input_container.seek(0) for audio_frame in input_container.decode(audio=0): + if trimming: + if audio_frame.time is None or audio_frame.time < start_time: + continue + if duration and audio_frame.time > start_time + duration: + break + # Carry odd audio time bases the mp4 muxer rejects; reset pts, encoder assigns clean ones (MP3-in-AVI) + audio_frame.pts = None for packet in audio_stream.encode(audio_frame): output_container.mux(packet) for packet in audio_stream.encode(): diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py index 247d9ae8a..044077b18 100644 --- a/comfy_extras/nodes_ace.py +++ b/comfy_extras/nodes_ace.py @@ -11,7 +11,7 @@ class TextEncodeAceStepAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="TextEncodeAceStepAudio", - category="conditioning", + category="model/conditioning", inputs=[ IO.Clip.Input("clip"), IO.String.Input("tags", multiline=True, dynamic_prompts=True), @@ -33,7 +33,7 @@ class TextEncodeAceStepAudio15(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="TextEncodeAceStepAudio1.5", - category="conditioning", + category="model/conditioning", inputs=[ IO.Clip.Input("clip"), IO.String.Input("tags", multiline=True, dynamic_prompts=True), @@ -67,7 +67,7 @@ class EmptyAceStepLatentAudio(IO.ComfyNode): return IO.Schema( node_id="EmptyAceStepLatentAudio", display_name="Empty Ace Step 1.0 Latent Audio", - category="latent/audio", + category="model/latent/audio", inputs=[ IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.1), IO.Int.Input( @@ -90,7 +90,7 @@ class EmptyAceStep15LatentAudio(IO.ComfyNode): return IO.Schema( node_id="EmptyAceStep1.5LatentAudio", display_name="Empty Ace Step 1.5 Latent Audio", - category="latent/audio", + category="model/latent/audio", inputs=[ IO.Float.Input("seconds", default=120.0, min=1.0, max=1000.0, step=0.01), IO.Int.Input( diff --git a/comfy_extras/nodes_advanced_samplers.py b/comfy_extras/nodes_advanced_samplers.py index 20717ca38..77a561e30 100644 --- a/comfy_extras/nodes_advanced_samplers.py +++ b/comfy_extras/nodes_advanced_samplers.py @@ -45,7 +45,7 @@ class SamplerLCMUpscale(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="SamplerLCMUpscale", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Float.Input("scale_ratio", default=1.0, min=0.1, max=20.0, step=0.01, advanced=True), io.Int.Input("scale_steps", default=-1, min=-1, max=1000, step=1, advanced=True), @@ -91,7 +91,7 @@ class SamplerLCM(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="SamplerLCM", - category="sampling/samplers", + category="model/sampling/samplers", description=("LCM sampler with tunable per-step noise. s_noise is a multiplier on the model's training noise scale"), inputs=[ io.Float.Input("s_noise", default=1.0, min=0.0, max=64.0, step=0.01, diff --git a/comfy_extras/nodes_align_your_steps.py b/comfy_extras/nodes_align_your_steps.py index 307f41337..f89a809bb 100644 --- a/comfy_extras/nodes_align_your_steps.py +++ b/comfy_extras/nodes_align_your_steps.py @@ -29,7 +29,7 @@ class AlignYourStepsScheduler(io.ComfyNode): return io.Schema( node_id="AlignYourStepsScheduler", search_aliases=["AYS scheduler"], - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Combo.Input("model_type", options=["SD1", "SDXL", "SVD"]), io.Int.Input("steps", default=10, min=1, max=10000), diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py index fd561d360..4a352038a 100644 --- a/comfy_extras/nodes_apg.py +++ b/comfy_extras/nodes_apg.py @@ -16,7 +16,7 @@ class APG(io.ComfyNode): return io.Schema( node_id="APG", display_name="Adaptive Projected Guidance", - category="sampling/custom_sampling", + category="model/sampling/custom_sampling", inputs=[ io.Model.Input("model"), io.Float.Input( diff --git a/comfy_extras/nodes_ar_video.py b/comfy_extras/nodes_ar_video.py index 1a15facfa..c22359eb2 100644 --- a/comfy_extras/nodes_ar_video.py +++ b/comfy_extras/nodes_ar_video.py @@ -19,7 +19,7 @@ class EmptyARVideoLatent(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="EmptyARVideoLatent", - category="latent/video", + category="model/latent/video", inputs=[ io.Int.Input("width", default=832, min=16, max=8192, step=16), io.Int.Input("height", default=480, min=16, max=8192, step=16), @@ -53,7 +53,7 @@ class SamplerARVideo(io.ComfyNode): return io.Schema( node_id="SamplerARVideo", display_name="Sampler AR Video", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Int.Input( "num_frame_per_block", @@ -85,7 +85,7 @@ class ARVideoI2V(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ARVideoI2V", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Model.Input("model"), io.Vae.Input("vae"), diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index f09a8a874..1dc97ecd7 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -16,7 +16,7 @@ class EmptyLatentAudio(IO.ComfyNode): return IO.Schema( node_id="EmptyLatentAudio", display_name="Empty Latent Audio", - category="latent/audio", + category="model/latent/audio", essentials_category="Audio", inputs=[ IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), @@ -41,7 +41,7 @@ class ConditioningStableAudio(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="ConditioningStableAudio", - category="conditioning", + category="model/conditioning", inputs=[ IO.Conditioning.Input("positive"), IO.Conditioning.Input("negative"), @@ -70,7 +70,7 @@ class VAEEncodeAudio(IO.ComfyNode): node_id="VAEEncodeAudio", search_aliases=["audio to latent"], display_name="VAE Encode Audio", - category="latent/audio", + category="model/latent/audio", inputs=[ IO.Audio.Input("audio"), IO.Vae.Input("vae"), @@ -115,7 +115,7 @@ class VAEDecodeAudio(IO.ComfyNode): node_id="VAEDecodeAudio", search_aliases=["latent to audio"], display_name="VAE Decode Audio", - category="latent/audio", + category="model/latent/audio", inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), @@ -137,7 +137,7 @@ class VAEDecodeAudioTiled(IO.ComfyNode): node_id="VAEDecodeAudioTiled", search_aliases=["latent to audio"], display_name="VAE Decode Audio (Tiled)", - category="latent/audio", + category="model/latent/audio", inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), @@ -158,7 +158,7 @@ class SaveAudio(IO.ComfyNode): return IO.Schema( node_id="SaveAudio", search_aliases=["export flac"], - display_name="Save Audio (FLAC)", + display_name="Save Audio (FLAC) (Deprecated)", category="audio", essentials_category="Audio", inputs=[ @@ -167,6 +167,7 @@ class SaveAudio(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + is_deprecated=True, ) @classmethod @@ -186,7 +187,7 @@ class SaveAudioMP3(IO.ComfyNode): return IO.Schema( node_id="SaveAudioMP3", search_aliases=["export mp3"], - display_name="Save Audio (MP3)", + display_name="Save Audio (MP3) (Deprecated)", category="audio", essentials_category="Audio", inputs=[ @@ -196,6 +197,7 @@ class SaveAudioMP3(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + is_deprecated=True, ) @classmethod @@ -217,7 +219,7 @@ class SaveAudioOpus(IO.ComfyNode): return IO.Schema( node_id="SaveAudioOpus", search_aliases=["export opus"], - display_name="Save Audio (Opus)", + display_name="Save Audio (Opus) (Deprecated)", category="audio", inputs=[ IO.Audio.Input("audio"), @@ -226,6 +228,7 @@ class SaveAudioOpus(IO.ComfyNode): ], hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], is_output_node=True, + is_deprecated=True, ) @classmethod @@ -241,6 +244,54 @@ class SaveAudioOpus(IO.ComfyNode): save_opus = execute # TODO: remove +class SaveAudioAdvanced(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioAdvanced", + search_aliases=["save audio", "export audio", "output audio", "write audio", "flac", "mp3", "opus"], + display_name="Save Audio (Advanced)", + description="Saves the input audio to your ComfyUI output directory.", + category="audio", + inputs=[ + IO.Audio.Input("audio", tooltip="The audio to save."), + IO.String.Input( + "filename_prefix", + default="audio/ComfyUI", + tooltip=( + "The prefix for the file to save. May include formatting tokens " + "such as %date:yyyy-MM-dd%." + ), + ), + IO.DynamicCombo.Input( + "format", + options=[ + IO.DynamicCombo.Option("flac", []), + IO.DynamicCombo.Option("mp3", [ + IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"), + ]), + IO.DynamicCombo.Option("opus", [ + IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"), + ]), + ], + tooltip="The file format in which to save the audio.", + ), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) + + @classmethod + def execute(cls, audio, filename_prefix: str, format: dict) -> IO.NodeOutput: + file_format = format.get("format", None) + quality = format.get("quality", None) + if quality: + ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format, quality=quality) + else: + ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format) + return IO.NodeOutput(ui=ui) + + class PreviewAudio(IO.ComfyNode): @classmethod def define_schema(cls): @@ -822,6 +873,7 @@ class AudioExtension(ComfyExtension): SaveAudio, SaveAudioMP3, SaveAudioOpus, + SaveAudioAdvanced, LoadAudio, PreviewAudio, ConditioningStableAudio, diff --git a/comfy_extras/nodes_audio_encoder.py b/comfy_extras/nodes_audio_encoder.py index 6a85da89b..2ae30d321 100644 --- a/comfy_extras/nodes_audio_encoder.py +++ b/comfy_extras/nodes_audio_encoder.py @@ -11,7 +11,7 @@ class AudioEncoderLoader(io.ComfyNode): return io.Schema( node_id="AudioEncoderLoader", display_name="Load Audio Encoder", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input( "audio_encoder_name", @@ -36,7 +36,7 @@ class AudioEncoderEncode(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="AudioEncoderEncode", - category="conditioning", + category="model/conditioning", inputs=[ io.AudioEncoder.Input("audio_encoder"), io.Audio.Input("audio"), diff --git a/comfy_extras/nodes_bg_removal.py b/comfy_extras/nodes_bg_removal.py index 793fd802b..9dc9ad854 100644 --- a/comfy_extras/nodes_bg_removal.py +++ b/comfy_extras/nodes_bg_removal.py @@ -11,7 +11,7 @@ class LoadBackgroundRemovalModel(IO.ComfyNode): return IO.Schema( node_id="LoadBackgroundRemovalModel", display_name="Load Background Removal Model", - category="loaders", + category="model/loaders", inputs=[ IO.Combo.Input("bg_removal_name", options=sorted(files), tooltip="The model used to remove backgrounds from images"), ], diff --git a/comfy_extras/nodes_camera_trajectory.py b/comfy_extras/nodes_camera_trajectory.py index 34b78e81b..13a1448f4 100644 --- a/comfy_extras/nodes_camera_trajectory.py +++ b/comfy_extras/nodes_camera_trajectory.py @@ -153,7 +153,7 @@ class WanCameraEmbedding(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanCameraEmbedding", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Combo.Input( "camera_pose", diff --git a/comfy_extras/nodes_chroma_radiance.py b/comfy_extras/nodes_chroma_radiance.py index 509436062..a4f673001 100644 --- a/comfy_extras/nodes_chroma_radiance.py +++ b/comfy_extras/nodes_chroma_radiance.py @@ -13,7 +13,7 @@ class EmptyChromaRadianceLatentImage(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="EmptyChromaRadianceLatentImage", - category="latent/chroma_radiance", + category="model/latent/chroma_radiance", inputs=[ io.Int.Input(id="width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input(id="height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), @@ -33,7 +33,7 @@ class ChromaRadianceOptions(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="ChromaRadianceOptions", - category="model_patches/chroma_radiance", + category="model/patch/chroma_radiance", description="Allows setting advanced options for the Chroma Radiance model.", inputs=[ io.Model.Input(id="model"), @@ -65,6 +65,12 @@ class ChromaRadianceOptions(io.ComfyNode): tooltip="Allows overriding the default NeRF tile size. -1 means use the default (32). 0 means use non-tiling mode (may require a lot of VRAM).", advanced=True, ), + io.Boolean.Input( + id="force_sequential_txt_ids", + default=False, + tooltip="Force usage of sequential text token IDs instead of zeroes. Should be used for checkpoints from 2026-05-22 to 2026-06-01 that are trained in this way but do not contain the __sequential__ key in the state dict.", + advanced=True, + ), ], outputs=[io.Model.Output()], ) @@ -78,11 +84,15 @@ class ChromaRadianceOptions(io.ComfyNode): start_sigma: float, end_sigma: float, nerf_tile_size: int, + force_sequential_txt_ids: bool, ) -> io.NodeOutput: radiance_options = {} if nerf_tile_size >= 0: radiance_options["nerf_tile_size"] = nerf_tile_size + if force_sequential_txt_ids: + radiance_options["use_sequential_txt_ids"] = True + if not radiance_options: return io.NodeOutput(model) diff --git a/comfy_extras/nodes_color.py b/comfy_extras/nodes_color.py index 80ba121cd..688254e4e 100644 --- a/comfy_extras/nodes_color.py +++ b/comfy_extras/nodes_color.py @@ -7,29 +7,29 @@ class ColorToRGBInt(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="ColorToRGBInt", - display_name="Color to RGB Int", - category="utils", - description="Convert a color to a RGB integer value.", + display_name="Color Picker", + category="utilities", + description="Return a color RGB integer value and hexadecimal representation.", inputs=[ io.Color.Input("color"), ], outputs=[ io.Int.Output(display_name="rgb_int"), + io.Color.Output(display_name="hex") ], ) @classmethod - def execute( - cls, - color: str, - ) -> io.NodeOutput: + def execute(cls, color: str) -> io.NodeOutput: # expect format #RRGGBB if len(color) != 7 or color[0] != "#": raise ValueError("Color must be in format #RRGGBB") r = int(color[1:3], 16) g = int(color[3:5], 16) b = int(color[5:7], 16) - return io.NodeOutput(r * 256 * 256 + g * 256 + b) + + rgb_int = r * 256 * 256 + g * 256 + b + return io.NodeOutput(rgb_int, color) class ColorExtension(ComfyExtension): diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 24729c3a7..d9e32b9d9 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -9,7 +9,7 @@ class ContextWindowsManualNode(io.ComfyNode): return io.Schema( node_id="ContextWindowsManual", display_name="Context Windows (Manual)", - category="model_patches", + category="model/patch", description="Manually set context windows.", inputs=[ io.Model.Input("model", tooltip="The model to apply context windows to during sampling."), diff --git a/comfy_extras/nodes_controlnet.py b/comfy_extras/nodes_controlnet.py index 847cb0bdf..17d965405 100644 --- a/comfy_extras/nodes_controlnet.py +++ b/comfy_extras/nodes_controlnet.py @@ -9,7 +9,7 @@ class SetUnionControlNetType(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SetUnionControlNetType", - category="conditioning/controlnet", + category="model/conditioning/controlnet", inputs=[ io.ControlNet.Input("control_net"), io.Combo.Input("type", options=["auto"] + list(UNION_CONTROLNET_TYPES.keys())), @@ -39,7 +39,7 @@ class ControlNetInpaintingAliMamaApply(io.ComfyNode): return io.Schema( node_id="ControlNetInpaintingAliMamaApply", search_aliases=["masked controlnet"], - category="conditioning/controlnet", + category="model/conditioning/controlnet", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py index 7dd129d19..d754ab442 100644 --- a/comfy_extras/nodes_cosmos.py +++ b/comfy_extras/nodes_cosmos.py @@ -13,7 +13,7 @@ class EmptyCosmosLatentVideo(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="EmptyCosmosLatentVideo", - category="latent/video", + category="model/latent/video", inputs=[ io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("height", default=704, min=16, max=nodes.MAX_RESOLUTION, step=16), @@ -45,7 +45,7 @@ class CosmosImageToVideoLatent(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="CosmosImageToVideoLatent", - category="conditioning/inpaint", + category="model/conditioning/inpaint", inputs=[ io.Vae.Input("vae"), io.Int.Input("width", default=1280, min=16, max=nodes.MAX_RESOLUTION, step=16), @@ -88,7 +88,7 @@ class CosmosPredict2ImageToVideoLatent(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="CosmosPredict2ImageToVideoLatent", - category="conditioning/inpaint", + category="model/conditioning/inpaint", inputs=[ io.Vae.Input("vae"), io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), diff --git a/comfy_extras/nodes_curve.py b/comfy_extras/nodes_curve.py index 099453131..aa2d94bb6 100644 --- a/comfy_extras/nodes_curve.py +++ b/comfy_extras/nodes_curve.py @@ -11,7 +11,7 @@ class CurveEditor(io.ComfyNode): return io.Schema( node_id="CurveEditor", display_name="Curve Editor", - category="utils", + category="utilities", inputs=[ io.Curve.Input("curve"), io.Histogram.Input("histogram", optional=True), @@ -38,7 +38,7 @@ class ImageHistogram(io.ComfyNode): return io.Schema( node_id="ImageHistogram", display_name="Image Histogram", - category="utils", + category="utilities", inputs=[ io.Image.Input("image"), ], diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py index 10b56b91c..3e97084a4 100644 --- a/comfy_extras/nodes_custom_sampler.py +++ b/comfy_extras/nodes_custom_sampler.py @@ -1,5 +1,7 @@ import math import comfy.samplers +import comfy.sampler_helpers +import comfy.patcher_extension import comfy.sample from comfy.k_diffusion import sampling as k_diffusion_sampling from comfy.k_diffusion import sa_solver @@ -17,7 +19,7 @@ class BasicScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="BasicScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Model.Input("model"), io.Combo.Input("scheduler", options=comfy.samplers.SCHEDULER_NAMES), @@ -47,7 +49,7 @@ class KarrasScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="KarrasScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), @@ -69,7 +71,7 @@ class ExponentialScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ExponentialScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), @@ -90,7 +92,7 @@ class PolyexponentialScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PolyexponentialScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), @@ -112,7 +114,7 @@ class LaplaceScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LaplaceScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), @@ -136,7 +138,7 @@ class SDTurboScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SDTurboScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Model.Input("model"), io.Int.Input("steps", default=1, min=1, max=10), @@ -160,7 +162,7 @@ class BetaSamplingScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="BetaSamplingScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Model.Input("model"), io.Int.Input("steps", default=20, min=1, max=10000), @@ -182,7 +184,7 @@ class VPScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="VPScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False, advanced=True), #TODO: fix default values @@ -204,7 +206,7 @@ class SplitSigmas(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SplitSigmas", - category="sampling/sigmas", + category="model/sampling/sigmas", inputs=[ io.Sigmas.Input("sigmas"), io.Int.Input("step", default=0, min=0, max=10000), @@ -228,7 +230,7 @@ class SplitSigmasDenoise(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SplitSigmasDenoise", - category="sampling/sigmas", + category="model/sampling/sigmas", inputs=[ io.Sigmas.Input("sigmas"), io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), @@ -254,7 +256,7 @@ class FlipSigmas(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="FlipSigmas", - category="sampling/sigmas", + category="model/sampling/sigmas", inputs=[io.Sigmas.Input("sigmas")], outputs=[io.Sigmas.Output()] ) @@ -276,7 +278,7 @@ class SetFirstSigma(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SetFirstSigma", - category="sampling/sigmas", + category="model/sampling/sigmas", inputs=[ io.Sigmas.Input("sigmas"), io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False), @@ -298,7 +300,7 @@ class ExtendIntermediateSigmas(io.ComfyNode): return io.Schema( node_id="ExtendIntermediateSigmas", search_aliases=["interpolate sigmas"], - category="sampling/sigmas", + category="model/sampling/sigmas", inputs=[ io.Sigmas.Input("sigmas"), io.Int.Input("steps", default=2, min=1, max=100), @@ -351,7 +353,7 @@ class SamplingPercentToSigma(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplingPercentToSigma", - category="sampling/sigmas", + category="model/sampling/sigmas", inputs=[ io.Model.Input("model"), io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001), @@ -379,7 +381,7 @@ class KSamplerSelect(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="KSamplerSelect", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES)], outputs=[io.Sampler.Output()] ) @@ -396,7 +398,7 @@ class SamplerDPMPP_3M_SDE(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerDPMPP_3M_SDE", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), @@ -421,7 +423,7 @@ class SamplerDPMPP_2M_SDE(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerDPMPP_2M_SDE", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Combo.Input("solver_type", options=['midpoint', 'heun']), io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), @@ -448,7 +450,7 @@ class SamplerDPMPP_SDE(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerDPMPP_SDE", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), @@ -474,7 +476,7 @@ class SamplerDPMPP_2S_Ancestral(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerDPMPP_2S_Ancestral", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), @@ -494,7 +496,7 @@ class SamplerEulerAncestral(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerEulerAncestral", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False, advanced=True), @@ -515,7 +517,7 @@ class SamplerEulerAncestralCFGPP(io.ComfyNode): return io.Schema( node_id="SamplerEulerAncestralCFGPP", display_name="SamplerEulerAncestralCFG++", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False), io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False), @@ -537,7 +539,7 @@ class SamplerLMS(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerLMS", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[io.Int.Input("order", default=4, min=1, max=100, advanced=True)], outputs=[io.Sampler.Output()] ) @@ -554,7 +556,7 @@ class SamplerDPMAdaptative(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerDPMAdaptative", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Int.Input("order", default=3, min=2, max=3, advanced=True), io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False, advanced=True), @@ -585,7 +587,7 @@ class SamplerER_SDE(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerER_SDE", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]), io.Int.Input("max_stage", default=3, min=1, max=3, advanced=True), @@ -623,7 +625,7 @@ class SamplerSASolver(io.ComfyNode): return io.Schema( node_id="SamplerSASolver", search_aliases=["sde"], - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Model.Input("model"), io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False, advanced=True), @@ -668,7 +670,7 @@ class SamplerSEEDS2(io.ComfyNode): return io.Schema( node_id="SamplerSEEDS2", search_aliases=["sde", "exp heun"], - category="sampling/samplers", + category="model/sampling/samplers", inputs=[ io.Combo.Input("solver_type", options=["phi_1", "phi_2"]), io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength", advanced=True), @@ -727,7 +729,7 @@ class SamplerCustom(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerCustom", - category="sampling/custom_sampling", + category="model/sampling/custom_sampling", inputs=[ io.Model.Input("model"), io.Boolean.Input("add_noise", default=True, advanced=True), @@ -795,7 +797,7 @@ class BasicGuider(io.ComfyNode): return io.Schema( node_id="BasicGuider", display_name="Basic Guider", - category="sampling/guiders", + category="model/sampling/guiders", inputs=[ io.Model.Input("model"), io.Conditioning.Input("conditioning"), @@ -817,7 +819,7 @@ class CFGGuider(io.ComfyNode): return io.Schema( node_id="CFGGuider", display_name="CFG Guider", - category="sampling/guiders", + category="model/sampling/guiders", inputs=[ io.Model.Input("model"), io.Conditioning.Input("positive"), @@ -872,7 +874,7 @@ class DualCFGGuider(io.ComfyNode): node_id="DualCFGGuider", search_aliases=["dual prompt guidance"], display_name="Dual CFG Guider", - category="sampling/guiders", + category="model/sampling/guiders", inputs=[ io.Model.Input("model"), io.Conditioning.Input("cond1"), @@ -894,13 +896,92 @@ class DualCFGGuider(io.ComfyNode): get_guider = execute +class Guider_DualModel(comfy.samplers.CFGGuider): + # Runs the positive (cond) pass on the main model and the negative (uncond) pass on a separate model + def __init__(self, model_patcher, uncond_model_patcher): + super().__init__(model_patcher) + self.uncond_model_patcher = uncond_model_patcher + self.uncond_inner = None + + def outer_sample(self, noise, latent_image, sampler, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None): + self.uncond_inner = None + self.uncond_loaded = [] + self._uncond_neg = None + # skip at cfg 1.0 + if not math.isclose(self.cfg, 1.0): + uc = {"negative": list(map(lambda a: a.copy(), self.conds["negative"]))} + self.uncond_inner, uc, self.uncond_loaded = comfy.sampler_helpers.prepare_sampling( + self.uncond_model_patcher, noise.shape, uc, self.uncond_model_patcher.model_options) + self._uncond_neg = uc["negative"] + self.uncond_model_patcher.pre_run() + try: + return super().outer_sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) + finally: + if self.uncond_inner is not None: + self.uncond_model_patcher.cleanup() + comfy.sampler_helpers.cleanup_models({"negative": self._uncond_neg}, self.uncond_loaded) + self.uncond_inner = None + + def inner_sample(self, noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None): + if self.uncond_inner is not None: + li = latent_image + if li is not None and torch.count_nonzero(li) > 0: + li = self.uncond_inner.process_latent_in(li) + self._uncond_conds = comfy.samplers.process_conds( + self.uncond_inner, noise, {"negative": self._uncond_neg}, device, li, denoise_mask, seed, latent_shapes=latent_shapes)["negative"] + return super().inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) + + def predict_noise(self, x, timestep, model_options={}, seed=None): + positive = self.conds.get("positive", None) + cond = comfy.samplers.calc_cond_batch(self.inner_model, [positive], x, timestep, model_options)[0] + # uncond model not loaded (base cfg==1/no negative), or cfg driven to 1.0 this step -> single model, cond only + if self.uncond_inner is None or (math.isclose(self.cfg, 1.0) and not model_options.get("disable_cfg1_optimization", False)): + return cond + + uncond_model_options = model_options + if "multigpu_clones" in model_options: # TODO: support multigpu instead of just running uncond on a single GPU + uncond_model_options = {k: v for k, v in model_options.items() if k != "multigpu_clones"} + uncond = comfy.samplers.calc_cond_batch(self.uncond_inner, [self._uncond_conds], x, timestep, uncond_model_options)[0] + return comfy.samplers.cfg_function(self.inner_model, cond, uncond, self.cfg, x, timestep, + model_options=model_options, cond=positive, uncond=self._uncond_conds) + +class DualModelGuider(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DualModelGuider", + display_name="Dual Model CFG Guider", + category="model/sampling/guiders", + is_experimental=True, + inputs=[ + io.Model.Input("model", tooltip="Model used for the positive (conditional) pass."), + io.Model.Input("model_negative", optional=True, tooltip="Model used for the negative (unconditional) pass. Use the same model for ordinary CFG."), + io.Conditioning.Input("positive"), + io.Float.Input("cfg", default=4.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Conditioning.Input("negative", optional=True, tooltip="Negative conditioning run on the negative model. Leave unconnected for a text-free (image-only) unconditional pass."), + ], + outputs=[io.Guider.Output()], + ) + + @classmethod + def execute(cls, model, positive, cfg, model_negative=None, negative=None) -> io.NodeOutput: + if negative is None: + negative = [[None, {}]] # null cond -> no cross_attn -> model runs image-only + + guider = Guider_DualModel(model, model_negative) if model_negative is not None else comfy.samplers.CFGGuider(model) + guider.set_conds(positive, negative) + guider.set_cfg(cfg) + return io.NodeOutput(guider) + + get_guider = execute + class DisableNoise(io.ComfyNode): @classmethod def define_schema(cls): return io.Schema( node_id="DisableNoise", search_aliases=["zero noise"], - category="sampling/noise", + category="model/sampling/noise", inputs=[], outputs=[io.Noise.Output()] ) @@ -917,7 +998,7 @@ class RandomNoise(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="RandomNoise", - category="sampling/noise", + category="model/sampling/noise", inputs=[io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True)], outputs=[io.Noise.Output()] ) @@ -934,7 +1015,7 @@ class SamplerCustomAdvanced(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SamplerCustomAdvanced", - category="sampling/custom_sampling", + category="model/sampling/custom_sampling", inputs=[ io.Noise.Input("noise"), io.Guider.Input("guider"), @@ -1054,11 +1135,53 @@ class ManualSigmas(io.ComfyNode): sigmas = torch.FloatTensor(sigmas) return io.NodeOutput(sigmas) +class CFGOverride(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="CFGOverride", + display_name="CFG Override", + description="Override cfg to a fixed value over a [start, end] percent (sigma) range. " + "With multiple overrides, the one nearest the sampler wins on overlap.", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Float.Input("cfg", default=1.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001), + io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001), + ], + outputs=[io.Model.Output()], + ) + + @classmethod + def execute(cls, model, cfg, start_percent, end_percent) -> io.NodeOutput: + ms = model.get_model_object("model_sampling") + sigma_hi = ms.percent_to_sigma(start_percent) # percent->sigma decreasing, so hi >= lo + sigma_lo = ms.percent_to_sigma(end_percent) + + def predict_noise_wrapper(executor, *args, **kwargs): + sigma = float(args[1].flatten()[0]) # args = (x, timestep, model_options, seed) + if not (sigma_lo <= sigma <= sigma_hi): + return executor(*args, **kwargs) + guider = executor.class_obj # guider.cfg feeds cond_scale + saved = guider.cfg + guider.cfg = cfg + try: + return executor(*args, **kwargs) + finally: + guider.cfg = saved # restore for other steps/overrides + + m = model.clone() + m.add_wrapper(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, predict_noise_wrapper) + return io.NodeOutput(m) + + class CustomSamplersExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: return [ SamplerCustom, + CFGOverride, BasicScheduler, KarrasScheduler, ExponentialScheduler, @@ -1087,6 +1210,7 @@ class CustomSamplersExtension(ComfyExtension): SamplingPercentToSigma, CFGGuider, DualCFGGuider, + DualModelGuider, BasicGuider, RandomNoise, DisableNoise, diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 22f5ff203..0253b4b4f 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -157,7 +157,7 @@ class LoadImageTextDataSetFromFolderNode(io.ComfyNode): return io.NodeOutput(output_tensor, captions) -def save_images_to_folder(image_list, output_dir, prefix="image"): +def save_images_to_folder(image_list, output_dir, prefix="image", overwrite=True): """Utility function to save a list of image tensors to disk. Args: @@ -197,7 +197,11 @@ def save_images_to_folder(image_list, output_dir, prefix="image"): raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}") # Save image - filename = f"{prefix}_{idx:05d}.png" + if overwrite: + filename = f"{prefix}_{idx:05d}.png" + else: + _, _, counter, _, resolved_prefix = folder_paths.get_save_image_path(prefix, output_dir) + filename = f"{resolved_prefix}_{counter:05}_{idx:05d}.png" filepath = os.path.join(output_dir, filename) img.save(filepath) saved_files.append(filename) @@ -230,19 +234,26 @@ class SaveImageDataSetToFolderNode(io.ComfyNode): tooltip="Prefix for saved image filenames.", advanced=True, ), + io.Combo.Input( + "mode", + default="overwrite", + options=["overwrite", "increment"], + tooltip="Whether to overwrite existing files or increment filenames to avoid overwriting." + ), ], outputs=[], is_deprecated=True, # This node is redundant and superseded by existing Save Image nodes where the target folder can be specified in the filename_prefix ) @classmethod - def execute(cls, images, folder_name, filename_prefix): + def execute(cls, images, folder_name, filename_prefix, mode): # Extract scalar values folder_name = folder_name[0] filename_prefix = filename_prefix[0] + mode = mode[0] output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) - saved_files = save_images_to_folder(images, output_dir, filename_prefix) + saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite') logging.info(f"Saved {len(saved_files)} images to {output_dir}.") return io.NodeOutput() @@ -278,18 +289,25 @@ class SaveImageTextDataSetToFolderNode(io.ComfyNode): tooltip="Prefix for saved image filenames.", advanced=True, ), + io.Combo.Input( + "mode", + default="overwrite", + options=["overwrite", "increment"], + tooltip="Whether to overwrite existing files or increment filenames to avoid overwriting." + ), ], outputs=[], ) @classmethod - def execute(cls, images, folder_name, filename_prefix, texts=None): + def execute(cls, images, folder_name, filename_prefix, mode, texts=None): # Extract scalar values folder_name = folder_name[0] filename_prefix = filename_prefix[0] + mode = mode[0] output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) - saved_files = save_images_to_folder(images, output_dir, filename_prefix) + saved_files = save_images_to_folder(images, output_dir, filename_prefix, mode=='overwrite') # Save captions if texts: @@ -393,6 +411,21 @@ class ImageProcessingNode(io.ComfyNode): return has_group + @classmethod + def _ensure_image_list(cls, images): + """Normalize to a flat list of [1, H, W, C] tensors.""" + if isinstance(images, torch.Tensor): + if images.ndim != 4: + raise ValueError(f"Expected 4D image tensor, got shape {tuple(images.shape)}") + return [images[i:i+1] for i in range(images.shape[0])] + + flat = [] + for item in images: + if not isinstance(item, torch.Tensor) or item.ndim != 4: + raise ValueError(f"Expected 4D image tensor, got {type(item).__name__} shape {getattr(item, 'shape', None)}") + flat.extend([item[i:i+1] for i in range(item.shape[0])]) + return flat + @classmethod def define_schema(cls): if cls.node_id is None: @@ -440,6 +473,9 @@ class ImageProcessingNode(io.ComfyNode): """Execute the node. Routes to _process or _group_process based on mode.""" is_group = cls._detect_processing_mode() + if is_group: + images = cls._ensure_image_list(images) + # Extract scalar values from lists for parameters params = {} for k, v in kwargs.items(): @@ -574,7 +610,7 @@ class TextProcessingNode(io.ComfyNode): return io.Schema( node_id=cls.node_id, display_name=cls.display_name or cls.node_id, - category="dataset/text", + category="text", is_experimental=True, is_input_list=is_group, # True for group, False for individual inputs=inputs, @@ -1208,7 +1244,7 @@ class ResolutionBucket(io.ComfyNode): node_id="ResolutionBucket", search_aliases=["bucket by resolution", "group by resolution", "batch by resolution"], display_name="Resolution Bucket", - category="training", + category="model/training", description="Group latents and conditionings into buckets", is_experimental=True, is_input_list=True, @@ -1302,7 +1338,7 @@ class MakeTrainingDataset(io.ComfyNode): node_id="MakeTrainingDataset", search_aliases=["encode dataset"], display_name="Make Training Dataset", - category="training", + category="model/training", description="Encode images with VAE and texts with CLIP to create a training dataset of latents and conditionings.", is_experimental=True, is_input_list=True, # images and texts as lists @@ -1390,7 +1426,7 @@ class SaveTrainingDataset(io.ComfyNode): node_id="SaveTrainingDataset", search_aliases=["export dataset", "save dataset"], display_name="Save Training Dataset", - category="training", + category="model/training", description="Save encoded training dataset (latents + conditioning) to disk for efficient loading during training.", is_experimental=True, is_output_node=True, @@ -1493,7 +1529,7 @@ class LoadTrainingDataset(io.ComfyNode): node_id="LoadTrainingDataset", search_aliases=["import dataset", "training data"], display_name="Load Training Dataset", - category="training", + category="model/training", description="Load encoded training dataset (latents + conditioning) from disk for use in training.", is_experimental=True, inputs=[ diff --git a/comfy_extras/nodes_eps.py b/comfy_extras/nodes_eps.py index 0fb3871c8..8c397f132 100644 --- a/comfy_extras/nodes_eps.py +++ b/comfy_extras/nodes_eps.py @@ -18,7 +18,7 @@ class EpsilonScaling(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Epsilon Scaling", - category="model_patches/unet", + category="model/patch/unet", inputs=[ io.Model.Input("model"), io.Float.Input( @@ -84,7 +84,7 @@ class TemporalScoreRescaling(io.ComfyNode): return io.Schema( node_id="TemporalScoreRescaling", display_name="TSR - Temporal Score Rescaling", - category="model_patches/unet", + category="model/patch/unet", inputs=[ io.Model.Input("model"), io.Float.Input( diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 997f21c09..afc663b22 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -40,7 +40,7 @@ class EmptyFlux2LatentImage(io.ComfyNode): return io.Schema( node_id="EmptyFlux2LatentImage", display_name="Empty Flux 2 Latent", - category="latent", + category="model/latent", inputs=[ io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), @@ -215,7 +215,7 @@ class Flux2Scheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Flux2Scheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=4096), io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=1), diff --git a/comfy_extras/nodes_frame_interpolation.py b/comfy_extras/nodes_frame_interpolation.py index 9dd34cfb8..4d5bca17e 100644 --- a/comfy_extras/nodes_frame_interpolation.py +++ b/comfy_extras/nodes_frame_interpolation.py @@ -19,7 +19,7 @@ class FrameInterpolationModelLoader(io.ComfyNode): return io.Schema( node_id="FrameInterpolationModelLoader", display_name="Load Frame Interpolation Model", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input("model_name", options=folder_paths.get_filename_list("frame_interpolation"), tooltip="Select a frame interpolation model to load. Models must be placed in the 'frame_interpolation' folder."), diff --git a/comfy_extras/nodes_freelunch.py b/comfy_extras/nodes_freelunch.py index 248efdef3..ccbd1fd90 100644 --- a/comfy_extras/nodes_freelunch.py +++ b/comfy_extras/nodes_freelunch.py @@ -29,7 +29,7 @@ class FreeU(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="FreeU", - category="model_patches/unet", + category="model/patch/unet", inputs=[ IO.Model.Input("model"), IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01, advanced=True), @@ -76,7 +76,7 @@ class FreeU_V2(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="FreeU_V2", - category="model_patches/unet", + category="model/patch/unet", inputs=[ IO.Model.Input("model"), IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01, advanced=True), diff --git a/comfy_extras/nodes_gaussian_splat.py b/comfy_extras/nodes_gaussian_splat.py new file mode 100644 index 000000000..116c14fde --- /dev/null +++ b/comfy_extras/nodes_gaussian_splat.py @@ -0,0 +1,1664 @@ +# Generic utility nodes for the SPLAT type (3D gaussian splats) + +import gzip +import logging +import math +import struct +from io import BytesIO + +import numpy as np +import torch +from typing_extensions import override +from scipy.ndimage import map_coordinates, minimum as _ndi_minimum, maximum as _ndi_maximum +from scipy.sparse import coo_matrix +from scipy.sparse.csgraph import connected_components + +import comfy.model_management +import comfy.utils +from comfy_api.latest import ComfyExtension, IO, Types +from comfy_extras.nodes_save_3d import pack_variable_mesh_batch +from server import PromptServer + +_C0 = 0.28209479177387814 # SH band-0 constant: DC coefficient -> base RGB + + +def _srgb_to_linear(c): + return torch.where(c <= 0.04045, c / 12.92, ((c.clamp_min(0) + 0.055) / 1.055) ** 2.4) + + +def _linear_to_srgb(c): + return torch.where(c <= 0.0031308, c * 12.92, 1.055 * c.clamp_min(0) ** (1 / 2.4) - 0.055) + + +def _real_len(g: Types.SPLAT, i: int) -> int: + # Real splat count of batch item i (honors variable-length `counts`). + return int(g.counts[i].item()) if g.counts is not None else g.positions.shape[1] + + +def _hex_to_rgb(h: str) -> tuple[float, float, float]: + # "#RRGGBB" -> (r,g,b) in [0,1]; falls back to black. + h = h.lstrip("#") + if len(h) != 6: + return (0.0, 0.0, 0.0) + return tuple(int(h[i:i + 2], 16) / 255.0 for i in (0, 2, 4)) + + +def _quantile(x, q): + # torch.quantile errors above 2**24 elements; stride-subsample large inputs for the estimate. + lim = 1 << 24 + if x.numel() > lim: + x = x[:: x.numel() // lim + 1] + return torch.quantile(x, q) + + +def _gaussian_ply_bytes(positions, scales, rotations, opacities, sh) -> bytes: + """Serialize render-ready gaussian tensors as a binary 3DGS .ply. + + positions (N,3) world; scales (N,3) linear; rotations (N,4) quat wxyz; opacities (N,1) in [0,1]; + sh (N,K,3) SH coefficients. Activated values are inverted to the standard 3D gaussian splat storage convention + (log scale, logit opacity). + """ + xyz = positions.cpu().numpy().astype(np.float32) + n = xyz.shape[0] + if n == 0: + raise ValueError("SplatToFile3D: gaussian is empty") + normals = np.zeros_like(xyz) + f = sh.cpu().numpy().astype(np.float32) # (N, K, 3) + f_dc = f[:, 0, :] # (N, 3) + f_rest = f[:, 1:, :].transpose(0, 2, 1).reshape(n, -1) # (N, 3*(K-1)) channel-major + op = opacities.cpu().numpy().astype(np.float32).reshape(n, 1).clip(1e-6, 1 - 1e-6) + op = np.log(op / (1.0 - op)) # inverse sigmoid (logit) + scale = np.log(scales.cpu().numpy().astype(np.float32).clip(min=1e-8)) + rot = rotations.cpu().numpy().astype(np.float32) # (N, 4) + + attrs = (['x', 'y', 'z', 'nx', 'ny', 'nz'] + + [f'f_dc_{i}' for i in range(3)] + + [f'f_rest_{i}' for i in range(f_rest.shape[1])] + + ['opacity'] + [f'scale_{i}' for i in range(3)] + [f'rot_{i}' for i in range(4)]) + elements = np.empty(n, dtype=[(a, 'f4') for a in attrs]) + elements[:] = list(map(tuple, np.concatenate([xyz, normals, f_dc, f_rest, op, scale, rot], axis=1))) + + header = "ply\nformat binary_little_endian 1.0\n" + f"element vertex {n}\n" + header += "".join(f"property float {a}\n" for a in attrs) + "end_header\n" + return header.encode('ascii') + elements.tobytes() + + +# .ksplat (mkkellogg SplatBuffer) level 0, SH degree 0: 4096-byte header, one 1024-byte section header, +# then N 44-byte records. Bucketing/quantization only exist at levels >= 1. See SplatBuffer.js. +_KSPLAT_HEADER_BYTES = 4096 +_KSPLAT_SECTION_HEADER_BYTES = 1024 +_KSPLAT_BYTES_PER_SPLAT = 44 # center 12 + scale 12 + rotation 16 + color(RGBA u8) 4 +_KSPLAT_VERSION = (0, 1) # SplatBuffer CurrentMajor/MinorVersion + + +def _gaussian_ksplat_bytes(positions, scales, rotations, opacities, sh) -> bytes: + """Serialize gaussian tensors as a level-0, SH degree-0 .ksplat (linear scale, opacity in color alpha). + + positions (N,3) world; scales (N,3) linear; rotations (N,4) wxyz; opacities (N,1) in [0,1]; sh (N,K,3). + """ + xyz = positions.cpu().numpy().astype(np.float32) + n = xyz.shape[0] + if n == 0: + raise ValueError("SplatToFile3D: gaussian is empty") + scale = scales.cpu().numpy().astype(np.float32) + rot = rotations.cpu().numpy().astype(np.float32) # wxyz, mirrors the .ply rot order + rot = rot / np.linalg.norm(rot, axis=1, keepdims=True).clip(1e-12) + rgb = np.clip(sh[:, 0, :].cpu().numpy().astype(np.float32) * _C0 + 0.5, 0, 1) + op = opacities.cpu().numpy().astype(np.float32).reshape(n, 1).clip(0, 1) + rgba = np.round(np.concatenate([rgb, op], axis=1) * 255.0).astype(np.uint8) # (N, 4) RGBA + + # 44-byte record: float center(3) + scale(3) + rot(4), then uint8 rgba(4). + floats = np.concatenate([xyz, scale, rot], axis=1).astype(' bytes: + """Serialize gaussian tensors as a gzip-compressed .spz (Niantic v2, SH degree 0, base color only). + + positions (N,3) world; scales (N,3) linear; rotations (N,4) wxyz; opacities (N,1) in [0,1]; sh (N,K,3). + """ + xyz = positions.cpu().numpy().astype(np.float32) + n = xyz.shape[0] + if n == 0: + raise ValueError("SplatToFile3D: gaussian is empty") + + # Positions: fixed point, masked to 24 bits, little-endian 3-byte words. + fixed = 1 << _SPZ_FRACTIONAL_BITS + qi = np.clip(np.round(xyz * fixed), -(1 << 23), (1 << 23) - 1).astype(np.int32) + qu = (qi & 0xFFFFFF).astype(np.uint32) + pos = np.stack([qu & 0xFF, (qu >> 8) & 0xFF, (qu >> 16) & 0xFF], axis=-1).reshape(n, 9).astype(np.uint8) + + alpha = np.round(opacities.cpu().numpy().astype(np.float32).reshape(n) * 255.0).clip(0, 255).astype(np.uint8) + + rgb = sh[:, 0, :].cpu().numpy().astype(np.float32) * _C0 + 0.5 + col = np.round(((rgb - 0.5) / _SPZ_COLOR_SCALE + 0.5) * 255.0).clip(0, 255).astype(np.uint8) # (N,3) + + sln = np.log(scales.cpu().numpy().astype(np.float32).clip(min=1e-9)) + scb = np.round((sln + 10.0) * 16.0).clip(0, 255).astype(np.uint8) # (N,3) inverts exp(b/16-10) + + rot = rotations.cpu().numpy().astype(np.float32) # wxyz + rot = rot / np.linalg.norm(rot, axis=1, keepdims=True).clip(1e-12) + rot[rot[:, 0] < 0] *= -1.0 # canonical w >= 0 (w dropped on decode) + rotb = np.round((rot[:, 1:4] + 1.0) * 127.5).clip(0, 255).astype(np.uint8) # (N,3) x,y,z + + header = bytearray(16) + struct.pack_into(' (positions, scales linear, rotations wxyz, opacities [0,1], sh (N,K,3)) ---- +# Inverse of the writers above and of spark's loaders. ksplat/splat/spz carry base color only (SH degree 0 +# -> K=1); .ply round-trips full SH. None of the formats flip axes, so import is the identity of export. +_PLY_DTYPES = {'char': 'i1', 'uchar': 'u1', 'short': 'i2', 'ushort': 'u2', 'int': 'i4', 'uint': 'u4', + 'float': 'f4', 'double': 'f8', 'int8': 'i1', 'uint8': 'u1', 'int16': 'i2', 'uint16': 'u2', + 'int32': 'i4', 'uint32': 'u4', 'float32': 'f4', 'float64': 'f8'} +_KSPLAT_COMPRESSION = { # level -> (bytesPerCenter, scale, rotation, color, shComponent, defaultScaleRange) + 0: (12, 12, 16, 4, 4, 1), 1: (6, 6, 8, 4, 2, 32767), 2: (6, 6, 8, 4, 1, 32767)} +_KSPLAT_SH_COMPONENTS = {0: 0, 1: 9, 2: 24, 3: 45} + + +def _rgb_to_sh_dc(rgb): + return ((np.asarray(rgb, np.float32) - 0.5) / _C0)[:, None, :] # (N,3) base color -> (N,1,3) SH DC + + +def _norm_quat(q): + return q / np.linalg.norm(q, axis=1, keepdims=True).clip(1e-12) + + +def _parse_ply_gaussian(data: bytes): + end = data.find(b'end_header') + if end < 0: + raise ValueError("File3DToSplat: not a PLY (missing end_header)") + header = data[:end].decode('ascii', 'replace') + body = end + len(b'end_header') + body += 2 if data[body:body + 2] == b'\r\n' else 1 + count, props, in_vertex = 0, [], False + for line in header.splitlines(): + p = line.split() + if not p: + continue + if p[0] == 'format' and p[1] != 'binary_little_endian': + raise ValueError(f"File3DToSplat: unsupported PLY format '{p[1]}' (need binary_little_endian)") + if p[0] == 'element': + in_vertex = p[1] == 'vertex' + if in_vertex: + count = int(p[2]) + elif p[0] == 'property' and in_vertex: + if p[1] == 'list': + raise ValueError("File3DToSplat: PLY vertex has list properties (unsupported)") + props.append((p[2], '<' + _PLY_DTYPES[p[1]])) + arr = np.frombuffer(data, np.dtype(props), count=count, offset=body) + names = arr.dtype.names + c = lambda k: arr[k].astype(np.float32) + n = count + + xyz = np.stack([c('x'), c('y'), c('z')], 1) + if 'scale_0' in names: + scale = np.exp(np.stack([c('scale_0'), c('scale_1'), c('scale_2')], 1)) # 3DGS stores log scale + else: + scale = np.full((n, 3), 0.01, np.float32) + if 'rot_0' in names: + rot = _norm_quat(np.stack([c('rot_0'), c('rot_1'), c('rot_2'), c('rot_3')], 1)) # wxyz + else: + rot = np.tile(np.array([1, 0, 0, 0], np.float32), (n, 1)) + opacity = 1.0 / (1.0 + np.exp(-c('opacity'))) if 'opacity' in names else np.ones(n, np.float32) + + if 'f_dc_0' in names: + dc = np.stack([c('f_dc_0'), c('f_dc_1'), c('f_dc_2')], 1) # (N,3) + rest = sorted((k for k in names if k.startswith('f_rest_')), key=lambda s: int(s.split('_')[-1])) + if rest: + r = np.stack([c(k) for k in rest], 1) # (N, 3*(K-1)) channel-major + kk = r.shape[1] // 3 + 1 + r = r.reshape(n, 3, kk - 1).transpose(0, 2, 1) # -> (N, K-1, 3) + sh = np.concatenate([dc[:, None, :], r], 1) + else: + sh = dc[:, None, :] + elif 'red' in names: + sh = _rgb_to_sh_dc(np.stack([c('red'), c('green'), c('blue')], 1) / 255.0) + else: + sh = np.zeros((n, 1, 3), np.float32) + return xyz, scale, rot, opacity, sh + + +def _parse_splat_gaussian(data: bytes): + # antimatter15 .splat: 32-byte records (f32 xyz, f32 scale, u8 rgba, u8 quat as (b-128)/128 wxyz). + if len(data) % 32 != 0: + raise ValueError("File3DToSplat: .splat size is not a multiple of 32 bytes") + rec = np.frombuffer(data, np.dtype([('xyz', ' 0: + ct, ft = (' full_splats: + lengths = np.frombuffer(data, '> 30) & 3 + q = np.zeros((n, 4), np.float32) # x,y,z,w + remaining, sumsq = combined.copy(), np.zeros(n, np.float64) + for comp in (3, 2, 1, 0): + active = comp != largest + value = (remaining & 0x1FF).astype(np.float64) + sign = (remaining >> 9) & 1 + remaining = np.where(active, remaining >> 10, remaining) + val = (1.0 / math.sqrt(2)) * (value / 0x1FF) + val = np.where(sign == 1, -val, val) + q[active, comp] = val[active] + sumsq += np.where(active, val * val, 0.0) + q[np.arange(n), largest] = np.sqrt(np.clip(1.0 - sumsq, 0, None)) + rot = _norm_quat(np.stack([q[:, 3], q[:, 0], q[:, 1], q[:, 2]], 1)) # xyzw -> wxyz + else: + qb = np.frombuffer(raw, np.uint8, count=n * 3, offset=off).reshape(n, 3).astype(np.float32) + xq = qb / 127.5 - 1.0 + w = np.sqrt(np.clip(1.0 - (xq ** 2).sum(1), 0, None)) + rot = _norm_quat(np.concatenate([w[:, None], xq], 1)) # wxyz + return xyz, scale, rot, alpha, _rgb_to_sh_dc(rgb) + + +_GAUSSIAN_PARSERS = {"ply": _parse_ply_gaussian, "splat": _parse_splat_gaussian, + "ksplat": _parse_ksplat_gaussian, "spz": _parse_spz_gaussian} + + +def _detect_splat_format(data: bytes) -> str: + if data[:3] == b'ply': + return "ply" + if data[:2] == b'\x1f\x8b': # gzip -> spz + return "spz" + if len(data) >= 2 and data[0] == 0 and data[1] >= 1: # ksplat version 0.x header + return "ksplat" + if len(data) % 32 == 0: + return "splat" + raise ValueError("File3DToSplat: could not determine splat format from contents") + + +def _gaussian_item(g: Types.SPLAT, i: int, device): + # Slice batch item i to its real length, as float32 torch tensors on `device` (SH DC -> base RGB). + end = _real_len(g, i) + to = lambda a: a.to(device=device, dtype=torch.float32) + xyz = to(g.positions[i, :end]) + rgb = (to(g.sh[i, :end, 0, :]) * _C0 + 0.5).clamp(0, 1) + opacity = to(g.opacities[i, :end]).reshape(-1) + scale = to(g.scales[i, :end]) + rot = to(g.rotations[i, :end]) + return xyz, rgb, opacity, scale, rot + + +def _quat_to_mat(q): + # q: (N, 4) wxyz, normalized -> (N, 3, 3) + q = q / q.norm(dim=-1, keepdim=True).clamp_min(1e-12) + w, x, y, z = q.unbind(-1) + return torch.stack([ + 1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y), + 2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x), + 2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y), + ], dim=-1).reshape(-1, 3, 3) + + +def _quat_mul(a, b): + # Hamilton product a (x) b, wxyz. + aw, ax, ay, az = a.unbind(-1) + bw, bx, by, bz = b.unbind(-1) + return torch.stack([ + aw * bw - ax * bx - ay * by - az * bz, + aw * bx + ax * bw + ay * bz - az * by, + aw * by - ax * bz + ay * bw + az * bx, + aw * bz + ax * by - ay * bx + az * bw, + ], dim=-1) + + +def _euler_to_quat(rx, ry, rz): + # Degrees, applied as Rz @ Ry @ Rx (rotate about X, then Y, then Z in world). Returns wxyz. + c, s = np.cos(np.radians([rx, ry, rz]) / 2.0), np.sin(np.radians([rx, ry, rz]) / 2.0) + qx = torch.tensor([c[0], s[0], 0.0, 0.0], dtype=torch.float32) + qy = torch.tensor([c[1], 0.0, s[1], 0.0], dtype=torch.float32) + qz = torch.tensor([c[2], 0.0, 0.0, s[2]], dtype=torch.float32) + return _quat_mul(_quat_mul(qz, qy), qx) + + +def _mat_to_quat(m): + # Rotation matrix (..., 3, 3) -> quaternion (..., 4) wxyz. Batched; builds the four candidate quaternions + # and keeps the one with the largest component (numerically stable across all rotations). + m00, m11, m22 = m[..., 0, 0], m[..., 1, 1], m[..., 2, 2] + m21, m12 = m[..., 2, 1], m[..., 1, 2] + m02, m20 = m[..., 0, 2], m[..., 2, 0] + m10, m01 = m[..., 1, 0], m[..., 0, 1] + q2 = torch.stack([1 + m00 + m11 + m22, 1 + m00 - m11 - m22, + 1 - m00 + m11 - m22, 1 - m00 - m11 + m22], -1) # 4 * (w^2, x^2, y^2, z^2) + cand = torch.stack([ + torch.stack([q2[..., 0], m21 - m12, m02 - m20, m10 - m01], -1), + torch.stack([m21 - m12, q2[..., 1], m10 + m01, m02 + m20], -1), + torch.stack([m02 - m20, m10 + m01, q2[..., 2], m12 + m21], -1), + torch.stack([m10 - m01, m02 + m20, m12 + m21, q2[..., 3]], -1), + ], -2) # (...,4,4) candidates, rows = wxyz + sel = q2.argmax(-1) + q = torch.gather(cand, -2, sel[..., None, None].expand(sel.shape + (1, 4)))[..., 0, :] + return q / q.norm(dim=-1, keepdim=True).clamp_min(1e-12) + + +class SplatToFile3D(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SplatToFile3D", + display_name="Create 3D File (from Splat)", + search_aliases=["gaussian to ply", "splat to file", "export gaussian"], + category="3d/splat", + description="Serialize a gaussian splat to a File3D object for Save / Preview 3D nodes. " + "Supports one item per batch only.", + inputs=[ + IO.Splat.Input("splat"), + IO.Combo.Input("format", options=["ply", "ksplat", "spz"], # TODO: add "splat" when we have a writer for it + tooltip="ply: standard 3D Gaussian Splat with full spherical harmonics. " + "ksplat: mkkellogg SplatBuffer (level 0, uncompressed), base color only " + "spz: Niantic gzip-compressed (~10x smaller), base color only " + ), + ], + outputs=[IO.File3DSplatAny.Output(display_name="model_3d")], + ) + + @classmethod + def execute(cls, splat, format="ply") -> IO.NodeOutput: + if splat.positions.shape[0] > 1: + logging.warning("SplatToFile3D supports one item per batch only. Got %d; using first.", splat.positions.shape[0]) + end = _real_len(splat, 0) + writer = {"ksplat": _gaussian_ksplat_bytes, "spz": _gaussian_spz_bytes}.get(format, _gaussian_ply_bytes) + data = writer(splat.positions[0, :end], splat.scales[0, :end], + splat.rotations[0, :end], splat.opacities[0, :end], splat.sh[0, :end]) + return IO.NodeOutput(Types.File3D(BytesIO(data), file_format=format)) + + +class File3DToSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="File3DToSplat", + display_name="Get Splat", + search_aliases=["load splat", "ply to splat", "import splat", "file to splat"], + category="3d/splat", + description="Parse a splat File3D into a gaussian splat. Inverse of Create 3D File (from Splat). " + "Supported format: PLY, SPLAT, KSPLAT, SPZ. PLY carries full spherical harmonics, " + "the other formats are base color only. Format is auto-detected from the file contents.", + inputs=[ + IO.MultiType.Input( + IO.File3DAny.Input("model_3d"), + types=[IO.File3DSplatAny, IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ], + tooltip="A gaussian splat 3D file", + ), + ], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, model_3d: Types.File3D) -> IO.NodeOutput: + data = model_3d.get_bytes() + fmt = (model_3d.format or "").lower() + parser = _GAUSSIAN_PARSERS.get(fmt) or _GAUSSIAN_PARSERS[_detect_splat_format(data)] + xyz, scale, rot, opacity, sh = parser(data) + + t = lambda a: torch.from_numpy(np.ascontiguousarray(a)).float() + splat = Types.SPLAT( + t(xyz)[None], # (1, N, 3) + t(scale)[None], # (1, N, 3) linear + t(rot)[None], # (1, N, 4) wxyz + t(opacity).reshape(1, -1, 1), # (1, N, 1) + t(sh)[None], # (1, N, K, 3) + ) + return IO.NodeOutput(splat) + + +def _view_matrix_t(yaw_deg, pitch_deg, device): + y, p = math.radians(yaw_deg), math.radians(pitch_deg) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + Ry = torch.tensor([[cy, 0, sy], [0, 1, 0], [-sy, 0, cy]], device=device) + Rx = torch.tensor([[1, 0, 0], [0, cp, -sp], [0, sp, cp]], device=device) + return Rx @ Ry + + +def _camera_basis(camera_info, dev): + # Look-at basis in the splat frame, named by their projection rows: right = image +x, up = image +y + # (down, since yflip=1), fwd = view/depth axis (eye -> scene). Load3D is three.js (right-handed, Y-up, + # camera looks down -Z); the splat is 3DGS (Y-down, Z-forward). World -> splat is a 180 deg rotation + # about X: (x, y, z) -> (x, -y, -z) (det +1, no mirror, no axis swap). + pos, tgt = camera_info.get("position", {}), camera_info.get("target", {}) + m = lambda d: torch.tensor([float(d.get("x", 0.0)), -float(d.get("y", 0.0)), -float(d.get("z", 0.0))], device=dev) + eye, target = m(pos), m(tgt) + mv = lambda v: torch.stack([v[0], -v[1], -v[2]]) # same world->splat map, for direction vectors + n = lambda v: v / v.norm().clamp_min(1e-8) + q = camera_info.get("quaternion") + if q: # exact camera world rotation (incl. roll) + qwxyz = torch.tensor([float(q.get("w", 1.0)), float(q.get("x", 0.0)), + float(q.get("y", 0.0)), float(q.get("z", 0.0))], device=dev) + R = _quat_to_mat(qwxyz[None])[0] # columns = camera world axes; looks down local -Z + right = n(mv(R[:, 0])) # camera +X -> image right + up = n(mv(-R[:, 1])) # camera +Y is image up; image-down row is its negative + fwd = n(mv(-R[:, 2])) # camera looks down local -Z -> view direction + return eye, target, right, up, fwd + fwd = n(target - eye) # no quaternion: orbit-consistent, roll-free + yaw = math.degrees(math.atan2(-float(fwd[0]), float(fwd[2]))) + pitch = math.degrees(math.asin(max(-1.0, min(1.0, float(fwd[1]))))) + W = _view_matrix_t(yaw, pitch, dev) + return eye, target, W[0], W[1], W[2] + + +def _lookat_quat_wxyz(position, target, dev): + # three.js lookAt in world frame: camera local +Z = (eye - target), up = world +Y. Returns wxyz. + z = position - target + z = z / z.norm().clamp_min(1e-8) + up0 = torch.tensor([0.0, 1.0, 0.0], device=dev) + if z.dot(up0).abs() > 0.999: # looking straight up/down + up0 = torch.tensor([0.0, 0.0, 1.0], device=dev) + x = torch.linalg.cross(up0, z) + x = x / x.norm().clamp_min(1e-8) + y = torch.linalg.cross(z, x) + R = torch.stack([x, y, z], dim=1) # columns = camera world axes + return _mat_to_quat(R[None])[0] + + +def _lookat_camera_info(position, target, fov, dev, zoom=1.0, camera_type="perspective", roll=0.0): + # Build a camera_info from a world-space (right-handed, Y-up) eye + look-at target; up = world +Y. + pos = torch.as_tensor(position, dtype=torch.float32, device=dev) + tgt = torch.as_tensor(target, dtype=torch.float32, device=dev) + q = _lookat_quat_wxyz(pos, tgt, dev) + if roll: # roll about the view axis (camera local Z) + a = math.radians(roll) + qz = torch.tensor([math.cos(a / 2), 0.0, 0.0, math.sin(a / 2)], device=dev) + q = _quat_mul(q[None], qz[None])[0] + xyz = lambda v: {"x": float(v[0]), "y": float(v[1]), "z": float(v[2])} + return {"position": xyz(pos), "target": xyz(tgt), + "quaternion": {"x": float(q[1]), "y": float(q[2]), "z": float(q[3]), "w": float(q[0])}, + "fov": float(fov), "cameraType": str(camera_type), "zoom": float(zoom)} + + +def _quat_camera_info(position, quat_xyzw, fov, dev, zoom=1.0, camera_type="perspective"): + # camera_info from an explicit world position + camera-rotation quaternion (three.js: looks down local -Z). + pos = torch.as_tensor(position, dtype=torch.float32, device=dev) + qx, qy, qz, qw = (float(c) for c in quat_xyzw) + qwxyz = torch.tensor([qw, qx, qy, qz], dtype=torch.float32, device=dev) + qwxyz = qwxyz / qwxyz.norm().clamp_min(1e-8) + R = _quat_to_mat(qwxyz[None])[0] + tgt = pos - R[:, 2] # look one unit down local -Z + xyz = lambda v: {"x": float(v[0]), "y": float(v[1]), "z": float(v[2])} + return {"position": xyz(pos), "target": xyz(tgt), + "quaternion": {"x": float(qwxyz[1]), "y": float(qwxyz[2]), "z": float(qwxyz[3]), "w": float(qwxyz[0])}, + "fov": float(fov), "cameraType": str(camera_type), "zoom": float(zoom)} + + +def _orbit_camera_info(yaw, pitch, distance, fov, pivot_splat, dev): + # Orbit helper for RenderSplat's default camera: yaw/pitch about `pivot_splat` (splat frame) at `distance`. + # World<->splat is the (x,-y,-z) map, so _camera_basis recovers exactly _view_matrix_t(yaw, pitch). + y, p = math.radians(yaw), math.radians(pitch) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + fwd_splat = torch.tensor([-cp * sy, sp, cp * cy], device=dev) # == _view_matrix_t(yaw, pitch)[2] + m = lambda v: torch.stack([v[0], -v[1], -v[2]]) # splat<->world (its own inverse) + return _lookat_camera_info(m(pivot_splat - distance * fwd_splat), m(pivot_splat), fov, dev) + + +def _orbit_camera_info_yaw(camera_info, angle_deg, dev): + # Turntable: rigidly rotate a camera_info about world +Y around its target by angle_deg. Returns a new dict. + a = math.radians(angle_deg) + ca, sa = math.cos(a), math.sin(a) + v = lambda d: torch.tensor([float(d.get("x", 0.0)), float(d.get("y", 0.0)), float(d.get("z", 0.0))], device=dev) + pos, tgt = v(camera_info.get("position", {})), v(camera_info.get("target", {})) + Ry = torch.tensor([[ca, 0.0, sa], [0.0, 1.0, 0.0], [-sa, 0.0, ca]], device=dev) + new_pos = tgt + Ry @ (pos - tgt) + q = camera_info.get("quaternion") or {} + qcur = torch.tensor([float(q.get("w", 1.0)), float(q.get("x", 0.0)), + float(q.get("y", 0.0)), float(q.get("z", 0.0))], device=dev) + qy = torch.tensor([math.cos(a / 2), 0.0, math.sin(a / 2), 0.0], device=dev) # world +Y rotation + qn = _quat_mul(qy[None], qcur[None])[0] + xyz = lambda t: {"x": float(t[0]), "y": float(t[1]), "z": float(t[2])} + return {**camera_info, "position": xyz(new_pos), + "quaternion": {"x": float(qn[1]), "y": float(qn[2]), "z": float(qn[3]), "w": float(qn[0])}} + + +def _gauss_blur(x, sigma, dev): + # Separable Gaussian blur of (1, C, H, W). Used to denoise the screen-space normal map. + r = max(1, int(round(3 * sigma))) + k = torch.exp(-0.5 * (torch.arange(-r, r + 1, device=dev, dtype=torch.float32) / sigma) ** 2) + k = k / k.sum() + c = x.shape[1] + x = torch.nn.functional.conv2d(x, k.view(1, 1, 1, -1).expand(c, 1, 1, -1), padding=(0, r), groups=c) + x = torch.nn.functional.conv2d(x, k.view(1, 1, -1, 1).expand(c, 1, -1, 1), padding=(r, 0), groups=c) + return x + + +def _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, bg, camera_info, + sharpen=1.0, headlight_shading=0.0, render_style="color"): + # Perspective-correct anisotropic gaussian splat rasterizer. Each splat is weighted by its 3D Gaussian's + # peak along each pixel's ray (AAA / Hahlbohm), composited front-to-back across depth slabs. `render_style` + # selects the image: color / clay / depth / normal. Returns (image HxWx3, coverage mask HxW) on CPU. + dev = comfy.model_management.get_torch_device() + t = lambda a: torch.as_tensor(a, dtype=torch.float32, device=dev) + idev, idtype = comfy.model_management.intermediate_device(), comfy.model_management.intermediate_dtype() + xyz, rgb, opacity = t(xyz), t(rgb).clamp(0, 1), t(opacity).reshape(-1) + scale, rot = t(scale) * float(splat_scale), t(rot) + do_linear = render_style == "color" # colour blends in linear light, re-encoded at the end + if do_linear: + rgb = _srgb_to_linear(rgb) + flat = width * height + bg_t = t(bg) + bg_comp = _srgb_to_linear(bg_t) if do_linear else bg_t # background blended in the same space as the splats + need_depth = render_style == "depth" + need_normal = render_style in ("normal", "clay") or headlight_shading > 0 + + def background_only(): # no splats to rasterize -> just the background + empty mask + img = bg_t.expand(height, width, 3) if render_style == "color" else torch.zeros(height, width, 3, device=dev) + return img.to(idev, idtype), torch.zeros(height, width, device=idev, dtype=idtype) + + if xyz.shape[0] == 0: # empty input (e.g. all culled by opacity_threshold) + return background_only() + + eye, target, right, up, fwd = _camera_basis(camera_info, dev) # all camera state comes from camera_info + W = torch.stack([right, up, fwd], 0) # rows = camera axes (world -> camera) + cam = (xyz - eye) @ W.T + fov = float(camera_info.get("fov", 0) or 0) or 35.0 + zoom = float(camera_info.get("zoom", 1.0) or 1.0) # three.js digital zoom: scales the focal length + is_ortho = str(camera_info.get("cameraType", "")).lower().startswith("ortho") + xc, yc, zc = cam.unbind(-1) + + keep = zc > 1e-2 + xc, yc, zc, rgb, opacity, scale, rot = (a[keep] for a in (xc, yc, zc, rgb, opacity, scale, rot)) + if xc.shape[0] == 0: # nothing in front of the camera -> background only + return background_only() + if render_style == "clay": + rgb = torch.full_like(rgb, 0.75) # neutral albedo -> shading shows pure geometry + + f = (min(width, height) / 2) / math.tan(math.radians(fov) / 2) * zoom # fov over the smaller axis, x camera zoom + cx0, cy0 = width / 2, height / 2 + + # Camera-space 3D covariance per splat: Sigma = (W Rq) diag(scale^2) (W Rq)^T, plus a tiny relative + # regularizer for a stable inverse (a pixel-size Mip low-pass would over-thicken flat surfels and blur). + Mw = W[None] @ _quat_to_mat(rot) # (N,3,3) world -> camera + cam_cov = (Mw * scale.square()[:, None, :]) @ Mw.transpose(1, 2) + cam_cov = cam_cov + (cam_cov.diagonal(dim1=-2, dim2=-1).mean(-1) * 1e-3)[:, None, None] * torch.eye(3, device=dev) + + # Perspective-correct weighting: peak of the 3D Gaussian along each pixel ray. Precompute Si, Si@mu, mu^T Si mu. + mu = torch.stack([xc, yc, zc], -1) + si = torch.linalg.inv(cam_cov) + simu = (si @ mu[:, :, None])[:, :, 0] # (N,3) + musimu = (mu * simu).sum(-1) # (N,) + s00, s01, s02 = si[:, 0, 0], si[:, 0, 1], si[:, 0, 2] + s11, s12, s22 = si[:, 1, 1], si[:, 1, 2], si[:, 2, 2] + simu0, simu1, simu2 = simu.unbind(-1) + if need_normal: # surfel normal = thinnest axis, oriented toward camera + nrm = Mw[torch.arange(Mw.shape[0], device=dev), :, scale.argmin(-1)] # (N,3) camera-space normal + nrm = nrm * torch.where(nrm[:, 2:3] > 0, -1.0, 1.0) # flip so nz <= 0 (faces camera) + + # Screen centre (exact) + footprint radius from the affine 2D projection (used only to size the kernel). + # The image is +y-down, so the projection's y row is unflipped - it matches the splat frame's +Y. + jm = torch.zeros(xc.shape[0], 2, 3, device=dev) + if is_ortho: # parallel projection: screen = s * (xc, yc) + s = f / float((target - eye).norm().clamp_min(1e-6)) # pixels per world unit at the target plane + cx, cy = cx0 + s * xc, cy0 + s * yc + jm[:, 0, 0] = s + jm[:, 1, 1] = s + else: # perspective: screen = f * (xc, yc) / zc + invz = 1.0 / zc + cx, cy = cx0 + f * xc * invz, cy0 + f * yc * invz + jm[:, 0, 0], jm[:, 0, 2] = f * invz, -f * xc * invz.square() + jm[:, 1, 1], jm[:, 1, 2] = f * invz, -f * yc * invz.square() + cov2 = jm @ cam_cov @ jm.transpose(1, 2) + a, b, c = cov2[:, 0, 0], cov2[:, 0, 1], cov2[:, 1, 1] + max_eig = (a + c) * 0.5 + (((a - c) * 0.5).square() + b * b).clamp_min(0).sqrt() + radius = 3.0 * max_eig.clamp_min(1e-8).sqrt() + K = int(min(max(24, min(width, height) // 16), max(2, math.ceil(_quantile(radius, 0.995).item())))) + + # Per-splat kernel size: bucket splats by radius into a coarse ladder of window sizes (global K stays the cap) so + # small splats (the bulk of it) use a small window. + levels = [L for L in (16, 64, 256) if L < K] + [K] + levels_t = torch.tensor(levels, device=dev, dtype=torch.float32) + grids = [] + for L in levels: + rng = torch.arange(-L, L + 1, device=dev, dtype=torch.float32) + gy, gx = torch.meshgrid(rng, rng, indexing="ij") + grids.append((gx.reshape(-1), gy.reshape(-1))) + blevel = torch.bucketize(radius * (4.0 / 3.0), levels_t).clamp_(max=len(levels) - 1) # window >= ~4 sigma + + n = zc.shape[0] + ns = int(min(256, max(1, n // 1000))) # depth slabs: 1 per ~1000 splats, capped + nl = len(levels) + order = torch.argsort(zc) # front (small zc) -> back -> defines the slabs + bounds = torch.linspace(0, n, ns + 1, device=dev).round().long() + rank = torch.empty(n, dtype=torch.long, device=dev) + rank[order] = torch.arange(n, device=dev) # depth rank of each splat + slab_id = (torch.searchsorted(bounds, rank, right=True) - 1).clamp_(0, ns - 1) + key = slab_id * nl + blevel # group by slab, then kernel level (order-free within) + order = torch.argsort(key) + key = key[order] + + cxr, cyr = cx[order].round(), cy[order].round() + s00, s01, s02 = s00[order], s01[order], s02[order] + s11, s12, s22 = s11[order], s12[order], s22[order] + s01b, s02b, s12b = s01 * 2, s02 * 2, s12 * 2 # doubled cross terms for the fused quadratic forms + simu0, simu1, simu2, musimu = simu0[order], simu1[order], simu2[order], musimu[order] + opacity, rgb = opacity[order], rgb[order] + zc_o = zc[order] if need_depth else None + nrm_o = nrm[order] if need_normal else None + mux_o, muy_o, muz_o = (xc[order], yc[order], zc[order]) if is_ortho else (None, None, None) + + # Pack the per-splat scalars into one tensor so each chunk slices once + common = [cxr, cyr, s00, s11, s22, s01b, s02b, s12b, opacity] + pstack = torch.stack(common + ([s02, s12, mux_o, muy_o, muz_o] if is_ortho else [simu0, simu1, simu2, musimu])) + + # Precompute the (slab, level) run table on-GPU and pull it to the CPU once + starts = torch.cat([torch.zeros(1, dtype=torch.long, device=dev), (key[1:] != key[:-1]).nonzero().flatten() + 1]) + ks = key[starts] + run_lo = starts.tolist() + [n] + run_lev = (ks % nl).tolist() + run_slab = torch.div(ks, nl, rounding_mode="floor").tolist() + slab_runs = [[] for _ in range(ns)] + for r in range(len(run_lev)): + slab_runs[run_slab[r]].append((run_lo[r], run_lo[r + 1], run_lev[r])) + + def splat(lo, hi, ox, oy): # -> pixel idx (m,M), alpha (m,M); weight = 3D Gaussian peak along each pixel's ray + cols = pstack[:, lo:hi, None].unbind(0) + cxr_, cyr_, a00, a11, a22, b01, b02, b12, opa = cols[:9] # a* = Si components; b* = 2 * cross terms + px = cxr_ + ox[None, :] + py = cyr_ + oy[None, :] + valid = (px >= 0) & (px < width) & (py >= 0) & (py < height) + if is_ortho: # parallel ray (0,0,1) from screen point (X, Y, 0); rz constant per splat + c02, c12, mx, my, mz = cols[9:] + rx = (px - cx0) / s - mx + ry = (py - cy0) / s - my + rz = -mz + a22rz = a22 * rz + inx = torch.addcmul(b02 * rz, a00, rx).addcmul_(b01, ry) # a00 rx + b01 ry + b02 rz + rSr = torch.addcmul(a22rz * rz, rx, inx).addcmul_(ry, torch.addcmul(b12 * rz, a11, ry)) + dsr = torch.addcmul(a22rz, c02, rx).addcmul_(c12, ry) + q = torch.addcdiv(rSr, dsr * dsr, a22.clamp_min(1e-12), value=-1).clamp_min_(0) + else: # perspective ray (dx,dy,1) through the camera origin + su0, su1, su2, mus = cols[9:] + dx, dy = (px - cx0) / f, (py - cy0) / f + dsid = torch.addcmul(a22, dx, torch.addcmul(b02, a00, dx)) # a22 + dx*(a00 dx + b02) + dsid = dsid.addcmul_(dy, torch.addcmul(b12, a11, dy)) # + dy*(a11 dy + b12) + dsid = dsid.addcmul_(b01 * dx, dy) # + (2 s01) dx dy + dsimu = torch.addcmul(su2, dx, su0).addcmul_(dy, su1) + q = torch.addcdiv(mus, dsimu * dsimu, dsid.clamp_min(1e-12), value=-1).clamp_min_(0) + alpha = (opa * torch.exp(-0.5 * q) * valid).clamp_(0, 0.999) + idx = py.long().clamp(0, height - 1) * width + px.long().clamp(0, width - 1) + return idx, alpha + + # Front-to-back compositing over the depth slabs set up above. Within a slab the accumulation is a pure + # sum (order-independent), so splats are grouped by kernel level and each level uses its own tight window. + sharp = sharpen != 1.0 # winner-take-more colour blend: dominant splat shows more + cacc = torch.zeros((flat, 3), device=dev) + trans = torch.ones((flat,), device=dev) + a_buf = torch.zeros((flat,), device=dev) # sum alpha -> colour/depth/normal weight (alpha-weighted mean) + tau_buf = torch.zeros((flat,), device=dev) # sum -ln(1-alpha) -> slab opacity = 1-prod(1-alpha) + crgb = torch.zeros((flat, 3), device=dev) # sum alpha^p * rgb -> slab colour + wbuf = torch.zeros((flat,), device=dev) if sharp else None # sum alpha^p -> colour normalizer (sharp only) + dacc = torch.zeros((flat,), device=dev) if need_depth else None # front-weighted depth + nacc = torch.zeros((flat, 3), device=dev) if need_normal else None # front-weighted camera-space normal + zslab = torch.zeros((flat,), device=dev) if need_depth else None + nslab = torch.zeros((flat, 3), device=dev) if need_normal else None + stale = 0 # consecutive fully-occluded slabs -> early-out + for si in range(ns): + runs = slab_runs[si] + if not runs: + continue + a_buf.zero_() + tau_buf.zero_() + crgb.zero_() + if sharp: + wbuf.zero_() + if need_depth: + zslab.zero_() + if need_normal: + nslab.zero_() + for r_lo, r_hi, li in runs: # contiguous same-kernel-level runs in this slab + ox, oy = grids[li] + ch = max(2048, 10_000_000 // ox.shape[0]) # splats/chunk, bounded by this level's kernel size + for lo in range(r_lo, r_hi, ch): + hi = min(lo + ch, r_hi) + idx, alpha = splat(lo, hi, ox, oy) + idx, af = idx.reshape(-1), alpha.reshape(-1) + a_buf.index_add_(0, idx, af) + tau_buf.index_add_(0, idx, (-torch.log1p(-alpha)).reshape(-1)) # -ln(1-alpha), correct opacity merge + apw = alpha.pow(sharpen) if sharp else alpha # bias colour toward the highest-alpha splat + crgb.index_add_(0, idx, (apw[:, :, None] * rgb[lo:hi, None, :]).reshape(-1, 3)) + if sharp: + wbuf.index_add_(0, idx, apw.reshape(-1)) + if need_depth: + zslab.index_add_(0, idx, (alpha * zc_o[lo:hi, None]).reshape(-1)) + if need_normal: + nslab.index_add_(0, idx, (alpha[:, :, None] * nrm_o[lo:hi, None, :]).reshape(-1, 3)) + slab_a = 1 - torch.exp(-tau_buf) # 1 - prod(1-alpha): true opacity of the slab's splats + front = trans * slab_a + denom = wbuf if sharp else a_buf + cacc.addcmul_(front[:, None], crgb / denom.clamp_min(1e-8)[:, None]) # cacc += front * (crgb/denom) + if need_depth or need_normal: + ainv = a_buf.clamp_min(1e-8) # alpha-weighted-mean normalizer (depth/normal only) + if need_depth: + dacc.addcmul_(front, zslab / ainv) + if need_normal: + nacc.addcmul_(front[:, None], nslab / ainv[:, None]) + trans.mul_(1 - slab_a) + if si % 8 == 7: # checkpoint every 8 slabs (a per-slab GPU sync would cost more) + if float(front.max()) < 1e-3: # this checkpoint slab is fully occluded by what is in front + stale += 1 + if stale >= 2: # two occluded checkpoints running -> the rest are too -> stop + break + else: + stale = 0 + + cov = 1 - trans + covg = cov.reshape(height, width) + covm = covg > 0.5 if render_style in ("depth", "normal") else None # silhouette mask (depth/normal styles only) + depth_map = (dacc / cov.clamp_min(1e-6)).reshape(height, width) if need_depth else None + nrm_map = None + if need_normal: + # Per-splat surfel normals are jittery, so do a masked blur + nb = nacc.reshape(height, width, 3).permute(2, 0, 1)[None] + cb = cov.reshape(1, 1, height, width) + nb, cb = _gauss_blur(nb, 1.2, dev), _gauss_blur(cb, 1.2, dev) + normal = (nb / cb.clamp_min(1e-6))[0].permute(1, 2, 0) + nrm_map = normal / normal.norm(dim=-1, keepdim=True).clamp_min(1e-6) + + if render_style == "depth": # near = bright, far = dark, 0 off-object + d = torch.zeros(height, width, device=dev) + if bool(covm.any()): + lo, hi = depth_map[covm].min(), depth_map[covm].max() + d = torch.where(covm, ((hi - depth_map) / (hi - lo).clamp_min(1e-6)).clamp(0, 1), d) + img = d[:, :, None].expand(height, width, 3) + elif render_style == "normal": # OpenGL normal map: +X right, +Y up, +Z to viewer + enc = (nrm_map * t([1.0, -1.0, -1.0]) * 0.5 + 0.5).clamp(0, 1) + img = enc * covm[:, :, None] + else: # color / clay + img = cacc.reshape(height, width, 3) + if render_style == "clay": # studio key light + ambient -> sculpted matte look + kl = t([-0.4, -0.7, -0.6]) # key from screen upper-left, angled toward the viewer + kl = kl / kl.norm() + hl = (0.5 * (nrm_map * kl).sum(-1) + 0.5).clamp(0, 1) # half-Lambert: soft terminator, no harsh dark side + img = img * (0.35 + 0.65 * hl * hl)[:, :, None] # ambient floor + diffuse key + elif headlight_shading > 0: # camera headlight: darken faces turned from view + k = float(headlight_shading) + ndotl = (-nrm_map[:, :, 2]).clamp(0, 1) + img = img * (1 - 0.6 * k + 0.6 * k * ndotl)[:, :, None] + img = img.addcmul_(trans.reshape(height, width, 1), bg_comp) + if do_linear: # back to display space after linear compositing + img = _linear_to_srgb(img) + return img.clamp(0, 1).to(idev, idtype), covg.clamp(0, 1).to(idev, idtype) + + +class RenderSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RenderSplat", + display_name="Render Splat", + search_aliases=["splat to image", "render splat", "gaussian turntable"], + category="3d/splat", + description="Render a gaussian splat as an image with an anisotropic EWA rasterizer (oriented " + "elliptical splats, antialiased, depth-sorted front-to-back). The camera comes from a " + "camera_info input (Load / Preview 3D, or a Create Camera Info node); leave it empty to " + "auto-frame the splat. Set frames greater than 1 for a turntable batch of images to feed a Video node.", + inputs=[ + IO.Splat.Input("splat"), + IO.Int.Input("width", default=1024, min=64, max=2048, step=8), + IO.Int.Input("height", default=1024, min=64, max=2048, step=8), + IO.Int.Input("frames", default=1, min=-240, max=240, + tooltip="-1, 0, 1 = single still image; >1 = turntable, the camera orbits over a full " + "360 turn (works with any camera_info). Negative value orbits the other way."), + IO.Float.Input("splat_scale", default=1.0, min=0.1, max=5.0, step=0.05, advanced=True, + tooltip="Multiplier on each splat's projected footprint (lower = crisper points, " + "higher = softer/fuller surface)."), + IO.Float.Input("sharpen", default=2.0, min=1.0, max=8.0, step=0.5, + tooltip="Sharpen overlapping splats: 1.0 = physically-correct blend; higher biases " + "each pixel toward its dominant (nearest) splat for crisper texture, without " + "shrinking splats or opening gaps. Non-physical above 1."), + IO.Float.Input("headlight_shading", default=0.0, min=0.0, max=3.0, step=0.05, advanced=True, + tooltip="Diffuse shading from a light at the camera (headlight), using the splat surfel " + "normals: darkens surfaces that turn away from view to reveal form/curvature. " + "0 = flat albedo, 1 = strongest shading."), + IO.Float.Input("opacity_threshold", default=0.0, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="Cull gaussians with opacity below this (removes faint floaters)."), + IO.Combo.Input("render_style", options=["color", "clay", "depth", "normal"], + tooltip="What the image output shows: color, clay (neutral-albedo shaded), " + "depth (near=bright), normal (OpenGL normal map)."), + IO.Color.Input("background", default="#000000"), + IO.Image.Input("bg_image", optional=True, + tooltip="Optional background plate composited behind the splat (overrides the solid " + "background colour). Resized to the render size; a batch is used per frame, " + "a single image for all. color/clay only."), + IO.Load3DCamera.Input("camera_info", optional=True, + tooltip="Camera to render from - a Load3D / Preview3D camera or a Create Camera " + "Info node. If empty, the splat is auto-framed from a default 3/4 view."), + ], + outputs=[IO.Image.Output(display_name="image"), IO.Mask.Output(display_name="mask")], + ) + + @classmethod + def execute(cls, splat, width, height, frames, splat_scale, sharpen, headlight_shading, + opacity_threshold, background, render_style, camera_info=None, bg_image=None) -> IO.NodeOutput: + bg = _hex_to_rgb(background) + bg_imgs = None + if bg_image is not None: # resize the plate(s) to the render size: (B,H,W,3) + bi = bg_image[... , :3].movedim(-1, 1) # (B,3,H,W) + bi = comfy.utils.common_upscale(bi, width, height, "bicubic", "disabled") + bg_imgs = bi.movedim(1, -1).clamp(0, 1) + n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still) + orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction + imgs, masks = [], [] + device = comfy.model_management.get_torch_device() + total = splat.positions.shape[0] * n_frames + pbar = comfy.utils.ProgressBar(total) if total > 1 else None + k = 0 + for i in range(splat.positions.shape[0]): + xyz, rgb, opacity, scale, rot = _gaussian_item(splat, i, device) + if opacity_threshold > 0: + keep = opacity >= opacity_threshold + xyz, rgb, opacity, scale, rot = xyz[keep], rgb[keep], opacity[keep], scale[keep], rot[keep] + base_cam = camera_info + if base_cam is None: # no camera -> default 3/4 view, auto-framed on the splat + center = xyz.mean(0) if xyz.shape[0] else torch.zeros(3, device=device) + extent = (_quantile((xyz - center).norm(dim=-1), 0.99).clamp_min(1e-4) if xyz.shape[0] + else torch.tensor(1.0, device=device)) + dist = float(extent / (math.tan(math.radians(35.0) / 2) * 0.9)) + base_cam = _orbit_camera_info(35.0, 30.0, dist, 35.0, center, device) + for fr in range(n_frames): + cam_fr = (base_cam if n_frames == 1 + else _orbit_camera_info_yaw(base_cam, orbit_dir * 360.0 * fr / n_frames, device)) + bg_k = bg_imgs[k % bg_imgs.shape[0]] if bg_imgs is not None else bg # per-frame plate, or solid colour + img, mask = _render_gaussian(xyz, rgb, opacity, scale, rot, width, height, splat_scale, bg_k, cam_fr, + sharpen=sharpen, headlight_shading=headlight_shading, + render_style=render_style) + imgs.append(img) + masks.append(mask) + k += 1 + if pbar is not None: + pbar.update(1) + return IO.NodeOutput(torch.stack(imgs), torch.stack(masks)) + + +class CreateCameraInfo(IO.ComfyNode): # TODO: move to better file + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="CreateCameraInfo", + display_name="Create Camera Info", + search_aliases=["camera position", "make camera info", "orbit camera", "look at camera"], + category="3d", + description="Build a camera_info" + "Mode 'orbit' aims with yaw/pitch/distance around the target; " + "'look_at' places the camera at world position. Coordinates are the viewer's world space (right-handed,Y-up).", + inputs=[ + IO.DynamicCombo.Input("mode", options=[ + IO.DynamicCombo.Option("orbit", [ + IO.Float.Input("yaw", default=35.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("pitch", default=30.0, min=-89.0, max=89.0, step=1.0), + IO.Float.Input("distance", default=4.0, min=0.01, max=1000.0, step=0.01, + tooltip="Camera distance from the target."), + ]), + IO.DynamicCombo.Option("look_at", [ + IO.Float.Input("position_x", default=4.0, min=-1000.0, max=1000.0, step=0.01, + tooltip="Camera position in world space (right-handed, Y-up)."), + IO.Float.Input("position_y", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("position_z", default=4.0, min=-1000.0, max=1000.0, step=0.01), + ]), + IO.DynamicCombo.Option("quaternion", [ + IO.Float.Input("position_x", default=4.0, min=-1000.0, max=1000.0, step=0.01, + tooltip="Camera position in world space (right-handed, Y-up)."), + IO.Float.Input("position_y", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("position_z", default=4.0, min=-1000.0, max=1000.0, step=0.01), + IO.Float.Input("quat_x", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_y", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_z", default=0.0, min=-1.0, max=1.0, step=0.001), + IO.Float.Input("quat_w", default=1.0, min=-1.0, max=1.0, step=0.001, + tooltip="Camera world-rotation quaternion (three.js: looks down local -Z). Normalized for you."), + ]), + ], tooltip="How to define the camera: orbit angles, an explicit position, or a position + quaternion."), + IO.Float.Input("target_x", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True, + tooltip="Look-at point (orbit pivot / aim). In orbit mode, move it to pan/translate the " + "whole camera. Ignored in quaternion mode. Defaults to the origin."), + IO.Float.Input("target_y", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True), + IO.Float.Input("target_z", default=0.0, min=-1000.0, max=1000.0, step=0.01, advanced=True), + IO.Float.Input("roll", default=0.0, min=-180.0, max=180.0, step=1.0, + tooltip="Camera roll about the view axis, degrees."), + IO.Float.Input("fov", default=35.0, min=1.0, max=120.0, step=1.0, + tooltip="Vertical field of view in degrees."), + IO.Float.Input("zoom", default=1.0, min=0.01, max=100.0, step=0.01, + tooltip="Digital zoom (focal-length multiplier). >1 zooms in without moving the camera."), + IO.Combo.Input("camera_type", options=["perspective", "orthographic"], + tooltip="Projection used by Render Splat: perspective (foreshortening) or orthographic (parallel)."), + ], + outputs=[IO.Load3DCamera.Output(display_name="camera_info")], + ) + + @classmethod + def execute(cls, mode, target_x, target_y, target_z, roll, fov, zoom=1.0, camera_type="perspective") -> IO.NodeOutput: + dev = comfy.model_management.get_torch_device() + kind = mode["mode"] + if kind == "quaternion": # explicit world position + camera rotation + position = [mode["position_x"], mode["position_y"], mode["position_z"]] + quat = [mode["quat_x"], mode["quat_y"], mode["quat_z"], mode["quat_w"]] + return IO.NodeOutput(_quat_camera_info(position, quat, fov, dev, zoom=zoom, camera_type=camera_type)) + target = [target_x, target_y, target_z] # orbit pivot / aim; move it to pan the whole camera + if kind == "orbit": # yaw/pitch/distance about the target (world Y-up) + y, p = math.radians(mode["yaw"]), math.radians(mode["pitch"]) + cy, sy, cp, sp = math.cos(y), math.sin(y), math.cos(p), math.sin(p) + d = mode["distance"] + position = [target_x + d * cp * sy, target_y + d * sp, target_z + d * cp * cy] + else: # look_at: explicit world-space camera position + position = [mode["position_x"], mode["position_y"], mode["position_z"]] + return IO.NodeOutput(_lookat_camera_info(position, target, fov, dev, zoom=zoom, camera_type=camera_type, roll=roll)) + + +class TransformSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TransformSplat", + display_name="Transform Splat", + search_aliases=["move splat", "rotate splat", "scale splat", "gaussian transform"], + category="3d/splat", + description="Translate, rotate, and scale a gaussian splat. " + "Non-uniform scale also reshapes every individual splat, slower process.", + inputs=[ + IO.Splat.Input("splat"), + IO.Float.Input("translate_x", default=0.0, min=-100.0, max=100.0, step=0.01), + IO.Float.Input("translate_y", default=0.0, min=-100.0, max=100.0, step=0.01), + IO.Float.Input("translate_z", default=0.0, min=-100.0, max=100.0, step=0.01), + IO.Float.Input("rotate_x", default=0.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("rotate_y", default=0.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("rotate_z", default=0.0, min=-360.0, max=360.0, step=1.0), + IO.Float.Input("scale_x", default=1.0, min=0.01, max=100.0, step=0.01), + IO.Float.Input("scale_y", default=1.0, min=0.01, max=100.0, step=0.01), + IO.Float.Input("scale_z", default=1.0, min=0.01, max=100.0, step=0.01), + ], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, splat, translate_x, translate_y, translate_z, + rotate_x, rotate_y, rotate_z, scale_x, scale_y, scale_z) -> IO.NodeOutput: + pos = splat.positions + dev, dt = pos.device, pos.dtype + q_rot = _euler_to_quat(rotate_x, rotate_y, rotate_z).to(device=dev, dtype=dt) + R = _quat_to_mat(q_rot[None])[0] # (3, 3) node rotation + D = torch.tensor([scale_x, scale_y, scale_z], dtype=dt, device=dev) + A = D[:, None] * R # diag(D) @ R: per-axis scale after rotation + t = torch.tensor([translate_x, translate_y, translate_z], dtype=dt, device=dev) + + positions = pos @ A.T + t # rotate, scale per-axis, then translate + if scale_x == scale_y == scale_z: # uniform: rotation/scale factor out cleanly + scales = splat.scales * scale_x + rotations = _quat_mul(q_rot.expand_as(splat.rotations), splat.rotations) + rotations = rotations / rotations.norm(dim=-1, keepdim=True).clamp_min(1e-12) + else: # non-uniform: transform Sigma = A R s^2 R^T A^T, re-extract + rg = _quat_to_mat(splat.rotations.reshape(-1, 4)) # (M,3,3) per-splat rotation + s2 = splat.scales.reshape(-1, 3).square() + cov = (rg * s2[:, None, :]) @ rg.transpose(-1, -2) # Sigma + cov = A @ cov @ A.T # A Sigma A^T (A broadcast over splats) + lam, V = torch.linalg.eigh(cov) # symmetric -> eigenvalues (asc), orthonormal axes + V = V * torch.where(torch.linalg.det(V) < 0, -1.0, 1.0)[..., None, None] # keep a proper rotation + scales = lam.clamp_min(0).sqrt().reshape(splat.scales.shape) + rotations = _mat_to_quat(V).reshape(splat.rotations.shape) + out = Types.SPLAT(positions, scales, rotations, splat.opacities, splat.sh, + counts=getattr(splat, "counts", None)) + return IO.NodeOutput(out) + + +class GetSplatCount(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GetSplatCount", + display_name="Get Splat Count", + search_aliases=["splat count", "gaussian count", "number of splats", "splat info"], + category="3d/splat", + description="Returns the number of splats summed across the batch.", + inputs=[IO.Splat.Input("splat")], + outputs=[IO.Splat.Output(display_name="splat"), + IO.Int.Output(display_name="count"), + ], + hidden=[IO.Hidden.unique_id], + ) + + @classmethod + def execute(cls, splat) -> IO.NodeOutput: + count = sum(_real_len(splat, i) for i in range(splat.positions.shape[0])) + if cls.hidden.unique_id: # show the count inline on the node + PromptServer.instance.send_progress_text(f"{count:,} splats", cls.hidden.unique_id) + return IO.NodeOutput(splat, count) + + +def _pad_stack(items, n): + # Stack a list of (Lᵢ, *tail) tensors into (B, n, *tail), zero-padding each row up to n. + tail = items[0].shape[1:] + out = items[0].new_zeros((len(items), n, *tail)) + for i, t in enumerate(items): + out[i, :t.shape[0]] = t + return out + + +def _merge_gaussians(gaussians: list) -> Types.SPLAT: + # Concatenate SPLAT batches along the splat dimension (per item), padding SH to the highest degree. + gs = [g for g in gaussians if g is not None] + if not gs: + raise ValueError("MergeSplat: no gaussians to merge") + b = gs[0].positions.shape[0] + for g in gs: + if g.positions.shape[0] != b: + raise ValueError(f"MergeSplat: batch size mismatch ({b} vs {g.positions.shape[0]}).") + max_k = max(g.sh.shape[2] for g in gs) + + pos_b, scl_b, rot_b, op_b, sh_b, lengths = [], [], [], [], [], [] + for i in range(b): + pos_i, scl_i, rot_i, op_i, sh_i = [], [], [], [], [] + for g in gs: + end = _real_len(g, i) + pos_i.append(g.positions[i, :end]) + scl_i.append(g.scales[i, :end]) + rot_i.append(g.rotations[i, :end]) + op_i.append(g.opacities[i, :end]) + sh = g.sh[i, :end] # (end, K, 3) + if sh.shape[1] < max_k: # zero-pad lower-degree SH + sh = torch.cat([sh, sh.new_zeros(sh.shape[0], max_k - sh.shape[1], sh.shape[2])], dim=1) + sh_i.append(sh) + pos_b.append(torch.cat(pos_i)) + scl_b.append(torch.cat(scl_i)) + rot_b.append(torch.cat(rot_i)) + op_b.append(torch.cat(op_i)) + sh_b.append(torch.cat(sh_i)) + lengths.append(pos_b[-1].shape[0]) + + n = max(lengths) + counts = None + if len(set(lengths)) > 1: + counts = torch.tensor(lengths, device=gs[0].positions.device, dtype=torch.int64) + return Types.SPLAT(_pad_stack(pos_b, n), _pad_stack(scl_b, n), _pad_stack(rot_b, n), + _pad_stack(op_b, n), _pad_stack(sh_b, n), counts=counts) + + +class MergeSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + # Autogrow: a splat0/splat1/... input list that grows a fresh slot as you connect splats. + splats = IO.Autogrow.TemplatePrefix(IO.Splat.Input("splat"), prefix="splat", min=2, max=32) + return IO.Schema( + node_id="MergeSplat", + display_name="Merge Splats", + search_aliases=["union splat", "densify gaussian", "combine splat", "merge gaussian"], + category="3d/splat", + description="Concatenate any number of gaussian splats into one. Unioning several decodes of the same " + "latent at different seeds densifies the surface, this can improve surface quality when meshing.", + inputs=[IO.Autogrow.Input("splats", template=splats)], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, splats: IO.Autogrow.Type) -> IO.NodeOutput: + gs = [v for v in splats.values() if v is not None] + if not gs: + raise ValueError("MergeSplat: connect at least one splat.") + return IO.NodeOutput(_merge_gaussians(gs)) + + +def _inverse_covariance(scale, quat): + # Per-splat Sigma^-1 = R diag(1/s^2) R^T. scale (N,3) linear std, quat (N,4) wxyz -> (N,3,3). + q = quat / quat.norm(dim=1, keepdim=True).clamp_min(1e-12) + w, x, y, z = q.unbind(-1) + R = torch.stack([ + 1 - 2 * (y * y + z * z), 2 * (x * y - w * z), 2 * (x * z + w * y), + 2 * (x * y + w * z), 1 - 2 * (x * x + z * z), 2 * (y * z - w * x), + 2 * (x * z - w * y), 2 * (y * z + w * x), 1 - 2 * (x * x + y * y), + ], dim=1).reshape(-1, 3, 3) + inv_s2 = 1.0 / scale.clamp_min(1e-8) ** 2 # (N, 3) + return torch.einsum("nij,nj,nkj->nik", R, inv_s2, R) + + +def _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, color_sharpen=1.0, chunk=4096, progress=None, + col_dtype=torch.float16): + # Splat each gaussian as its oriented-covariance disk (3-sigma, opacity-weighted) into a density grid, + # plus a colour volume. Each gaussian uses a voxel window sized to its OWN 3-sigma (capped at `kernel`). + # Colour is weighted by w^color_sharpen: >1 biases each voxel toward its dominant gaussian (crisper + # texture). Returns (density, colour numerator, colour normaliser, origin, voxel). + pad = 4.0 * scale.median() + lo = xyz.amin(0) - pad + hi = xyz.amax(0) + pad + voxel = ((hi - lo).max() / res).clamp_min(1e-8) + dx, dy, dz = (torch.ceil((hi - lo) / voxel).long() + 1).tolist() + + sinv = _inverse_covariance(scale, quat) + kreq = torch.ceil(3.0 * scale.amax(-1) / voxel).long().clamp(1, int(kernel)) # per-gaussian half-width + sharp = color_sharpen != 1.0 + vol = torch.zeros(dx * dy * dz, device=device) # Sum(w) density (surface) + colvol = torch.zeros(dx * dy * dz, 3, device=device, dtype=col_dtype) # Sum(w^p * rgb) colour numerator + wcol = torch.zeros(dx * dy * dz, device=device, dtype=col_dtype) if sharp else None # Sum(w^p) normaliser (p>1) + n, done = xyz.shape[0], 0 + for k in range(1, int(kernel) + 1): + sel = (kreq == k).nonzero(as_tuple=True)[0] + if sel.numel() == 0: + continue + rng = torch.arange(-k, k + 1, device=device, dtype=torch.float32) + off = torch.stack(torch.meshgrid(rng, rng, rng, indexing="ij"), -1).reshape(-1, 3) # (M, 3) + for st in range(0, sel.numel(), chunk): + gi = sel[st:st + chunk] + cc = xyz[gi] + idx = ((cc - lo) / voxel).round()[:, None, :] + off[None] # (b, M, 3) voxel coords + d = (lo + idx * voxel) - cc[:, None, :] # world offset to voxel center + quad = torch.einsum("bmi,bij,bmj->bm", d, sinv[gi], d) + wgt = opacity[gi, None] * torch.exp(-0.5 * quad) + wgt = torch.where(quad < 9.0, wgt, torch.zeros_like(wgt)) # clip beyond 3 sigma + ii = idx.long() + ix = ii[..., 0].clamp(0, dx - 1) + iy = ii[..., 1].clamp(0, dy - 1) + iz = ii[..., 2].clamp(0, dz - 1) + flat = (ix * (dy * dz) + iy * dz + iz).reshape(-1) + vol.index_add_(0, flat, wgt.reshape(-1)) + wp = wgt.pow(color_sharpen) if sharp else wgt # winner-take-more colour weight + colvol.index_add_(0, flat, (wp[..., None] * rgb[gi, None, :]).reshape(-1, 3).to(col_dtype)) + if sharp: + wcol.index_add_(0, flat, wp.reshape(-1).to(col_dtype)) + done += gi.numel() + if progress is not None: + progress(min(1.0, done / max(1, n))) + colnorm = (wcol if sharp else vol).reshape(dx, dy, dz) # p==1 -> Sum(w) == density + return vol.reshape(dx, dy, dz), colvol.reshape(dx, dy, dz, 3), colnorm, lo.cpu().numpy(), float(voxel) + + +def _connected_components_gpu(faces, nv): + # FastSV connected components: grandparent hooking + shortcutting, ~O(log nv) iterations. + # Returns per-vertex component labels (min node id, not densified). + a = torch.cat([faces[:, 0], faces[:, 1]]) # 2F edge endpoints: (v0,v1),(v1,v2) + b = torch.cat([faces[:, 1], faces[:, 2]]) + f = torch.arange(nv, device=faces.device) + while True: + gp = f[f] # grandparent + ga, gb = gp[a], gp[b] + new = f.clone() + new.scatter_reduce_(0, f[a], gb, "amin", include_self=True) # stochastic hooking onto roots + new.scatter_reduce_(0, f[b], ga, "amin", include_self=True) + new.scatter_reduce_(0, a, gb, "amin", include_self=True) # aggressive hooking, both directions + new.scatter_reduce_(0, b, ga, "amin", include_self=True) + new = new[new] # shortcut (path compression) + if torch.equal(new, f): + return f + f = new + + +def _clean_components_gpu(verts, faces, min_verts, device): + # GPU port of _clean_components: FastSV components + scatter reductions. Byte-identical to the numpy path + vt = torch.as_tensor(verts, device=device) + ft = torch.as_tensor(faces, device=device) + nv = vt.shape[0] + _, label = torch.unique(_connected_components_gpu(ft, nv), return_inverse=True) # dense 0..ncomp-1 + ncomp = int(label.max()) + 1 + flabel = label[ft[:, 0]] # component id per face + keep = torch.bincount(label, minlength=ncomp) >= min_verts # per-component vertex-count gate + if int(keep.sum()) > 1: + fcount = torch.bincount(flabel, minlength=ncomp) + largest = int(torch.where(keep, fcount, fcount.new_tensor(-1)).argmax()) + v0, v1, v2 = vt[ft[:, 0]], vt[ft[:, 1]], vt[ft[:, 2]] + cvol = torch.zeros(ncomp, device=device).scatter_add_(0, flabel, (v0 * torch.linalg.cross(v1, v2)).sum(-1)) + idx3 = label[:, None].expand(-1, 3) # per-component vertex bbox + cmin = torch.full((ncomp, 3), float("inf"), device=device).scatter_reduce_(0, idx3, vt, "amin", include_self=True) + cmax = torch.full((ncomp, 3), float("-inf"), device=device).scatter_reduce_(0, idx3, vt, "amax", include_self=True) + tol = 1e-4 * (cmax[largest] - cmin[largest]).max() + enclosed = (cmin >= cmin[largest] - tol).all(1) & (cmax <= cmax[largest] + tol).all(1) + inner = enclosed & (torch.sign(cvol) != torch.sign(cvol[largest])) & (torch.arange(ncomp, device=device) != largest) + keep &= ~inner + faces_k = ft[keep[flabel]] + if faces_k.shape[0] == 0: + return verts[:0], faces[:0] + used = torch.unique(faces_k) # sorted, matches np.unique + remap = torch.full((nv,), -1, dtype=torch.int64, device=device) + remap[used] = torch.arange(used.shape[0], device=device) + return vt[used].cpu().numpy(), remap[faces_k].cpu().numpy() + + +def _clean_components(verts, faces, min_verts, device=None): + # Drop floaters (components with < min_verts vertices) and inner shells - the surfel shell density + # extracts a double wall (outer + inner cavity surface). GPU path (FastSV CC + scatter reductions, ~13x + # faster) when an accelerator has headroom; else numpy/scipy. Both produce byte-identical output. + if device is not None and not comfy.model_management.is_device_cpu(device) and \ + comfy.model_management.get_free_memory(device) > 10 * faces.size * 8: # peak ~8.4x faces bytes + return _clean_components_gpu(verts, faces, min_verts, device) + nv = len(verts) + e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0) + ncomp, label = connected_components(coo_matrix((np.ones(len(e)), (e[:, 0], e[:, 1])), shape=(nv, nv)), directed=False) + flabel = label[faces[:, 0]] # component id per face + keep = np.bincount(label, minlength=ncomp) >= min_verts # per-component vertex-count gate + if keep.sum() > 1: + fcount = np.bincount(flabel, minlength=ncomp) + largest = np.where(keep, fcount, -1).argmax() + v0, v1, v2 = verts[faces[:, 0]], verts[faces[:, 1]], verts[faces[:, 2]] + cvol = np.bincount(flabel, weights=np.einsum("ij,ij->i", v0, np.cross(v1, v2)), minlength=ncomp) # 6*signed vol + cidx = np.arange(ncomp) # per-component vertex bbox via ndimage (~6x faster than ufunc.at) + cmin = np.stack([_ndi_minimum(verts[:, a], label, cidx) for a in range(3)], 1) + cmax = np.stack([_ndi_maximum(verts[:, a], label, cidx) for a in range(3)], 1) + tol = 1e-4 * (cmax[largest] - cmin[largest]).max() + enclosed = (cmin >= cmin[largest] - tol).all(1) & (cmax <= cmax[largest] + tol).all(1) + inner = enclosed & (np.sign(cvol) != np.sign(cvol[largest])) & (np.arange(ncomp) != largest) + keep &= ~inner + faces = faces[keep[flabel]] + if len(faces) == 0: + return verts[:0], faces + used = np.unique(faces) + remap = np.full(nv, -1, np.int64) + remap[used] = np.arange(len(used)) + return verts[used], remap[faces] + + +def _surface_nets(vol, level, voxel, origin, device): + # Vectorized Surface Nets: one dual vertex per sign-changing cell at its edge-crossing mean, quads wound CCW-outward. + # Returns verts (V,3), faces (F,3). + vol = vol.to(device=device, dtype=torch.float32) + dx, dy, dz = vol.shape + origin_t = torch.as_tensor(origin, device=device, dtype=torch.float32) + empty = (np.zeros((0, 3), np.float32), np.zeros((0, 3), np.int64)) + if dx < 2 or dy < 2 or dz < 2: + return empty + + # Active = cells whose 8 corners aren't all in/all out. + inside = vol >= level # (dx,dy,dz) bool + cs8 = [inside[ox:ox + dx - 1, oy:oy + dy - 1, oz:oz + dz - 1] + for ox, oy, oz in ((0, 0, 0), (1, 0, 0), (0, 1, 0), (1, 1, 0), + (0, 0, 1), (1, 0, 1), (0, 1, 1), (1, 1, 1))] + any_in = cs8[0] | cs8[1] | cs8[2] | cs8[3] | cs8[4] | cs8[5] | cs8[6] | cs8[7] + all_in = cs8[0] & cs8[1] & cs8[2] & cs8[3] & cs8[4] & cs8[5] & cs8[6] & cs8[7] + active = any_in & ~all_in # (cx,cy,cz) straddling cells + nv = int(active.sum()) + if nv == 0: + return empty + + # Active cells only (a thin shell): each dual vertex = mean of its 12 edges' zero-crossings. + del any_in, all_in, cs8 # corner bool grids no longer needed + ac = active.nonzero(as_tuple=False) # (nv,3) cell min-corner indices + offs = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], + [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]], device=device) + offf = offs.to(torch.float32) + edges = torch.tensor([[0, 1], [0, 2], [0, 4], [1, 3], [1, 5], [2, 3], + [2, 6], [3, 7], [4, 5], [4, 6], [5, 7], [6, 7]], device=device) + e0, e1 = edges[:, 0], edges[:, 1] + oe0, oe1 = offf[e0], offf[e1] # (12,3) edge endpoints + + cstep = 1 << 18 # chunk to bound peak memory (CPU RAM too) + loc = [] + for st in range(0, nv, cstep): + ci = ac[st:st + cstep, None, :] + offs[None] # (m,8,3) + cval = vol[ci[..., 0], ci[..., 1], ci[..., 2]] # (m,8) corner values + csl = cval >= level + v0, v1 = cval[:, e0], cval[:, e1] # (m,12) + cross = (csl[:, e0] != csl[:, e1])[..., None].to(torch.float32) + denom = v1 - v0 + t = torch.where(denom.abs() > 1e-12, (level - v0) / denom, torch.full_like(denom, 0.5)).clamp(0, 1) + pts = torch.lerp(oe0, oe1, t[..., None]) # (m,12,3) local crossings (fused interp) + loc.append((pts * cross).sum(1) / cross.sum(1).clamp_min(1.0)) # (m,3) in [0,1] + local = torch.cat(loc, 0) if len(loc) > 1 else loc[0] # (nv,3) + verts = origin_t + (ac.to(torch.float32) + local) * voxel # world space + del loc, local, ac + + vid = torch.full((dx - 1, dy - 1, dz - 1), -1, dtype=torch.int32, device=device) + vid[active] = torch.arange(nv, dtype=torch.int32, device=device) + del active + + # Each straddling grid edge -> one quad from its 4 cells; `sol` (low-end sign) picks outward winding. + faces = [] + + def emit(cr, sol, a, b, d, c): + valid = cr & (a >= 0) & (b >= 0) & (c >= 0) & (d >= 0) + if not bool(valid.any()): + return + a, b, c, d, sol = a[valid], b[valid], c[valid], d[valid], sol[valid] + p2, p4 = torch.where(sol, b, c), torch.where(sol, c, b) # reverse quad winding where ~sol + faces.append(torch.stack([a, p2, d], 1)) + faces.append(torch.stack([a, d, p4], 1)) + + a = inside[0:dx - 1, 1:dy - 1, 1:dz - 1] + emit(a != inside[1:dx, 1:dy - 1, 1:dz - 1], a, + vid[:, 0:dy - 2, 0:dz - 2], vid[:, 1:dy - 1, 0:dz - 2], + vid[:, 1:dy - 1, 1:dz - 1], vid[:, 0:dy - 2, 1:dz - 1]) + a = inside[1:dx - 1, 0:dy - 1, 1:dz - 1] + emit(a != inside[1:dx - 1, 1:dy, 1:dz - 1], a, + vid[0:dx - 2, :, 0:dz - 2], vid[0:dx - 2, :, 1:dz - 1], + vid[1:dx - 1, :, 1:dz - 1], vid[1:dx - 1, :, 0:dz - 2]) + a = inside[1:dx - 1, 1:dy - 1, 0:dz - 1] + emit(a != inside[1:dx - 1, 1:dy - 1, 1:dz], a, + vid[0:dx - 2, 0:dy - 2, :], vid[1:dx - 1, 0:dy - 2, :], + vid[1:dx - 1, 1:dy - 1, :], vid[0:dx - 2, 1:dy - 1, :]) + + if not faces: + return empty + return verts.cpu().numpy().astype(np.float32), torch.cat(faces, 0).cpu().numpy().astype(np.int64) + + +def _otsu_level(values, bins=256): + # Otsu threshold: the density value that best splits inside/outside (max between-class variance). + hist, edges = np.histogram(values, bins=bins) + hist = hist.astype(np.float64) + centers = (edges[:-1] + edges[1:]) * 0.5 + w = np.cumsum(hist) # background-class weight at each split + mu = np.cumsum(hist * centers) + wf = w[-1] - w # foreground-class weight + mb = mu / np.where(w > 0, w, 1.0) + mf = (mu[-1] - mu) / np.where(wf > 0, wf, 1.0) + var_b = w * wf * (mb - mf) ** 2 # between-class variance + var_b[(w <= 0) | (wf <= 0)] = -1.0 + return float(centers[int(np.argmax(var_b))]) + + +def _taubin_smooth(verts, faces, iters, lam=0.5, mu=-0.53): + # Taubin lambda|mu smoothing: low-pass the mesh surface without the shrinkage of a Laplacian blur + # (the mu inflation pass cancels the lambda pass's volume loss). Uniform (umbrella) weights. + if iters <= 0 or len(verts) == 0 or len(faces) == 0: + return verts + nv = len(verts) + e = np.concatenate([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [0, 2]]], 0) + e = np.concatenate([e, e[:, ::-1]], 0) # symmetric adjacency + adj = coo_matrix((np.ones(len(e), np.float32), (e[:, 0], e[:, 1])), shape=(nv, nv)).tocsr() + adj.data[:] = 1.0 + deg = np.clip(np.asarray(adj.sum(1)).ravel(), 1.0, None).astype(np.float32)[:, None] + v = verts.astype(np.float32) # fp32 matvec: ~2x faster, sub-micron drift on unit-scale verts + for _ in range(int(iters)): + for fac in (lam, mu): + v = v + np.float32(fac) * ((adj @ v) / deg - v) # fac * (mean(neighbours) - v) + return np.ascontiguousarray(v) + + +def _sample_vertex_colours_gpu(colvol, colnorm, verts, origin, voxel, device): + # GPU trilinear sampling of the colour numerator (3ch) and normaliser (1ch) at vertex grid-coords + # reproduces scipy map_coordinates(order=1, mode='nearest'). Returns col (V,3) numpy. + dx, dy, dz = colnorm.shape + vt = torch.as_tensor(verts, device=device, dtype=torch.float32) + org = torch.as_tensor(origin, device=device, dtype=torch.float32) + gi = (vt - org) / voxel # (V,3) grid-index coords (x,y,z) + size = torch.tensor([dx, dy, dz], device=device, dtype=torch.float32) + g = 2.0 * gi / (size - 1).clamp_min(1.0) - 1.0 # -> [-1,1] (align_corners) + grid = torch.stack([g[:, 2], g[:, 1], g[:, 0]], -1)[None, None, None] # (1,1,1,V,3): grid_sample order (W=z,H=y,D=x) + + def samp(v): # (dx,dy,dz,C) cpu fp16 -> (C,V) fp32 on device + inp = v.to(device).permute(3, 0, 1, 2)[None].float() + o = torch.nn.functional.grid_sample(inp, grid, mode="bilinear", padding_mode="border", align_corners=True) + return o[0, :, 0, 0, :] + num = samp(colvol) # (3,V) + den = samp(colnorm[..., None]) # (1,V) + return (num / den.clamp_min(1e-8)).T.cpu().numpy() # (V,3) + + +def _gaussian_to_mesh(g: Types.SPLAT, i, res, kernel, taubin, level_bias, min_component, min_opacity, color_sharpen, device, progress=None): + # Mesh one splat: density + colour grids -> Surface Nets -> floater removal -> Taubin smoothing -> + # volume-sampled colours. Returns (verts, faces int64, colors in [0,1]), or None if no surface. + rep = progress if progress is not None else (lambda *_: None) + + end = _real_len(g, i) + xyz = g.positions[i, :end].to(device=device, dtype=torch.float32) + scale = g.scales[i, :end].to(device=device, dtype=torch.float32) + quat = g.rotations[i, :end].to(device=device, dtype=torch.float32) + opacity = g.opacities[i, :end].reshape(-1).to(device=device, dtype=torch.float32) + rgb = (g.sh[i, :end, 0, :].to(device=device, dtype=torch.float32) * _C0 + 0.5).clamp(0, 1) + + keep = opacity >= min_opacity + xyz, scale, quat, opacity, rgb = xyz[keep], scale[keep], quat[keep], opacity[keep], rgb[keep] + if xyz.shape[0] == 0: + return None + + vol, colvol, colnorm, origin, voxel = _splat_density(xyz, opacity, scale, quat, rgb, res, kernel, device, + color_sharpen=color_sharpen, + progress=lambda f: rep(0.25 * f)) # density build: 0 -> 25% + # Colour: sample on the GPU (grid_sample) when there's headroom + colour_gpu = not comfy.model_management.is_device_cpu(device) and comfy.model_management.get_free_memory(device) > 6 * vol.numel() * 4 + if colour_gpu: + colvol_cpu, colnorm_cpu = colvol.cpu(), colnorm.half().cpu() # park colours (fp16) off-GPU during meshing + colvol_np = colnorm_np = None + else: + colvol_np = colvol.cpu().numpy().astype(np.float32) # Sum(w^p * rgb) colour numerator (fp16 grid -> fp32) + colnorm_np = colnorm.cpu().numpy().astype(np.float32) # Sum(w^p) colour normaliser + del colvol, colnorm # free the colour grids before iso-surfacing + rep(0.40) + + vmin, vmax = float(vol.min()), float(vol.max()) + occ = vol[vol > vmax * 1e-3] # occupied voxels (skip the empty-space peak) + if occ.numel() == 0: + return None + # Otsu picks the inside/outside split principledly; `level_bias` nudges it (1.0 = auto). Clamp strictly + # inside the data range so a bias can't push the iso off the histogram. + level = min(max(_otsu_level(occ.cpu().numpy()) * level_bias, vmin + 1e-6 * (vmax - vmin)), + vmax - 1e-6 * (vmax - vmin)) + + # Iso-surface on the accelerator when there's headroom: ~15x faster than CPU, identical output. Chunked + # Surface Nets peaks at ~3-3.5x the density grid, so fall back to CPU for large grids / tight VRAM. + sn_dev = device + if not comfy.model_management.is_device_cpu(device) and comfy.model_management.get_free_memory(device) < 6 * vol.numel() * 4: + sn_dev = torch.device("cpu") + vol = vol.cpu() + verts, faces = _surface_nets(vol, level, voxel, origin, sn_dev) + del vol + rep(0.55) + if min_component > 0 and len(faces) > 0: + verts, faces = _clean_components(verts, faces, min_component, device) + if len(verts) == 0 or len(faces) == 0: + return None + + # Taubin smooths the blocky iso without shrinking it (unlike blurring the density, which rounds features). + verts = _taubin_smooth(verts, faces, taubin) + rep(0.7) + + # Colour each vertex from the co-splatted colour volume: trilinearly sample the numerator Sum(w^p*rgb) + # and normaliser Sum(w^p) separately, then divide. Normalising AFTER interpolation keeps zero-density + # edge voxels from pulling colours toward black, and matches the gaussians that formed the surface. + if colour_gpu: + col = _sample_vertex_colours_gpu(colvol_cpu, colnorm_cpu, verts, origin, voxel, device) + else: + coords = ((verts - origin) / voxel).T # (3, V) grid-index coords, matching volume axes + num = np.stack([map_coordinates(colvol_np[..., c], coords, order=1, mode="nearest") for c in range(3)], -1) + den = map_coordinates(colnorm_np, coords, order=1, mode="nearest") + col = num / np.clip(den, 1e-8, None)[:, None] + rep(1.0) + + # The unlit material's COLOR_0 is linear and the viewer sRGB-encodes it on output; the splat colours + # are display (sRGB) values, so convert sRGB -> linear here to land at the same brightness as the splat. + col = np.clip(col, 0, 1) + col = np.where(col <= 0.04045, col / 12.92, ((col + 0.055) / 1.055) ** 2.4).astype(np.float32) + + # Splat +Y is glTF's -Y: rotate 180 deg about X (negate Y,Z) to land upright. Proper rotation, so + # winding is kept; done after colouring (which works in the splat frame). + verts = np.ascontiguousarray(verts * np.array([1.0, -1.0, -1.0], dtype=np.float32)) + return (torch.from_numpy(verts), torch.from_numpy(faces), torch.from_numpy(col)) + + +class SplatToMesh(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SplatToMesh", + display_name="Extract Mesh from Splat", + search_aliases=["splat to mesh", "gaussian surface nets", "splat surface", "mesh splat"], + category="3d/splat", + description="Extract a coloured mesh from a gaussian splat.", + inputs=[ + IO.Splat.Input("splat"), + IO.Int.Input("resolution", default=384, min=64, max=768, step=16, + tooltip="Density-grid resolution along the longest axis. Higher = finer surface, " + "more VRAM/time (grows with resolution^3)."), + IO.Int.Input("kernel", default=5, min=1, max=8, + tooltip="Max splat half-width in voxels. Each gaussian is rasterized over a window " + "sized to its own 3-sigma, capped here - small surfels stay cheap, large ones " + "aren't truncated. Raise if sparse splats leave gaps."), + IO.Int.Input("smooth", default=0, min=0, max=60, advanced = True, + tooltip="Taubin mesh-smoothing iterations. Smooths the surface without shrinking it " + "(volume-preserving), unlike blurring the density. 0 = raw surface."), + IO.Float.Input("level", default=0.4, min=0.0, max=2.0, step=0.01, + tooltip="Iso-surface level. Auto-picked by Otsu; this biases it (1.0 = auto, lower = " + "fatter/more-connected surface, higher = thinner/tighter)."), + IO.Int.Input("min_component", default=500, min=0, max=100000, step=50, advanced=True, + tooltip="Drop connected components smaller than this many vertices (0 = keep all). " + "Removes detached floater blobs and the inner shell of the double wall."), + IO.Float.Input("min_opacity", default=0.02, min=0.0, max=1.0, step=0.01, advanced=True, + tooltip="Ignore gaussians fainter than this before meshing."), + IO.Float.Input("color_sharpen", default=2.0, min=1.0, max=8.0, step=0.5, + tooltip="Crisp up the vertex texture: 1.0 = physically-correct blend; higher biases " + "each voxel's colour toward its dominant gaussian instead of averaging " + "neighbours (de-smears the texture). Colour only - geometry is unchanged."), + ], + outputs=[IO.Mesh.Output(display_name="mesh")], + ) + + @classmethod + def execute(cls, splat, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen) -> IO.NodeOutput: + device = comfy.model_management.get_torch_device() + b = splat.positions.shape[0] + prec = 1000 # each splat owns a 0..prec block of the bar; its callback advances within that block + pbar = comfy.utils.ProgressBar(b * prec) + + verts_l, faces_l, colors_l = [], [], [] + for i in range(b): + cb = lambda f, base=i * prec: pbar.update_absolute(base + int(min(max(f, 0.0), 1.0) * prec)) + res = _gaussian_to_mesh(splat, i, resolution, kernel, smooth, level, min_component, min_opacity, color_sharpen, device, cb) + if res is None: + logging.warning("SplatToMesh: splat %d produced no surface; emitting an empty mesh.", i) + v, f, c = torch.zeros((0, 3)), torch.zeros((0, 3), dtype=torch.int64), torch.zeros((0, 3)) + else: + v, f, c = res + verts_l.append(v) + faces_l.append(f) + colors_l.append(c) + pbar.update_absolute((i + 1) * prec) # snap to block end (covers empty / early-out splats) + # unlit: render flat (emissive-like) so SaveGLB matches the splat instead of lighting/washing it. + return IO.NodeOutput(pack_variable_mesh_batch(verts_l, faces_l, colors=colors_l, unlit=True)) + + +class GaussianExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [SplatToFile3D, File3DToSplat, RenderSplat, CreateCameraInfo, TransformSplat, + GetSplatCount, MergeSplat, SplatToMesh] + + +async def comfy_entrypoint() -> GaussianExtension: + return GaussianExtension() diff --git a/comfy_extras/nodes_gits.py b/comfy_extras/nodes_gits.py index 0b7666524..434a24387 100644 --- a/comfy_extras/nodes_gits.py +++ b/comfy_extras/nodes_gits.py @@ -340,7 +340,7 @@ class GITSScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="GITSScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Float.Input("coeff", default=1.20, min=0.80, max=1.50, step=0.05, advanced=True), io.Int.Input("steps", default=10, min=2, max=1000), diff --git a/comfy_extras/nodes_hidream_o1.py b/comfy_extras/nodes_hidream_o1.py index f393745f6..8648d2e26 100644 --- a/comfy_extras/nodes_hidream_o1.py +++ b/comfy_extras/nodes_hidream_o1.py @@ -14,7 +14,7 @@ class EmptyHiDreamO1LatentImage(io.ComfyNode): return io.Schema( node_id="EmptyHiDreamO1LatentImage", display_name="Empty HiDream-O1 Latent Image", - category="latent/image", + category="model/latent/image", description=( "Empty pixel-space latent for HiDream-O1-Image. The model was " "trained at ~4 megapixels; lower resolutions go off-distribution " @@ -47,7 +47,7 @@ class HiDreamO1ReferenceImages(io.ComfyNode): return io.Schema( node_id="HiDreamO1ReferenceImages", display_name="HiDream-O1 Reference Images", - category="conditioning/image", + category="model/conditioning/image", description=( "Attach 1-10 reference images to conditioning, one for edit instruction" "or multiple for subject-driven personalization." diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index 9e4873be5..16fff12af 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -41,7 +41,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode): return io.Schema( node_id="EmptyHunyuanLatentVideo", display_name="Empty HunyuanVideo 1.0 Latent", - category="latent/video", + category="model/latent/video", inputs=[ io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), @@ -81,7 +81,7 @@ class HunyuanVideo15ImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="HunyuanVideo15ImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -132,7 +132,7 @@ class HunyuanVideo15SuperResolution(io.ComfyNode): return io.Schema( node_id="HunyuanVideo15SuperResolution", display_name="Hunyuan Video 1.5 Super Resolution", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -178,7 +178,7 @@ class LatentUpscaleModelLoader(io.ComfyNode): return io.Schema( node_id="LatentUpscaleModelLoader", display_name="Load Latent Upscale Model", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")), ], @@ -227,7 +227,7 @@ class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode): return io.Schema( node_id="HunyuanVideo15LatentUpscaleWithModel", display_name="Hunyuan Video 15 Latent Upscale With Model", - category="latent", + category="model/latent", inputs=[ io.LatentUpscaleModel.Input("model"), io.Latent.Input("samples"), @@ -308,7 +308,7 @@ class HunyuanImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="HunyuanImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Vae.Input("vae"), @@ -359,7 +359,7 @@ class EmptyHunyuanImageLatent(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="EmptyHunyuanImageLatent", - category="latent", + category="model/latent", inputs=[ io.Int.Input("width", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32), io.Int.Input("height", default=2048, min=64, max=nodes.MAX_RESOLUTION, step=32), @@ -384,7 +384,7 @@ class HunyuanRefinerLatent(io.ComfyNode): return io.Schema( node_id="HunyuanRefinerLatent", display_name="Hunyuan Latent Refiner", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index bcd3f9198..60e530626 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -12,7 +12,7 @@ class EmptyLatentHunyuan3Dv2(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="EmptyLatentHunyuan3Dv2", - category="latent/3d", + category="model/latent/3d", inputs=[ IO.Int.Input("resolution", default=3072, min=1, max=8192), IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), @@ -35,7 +35,7 @@ class Hunyuan3Dv2Conditioning(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="Hunyuan3Dv2Conditioning", - category="conditioning/3d_models", + category="model/conditioning/3d_models", inputs=[ IO.ClipVisionOutput.Input("clip_vision_output"), ], @@ -60,7 +60,7 @@ class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="Hunyuan3Dv2ConditioningMultiView", - category="conditioning/3d_models", + category="model/conditioning/3d_models", inputs=[ IO.ClipVisionOutput.Input("front", optional=True), IO.ClipVisionOutput.Input("left", optional=True), @@ -97,7 +97,7 @@ class VAEDecodeHunyuan3D(IO.ComfyNode): def define_schema(cls): return IO.Schema( node_id="VAEDecodeHunyuan3D", - category="latent/3d", + category="model/latent/3d", inputs=[ IO.Latent.Input("samples"), IO.Vae.Input("vae"), diff --git a/comfy_extras/nodes_hypernetwork.py b/comfy_extras/nodes_hypernetwork.py index 44a9c6f97..2d3f1bd05 100644 --- a/comfy_extras/nodes_hypernetwork.py +++ b/comfy_extras/nodes_hypernetwork.py @@ -103,7 +103,7 @@ class HypernetworkLoader(IO.ComfyNode): return IO.Schema( node_id="HypernetworkLoader", display_name="Load Hypernetwork", - category="loaders", + category="model/loaders", inputs=[ IO.Model.Input("model"), IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")), diff --git a/comfy_extras/nodes_hypertile.py b/comfy_extras/nodes_hypertile.py index 354d96db1..2a96416be 100644 --- a/comfy_extras/nodes_hypertile.py +++ b/comfy_extras/nodes_hypertile.py @@ -27,7 +27,7 @@ class HyperTile(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="HyperTile", - category="model_patches/unet", + category="model/patch/unet", inputs=[ io.Model.Input("model"), io.Int.Input("tile_size", default=256, min=1, max=2048, advanced=True), diff --git a/comfy_extras/nodes_ideogram4.py b/comfy_extras/nodes_ideogram4.py new file mode 100644 index 000000000..d5827db4f --- /dev/null +++ b/comfy_extras/nodes_ideogram4.py @@ -0,0 +1,64 @@ +"""Ideogram 4 sampling helper +""" + +import math + +import torch +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + +_LOGSNR_MIN = -15.0 +_LOGSNR_MAX = 18.0 + + +def _logit_normal_schedule(u, mean, std): + # Reference time (0=noise..1=clean) via the probit/ndtri quantile. + u = torch.as_tensor(u, dtype=torch.float64) + t = 1.0 - torch.special.expit(mean + std * torch.special.ndtri(u)) + t_min = 1.0 / (1.0 + math.exp(0.5 * _LOGSNR_MAX)) + t_max = 1.0 / (1.0 + math.exp(0.5 * _LOGSNR_MIN)) + return t.clamp(t_min, t_max) + + +def ideogram4_sigmas(num_steps, width, height, mu, std): + """Descending sigmas (len num_steps+1) for the reference schedule. + + mu + the resolution term form the logSNR shift; std is the spread. + """ + mean = mu + 0.5 * math.log((width * height) / (512 * 512)) + u = torch.linspace(0.0, 1.0, num_steps + 1, dtype=torch.float64) + sigmas = (1.0 - _logit_normal_schedule(u, mean, std)).flip(0) + sigmas[-1] = 0.0 # clamp leaves ~6e-4; force full denoise + return sigmas.to(torch.float32) + + +class Ideogram4Scheduler(io.ComfyNode): + @classmethod + def define_schema(cls) -> io.Schema: + return io.Schema( + node_id="Ideogram4Scheduler", + display_name="Ideogram 4 Scheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=200), + io.Int.Input("width", default=1024, min=256, max=8192, step=16), + io.Int.Input("height", default=1024, min=256, max=8192, step=16), + io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.05), + io.Float.Input("std", default=1.75, min=0.1, max=5.0, step=0.05), + ], + outputs=[io.Sigmas.Output()], + ) + + @classmethod + def execute(cls, steps, width, height, mu, std) -> io.NodeOutput: + return io.NodeOutput(ideogram4_sigmas(steps, width, height, mu, std)) + + +class Ideogram4Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [Ideogram4Scheduler] + + +async def comfy_entrypoint() -> Ideogram4Extension: + return Ideogram4Extension() diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index fe6008aa3..469a7be55 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -95,7 +95,7 @@ class BoundingBox(IO.ComfyNode): return IO.Schema( node_id="PrimitiveBoundingBox", display_name="Bounding Box", - category="utils/primitive", + category="utilities/primitive", inputs=[ IO.Int.Input("x", default=0, min=0, max=MAX_RESOLUTION), IO.Int.Input("y", default=0, min=0, max=MAX_RESOLUTION), diff --git a/comfy_extras/nodes_ip2p.py b/comfy_extras/nodes_ip2p.py index 78f29915d..9c80834f0 100644 --- a/comfy_extras/nodes_ip2p.py +++ b/comfy_extras/nodes_ip2p.py @@ -9,7 +9,7 @@ class InstructPixToPixConditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="InstructPixToPixConditioning", - category="conditioning/instructpix2pix", + category="model/conditioning/instructpix2pix", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index 346c50cde..015965498 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -13,7 +13,7 @@ class Kandinsky5ImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Kandinsky5ImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -71,7 +71,7 @@ class NormalizeVideoLatentStart(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="NormalizeVideoLatentStart", - category="conditioning/video_models", + category="model/conditioning/video_models", description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.", inputs=[ io.Latent.Input("latent"), diff --git a/comfy_extras/nodes_latent.py b/comfy_extras/nodes_latent.py index 8bb368dec..32da9e8ac 100644 --- a/comfy_extras/nodes_latent.py +++ b/comfy_extras/nodes_latent.py @@ -22,7 +22,7 @@ class LatentAdd(io.ComfyNode): return io.Schema( node_id="LatentAdd", search_aliases=["combine latents", "sum latents"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples1"), io.Latent.Input("samples2"), @@ -49,7 +49,7 @@ class LatentSubtract(io.ComfyNode): return io.Schema( node_id="LatentSubtract", search_aliases=["difference latent", "remove features"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples1"), io.Latent.Input("samples2"), @@ -76,7 +76,7 @@ class LatentMultiply(io.ComfyNode): return io.Schema( node_id="LatentMultiply", search_aliases=["scale latent", "amplify latent", "latent gain"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples"), io.Float.Input("multiplier", default=1.0, min=-10.0, max=10.0, step=0.01), @@ -100,7 +100,7 @@ class LatentInterpolate(io.ComfyNode): return io.Schema( node_id="LatentInterpolate", search_aliases=["blend latent", "mix latent", "lerp latent", "transition"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples1"), io.Latent.Input("samples2"), @@ -139,7 +139,7 @@ class LatentConcat(io.ComfyNode): return io.Schema( node_id="LatentConcat", search_aliases=["join latents", "stitch latents"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples1"), io.Latent.Input("samples2"), @@ -179,7 +179,7 @@ class LatentCut(io.ComfyNode): return io.Schema( node_id="LatentCut", search_aliases=["crop latent", "slice latent", "extract region"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples"), io.Combo.Input("dim", options=["x", "y", "t"]), @@ -220,7 +220,7 @@ class LatentCutToBatch(io.ComfyNode): return io.Schema( node_id="LatentCutToBatch", search_aliases=["slice to batch", "split latent", "tile latent"], - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples"), io.Combo.Input("dim", options=["t", "x", "y"]), @@ -262,7 +262,7 @@ class LatentBatch(io.ComfyNode): return io.Schema( node_id="LatentBatch", search_aliases=["combine latents", "merge latents", "join latents"], - category="latent/batch", + category="model/latent/batch", is_deprecated=True, inputs=[ io.Latent.Input("samples1"), @@ -290,7 +290,7 @@ class LatentBatchSeedBehavior(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentBatchSeedBehavior", - category="latent/advanced", + category="model/latent/advanced", inputs=[ io.Latent.Input("samples"), io.Combo.Input("seed_behavior", options=["random", "fixed"], default="fixed"), @@ -319,7 +319,7 @@ class LatentApplyOperation(io.ComfyNode): return io.Schema( node_id="LatentApplyOperation", search_aliases=["transform latent"], - category="latent/advanced/operations", + category="model/latent/advanced/operations", is_experimental=True, inputs=[ io.Latent.Input("samples"), @@ -343,7 +343,7 @@ class LatentApplyOperationCFG(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentApplyOperationCFG", - category="latent/advanced/operations", + category="model/latent/advanced/operations", is_experimental=True, inputs=[ io.Model.Input("model"), @@ -375,7 +375,7 @@ class LatentOperationTonemapReinhard(io.ComfyNode): return io.Schema( node_id="LatentOperationTonemapReinhard", search_aliases=["hdr latent"], - category="latent/advanced/operations", + category="model/latent/advanced/operations", is_experimental=True, inputs=[ io.Float.Input("multiplier", default=1.0, min=0.0, max=100.0, step=0.01), @@ -410,7 +410,7 @@ class LatentOperationSharpen(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LatentOperationSharpen", - category="latent/advanced/operations", + category="model/latent/advanced/operations", is_experimental=True, inputs=[ io.Int.Input("sharpen_radius", default=9, min=1, max=31, step=1, advanced=True), @@ -447,7 +447,7 @@ class ReplaceVideoLatentFrames(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ReplaceVideoLatentFrames", - category="latent/batch", + category="model/latent/batch", inputs=[ io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."), io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."), diff --git a/comfy_extras/nodes_load_3d.py b/comfy_extras/nodes_load_3d.py index 9112bdd0a..455897859 100644 --- a/comfy_extras/nodes_load_3d.py +++ b/comfy_extras/nodes_load_3d.py @@ -34,7 +34,7 @@ class Load3D(IO.ComfyNode): essentials_category="Basics", is_experimental=True, inputs=[ - IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model), + IO.Combo.Input("model_file", options=["none"] + sorted(files), upload=IO.UploadType.model), IO.Load3D.Input("image"), IO.Int.Input("width", default=1024, min=1, max=4096, step=1), IO.Int.Input("height", default=1024, min=1, max=4096, step=1), @@ -47,9 +47,18 @@ class Load3D(IO.ComfyNode): IO.Load3DCamera.Output(display_name="camera_info"), IO.Video.Output(display_name="recording_video"), IO.File3DAny.Output(display_name="model_3d"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), ], ) + @classmethod + def validate_inputs(cls, model_file, **kwargs) -> bool | str: + if not model_file or model_file == "none": + return True + if not folder_paths.exists_annotated_filepath(model_file): + return f"Invalid 3D model file: {model_file}" + return True + @classmethod def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput: image_path = folder_paths.get_annotated_filepath(image['image']) @@ -68,8 +77,13 @@ class Load3D(IO.ComfyNode): video = InputImpl.VideoFromFile(recording_video_path) - file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file)) - return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video, file_3d) + file_3d = None + mesh_path = "" + if model_file and model_file != "none": + file_3d = Types.File3D(folder_paths.get_annotated_filepath(model_file)) + mesh_path = model_file + model_3d_info = image.get('model_3d_info', []) + return IO.NodeOutput(output_image, output_mask, mesh_path, normal_image, image['camera_info'], video, file_3d, model_3d_info) process = execute # TODO: remove @@ -118,12 +132,200 @@ class Preview3D(IO.ComfyNode): process = execute # TODO: remove +class Preview3DAdvanced(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Preview3DAdvanced", + display_name="Preview 3D (Advanced)", + search_aliases=["preview 3d", "3d viewer", "view mesh", "frame 3d", "3d camera output"], + category="3d", + is_experimental=True, + is_output_node=True, + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[ + IO.File3DGLB, + IO.File3DGLTF, + IO.File3DFBX, + IO.File3DOBJ, + IO.File3DSTL, + IO.File3DUSDZ, + IO.File3DAny, + ], + tooltip="3D model file from an upstream 3D node.", + ), + IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), + IO.Load3D.Input("viewport_state"), + IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.File3DAny.Output(display_name="model_3d"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + ], + ) + + @classmethod + def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput: + filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_3d.format}" + model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename)) + + camera_info_input = kwargs.get("camera_info", None) + camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info'] + model_3d_info_input = kwargs.get("model_3d_info", None) + model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', []) + return IO.NodeOutput( + model_3d, + model_3d_info, + camera_info, + width, + height, + ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info), + ) + + +class PreviewGaussianSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PreviewGaussianSplat", + display_name="Preview Splat", + category="3d", + is_experimental=True, + is_output_node=True, + search_aliases=[ + "view splat", + "view gaussian", + "view gaussian splat", + "preview gaussian", + "preview gaussian splat", + "view 3dgs", + "preview 3dgs", + "preview ply", + "preview spz", + "preview splat", + "preview ksplat", + ], + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[ + IO.File3DSplatAny, + IO.File3DPLY, + IO.File3DSPLAT, + IO.File3DSPZ, + IO.File3DKSPLAT, + ], + tooltip="A gaussian splat 3D file.", + ), + IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), + IO.Load3D.Input("viewport_state"), + IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.File3DSplatAny.Output(display_name="model_3d"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + ], + ) + + @classmethod + def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput: + filename = f"preview_splat_{uuid.uuid4().hex}.{model_3d.format}" + model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename)) + + camera_info_input = kwargs.get("camera_info", None) + camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info'] + model_3d_info_input = kwargs.get("model_3d_info", None) + model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', []) + return IO.NodeOutput( + model_3d, + model_3d_info, + camera_info, + width, + height, + ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info), + ) + + +class PreviewPointCloud(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PreviewPointCloud", + display_name="Preview Point Cloud", + category="3d", + is_experimental=True, + is_output_node=True, + search_aliases=[ + "view point cloud", + "view pointcloud", + "preview point cloud", + "preview pointcloud", + "preview ply", + ], + inputs=[ + IO.MultiType.Input( + "model_3d", + types=[ + IO.File3DPointCloudAny, + IO.File3DPLY, + ], + tooltip="Point cloud file (.ply)", + ), + IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True), + IO.Load3D.Input("viewport_state"), + IO.Load3DCamera.Input("camera_info", optional=True, advanced=True), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.File3DPointCloudAny.Output(display_name="model_3d"), + IO.Load3DModelInfo.Output(display_name="model_3d_info"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Int.Output(display_name="width"), + IO.Int.Output(display_name="height"), + ], + ) + + @classmethod + def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput: + filename = f"preview_pointcloud_{uuid.uuid4().hex}.{model_3d.format}" + model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename)) + + camera_info_input = kwargs.get("camera_info", None) + camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info'] + model_3d_info_input = kwargs.get("model_3d_info", None) + model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', []) + return IO.NodeOutput( + model_3d, + model_3d_info, + camera_info, + width, + height, + ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info), + ) + + class Load3DExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ Load3D, Preview3D, + Preview3DAdvanced, + PreviewGaussianSplat, + PreviewPointCloud, ] diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py index 92507f1fc..95f6ab848 100644 --- a/comfy_extras/nodes_logic.py +++ b/comfy_extras/nodes_logic.py @@ -13,7 +13,7 @@ class NotNode(io.ComfyNode): return io.Schema( node_id="ComfyNotNode", display_name="Not", - category="utils/logic", + category="utilities/logic", description="Logical NOT operation. Returns true if the value is falsy. Uses Python's rules for truthiness.", search_aliases=["invert", "toggle", "negate", "flip boolean"], inputs=[ @@ -40,7 +40,7 @@ class AndNode(io.ComfyNode): return io.Schema( node_id="ComfyAndNode", display_name="And", - category="utils/logic", + category="utilities/logic", description="Logical AND operation. Returns true if all of the values are truthy. Uses Python's rules for truthiness.", search_aliases=["all", "every"], inputs=[ @@ -67,7 +67,7 @@ class OrNode(io.ComfyNode): return io.Schema( node_id="ComfyOrNode", display_name="Or", - category="utils/logic", + category="utilities/logic", description="Logical OR operation. Returns true if any of the values are truthy. Uses Python's rules for truthiness.", search_aliases=["any", "some"], inputs=[ @@ -90,7 +90,7 @@ class SwitchNode(io.ComfyNode): return io.Schema( node_id="ComfySwitchNode", display_name="Switch", - category="utils/logic", + category="utilities/logic", is_experimental=True, inputs=[ io.Boolean.Input("switch"), @@ -121,7 +121,7 @@ class SoftSwitchNode(io.ComfyNode): return io.Schema( node_id="ComfySoftSwitchNode", display_name="Soft Switch", - category="utils/logic", + category="utilities/logic", is_experimental=True, inputs=[ io.Boolean.Input("switch"), @@ -176,7 +176,7 @@ class CustomComboNode(io.ComfyNode): return io.Schema( node_id="CustomCombo", display_name="Custom Combo", - category="utils", + category="utilities", is_experimental=True, inputs=[io.Combo.Input("choice", options=[])], outputs=[ @@ -211,7 +211,7 @@ class DCTestNode(io.ComfyNode): return io.Schema( node_id="DCTestNode", display_name="DCTest", - category="utils/logic", + category="utilities/logic", is_output_node=True, inputs=[io.DynamicCombo.Input("combo", options=[ io.DynamicCombo.Option("option1", [io.String.Input("string")]), @@ -249,7 +249,7 @@ class AutogrowNamesTestNode(io.ComfyNode): return io.Schema( node_id="AutogrowNamesTestNode", display_name="AutogrowNamesTest", - category="utils/logic", + category="utilities/logic", inputs=[ _io.Autogrow.Input("autogrow", template=template) ], @@ -269,7 +269,7 @@ class AutogrowPrefixTestNode(io.ComfyNode): return io.Schema( node_id="AutogrowPrefixTestNode", display_name="AutogrowPrefixTest", - category="utils/logic", + category="utilities/logic", inputs=[ _io.Autogrow.Input("autogrow", template=template) ], @@ -288,7 +288,7 @@ class ComboOutputTestNode(io.ComfyNode): return io.Schema( node_id="ComboOptionTestNode", display_name="ComboOptionTest", - category="utils/logic", + category="utilities/logic", inputs=[io.Combo.Input("combo", options=["option1", "option2", "option3"]), io.Combo.Input("combo2", options=["option4", "option5", "option6"])], outputs=[io.Combo.Output(), io.Combo.Output()], @@ -305,7 +305,7 @@ class ConvertStringToComboNode(io.ComfyNode): node_id="ConvertStringToComboNode", search_aliases=["string to dropdown", "text to combo"], display_name="Convert String to Combo", - category="utils/logic", + category="utilities/logic", inputs=[io.String.Input("string")], outputs=[io.Combo.Output()], ) @@ -321,7 +321,7 @@ class InvertBooleanNode(io.ComfyNode): node_id="InvertBooleanNode", search_aliases=["not", "toggle", "negate", "flip boolean"], display_name="Invert Boolean", - category="utils/logic", + category="utilities/logic", inputs=[io.Boolean.Input("boolean")], outputs=[io.Boolean.Output()], ) diff --git a/comfy_extras/nodes_lora_debug.py b/comfy_extras/nodes_lora_debug.py index 937a0fbfb..3f68064e5 100644 --- a/comfy_extras/nodes_lora_debug.py +++ b/comfy_extras/nodes_lora_debug.py @@ -30,7 +30,7 @@ class LoraLoaderBypass: OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.") FUNCTION = "load_lora" - CATEGORY = "loaders" + CATEGORY = "model/loaders" DESCRIPTION = "Apply LoRA in bypass mode. Unlike regular LoRA, this doesn't modify model weights - instead it injects the LoRA computation during forward pass. Useful for training scenarios." EXPERIMENTAL = True diff --git a/comfy_extras/nodes_lotus.py b/comfy_extras/nodes_lotus.py index 9f62ba2bf..9fe4c5c7b 100644 --- a/comfy_extras/nodes_lotus.py +++ b/comfy_extras/nodes_lotus.py @@ -10,7 +10,7 @@ class LotusConditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LotusConditioning", - category="conditioning/lotus", + category="model/conditioning/lotus", inputs=[], outputs=[io.Conditioning.Output(display_name="conditioning")], ) diff --git a/comfy_extras/nodes_lt.py b/comfy_extras/nodes_lt.py index 48d75c9e5..6d6078abe 100644 --- a/comfy_extras/nodes_lt.py +++ b/comfy_extras/nodes_lt.py @@ -25,7 +25,7 @@ class GetICLoRAParameters(io.ComfyNode): display_name="Get IC-LoRA Parameters", description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded " "model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).", - category="conditioning/video_models", + category="model/conditioning/video_models", search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"], inputs=[ io.Model.Input( @@ -62,7 +62,7 @@ class EmptyLTXVLatentVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="EmptyLTXVLatentVideo", - category="latent/video/ltxv", + category="model/latent/video/ltxv", inputs=[ io.Int.Input("width", default=768, min=64, max=nodes.MAX_RESOLUTION, step=32), io.Int.Input("height", default=512, min=64, max=nodes.MAX_RESOLUTION, step=32), @@ -86,7 +86,7 @@ class LTXVImgToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVImgToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -131,7 +131,7 @@ class LTXVImgToVideoInplace(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVImgToVideoInplace", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Vae.Input("vae"), io.Image.Input("image"), @@ -251,7 +251,7 @@ class LTXVAddGuide(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVAddGuide", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -498,7 +498,7 @@ class LTXVCropGuides(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVCropGuides", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -542,7 +542,7 @@ class LTXVConditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVConditioning", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -611,7 +611,7 @@ class LTXVScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Int.Input("steps", default=20, min=1, max=10000), io.Float.Input("max_shift", default=2.05, min=0.0, max=100.0, step=0.01), @@ -746,7 +746,7 @@ class LTXVConcatAVLatent(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVConcatAVLatent", - category="latent/video/ltxv", + category="model/latent/video/ltxv", inputs=[ io.Latent.Input("video_latent"), io.Latent.Input("audio_latent"), @@ -781,7 +781,7 @@ class LTXVSeparateAVLatent(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="LTXVSeparateAVLatent", - category="latent/video/ltxv", + category="model/latent/video/ltxv", description="LTXV Separate AV Latent", inputs=[ io.Latent.Input("av_latent"), @@ -814,7 +814,7 @@ class LTXVReferenceAudio(io.ComfyNode): return io.Schema( node_id="LTXVReferenceAudio", display_name="LTXV Reference Audio (ID-LoRA)", - category="conditioning/audio", + category="model/conditioning/audio", description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).", inputs=[ io.Model.Input("model"), diff --git a/comfy_extras/nodes_lt_audio.py b/comfy_extras/nodes_lt_audio.py index 51ddf584a..052186083 100644 --- a/comfy_extras/nodes_lt_audio.py +++ b/comfy_extras/nodes_lt_audio.py @@ -12,7 +12,7 @@ class LTXVAudioVAELoader(io.ComfyNode): return io.Schema( node_id="LTXVAudioVAELoader", display_name="Load LTXV Audio VAE", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input( "ckpt_name", @@ -40,7 +40,7 @@ class LTXVAudioVAEEncode(VAEEncodeAudio): return io.Schema( node_id="LTXVAudioVAEEncode", display_name="LTXV Audio VAE Encode", - category="latent/audio", + category="model/latent/audio", inputs=[ io.Audio.Input("audio", tooltip="The audio to be encoded."), io.Vae.Input( @@ -63,7 +63,7 @@ class LTXVAudioVAEDecode(io.ComfyNode): return io.Schema( node_id="LTXVAudioVAEDecode", display_name="LTXV Audio VAE Decode", - category="latent/audio", + category="model/latent/audio", inputs=[ io.Latent.Input("samples", tooltip="The latent to be decoded."), io.Vae.Input( @@ -96,7 +96,7 @@ class LTXVEmptyLatentAudio(io.ComfyNode): return io.Schema( node_id="LTXVEmptyLatentAudio", display_name="LTXV Empty Latent Audio", - category="latent/audio", + category="model/latent/audio", inputs=[ io.Int.Input( "frames_number", diff --git a/comfy_extras/nodes_lt_upsampler.py b/comfy_extras/nodes_lt_upsampler.py index f99ba13fb..be9a36e69 100644 --- a/comfy_extras/nodes_lt_upsampler.py +++ b/comfy_extras/nodes_lt_upsampler.py @@ -1,32 +1,32 @@ from comfy import model_management +from comfy_api.latest import ComfyExtension, IO +from typing_extensions import override import math -class LTXVLatentUpsampler: + +class LTXVLatentUpsampler(IO.ComfyNode): """ Upsamples a video latent by a factor of 2. """ @classmethod - def INPUT_TYPES(s): - return { - "required": { - "samples": ("LATENT",), - "upscale_model": ("LATENT_UPSCALE_MODEL",), - "vae": ("VAE",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="LTXVLatentUpsampler", + category="model/latent/video", + is_experimental=True, + inputs=[ + IO.Latent.Input("samples"), + IO.LatentUpscaleModel.Input("upscale_model"), + IO.Vae.Input("vae"), + ], + outputs=[ + IO.Latent.Output(), + ], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "upsample_latent" - CATEGORY = "latent/video" - EXPERIMENTAL = True - - def upsample_latent( - self, - samples: dict, - upscale_model, - vae, - ) -> tuple: + @classmethod + def execute(cls, samples, upscale_model, vae) -> IO.NodeOutput: """ Upsample the input latent using the provided model. @@ -34,7 +34,6 @@ class LTXVLatentUpsampler: samples (dict): Input latent samples upscale_model (LatentUpsampler): Loaded upscale model vae: VAE model for normalization - auto_tiling (bool): Whether to automatically tile the input for processing Returns: tuple: Tuple containing the upsampled latent @@ -67,9 +66,16 @@ class LTXVLatentUpsampler: return_dict = samples.copy() return_dict["samples"] = upsampled_latents return_dict.pop("noise_mask", None) - return (return_dict,) + return IO.NodeOutput(return_dict) + + upsample_latent = execute # TODO: remove -NODE_CLASS_MAPPINGS = { - "LTXVLatentUpsampler": LTXVLatentUpsampler, -} +class LTXVLatentUpsamplerExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [LTXVLatentUpsampler] + + +async def comfy_entrypoint() -> LTXVLatentUpsamplerExtension: + return LTXVLatentUpsamplerExtension() diff --git a/comfy_extras/nodes_lumina2.py b/comfy_extras/nodes_lumina2.py index b35ab8b7d..c060a86a0 100644 --- a/comfy_extras/nodes_lumina2.py +++ b/comfy_extras/nodes_lumina2.py @@ -81,7 +81,7 @@ class CLIPTextEncodeLumina2(io.ComfyNode): node_id="CLIPTextEncodeLumina2", search_aliases=["lumina prompt"], display_name="CLIP Text Encode for Lumina2", - category="conditioning", + category="model/conditioning", description="Encodes a system prompt and a user prompt using a CLIP model into an embedding " "that can be used to guide the diffusion model towards generating specific images.", inputs=[ diff --git a/comfy_extras/nodes_mask.py b/comfy_extras/nodes_mask.py index d15f1f4e7..52484697a 100644 --- a/comfy_extras/nodes_mask.py +++ b/comfy_extras/nodes_mask.py @@ -53,7 +53,7 @@ class LatentCompositeMasked(IO.ComfyNode): return IO.Schema( node_id="LatentCompositeMasked", search_aliases=["overlay latent", "layer latent", "paste latent", "inpaint latent"], - category="latent", + category="model/latent", inputs=[ IO.Latent.Input("destination"), IO.Latent.Input("source"), diff --git a/comfy_extras/nodes_math.py b/comfy_extras/nodes_math.py index 0040d1a92..0883c65ac 100644 --- a/comfy_extras/nodes_math.py +++ b/comfy_extras/nodes_math.py @@ -69,7 +69,7 @@ class MathExpressionNode(io.ComfyNode): return io.Schema( node_id="ComfyMathExpression", display_name="Math Expression", - category="utils", + category="utilities", search_aliases=[ "expression", "formula", "calculate", "calculator", "eval", "math", @@ -102,11 +102,18 @@ class MathExpressionNode(io.ComfyNode): f"Math Expression '{expression}' must evaluate to a numeric result, " f"got {type(result).__name__}: {result!r}" ) - if not math.isfinite(result): + try: + float_result = float(result) + except OverflowError: + raise ValueError( + f"Math Expression '{expression}' produced a result too large to " + f"represent as a float: {result}" + ) from None + if not math.isfinite(float_result): raise ValueError( f"Math Expression '{expression}' produced a non-finite result: {result}" ) - return io.NodeOutput(float(result), int(result), bool(result)) + return io.NodeOutput(float_result, int(result), bool(result)) class MathExtension(ComfyExtension): diff --git a/comfy_extras/nodes_mediapipe.py b/comfy_extras/nodes_mediapipe.py index 32dc22de3..343d88dbb 100644 --- a/comfy_extras/nodes_mediapipe.py +++ b/comfy_extras/nodes_mediapipe.py @@ -205,7 +205,7 @@ class LoadMediaPipeFaceLandmarker(io.ComfyNode): node_id="LoadMediaPipeFaceLandmarker", search_aliases=["face", "facial", "mediapipe", "face landmark", "face mesh", "blazeface", "face detection"], display_name="Load Face Detection Model (MediaPipe)", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input("model_name", options=folder_paths.get_filename_list("detection"), tooltip="Face detection model from models/detection/."), diff --git a/comfy_extras/nodes_mochi.py b/comfy_extras/nodes_mochi.py index d750194fc..3dcea6ab3 100644 --- a/comfy_extras/nodes_mochi.py +++ b/comfy_extras/nodes_mochi.py @@ -10,7 +10,7 @@ class EmptyMochiLatentVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="EmptyMochiLatentVideo", - category="latent/video", + category="model/latent/video", inputs=[ io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), diff --git a/comfy_extras/nodes_model_downscale.py b/comfy_extras/nodes_model_downscale.py index 24d47a903..817542452 100644 --- a/comfy_extras/nodes_model_downscale.py +++ b/comfy_extras/nodes_model_downscale.py @@ -10,7 +10,7 @@ class PatchModelAddDownscale(io.ComfyNode): return io.Schema( node_id="PatchModelAddDownscale", display_name="PatchModelAddDownscale (Kohya Deep Shrink)", - category="model_patches/unet", + category="model/patch/unet", inputs=[ io.Model.Input("model"), io.Int.Input("block_number", default=3, min=1, max=32, step=1, advanced=True), diff --git a/comfy_extras/nodes_model_patch.py b/comfy_extras/nodes_model_patch.py index 748559a6b..bdccbf8c4 100644 --- a/comfy_extras/nodes_model_patch.py +++ b/comfy_extras/nodes_model_patch.py @@ -548,7 +548,7 @@ class USOStyleReference: FUNCTION = "apply_patch" EXPERIMENTAL = True - CATEGORY = "advanced/model_patches/flux" + CATEGORY = "model/patch/flux" def apply_patch(self, model, model_patch, clip_vision_output): encoded_image = torch.stack((clip_vision_output.all_hidden_states[:, -20], clip_vision_output.all_hidden_states[:, -11], clip_vision_output.penultimate_hidden_states)) @@ -594,7 +594,7 @@ class SUPIRApply(io.ComfyNode): def define_schema(cls) -> io.Schema: return io.Schema( node_id="SUPIRApply", - category="model_patches/supir", + category="model/patch/supir", is_experimental=True, inputs=[ io.Model.Input("model"), diff --git a/comfy_extras/nodes_moge.py b/comfy_extras/nodes_moge.py index 79aec5d7f..422949531 100644 --- a/comfy_extras/nodes_moge.py +++ b/comfy_extras/nodes_moge.py @@ -78,7 +78,7 @@ class LoadMoGeModel(io.ComfyNode): return io.Schema( node_id="LoadMoGeModel", display_name="Load MoGe Model", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input("model_name", options=folder_paths.get_filename_list("geometry_estimation")), ], @@ -104,7 +104,7 @@ class MoGePanoramaInference(io.ComfyNode): node_id="MoGePanoramaInference", search_aliases=["moge", "panorama", "depth", "geometry", "depth estimation", "geometry estimation"], display_name="Run MoGe Panorama Inference", - category="image/geometry_estimation", + category="image/geometry estimation", description="Run MoGe on an equirectangular panorama by splitting it into 12 perspective views, running inference on each, and merging the results into a single depth map.", inputs=[ MoGeModelType.Input("moge_model"), @@ -226,7 +226,7 @@ class MoGeInference(io.ComfyNode): search_aliases=["moge", "depth", "geometry", "depth estimation", "geometry estimation"], display_name="Run MoGe Inference", description="Run MoGe on a single image to estimate depth and geometry.", - category="image/geometry_estimation", + category="image/geometry estimation", inputs=[ MoGeModelType.Input("moge_model"), io.Image.Input("image"), @@ -283,7 +283,7 @@ class MoGeRender(io.ComfyNode): search_aliases=["moge", "render", "geometry", "depth", "normal"], display_name="Render MoGe Geometry", description="Render a depth map or normal map from geometry data", - category="image/geometry_estimation", + category="image/geometry estimation", inputs=[ MoGeGeometry.Input("moge_geometry"), io.Combo.Input("output", options=["depth", "depth_colored", "normal_opengl", "normal_directx", "mask"], default="depth", @@ -350,7 +350,7 @@ class MoGePointMapToMesh(io.ComfyNode): search_aliases=["moge", "mesh", "geometry", "point map"], display_name="Convert MoGe Point Map to Mesh", description="Convert a MoGe point map into a 3D mesh.", - category="image/geometry_estimation", + category="image/geometry estimation", inputs=[ MoGeGeometry.Input("moge_geometry"), io.Int.Input("batch_index", default=0, min=0, max=4096, diff --git a/comfy_extras/nodes_number_convert.py b/comfy_extras/nodes_number_convert.py index 01593b6e6..d7e557e95 100644 --- a/comfy_extras/nodes_number_convert.py +++ b/comfy_extras/nodes_number_convert.py @@ -20,7 +20,7 @@ class NumberConvertNode(io.ComfyNode): return io.Schema( node_id="ComfyNumberConvert", display_name="Convert Number", - category="utils", + category="utilities", search_aliases=[ "int to float", "float to int", "number convert", "int2float", "float2int", "cast", "parse number", diff --git a/comfy_extras/nodes_optimalsteps.py b/comfy_extras/nodes_optimalsteps.py index 5beeaa7db..19629790f 100644 --- a/comfy_extras/nodes_optimalsteps.py +++ b/comfy_extras/nodes_optimalsteps.py @@ -31,7 +31,7 @@ class OptimalStepsScheduler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="OptimalStepsScheduler", - category="sampling/schedulers", + category="model/sampling/schedulers", inputs=[ io.Combo.Input("model_type", options=["FLUX", "Wan", "Chroma"]), io.Int.Input("steps", default=20, min=3, max=1000), diff --git a/comfy_extras/nodes_pag.py b/comfy_extras/nodes_pag.py index 79fea5f0c..c875e1e06 100644 --- a/comfy_extras/nodes_pag.py +++ b/comfy_extras/nodes_pag.py @@ -15,7 +15,7 @@ class PerturbedAttentionGuidance(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="PerturbedAttentionGuidance", - category="model_patches/unet", + category="model/patch/unet", inputs=[ io.Model.Input("model"), io.Float.Input("scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01), diff --git a/comfy_extras/nodes_pid.py b/comfy_extras/nodes_pid.py index 811b9ae8e..71855254e 100644 --- a/comfy_extras/nodes_pid.py +++ b/comfy_extras/nodes_pid.py @@ -21,8 +21,8 @@ class PiDConditioning(io.ComfyNode): inputs=[ io.Conditioning.Input("positive"), io.Latent.Input("latent", tooltip="latent (from VAEEncode or a KSampler)."), - io.Combo.Input("latent_format", options=["flux", "sd3"], default="flux", - tooltip="Flux1 and Flux2 latents auto-detected from channel dim, sd3 has to be selected manually."), + io.Combo.Input("latent_format", options=["flux", "sd3", "sdxl", "qwenimage"], default="flux", + tooltip="Flux1 (16-ch) and Flux2 (128-ch) latents are auto-detected from channel dim under 'flux'. For SD3 (16-ch), SDXL (4-ch), or QwenImage (16-ch), select manually."), io.Float.Input( "degrade_sigma", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="0 = clean latent. Increase to denoise corrupted latent outputs.", @@ -36,9 +36,17 @@ class PiDConditioning(io.ComfyNode): samples = latent["samples"] if latent_format == "flux": fmt_cls = comfy.latent_formats.Flux2 if samples.shape[1] == 128 else comfy.latent_formats.Flux - else: + elif latent_format == "sd3": fmt_cls = comfy.latent_formats.SD3 + elif latent_format == "sdxl": + fmt_cls = comfy.latent_formats.SDXL + elif latent_format == "qwenimage": + fmt_cls = comfy.latent_formats.Wan21 + else: + raise ValueError(f"Unknown latent_format: {latent_format}") lq_latent = fmt_cls().process_in(samples) + if lq_latent.ndim == 5: + lq_latent = lq_latent[:, :, 0] sigma_t = torch.tensor([float(degrade_sigma)], dtype=torch.float32) return io.NodeOutput(node_helpers.conditioning_set_values( positive, {"lq_latent": lq_latent, "degrade_sigma": sigma_t}, diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index a25db277c..3e440433e 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -616,7 +616,7 @@ class BatchLatentsNode(io.ComfyNode): node_id="BatchLatentsNode", search_aliases=["combine latents", "stack latents", "merge latents"], display_name="Batch Latents", - category="latent", + category="model/latent", inputs=[ io.Autogrow.Input("latents", template=autogrow_template) ], diff --git a/comfy_extras/nodes_preview_any.py b/comfy_extras/nodes_preview_any.py index 17e25d514..1070a69d0 100644 --- a/comfy_extras/nodes_preview_any.py +++ b/comfy_extras/nodes_preview_any.py @@ -16,7 +16,7 @@ class PreviewAny(): FUNCTION = "main" OUTPUT_NODE = True - CATEGORY = "utils" + CATEGORY = "utilities" SEARCH_ALIASES = ["show output", "inspect", "debug", "print value", "show text"] def main(self, source=None): diff --git a/comfy_extras/nodes_primitive.py b/comfy_extras/nodes_primitive.py index 33373266b..c44b09098 100644 --- a/comfy_extras/nodes_primitive.py +++ b/comfy_extras/nodes_primitive.py @@ -11,7 +11,7 @@ class String(io.ComfyNode): node_id="PrimitiveString", search_aliases=["text", "string", "text box", "prompt"], display_name="Text String", - category="utils/primitive", + category="utilities/primitive", inputs=[ io.String.Input("value"), ], @@ -30,7 +30,7 @@ class StringMultiline(io.ComfyNode): node_id="PrimitiveStringMultiline", search_aliases=["text", "string", "text multiline", "string multiline", "text box", "prompt"], display_name="Text String (Multiline)", - category="utils/primitive", + category="utilities/primitive", essentials_category="Basics", inputs=[ io.String.Input("value", multiline=True), @@ -49,7 +49,7 @@ class Int(io.ComfyNode): return io.Schema( node_id="PrimitiveInt", display_name="Int", - category="utils/primitive", + category="utilities/primitive", inputs=[ io.Int.Input("value", min=-sys.maxsize, max=sys.maxsize, control_after_generate=io.ControlAfterGenerate.fixed), ], @@ -67,7 +67,7 @@ class Float(io.ComfyNode): return io.Schema( node_id="PrimitiveFloat", display_name="Float", - category="utils/primitive", + category="utilities/primitive", inputs=[ io.Float.Input("value", min=-sys.maxsize, max=sys.maxsize, step=0.1), ], @@ -85,7 +85,7 @@ class Boolean(io.ComfyNode): return io.Schema( node_id="PrimitiveBoolean", display_name="Boolean", - category="utils/primitive", + category="utilities/primitive", inputs=[ io.Boolean.Input("value"), ], diff --git a/comfy_extras/nodes_qwen.py b/comfy_extras/nodes_qwen.py index fde8fac9a..5b92814a4 100644 --- a/comfy_extras/nodes_qwen.py +++ b/comfy_extras/nodes_qwen.py @@ -112,7 +112,7 @@ class EmptyQwenImageLayeredLatentImage(io.ComfyNode): return io.Schema( node_id="EmptyQwenImageLayeredLatentImage", display_name="Empty Qwen Image Layered Latent", - category="latent/qwen", + category="model/latent/qwen", inputs=[ io.Int.Input("width", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("height", default=640, min=16, max=nodes.MAX_RESOLUTION, step=16), diff --git a/comfy_extras/nodes_rebatch.py b/comfy_extras/nodes_rebatch.py index 5f4e82aef..2185385f0 100644 --- a/comfy_extras/nodes_rebatch.py +++ b/comfy_extras/nodes_rebatch.py @@ -10,7 +10,7 @@ class LatentRebatch(io.ComfyNode): return io.Schema( node_id="RebatchLatents", display_name="Rebatch Latents", - category="latent/batch", + category="model/latent/batch", is_input_list=True, inputs=[ io.Latent.Input("latents"), diff --git a/comfy_extras/nodes_resolution.py b/comfy_extras/nodes_resolution.py index 1628038cc..083e47ae4 100644 --- a/comfy_extras/nodes_resolution.py +++ b/comfy_extras/nodes_resolution.py @@ -6,24 +6,24 @@ from comfy_api.latest import ComfyExtension, io class AspectRatio(str, Enum): SQUARE = "1:1 (Square)" + PHOTO_V = "2:3 (Portrait Photo)" PHOTO_H = "3:2 (Photo)" + STANDARD_V = "3:4 (Portrait Standard)" STANDARD_H = "4:3 (Standard)" + WIDESCREEN_V = "9:16 (Portrait Widescreen)" WIDESCREEN_H = "16:9 (Widescreen)" ULTRAWIDE_H = "21:9 (Ultrawide)" - PHOTO_V = "2:3 (Portrait Photo)" - STANDARD_V = "3:4 (Portrait Standard)" - WIDESCREEN_V = "9:16 (Portrait Widescreen)" ASPECT_RATIOS: dict[AspectRatio, tuple[int, int]] = { AspectRatio.SQUARE: (1, 1), + AspectRatio.PHOTO_V: (2, 3), AspectRatio.PHOTO_H: (3, 2), + AspectRatio.STANDARD_V: (3, 4), AspectRatio.STANDARD_H: (4, 3), + AspectRatio.WIDESCREEN_V: (9, 16), AspectRatio.WIDESCREEN_H: (16, 9), AspectRatio.ULTRAWIDE_H: (21, 9), - AspectRatio.PHOTO_V: (2, 3), - AspectRatio.STANDARD_V: (3, 4), - AspectRatio.WIDESCREEN_V: (9, 16), } @@ -35,7 +35,7 @@ class ResolutionSelector(io.ComfyNode): return io.Schema( node_id="ResolutionSelector", display_name="Resolution Selector", - category="utils", + category="utilities", description="Calculate width and height from aspect ratio and megapixel target. Useful for setting up Empty Latent Image dimensions.", inputs=[ io.Combo.Input( @@ -50,26 +50,35 @@ class ResolutionSelector(io.ComfyNode): min=0.1, max=16.0, step=0.1, - tooltip="Target total megapixels. 1.0 MP ≈ 1024×1024 for square.", + tooltip="Target total megapixels. 1.0 MP ≈ 1024x1024 for square.", + ), + io.Int.Input( + id="multiple", + default=8, + min=8, + max=128, + step=4, + tooltip="Nearest multiple of the result to set the selected resolution to.", + advanced=True, ), ], outputs=[ io.Int.Output( - "width", tooltip="Calculated width in pixels (multiple of 8)." + "width", tooltip="Calculated width in pixels multiplied by the selected multiple." ), io.Int.Output( - "height", tooltip="Calculated height in pixels (multiple of 8)." + "height", tooltip="Calculated height in pixels multiplied by the selected multiple." ), ], ) @classmethod - def execute(cls, aspect_ratio: str, megapixels: float) -> io.NodeOutput: + def execute(cls, aspect_ratio: str, megapixels: float, multiple: int) -> io.NodeOutput: w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio] total_pixels = megapixels * 1024 * 1024 scale = math.sqrt(total_pixels / (w_ratio * h_ratio)) - width = round(w_ratio * scale / 8) * 8 - height = round(h_ratio * scale / 8) * 8 + width = round(w_ratio * scale / multiple) * multiple + height = round(h_ratio * scale / multiple) * multiple return io.NodeOutput(width, height) diff --git a/comfy_extras/nodes_rope.py b/comfy_extras/nodes_rope.py index 918ddc02b..808eee29b 100644 --- a/comfy_extras/nodes_rope.py +++ b/comfy_extras/nodes_rope.py @@ -7,7 +7,7 @@ class ScaleROPE(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="ScaleROPE", - category="advanced/model_patches", + category="model/patch", description="Scale and shift the ROPE of the model.", is_experimental=True, inputs=[ diff --git a/comfy_extras/nodes_save_3d.py b/comfy_extras/nodes_save_3d.py index c03524246..1b6592bb2 100644 --- a/comfy_extras/nodes_save_3d.py +++ b/comfy_extras/nodes_save_3d.py @@ -16,7 +16,7 @@ from comfy.cli_args import args from comfy_api.latest import ComfyExtension, IO, Types -def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None): +def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=None, unlit=False): # Pack lists of (Nᵢ, *) vertex/face/color/uv tensors into padded batched tensors, # stashing per-item lengths as runtime attrs so consumers can recover the real slice. # colors and uvs are 1:1 with vertices, so they're padded to max_vertices and read with vertex_counts. @@ -54,7 +54,7 @@ def pack_variable_mesh_batch(vertices, faces, colors=None, uvs=None, texture=Non return Types.MESH(packed_vertices, packed_faces, uvs=packed_uvs, vertex_colors=packed_colors, texture=texture, - vertex_counts=vertex_counts, face_counts=face_counts) + vertex_counts=vertex_counts, face_counts=face_counts, unlit=unlit) def get_mesh_batch_item(mesh, index): @@ -77,7 +77,7 @@ def get_mesh_batch_item(mesh, index): def save_glb(vertices, faces, filepath, metadata=None, - uvs=None, vertex_colors=None, texture_image=None): + uvs=None, vertex_colors=None, texture_image=None, unlit=False): """ Save PyTorch tensor vertices and faces as a GLB file without external dependencies. @@ -234,6 +234,17 @@ def save_glb(vertices, faces, filepath, metadata=None, textures = [] samplers = [] materials = [] + extensions_used = [] + if unlit and texture_png_bytes is None: + # Flat, light-independent shading (KHR_materials_unlit): COLOR_0 is shown as-is, matching how a + # gaussian splat renders (emissive). Without this the viewer lights the mesh and washes the colours. + materials.append({ + "pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0], "metallicFactor": 0.0, "roughnessFactor": 1.0}, + "extensions": {"KHR_materials_unlit": {}}, + "doubleSided": True, + }) + extensions_used.append("KHR_materials_unlit") + primitive["material"] = 0 if texture_png_bytes is not None and "TEXCOORD_0" in primitive_attributes: buffer_views.append({ "buffer": 0, @@ -271,6 +282,8 @@ def save_glb(vertices, faces, filepath, metadata=None, gltf["textures"] = textures if materials: gltf["materials"] = materials + if extensions_used: + gltf["extensionsUsed"] = extensions_used if metadata: gltf["asset"]["extras"] = metadata @@ -324,6 +337,12 @@ class SaveGLB(IO.ComfyNode): IO.File3DFBX, IO.File3DSTL, IO.File3DUSDZ, + IO.File3DPLY, + IO.File3DSPLAT, + IO.File3DSPZ, + IO.File3DKSPLAT, + IO.File3DSplatAny, + IO.File3DPointCloudAny, IO.File3DAny, ], tooltip="Mesh or 3D file to save", @@ -376,7 +395,8 @@ class SaveGLB(IO.ComfyNode): save_glb(vertices_i, faces_i, os.path.join(full_output_folder, f), metadata, uvs=uvs_i, vertex_colors=v_colors, - texture_image=tex_img) + texture_image=tex_img, + unlit=getattr(mesh, "unlit", False)) results.append({ "filename": f, "subfolder": subfolder, diff --git a/comfy_extras/nodes_sd3.py b/comfy_extras/nodes_sd3.py index 6655c1ba7..38cbf117b 100644 --- a/comfy_extras/nodes_sd3.py +++ b/comfy_extras/nodes_sd3.py @@ -41,7 +41,7 @@ class EmptySD3LatentImage(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="EmptySD3LatentImage", - category="latent/sd3", + category="model/latent/sd3", inputs=[ io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), @@ -113,7 +113,7 @@ class ControlNetApplySD3(io.ComfyNode): return io.Schema( node_id="ControlNetApplySD3", display_name="Apply Controlnet with VAE", - category="conditioning/controlnet", + category="model/conditioning/controlnet", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfy_extras/nodes_sdupscale.py b/comfy_extras/nodes_sdupscale.py index 5877719d3..ea283e971 100644 --- a/comfy_extras/nodes_sdupscale.py +++ b/comfy_extras/nodes_sdupscale.py @@ -9,7 +9,7 @@ class SD_4XUpscale_Conditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SD_4XUpscale_Conditioning", - category="conditioning/upscale_diffusion", + category="model/conditioning/upscale_diffusion", inputs=[ io.Image.Input("images"), io.Conditioning.Input("positive"), diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py new file mode 100644 index 000000000..d5cd029ba --- /dev/null +++ b/comfy_extras/nodes_seedvr.py @@ -0,0 +1,1015 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +import torch +import math +import logging +from einops import rearrange + +import gc +import comfy.model_management +import comfy.sample +import comfy.samplers +from comfy.ldm.seedvr.color_fix import ( + adain_color_transfer, + lab_color_transfer, + wavelet_color_transfer, +) +from comfy.ldm.seedvr.constants import ( + BYTEDANCE_IMG_SHIFT_FIT, + BYTEDANCE_SCHEDULE_T, + BYTEDANCE_VID_SHIFT_FIT, + SEEDVR2_ADAIN_SCALE_MULTIPLIER, + SEEDVR2_COLOR_MEM_HEADROOM, + SEEDVR2_COND_CHANNELS, + SEEDVR2_DTYPE_BYTES_FLOOR, + SEEDVR2_LAB_SCALE_MULTIPLIER, + SEEDVR2_LATENT_CHANNELS, + SEEDVR2_OOM_BACKOFF_DIVISOR, + SEEDVR2_WAVELET_SCALE_MULTIPLIER, +) + +from torchvision.transforms import functional as TVF +from torchvision.transforms import Lambda +from torchvision.transforms.functional import InterpolationMode + + +_SEEDVR2_INVALID_MODEL_MSG_PREFIX = ( + "SeedVR2Conditioning: model object does not match expected SeedVR2 structure" +) + +# Private sentinel for getattr default: distinguishes "attribute missing" +# from "attribute present but None" so the failure message is accurate. +_ATTR_MISSING = object() + + +def _seedvr2_auto_chunk_attempts(t_latent, t_pixel, frames_per_chunk): + """Return stricter 4n+1 frame chunk sizes for auto OOM retries.""" + attempts = [frames_per_chunk] + current_chunk_latent = ( + t_latent if t_pixel <= frames_per_chunk + else (frames_per_chunk - 1) // 4 + 1 + ) + current_chunk_count = max(1, math.ceil(t_latent / current_chunk_latent)) + seen = {frames_per_chunk} + + for target_chunks in range(max(2, current_chunk_count + 1), t_latent + 1): + chunk_latent = max(1, math.ceil(t_latent / target_chunks)) + candidate = 4 * (chunk_latent - 1) + 1 + if candidate in seen: + continue + if candidate >= attempts[-1]: + continue + attempts.append(candidate) + seen.add(candidate) + + return attempts + + +def _resolve_seedvr2_diffusion_model(model): + """Resolve ``model.model.diffusion_model``, failing loud via the ``_ATTR_MISSING`` sentinel so each of the four modes (model/diffusion_model missing vs None) gives an accurate message.""" + inner = getattr(model, "model", _ATTR_MISSING) + if inner is _ATTR_MISSING: + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input has no 'model' attribute " + f"(got type {type(model).__name__})." + ) + if inner is None: + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: input.model is None " + f"(input type {type(model).__name__})." + ) + diffusion_model = getattr(inner, "diffusion_model", _ATTR_MISSING) + if diffusion_model is _ATTR_MISSING: + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model' has no " + f"'diffusion_model' attribute (got type {type(inner).__name__})." + ) + if diffusion_model is None: + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: 'model.model.diffusion_model' " + f"is None (model.model type {type(inner).__name__})." + ) + return diffusion_model + + +def _apply_rope_freqs_float32_cast(diffusion_model): + """Cast every module's ``rope.freqs`` to float32; the per-tensor dtype check (not a sentinel attr) self-corrects across Comfy's unload/reload, which would otherwise restore the archived fp16/bf16 dtype.""" + for module in diffusion_model.modules(): + if hasattr(module, 'rope') and hasattr(module.rope, 'freqs'): + if module.rope.freqs.data.dtype != torch.float32: + module.rope.freqs.data = module.rope.freqs.data.to(torch.float32) + + +def clear_vae_memory(vae_model): + for module in vae_model.modules(): + if hasattr(module, "memory"): + module.memory = None + gc.collect() + comfy.model_management.soft_empty_cache() + +def expand_dims(tensor, ndim): + shape = tensor.shape + (1,) * (ndim - tensor.ndim) + return tensor.reshape(shape) + +def get_conditions(latent, latent_blur): + t, h, w, c = latent.shape + cond = torch.ones([t, h, w, c + 1], device=latent.device, dtype=latent.dtype) + cond[:, ..., :-1] = latent_blur[:] + cond[:, ..., -1:] = 1.0 + return cond + +def timestep_transform(timesteps, latents_shapes): + vt = 4 + vs = 8 + frames = (latents_shapes[:, 0] - 1) * vt + 1 + heights = latents_shapes[:, 1] * vs + widths = latents_shapes[:, 2] * vs + + # Compute shift factor. + def get_lin_function(x1, y1, x2, y2): + m = (y2 - y1) / (x2 - x1) + b = y1 - m * x1 + return lambda x: m * x + b + + img_shift_fn = get_lin_function(*BYTEDANCE_IMG_SHIFT_FIT) + vid_shift_fn = get_lin_function(*BYTEDANCE_VID_SHIFT_FIT) + shift = torch.where( + frames > 1, + vid_shift_fn(heights * widths * frames), + img_shift_fn(heights * widths), + ).to(timesteps.device) + + # Shift timesteps. + T = BYTEDANCE_SCHEDULE_T + timesteps = timesteps / T + timesteps = shift * timesteps / (1 + (shift - 1) * timesteps) + timesteps = timesteps * T + return timesteps + +def inter(x_0, x_T, t): + t = expand_dims(t, x_0.ndim) + T = BYTEDANCE_SCHEDULE_T + B = lambda t: t / T + A = lambda t: 1 - (t / T) + return A(t) * x_0 + B(t) * x_T + +def div_pad(image, factor): + + height_factor, width_factor = factor + height, width = image.shape[-2:] + + pad_height = (height_factor - (height % height_factor)) % height_factor + pad_width = (width_factor - (width % width_factor)) % width_factor + + if pad_height == 0 and pad_width == 0: + return image + + if isinstance(image, torch.Tensor): + padding = (0, pad_width, 0, pad_height) + image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0) + + return image + +def cut_videos(videos): + t = videos.size(1) + if t == 1: + return videos + if t <= 4 : + padding = [videos[:, -1].unsqueeze(1)] * (4 - t + 1) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + return videos + if (t - 1) % (4) == 0: + return videos + else: + padding = [videos[:, -1].unsqueeze(1)] * ( + 4 - ((t - 1) % (4)) + ) + padding = torch.cat(padding, dim=1) + videos = torch.cat([videos, padding], dim=1) + assert (videos.size(1) - 1) % (4) == 0 + return videos + +def _seedvr2_input_shorter_edge(images, node_name): + if images.dim() == 4: + return min(images.shape[1], images.shape[2]) + if images.dim() == 5: + return min(images.shape[2], images.shape[3]) + raise ValueError( + f"{node_name}: expected 4-D or 5-D IMAGE tensor, " + f"got shape {tuple(images.shape)}" + ) + + +def _seedvr2_pad(images, upscaled_shorter_edge, node_name): + if upscaled_shorter_edge < 2: + raise ValueError( + f"{node_name}: input shorter edge must be at least 2 pixels; " + f"got {upscaled_shorter_edge}." + ) + if images.shape[-1] > 3: + images = images[..., :3] + if images.dim() == 4: + # Comfy video components arrive as a 4-D IMAGE frame sequence: + # (frames, H, W, C). SeedVR2 consumes that as one video. + images = images.unsqueeze(0) + elif images.dim() != 5: + raise ValueError( + f"{node_name}: expected 4-D or 5-D IMAGE tensor, " + f"got shape {tuple(images.shape)}" + ) + images = images.permute(0, 1, 4, 2, 3) + + b, t, c, h, w = images.shape + images = images.reshape(b * t, c, h, w) + + clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) + images = clip(images) + images = div_pad(images, (16, 16)) + _, _, new_h, new_w = images.shape + + images = images.reshape(b, t, c, new_h, new_w) + images = cut_videos(images) + images_bthwc = rearrange(images, "b t c h w -> b t h w c") + + return io.NodeOutput(images_bthwc) + + +class SeedVR2Preprocess(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2Preprocess", + display_name="Pre-Process SeedVR2 Input", + category="image/upscaling", + description="Pad a resized image for SeedVR2 model. Alpha channel is dropped. The node Post-Process SeedVR2 Output re-applies it from the original resized image.", + inputs=[ + io.Image.Input("resized_images", tooltip="The resized image to process."), + ], + outputs=[ + io.Image.Output("images"), + ] + ) + + @classmethod + def execute(cls, resized_images): + upscaled_shorter_edge = _seedvr2_input_shorter_edge(resized_images, "SeedVR2Preprocess") + return _seedvr2_pad( + resized_images, upscaled_shorter_edge, "SeedVR2Preprocess", + ) + + +class SeedVR2PostProcessing(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2PostProcessing", + display_name="Post-Process SeedVR2 Output", + category="image/upscaling", + description="Align the generated image with the original resized image and apply color correction.", + inputs=[ + io.Image.Input("images", tooltip="The generated image to process."), + io.Image.Input("original_resized_images", tooltip="The original resized image before pre-processing, used as reference."), + io.Combo.Input("color_correction_method", options=["lab", "wavelet", "adain", "none"], default="lab", tooltip="Method to match the generated image colors to the original image. lab: transfer color in CIELAB space, preserving detail (most faithful). wavelet: transfer low-frequency color, keeping upscaled high-frequency detail. adain: match per-channel mean/std (fastest, global tint). none: skip color transfer (geometry alignment only)."), + ], + outputs=[io.Image.Output(display_name="images")], + ) + + @classmethod + def execute(cls, images, original_resized_images, color_correction_method): + alpha_input = None + if original_resized_images.shape[-1] == 4: + alpha_input = original_resized_images[..., 3:4] + original_resized_images = original_resized_images[..., :3] + decoded_5d, decoded_was_4d = cls._as_bthwc(images) + reference_full, _ = cls._as_bthwc(original_resized_images) + decoded_5d = cls._restore_reference_batch_time(decoded_5d, reference_full) + + b = min(decoded_5d.shape[0], reference_full.shape[0]) + t = min(decoded_5d.shape[1], reference_full.shape[1]) + reference_h = reference_full.shape[2] + reference_w = reference_full.shape[3] + + decoded_5d = decoded_5d[:b, :t, :, :, :] + target_h = min(decoded_5d.shape[2], reference_h) + target_w = min(decoded_5d.shape[3], reference_w) + decoded_5d = decoded_5d[:, :, :target_h, :target_w, :] + if color_correction_method in ("lab", "wavelet", "adain"): + reference_5d = reference_full[:b, :t, :, :, :] + reference_5d = cls._resize_reference(reference_5d, target_h, target_w) + output_device = decoded_5d.device + decoded_raw = cls._to_seedvr2_raw(decoded_5d) + reference_raw = cls._to_seedvr2_raw(reference_5d) + decoded_flat = rearrange(decoded_raw, "b t h w c -> (b t) c h w") + reference_flat = rearrange(reference_raw, "b t h w c -> (b t) c h w") + output = cls._color_transfer_chunked( + decoded_flat, reference_flat, output_device, color_correction_method, + ) + output = rearrange(output, "(b t) c h w -> b t h w c", b=b, t=t) + output = output.add(1.0).div(2.0).clamp(0.0, 1.0) + elif color_correction_method == "none": + output = decoded_5d + else: + raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") + + if alpha_input is not None: + alpha_5d, _ = cls._as_bthwc(alpha_input) + alpha_5d = alpha_5d[:output.shape[0], :output.shape[1], :output.shape[2], :output.shape[3], :] + output = torch.cat([output, alpha_5d.to(dtype=output.dtype, device=output.device)], dim=-1) + h2 = output.shape[-3] - (output.shape[-3] % 2) + w2 = output.shape[-2] - (output.shape[-2] % 2) + output = output[:, :, :h2, :w2, :] + if decoded_was_4d: + output = output.reshape(-1, output.shape[-3], output.shape[-2], output.shape[-1]) + return io.NodeOutput(output) + + @staticmethod + def _as_bthwc(images): + if images.ndim == 4: + return images.unsqueeze(0), True + if images.ndim == 5: + return images, False + raise ValueError( + f"SeedVR2PostProcessing: expected 4-D or 5-D IMAGE tensor, got shape {tuple(images.shape)}" + ) + + @staticmethod + def _restore_reference_batch_time(decoded, reference): + if decoded.shape[0] != 1: + return decoded + ref_b, ref_t = reference.shape[:2] + if ref_b < 1 or decoded.shape[1] % ref_b != 0: + return decoded + decoded_t = decoded.shape[1] // ref_b + if decoded_t < ref_t: + return decoded + return decoded.reshape(ref_b, decoded_t, decoded.shape[2], decoded.shape[3], decoded.shape[4]) + + @staticmethod + def _to_seedvr2_raw(images): + return images.mul(2.0).sub(1.0) + + @staticmethod + def _color_transfer_on_vae_device(decoded_flat, reference_flat, output_device, transfer_fn): + color_device = comfy.model_management.vae_device() + decoded_flat = decoded_flat.to(device=color_device) + reference_flat = reference_flat.to(device=color_device) + output = transfer_fn(decoded_flat, reference_flat) + return output.to(device=output_device) + + @staticmethod + def _lab_color_transfer_on_vae_device(decoded_flat, reference_flat, output_device): + color_device = comfy.model_management.vae_device() + result = None + for start in range(decoded_flat.shape[0]): + decoded_frame = decoded_flat[start:start + 1].to(device=color_device).clone() + reference_frame = reference_flat[start:start + 1].to(device=color_device).clone() + output = lab_color_transfer(decoded_frame, reference_frame).to(device=output_device) + if result is None: + result = torch.empty( + (decoded_flat.shape[0],) + tuple(output.shape[1:]), + device=output_device, + dtype=output.dtype, + ) + result[start:start + 1].copy_(output) + if result is None: + raise ValueError("SeedVR2PostProcessing: LAB color correction requires at least one frame.") + return result + + @classmethod + def _color_transfer_chunked(cls, decoded_flat, reference_flat, output_device, color_correction_method): + chunk_size = cls._estimate_color_correction_chunk_size(decoded_flat, color_correction_method) + while True: + next_chunk_size = None + try: + return cls._run_color_transfer_chunks( + decoded_flat, reference_flat, output_device, color_correction_method, chunk_size, + ) + except Exception as e: + comfy.model_management.raise_non_oom(e) + if chunk_size <= 1: + raise RuntimeError( + "SeedVR2PostProcessing: color correction OOM at one frame; " + f"color_correction_method={color_correction_method}, shape={tuple(decoded_flat.shape)}." + ) from e + next_chunk_size = max(1, chunk_size // SEEDVR2_OOM_BACKOFF_DIVISOR) + + comfy.model_management.soft_empty_cache() + chunk_size = next_chunk_size + + @classmethod + def _run_color_transfer_chunks(cls, decoded_flat, reference_flat, output_device, color_correction_method, chunk_size): + result = None + for start in range(0, decoded_flat.shape[0], chunk_size): + end = min(start + chunk_size, decoded_flat.shape[0]) + decoded_chunk = decoded_flat[start:end] + reference_chunk = reference_flat[start:end] + if color_correction_method == "lab": + output = cls._lab_color_transfer_on_vae_device(decoded_chunk, reference_chunk, output_device) + elif color_correction_method == "wavelet": + output = cls._color_transfer_on_vae_device( + decoded_chunk, reference_chunk, output_device, wavelet_color_transfer, + ) + else: + output = cls._color_transfer_on_vae_device( + decoded_chunk, reference_chunk, output_device, adain_color_transfer, + ) + if result is None: + result = torch.empty( + (decoded_flat.shape[0],) + tuple(output.shape[1:]), + device=output_device, + dtype=output.dtype, + ) + result[start:end].copy_(output) + if result is None: + raise ValueError("SeedVR2PostProcessing: color correction requires at least one frame.") + return result + + @classmethod + def _estimate_color_correction_chunk_size(cls, decoded_flat, color_correction_method): + multiplier = cls._color_correction_memory_multiplier(color_correction_method) + frames = decoded_flat.shape[0] + _, channels, height, width = decoded_flat.shape + dtype_bytes = max(decoded_flat.element_size(), SEEDVR2_DTYPE_BYTES_FLOOR) + bytes_per_frame = height * width * channels * dtype_bytes * multiplier + if bytes_per_frame <= 0: + return frames + color_device = comfy.model_management.vae_device() + free_memory = comfy.model_management.get_free_memory(color_device) + chunk_size = int((free_memory * SEEDVR2_COLOR_MEM_HEADROOM) // bytes_per_frame) + return max(1, min(frames, chunk_size)) + + @staticmethod + def _color_correction_memory_multiplier(color_correction_method): + if color_correction_method == "lab": + return SEEDVR2_LAB_SCALE_MULTIPLIER + if color_correction_method == "wavelet": + return SEEDVR2_WAVELET_SCALE_MULTIPLIER + if color_correction_method == "adain": + return SEEDVR2_ADAIN_SCALE_MULTIPLIER + raise ValueError(f"SeedVR2PostProcessing: unknown color_correction_method {color_correction_method!r}") + + @staticmethod + def _resize_reference(reference, height, width): + if reference.shape[2] == height and reference.shape[3] == width: + return reference + b, t = reference.shape[:2] + reference_flat = rearrange(reference, "b t h w c -> (b t) c h w") + resized = TVF.resize( + reference_flat, + size=(height, width), + interpolation=InterpolationMode.BICUBIC, + antialias=not (isinstance(reference_flat, torch.Tensor) and reference_flat.device.type == "mps"), + ) + return rearrange(resized, "(b t) c h w -> b t h w c", b=b, t=t) + + +class SeedVR2Conditioning(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2Conditioning", + display_name="Apply SeedVR2 Conditioning", + category="conditioning", + description="Build SeedVR2 positive/negative conditioning from a VAE latent.", + inputs=[ + io.Model.Input("model", tooltip="The SeedVR2 model."), + io.Latent.Input("vae_conditioning", display_name="latent"), + ], + outputs=[ + io.Model.Output(display_name = "model"), + io.Conditioning.Output(display_name = "positive"), + io.Conditioning.Output(display_name = "negative"), + io.Latent.Output(display_name = "latent"), + ], + ) + + @classmethod + def execute(cls, model, vae_conditioning) -> io.NodeOutput: + + vae_conditioning = vae_conditioning["samples"] + if vae_conditioning.ndim != 5: + raise ValueError( + "SeedVR2Conditioning expects a 5-D VAE latent in Comfy " + f"channel-first layout; got shape {tuple(vae_conditioning.shape)}." + ) + if vae_conditioning.shape[-1] == SEEDVR2_LATENT_CHANNELS and vae_conditioning.shape[1] != SEEDVR2_LATENT_CHANNELS: + raise ValueError( + "SeedVR2Conditioning expects SeedVR2 VAE latents in Comfy " + f"channel-first layout (B, {SEEDVR2_LATENT_CHANNELS}, T, H, W); " + f"got channel-last shape {tuple(vae_conditioning.shape)}." + ) + vae_conditioning = vae_conditioning.movedim(1, -1).contiguous() + model_patcher = model + model = _resolve_seedvr2_diffusion_model(model_patcher) + pos_cond = model.positive_conditioning + neg_cond = model.negative_conditioning + + # Fail-loud guard against silently-wrong output when a + # DiT-only ``.safetensors`` (no ``positive_conditioning`` / + # ``negative_conditioning`` keys) is loaded via ``UNETLoader``. + # ``NaDiT.__init__`` zero-fills the buffers via ``torch.zeros`` (see + # ``comfy/ldm/seedvr/model.py``); ``load_state_dict(strict=False)`` + # leaves them at zero when the keys are absent. Detect that state + # here rather than at ``BaseModel.extra_conds`` (per sampling step, + # wasteful) or at the resolver helper (mixes structural shape with + # semantic content). Both buffers must be checked together — partial + # bake regressions could populate one but not the other. + if ( + pos_cond.float().abs().sum().item() == 0 + and neg_cond.float().abs().sum().item() == 0 + ): + raise RuntimeError( + f"{_SEEDVR2_INVALID_MODEL_MSG_PREFIX}: positive_conditioning " + f"and negative_conditioning buffers are zero-valued — model " + f"file appears to be a DiT-only export missing " + f"the SeedVR2 conditioning tensors. " + f"Re-bake the file with ``positive_conditioning`` (58, 5120) " + f"and ``negative_conditioning`` (64, 5120) keys at top level, " + f"or load via CheckpointLoaderSimple from a bundled " + f"checkpoint." + ) + + _apply_rope_freqs_float32_cast(model) + + condition = torch.stack([get_conditions(c, c) for c in vae_conditioning]) + condition = condition.movedim(-1, 1) + latent = vae_conditioning.movedim(-1, 1) + + latent = rearrange(latent, "b c t h w -> b (c t) h w") + condition = rearrange(condition, "b c t h w -> b (c t) h w") + + negative = [[neg_cond.unsqueeze(0), {"condition": condition}]] + positive = [[pos_cond.unsqueeze(0), {"condition": condition}]] + + return io.NodeOutput(model_patcher, positive, negative, {"samples": latent}) + +def _slice_collapsed_4d_along_t(tensor_4d: torch.Tensor, t_start: int, + t_end: int, channels: int) -> torch.Tensor: + """Slice collapsed ``(B, channels*T, H, W)`` along latent T: reshape (accepts non-contiguous inputs), slice, ``.contiguous()`` (T-slice of 5D is a non-contiguous view; re-collapse needs contiguous), re-collapse.""" + B, CT, H, W = tensor_4d.shape + if CT % channels != 0: + raise ValueError( + f"_slice_collapsed_4d_along_t: collapsed channel dim {CT} is not " + f"divisible by channels={channels}; tensor shape {tuple(tensor_4d.shape)}." + ) + T = CT // channels + if not (0 <= t_start < t_end <= T): + raise ValueError( + f"_slice_collapsed_4d_along_t: slice [{t_start}:{t_end}] out of " + f"range for T={T}." + ) + new_T = t_end - t_start + sliced = tensor_4d.reshape(B, channels, T, H, W)[:, :, t_start:t_end, :, :].contiguous() + return sliced.reshape(B, channels * new_T, H, W) + + +def _slice_seedvr2_cond_along_t(cond_list, t_start: int, t_end: int): + """Return a new conditioning list with each entry's ``options["condition"]`` (collapsed ``(B, 17*T, H, W)``) sliced along latent T; text tensors, other option keys, and condition-less entries pass through unchanged and inputs are not mutated.""" + new_list = [] + for entry in cond_list: + text_cond, options = entry[0], entry[1] + if "condition" not in options: + new_list.append(entry) + continue + new_options = options.copy() + new_options["condition"] = _slice_collapsed_4d_along_t( + new_options["condition"], t_start, t_end, + SEEDVR2_COND_CHANNELS, + ) + new_list.append([text_cond, new_options]) + return new_list + + +def _slice_seedvr2_noise_mask_along_t(noise_mask: torch.Tensor, + samples_4d: torch.Tensor, + t_start: int, + t_end: int): + """Slice only masks already expanded to collapsed ``(B, 16*T, H, W)``; pass standard ``(B, 1, H, W)`` ``SetLatentNoiseMask`` outputs through for KSampler to expand.""" + if noise_mask.ndim == samples_4d.ndim and noise_mask.shape[1] == samples_4d.shape[1]: + return _slice_collapsed_4d_along_t( + noise_mask, t_start, t_end, SEEDVR2_LATENT_CHANNELS, + ) + return noise_mask + + +def _concat_chunks_along_t(chunks_4d, channels: int) -> torch.Tensor: + """Concatenate collapsed ``(B, channels*T_i, H, W)`` chunks along latent T: un-collapse to 5D, cat on ``dim=2``, re-collapse to 4D.""" + if len(chunks_4d) == 0: + raise ValueError("_concat_chunks_along_t: empty chunk list.") + fives = [] + for ch in chunks_4d: + B, CT, H, W = ch.shape + if CT % channels != 0: + raise ValueError( + f"_concat_chunks_along_t: chunk shape {tuple(ch.shape)} " + f"channel dim {CT} not divisible by channels={channels}." + ) + T = CT // channels + fives.append(ch.reshape(B, channels, T, H, W)) + cat = torch.cat(fives, dim=2).contiguous() + B, C, T_total, H, W = cat.shape + return cat.reshape(B, C * T_total, H, W) + + +def _hann_blend_weights_1d(overlap: int, device, dtype) -> torch.Tensor: + """1D length-``overlap`` crossfade weights for the previous chunk (current = ``1 - w_prev``): + Hann window with a ``[1/3, 2/3]`` dead-band for ``overlap >= 3``, linear ramp for ``overlap < 3`` + (dead-band would collapse a tiny transition). Window shape matched to the reference + overlapping-frame blend for parity; caller broadcasts across ``(B, C, T_overlap, H, W)``. + """ + if overlap < 1: + raise ValueError( + f"_hann_blend_weights_1d: overlap must be >= 1; got {overlap}." + ) + if overlap >= 3: + t = torch.linspace(0.0, 1.0, steps=overlap, device=device, dtype=dtype) + blend_start = 1.0 / 3.0 + blend_end = 2.0 / 3.0 + u = ((t - blend_start) / (blend_end - blend_start)).clamp(0.0, 1.0) + return 0.5 + 0.5 * torch.cos(torch.pi * u) + return torch.linspace(1.0, 0.0, steps=overlap, device=device, dtype=dtype) + + +def _blend_overlap_region(prev_tail_5d: torch.Tensor, + cur_head_5d: torch.Tensor) -> torch.Tensor: + """Blend two equal-shape 5D ``(B, C, T_overlap, H, W)`` tensors with a 1D Hann/linear T-ramp: ``prev_tail_5d`` takes the descending weight, ``cur_head_5d`` takes ``1 - w_prev`` (caller ensures matching shape/dtype/device).""" + if prev_tail_5d.shape != cur_head_5d.shape: + raise ValueError( + f"_blend_overlap_region: shape mismatch " + f"prev {tuple(prev_tail_5d.shape)} vs " + f"cur {tuple(cur_head_5d.shape)}." + ) + overlap = int(prev_tail_5d.shape[2]) + w_prev_1d = _hann_blend_weights_1d( + overlap, prev_tail_5d.device, prev_tail_5d.dtype, + ) + # Reshape to (1, 1, overlap, 1, 1) for broadcast across B, C, H, W. + w_prev = w_prev_1d.view(1, 1, overlap, 1, 1) + w_cur = 1.0 - w_prev + return prev_tail_5d * w_prev + cur_head_5d * w_cur + + +def _concat_chunks_with_overlap_blend(chunk_specs, channels: int, + overlap_latent: int) -> torch.Tensor: + """Concatenate overlapping ``(t_start, t_end, chunk_4d)`` specs (source-latent T coords) into one collapsed 4D tensor, Hann/linear-blending overlaps; ``overlap_latent == 0`` fast-paths to plain concat (bit-identical to ``_concat_chunks_along_t``). Each blend uses the actual width ``min(prev_end - cur_start, chunk length)``, smaller than ``overlap_latent`` for a runt final chunk.""" + if len(chunk_specs) == 0: + raise ValueError("_concat_chunks_with_overlap_blend: empty chunk list.") + if overlap_latent < 0: + raise ValueError( + f"_concat_chunks_with_overlap_blend: overlap_latent must be " + f">= 0; got {overlap_latent}." + ) + + # Validate channel divisibility once and capture per-chunk T. + chunk_5d = [] + for t_start, t_end, ch in chunk_specs: + B, CT, H, W = ch.shape + if CT % channels != 0: + raise ValueError( + f"_concat_chunks_with_overlap_blend: chunk shape " + f"{tuple(ch.shape)} channel dim {CT} not divisible " + f"by channels={channels}." + ) + T = CT // channels + if t_end - t_start != T: + raise ValueError( + f"_concat_chunks_with_overlap_blend: chunk T={T} mismatches " + f"declared range [{t_start}:{t_end}]." + ) + chunk_5d.append((t_start, t_end, ch.reshape(B, channels, T, H, W))) + + if overlap_latent == 0: + # Fast path: pure concat in the caller-provided chunk order. + return _concat_chunks_along_t( + [c.reshape(c.shape[0], channels * c.shape[2], c.shape[3], c.shape[4]) + for _, _, c in chunk_5d], + channels, + ) + + T_total = max(t_end for _, t_end, _ in chunk_5d) + first_5d = chunk_5d[0][2] + B = first_5d.shape[0] + H = first_5d.shape[3] + W = first_5d.shape[4] + result = torch.empty( + (B, channels, T_total, H, W), + device=first_5d.device, dtype=first_5d.dtype, + ) + filled_until = 0 + for i, (cs, ce, ct_5d) in enumerate(chunk_5d): + chunk_T = int(ct_5d.shape[2]) + if i == 0: + result[:, :, cs:ce, :, :] = ct_5d + filled_until = ce + continue + # Overlap region width is bounded by both the previous fill + # frontier and the current chunk's actual length (for runt + # final chunks shorter than the configured overlap). + overlap_len = min(filled_until - cs, chunk_T) + if overlap_len > 0: + prev_tail = result[:, :, cs:cs + overlap_len, :, :].contiguous() + cur_head = ct_5d[:, :, :overlap_len, :, :].contiguous() + blended = _blend_overlap_region(prev_tail, cur_head) + result[:, :, cs:cs + overlap_len, :, :] = blended + tail_start = cs + overlap_len + tail_end = ce + if tail_end > tail_start: + result[:, :, tail_start:tail_end, :, :] = ( + ct_5d[:, :, overlap_len:, :, :] + ) + else: + # Disjoint chunks (overlap_latent set but this pair did not + # actually overlap, e.g. step_latent equal to chunk_latent + # in a degenerate config). Treat as concat. + result[:, :, cs:ce, :, :] = ct_5d + filled_until = ce + + return result.contiguous().reshape(B, channels * T_total, H, W) + + +def _run_standard_sample(model, seed: int, steps: int, cfg: float, + sampler_name: str, scheduler: str, + positive, negative, latent: dict, + denoise: float) -> dict: + """Single-shot mirror of ``nodes.py:common_ksampler`` (seed -> noise, ``comfy.sample.sample``, latent dict); used by the ProgressiveSampler short-circuit when the whole sequence fits one chunk.""" + samples_in = latent["samples"] + samples_in = comfy.sample.fix_empty_latent_channels( + model, samples_in, latent.get("downscale_ratio_spacial", None), + ) + batch_inds = latent.get("batch_index", None) + noise = comfy.sample.prepare_noise(samples_in, seed, batch_inds) + noise_mask = latent.get("noise_mask", None) + samples = comfy.sample.sample( + model, noise, steps, cfg, sampler_name, scheduler, + positive, negative, samples_in, + denoise=denoise, noise_mask=noise_mask, seed=seed, + ) + out = latent.copy() + out.pop("downscale_ratio_spacial", None) + out["samples"] = samples + return out + + +class SeedVR2ProgressiveSampler(io.ComfyNode): + """Sequential temporal chunking sampler for SeedVR2 native. + + Drop-in replacement for ``KSampler`` in SeedVR2 native workflows that + OOM on long sequences. The latent enters the sampler in SeedVR2's + collapsed form ``(B, 16*T, H, W)`` (collapsed by ``SeedVR2Conditioning`` + at ``rearrange(b c t h w -> b (c t) h w)``); this node slices that + tensor along the temporal axis, runs the configured inner sampler + sequentially per chunk against the standard ``comfy.sample.sample`` + entry point, and concatenates per-chunk outputs back into a single + ``(B, 16*T_total, H, W)`` latent. + + ``frames_per_chunk`` is expressed in pixel-frame units to match the + SeedVR2 4n+1 constraint enforced upstream by ``cut_videos`` and the + VAE's ``temporal_downsample_factor=4``. A pixel chunk size ``F`` + maps to ``(F - 1) // 4 + 1`` latent-frame chunks. + + Determinism contract: a single noise tensor is generated once from + the user seed and sliced per chunk (rather than re-seeding each + chunk), so a workflow that fits in a single chunk produces output + identical to a workflow that fits in N chunks at the same seed, + modulo the inherent T-axis chunk-boundary independence of the model. + """ + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SeedVR2ProgressiveSampler", + display_name="Sample SeedVR2 (Progressive)", + category="sampling", + description="Sample a SeedVR2 latent in sequential temporal chunks to allow longer videos to fit into VRAM via frame blending the resulting upscaled latents.", + inputs=[ + io.Model.Input("model", tooltip="The model used for denoising the input latent."), + io.Int.Input("seed", default=0, min=0, + max=0xffffffffffffffff, + control_after_generate=True, + tooltip="The random seed used for creating the noise."), + io.Int.Input("steps", default=20, min=1, max=10000, + tooltip="The number of steps used in the denoising process."), + io.Float.Input("cfg", default=1.0, min=0.0, max=100.0, + step=0.1, round=0.01, + tooltip="The Classifier-Free Guidance scale balances creativity and adherence to the prompt. Higher values result in images more closely matching the prompt however too high values will negatively impact quality."), + io.Combo.Input("sampler_name", + options=comfy.samplers.SAMPLER_NAMES, + tooltip="The algorithm used when sampling, this can affect the quality, speed, and style of the generated output."), + io.Combo.Input("scheduler", + options=comfy.samplers.SCHEDULER_NAMES, + tooltip="The scheduler controls how noise is gradually removed to form the image."), + io.Conditioning.Input("positive", + tooltip="The conditioning describing the attributes you want to include in the image."), + io.Conditioning.Input("negative", + tooltip="The conditioning describing the attributes you want to exclude from the image."), + io.Latent.Input("latent", + tooltip="The latent image to denoise."), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, + step=0.01, + tooltip="The amount of denoising applied, lower values will maintain the structure of the initial image allowing for image to image sampling."), + io.Int.Input("frames_per_chunk", default=21, min=1, + max=16384, step=4, + tooltip="Pixel frames per temporal chunk (4n+1: 1, 5, 9, 13, ...)."), + io.Int.Input("temporal_overlap", default=0, min=0, + max=16384, + tooltip="Latent frames blended between adjacent chunks to hide the seam; 0 = no blend."), + io.Combo.Input("chunking_mode", + options=["manual", "auto"], + default="manual", + tooltip="manual = use frames_per_chunk exactly; auto = shrink the chunk until it fits in VRAM."), + ], + outputs=[io.Latent.Output(display_name="latent")], + ) + + @classmethod + def execute(cls, model, seed, steps, cfg, sampler_name, scheduler, + positive, negative, latent, denoise, + frames_per_chunk, temporal_overlap, + chunking_mode="manual") -> io.NodeOutput: + # 4n+1 validation in pixel-frame domain. The SeedVR2 native pipeline + # requires pixel-frame counts of the form 4n+1 (1, 5, 9, 13, ...), + # imposed at ``cut_videos`` upstream and propagated through the VAE's + # temporal_downsample_factor=4. Reject violations explicitly before + # any model invocation; a silent rounding would mis-align chunk + # boundaries with the 4n+1 lattice. + if frames_per_chunk < 1 or (frames_per_chunk - 1) % 4 != 0: + raise ValueError( + f"SeedVR2ProgressiveSampler: frames_per_chunk must be a " + f"4n+1 pixel-frame count (1, 5, 9, 13, 17, 21, ...); " + f"got {frames_per_chunk}." + ) + + samples_4d = latent["samples"] + samples_4d = comfy.sample.fix_empty_latent_channels( + model, samples_4d, + latent.get("downscale_ratio_spacial", None), + ) + if samples_4d.ndim != 4: + raise ValueError( + f"SeedVR2ProgressiveSampler: expected 4D collapsed latent " + f"(B, 16*T, H, W); got shape {tuple(samples_4d.shape)}." + ) + B, CT, H, W = samples_4d.shape + if CT % SEEDVR2_LATENT_CHANNELS != 0: + raise ValueError( + f"SeedVR2ProgressiveSampler: collapsed channel dim {CT} is " + f"not divisible by SeedVR2 latent channels " + f"{SEEDVR2_LATENT_CHANNELS}; latent does not appear to be " + f"SeedVR2-shaped." + ) + T_latent = CT // SEEDVR2_LATENT_CHANNELS + T_pixel = 4 * (T_latent - 1) + 1 + + if chunking_mode not in ("manual", "auto"): + raise ValueError( + f"SeedVR2ProgressiveSampler: chunking_mode must be " + f"'manual' or 'auto'; got {chunking_mode!r}." + ) + + if chunking_mode == "auto": + attempts = _seedvr2_auto_chunk_attempts( + T_latent, T_pixel, frames_per_chunk, + ) + for i, attempt_frames_per_chunk in enumerate(attempts): + retry = False + try: + return cls.execute( + model=model, seed=seed, steps=steps, cfg=cfg, + sampler_name=sampler_name, scheduler=scheduler, + positive=positive, negative=negative, + latent=latent, denoise=denoise, + frames_per_chunk=attempt_frames_per_chunk, + temporal_overlap=temporal_overlap, + chunking_mode="manual", + ) + except Exception as e: + comfy.model_management.raise_non_oom(e) + if i == len(attempts) - 1: + raise RuntimeError( + "SeedVR2ProgressiveSampler: exhausted auto " + "chunking attempts after OOM. Tried " + f"frames_per_chunk values {attempts}." + ) from e + retry = True + + if retry: + logging.warning( + "SeedVR2ProgressiveSampler auto chunking OOM at " + "frames_per_chunk=%s; retrying with " + "frames_per_chunk=%s.", + attempt_frames_per_chunk, attempts[i + 1], + ) + comfy.model_management.soft_empty_cache() + + # Short-circuit: total fits in one chunk -> standard path with no + # chunking overhead. Output of this branch is byte-identical to the + # built-in KSampler given the same (model, seed, steps, cfg, + # sampler_name, scheduler, positive, negative, latent, + # denoise) tuple. + if T_pixel <= frames_per_chunk: + return io.NodeOutput(_run_standard_sample( + model, seed, steps, cfg, sampler_name, scheduler, + positive, negative, latent, denoise, + )) + + # Map pixel chunk -> latent chunk. Each chunk's latent length is + # at most ``chunk_latent``; the final chunk may be a runt that + # is automatically 4n+1-aligned in the pixel domain by the + # T_pixel = 4*(T_latent-1) + 1 mapping (every positive integer + # T_latent corresponds to a valid 4n+1 pixel count). + chunk_latent = (frames_per_chunk - 1) // 4 + 1 + + # ``temporal_overlap`` is exposed in latent-frame units, but users + # do not know the derived latent chunk length. Treat oversized + # values as "maximum valid overlap" while preserving a strictly + # positive chunk-loop stride. + if temporal_overlap < 0: + raise ValueError( + f"SeedVR2ProgressiveSampler: temporal_overlap must be >= 0; " + f"got {temporal_overlap}." + ) + temporal_overlap = min(temporal_overlap, chunk_latent - 1) + step_latent = chunk_latent - temporal_overlap + + # Generate full noise once from the user seed, then slice along T + # per chunk. Using one global noise tensor (rather than re-seeding + # per chunk) preserves seed-determinism across chunk-count + # variations: the same (seed, total T_latent) always produces the + # same noise samples regardless of how the work is partitioned. + batch_inds = latent.get("batch_index", None) + noise_full = comfy.sample.prepare_noise(samples_4d, seed, batch_inds) + + noise_mask = latent.get("noise_mask", None) + + # Build the flat list of chunk ranges first so the chunking + # geometry is fully known before any sample call. + chunk_ranges = [] + for chunk_start in range(0, T_latent, step_latent): + chunk_end = min(chunk_start + chunk_latent, T_latent) + if chunk_start >= chunk_end: + # The final iteration of a stride that lands exactly on + # T_latent produces a zero-length chunk; skip it. + break + chunk_ranges.append((chunk_start, chunk_end)) + if chunk_end >= T_latent: + break + + def _sample_one_chunk(chunk_start, chunk_end): + samples_chunk = _slice_collapsed_4d_along_t( + samples_4d, chunk_start, chunk_end, + SEEDVR2_LATENT_CHANNELS, + ) + noise_chunk = _slice_collapsed_4d_along_t( + noise_full, chunk_start, chunk_end, + SEEDVR2_LATENT_CHANNELS, + ) + positive_chunk = _slice_seedvr2_cond_along_t( + positive, chunk_start, chunk_end, + ) + negative_chunk = _slice_seedvr2_cond_along_t( + negative, chunk_start, chunk_end, + ) + + # Per-chunk noise_mask handling: standard masks are passed + # through for KSampler expansion; pre-expanded collapsed + # masks are sliced. + chunk_noise_mask = None + if noise_mask is not None: + chunk_noise_mask = _slice_seedvr2_noise_mask_along_t( + noise_mask, samples_4d, chunk_start, chunk_end, + ) + + return comfy.sample.sample( + model, noise_chunk, steps, cfg, sampler_name, scheduler, + positive_chunk, negative_chunk, samples_chunk, + denoise=denoise, noise_mask=chunk_noise_mask, seed=seed, + ) + + chunk_specs = [] + for chunk_start, chunk_end in chunk_ranges: + chunk_samples = _sample_one_chunk(chunk_start, chunk_end) + chunk_specs.append((chunk_start, chunk_end, chunk_samples)) + + final = _concat_chunks_with_overlap_blend( + chunk_specs, SEEDVR2_LATENT_CHANNELS, temporal_overlap, + ) + + out = latent.copy() + out.pop("downscale_ratio_spacial", None) + out["samples"] = final + return io.NodeOutput(out) + + +class SeedVRExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SeedVR2Conditioning, + SeedVR2Preprocess, + SeedVR2PostProcessing, + SeedVR2ProgressiveSampler, + ] + +async def comfy_entrypoint() -> SeedVRExtension: + return SeedVRExtension() diff --git a/comfy_extras/nodes_stable3d.py b/comfy_extras/nodes_stable3d.py index 829c837a1..8a6e5b726 100644 --- a/comfy_extras/nodes_stable3d.py +++ b/comfy_extras/nodes_stable3d.py @@ -27,7 +27,7 @@ class StableZero123_Conditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StableZero123_Conditioning", - category="conditioning/3d_models", + category="model/conditioning/3d_models", inputs=[ io.ClipVision.Input("clip_vision"), io.Image.Input("init_image"), @@ -65,7 +65,7 @@ class StableZero123_Conditioning_Batched(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StableZero123_Conditioning_Batched", - category="conditioning/3d_models", + category="model/conditioning/3d_models", inputs=[ io.ClipVision.Input("clip_vision"), io.Image.Input("init_image"), @@ -112,7 +112,7 @@ class SV3D_Conditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="SV3D_Conditioning", - category="conditioning/3d_models", + category="model/conditioning/3d_models", inputs=[ io.ClipVision.Input("clip_vision"), io.Image.Input("init_image"), diff --git a/comfy_extras/nodes_stable_cascade.py b/comfy_extras/nodes_stable_cascade.py index 0dc6c9fcd..e55f248ae 100644 --- a/comfy_extras/nodes_stable_cascade.py +++ b/comfy_extras/nodes_stable_cascade.py @@ -29,7 +29,7 @@ class StableCascade_EmptyLatentImage(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StableCascade_EmptyLatentImage", - category="latent/stable_cascade", + category="model/latent/stable_cascade", inputs=[ io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8), @@ -58,7 +58,7 @@ class StableCascade_StageC_VAEEncode(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StableCascade_StageC_VAEEncode", - category="latent/stable_cascade", + category="model/latent/stable_cascade", inputs=[ io.Image.Input("image"), io.Vae.Input("vae"), @@ -93,7 +93,7 @@ class StableCascade_StageB_Conditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="StableCascade_StageB_Conditioning", - category="conditioning/stable_cascade", + category="model/conditioning/stable_cascade", inputs=[ io.Conditioning.Input("conditioning"), io.Latent.Input("stage_c"), diff --git a/comfy_extras/nodes_tomesd.py b/comfy_extras/nodes_tomesd.py index 87bf29b8f..3667fac3a 100644 --- a/comfy_extras/nodes_tomesd.py +++ b/comfy_extras/nodes_tomesd.py @@ -151,7 +151,7 @@ class TomePatchModel(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="TomePatchModel", - category="model_patches/unet", + category="model/patch/unet", inputs=[ io.Model.Input("model"), io.Float.Input("ratio", default=0.3, min=0.0, max=1.0, step=0.01), diff --git a/comfy_extras/nodes_toolkit.py b/comfy_extras/nodes_toolkit.py index 0548a0cf8..9f709bbe3 100644 --- a/comfy_extras/nodes_toolkit.py +++ b/comfy_extras/nodes_toolkit.py @@ -13,7 +13,7 @@ class CreateList(io.ComfyNode): return io.Schema( node_id="CreateList", display_name="Create List", - category="utils", + category="utilities", is_input_list=True, search_aliases=["Image Iterator", "Text Iterator", "Iterator"], inputs=[io.Autogrow.Input("inputs", template=template_autogrow)], diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index e9871369b..046eeaaf5 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -951,7 +951,7 @@ class TrainLoraNode(io.ComfyNode): return io.Schema( node_id="TrainLoraNode", display_name="Train LoRA", - category="training", + category="model/training", is_experimental=True, is_input_list=True, # All inputs become lists inputs=[ @@ -1309,7 +1309,7 @@ class LoraModelLoader(io.ComfyNode): return io.Schema( node_id="LoraModelLoader", display_name="Load LoRA Model", - category="loaders", + category="model/loaders", is_experimental=True, inputs=[ io.Model.Input( @@ -1405,7 +1405,7 @@ class LossGraphNode(io.ComfyNode): node_id="LossGraphNode", search_aliases=["training chart", "training visualization", "plot loss"], display_name="Plot Loss Graph", - category="training", + category="model/training", is_experimental=True, is_output_node=True, inputs=[ diff --git a/comfy_extras/nodes_triposplat.py b/comfy_extras/nodes_triposplat.py new file mode 100644 index 000000000..1848ad31a --- /dev/null +++ b/comfy_extras/nodes_triposplat.py @@ -0,0 +1,270 @@ +# TripoSplat nodes: image -> 3D gaussian splat + +import logging + +import torch +import torch.nn.functional as F +from typing_extensions import override + +import comfy.model_management +import comfy.nested_tensor +import comfy.patcher_extension +import comfy.utils +from comfy_api.latest import ComfyExtension, IO, Types + + +_Q_TOKEN_LENGTH = 8192 +_LATENT_CHANNELS = 16 +_CAM_CHANNELS = 5 +_DINOV3_MEAN = [0.485, 0.456, 0.406] +_DINOV3_STD = [0.229, 0.224, 0.225] +_NUM_GAUSSIANS_MIN = 32768 +_NUM_GAUSSIANS_MAX = 1048576 + + +def _preprocess(image: torch.Tensor, mask: torch.Tensor, erode_radius: int, size: int) -> torch.Tensor: + # Match original preprocessing: + # resize min side to `size` -> erode alpha -> alpha bbox -> 1.2x square crop -> resize -> composite on black. + rgb = image[..., :3].clamp(0, 1).movedim(-1, 0) # (3, H, W) + alpha = mask.clamp(0, 1)[None] # (1, H, W) + rgba = torch.cat([rgb, alpha], 0)[None] # (1, 4, H, W) + + h, w = rgba.shape[-2:] + s = size / min(w, h) + rgba = comfy.utils.common_upscale(rgba, max(1, round(w * s)), max(1, round(h * s)), "lanczos", "disabled").clamp(0, 1) + + a = rgba[:, 3:4] + if erode_radius > 0: + # min filter over a (2r+1) window == morphological erosion of the alpha matte. + a = -F.max_pool2d(-a, 2 * erode_radius + 1, stride=1, padding=erode_radius) + rgba = torch.cat([rgba[:, :3], a], 1) + + ys, xs = torch.nonzero(a[0, 0] > 0, as_tuple=True) + if xs.numel() == 0: + raise ValueError("TripoSplatPreprocessImage: mask is empty (no foreground pixels).") + x0, x1 = int(xs.min()), int(xs.max()) + y0, y1 = int(ys.min()), int(ys.max()) + cx, cy = (x0 + x1) / 2, (y0 + y1) / 2 + half = max(x1 - x0, y1 - y0) / 2 * 1.2 + left, upper, right, lower = int(cx - half), int(cy - half), int(cx + half), int(cy + half) + + H, W = rgba.shape[-2:] + crop = rgba.new_zeros((1, 4, lower - upper, right - left)) # out-of-bounds stays 0, matching PIL.crop + sx0, sy0, sx1, sy1 = max(left, 0), max(upper, 0), min(right, W), min(lower, H) + if sx1 > sx0 and sy1 > sy0: + crop[:, :, sy0 - upper:sy1 - upper, sx0 - left:sx1 - left] = rgba[:, :, sy0:sy1, sx0:sx1] + + crop = comfy.utils.common_upscale(crop, size, size, "lanczos", "disabled").clamp(0, 1) + out = (crop[:, :3] * crop[:, 3:4])[0].movedim(0, -1) # composite over black == rgb * alpha + return out.unsqueeze(0) # (1, 1024, 1024, 3) + + +class TripoSplatPreprocessImage(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoSplatPreprocessImage", + display_name="TripoSplat Preprocess Image", + category="3d/conditioning", + description="Crop center each image to a square canvas on a black background and add padding.", + inputs=[ + IO.Image.Input("image"), + IO.Mask.Input("mask"), + IO.Int.Input("erode_radius", default=1, min=0, max=16, + tooltip="Erode the alpha matte by this pixel radius before cropping (avoids border bleed)."), + IO.Int.Input("size", default=1024, min=256, max=4096, step=16, + tooltip="Square image size. The model is trained at 1024; other sizes run but are off-distribution."), + ], + outputs=[IO.Image.Output(display_name="image")], + ) + + @classmethod + def execute(cls, image, mask, erode_radius, size) -> IO.NodeOutput: + size = max(16, (int(size) // 16) * 16) # DINOv3 patch / Flux2 VAE stride is 16 + if mask.shape[0] != image.shape[0]: + mask = comfy.utils.repeat_to_batch_size(mask, image.shape[0]) + if tuple(mask.shape[1:]) != tuple(image.shape[1:3]): + mask = F.interpolate(mask[:, None].float(), size=tuple(image.shape[1:3]), mode="bilinear", align_corners=False)[:, 0] + prepared = torch.cat([_preprocess(image[i], mask[i], erode_radius, size) for i in range(image.shape[0])], dim=0) + return IO.NodeOutput(prepared) + + +class TripoSplatConditioning(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoSplatConditioning", + display_name="TripoSplat Conditioning", + category="3d/conditioning", + description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative " + "conditioning, and create the fixed size noise target (latent + camera) for the KSampler", + inputs=[ + IO.ClipVision.Input("clip_vision", tooltip="DINOv3 ViT-H/16+ image encoder"), + IO.Vae.Input("vae", tooltip="Flux2 VAE"), + IO.Image.Input("image"), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + IO.Latent.Output(display_name="latent", tooltip="The fixed size noise target (latent +camera)."), + ], + ) + + @classmethod + def execute(cls, clip_vision, vae, image) -> IO.NodeOutput: + # feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top + comfy.model_management.load_model_gpu(clip_vision.patcher) + device = clip_vision.load_device + img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1] + mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1) + std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1) + img = (img - mean) / std + seq = clip_vision.model(pixel_values=img.float())[0] + feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device()) + + # Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry + ref = vae.encode(image).to(comfy.model_management.intermediate_device()) # (B, 128, H, W) + b = ref.shape[0] + + positive = [[feature1, {"reference_latents": [ref]}]] + negative = [[torch.zeros_like(feature1), {"reference_latents": [torch.zeros_like(ref)]}]] + + # Fixed noise target: the latent is a constant-shape (8192, 16) shape-code + a (1, 5) camera token + dev = comfy.model_management.intermediate_device() + latent_seq = torch.zeros([b, _Q_TOKEN_LENGTH, _LATENT_CHANNELS], device=dev) + camera = torch.zeros([b, 1, _CAM_CHANNELS], device=dev) + samples = comfy.nested_tensor.NestedTensor((latent_seq, camera)) + return IO.NodeOutput(positive, negative, {"samples": samples}) + + +class VAEDecodeTripoSplat(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeTripoSplat", + display_name="TripoSplat Decode", + category="3d/latent", + description="Decode the sampled TripoSplat latent into a 3D gaussian splat. " + "Modify the number of gaussians to vary the density.", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"), + IO.Int.Input("num_gaussians", default=262144, min=_NUM_GAUSSIANS_MIN, max=_NUM_GAUSSIANS_MAX, step=32, + tooltip="Number of gaussians to produce (rounded to a multiple of 32). " + "262144 matches the octree's point density; higher oversamples the same points " + "(denser, but no new detail) and costs proportionally more VRAM/time."), + IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff, + tooltip="Seeds the octree point sampler (global RNG) for deterministic decodes."), + ], + outputs=[IO.Splat.Output(display_name="splat")], + ) + + @classmethod + def execute(cls, samples, vae, num_gaussians, seed) -> IO.NodeOutput: + s = samples["samples"] + latent = s.unbind()[0] if getattr(s, "is_nested", False) else s # take the latent stream, drop camera + + decoder = vae.first_stage_model + gpp = decoder.gaussians_per_point + n = max(_NUM_GAUSSIANS_MIN, min(_NUM_GAUSSIANS_MAX, int(num_gaussians))) + if n % gpp != 0: + n = round(n / gpp) * gpp + + dtype_size = comfy.model_management.dtype_size(vae.vae_dtype) + hidden = decoder.gs.model_channels + cond_tokens = latent.shape[1] + memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size + comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required) + latent = latent.to(device=vae.device, dtype=vae.vae_dtype) + generator = torch.Generator(device="cpu").manual_seed(seed) + parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n, generator=generator)] + positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts)) + return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh)) + + +class TripoSplatSamplingPreview(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoSplatSamplingPreview", + display_name="TripoSplat Sampling Preview", + category="3d/latent", + description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded " + "gaussian splat preview at each step.", + inputs=[ + IO.Model.Input("model"), + IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"), + IO.Int.Input("octree_level", default=5, min=2, max=8, advanced=True, + tooltip="Octree depth for the preview decode (lower = cheaper/coarser)."), + IO.Int.Input("num_gaussians", default=16384, min=1024, max=262144, step=32, + tooltip="Number of gaussians to produce for the preview (rounded to a multiple of 32)."), + IO.Float.Input("yaw", default=90.0, min=-360.0, max=360.0, step=1.0, tooltip="Preview camera yaw in degrees.", advanced=True,), + IO.Float.Input("pitch", default=15.0, min=-89.0, max=89.0, step=1.0, tooltip="Preview camera pitch in degrees.", advanced=True,), + IO.Int.Input("point_size", default=3, min=1, max=16, + tooltip="Maximum splat radius in pixels. Each gaussian is sized from its scale and capped here; " + "lower = finer/pointier, higher = chunkier."), + ], + outputs=[IO.Model.Output()], + ) + + @classmethod + def execute(cls, model, vae, octree_level, num_gaussians, yaw, pitch, point_size) -> IO.NodeOutput: + from comfy.ldm.triposplat.preview import decode_x0_to_image + cfg = {"gaussians": num_gaussians, "level": octree_level, "yaw": yaw, "pitch": pitch, + "point_size": point_size} + + fsm = vae.first_stage_model + cond_tokens = model.model.diffusion_model.q_token_length + memory_required = (cond_tokens * 4 + (num_gaussians // fsm.gaussians_per_point) * 10) * fsm.gs.model_channels * comfy.model_management.dtype_size(vae.vae_dtype) + + # Live preview via WrappersMP.OUTER_SAMPLE + ProgressBar + # The wrapper augments the sampler's own callback to decode x0 -> gaussian splat -> preview image each step + def outer_sample_wrapper(executor, *args, **kwargs): + args = list(args) + cb_idx = 5 # outer_sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + orig_cb = args[cb_idx] if len(args) > cb_idx else kwargs.get("callback") + state = {"ok": True, "pbar": None, "loaded": False} + + def callback(step, x0, x, total_steps): + if orig_cb is not None: + orig_cb(step, x0, x, total_steps) + if not state["ok"]: + return + try: + if not state["loaded"]: + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + loaded_models.append(vae.patcher) + comfy.model_management.load_models_gpu(loaded_models, memory_required=memory_required) + state["loaded"] = True + img = decode_x0_to_image(vae, x0, cfg) + if state["pbar"] is None: + state["pbar"] = comfy.utils.ProgressBar(total_steps) + state["pbar"].update_absolute(step + 1, total_steps, ("JPEG", img, 512)) + except Exception as e: + logging.warning("TripoSplatSamplingPreview: preview failed, disabling ({})".format(e)) + state["ok"] = False + + if len(args) > cb_idx: + args[cb_idx] = callback + else: + kwargs["callback"] = callback + return executor(*args, **kwargs) + + m = model.clone() + m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "triposplat_sampling_preview", outer_sample_wrapper) + return IO.NodeOutput(m) + + +class TripoSplatExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TripoSplatPreprocessImage, + TripoSplatConditioning, + VAEDecodeTripoSplat, + TripoSplatSamplingPreview, + ] + + +async def comfy_entrypoint() -> TripoSplatExtension: + return TripoSplatExtension() diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index d3ee3f1c1..1cf5a5d01 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -22,7 +22,7 @@ class UpscaleModelLoader(io.ComfyNode): return io.Schema( node_id="UpscaleModelLoader", display_name="Load Upscale Model", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input("model_name", options=folder_paths.get_filename_list("upscale_models")), ], diff --git a/comfy_extras/nodes_video.py b/comfy_extras/nodes_video.py index ae1d826d5..6f6c416a6 100644 --- a/comfy_extras/nodes_video.py +++ b/comfy_extras/nodes_video.py @@ -19,7 +19,7 @@ class SaveWEBM(io.ComfyNode): category="video", is_experimental=True, inputs=[ - io.Image.Input("images"), + io.Image.Input("images", tooltip="RGBA images are saved with their alpha channel as transparency (vp9 codec only)."), io.String.Input("filename_prefix", default="ComfyUI"), io.Combo.Input("codec", options=["vp9", "av1"]), io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01), @@ -45,18 +45,25 @@ class SaveWEBM(io.ComfyNode): for x in cls.hidden.extra_pnginfo: container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) + # Save transparency when the images carry an alpha channel (RGBA) and the codec supports it. + # vp9 -> yuva420p; other codecs have no usable alpha path, so the alpha is ignored. + save_alpha = images.shape[-1] == 4 and codec == "vp9" + codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"} stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000)) stream.width = images.shape[-2] stream.height = images.shape[-3] - stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p" + stream.pix_fmt = "yuva420p" if save_alpha else ("yuv420p10le" if codec == "av1" else "yuv420p") stream.bit_rate = 0 stream.options = {'crf': str(crf)} if codec == "av1": stream.options["preset"] = "6" for frame in images: - frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24") + if save_alpha: + frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :4] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgba") + else: + frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24") for packet in stream.encode(frame): container.mux(packet) container.mux(stream.encode()) diff --git a/comfy_extras/nodes_video_model.py b/comfy_extras/nodes_video_model.py index 8f19895a1..0d6cae6a8 100644 --- a/comfy_extras/nodes_video_model.py +++ b/comfy_extras/nodes_video_model.py @@ -15,7 +15,7 @@ class ImageOnlyCheckpointLoader: RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") FUNCTION = "load_checkpoint" - CATEGORY = "loaders" + CATEGORY = "model/loaders" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) @@ -41,7 +41,7 @@ class SVD_img2vid_Conditioning: FUNCTION = "encode" - CATEGORY = "conditioning/video_models" + CATEGORY = "model/conditioning/video_models" def encode(self, clip_vision, init_image, vae, width, height, video_frames, motion_bucket_id, fps, augmentation_level): output = clip_vision.encode_image(init_image) @@ -65,7 +65,7 @@ class VideoLinearCFGGuidance: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "sampling/guiders" + CATEGORY = "model/sampling/guiders" def patch(self, model, min_cfg): def linear_cfg(args): @@ -89,7 +89,7 @@ class VideoTriangleCFGGuidance: RETURN_TYPES = ("MODEL",) FUNCTION = "patch" - CATEGORY = "sampling/guiders" + CATEGORY = "model/sampling/guiders" def patch(self, model, min_cfg): def linear_cfg(args): @@ -138,7 +138,7 @@ class ConditioningSetAreaPercentageVideo: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def append(self, conditioning, width, height, temporal, x, y, z, strength): c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", temporal, height, width, z, y, x), diff --git a/comfy_extras/nodes_void.py b/comfy_extras/nodes_void.py index be724371a..b43154b8d 100644 --- a/comfy_extras/nodes_void.py +++ b/comfy_extras/nodes_void.py @@ -58,7 +58,7 @@ class OpticalFlowLoader(io.ComfyNode): return io.Schema( node_id="OpticalFlowLoader", display_name="Load Optical Flow Model", - category="loaders", + category="model/loaders", inputs=[ io.Combo.Input( "model_name", @@ -175,7 +175,7 @@ class VOIDInpaintConditioning(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="VOIDInpaintConditioning", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -288,7 +288,7 @@ class VOIDWarpedNoise(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="VOIDWarpedNoise", - category="latent/video", + category="model/latent/video", inputs=[ OpticalFlow.Input( "optical_flow", @@ -393,7 +393,7 @@ class VOIDWarpedNoiseSource(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="VOIDWarpedNoiseSource", - category="sampling/noise", + category="model/sampling/noise", inputs=[ io.Latent.Input("warped_noise", tooltip="Warped noise latent from VOIDWarpedNoise"), @@ -455,7 +455,7 @@ class VOIDSampler(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="VOIDSampler", - category="sampling/samplers", + category="model/sampling/samplers", inputs=[], outputs=[io.Sampler.Output()], ) diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index e50bfcd2c..67d3a8443 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -18,7 +18,7 @@ class WanImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -66,7 +66,7 @@ class WanFunControlToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanFunControlToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -119,7 +119,7 @@ class Wan22FunControlToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Wan22FunControlToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -184,7 +184,7 @@ class WanFirstLastFrameToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanFirstLastFrameToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -256,7 +256,7 @@ class WanFunInpaintToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanFunInpaintToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -288,7 +288,7 @@ class WanVaceToVideo(io.ComfyNode): return io.Schema( node_id="WanVaceToVideo", search_aliases=["video conditioning", "video control"], - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -375,7 +375,7 @@ class TrimVideoLatent(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="TrimVideoLatent", - category="latent/video", + category="model/latent/video", inputs=[ io.Latent.Input("samples"), io.Int.Input("trim_amount", default=0, min=0, max=99999), @@ -398,7 +398,7 @@ class WanCameraImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanCameraImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -452,7 +452,7 @@ class WanPhantomSubjectToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanPhantomSubjectToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -707,7 +707,7 @@ class WanTrackToVideo(io.ComfyNode): return io.Schema( node_id="WanTrackToVideo", search_aliases=["motion tracking", "trajectory video", "point tracking", "keypoint animation"], - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -951,7 +951,7 @@ class WanSoundImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanSoundImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -984,7 +984,7 @@ class WanSoundImageToVideoExtend(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanSoundImageToVideoExtend", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -1046,7 +1046,7 @@ class WanHuMoImageToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanHuMoImageToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -1112,7 +1112,7 @@ class WanAnimateToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanAnimateToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), @@ -1252,7 +1252,7 @@ class Wan22ImageToVideoLatent(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="Wan22ImageToVideoLatent", - category="conditioning/inpaint", + category="model/conditioning/inpaint", inputs=[ io.Vae.Input("vae"), io.Int.Input("width", default=1280, min=32, max=nodes.MAX_RESOLUTION, step=32), @@ -1302,7 +1302,7 @@ class WanInfiniteTalkToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanInfiniteTalkToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.DynamicCombo.Input("mode", options=[ io.DynamicCombo.Option("single_speaker", []), @@ -1461,7 +1461,7 @@ class WanSCAILToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanSCAILToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfy_extras/nodes_wandancer.py b/comfy_extras/nodes_wandancer.py index fc005ed4c..a96885745 100644 --- a/comfy_extras/nodes_wandancer.py +++ b/comfy_extras/nodes_wandancer.py @@ -713,7 +713,7 @@ class WanDancerEncodeAudio(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanDancerEncodeAudio", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Audio.Input("audio"), io.Int.Input("video_frames", default=149, min=1, max=nodes.MAX_RESOLUTION, step=4), @@ -787,7 +787,7 @@ class WanDancerVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanDancerVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfy_extras/nodes_wanmove.py b/comfy_extras/nodes_wanmove.py index 5acae03eb..2db064922 100644 --- a/comfy_extras/nodes_wanmove.py +++ b/comfy_extras/nodes_wanmove.py @@ -247,7 +247,7 @@ class WanMoveVisualizeTracks(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanMoveVisualizeTracks", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Image.Input("images"), io.Tracks.Input("tracks", optional=True), @@ -283,7 +283,7 @@ class WanMoveTracksFromCoords(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanMoveTracksFromCoords", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.String.Input("track_coords", force_input=True, default="[]", optional=True), io.Mask.Input("track_mask", optional=True), @@ -325,7 +325,7 @@ class GenerateTracks(io.ComfyNode): return io.Schema( node_id="GenerateTracks", search_aliases=["motion paths", "camera movement", "trajectory"], - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Int.Input("width", default=832, min=16, max=4096, step=16), io.Int.Input("height", default=480, min=16, max=4096, step=16), @@ -434,7 +434,7 @@ class WanMoveConcatTrack(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanMoveConcatTrack", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Tracks.Input("tracks_1"), io.Tracks.Input("tracks_2", optional=True), @@ -463,7 +463,7 @@ class WanMoveTrackToVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="WanMoveTrackToVideo", - category="conditioning/video_models", + category="model/conditioning/video_models", inputs=[ io.Conditioning.Input("positive"), io.Conditioning.Input("negative"), diff --git a/comfyui_version.py b/comfyui_version.py index 0bb0f780c..4e3c924e6 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.22.0" +__version__ = "0.24.0" diff --git a/main.py b/main.py index bce451a83..239a52013 100644 --- a/main.py +++ b/main.py @@ -464,13 +464,6 @@ def start_comfyui(asyncio_loop=None): folder_paths.set_temp_directory(temp_dir) cleanup_temp() - if args.windows_standalone_build: - try: - import new_updater - new_updater.update_windows_updater() - except: - pass - if not asyncio_loop: asyncio_loop = asyncio.new_event_loop() asyncio.set_event_loop(asyncio_loop) diff --git a/new_updater.py b/new_updater.py deleted file mode 100644 index 9a203acdd..000000000 --- a/new_updater.py +++ /dev/null @@ -1,35 +0,0 @@ -import os -import shutil - -base_path = os.path.dirname(os.path.realpath(__file__)) - - -def update_windows_updater(): - top_path = os.path.dirname(base_path) - updater_path = os.path.join(base_path, ".ci/update_windows/update.py") - bat_path = os.path.join(base_path, ".ci/update_windows/update_comfyui.bat") - - dest_updater_path = os.path.join(top_path, "update/update.py") - dest_bat_path = os.path.join(top_path, "update/update_comfyui.bat") - dest_bat_deps_path = os.path.join(top_path, "update/update_comfyui_and_python_dependencies.bat") - - try: - with open(dest_bat_path, 'rb') as f: - contents = f.read() - except: - return - - if not contents.startswith(b"..\\python_embeded\\python.exe .\\update.py"): - return - - shutil.copy(updater_path, dest_updater_path) - try: - with open(dest_bat_deps_path, 'rb') as f: - contents = f.read() - contents = contents.replace(b'..\\python_embeded\\python.exe .\\update.py ..\\ComfyUI\\', b'call update_comfyui.bat nopause') - with open(dest_bat_deps_path, 'wb') as f: - f.write(contents) - except: - pass - shutil.copy(bat_path, dest_bat_path) - print("Updated the windows standalone package updater.") # noqa: T201 diff --git a/nodes.py b/nodes.py index 87d81b5b7..d9ac53ede 100644 --- a/nodes.py +++ b/nodes.py @@ -47,14 +47,18 @@ import node_helpers if args.enable_manager: import comfyui_manager + def before_node_execution(): comfy.model_management.throw_exception_if_processing_interrupted() + def interrupt_processing(value=True): comfy.model_management.interrupt_current_processing(value) + MAX_RESOLUTION=16384 + class CLIPTextEncode(ComfyNodeABC): @classmethod def INPUT_TYPES(s) -> InputTypeDict: @@ -68,7 +72,7 @@ class CLIPTextEncode(ComfyNodeABC): OUTPUT_TOOLTIPS = ("A conditioning containing the embedded text used to guide the diffusion model.",) FUNCTION = "encode" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" DESCRIPTION = "Encodes a text prompt using a CLIP model into an embedding that can be used to guide the diffusion model towards generating specific images." SEARCH_ALIASES = ["text", "prompt", "text prompt", "positive prompt", "negative prompt", "encode text", "text encoder", "encode prompt"] @@ -87,7 +91,7 @@ class ConditioningCombine: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "combine" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" SEARCH_ALIASES = ["combine", "merge conditioning", "combine prompts", "merge prompts", "mix prompts", "add prompt"] def combine(self, conditioning_1, conditioning_2): @@ -104,7 +108,7 @@ class ConditioningAverage : RETURN_TYPES = ("CONDITIONING",) FUNCTION = "addWeighted" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def addWeighted(self, conditioning_to, conditioning_from, conditioning_to_strength): out = [] @@ -143,7 +147,7 @@ class ConditioningConcat: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "concat" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def concat(self, conditioning_to, conditioning_from): out = [] @@ -176,7 +180,7 @@ class ConditioningSetArea: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def append(self, conditioning, width, height, x, y, strength): c = node_helpers.conditioning_set_values(conditioning, {"area": (height // 8, width // 8, y // 8, x // 8), @@ -197,7 +201,7 @@ class ConditioningSetAreaPercentage: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def append(self, conditioning, width, height, x, y, strength): c = node_helpers.conditioning_set_values(conditioning, {"area": ("percentage", height, width, y, x), @@ -214,7 +218,7 @@ class ConditioningSetAreaStrength: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def append(self, conditioning, strength): c = node_helpers.conditioning_set_values(conditioning, {"strength": strength}) @@ -234,7 +238,7 @@ class ConditioningSetMask: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def append(self, conditioning, mask, set_cond_area, strength): set_area_to_bounds = False @@ -303,7 +307,7 @@ class VAEDecode: OUTPUT_TOOLTIPS = ("The decoded image.",) FUNCTION = "decode" - CATEGORY = "latent" + CATEGORY = "model/latent" DESCRIPTION = "Decodes latent images back into pixel space images." SEARCH_ALIASES = ["decode", "decode latent", "latent to image", "render latent"] @@ -323,8 +327,8 @@ class VAEDecodeTiled: return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32, "advanced": True}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}), - "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time.", "advanced": True}), - "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), + "temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}), + "temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" @@ -334,18 +338,32 @@ class VAEDecodeTiled: def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): if tile_size < overlap * 4: overlap = tile_size // 4 - if temporal_size < temporal_overlap * 2: - temporal_overlap = temporal_overlap // 2 temporal_compression = vae.temporal_compression_decode() if temporal_compression is not None: - temporal_size = max(2, temporal_size // temporal_compression) - temporal_overlap = max(1, min(temporal_size // 2, temporal_overlap // temporal_compression)) + if temporal_size <= 0: + temporal_size = 0 + temporal_overlap = 0 + else: + requested_temporal_overlap = temporal_overlap + if temporal_size < temporal_overlap * 2: + temporal_overlap = temporal_overlap // 2 + temporal_size = max(2, temporal_size // temporal_compression) + temporal_overlap = min(temporal_size // 2, temporal_overlap // temporal_compression) + if requested_temporal_overlap > 0: + temporal_overlap = max(1, temporal_overlap) else: temporal_size = None temporal_overlap = None compression = vae.spacial_compression_decode() - images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap) + images = vae.decode_tiled( + samples["samples"], + tile_x=tile_size // compression, + tile_y=tile_size // compression, + overlap=overlap // compression, + tile_t=temporal_size, + overlap_t=temporal_overlap, + ) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, ) @@ -357,12 +375,12 @@ class VAEEncode: RETURN_TYPES = ("LATENT",) FUNCTION = "encode" - CATEGORY = "latent" + CATEGORY = "model/latent" SEARCH_ALIASES = ["encode", "encode image", "image to latent"] def encode(self, vae, pixels): t = vae.encode(pixels) - return ({"samples":t}, ) + return ({"samples": t}, ) class VAEEncodeTiled: @classmethod @@ -370,8 +388,8 @@ class VAEEncodeTiled: return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64, "advanced": True}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32, "advanced": True}), - "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time.", "advanced": True}), - "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), + "temporal_size": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time. SeedVR2 allows 0 to disable temporal slicing.", "advanced": True}), + "temporal_overlap": ("INT", {"default": 8, "min": 0, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap.", "advanced": True}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "encode" @@ -379,6 +397,9 @@ class VAEEncodeTiled: CATEGORY = "experimental" def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): + if temporal_size <= 0: + temporal_size = 0 + temporal_overlap = 0 t = vae.encode_tiled(pixels, tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) return ({"samples": t}, ) @@ -389,7 +410,7 @@ class VAEEncodeForInpaint: RETURN_TYPES = ("LATENT",) FUNCTION = "encode" - CATEGORY = "latent/inpaint" + CATEGORY = "model/latent/inpaint" def encode(self, vae, pixels, mask, grow_mask_by=6): downscale_ratio = vae.spacial_compression_encode() @@ -438,7 +459,7 @@ class InpaintModelConditioning: RETURN_NAMES = ("positive", "negative", "latent") FUNCTION = "encode" - CATEGORY = "conditioning/inpaint" + CATEGORY = "model/conditioning/inpaint" def encode(self, positive, negative, pixels, vae, mask, noise_mask=True): x = (pixels.shape[1] // 8) * 8 @@ -598,7 +619,7 @@ class CheckpointLoaderSimple: "The VAE model used for encoding and decoding images to and from latent space.") FUNCTION = "load_checkpoint" - CATEGORY = "loaders" + CATEGORY = "model/loaders" DESCRIPTION = "Loads a diffusion model checkpoint, diffusion models are used to denoise latents." SEARCH_ALIASES = ["load model", "checkpoint", "model loader", "load checkpoint", "ckpt", "model"] @@ -644,7 +665,7 @@ class unCLIPCheckpointLoader: RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION") FUNCTION = "load_checkpoint" - CATEGORY = "loaders" + CATEGORY = "model/loaders" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): ckpt_path = folder_paths.get_full_path_or_raise("checkpoints", ckpt_name) @@ -660,7 +681,7 @@ class CLIPSetLastLayer: RETURN_TYPES = ("CLIP",) FUNCTION = "set_last_layer" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def set_last_layer(self, clip, stop_at_clip_layer): clip = clip.clone() @@ -689,7 +710,7 @@ class LoraLoader: OUTPUT_TOOLTIPS = ("The modified diffusion model.", "The modified CLIP model.") FUNCTION = "load_lora" - CATEGORY = "loaders" + CATEGORY = "model/loaders" DESCRIPTION = "This LoRA loader is used to modify both diffusion and CLIP models, altering the way in which latents are denoised such as applying styles. Multiple LoRA nodes can be linked together." SEARCH_ALIASES = ["lora", "load lora", "apply lora", "lora loader", "lora model"] @@ -789,7 +810,7 @@ class VAELoader: RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" - CATEGORY = "loaders" + CATEGORY = "model/loaders" #TODO: scale factor? def load_vae(self, vae_name): @@ -831,7 +852,7 @@ class ControlNetLoader: RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" - CATEGORY = "loaders" + CATEGORY = "model/loaders" SEARCH_ALIASES = ["controlnet", "control net", "cn", "load controlnet", "controlnet loader"] def load_controlnet(self, control_net_name): @@ -850,7 +871,7 @@ class DiffControlNetLoader: RETURN_TYPES = ("CONTROL_NET",) FUNCTION = "load_controlnet" - CATEGORY = "loaders" + CATEGORY = "model/loaders" def load_controlnet(self, model, control_net_name): controlnet_path = folder_paths.get_full_path_or_raise("controlnet", control_net_name) @@ -870,7 +891,7 @@ class ControlNetApply: FUNCTION = "apply_controlnet" DEPRECATED = True - CATEGORY = "conditioning/controlnet" + CATEGORY = "model/conditioning/controlnet" def apply_controlnet(self, conditioning, control_net, image, strength): if strength == 0: @@ -908,7 +929,7 @@ class ControlNetApplyAdvanced: RETURN_NAMES = ("positive", "negative") FUNCTION = "apply_controlnet" - CATEGORY = "conditioning/controlnet" + CATEGORY = "model/conditioning/controlnet" SEARCH_ALIASES = ["controlnet", "apply controlnet", "use controlnet", "control net"] def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]): @@ -969,7 +990,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis", "longcat_image", "cogvideox", "lens", "pixeldit", "ideogram4"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -1030,7 +1051,7 @@ class CLIPVisionLoader: RETURN_TYPES = ("CLIP_VISION",) FUNCTION = "load_clip" - CATEGORY = "loaders" + CATEGORY = "model/loaders" def load_clip(self, clip_name): clip_path = folder_paths.get_full_path_or_raise("clip_vision", clip_name) @@ -1049,7 +1070,7 @@ class CLIPVisionEncode: RETURN_TYPES = ("CLIP_VISION_OUTPUT",) FUNCTION = "encode" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def encode(self, clip_vision, image, crop): crop_image = True @@ -1066,7 +1087,7 @@ class StyleModelLoader: RETURN_TYPES = ("STYLE_MODEL",) FUNCTION = "load_style_model" - CATEGORY = "loaders" + CATEGORY = "model/loaders" def load_style_model(self, style_model_name): style_model_path = folder_paths.get_full_path_or_raise("style_models", style_model_name) @@ -1088,7 +1109,7 @@ class StyleModelApply: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_stylemodel" - CATEGORY = "conditioning/style_model" + CATEGORY = "model/conditioning/style_model" def apply_stylemodel(self, conditioning, style_model, clip_vision_output, strength, strength_type): cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0) @@ -1148,7 +1169,7 @@ class unCLIPConditioning: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_adm" - CATEGORY = "conditioning" + CATEGORY = "model/conditioning" def apply_adm(self, conditioning, clip_vision_output, strength, noise_augmentation): if strength == 0: @@ -1165,7 +1186,7 @@ class GLIGENLoader: RETURN_TYPES = ("GLIGEN",) FUNCTION = "load_gligen" - CATEGORY = "loaders" + CATEGORY = "model/loaders" def load_gligen(self, gligen_name): gligen_path = folder_paths.get_full_path_or_raise("gligen", gligen_name) @@ -1187,7 +1208,7 @@ class GLIGENTextBoxApply: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "append" - CATEGORY = "conditioning/gligen" + CATEGORY = "model/conditioning/gligen" def append(self, conditioning_to, clip, gligen_textbox_model, text, width, height, x, y): c = [] @@ -1217,7 +1238,7 @@ class EmptyLatentImage: OUTPUT_TOOLTIPS = ("The empty latent image batch.",) FUNCTION = "generate" - CATEGORY = "latent" + CATEGORY = "model/latent" DESCRIPTION = "Create a new batch of empty latent images to be denoised via sampling." SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"] @@ -1238,7 +1259,7 @@ class LatentFromBatch: RETURN_TYPES = ("LATENT",) FUNCTION = "frombatch" - CATEGORY = "latent/batch" + CATEGORY = "model/latent/batch" def frombatch(self, samples, batch_index, length): s = samples.copy() @@ -1273,7 +1294,7 @@ class RepeatLatentBatch: RETURN_TYPES = ("LATENT",) FUNCTION = "repeat" - CATEGORY = "latent/batch" + CATEGORY = "model/latent/batch" def repeat(self, samples, amount): s = samples.copy() @@ -1305,7 +1326,7 @@ class LatentUpscale: RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" - CATEGORY = "latent" + CATEGORY = "model/latent" def upscale(self, samples, upscale_method, width, height, crop): if width == 0 and height == 0: @@ -1338,7 +1359,7 @@ class LatentUpscaleBy: RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" - CATEGORY = "latent" + CATEGORY = "model/latent" def upscale(self, samples, upscale_method, scale_by): s = samples.copy() @@ -1356,7 +1377,7 @@ class LatentRotate: RETURN_TYPES = ("LATENT",) FUNCTION = "rotate" - CATEGORY = "latent/transform" + CATEGORY = "model/latent/transform" def rotate(self, samples, rotation): s = samples.copy() @@ -1382,7 +1403,7 @@ class LatentFlip: RETURN_TYPES = ("LATENT",) FUNCTION = "flip" - CATEGORY = "latent/transform" + CATEGORY = "model/latent/transform" def flip(self, samples, flip_method): s = samples.copy() @@ -1407,7 +1428,7 @@ class LatentComposite: RETURN_TYPES = ("LATENT",) FUNCTION = "composite" - CATEGORY = "latent" + CATEGORY = "model/latent" def composite(self, samples_to, samples_from, x, y, composite_method="normal", feather=0): x = x // 8 @@ -1494,7 +1515,7 @@ class LatentCrop: RETURN_TYPES = ("LATENT",) FUNCTION = "crop" - CATEGORY = "latent/transform" + CATEGORY = "model/latent/transform" def crop(self, samples, width, height, x, y): s = samples.copy() @@ -1524,7 +1545,7 @@ class SetLatentNoiseMask: RETURN_TYPES = ("LATENT",) FUNCTION = "set_mask" - CATEGORY = "latent/inpaint" + CATEGORY = "model/latent/inpaint" def set_mask(self, samples, mask): s = samples.copy() @@ -1578,7 +1599,7 @@ class KSampler: OUTPUT_TOOLTIPS = ("The denoised latent.",) FUNCTION = "sample" - CATEGORY = "sampling" + CATEGORY = "model/sampling" DESCRIPTION = "Uses the provided model, positive and negative conditioning to denoise the latent image." SEARCH_ALIASES = ["sampler", "sample", "generate", "denoise", "diffuse", "txt2img", "img2img"] @@ -1608,7 +1629,7 @@ class KSamplerAdvanced: RETURN_TYPES = ("LATENT",) FUNCTION = "sample" - CATEGORY = "sampling" + CATEGORY = "model/sampling" def sample(self, model, add_noise, noise_seed, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0): force_full_denoise = True @@ -2362,6 +2383,7 @@ async def init_builtin_extra_nodes(): "nodes_model_downscale.py", "nodes_images.py", "nodes_video_model.py", + "nodes_ideogram4.py", "nodes_train.py", "nodes_dataset.py", "nodes_sag.py", @@ -2417,6 +2439,7 @@ async def init_builtin_extra_nodes(): "nodes_camera_trajectory.py", "nodes_edit_model.py", "nodes_tcfg.py", + "nodes_seedvr.py", "nodes_context_windows.py", "nodes_qwen.py", "nodes_chroma_radiance.py", @@ -2455,6 +2478,8 @@ async def init_builtin_extra_nodes(): "nodes_save_3d.py", "nodes_moge.py", "nodes_mediapipe.py", + "nodes_gaussian_splat.py", + "nodes_triposplat.py" ] import_failed = [] diff --git a/openapi.yaml b/openapi.yaml index f801a39d9..b7e21245f 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -1,11749 +1,4956 @@ -openapi: 3.1.0 -info: - title: ComfyUI API - description: | - API for ComfyUI - A powerful and modular stable diffusion GUI and backend. - - This API allows you to interact with ComfyUI programmatically, including: - - Submitting and managing workflow executions - - Querying node/object information - - Uploading and viewing files - - Managing user settings and data - - Asset management (feature-gated) - - ## Dual-path routing - Every route registered via `self.routes` in the ComfyUI server is available at - both its bare path (e.g. `/prompt`) and an `/api`-prefixed path (e.g. `/api/prompt`). - This spec uses the `/api`-prefixed versions as canonical. - - ## Multi-user mode - When ComfyUI is started with `--multi-user`, the `Comfy-User` header identifies - the active user for settings, userdata, and history isolation. This is **not** a - security mechanism — it is an organisational convenience with no authentication - or authorisation behind it. - version: 1.0.0 - license: - name: GNU General Public License v3.0 - url: https://github.com/comfyanonymous/ComfyUI/blob/master/LICENSE - -servers: - - url: / - description: Default ComfyUI server (typically http://127.0.0.1:8188) - -tags: - - name: prompt - description: Workflow submission and prompt info - - name: queue - description: Queue inspection and management - - name: history - description: Execution history - - name: upload - description: File upload endpoints - - name: view - description: File viewing / download - - name: system - description: System stats and feature flags - - name: node - description: Node / object_info definitions - - name: model - description: Model folder and file listing - - name: user - description: User management (multi-user mode) - - name: userdata - description: Per-user file storage - - name: settings - description: Per-user settings - - name: extensions - description: Frontend extension JS files - - name: subgraph - description: Global subgraph blueprints - - name: internal - description: Internal / debug endpoints - - name: assets - description: Asset management (feature-gated behind enable-assets) - - - name: auth - description: Authentication and session management (cloud-only) - - name: billing - description: Billing, subscriptions, and payment management (cloud-only) - - name: workspace - description: Workspace and team management (cloud-only) - - name: hub - description: "ComfyUI Hub: profiles, shared workflows, and labels (cloud-only)" - - name: workflows - description: Cloud workflow management and versioning (cloud-only) - - name: task - description: Background task management (cloud-only) - - name: runtime-only - description: Operations served exclusively by the cloud runtime with no local equivalent - -paths: - # --------------------------------------------------------------------------- - # WebSocket - # --------------------------------------------------------------------------- - /ws: - get: - operationId: connectWebSocket - tags: [system] - summary: WebSocket connection for real-time updates - description: | - Upgrades to a WebSocket connection that streams execution progress, - node status, and output messages. The server sends an initial `status` - message with the session ID (SID) on connect. - - ## Message types (server → client) - The server sends JSON messages with a `type` field. See the - `x-websocket-messages` list below for the schema of each message type. - parameters: - - name: clientId - in: query - required: false - schema: - type: string - description: Client identifier. If omitted the server assigns one. - responses: - "101": - description: WebSocket upgrade successful - '401': - description: Unauthorized - x-websocket-messages: - - type: status - schema: - $ref: "#/components/schemas/StatusWsMessage" - - type: progress - schema: - $ref: "#/components/schemas/ProgressWsMessage" - - type: progress_text - schema: - $ref: "#/components/schemas/ProgressTextWsMessage" - - type: progress_state - schema: - $ref: "#/components/schemas/ProgressStateWsMessage" - - type: executing - schema: - $ref: "#/components/schemas/ExecutingWsMessage" - - type: executed - schema: - $ref: "#/components/schemas/ExecutedWsMessage" - - type: execution_start - schema: - $ref: "#/components/schemas/ExecutionStartWsMessage" - - type: execution_success - schema: - $ref: "#/components/schemas/ExecutionSuccessWsMessage" - - type: execution_cached - schema: - $ref: "#/components/schemas/ExecutionCachedWsMessage" - - type: execution_interrupted - schema: - $ref: "#/components/schemas/ExecutionInterruptedWsMessage" - - type: execution_error - schema: - $ref: "#/components/schemas/ExecutionErrorWsMessage" - - type: logs - schema: - $ref: "#/components/schemas/LogsWsMessage" - - type: notification - schema: - $ref: "#/components/schemas/NotificationWsMessage" - - type: feature_flags - schema: - $ref: "#/components/schemas/FeatureFlagsWsMessage" - - type: asset_download - schema: - $ref: "#/components/schemas/AssetDownloadWsMessage" - - type: asset_export - schema: - $ref: "#/components/schemas/AssetExportWsMessage" - - # --------------------------------------------------------------------------- - # Prompt - # --------------------------------------------------------------------------- - /api/prompt: - get: - operationId: getPromptInfo - tags: [prompt] - summary: Get queue status - description: Returns how many items remain in the execution queue. - responses: - "200": - description: Queue info - content: - application/json: - schema: - $ref: "#/components/schemas/PromptInfo" - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - post: - operationId: executePrompt - tags: [prompt] - summary: Submit a workflow for execution - description: Submits a workflow for execution. The server validates the graph, assigns a `prompt_id`, and enqueues it. Clients listen on `/ws` for execution progress and output messages. - requestBody: - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/PromptRequest" - responses: - "200": - description: Prompt accepted - content: - application/json: - schema: - $ref: "#/components/schemas/PromptResponse" - "400": - description: Validation or node errors - content: - application/json: - schema: - $ref: "#/components/schemas/PromptErrorResponse" - - '402': - description: Payment required - Insufficient credits - content: - application/json: - schema: - $ref: '#/components/schemas/PromptErrorResponse' - '429': - description: Payment required - User has not paid - content: - application/json: - schema: - $ref: '#/components/schemas/PromptErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/PromptErrorResponse' - '503': - description: Service unavailable - content: - application/json: - schema: - $ref: '#/components/schemas/PromptErrorResponse' - # --------------------------------------------------------------------------- - # Queue - # --------------------------------------------------------------------------- - /api/queue: - get: - operationId: getQueueInfo - tags: [queue] - summary: Get running and pending queue items - description: Returns the server's current execution queue, split into the currently-running prompt and the list of pending prompts. - responses: - "200": - description: Queue contents - content: - application/json: - schema: - $ref: "#/components/schemas/QueueInfo" - '400': - description: Invalid request parameters - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Invalid request parameters - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - post: - operationId: manageQueue - tags: [queue] - summary: Clear or delete items from the queue - description: Mutates the execution queue. Supports clearing all queued prompts or deleting individual prompts by ID. - requestBody: - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/QueueManageRequest" - responses: - "200": - description: Queue updated - content: - application/json: - schema: - $ref: "#/components/schemas/QueueManageResponse" - '400': - description: Invalid request parameters - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/interrupt: - post: - operationId: interruptJob - tags: [queue] - summary: Interrupt current execution - description: Interrupts the prompt that is currently executing. The next queued prompt (if any) will start immediately after. - requestBody: - required: false - content: - application/json: - schema: - type: object - properties: - prompt_id: - type: string - format: uuid - description: "If provided, only interrupts this specific running prompt. Otherwise interrupts all." - responses: - "200": - description: Interrupt signal sent - - '401': - description: Unauthorized - Authentication required - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/free: - post: - operationId: freeMemory - tags: [queue] - summary: Free GPU memory and/or unload models - description: Frees GPU memory by unloading models and/or freeing the resident model cache, controlled by the request flags. - requestBody: - required: false - content: - application/json: - schema: - type: object - properties: - unload_models: - type: boolean - description: Unload all models from VRAM/RAM - free_memory: - type: boolean - description: Run garbage collection and free cached memory - responses: - "200": - description: Memory freed - - # --------------------------------------------------------------------------- - # Jobs - # --------------------------------------------------------------------------- - /api/jobs: - get: - operationId: listJobs - tags: [queue] - summary: List jobs with filtering and pagination - description: Returns a paginated list of completed prompt executions, newest first. - parameters: - - name: status - in: query - schema: - type: string - description: Filter by job status - - name: workflow_id - in: query - schema: - type: string - description: Filter by workflow ID - - name: sort_by - in: query - schema: - type: string - description: Field to sort by - - name: sort_order - in: query - schema: - type: string - enum: [asc, desc] - description: Sort direction - - name: limit - in: query - schema: - type: integer - description: Maximum number of results (default is unlimited/None) - - name: offset - in: query - schema: - type: integer - default: 0 - description: Pagination offset - responses: - "200": - description: Jobs list - content: - application/json: - schema: - type: object - properties: - jobs: - type: array - items: - $ref: "#/components/schemas/JobEntry" - pagination: - $ref: "#/components/schemas/PaginationInfo" - - '401': - description: Unauthorized - Authentication required - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/jobs/{job_id}: - get: - operationId: getJobDetail - tags: [queue] - summary: Get a single job by ID - description: Returns the full record for a single completed prompt execution, including its outputs, status, and metadata. - parameters: - - name: job_id - in: path - description: The job (prompt) ID to fetch. - required: true - schema: - type: string - format: uuid - responses: - "200": - description: Job detail - content: - application/json: - schema: - $ref: "#/components/schemas/JobDetailResponse" - "404": - description: Job not found - - '401': - description: Unauthorized - Authentication required - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '403': - description: Forbidden - Job does not belong to user - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - # --------------------------------------------------------------------------- - # History - # --------------------------------------------------------------------------- - /api/history: - get: - operationId: getPromptHistory - tags: [history] - summary: Get execution history - deprecated: true - description: | - **Deprecated.** Superseded by `GET /api/jobs`, which returns the same - execution records in a paginated, filterable format. Planned for removal - no earlier than a future major release; sunset timeline TBD. - - Returns a dictionary keyed by prompt_id. Each value is a HistoryEntry - containing prompt metadata, outputs, status, and node meta. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - - name: max_items - in: query - schema: - type: integer - description: Maximum number of history entries to return - - name: offset - in: query - schema: - type: integer - description: Pagination offset (number of entries to skip) - responses: - "200": - description: History dictionary keyed by prompt_id - content: - application/json: - schema: - type: object - additionalProperties: - $ref: "#/components/schemas/HistoryEntry" - '404': - description: "Not Found \u2014 use /api/history_v2 instead" - post: - operationId: manageHistory - tags: [history] - summary: Clear or delete history entries - deprecated: true - description: | - **Deprecated.** Superseded by the forthcoming job-management endpoints - under `/api/jobs`. Planned for removal no earlier than a future major - release; sunset timeline TBD. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - requestBody: - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/HistoryManageRequest" - responses: - "200": - description: History updated - - '400': - description: Invalid request parameters - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '401': - description: Unauthorized - Authentication required - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/history/{prompt_id}: - get: - operationId: getHistoryByPromptId - tags: [history] - summary: Get history for a specific prompt - deprecated: true - description: | - **Deprecated.** Superseded by `GET /api/jobs/{job_id}`, which returns - the same execution record. Planned for removal no earlier than a future - major release; sunset timeline TBD. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - - name: prompt_id - in: path - description: The prompt ID to fetch history for. - required: true - schema: - type: string - format: uuid - responses: - "200": - description: Single-entry history dictionary. Returns an empty object `{}` if the prompt_id is not found. - content: - application/json: - schema: - type: object - additionalProperties: - $ref: "#/components/schemas/HistoryEntry" - - '404': - description: "Not Found \u2014 use /api/jobs/{prompt_id} instead" - # --------------------------------------------------------------------------- - # Upload - # --------------------------------------------------------------------------- - /api/upload/image: - post: - operationId: uploadImage - tags: [upload] - summary: Upload an image file - description: Uploads an image file into one of the input/output/temp directories so it can be referenced by workflow nodes. - requestBody: - required: true - content: - multipart/form-data: - schema: - type: object - required: - - image - properties: - image: - type: string - format: binary - description: Image file to upload - type: - type: string - enum: [input, temp, output] - default: input - description: Target directory type - overwrite: - type: string - description: 'Set to "true" to overwrite existing files' - subfolder: - type: string - description: Subfolder within the target directory - responses: - "200": - description: Upload result - content: - application/json: - schema: - $ref: "#/components/schemas/UploadResult" - "400": - description: No file provided or invalid request - - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/upload/mask: - post: - operationId: uploadMask - tags: [upload] - deprecated: true - summary: Upload a mask image (deprecated) - description: | - Deprecated. Clients should composite the mask onto the source image - client-side and upload the resulting image via POST /api/upload/image - instead. This endpoint will continue to function for older clients, - but will not receive new features. - - Uploads a mask image associated with a previously-uploaded reference image. - requestBody: - required: true - content: - multipart/form-data: - schema: - type: object - required: - - image - - original_ref - properties: - image: - type: string - format: binary - description: Mask image (alpha channel is used) - original_ref: - type: object - description: Reference to the original image file - required: - - filename - properties: - filename: - type: string - description: Filename of the original image - additionalProperties: true - type: - type: string - enum: [input, temp, output] - default: input - description: Target directory type - overwrite: - type: string - description: 'Set to "true" to overwrite existing files' - subfolder: - type: string - description: Subfolder within the target directory - responses: - "200": - description: Upload result - content: - application/json: - schema: - $ref: "#/components/schemas/UploadResult" - "400": - description: No file provided or invalid request - - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - # --------------------------------------------------------------------------- - # View - # --------------------------------------------------------------------------- - /api/view: - get: - operationId: viewFile - tags: [view] - summary: View or download a file - description: Serves a file (image, audio, or video) from the input/output/temp directory identified by the query parameters. - parameters: - - name: filename - in: query - required: true - schema: - type: string - description: Name of the file to view - - name: type - in: query - schema: - type: string - enum: [input, output, temp] - default: output - description: Directory type - - name: subfolder - in: query - schema: - type: string - description: Subfolder within the directory - - name: preview - in: query - schema: - type: string - description: Preview format hint (e.g. "webp;90") - - name: channel - in: query - schema: - type: string - enum: [rgba, rgb, a] - description: Channel extraction mode - responses: - "200": - description: File content - content: - image/*: - schema: - type: string - format: binary - video/*: - schema: - type: string - format: binary - audio/*: - schema: - type: string - format: binary - application/octet-stream: - schema: - type: string - format: binary - "404": - description: File not found - - '302': - description: Redirect to GCS signed URL - headers: - Location: - description: Signed URL to access the file in GCS - schema: - type: string - Cache-Control: - description: Cache directive for the redirect response - schema: - type: string - Vary: - description: Headers that affect response caching - schema: - type: string - '400': - description: Invalid request parameters - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/view_metadata/{folder_name}: - get: - operationId: viewMetadata - tags: [view] - summary: Get metadata for a file (e.g. safetensors header) - description: Returns embedded metadata parsed from a file in the given folder — for example, the header of a safetensors model. - parameters: - - name: folder_name - in: path - required: true - schema: - type: string - description: Folder type (output, input, temp, etc.) - - name: filename - in: query - required: true - schema: - type: string - description: Filename to read metadata from - responses: - "200": - description: File metadata - content: - application/json: - schema: - type: object - additionalProperties: true - "404": - description: File or metadata not found - - # --------------------------------------------------------------------------- - # System - # --------------------------------------------------------------------------- - /api/system_stats: - get: - operationId: getSystemStats - tags: [system] - summary: Get system statistics - description: Returns hardware, Python, VRAM, and runtime statistics for the running ComfyUI process. - responses: - "200": - description: System stats - content: - application/json: - schema: - $ref: "#/components/schemas/SystemStatsResponse" - - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/features: - get: - operationId: getFeatures - tags: [system] - summary: Get enabled feature flags - description: Returns a dictionary of feature flag names to their enabled state. Cloud deployments may include additional typed fields alongside the boolean flags. - responses: - "200": - description: Feature flags - content: - application/json: - schema: - type: object - additionalProperties: - type: boolean - properties: - max_upload_size: - type: integer - format: int64 - minimum: 0 - description: "Maximum file upload size in bytes." - free_tier_credits: - type: integer - format: int32 - minimum: 0 - nullable: true - x-runtime: [cloud] - description: "[cloud-only] Credits available to free-tier users. Local ComfyUI returns null." - posthog_api_host: - type: string - format: uri - nullable: true - x-runtime: [cloud] - description: "[cloud-only] PostHog analytics proxy URL for frontend telemetry. Local ComfyUI returns null." - max_concurrent_jobs: - type: integer - format: int32 - minimum: 0 - nullable: true - x-runtime: [cloud] - description: "[cloud-only] Maximum concurrent jobs the authenticated user can run. Local ComfyUI returns null." - workflow_templates_version: - type: string - nullable: true - x-runtime: [cloud] - description: "[cloud-only] Version identifier for the workflow templates bundle. Local ComfyUI returns null." - workflow_templates_source: - type: string - nullable: true - enum: [dynamic_config_override, workflow_templates_version_json] - x-runtime: [cloud] - description: "[cloud-only] How the templates version was resolved. Local ComfyUI returns null." - - # --------------------------------------------------------------------------- - # Node / Object Info - # --------------------------------------------------------------------------- - /api/object_info: - get: - operationId: getNodeInfo - tags: [node] - summary: Get all node definitions - description: | - Returns a dictionary of every registered node class, keyed by class name. - Each value is a NodeInfo object describing inputs, outputs, category, etc. - responses: - "200": - description: All node definitions - content: - application/json: - schema: - type: object - additionalProperties: - $ref: "#/components/schemas/NodeInfo" - - /api/object_info/{node_class}: - get: - operationId: getObjectInfoByClass - tags: [node] - summary: Get a single node definition - description: Returns the `NodeInfo` definition for a single registered node class. - parameters: - - name: node_class - in: path - required: true - schema: - type: string - description: Node class name (e.g. "KSampler") - responses: - "200": - description: Single node definition - content: - application/json: - schema: - type: object - additionalProperties: - $ref: "#/components/schemas/NodeInfo" - "404": - description: Node class not found - - /api/embeddings: - get: - operationId: getEmbeddings - tags: [node] - summary: List available embedding names - description: Returns the list of text-encoder embeddings available on disk. - responses: - "200": - description: Embedding names - content: - application/json: - schema: - type: array - items: - type: string - - # --------------------------------------------------------------------------- - # Models - # --------------------------------------------------------------------------- - /api/models: - get: - operationId: getModelTypes - tags: [model] - summary: List model folder type names - description: Returns an array of model type names (e.g. checkpoints, loras, vae). - responses: - "200": - description: Model type names - content: - application/json: - schema: - type: array - items: - type: string - - '404': - description: "Not Found \u2014 use /api/experiment/models instead" - /api/models/{folder}: - get: - operationId: getModelsByFolder - tags: [model] - summary: List model filenames in a folder - description: Returns the names of model files in the given folder. This endpoint predates `/api/experiment/models/{folder}` and returns names only — prefer the experiment endpoint for new integrations. - parameters: - - name: folder - in: path - required: true - schema: - type: string - description: Model folder type name - responses: - "200": - description: Model filenames - content: - application/json: - schema: - type: array - items: - type: string - "404": - description: Unknown folder type - - /api/experiment/models: - get: - operationId: getModelFolders - tags: [model] - summary: List model folders with paths - description: Returns an array of model folder objects with name and folder paths. - responses: - "200": - description: Model folders - content: - application/json: - schema: - type: array - items: - $ref: "#/components/schemas/ModelFolder" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/experiment/models/{folder}: - get: - operationId: getModelsInFolder - tags: [model] - summary: List model files with metadata - description: Returns the model files in the given folder with richer metadata (path index, mtime, size) than the legacy `/api/models/{folder}` endpoint. - parameters: - - name: folder - in: path - required: true - schema: - type: string - description: Model folder type name - responses: - "200": - description: Model files with metadata - content: - application/json: - schema: - type: array - items: - $ref: "#/components/schemas/ModelFile" - "404": - description: Unknown folder type - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/experiment/models/preview/{folder}/{path_index}/{filename}: - get: - operationId: getModelPreview - tags: [model] - summary: Get model preview image - description: Returns the preview image associated with a model file, if one exists alongside the model on disk. - parameters: - - name: folder - in: path - required: true - schema: - type: string - description: Model folder type name - - name: path_index - in: path - required: true - schema: - type: integer - description: Path index within the folder - - name: filename - in: path - required: true - schema: - type: string - description: Model filename - responses: - "200": - description: Preview image (WebP) - content: - image/webp: - schema: - type: string - format: binary - "404": - description: Preview not found - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - # --------------------------------------------------------------------------- - # Users - # --------------------------------------------------------------------------- - /api/users: - get: - operationId: getUsersInfo - tags: [user] - summary: Get user storage info - description: | - Returns user storage configuration. In single-user mode returns - `{"storage": "server", "migrated": true/false}`. In multi-user mode - returns `{"storage": "server", "users": {"user_id": "user_dir", ...}}`. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - responses: - "200": - description: User info - content: - application/json: - schema: - type: object - properties: - storage: - type: string - description: Storage backend type (always "server") - migrated: - type: boolean - description: Whether migration from browser storage is complete (single-user) - users: - type: object - additionalProperties: - type: string - description: Map of user_id to directory name (multi-user) - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - post: - operationId: createUser - tags: [user] - summary: Create a new user (multi-user mode) - description: Creates a new user entry. Only meaningful when ComfyUI is running in multi-user mode. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - username - properties: - username: - type: string - description: Username for the new user - responses: - "200": - description: Created user ID - content: - application/json: - schema: - type: string - description: The generated user_id - "400": - description: Username already exists or invalid - - # --------------------------------------------------------------------------- - # Userdata - # --------------------------------------------------------------------------- - /api/userdata: - get: - operationId: getUserdata - tags: [userdata] - summary: List files in a userdata directory - description: Lists files in the authenticated user's data directory. Returns either filename strings or full objects depending on the `full_info` query parameter. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - - name: dir - in: query - required: true - schema: - type: string - description: Directory path relative to the user's data folder - - name: recurse - in: query - schema: - type: boolean - description: Recurse into subdirectories - - name: full_info - in: query - schema: - type: boolean - description: Return full file info objects instead of just names - - name: split - in: query - schema: - type: boolean - description: Split paths into directory components - responses: - "200": - description: File listing - content: - application/json: - schema: - $ref: "#/components/schemas/GetUserDataResponseFull" - "404": - description: Directory not found - - '400': - description: Bad request (e.g., invalid filename). - content: - text/plain: - schema: - type: string - '401': - description: Unauthorized. - content: - text/plain: - schema: - type: string - '500': - description: General error - content: - text/plain: - schema: - type: string - /api/v2/userdata: - get: - operationId: listUserdataV2 - tags: [userdata] - summary: List files in userdata (v2 format) - description: Lists files in the authenticated user's data directory using the v2 response shape, which always returns full objects. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - - name: path - in: query - schema: - type: string - description: Directory path relative to user data root - responses: - "200": - description: File listing with metadata - content: - application/json: - schema: - type: array - items: - type: object - properties: - name: - type: string - path: - type: string - type: - type: string - enum: [file, directory] - size: - type: integer - modified: - type: number - description: Unix timestamp - - '404': - description: "Not Found \u2014 use /api/userdata instead" - /api/userdata/{file}: - get: - operationId: getUserdataFile - tags: [userdata] - summary: Read a userdata file - description: Reads the contents of a file from the authenticated user's data directory. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - - name: file - in: path - required: true - schema: - type: string - description: File path relative to user data directory - responses: - "200": - description: File content - content: - application/octet-stream: - schema: - type: string - format: binary - "404": - description: File not found - '400': - description: Bad request (e.g., invalid filename). - content: - text/plain: - schema: - type: string - '401': - description: Unauthorized. - content: - text/plain: - schema: - type: string - '500': - description: General error - content: - text/plain: - schema: - type: string - post: - operationId: postUserdataFile - tags: [userdata] - summary: Write or create a userdata file - description: Writes (creates or replaces) a file in the authenticated user's data directory. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - - name: file - in: path - required: true - schema: - type: string - description: File path relative to user data directory - - name: overwrite - in: query - schema: - type: boolean - description: Allow overwriting existing files - - name: full_info - in: query - schema: - type: boolean - description: Return full file info in response - requestBody: - required: true - content: - application/octet-stream: - schema: - type: string - format: binary - application/json: - schema: {} - responses: - "200": - description: File written - content: - application/json: - schema: - $ref: "#/components/schemas/UserDataResponseFull" - "409": - description: File exists and overwrite not set - '400': - description: Missing or invalid 'file' parameter. - content: - text/plain: - schema: - type: string - '401': - description: Unauthorized. - content: - text/plain: - schema: - type: string - '403': - description: The requested path is not allowed. - content: - text/plain: - schema: - type: string - '500': - description: General error - content: - text/plain: - schema: - type: string - delete: - operationId: deleteUserdataFile - tags: [userdata] - summary: Delete a userdata file - description: Deletes a file from the authenticated user's data directory. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - - name: file - in: path - required: true - schema: - type: string - description: File path relative to user data directory - responses: - "204": - description: File deleted - "404": - description: File not found - - '401': - description: Unauthorized. - content: - text/plain: - schema: - type: string - '500': - description: Internal server error. - content: - text/plain: - schema: - type: string - /api/userdata/{file}/move/{dest}: - post: - operationId: moveUserdataFile - tags: [userdata] - summary: Move or rename a userdata file - description: Renames or moves a file within the authenticated user's data directory. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - - name: file - in: path - required: true - schema: - type: string - description: Source file path - - name: dest - in: path - required: true - schema: - type: string - description: Destination file path - - name: overwrite - in: query - schema: - type: boolean - description: Allow overwriting at destination - - name: full_info - in: query - schema: - type: boolean - description: Return full file info in response - responses: - "200": - description: File moved - content: - application/json: - schema: - $ref: "#/components/schemas/UserDataResponseFull" - "404": - description: Source file not found - "409": - description: Destination exists and overwrite not set - - '400': - description: Missing or invalid parameters. - content: - text/plain: - schema: - type: string - '401': - description: Unauthorized. - content: - text/plain: - schema: - type: string - '500': - description: General error - content: - text/plain: - schema: - type: string - # --------------------------------------------------------------------------- - # Settings - # --------------------------------------------------------------------------- - /api/settings: - get: - operationId: getAllSettings - tags: [settings] - summary: Get all user settings - description: Returns all settings for the authenticated user. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - responses: - "200": - description: Settings object - content: - application/json: - schema: - type: object - additionalProperties: true - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - post: - operationId: updateMultipleSettings - tags: [settings] - summary: Update user settings (partial merge) - description: Replaces the authenticated user's settings with the provided object. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - requestBody: - required: true - content: - application/json: - schema: - type: object - additionalProperties: true - description: Partial settings to merge - responses: - "200": - description: Settings updated - - '400': - description: Invalid request - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/settings/{id}: - get: - operationId: getSettingById - tags: [settings] - summary: Get a single setting by key - description: Returns the value of a single setting, identified by key. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - - name: id - in: path - required: true - schema: - type: string - description: Setting key - responses: - "200": - description: Setting value (null if the setting does not exist) - content: - application/json: - schema: - nullable: true - description: The setting value (any JSON type), or null if not set - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '404': - description: Setting not found - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - post: - operationId: updateSettingById - tags: [settings] - summary: Set a single setting value - description: Sets the value of a single setting, identified by key. - parameters: - - $ref: "#/components/parameters/ComfyUserHeader" - - name: id - in: path - required: true - schema: - type: string - description: Setting key - requestBody: - required: true - content: - application/json: - schema: - description: The setting value (any JSON type) - responses: - "200": - description: Setting updated - - '400': - description: Invalid request - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - # --------------------------------------------------------------------------- - # Extensions / Templates / i18n - # --------------------------------------------------------------------------- - /api/extensions: - get: - operationId: getExtensions - tags: [extensions] - summary: List frontend extension JS file paths - description: Returns the list of frontend extension JS URLs registered by custom nodes, to be loaded by the frontend on startup. - responses: - "200": - description: Array of JS file paths - content: - application/json: - schema: - type: array - items: - type: string - description: Relative path to extension JS file - - /api/workflow_templates: - get: - operationId: getWorkflowTemplates - tags: [extensions] - summary: Get workflow template mappings - description: Returns a map of custom node names to their provided workflow template names. - responses: - "200": - description: Template mappings - content: - application/json: - schema: - type: object - additionalProperties: - type: array - items: - type: string - description: Map of node pack name to array of template names - - /api/i18n: - get: - operationId: getI18n - tags: [extensions] - summary: Get internationalisation translation strings - description: Returns the URLs of translation files contributed by custom nodes, keyed by locale. - responses: - "200": - description: Translation map - content: - application/json: - schema: - type: object - additionalProperties: true - description: Nested map of locale to translation key-value pairs - - # --------------------------------------------------------------------------- - # Subgraphs - # --------------------------------------------------------------------------- - /api/global_subgraphs: - get: - operationId: getGlobalSubgraphs - tags: [subgraph] - summary: List global subgraph blueprints - description: Returns a dictionary of subgraph IDs to their metadata. - responses: - "200": - description: Subgraph metadata dictionary - content: - application/json: - schema: - type: object - additionalProperties: - $ref: "#/components/schemas/GlobalSubgraphInfo" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/global_subgraphs/{id}: - get: - operationId: getGlobalSubgraph - tags: [subgraph] - summary: Get a global subgraph with full data - description: Returns the blueprint for a globally-registered subgraph, used by the frontend to materialize the subgraph node. - parameters: - - name: id - in: path - required: true - schema: - type: string - description: Subgraph identifier - responses: - "200": - description: Full subgraph data - content: - application/json: - schema: - $ref: "#/components/schemas/GlobalSubgraphData" - "404": - description: Subgraph not found - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - # --------------------------------------------------------------------------- - # Node Replacements - # --------------------------------------------------------------------------- - /api/node_replacements: - get: - operationId: getNodeReplacements - tags: [node] - summary: Get node replacement mappings - description: | - Returns a dictionary mapping deprecated or replaced node class names - to their replacement node information. - responses: - "200": - description: Replacement mappings - content: - application/json: - schema: - type: object - additionalProperties: true - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - # --------------------------------------------------------------------------- - # Internal (x-internal: true) - # --------------------------------------------------------------------------- - /internal/logs: - get: - operationId: getInternalLogs - tags: [internal] - summary: Get server logs as text - description: Returns structured ComfyUI log entries from the in-memory log buffer. - x-internal: true - responses: - "200": - description: Log text - content: - text/plain: - schema: - type: string - - /internal/logs/raw: - get: - operationId: getInternalLogsRaw - tags: [internal] - summary: Get raw structured log entries - description: Returns the raw ComfyUI log buffer as text, together with metadata about the current size limit. - x-internal: true - responses: - "200": - description: Structured log data - content: - application/json: - schema: - type: object - properties: - entries: - type: array - items: - type: object - properties: - t: - type: number - description: Timestamp - m: - type: string - description: Message - size: - type: object - properties: - cols: - type: integer - rows: - type: integer - - /internal/logs/subscribe: - patch: - operationId: subscribeToLogs - tags: [internal] - summary: Subscribe or unsubscribe a WebSocket client to log streaming - description: Subscribes or unsubscribes the current client from live log streaming over the WebSocket. - x-internal: true - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - clientId - - enabled - properties: - clientId: - type: string - description: WebSocket client ID - enabled: - type: boolean - description: Enable or disable log streaming for this client - responses: - "200": - description: Subscription updated - - /internal/folder_paths: - get: - operationId: getInternalFolderPaths - tags: [internal] - summary: Get configured folder paths - description: Returns the filesystem paths ComfyUI is configured to load models and other assets from, keyed by folder type. - x-internal: true - responses: - "200": - description: Dictionary of folder type to paths - content: - application/json: - schema: - type: object - additionalProperties: - type: array - items: - type: array - items: - type: string - description: Map of folder type name to list of [path, ...] entries - - /internal/files/{directory_type}: - get: - operationId: getFiles - tags: [internal] - summary: List files in a directory type - description: Lists the files present in one of ComfyUI's known directories (input, output, or temp). - x-internal: true - parameters: - - name: directory_type - in: path - required: true - schema: - type: string - description: Directory type (e.g. output, input, temp) - responses: - "200": - description: Array of filenames - content: - application/json: - schema: - type: array - items: - type: string - - '400': - description: Invalid directory type - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - # --------------------------------------------------------------------------- - # Assets (x-feature-gate: enable-assets) - # --------------------------------------------------------------------------- - /api/assets/hash/{hash}: - head: - operationId: checkAssetByHash - tags: [assets] - summary: Check if an asset with the given hash exists - description: Returns 204 if an asset with the given content hash already exists, 404 otherwise. Used by clients to deduplicate uploads before transferring bytes. - x-feature-gate: enable-assets - parameters: - - name: hash - in: path - required: true - schema: - type: string - description: "Blake3 hash of the asset (e.g. blake3:abc123...)" - responses: - "200": - description: Asset exists - "404": - description: No asset with this hash - - '400': - description: Invalid hash format - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/assets: - get: - operationId: listAssets - tags: [assets] - summary: List assets with filtering and pagination - description: Returns a paginated list of assets, optionally filtered by tags, name, or other query parameters. - x-feature-gate: enable-assets - parameters: - - name: limit - in: query - schema: - type: integer - default: 50 - - name: offset - in: query - schema: - type: integer - default: 0 - - name: include_tags - in: query - schema: - type: array - items: - type: string - style: form - explode: true - description: Tags that assets must have (AND logic) - - name: exclude_tags - in: query - schema: - type: array - items: - type: string - style: form - explode: true - description: Tags that assets must not have - - name: name_contains - in: query - schema: - type: string - description: Filter assets whose name contains this substring - - name: metadata_filter - in: query - schema: - type: string - description: JSON-encoded metadata key/value filter - - name: sort - in: query - schema: - type: string - description: Field to sort by - - name: order - in: query - schema: - type: string - enum: [asc, desc] - description: Sort direction - - name: include_public - in: query - schema: - type: boolean - x-runtime: [cloud] - description: "[cloud-only] Include workspace-public assets in addition to the caller's own." - - name: asset_hash - in: query - schema: - type: string - x-runtime: [cloud] - description: "[cloud-only] Filter by exact content hash." - responses: - "200": - description: Asset list - content: - application/json: - schema: - $ref: "#/components/schemas/ListAssetsResponse" - '400': - description: Invalid request parameters - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - post: - operationId: uploadAsset - tags: [assets] - summary: Upload a new asset - description: Uploads a new asset (binary content plus metadata) and registers it in the asset database. - x-feature-gate: enable-assets - requestBody: - required: true - content: - multipart/form-data: - schema: - type: object - required: - - file - properties: - file: - type: string - format: binary - description: Asset file to upload - name: - type: string - description: Display name for the asset - tags: - type: string - description: Comma-separated tags - user_metadata: - type: string - description: JSON-encoded user metadata - hash: - type: string - description: "Blake3 hash of the file content (e.g. blake3:abc123...)" - mime_type: - type: string - description: MIME type of the file (overrides auto-detected type) - preview_id: - type: string - format: uuid - description: ID of an existing asset to use as the preview image - id: - type: string - format: uuid - nullable: true - x-runtime: [cloud] - description: "[cloud-only] Client-supplied asset ID for idempotent creation. If an asset with this ID already exists, the existing asset is returned." - application/json: - schema: - type: object - x-runtime: [cloud] - description: "[cloud-only] URL-based asset upload. Caller supplies a URL instead of a file body; the server fetches the content." - required: - - url - properties: - url: - type: string - format: uri - description: "[cloud-only] URL of the file to import as an asset" - name: - type: string - description: Display name for the asset - tags: - type: string - description: Comma-separated tags - user_metadata: - type: string - description: JSON-encoded user metadata - hash: - type: string - description: "Blake3 hash of the file content (e.g. blake3:abc123...)" - mime_type: - type: string - description: MIME type of the file (overrides auto-detected type) - preview_id: - type: string - format: uuid - description: ID of an existing asset to use as the preview image - id: - type: string - format: uuid - nullable: true - x-runtime: [cloud] - description: "[cloud-only] Client-supplied asset ID for idempotent creation. If an asset with this ID already exists, the existing asset is returned." - responses: - "201": - description: Asset created - content: - application/json: - schema: - $ref: "#/components/schemas/AssetCreated" - - '200': - description: Asset already exists (returned existing asset) - content: - application/json: - schema: - $ref: '#/components/schemas/AssetCreated' - '400': - description: Invalid request (bad file, invalid URL, invalid content type, etc.) - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '403': - description: Source URL requires authentication or access denied - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '404': - description: Source URL not found - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '413': - description: File too large - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '415': - description: Unsupported media type - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '422': - description: Download failed due to network error or timeout - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/assets/from-hash: - post: - operationId: createAssetFromHash - tags: [assets] - summary: Create an asset reference from an existing hash - description: Registers a new asset that references existing content by hash, without re-uploading the bytes. - x-feature-gate: enable-assets - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - hash - - name - properties: - hash: - type: string - description: Blake3 hash of existing content - name: - type: string - description: Display name - tags: - type: array - items: - type: string - user_metadata: - type: object - additionalProperties: true - mime_type: - type: string - nullable: true - x-runtime: [cloud] - description: "[cloud-only] MIME type of the content, so the type is preserved without re-inspecting content. Ignored by local ComfyUI." - responses: - "201": - description: Asset created from hash - content: - application/json: - schema: - $ref: "#/components/schemas/AssetCreated" - - '200': - description: Asset reference already exists (returned existing) - content: - application/json: - schema: - $ref: '#/components/schemas/AssetCreated' - '400': - description: Invalid request (bad hash format, invalid tags, etc.) - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '404': - description: Source asset with given hash not found - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/assets/{id}: - get: - operationId: getAssetById - tags: [assets] - summary: Get asset metadata - description: Returns the metadata for a single asset. - x-feature-gate: enable-assets - parameters: - - name: id - in: path - description: The asset ID. - required: true - schema: - type: string - format: uuid - responses: - "200": - description: Asset metadata - content: - application/json: - schema: - $ref: "#/components/schemas/Asset" - "404": - description: Asset not found - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - put: - operationId: updateAsset - tags: [assets] - summary: Update asset metadata - description: Updates the mutable metadata of an asset (name, tags, etc.). Binary content is immutable. - x-feature-gate: enable-assets - parameters: - - name: id - in: path - description: The asset ID. - required: true - schema: - type: string - format: uuid - requestBody: - required: true - content: - application/json: - schema: - type: object - properties: - name: - type: string - description: New display name for the asset - user_metadata: - type: object - additionalProperties: true - description: Custom user metadata to set - preview_id: - type: string - format: uuid - description: ID of the asset to use as the preview - mime_type: - type: string - nullable: true - x-runtime: [cloud] - description: "[cloud-only] MIME type override when auto-detection was wrong. Ignored by local ComfyUI." - responses: - "200": - description: Asset updated - content: - application/json: - schema: - $ref: "#/components/schemas/AssetUpdated" - '400': - description: Invalid request (no fields provided) - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '404': - description: Asset not found - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - delete: - operationId: deleteAsset - tags: [assets] - summary: Delete an asset - description: Removes an asset entry. Depending on the server configuration, the underlying content may also be deleted. - x-feature-gate: enable-assets - parameters: - - name: id - in: path - description: The asset ID. - required: true - schema: - type: string - format: uuid - - name: delete_content - in: query - schema: - type: boolean - description: Also delete the underlying content file - responses: - "204": - description: Asset deleted - - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '404': - description: Asset not found - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '409': - description: Asset cannot be deleted because it is referenced by another resource (e.g., workflow version) - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/assets/{id}/content: - get: - operationId: getAssetContent - tags: [assets] - summary: Download asset file content - description: Returns the binary content of an asset. Supports range requests. - x-feature-gate: enable-assets - parameters: - - name: id - in: path - description: The asset ID. - required: true - schema: - type: string - format: uuid - responses: - "200": - description: Asset file content - content: - application/octet-stream: - schema: - type: string - format: binary - "404": - description: Asset not found - - /api/assets/{id}/tags: - post: - operationId: addAssetTags - tags: [assets] - summary: Add tags to an asset - description: Adds one or more tags to an asset. - x-feature-gate: enable-assets - parameters: - - name: id - in: path - description: The asset ID. - required: true - schema: - type: string - format: uuid - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - tags - properties: - tags: - type: array - items: - type: string - responses: - "200": - description: Tags added - content: - application/json: - schema: - $ref: "#/components/schemas/TagsModificationResponse" - '400': - description: Invalid request - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '404': - description: Asset not found - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '422': - description: Validation error (e.g., reserved tag) - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - delete: - operationId: removeAssetTags - tags: [assets] - summary: Remove tags from an asset - description: Removes one or more tags from an asset. - x-feature-gate: enable-assets - parameters: - - name: id - in: path - description: The asset ID. - required: true - schema: - type: string - format: uuid - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - tags - properties: - tags: - type: array - items: - type: string - responses: - "200": - description: Tags removed - content: - application/json: - schema: - $ref: "#/components/schemas/TagsModificationResponse" - - '400': - description: Invalid request - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '404': - description: Asset not found - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '422': - description: Validation error (e.g., reserved tag) - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/tags: - get: - operationId: listTags - tags: [assets] - summary: List all known tags with counts - description: Returns the list of all tags known to the asset database, with counts. - x-feature-gate: enable-assets - parameters: - - name: limit - in: query - schema: - type: integer - - name: offset - in: query - schema: - type: integer - - name: search - in: query - schema: - type: string - description: Search term for tag name - responses: - "200": - description: Tag list - content: - application/json: - schema: - $ref: "#/components/schemas/ListTagsResponse" - - '400': - description: Invalid request parameters - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/assets/tags/refine: - get: - operationId: getAssetTagHistogram - tags: [assets] - summary: Get tag counts for assets matching current filters - description: Returns suggested additional tags that would refine a filtered asset query, together with the count of assets each tag would select. - x-feature-gate: enable-assets - parameters: - - name: include_tags - in: query - schema: - type: array - items: - type: string - style: form - explode: true - description: Tags that assets must have (AND logic) - - name: exclude_tags - in: query - schema: - type: array - items: - type: string - style: form - explode: true - description: Tags that assets must not have - - name: name_contains - in: query - schema: - type: string - description: Filter assets whose name contains this substring - - name: metadata_filter - in: query - schema: - type: string - description: JSON-encoded metadata key/value filter - - name: limit - in: query - schema: - type: integer - - name: offset - in: query - schema: - type: integer - - name: sort - in: query - schema: - type: string - description: Field to sort by - - name: order - in: query - schema: - type: string - enum: [asc, desc] - description: Sort direction - responses: - "200": - description: Tag histogram - content: - application/json: - schema: - $ref: "#/components/schemas/AssetTagHistogramResponse" - - '400': - description: Invalid request parameters - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/assets/seed: - post: - operationId: seedAssets - tags: [assets] - summary: Trigger asset scan/seed from filesystem - description: Starts a background job that scans the configured directories and registers any assets not yet present in the asset database. - x-feature-gate: enable-assets - requestBody: - required: false - content: - application/json: - schema: - type: object - properties: - roots: - type: array - items: - type: string - description: Root folder paths to scan (if omitted, scans all) - responses: - "200": - description: Seed started - content: - application/json: - schema: - type: object - properties: - status: - type: string - - /api/assets/seed/status: - get: - operationId: getAssetSeedStatus - tags: [assets] - summary: Get asset scan progress - description: Returns the progress and status of the most recently-started asset seed job. - x-feature-gate: enable-assets - responses: - "200": - description: Scan progress - content: - application/json: - schema: - type: object - additionalProperties: true - description: Scan progress details (files scanned, total, status, etc.) - - /api/assets/seed/cancel: - post: - operationId: cancelAssetSeed - tags: [assets] - summary: Cancel an in-progress asset scan - description: Requests cancellation of the currently-running asset seed job. - x-feature-gate: enable-assets - responses: - "200": - description: Scan cancelled - content: - application/json: - schema: - type: object - properties: - status: - type: string - - /api/assets/prune: - post: - operationId: pruneAssets - tags: [assets] - summary: Mark assets whose backing files no longer exist on disk - description: Starts a background job that removes asset entries whose underlying content no longer exists on disk. - x-feature-gate: enable-assets - responses: - "200": - description: Prune result - content: - application/json: - schema: - type: object - properties: - status: - type: string - marked: - type: integer - description: Number of assets marked as missing - - # =========================================================================== - # Cloud-runtime FE-facing operations - # - # These operations are served by the cloud runtime. The local runtime returns - # 404 for all of these paths. Each operation is tagged x-runtime: [cloud]. - # =========================================================================== - - # --------------------------------------------------------------------------- - # Jobs / prompts (cloud) - # --------------------------------------------------------------------------- - /api/jobs/{job_id}/cancel: - post: - operationId: cancelJob - tags: [queue] - summary: Cancel a running or pending job - description: "[cloud-only] Requests cancellation of a job. If the job is currently executing, execution is interrupted. If it is pending in the queue, it is removed." - x-runtime: [cloud] - parameters: - - name: job_id - in: path - required: true - schema: - type: string - format: uuid - description: The job ID to cancel. - responses: - "200": - description: Cancellation accepted - content: - application/json: - schema: - $ref: "#/components/schemas/JobCancelResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '400': - description: Bad Request - job_id is not a valid UUID (emitted by request validation before the handler runs) - content: - application/json: - schema: - $ref: '#/components/schemas/BindingErrorResponse' - '500': - description: Internal server error - cancellation failed - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/job/{job_id}/status: - get: - operationId: getJobStatus - tags: [queue] - summary: Get status of a cloud job - deprecated: true - description: | - **Deprecated.** This endpoint is superseded by `GET /api/jobs/{job_id}`. - Clients should migrate; the endpoint is retained for backward - compatibility but will be removed in a future release. - x-runtime: [cloud] - parameters: - - name: job_id - in: path - required: true - schema: - type: string - format: uuid - description: The job ID to check status for. - responses: - "200": - description: Job status - content: - application/json: - schema: - $ref: "#/components/schemas/JobStatusResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '403': - description: Forbidden - job belongs to another user - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/prompt/{prompt_id}: - get: - operationId: getCloudPrompt - tags: [prompt] - summary: Get a cloud prompt by ID - description: "[cloud-only] Returns the full prompt record for a cloud-executed prompt, including the submitted workflow graph and execution metadata." - x-runtime: [cloud] - parameters: - - name: prompt_id - in: path - required: true - schema: - type: string - format: uuid - description: The prompt ID to fetch. - responses: - "200": - description: Cloud prompt detail - content: - application/json: - schema: - $ref: "#/components/schemas/CloudPrompt" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /api/history_v2: - get: - operationId: getHistory - tags: [history] - summary: Get paginated execution history (v2) - deprecated: true - description: | - **Deprecated.** This endpoint is superseded by `GET /api/jobs`. - Clients should migrate; the endpoint is retained for backward - compatibility but will be removed in a future release. - x-runtime: [cloud] - parameters: - - name: limit - in: query - schema: - type: integer - default: 20 - description: Maximum number of results - - name: offset - in: query - schema: - type: integer - default: 0 - description: Pagination offset - - name: status - in: query - schema: - type: string - description: Filter by execution status - responses: - "200": - description: History list - content: - application/json: - schema: - $ref: "#/components/schemas/HistoryResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/history_v2/{prompt_id}: - get: - operationId: getHistoryForPrompt - tags: [history] - summary: Get v2 history for a specific prompt - deprecated: true - description: | - **Deprecated.** This endpoint is superseded by `GET /api/jobs/{prompt_id}`. - Clients should migrate; the endpoint is retained for backward - compatibility but will be removed in a future release. - x-runtime: [cloud] - parameters: - - name: prompt_id - in: path - required: true - schema: - type: string - format: uuid - description: The prompt ID to fetch history for. - responses: - "200": - description: History entry - content: - application/json: - schema: - $ref: "#/components/schemas/HistoryDetailResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/logs: - get: - operationId: getLogs - tags: [system] - summary: Get cloud execution logs - deprecated: true - description: | - **Deprecated.** This endpoint returns a static placeholder response and - provides no real log data. It is retained only to avoid breaking clients - that still call it. Clients should remove their dependency; the endpoint - will be removed in a future release. - x-runtime: [cloud] - parameters: - - name: job_id - in: query - schema: - type: string - description: Filter logs by job ID - - name: limit - in: query - schema: - type: integer - default: 100 - description: Maximum number of log entries - - name: offset - in: query - schema: - type: integer - default: 0 - description: Pagination offset - responses: - "200": - description: Log entries - content: - application/json: - schema: - $ref: "#/components/schemas/LogsResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - # --------------------------------------------------------------------------- - # Assets extensions (cloud) - # --------------------------------------------------------------------------- - /api/assets/download: - post: - operationId: createAssetDownload - tags: [assets] - summary: Download assets to cloud runtime - description: "[cloud-only] Initiates a download of one or more assets to the cloud runtime environment. Returns a task ID for tracking download progress via WebSocket." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - assets - properties: - assets: - type: array - items: - $ref: "#/components/schemas/AssetDownloadRequest" - description: Assets to download - responses: - "202": - description: Download task accepted - content: - application/json: - schema: - type: object - required: - - task_id - - status - properties: - task_id: - type: string - format: uuid - description: ID of the download task; use to poll status. - status: - type: string - enum: [created, running, completed, failed] - description: Current task status (typically `created` on initial creation). - message: - type: string - description: Human-readable task message. - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '200': - description: File already exists in storage - asset created/returned immediately - content: - application/json: - schema: - $ref: '#/components/schemas/AssetCreated' - '422': - description: Validation errors - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/assets/export: - post: - operationId: createAssetExport - tags: [assets] - summary: Export assets as a downloadable archive - description: "[cloud-only] Initiates a bulk export of assets. Returns a task ID for tracking progress via WebSocket. When complete, the export can be downloaded via the exports endpoint." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - properties: - job_ids: - type: array - items: - type: string - description: Job IDs whose associated assets should all be included in the ZIP bundle. - asset_ids: - type: array - items: - type: string - format: uuid - description: Asset IDs to include in the ZIP bundle. Additive to assets associated with provided job IDs. - export_name: - type: string - description: Name for the export archive - naming_strategy: - type: string - enum: [group_by_job_id, preserve, asset_id, group_by_job_time] - default: group_by_job_time - description: "Strategy for naming files in the ZIP: group by job ID, preserve original names, use the asset ID, or group by job creation time." - job_asset_name_filters: - type: object - additionalProperties: - type: array - minItems: 1 - items: - type: string - description: Optional per-job asset name filters. When provided for a job ID, only assets whose name matches one of the listed names are included. - responses: - "202": - description: Export task accepted - content: - application/json: - schema: - type: object - required: - - task_id - - status - properties: - task_id: - type: string - format: uuid - description: ID of the export task; use to poll status. - status: - type: string - enum: [created, running, completed, failed] - description: Current task status (typically `created` on initial creation). - message: - type: string - description: Human-readable task message. - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/assets/exports/{exportName}: - get: - operationId: getAssetExport - tags: [assets] - summary: Download a completed asset export - description: "[cloud-only] Returns the archive file for a completed asset export." - x-runtime: [cloud] - parameters: - - name: exportName - in: path - required: true - schema: - type: string - description: Name of the export to download - responses: - "200": - description: Export archive file - content: - application/zip: - schema: - type: string - format: binary - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '400': - description: Invalid export name - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/assets/from-workflow: - post: - operationId: postAssetsFromWorkflow - tags: [assets] - summary: Create asset records from a workflow execution - description: "[cloud-only] Registers output files from a workflow execution as assets in the asset database." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - prompt_id - properties: - prompt_id: - type: string - format: uuid - description: Prompt ID whose outputs should be registered as assets - tags: - type: array - items: - type: string - description: Tags to apply to the created assets - responses: - "200": - description: Assets created or referenced - content: - application/json: - schema: - type: object - properties: - assets: - type: array - items: - $ref: "#/components/schemas/Asset" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/assets/import: - post: - operationId: importPublishedAssets - tags: [assets] - summary: "[cloud-only] Import published assets into the caller's library" - description: | - [cloud-only] Imports the specified published assets into the caller's asset library. New DB records reference the same storage objects; no file copying occurs. Assets the caller already owns (by hash) are deduplicated. The `id` field on each returned `AssetInfo` is the caller's newly-created private asset ID, not the published asset ID supplied in the request. - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/ImportPublishedAssetsRequest" - responses: - "200": - description: Successfully imported assets - content: - application/json: - schema: - $ref: "#/components/schemas/ImportPublishedAssetsResponse" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/assets/remote-metadata: - get: - operationId: getRemoteAssetMetadata - tags: [assets] - summary: Fetch metadata for a remote asset URL - description: "[cloud-only] Fetches and returns metadata (content type, size, filename) for a remote URL without downloading the full content." - x-runtime: [cloud] - parameters: - - name: url - in: query - required: true - schema: - type: string - format: uri - description: URL to inspect - responses: - "200": - description: Remote metadata - content: - application/json: - schema: - $ref: "#/components/schemas/AssetMetadataResponse" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '422': - description: Failed to retrieve metadata from source - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - # --------------------------------------------------------------------------- - # Custom nodes / hub (cloud) - # --------------------------------------------------------------------------- - /api/experiment/nodes: - get: - operationId: getNodeInfoSchema - tags: [runtime-only] - summary: Get pre-rendered node info schema - description: "[cloud-only] Returns the static ComfyUI object_info schema, identical for every caller, rendered once at startup with empty model/user-file context. Served by a raw HTTP handler that writes pre-rendered bytes with ETag + Cache-Control validators for RFC 7232 conditional GETs." - x-runtime: [cloud] - parameters: - - name: If-None-Match - in: header - required: false - schema: - type: string - description: Entity tag previously returned by this endpoint. When present and matching, the server returns 304 Not Modified. - responses: - "200": - description: Node info schema - headers: - ETag: - schema: - type: string - description: Entity tag for conditional request validation - Cache-Control: - schema: - type: string - description: Cache directives for the response - content: - application/json: - schema: - type: object - additionalProperties: - $ref: "#/components/schemas/NodeInfo" - "304": - description: Not Modified — returned when the client sends a matching If-None-Match header - post: - operationId: installCloudNode - tags: [node] - summary: Install a custom node package - description: "[cloud-only] Installs a custom node package in the cloud runtime by ID or repository URL." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - id - properties: - id: - type: string - description: Node package ID or repository URL - version: - type: string - description: Specific version to install - responses: - "200": - description: Node installed - content: - application/json: - schema: - $ref: "#/components/schemas/CloudNode" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /api/experiment/nodes/{id}: - get: - operationId: getNodeByID - tags: [runtime-only] - summary: Get a single node definition by ID - description: "[cloud-only] Returns one node's definition from the pre-indexed object_info schema. Served by a raw HTTP handler that writes pre-rendered bytes with ETag + Cache-Control validators for RFC 7232 conditional GETs." - x-runtime: [cloud] - parameters: - - name: id - in: path - required: true - schema: - type: string - description: Node class identifier - - name: If-None-Match - in: header - required: false - schema: - type: string - description: Entity tag previously returned by this endpoint. When present and matching, the server returns 304 Not Modified. - responses: - "200": - description: Single node definition - headers: - ETag: - schema: - type: string - description: Entity tag for conditional request validation - Cache-Control: - schema: - type: string - description: Cache directives for the response - content: - application/json: - schema: - $ref: "#/components/schemas/NodeInfo" - "304": - description: Not Modified — returned when the client sends a matching If-None-Match header - "404": - description: Node not found - delete: - operationId: uninstallCloudNode - tags: [node] - summary: Uninstall a custom node package - description: "[cloud-only] Removes a custom node package from the cloud runtime." - x-runtime: [cloud] - parameters: - - name: id - in: path - required: true - schema: - type: string - description: Custom node package ID - responses: - "204": - description: Node uninstalled - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /api/hub/assets/upload-url: - post: - operationId: createHubAssetUploadUrl - tags: [hub] - summary: Get a pre-signed upload URL for a hub asset - description: "[cloud-only] Returns a pre-signed URL that can be used to upload an asset file directly to storage." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - filename - - content_type - properties: - filename: - type: string - description: Name of the file to upload - content_type: - type: string - description: MIME type of the file - size: - type: integer - format: int64 - description: File size in bytes - responses: - "200": - description: Upload URL - content: - application/json: - schema: - type: object - properties: - upload_url: - type: string - format: uri - description: Pre-signed upload URL - asset_url: - type: string - format: uri - description: Public URL after upload completes - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '404': - description: Not found - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/hub/labels: - get: - operationId: listHubLabels - tags: [hub] - summary: List available hub labels - description: "[cloud-only] Returns the list of labels/categories available for tagging hub content." - x-runtime: [cloud] - responses: - "200": - description: Label list - content: - application/json: - schema: - $ref: "#/components/schemas/HubLabelListResponse" - '400': - description: Bad request (e.g. invalid type parameter) - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/hub/profiles: - get: - operationId: listHubProfiles - tags: [hub] - summary: List hub user profiles - description: "[cloud-only] Returns a paginated list of public hub user profiles." - x-runtime: [cloud] - parameters: - - name: limit - in: query - schema: - type: integer - description: Maximum number of results - - name: offset - in: query - schema: - type: integer - description: Pagination offset - - name: search - in: query - schema: - type: string - description: Search by username or display name - responses: - "200": - description: Profile list - content: - application/json: - schema: - type: object - properties: - profiles: - type: array - items: - $ref: "#/components/schemas/HubProfile" - total: - type: integer - has_more: - type: boolean - post: - operationId: createHubProfile - tags: [hub] - summary: Create a Hub profile - description: "[cloud-only] Creates a hub profile for the specified workspace. Username is immutable after creation." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/CreateHubProfileRequest" - responses: - "201": - description: Hub profile created - content: - application/json: - schema: - $ref: "#/components/schemas/HubProfile" - "400": - description: Bad request (e.g. invalid username) - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "409": - description: Username already taken or profile already exists - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/hub/profiles/{username}: - get: - operationId: getHubProfile - tags: [hub] - summary: Get a hub profile by username - description: "[cloud-only] Returns the public hub profile for the given username." - x-runtime: [cloud] - parameters: - - name: username - in: path - required: true - schema: - type: string - description: Hub username - responses: - "200": - description: Profile - content: - application/json: - schema: - $ref: "#/components/schemas/HubProfile" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/hub/profiles/check: - get: - operationId: checkHubUsername - tags: [hub] - summary: Check if a hub username is available - description: "[cloud-only] Returns whether the given username is available for registration." - x-runtime: [cloud] - parameters: - - name: username - in: query - required: true - schema: - type: string - description: Username to check - responses: - "200": - description: Availability result - content: - application/json: - schema: - type: object - properties: - available: - type: boolean - username: - type: string - - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '404': - description: Not found - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/hub/profiles/me: - get: - operationId: getMyHubProfile - tags: [hub] - summary: Get the authenticated user's hub profile - description: "[cloud-only] Returns the hub profile of the currently authenticated user." - x-runtime: [cloud] - responses: - "200": - description: Profile - content: - application/json: - schema: - $ref: "#/components/schemas/HubProfile" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '404': - description: No hub profile exists - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - put: - operationId: updateMyHubProfile - tags: [hub] - summary: Update the authenticated user's hub profile - description: "[cloud-only] Updates the hub profile of the currently authenticated user." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - properties: - username: - type: string - display_name: - type: string - bio: - type: string - avatar_url: - type: string - format: uri - links: - type: array - items: - type: string - format: uri - responses: - "200": - description: Updated profile - content: - application/json: - schema: - $ref: "#/components/schemas/HubProfile" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "409": - description: Conflict - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /api/hub/workflows: - get: - operationId: listHubWorkflows - tags: [hub] - summary: List published hub workflows - description: "[cloud-only] Returns a paginated list of publicly shared workflows on the hub." - x-runtime: [cloud] - parameters: - - name: limit - in: query - schema: - type: integer - description: Maximum number of results - - name: offset - in: query - schema: - type: integer - description: Pagination offset - - name: sort - in: query - schema: - type: string - description: Sort field (e.g. created_at, likes) - - name: order - in: query - schema: - type: string - enum: [asc, desc] - description: Sort direction - - name: search - in: query - schema: - type: string - description: Search by title or description - - name: labels - in: query - schema: - type: string - description: Filter by label IDs (comma-separated) - responses: - "200": - description: Hub workflow list - content: - application/json: - schema: - $ref: "#/components/schemas/HubWorkflowListResponse" - '400': - description: Bad request (e.g. malformed pagination cursor) - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '404': - description: Profile not found (when filtering by username) - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - post: - operationId: publishHubWorkflow - tags: [hub] - summary: Publish a workflow to the hub - description: "[cloud-only] Publishes a workflow to the hub with metadata, thumbnail, and sample images." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/PublishHubWorkflowRequest" - responses: - "200": - description: Workflow published to hub - content: - application/json: - schema: - $ref: "#/components/schemas/HubWorkflowDetail" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Workflow or profile not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/hub/workflows/{share_id}: - get: - operationId: getHubWorkflow - tags: [hub] - summary: Get a published hub workflow by share ID - description: "[cloud-only] Returns the full details of a published workflow on the hub." - x-runtime: [cloud] - parameters: - - name: share_id - in: path - required: true - schema: - type: string - description: Workflow share ID - responses: - "200": - description: Hub workflow - content: - application/json: - schema: - $ref: "#/components/schemas/HubWorkflowDetail" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '413': - description: Workflow JSON too large - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - delete: - operationId: deleteHubWorkflow - tags: [hub] - summary: Unpublish a workflow from the hub - description: "[cloud-only] Removes a workflow from the hub listing." - x-runtime: [cloud] - parameters: - - name: share_id - in: path - required: true - schema: - type: string - description: Workflow share ID - responses: - "204": - description: Successfully unpublished - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Workflow not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/hub/workflows/index: - get: - operationId: listHubWorkflowIndex - tags: [hub] - summary: Get the hub workflow index - description: "[cloud-only] Returns the lightweight index of all hub workflows for client-side search and navigation." - x-runtime: [cloud] - responses: - "200": - description: Workflow index - content: - application/json: - schema: - type: array - items: - $ref: "#/components/schemas/HubWorkflowIndexEntry" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - # --------------------------------------------------------------------------- - # Workflows (cloud) - # --------------------------------------------------------------------------- - /api/workflows: - get: - operationId: listWorkflows - tags: [workflows] - summary: List cloud workflows - description: "[cloud-only] Returns a paginated list of the authenticated user's cloud workflows." - x-runtime: [cloud] - parameters: - - name: limit - in: query - schema: - type: integer - description: Maximum number of results - - name: offset - in: query - schema: - type: integer - description: Pagination offset - - name: sort - in: query - schema: - type: string - description: Sort field - - name: order - in: query - schema: - type: string - enum: [asc, desc] - description: Sort direction - - name: search - in: query - schema: - type: string - description: Search by workflow name - responses: - "200": - description: Workflow list - content: - application/json: - schema: - $ref: "#/components/schemas/WorkflowListResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - post: - operationId: createWorkflow - tags: [workflows] - summary: Create a new cloud workflow - description: "[cloud-only] Creates a new cloud workflow with the provided name and optional initial content." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - name - properties: - name: - type: string - description: Workflow name - description: - type: string - description: Workflow description - content: - type: object - additionalProperties: true - description: Initial workflow graph JSON - responses: - "201": - description: Workflow created - content: - application/json: - schema: - $ref: "#/components/schemas/WorkflowResponse" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '422': - description: Validation error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/workflows/{workflow_id}: - get: - operationId: getWorkflow - tags: [workflows] - summary: Get a cloud workflow by ID - description: "[cloud-only] Returns the metadata for a cloud workflow." - x-runtime: [cloud] - parameters: - - name: workflow_id - in: path - required: true - schema: - type: string - format: uuid - description: The workflow ID. - responses: - "200": - description: Workflow detail - content: - application/json: - schema: - $ref: "#/components/schemas/WorkflowResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '403': - description: Forbidden - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - patch: - operationId: updateWorkflow - tags: [workflows] - summary: Update a cloud workflow - description: "[cloud-only] Updates the metadata (name, description) of an existing cloud workflow." - x-runtime: [cloud] - parameters: - - name: workflow_id - in: path - required: true - schema: - type: string - format: uuid - description: The workflow ID. - requestBody: - required: true - content: - application/json: - schema: - type: object - properties: - name: - type: string - description: - type: string - responses: - "200": - description: Workflow updated - content: - application/json: - schema: - $ref: "#/components/schemas/WorkflowResponse" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '422': - description: Validation error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - delete: - operationId: deleteWorkflow - tags: [workflows] - summary: Delete a cloud workflow - description: "[cloud-only] Deletes a cloud workflow and all its versions." - x-runtime: [cloud] - parameters: - - name: workflow_id - in: path - required: true - schema: - type: string - format: uuid - description: The workflow ID. - responses: - "204": - description: Workflow deleted - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/workflows/{workflow_id}/content: - get: - operationId: getWorkflowContent - tags: [workflows] - summary: Get the content of a cloud workflow - description: "[cloud-only] Returns the full workflow graph JSON for the latest version of a cloud workflow." - x-runtime: [cloud] - parameters: - - name: workflow_id - in: path - required: true - schema: - type: string - format: uuid - description: The workflow ID. - - name: version_id - in: query - schema: - type: string - description: Specific version ID to fetch - responses: - "200": - description: Workflow content - content: - application/json: - schema: - type: object - additionalProperties: true - description: The full workflow graph JSON - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '403': - description: Forbidden - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - put: - operationId: updateCloudWorkflowContent - tags: [workflows] - summary: Update the content of a cloud workflow - description: "[cloud-only] Saves new workflow graph JSON as a new version of the cloud workflow." - x-runtime: [cloud] - parameters: - - name: workflow_id - in: path - required: true - schema: - type: string - format: uuid - description: The workflow ID. - requestBody: - required: true - content: - application/json: - schema: - type: object - additionalProperties: true - description: The workflow graph JSON to save - responses: - "200": - description: Content updated - content: - application/json: - schema: - $ref: "#/components/schemas/CloudWorkflowVersion" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /api/workflows/{workflow_id}/fork: - post: - operationId: forkWorkflow - tags: [workflows] - summary: Fork a cloud workflow - description: "[cloud-only] Creates a copy of a cloud workflow under the authenticated user's account." - x-runtime: [cloud] - parameters: - - name: workflow_id - in: path - required: true - schema: - type: string - format: uuid - description: The workflow ID to fork. - requestBody: - required: false - content: - application/json: - schema: - type: object - properties: - name: - type: string - description: Name for the forked workflow (defaults to original name) - responses: - "201": - description: Forked workflow - content: - application/json: - schema: - $ref: "#/components/schemas/WorkflowResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '403': - description: Forbidden - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '422': - description: Validation error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/workflows/{workflow_id}/versions: - get: - operationId: listCloudWorkflowVersions - tags: [workflows] - summary: List versions of a cloud workflow - description: "[cloud-only] Returns the version history of a cloud workflow." - x-runtime: [cloud] - parameters: - - name: workflow_id - in: path - required: true - schema: - type: string - format: uuid - description: The workflow ID. - - name: limit - in: query - schema: - type: integer - description: Maximum number of results - - name: offset - in: query - schema: - type: integer - description: Pagination offset - responses: - "200": - description: Version list - content: - application/json: - schema: - type: object - properties: - versions: - type: array - items: - $ref: "#/components/schemas/CloudWorkflowVersion" - total: - type: integer - has_more: - type: boolean - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - post: - operationId: createWorkflowVersion - tags: [workflows] - summary: Create a new cloud workflow version - description: "[cloud-only] Creates a new workflow version with updated workflow JSON. Uses optimistic concurrency via base_version." - x-runtime: [cloud] - parameters: - - name: workflow_id - in: path - required: true - schema: - type: string - format: uuid - description: The workflow ID. - requestBody: - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/CreateWorkflowVersionRequest" - responses: - "201": - description: Version created - content: - application/json: - schema: - $ref: "#/components/schemas/WorkflowVersionResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden — not the workflow owner - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "409": - description: Version conflict — base_version does not match latest - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '422': - description: Validation error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/workflows/published/{share_id}: - get: - operationId: getPublishedWorkflow - tags: [workflows] - summary: Get a published workflow by share ID - description: "[cloud-only] Returns a publicly published cloud workflow by its share identifier." - x-runtime: [cloud] - parameters: - - name: share_id - in: path - required: true - schema: - type: string - description: The workflow share ID. - responses: - "200": - description: Published workflow - content: - application/json: - schema: - $ref: "#/components/schemas/PublishedWorkflowDetail" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '413': - description: Workflow JSON too large - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - # --------------------------------------------------------------------------- - # Auth / session (cloud) - # --------------------------------------------------------------------------- - /api/auth/session: - get: - operationId: getAuthSession - tags: [auth] - summary: Get the current authentication session - description: "[cloud-only] Returns the current session state for the authenticated user, including user identity and active workspace." - x-runtime: [cloud] - responses: - "200": - description: Session info - content: - application/json: - schema: - $ref: "#/components/schemas/AuthSession" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - post: - operationId: createSession - tags: [auth] - summary: Create a session cookie - description: "[cloud-only] Creates a session cookie from the bearer token in the Authorization header. Returns a Set-Cookie header with a secure HttpOnly session cookie. Cookie authentication is not allowed for this endpoint." - x-runtime: [cloud] - responses: - "200": - description: Session created - content: - application/json: - schema: - $ref: "#/components/schemas/CreateSessionResponse" - "400": - description: Bad request — invalid or expired ID token - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - delete: - operationId: deleteSession - tags: [auth] - summary: Delete session cookie (logout) - description: "[cloud-only] Clears the session cookie and optionally revokes the session on the server." - x-runtime: [cloud] - responses: - "200": - description: Session deleted - content: - application/json: - schema: - $ref: "#/components/schemas/DeleteSessionResponse" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/auth/token: - post: - operationId: exchangeToken - tags: [auth] - summary: Exchange credentials for an access token - description: "[cloud-only] Exchanges authentication credentials (e.g. an authorization code) for an access token." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - grant_type - properties: - grant_type: - type: string - enum: [authorization_code, refresh_token] - description: OAuth2 grant type - code: - type: string - description: Authorization code (for authorization_code grant) - refresh_token: - type: string - description: Refresh token (for refresh_token grant) - redirect_uri: - type: string - format: uri - description: Redirect URI used in the authorization request - responses: - "200": - description: Token response - content: - application/json: - schema: - $ref: "#/components/schemas/ExchangeTokenResponse" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '404': - description: Workspace not found or user not a member - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /.well-known/jwks.json: - get: - operationId: getJwks - tags: [auth] - summary: Get JSON Web Key Set - description: "[cloud-only] Returns the JSON Web Key Set (JWKS) used to verify JWTs issued by the cloud authentication service." - x-runtime: [cloud] - responses: - "200": - description: JWKS - content: - application/json: - schema: - $ref: "#/components/schemas/JwksResponse" - - # --------------------------------------------------------------------------- - # OAuth 2.1 / RFC 7591 Dynamic Client Registration (cloud) - # --------------------------------------------------------------------------- - /.well-known/oauth-authorization-server: - get: - operationId: getOAuthAuthorizationServer - tags: [auth] - summary: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)" - description: "[cloud-only] Public metadata document for OAuth 2.1 clients. Cached 5 minutes." - x-runtime: [cloud] - security: [] - responses: - "200": - description: Authorization-server metadata - content: - application/json: - schema: - $ref: "#/components/schemas/OAuthAuthorizationServerMetadata" - "404": - description: OAuth disabled - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /.well-known/oauth-protected-resource: - get: - operationId: getOAuthProtectedResource - tags: [auth] - summary: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)" - description: "[cloud-only] Public metadata describing the currently advertised protected resource. Cached 5 minutes." - x-runtime: [cloud] - security: [] - responses: - "200": - description: Protected-resource metadata - content: - application/json: - schema: - $ref: "#/components/schemas/OAuthProtectedResourceMetadata" - "404": - description: OAuth disabled or no active resource configured - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /oauth/authorize: - get: - operationId: getOAuthAuthorize - tags: [auth] - summary: "[cloud-only] Begin or resume an OAuth 2.1 authorization request" - description: | - [cloud-only] Two modes: - - **Initial entry** (OAuth params present): validates client/redirect/resource/scopes, persists a server-side authorization-request row, and either redirects (no session / unverified email) to the configured frontend login URL carrying only the opaque `oauth_request_id`, or returns the JSON consent challenge for the frontend to render. - - **Resume** (`oauth_request_id` present): loads the server-side row, fails closed if expired/consumed/unknown, returns the JSON consent challenge. Browser-replayed OAuth params are intentionally ignored. - - The frontend renders the consent UI from the JSON payload and POSTs the user's decision back to this endpoint. - x-runtime: [cloud] - security: [] - parameters: - - { name: response_type, in: query, required: false, schema: { type: string } } - - { name: client_id, in: query, required: false, schema: { type: string } } - - { name: redirect_uri, in: query, required: false, schema: { type: string } } - - { name: scope, in: query, required: false, schema: { type: string } } - - name: state - in: query - required: false - schema: { type: string } - description: | - RFC 6749 §10.12 marks `state` as RECOMMENDED. Cloud hardening makes it REQUIRED on the initial-entry path (omitted only on the resume path where `oauth_request_id` is supplied instead). This parameter is `required: false` at the spec level only because the operation is dual-mode (initial entry vs. resume); the runtime rejects empty `state` on the initial-entry path with a stable `invalid_request` 400. - - { name: code_challenge, in: query, required: false, schema: { type: string } } - - { name: code_challenge_method, in: query, required: false, schema: { type: string } } - - { name: resource, in: query, required: false, schema: { type: string } } - - { name: oauth_request_id, in: query, required: false, schema: { type: string } } - responses: - "200": - description: Consent challenge payload (session present, email verified). Frontend renders the consent UI from this payload and POSTs back to /oauth/authorize. - content: - application/json: - schema: - $ref: "#/components/schemas/OAuthConsentChallenge" - "302": - description: Redirect to login (no session / unverified email) or to registered redirect_uri (pre-validated client error) - headers: - Location: - schema: - type: string - "400": - description: Invalid authorize request (pre-redirect failure — unknown client, redirect mismatch, malformed params) - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: OAuth disabled - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - post: - operationId: postOAuthAuthorize - tags: [auth] - summary: "[cloud-only] Submit OAuth consent decision" - description: | - [cloud-only] JSON-only consent submission. The handler verifies the per-row CSRF token, atomically marks the authorization request consumed (single-use covers both allow and deny paths), then returns the redirect URL the browser must navigate to. The URL contains either `code` + original `state` for allow, or the RFC 6749 §5.2 error and `state` for deny. - - Workspace membership is re-checked at submission time. Consent is persisted keyed by `(user_id, client_id, resource_id, workspace_id)`; broadening the previously approved scope set requires a fresh consent flow. - x-runtime: [cloud] - security: [] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: [oauth_request_id, csrf_token, decision, workspace_id] - properties: - oauth_request_id: { type: string, format: uuid } - csrf_token: { type: string } - decision: { type: string, enum: [allow, deny] } - workspace_id: { type: string } - responses: - "200": - description: Redirect URL for the frontend to navigate to (allow → with code+state; deny → with error+state) - content: - application/json: - schema: - $ref: "#/components/schemas/OAuthAuthorizeRedirectResponse" - "400": - description: Bad request (CSRF mismatch, expired/consumed request, inaccessible workspace) - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Scope broadening on consent re-grant — fresh consent flow required - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: OAuth disabled - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /oauth/token: - post: - operationId: postOAuthToken - tags: [auth] - summary: "[cloud-only] Exchange authorization code or refresh token for a resource-bound access token" - description: | - [cloud-only] OAuth 2.1 token endpoint (RFC 6749 §3.2). Public clients only — `client_secret` is rejected. - - Two grant types are supported: - - `authorization_code` — exchanges the code minted by `/oauth/authorize` (with PKCE verifier) for an access token + first refresh token. Single-use; reuse fails closed. - - `refresh_token` — rotates the refresh token. Old token immediately invalid; presenting an already-rotated token revokes the entire token family and emits a security metric. - - Both grant types re-validate canonical user state, current workspace membership, and the resource's active flag at every mint. A code or refresh token bound to a deactivated resource fails closed. - - Errors follow RFC 6749 §5.2. Logs never contain raw codes, refresh tokens, or minted tokens. - - Per RFC 6749 §5.1, every 200 and 400 response carries `Cache-Control: no-store` and `Pragma: no-cache` so intermediaries cannot cache token-bearing or state-change-reason responses. - x-runtime: [cloud] - security: [] - requestBody: - required: true - content: - application/x-www-form-urlencoded: - schema: - type: object - required: [grant_type, client_id] - properties: - grant_type: { type: string, enum: [authorization_code, refresh_token] } - client_id: { type: string } - code: { type: string } - redirect_uri: { type: string } - code_verifier: { type: string } - refresh_token: { type: string } - scope: { type: string } - client_secret: { type: string } - responses: - "200": - description: New token pair - headers: - Cache-Control: - schema: - type: string - description: 'Always "no-store" per RFC 6749 §5.1' - Pragma: - schema: - type: string - description: 'Always "no-cache" per RFC 6749 §5.1' - content: - application/json: - schema: - $ref: "#/components/schemas/OAuthTokenResponse" - "400": - description: RFC 6749 §5.2 error - headers: - Cache-Control: - schema: - type: string - description: 'Always "no-store" per RFC 6749 §5.1' - Pragma: - schema: - type: string - description: 'Always "no-cache" per RFC 6749 §5.1' - content: - application/json: - schema: - $ref: "#/components/schemas/OAuthTokenError" - "404": - description: OAuth disabled - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /oauth/register: - post: - operationId: postOAuthRegister - tags: [auth] - summary: "[cloud-only] Dynamic Client Registration (RFC 7591)" - description: | - [cloud-only] Public, unauthenticated, insert-only RFC 7591 §3.1 client registration. Used by MCP-spec-compliant clients to self-register a public OAuth client without operator involvement. - - Policy: - - - Public clients only — `token_endpoint_auth_method` is forced to `none`. Confidential-client registration is out of scope this phase. - - Server-owned `resource_grants`. Caller-supplied `scope` or `resource_grants` is rejected as `invalid_client_metadata` (would be a privilege-escalation surface). Dynamic clients receive the same scopes the active resource publishes. - - Application-type-aware redirect URI policy. `application_type=native` accepts loopback (`127.0.0.1`, `::1`, `localhost`) and reverse-DNS-shaped custom schemes; `application_type=web` accepts HTTPS to hosts in an operator-controlled allowlist only. `application_type` is REQUIRED on the request — missing or empty rejects with `invalid_client_metadata`. - - Anti-impersonation: reserved client names are rejected from third parties via NFKC-folded compare. - - Generated `client_id` carries a stable prefix to distinguish dynamic from seeded clients in audit logs. - - Cache-Control: `no-store` on every 201 and 400 response (the response carries fresh credentials and rejection reasons). - x-runtime: [cloud] - security: [] - requestBody: - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/OAuthRegisterRequest" - responses: - "201": - description: Registered. Body echoes the metadata RFC 7591 §3.2.1 requires. - headers: - Cache-Control: - schema: - type: string - description: 'Always "no-store"' - Pragma: - schema: - type: string - description: 'Always "no-cache"' - content: - application/json: - schema: - $ref: "#/components/schemas/OAuthRegisterResponse" - "400": - description: RFC 7591 §3.2.2 invalid client metadata - headers: - Cache-Control: - schema: - type: string - description: 'Always "no-store"' - Pragma: - schema: - type: string - description: 'Always "no-cache"' - content: - application/json: - schema: - $ref: "#/components/schemas/OAuthRegisterError" - "404": - description: OAuth disabled - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "503": - description: No active resource is configured — DCR cannot mint a usable client until an active resource row is seeded. - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - # --------------------------------------------------------------------------- - # Billing (cloud) - # --------------------------------------------------------------------------- - /api/billing/balance: - get: - operationId: getBillingBalance - tags: [billing] - summary: Get current credit balance - description: "[cloud-only] Returns the authenticated user's current credit balance and usage summary." - x-runtime: [cloud] - responses: - "200": - description: Balance info - content: - application/json: - schema: - $ref: "#/components/schemas/BillingBalanceResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/billing/events: - get: - operationId: getBillingEvents - tags: [billing] - summary: List billing events - description: "[cloud-only] Returns a paginated list of billing events (charges, credits, refunds) for the authenticated user." - x-runtime: [cloud] - parameters: - - name: limit - in: query - schema: - type: integer - description: Maximum number of results - - name: offset - in: query - schema: - type: integer - description: Pagination offset - - name: type - in: query - schema: - type: string - description: Filter by event type - responses: - "200": - description: Billing events - content: - application/json: - schema: - $ref: "#/components/schemas/BillingEventsResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/billing/ops/{id}: - get: - operationId: getBillingOpStatus - tags: [billing] - summary: Get a billing operation by ID - description: "[cloud-only] Returns details of a specific billing operation." - x-runtime: [cloud] - parameters: - - name: id - in: path - required: true - schema: - type: string - description: The billing operation ID. - responses: - "200": - description: Billing operation - content: - application/json: - schema: - $ref: "#/components/schemas/BillingOpStatusResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/billing/payment-portal: - post: - operationId: getPaymentPortal - tags: [billing] - summary: Create a payment portal session - description: "[cloud-only] Creates a Stripe customer portal session for managing payment methods and invoices. Returns a URL to redirect the user to." - x-runtime: [cloud] - responses: - "200": - description: Portal session - content: - application/json: - schema: - type: object - properties: - url: - type: string - format: uri - description: Stripe portal URL - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '400': - description: Bad request (e.g., missing return_url) - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/billing/plans: - get: - operationId: getBillingPlans - tags: [billing] - summary: List available billing plans - description: "[cloud-only] Returns the list of available subscription plans and their pricing." - x-runtime: [cloud] - responses: - "200": - description: Plan list - content: - application/json: - schema: - type: array - items: - $ref: "#/components/schemas/BillingPlan" - - '401': - description: Unauthorized - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/billing/preview-subscribe: - post: - operationId: previewSubscribe - tags: [billing] - summary: Preview a subscription change - description: "[cloud-only] Returns a preview of what a subscription change would cost, including prorations." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - plan_id - properties: - plan_id: - type: string - description: ID of the plan to preview - responses: - "200": - description: Subscription preview - content: - application/json: - schema: - $ref: "#/components/schemas/PreviewSubscribeResponse" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/billing/status: - get: - operationId: getBillingStatus - tags: [billing] - summary: Get billing status - description: "[cloud-only] Returns the authenticated user's current billing and subscription status." - x-runtime: [cloud] - responses: - "200": - description: Billing status - content: - application/json: - schema: - $ref: "#/components/schemas/BillingStatusResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '404': - description: Workspace not found - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/billing/subscribe: - post: - operationId: subscribe - tags: [billing] - summary: Subscribe to a billing plan - description: "[cloud-only] Creates a new subscription to the specified billing plan." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - plan_id - properties: - plan_id: - type: string - description: ID of the plan to subscribe to - payment_method_id: - type: string - description: Stripe payment method ID - responses: - "200": - description: Subscription created - content: - application/json: - schema: - $ref: "#/components/schemas/SubscribeResponse" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/billing/subscription/cancel: - post: - operationId: cancelSubscription - tags: [billing] - summary: Cancel the active subscription - description: "[cloud-only] Cancels the authenticated user's active subscription. The subscription remains active until the end of the current billing period." - x-runtime: [cloud] - responses: - "200": - description: Subscription cancelled - content: - application/json: - schema: - $ref: "#/components/schemas/CancelSubscriptionResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '400': - description: Invalid request (e.g., no active subscription) - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/billing/subscription/resubscribe: - post: - operationId: resubscribe - tags: [billing] - summary: Resubscribe after cancellation - description: "[cloud-only] Reactivates a subscription that was previously cancelled but has not yet expired." - x-runtime: [cloud] - responses: - "200": - description: Subscription reactivated - content: - application/json: - schema: - $ref: "#/components/schemas/ResubscribeResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '400': - description: Invalid request (e.g., no active subscription, not in cancellation grace period) - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/billing/topup: - post: - operationId: createTopup - tags: [billing] - summary: Purchase additional credits - description: "[cloud-only] Purchases a one-time credit top-up using the user's payment method on file." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - amount - properties: - amount: - type: integer - description: Number of credits to purchase - responses: - "200": - description: Top-up successful - content: - application/json: - schema: - $ref: "#/components/schemas/CreateTopupResponse" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - # --------------------------------------------------------------------------- - # Workspace (cloud) - # --------------------------------------------------------------------------- - /api/workspace/api-keys: - get: - operationId: listWorkspaceAPIKeys - tags: [workspace] - summary: List workspace API keys - description: "[cloud-only] Returns the list of API keys for the current workspace." - x-runtime: [cloud] - responses: - "200": - description: API key list - content: - application/json: - schema: - type: array - items: - $ref: "#/components/schemas/WorkspaceApiKey" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - post: - operationId: createWorkspaceAPIKey - tags: [workspace] - summary: Create a workspace API key - description: "[cloud-only] Creates a new API key for the current workspace." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - name - properties: - name: - type: string - description: Display name for the API key - description: - type: string - description: User-provided description of the key's purpose - maxLength: 5000 - responses: - "201": - description: API key created - content: - application/json: - schema: - $ref: "#/components/schemas/CreateWorkspaceAPIKeyResponse" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '404': - description: Workspace not found - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '422': - description: Validation error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '429': - description: Key limit reached - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/workspace/api-keys/{id}: - delete: - operationId: revokeWorkspaceAPIKey - tags: [workspace] - summary: Delete a workspace API key - description: "[cloud-only] Revokes and deletes a workspace API key." - x-runtime: [cloud] - parameters: - - name: id - in: path - required: true - schema: - type: string - description: The API key ID. - responses: - "204": - description: API key deleted - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/workspace/invites: - get: - operationId: listWorkspaceInvites - tags: [workspace] - summary: List pending workspace invites - description: "[cloud-only] Returns the list of pending invitations for the current workspace." - x-runtime: [cloud] - responses: - "200": - description: Invite list - content: - application/json: - schema: - type: array - items: - $ref: "#/components/schemas/WorkspaceInvite" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - post: - operationId: createWorkspaceInvite - tags: [workspace] - summary: Invite a user to the workspace - description: "[cloud-only] Creates an invitation for a user to join the current workspace." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - email - properties: - email: - type: string - format: email - description: Email address to invite - role: - type: string - enum: [admin, member] - description: Role to assign - responses: - "201": - description: Invite created - content: - application/json: - schema: - $ref: "#/components/schemas/PendingInvite" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "409": - description: Conflict - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '404': - description: Workspace not found - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '422': - description: Validation error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/workspace/invites/{inviteId}: - delete: - operationId: revokeWorkspaceInvite - tags: [workspace] - summary: Cancel a workspace invite - description: "[cloud-only] Cancels a pending workspace invitation." - x-runtime: [cloud] - parameters: - - name: inviteId - in: path - required: true - schema: - type: string - description: The invite ID. - responses: - "204": - description: Invite cancelled - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/workspace/leave: - post: - operationId: leaveWorkspace - tags: [workspace] - summary: Leave the current workspace - description: "[cloud-only] Removes the authenticated user from the current workspace." - x-runtime: [cloud] - responses: - "204": - description: Left workspace - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '404': - description: Workspace not found or not a member - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/workspace/members: - get: - operationId: listWorkspaceMembers - tags: [workspace] - summary: List workspace members - description: "[cloud-only] Returns the list of members in the current workspace." - x-runtime: [cloud] - responses: - "200": - description: Member list - content: - application/json: - schema: - type: array - items: - $ref: "#/components/schemas/WorkspaceMember" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '404': - description: Workspace not found - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '422': - description: Validation error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/workspace/members/{user_id}/api-keys: - get: - operationId: listMemberApiKeys - tags: [workspace] - summary: List API keys for a workspace member - description: "[cloud-only] Returns the API keys belonging to a specific workspace member. Requires admin role." - x-runtime: [cloud] - parameters: - - name: user_id - in: path - required: true - schema: - type: string - description: The member's user ID. - responses: - "200": - description: API key list - content: - application/json: - schema: - type: array - items: - $ref: "#/components/schemas/WorkspaceApiKey" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - delete: - operationId: bulkRevokeWorkspaceMemberAPIKeys - tags: [workspace] - summary: Bulk revoke a member's API keys - description: "[cloud-only] Revokes all active API keys for a specific workspace member. Only workspace owners can perform this action." - x-runtime: [cloud] - parameters: - - name: user_id - in: path - required: true - schema: - type: string - minLength: 1 - description: The member's user ID. - responses: - "200": - description: Keys revoked - content: - application/json: - schema: - $ref: "#/components/schemas/BulkRevokeAPIKeysResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden — must be workspace owner - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '422': - description: Validation error (e.g. empty user_id) - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/workspace/members/{userId}: - patch: - operationId: updateWorkspaceMember - tags: [workspace] - summary: Update a workspace member's role - description: "[cloud-only] Updates the role of a workspace member. Requires admin role." - x-runtime: [cloud] - parameters: - - name: userId - in: path - required: true - schema: - type: string - description: The member's user ID. - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - role - properties: - role: - type: string - enum: [admin, member] - description: New role to assign - responses: - "200": - description: Member updated - content: - application/json: - schema: - $ref: "#/components/schemas/WorkspaceMember" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - delete: - operationId: removeWorkspaceMember - tags: [workspace] - summary: Remove a member from the workspace - description: "[cloud-only] Removes a member from the current workspace. Requires admin role." - x-runtime: [cloud] - parameters: - - name: userId - in: path - required: true - schema: - type: string - description: The member's user ID. - responses: - "204": - description: Member removed - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/workspaces: - get: - operationId: listWorkspaces - tags: [workspace] - summary: List workspaces the user belongs to - description: "[cloud-only] Returns the list of workspaces the authenticated user is a member of." - x-runtime: [cloud] - responses: - "200": - description: Workspace list - content: - application/json: - schema: - type: array - items: - $ref: "#/components/schemas/Workspace" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '404': - description: Feature not enabled for user - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - post: - operationId: createWorkspace - tags: [workspace] - summary: Create a new workspace - description: "[cloud-only] Creates a new workspace. The authenticated user becomes the owner." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - name - properties: - name: - type: string - description: Workspace name - responses: - "201": - description: Workspace created - content: - application/json: - schema: - $ref: "#/components/schemas/Workspace" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '404': - description: Feature not enabled for user - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '422': - description: Validation error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/workspaces/{id}: - get: - operationId: getWorkspace - tags: [workspace] - summary: Get a workspace by ID - description: "[cloud-only] Returns details of a workspace the user is a member of." - x-runtime: [cloud] - parameters: - - name: id - in: path - required: true - schema: - type: string - description: The workspace ID. - responses: - "200": - description: Workspace detail - content: - application/json: - schema: - $ref: "#/components/schemas/Workspace" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - patch: - operationId: updateWorkspace - tags: [workspace] - summary: Update workspace settings - description: "[cloud-only] Updates the name or settings of a workspace. Requires admin role." - x-runtime: [cloud] - parameters: - - name: id - in: path - required: true - schema: - type: string - description: The workspace ID. - requestBody: - required: true - content: - application/json: - schema: - type: object - properties: - name: - type: string - description: New workspace name - responses: - "200": - description: Workspace updated - content: - application/json: - schema: - $ref: "#/components/schemas/Workspace" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '422': - description: Validation error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - delete: - operationId: deleteWorkspace - tags: [workspace] - summary: Delete a workspace - description: "[cloud-only] Soft-deletes a workspace. Requires owner role. Personal workspaces cannot be deleted." - x-runtime: [cloud] - parameters: - - name: id - in: path - required: true - schema: - type: string - description: The workspace ID. - responses: - "204": - description: Workspace deleted - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "403": - description: Forbidden — must be workspace owner - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - # --------------------------------------------------------------------------- - # User / settings / misc (cloud) - # --------------------------------------------------------------------------- - /api/feedback: - post: - operationId: submitFeedback - tags: [user] - summary: Submit user feedback - description: "[cloud-only] Submits feedback from the user about their experience with the cloud runtime." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/FeedbackRequest" - responses: - "201": - description: Feedback submitted - content: - application/json: - schema: - type: object - properties: - id: - type: string - status: - type: string - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/files/mask-layers: - get: - operationId: getMaskLayers - tags: [assets] - summary: Get related mask layer filenames - description: "[cloud-only] Given a mask file (any of the 4 layers), returns all related mask layer filenames. Used by the mask editor to load the paint, mask, and painted layers when reopening a previously edited mask." - x-runtime: [cloud] - parameters: - - name: filename - in: query - required: true - schema: - type: string - description: Hash filename of any mask layer file - responses: - "200": - description: Related mask layers - content: - application/json: - schema: - type: object - properties: - mask: - type: string - description: Filename of the mask layer - nullable: true - paint: - type: string - description: Filename of the paint strokes layer - nullable: true - painted: - type: string - description: Filename of the painted image layer - nullable: true - painted_masked: - type: string - description: Filename of the final composite layer - nullable: true - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: File not found or not a mask file - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /api/internal/cloud_analytics: - post: - operationId: postCloudAnalytics - tags: [internal] - summary: Post client analytics events - description: "[cloud-only] Receives analytics events from the frontend for processing by the cloud analytics pipeline." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - events - properties: - events: - type: array - items: - type: object - required: - - event_name - properties: - event_name: - type: string - timestamp: - type: string - format: date-time - properties: - type: object - additionalProperties: true - responses: - "200": - description: Events accepted - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/invites/{token}/accept: - post: - operationId: acceptWorkspaceInvite - tags: [workspace] - summary: Accept a workspace invitation - description: "[cloud-only] Accepts a workspace invitation using the invite token. The authenticated user is added to the workspace." - x-runtime: [cloud] - parameters: - - name: token - in: path - required: true - schema: - type: string - description: The invitation token. - responses: - "200": - description: Invite accepted - content: - application/json: - schema: - $ref: "#/components/schemas/AcceptInviteResponse" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '403': - description: Email does not match invite - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '409': - description: Already a member of this workspace - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/secrets: - get: - operationId: listSecrets - tags: [settings] - summary: List user secrets - description: "[cloud-only] Returns the list of secrets (API keys for third-party services) stored for the authenticated user. Secret values are redacted." - x-runtime: [cloud] - responses: - "200": - description: Secret list - content: - application/json: - schema: - type: array - items: - $ref: "#/components/schemas/SecretMeta" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '503': - description: Service unavailable - feature is disabled - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - post: - operationId: createSecret - tags: [settings] - summary: Create or update a secret - description: "[cloud-only] Stores a new secret or updates an existing one. Secrets are encrypted at rest." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - required: - - name - - value - properties: - name: - type: string - description: Secret name (unique per user) - value: - type: string - description: Secret value - responses: - "201": - description: Secret created - content: - application/json: - schema: - $ref: "#/components/schemas/SecretResponse" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '409': - description: Conflict - secret with this name or provider already exists - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '422': - description: Validation error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '503': - description: Service unavailable - secrets feature disabled - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/secrets/{id}: - get: - operationId: getSecret - tags: [settings] - summary: Get secret metadata - description: "[cloud-only] Returns metadata for a specific secret. Does not return the plaintext secret value." - x-runtime: [cloud] - parameters: - - name: id - in: path - required: true - schema: - type: string - format: uuid - description: The secret ID. - responses: - "200": - description: Secret metadata - content: - application/json: - schema: - $ref: "#/components/schemas/SecretResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '403': - description: Forbidden - user does not own this secret - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '503': - description: Service unavailable - secrets feature disabled - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - patch: - operationId: updateSecret - tags: [settings] - summary: Update a secret - description: "[cloud-only] Updates an existing secret's name and/or value. Both fields are optional; only provided fields are updated." - x-runtime: [cloud] - parameters: - - name: id - in: path - required: true - schema: - type: string - format: uuid - description: The secret ID. - requestBody: - required: true - content: - application/json: - schema: - $ref: "#/components/schemas/UpdateSecretRequest" - responses: - "200": - description: Secret updated - content: - application/json: - schema: - $ref: "#/components/schemas/SecretResponse" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "409": - description: Conflict — a secret with this name already exists - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '403': - description: Forbidden - user does not own this secret - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '503': - description: Service unavailable - secrets feature disabled - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - delete: - operationId: deleteSecret - tags: [settings] - summary: Delete a secret - description: "[cloud-only] Permanently deletes a stored secret." - x-runtime: [cloud] - parameters: - - name: id - in: path - required: true - schema: - type: string - description: The secret ID. - responses: - "204": - description: Secret deleted - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '403': - description: Forbidden - user does not own this secret - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '503': - description: Service unavailable - secrets feature disabled - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/user: - get: - operationId: getUser - tags: [user] - summary: Get the authenticated cloud user - description: "[cloud-only] Returns the profile and account information for the currently authenticated user." - x-runtime: [cloud] - responses: - "200": - description: User profile - content: - application/json: - schema: - $ref: "#/components/schemas/UserResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - put: - operationId: updateCloudUser - tags: [user] - summary: Update the authenticated cloud user profile - description: "[cloud-only] Updates the profile information for the currently authenticated user." - x-runtime: [cloud] - requestBody: - required: true - content: - application/json: - schema: - type: object - properties: - display_name: - type: string - avatar_url: - type: string - format: uri - responses: - "200": - description: Updated profile - content: - application/json: - schema: - $ref: "#/components/schemas/CloudUser" - "400": - description: Bad request - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /api/userdata/{file}/publish: - get: - operationId: getUserdataFilePublish - tags: [userdata] - summary: Get publish info for a userdata file - description: "[cloud-only] Returns the publish status and share info for a userdata workflow file." - x-runtime: [cloud] - parameters: - - name: file - in: path - required: true - schema: - type: string - description: File path relative to user data directory - responses: - "200": - description: Publish info (publish_time is null if never published) - content: - application/json: - schema: - $ref: "#/components/schemas/WorkflowPublishInfo" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Workflow not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - post: - operationId: postUserdataFilePublish - tags: [userdata] - summary: Publish a userdata file to the cloud - description: "[cloud-only] Makes a userdata file available via a public URL for sharing or embedding." - x-runtime: [cloud] - parameters: - - name: file - in: path - required: true - schema: - type: string - description: File path relative to user data directory - responses: - "200": - description: Published file URL - content: - application/json: - schema: - type: object - properties: - url: - type: string - format: uri - description: Public URL of the published file - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '400': - description: Bad request - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/vhs/queryvideo: - get: - operationId: getVhsQueryVideo - tags: [view] - summary: Query VHS video metadata - description: "[cloud-only] Returns metadata about a video file processed by the VHS (Video Helper Suite) integration." - x-runtime: [cloud] - parameters: - - name: filename - in: query - required: true - schema: - type: string - description: Video filename - - name: type - in: query - schema: - type: string - enum: [input, output, temp] - description: Directory type - - name: subfolder - in: query - schema: - type: string - description: Subfolder within the directory - responses: - "200": - description: Video metadata - content: - application/json: - schema: - type: object - additionalProperties: true - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '400': - description: 'Missing required query parameter. Produced by the oapi-codegen - wrapper via echo.NewHTTPError, so the body shape matches Echo''s - default HTTPError serialization rather than ErrorResponse. - ' - content: - application/json: - schema: - $ref: '#/components/schemas/BindingErrorResponse' - /api/vhs/viewaudio: - get: - operationId: viewVhsAudio - tags: [view] - summary: View or download VHS audio - description: "[cloud-only] Returns audio content from a VHS-processed file." - x-runtime: [cloud] - parameters: - - name: filename - in: query - required: true - schema: - type: string - description: Audio filename - - name: type - in: query - schema: - type: string - enum: [input, output, temp] - description: Directory type - - name: subfolder - in: query - schema: - type: string - description: Subfolder within the directory - responses: - "200": - description: Audio content - content: - audio/*: - schema: - type: string - format: binary - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /api/vhs/viewvideo: - get: - operationId: viewVhsVideo - tags: [view] - summary: View or download VHS video - description: "[cloud-only] Returns video content from a VHS-processed file." - x-runtime: [cloud] - parameters: - - name: filename - in: query - required: true - schema: - type: string - description: Video filename - - name: type - in: query - schema: - type: string - enum: [input, output, temp] - description: Directory type - - name: subfolder - in: query - schema: - type: string - description: Subfolder within the directory - responses: - "200": - description: Video content - content: - video/*: - schema: - type: string - format: binary - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /api/viewvideo: - get: - operationId: viewVideo - tags: [view] - summary: View or download a video file - deprecated: true - description: | - **Deprecated.** This endpoint is an alias of `GET /api/view` added for - legacy history-queue video playback. Callers should use `/api/view` - directly; the endpoint is retained for backward compatibility but will - be removed in a future release. - x-runtime: [cloud] - parameters: - - name: filename - in: query - required: true - schema: - type: string - description: Video filename - - name: type - in: query - schema: - type: string - enum: [input, output, temp] - description: Directory type - - name: subfolder - in: query - schema: - type: string - description: Subfolder within the directory - responses: - "200": - description: Video content - content: - video/*: - schema: - type: string - format: binary - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - /api/tasks: - get: - operationId: listTasks - tags: [task] - summary: List background tasks - description: "[cloud-only] Retrieve a paginated list of background tasks for the authenticated user. Supports filtering by task type, status, and creation time." - x-runtime: [cloud] - parameters: - - name: task_name - in: query - schema: - type: string - description: Filter by task type name (exact match). - - name: idempotency_key - in: query - schema: - type: string - description: Filter by idempotency key (exact match). - - name: status - in: query - schema: - type: string - description: Filter by one or more statuses (comma-separated). - - name: created_after - in: query - schema: - type: string - format: date-time - description: Filter tasks created after this timestamp. - - name: created_before - in: query - schema: - type: string - format: date-time - description: Filter tasks created before this timestamp. - - name: sort_order - in: query - schema: - type: string - enum: [asc, desc] - default: desc - description: Sort direction by create_time. - - name: offset - in: query - schema: - type: integer - minimum: 0 - default: 0 - description: Pagination offset (0-based). - - name: limit - in: query - schema: - type: integer - minimum: 1 - maximum: 100 - default: 20 - description: Maximum items per page (1-100). - responses: - "200": - description: Tasks retrieved - content: - application/json: - schema: - $ref: "#/components/schemas/TasksListResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "422": - description: Validation error - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' - /api/tasks/{task_id}: - get: - operationId: getTask - tags: [task] - summary: Get task details - description: "[cloud-only] Retrieve full details for a specific background task." - x-runtime: [cloud] - parameters: - - name: task_id - in: path - required: true - schema: - type: string - format: uuid - description: Task identifier (UUID). - responses: - "200": - description: Task details - content: - application/json: - schema: - $ref: "#/components/schemas/TaskResponse" - "401": - description: Unauthorized - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - "404": - description: Task not found - content: - application/json: - schema: - $ref: "#/components/schemas/CloudError" - - '500': - description: Internal server error - content: - application/json: - schema: - $ref: '#/components/schemas/ErrorResponse' components: - parameters: - ComfyUserHeader: - name: Comfy-User - in: header - required: false - schema: - type: string - description: | - Identifies the active user in multi-user mode. Used for settings, - userdata, and history isolation. This is not a security mechanism — - it is an organisational convenience with no authentication behind it. - - schemas: - # ------------------------------------------------------------------- - # Prompt - # ------------------------------------------------------------------- - PromptRequest: - type: object - description: A workflow submission. Wraps the prompt graph plus optional client identifier and extra per-request data. - required: - - prompt - properties: - prompt: - type: object - description: | - The workflow graph to execute. Keys are node IDs (strings); - values are objects with class_type and inputs. - additionalProperties: true - number: - type: number - description: Priority number for the queue (lower numbers have higher priority) - front: - type: boolean - description: If true, adds the prompt to the front of the queue - extra_data: - type: object - description: Extra data associated with the prompt (e.g. extra_pnginfo) - additionalProperties: true - client_id: - type: string - description: WebSocket client ID to receive progress updates - prompt_id: - type: string - format: uuid - description: "Client-supplied prompt ID. Server generates a UUID if omitted." - partial_execution_targets: - type: array - items: - type: string - description: List of node IDs to execute (partial graph execution) - workflow_id: - type: string - format: uuid - nullable: true - x-runtime: [cloud] - description: "[cloud-only] Cloud workflow entity ID for tracking and gallery association. Ignored by local ComfyUI." - workflow_version_id: - type: string - format: uuid - nullable: true - x-runtime: [cloud] - description: "[cloud-only] Cloud workflow version ID for pinning execution to a specific version. Ignored by local ComfyUI." - - PromptResponse: - type: object - description: Server acknowledgement of a workflow submission. Includes the assigned `prompt_id` and current queue position. - properties: - prompt_id: - type: string - format: uuid - description: Unique identifier for the prompt execution - number: - type: number - description: Priority number in the queue - node_errors: - type: object - description: Validation errors keyed by node ID - additionalProperties: - $ref: "#/components/schemas/NodeError" - error: - description: Top-level prompt error (string message or structured error) - oneOf: - - type: string - - $ref: "#/components/schemas/PromptError" - - PromptErrorResponse: - type: object - description: Error response when prompt validation fails - additionalProperties: true - - PromptError: - type: object - description: Structured prompt validation error - properties: - type: - type: string - message: - type: string - details: - type: string - - Error: - type: object - description: Detailed node-level error - properties: - type: - type: string - message: - type: string - details: - type: string - extra_info: - type: object - properties: - input_name: - type: string - additionalProperties: true - - NodeError: - type: object - description: Error details for a single node - properties: - errors: - type: array - items: - $ref: "#/components/schemas/Error" - class_type: - type: string - description: The node's class type - dependent_outputs: - type: array - items: {} - - PromptInfo: - type: object - description: Summary of a queued or recently-executed prompt, as returned by the queue and history endpoints. - properties: - exec_info: - type: object - properties: - queue_remaining: - type: integer - description: Number of items remaining in the queue - - # ------------------------------------------------------------------- - # Queue - # ------------------------------------------------------------------- - QueueInfo: - type: object - description: Queue information with pending and running items - properties: - queue_running: - type: array - description: Currently running queue items - items: - type: array - description: | - Queue item tuple: [number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive] - items: {} - prefixItems: - - type: number - description: Priority number - - type: string - format: uuid - description: prompt_id - - type: object - description: prompt graph - additionalProperties: true - - type: object - description: extra_data - additionalProperties: true - - type: array - description: outputs_to_execute (list of output node IDs) - items: - type: string - - type: object - description: sensitive data (may be omitted) - additionalProperties: true - queue_pending: - type: array - description: Pending queue items (oldest first) - items: - type: array - description: | - Queue item tuple: [number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive] - items: {} - prefixItems: - - type: number - description: Priority number - - type: string - format: uuid - description: prompt_id - - type: object - description: prompt graph - additionalProperties: true - - type: object - description: extra_data - additionalProperties: true - - type: array - description: outputs_to_execute (list of output node IDs) - items: - type: string - - type: object - description: sensitive data (may be omitted) - additionalProperties: true - - QueueManageRequest: - type: object - description: Request to clear or delete from queue - properties: - clear: - type: boolean - description: If true, clear all pending items - delete: - type: array - items: - type: string - description: Array of prompt IDs to delete from queue - - QueueManageResponse: - type: object - x-runtime: [cloud] - description: >- - [cloud-only] Result of a queue mutation. The Cloud runtime returns which - items were deleted and whether the queue was cleared; local ComfyUI - returns an empty 200 body. - properties: - deleted: - type: array - nullable: true - items: - type: string - description: Prompt IDs that were deleted from the queue. - cleared: - type: boolean - nullable: true - description: Whether the queue was cleared. - - # ------------------------------------------------------------------- - # History - # ------------------------------------------------------------------- - HistoryEntry: - type: object - description: A single execution history entry - properties: - prompt: - type: array - description: | - Prompt tuple: [number, prompt_id, prompt_graph, extra_data, output_node_ids] - items: {} - outputs: - type: object - description: Output data from execution keyed by node ID - additionalProperties: true - status: - type: object - description: Execution status (status_str, completed, messages, etc.) - additionalProperties: true - meta: - type: object - description: Metadata about the execution and nodes - additionalProperties: true - - HistoryManageRequest: - type: object - description: Request to clear or delete history entries - properties: - clear: - type: boolean - description: If true, clear all history - delete: - type: array - items: - type: string - description: Array of prompt IDs to delete from history - - # ------------------------------------------------------------------- - # Jobs - # ------------------------------------------------------------------- - JobEntry: - type: object - description: Lightweight job data for list views - required: - - id - - status - properties: - id: - type: string - format: uuid - description: Unique job identifier (same as prompt_id) - status: - type: string - enum: - - pending - - in_progress - - completed - - failed - - cancelled - description: Current job status - create_time: - type: integer - format: int64 - description: Job creation timestamp (Unix milliseconds). - execution_start_time: - type: integer - format: int64 - description: Workflow execution start timestamp (Unix milliseconds, terminal states only). - execution_end_time: - type: integer - format: int64 - description: Workflow execution end timestamp (Unix milliseconds, terminal states only). - preview_output: - type: object - additionalProperties: true - description: Primary preview output - outputs_count: - type: integer - description: Total number of output files - workflow_id: - type: string - nullable: true - x-runtime: [cloud] - description: "[cloud-only] UUID of the Cloud workflow entity this job is associated with. Local ComfyUI returns null." - execution_error: - x-runtime: [cloud] - description: "[cloud-only] Detailed execution error from ComfyUI for failed jobs. Absent on local ComfyUI." - allOf: - - $ref: "#/components/schemas/ExecutionError" - - JobDetailResponse: - type: object - description: Full job details including workflow and outputs - required: - - id - - status - properties: - id: - type: string - format: uuid - status: - type: string - enum: - - pending - - in_progress - - completed - - failed - - cancelled - workflow: - type: object - additionalProperties: true - description: Full ComfyUI workflow - outputs: - type: object - additionalProperties: true - description: Full outputs object from execution - execution_error: - $ref: "#/components/schemas/ExecutionError" - create_time: - type: integer - format: int64 - description: Job creation timestamp (Unix milliseconds). - update_time: - type: integer - format: int64 - description: Last state-change timestamp (Unix milliseconds). - execution_start_time: - type: integer - format: int64 - description: Workflow execution start timestamp (Unix milliseconds, terminal states only). - execution_end_time: - type: integer - format: int64 - description: Workflow execution end timestamp (Unix milliseconds, terminal states only). - preview_output: - type: object - additionalProperties: true - outputs_count: - type: integer - execution_status: - type: object - additionalProperties: true - execution_meta: - type: object - additionalProperties: true - - ExecutionError: - type: object - description: Detailed execution error from ComfyUI - properties: - node_id: - type: string - description: ID of the node that failed - node_type: - type: string - description: Type name of the node - exception_message: - type: string - description: Human-readable error message - exception_type: - type: string - description: Python exception type - traceback: - type: array - items: - type: string - description: Traceback lines - current_inputs: - type: object - additionalProperties: true - current_outputs: - type: object - additionalProperties: true - - PaginationInfo: - type: object - description: Pagination metadata returned alongside list responses. - properties: - offset: - type: integer - limit: - type: integer - total: - type: integer - has_more: - type: boolean - - # ------------------------------------------------------------------- - # Upload / View - # ------------------------------------------------------------------- - UploadResult: - type: object - description: Response body returned by the image/mask upload endpoints, describing where the uploaded file now lives. - properties: - name: - type: string - description: Saved filename (may be renamed to avoid collisions) - subfolder: - type: string - description: Subfolder the file was saved to - type: - type: string - description: Directory type (input, temp) - - # ------------------------------------------------------------------- - # System - # ------------------------------------------------------------------- - DeviceStats: - type: object - description: GPU/compute device statistics - required: - - name - - type - - index - properties: - name: - type: string - description: Device name - type: - type: string - description: Device type (cuda, mps, cpu, etc.) - index: - type: number - nullable: true - description: | - Device index within its type (e.g. CUDA ordinal for `cuda:0`, - `cuda:1`). `null` for devices with no index, including the CPU - device returned in `--cpu` mode (PyTorch's `torch.device('cpu').index` - is `None`). - vram_total: - type: number - description: Total VRAM in bytes - vram_free: - type: number - description: Free VRAM in bytes - torch_vram_total: - type: number - description: Total PyTorch-managed VRAM in bytes - torch_vram_free: - type: number - description: Free PyTorch-managed VRAM in bytes - - SystemStatsResponse: - type: object - description: Hardware, VRAM, Python, and ComfyUI version information for the running process. - required: - - system - - devices - properties: - system: - type: object - required: - - os - - python_version - - embedded_python - - comfyui_version - - pytorch_version - - argv - - ram_total - - ram_free - properties: - os: - type: string - description: Operating system - python_version: - type: string - description: Python version - embedded_python: - type: boolean - description: Whether using embedded Python - comfyui_version: - type: string - description: ComfyUI version string - pytorch_version: - type: string - description: PyTorch version - required_frontend_version: - type: string - description: Required frontend version - argv: - type: array - items: - type: string - description: Command line arguments - ram_total: - type: number - description: Total RAM in bytes - ram_free: - type: number - description: Free RAM in bytes - installed_templates_version: - type: string - nullable: true - description: Version of the currently installed workflow templates - required_templates_version: - type: string - nullable: true - description: Minimum required workflow templates version for this ComfyUI build - comfy_package_versions: - type: array - description: Installed and required versions for every comfy* package pinned in requirements.txt - items: - type: object - required: - - name - - installed - - required - properties: - name: + schemas: + Asset: + description: Represents a user-owned asset (image, video, or other generated output). + properties: + asset_hash: + deprecated: true + description: 'Deprecated: use hash instead. Blake3 hash of the asset content.' + pattern: ^blake3:[a-f0-9]{64}$ type: string - installed: + created_at: + description: Timestamp when the asset was created + format: date-time type: string + display_name: + description: Display name of the asset. Mirrors name for backwards compatibility. nullable: true + type: string + hash: + description: Blake3 hash of the asset content. Preferred over asset_hash. + pattern: ^blake3:[a-f0-9]{64}$ + type: string + id: + description: Unique identifier for the asset + format: uuid + type: string + is_immutable: + description: Whether this asset is immutable (cannot be modified or deleted) + type: boolean + job_id: + description: ID of the job that created this asset, if available + format: uuid + nullable: true + type: string + last_access_time: + description: Timestamp when the asset was last accessed + format: date-time + type: string + metadata: + additionalProperties: true + description: System-managed metadata from download sources (HuggingFace, CivitAI, etc.) - read-only, not user-modifiable + readOnly: true + type: object + mime_type: + description: MIME type of the asset + type: string + name: + description: Name of the asset file + type: string + preview_id: + description: ID of the preview asset if available + format: uuid + nullable: true + type: string + preview_url: + description: URL for asset preview/thumbnail + format: uri + type: string + size: + description: Size of the asset in bytes + format: int64 + type: integer + tags: + description: Tags associated with the asset + items: + type: string + type: array + updated_at: + description: Timestamp when the asset was last updated + format: date-time + type: string + user_metadata: + additionalProperties: true + description: Custom user metadata for the asset + type: object + required: + - id + - name + - created_at + - updated_at + type: object + AssetCreated: + allOf: + - $ref: '#/components/schemas/Asset' + - properties: + created_new: + description: Whether this was a new asset creation (true) or returned existing (false) + type: boolean required: + - created_new + type: object + description: Response returned when a new asset is successfully created. + AssetInfo: + description: Lightweight asset reference used in workflow publishing payloads. + properties: + id: + description: Asset identifier. type: string + in_library: + description: Whether the caller already owns this asset. + type: boolean + model: + description: Whether this asset is a model. + type: boolean + name: + type: string + preview_url: + description: Signed URL for previewing the asset. + type: string + public: + description: Whether this is a public (platform-provided) asset. + type: boolean + storage_url: + type: string + required: + - id + - name + - preview_url + - storage_url + - model + - public + - in_library + type: object + AssetTagHistogramResponse: + description: Histogram of tag counts used for refining asset search results. + properties: + tag_counts: + additionalProperties: + type: integer + description: Map of tag names to their occurrence counts on matching assets + example: + checkpoint: 32 + lora: 193 + vae: 6 + type: object + required: + - tag_counts + type: object + AssetUpdated: + description: Response returned when an existing asset is successfully updated. + properties: + asset_hash: + deprecated: true + description: 'Deprecated: use hash instead. Blake3 hash of the asset content.' + pattern: ^blake3:[a-f0-9]{64}$ + type: string + display_name: + description: Display name of the asset. Mirrors name for backwards compatibility. nullable: true - devices: - type: array - items: - $ref: "#/components/schemas/DeviceStats" - - # ------------------------------------------------------------------- - # Node / Object Info - # ------------------------------------------------------------------- - NodeInfo: - type: object - description: 'Definition of a registered node class: its inputs, outputs, category, and display metadata.' - properties: - input: - type: object - description: Input specifications (required and optional groups) - additionalProperties: true - input_order: - type: object - description: Ordered input names per group - additionalProperties: - type: array - items: - type: string - output: - type: array - items: - type: string - description: Output type names - output_is_list: - type: array - items: - type: boolean - description: Whether each output is a list - output_name: - type: array - items: - type: string - description: Display names of outputs - name: - type: string - description: Internal class name - display_name: - type: string - description: Human-readable display name - description: - type: string - description: Node description - python_module: - type: string - description: Python module implementing the node - category: - type: string - description: Node category path - output_node: - type: boolean - description: Whether this is an output node - output_tooltips: - type: array - items: - type: string - description: Tooltips for each output - deprecated: - type: boolean - description: Whether the node is deprecated - experimental: - type: boolean - description: Whether the node is experimental - api_node: - type: boolean - description: Whether this is an API node - is_input_list: - type: boolean - description: Whether the node accepts list inputs - dev_only: - type: boolean - description: Whether the node is developer-only (hidden in production UI) - has_intermediate_output: - type: boolean - description: Whether the node emits intermediate output during execution - search_aliases: - type: array - items: - type: string - description: Alternative search terms for finding this node - essentials_category: - type: string - nullable: true - description: | - Category override used by the essentials pack. The - `essentials_category` key may be present with a string value, - present and `null`, or absent entirely: - - - V1 nodes: `essentials_category` is **omitted** when the node - class doesn't define an `ESSENTIALS_CATEGORY` attribute, and - **`null`** if the attribute is explicitly set to `None`. - - V3 nodes (`comfy_api.latest.io`): `essentials_category` is - **always present**, and **`null`** for nodes whose `Schema` - doesn't populate it. - - # ------------------------------------------------------------------- - # Models - # ------------------------------------------------------------------- - ModelFolder: - type: object - description: A configured model folder and the list of disk paths it resolves to. - required: - - name - - folders - properties: - name: - type: string - description: Model folder type name (e.g. "checkpoints") - folders: - type: array - items: - type: string - description: Filesystem paths for this model type - - ModelFile: - type: object - description: A single model file in a folder, with filesystem metadata. - required: - - name - - pathIndex - properties: - name: - type: string - description: Model filename - pathIndex: - type: integer - description: Index into the folder's paths array - modified: - type: number - description: File modification timestamp - created: - type: number - description: File creation timestamp - size: - type: integer - format: int64 - description: File size in bytes - - # ------------------------------------------------------------------- - # Subgraphs - # ------------------------------------------------------------------- - GlobalSubgraphInfo: - type: object - description: Metadata for a global subgraph blueprint (without full data) - required: - - source - - name - - info - properties: - source: - type: string - description: Source type ("templates" or "custom_node") - name: - type: string - description: Display name of the subgraph blueprint - info: - type: object - description: Additional information about the subgraph - required: - - node_pack - properties: - node_pack: - type: string - description: The node pack/module providing this subgraph - data: - type: string - description: The full subgraph JSON data (may be empty in list view) - - GlobalSubgraphData: - type: object - description: Full data for a global subgraph blueprint - required: - - source - - name - - info - - data - properties: - source: - type: string - description: Source type ("templates" or "custom_node") - name: - type: string - description: Display name of the subgraph blueprint - info: - type: object - description: Additional information about the subgraph - required: - - node_pack - properties: - node_pack: - type: string - description: The node pack/module providing this subgraph - data: - type: string - description: The full subgraph JSON data as a string - - # ------------------------------------------------------------------- - # Userdata - # ------------------------------------------------------------------- - UserDataResponse: - description: | - Response body for the POST endpoints `/api/userdata/{file}` and - `/api/userdata/{file}/move/{dest}`. Returns a single item whose - shape depends on the `full_info` query parameter. - x-variant-selector: - full_info=true: file-info object (`GetUserDataResponseFullFile`) - default: relative path string - oneOf: - - $ref: "#/components/schemas/GetUserDataResponseFullFile" - - type: string - description: Relative path of the written or moved file. Returned when `full_info` is absent or false. - - ListUserdataResponse: - description: | - Response body for `GET /api/userdata`. The array item shape is - determined by the `full_info` and `split` query parameters. - x-variant-selector: - full_info=true: array of file-info objects (`GetUserDataResponseFullFile`) - split=true: array of `[relative_path, ...path_components]` arrays - default: array of relative path strings - oneOf: - - type: array - items: - $ref: "#/components/schemas/GetUserDataResponseFullFile" - description: Returned when `full_info=true`. - - type: array - items: - type: array - items: - type: string - minItems: 2 - description: | - Returned when `split=true` and `full_info=false`. Each inner - array is `[relative_path, ...path_components]`. - - type: array - items: - type: string - description: Default shape — array of file paths relative to the user data root. - - GetUserDataResponseFullFile: - type: object - description: A single entry in a full-info user data listing. - properties: - path: - type: string - description: File name or path relative to the user directory - created: - type: number - description: Unix timestamp of file creation - size: - type: integer - description: File size in bytes - modified: - type: integer - format: int64 - description: Unix timestamp of last modification in milliseconds - - # ------------------------------------------------------------------- - # Assets - # ------------------------------------------------------------------- - Asset: - type: object - description: A registered asset — an input/output file tracked in the asset database with content hash and metadata. - required: - - id - - name - - size - - created_at - - updated_at - properties: - id: - type: string - format: uuid - description: Unique identifier for the asset - name: - type: string - description: Name of the asset file - hash: - type: string - nullable: true - description: Blake3 content hash of the asset (preferred over asset_hash) - pattern: "^blake3:[a-f0-9]{64}$" - asset_hash: - type: string - nullable: true - deprecated: true - description: "Deprecated: use `hash` instead. Blake3 hash of the asset content." - pattern: "^blake3:[a-f0-9]{64}$" - size: - type: integer - format: int64 - description: Size of the asset in bytes - mime_type: - type: string - description: MIME type of the asset - tags: - type: array - items: - type: string - description: Tags associated with the asset - user_metadata: - type: object - description: Custom user metadata - additionalProperties: true - metadata: - type: object - description: System-managed metadata (read-only) - additionalProperties: true - readOnly: true - preview_url: - type: string - format: uri - description: URL for asset preview/thumbnail - preview_id: - type: string - format: uuid - description: ID of the preview asset if available - prompt_id: - type: string - format: uuid - nullable: true - deprecated: true - description: "Deprecated: use job_id instead. ID of the prompt that created this asset." - job_id: - type: string - format: uuid - nullable: true - description: ID of the job that created this asset - created_at: - type: string - format: date-time - updated_at: - type: string - format: date-time - last_access_time: - type: string - format: date-time - is_immutable: - type: boolean - description: Whether this asset is immutable - - AssetCreated: - description: Response body returned after successfully registering a new asset. - allOf: - - $ref: "#/components/schemas/Asset" - - type: object - required: - - created_new - properties: - created_new: - type: boolean - description: Whether this was a new creation (true) or returned existing (false) - - AssetUpdated: - type: object - description: Response body returned after updating an asset's metadata. - required: - - id - - updated_at - properties: - id: - type: string - format: uuid - name: - type: string - hash: - type: string - nullable: true - description: Blake3 content hash of the asset (preferred over asset_hash) - pattern: "^blake3:[a-f0-9]{64}$" - asset_hash: - type: string - nullable: true - deprecated: true - description: "Deprecated: use `hash` instead. Blake3 hash of the asset content." - pattern: "^blake3:[a-f0-9]{64}$" - tags: - type: array - items: - type: string - mime_type: - type: string - user_metadata: - type: object - additionalProperties: true - prompt_id: - type: string - format: uuid - nullable: true - deprecated: true - description: "Deprecated: use job_id instead. ID of the prompt that created this asset." - job_id: - type: string - format: uuid - nullable: true - description: ID of the job that created this asset - updated_at: - type: string - format: date-time - - ListAssetsResponse: - type: object - description: Paginated list of assets. - required: - - assets - - total - - has_more - properties: - assets: - type: array - items: - $ref: "#/components/schemas/Asset" - total: - type: integer - has_more: - type: boolean - - TagInfo: - type: object - description: A tag known to the asset database, with the number of assets bearing it. - required: - - name - - count - properties: - name: - type: string - count: - type: integer - - ListTagsResponse: - type: object - description: Flat list of all tags, with counts. - required: - - tags - - total - - has_more - properties: - tags: - type: array - items: - $ref: "#/components/schemas/TagInfo" - total: - type: integer - has_more: - type: boolean - - AssetTagHistogramResponse: - type: object - description: Tags that would refine a filtered asset query, with the count of assets each tag would additionally select. - required: - - tag_counts - properties: - tag_counts: - type: object - additionalProperties: - type: integer - description: Map of tag names to occurrence counts - - TagsModificationResponse: - type: object - description: Response body returned after adding or removing tags on an asset. - required: - - total_tags - properties: - added: - type: array - items: - type: string - description: Tags successfully added - removed: - type: array - items: - type: string - description: Tags successfully removed - already_present: - type: array - items: - type: string - description: Tags already present (for add) - not_present: - type: array - items: - type: string - description: Tags not present (for remove) - total_tags: - type: array - items: - type: string - description: All tags on the asset after the operation - - # ------------------------------------------------------------------- - # Result / Output types - # ------------------------------------------------------------------- - ResultItem: - type: object - description: A single output file reference - properties: - filename: - type: string - subfolder: - type: string - type: - type: string - enum: [input, output, temp] - display_name: - type: string - - NodeOutputs: - type: object - description: | - Outputs from a single node execution. Known keys are listed below, - but custom nodes may add arbitrary keys (additionalProperties). - properties: - images: - type: array - items: - $ref: "#/components/schemas/ResultItem" - audio: - type: array - items: - $ref: "#/components/schemas/ResultItem" - video: - type: array - items: - $ref: "#/components/schemas/ResultItem" - animated: - type: array - items: - type: boolean - text: - oneOf: - - type: string - - type: array - items: - type: string - additionalProperties: true - - TerminalSize: - type: object - description: Terminal dimensions - properties: - cols: - type: number - row: - type: number - - LogEntry: - type: object - description: A single log entry - properties: - t: - type: string - description: Timestamp - m: - type: string - description: Log message - - StatusWsMessageStatus: - type: object - description: Inner payload of a `status` WebSocket message, describing the execution queue state. - properties: - exec_info: - type: object - required: - - queue_remaining - properties: - queue_remaining: - type: integer - - StatusWsMessage: - type: object - description: Initial status message sent on connect + queue status updates - properties: - status: - $ref: "#/components/schemas/StatusWsMessageStatus" - sid: - type: string - description: Session ID assigned by the server - - ProgressWsMessage: - type: object - description: Node execution progress (step N of M) - required: - - value - - max - - prompt_id - - node - properties: - value: - type: integer - description: Current step - max: - type: integer - description: Total steps - prompt_id: - type: string - node: - type: string - description: Node ID currently executing - - ProgressTextWsMessage: - type: object - description: Text-based progress update from a node - properties: - nodeId: - type: string - text: - type: string - prompt_id: - type: string - - NodeProgressState: - type: object - description: Progress state for a single node - properties: - value: - type: number - max: - type: number - state: - type: string - enum: [pending, running, finished, error] - node_id: - type: string - prompt_id: - type: string - display_node_id: - type: string - parent_node_id: - type: string - real_node_id: - type: string - - ProgressStateWsMessage: - type: object - description: Bulk progress state for all nodes in a prompt - required: - - prompt_id - - nodes - properties: - prompt_id: - type: string - nodes: - type: object - description: Map of node ID to progress state - additionalProperties: - $ref: "#/components/schemas/NodeProgressState" - - ExecutingWsMessage: - type: object - description: Fired when a node begins execution - required: - - node - - display_node - - prompt_id - properties: - node: - type: string - description: Node ID - display_node: - type: string - description: Display node ID (may differ for subgraphs) - prompt_id: - type: string - - ExecutedWsMessage: - type: object - description: Fired when a node completes execution with output - required: - - node - - display_node - - prompt_id - - output - properties: - node: - type: string - display_node: - type: string - prompt_id: - type: string - output: - $ref: "#/components/schemas/NodeOutputs" - merge: - type: boolean - description: Whether to merge with existing output - - ExecutionWsMessageBase: - type: object - description: Base fields for execution lifecycle messages - required: - - prompt_id - - timestamp - properties: - prompt_id: - type: string - timestamp: - type: integer - description: Unix timestamp in milliseconds - - ExecutionStartWsMessage: - allOf: - - $ref: "#/components/schemas/ExecutionWsMessageBase" - description: Fired when prompt execution begins - - ExecutionSuccessWsMessage: - allOf: - - $ref: "#/components/schemas/ExecutionWsMessageBase" - description: Fired when prompt execution completes successfully - - ExecutionCachedWsMessage: - allOf: - - $ref: "#/components/schemas/ExecutionWsMessageBase" - - type: object - properties: - nodes: - type: array - items: - type: string - description: List of node IDs that were cached - description: Fired when nodes are served from cache - - ExecutionInterruptedWsMessage: - allOf: - - $ref: "#/components/schemas/ExecutionWsMessageBase" - - type: object - properties: - node_id: - type: string - node_type: - type: string - executed: - type: array - items: - type: string - description: Node IDs that completed before interruption - description: Fired when execution is interrupted by user - - ExecutionErrorWsMessage: - allOf: - - $ref: "#/components/schemas/ExecutionWsMessageBase" - - type: object - properties: - node_id: - type: string - node_type: - type: string - executed: - type: array - items: - type: string - exception_message: - type: string - exception_type: - type: string - traceback: - type: array - items: - type: string - current_inputs: {} - current_outputs: {} - description: Fired when a node throws an exception during execution - - LogsWsMessage: - type: object - description: Streaming log entries from the server - properties: - size: - $ref: "#/components/schemas/TerminalSize" - entries: - type: array - items: - $ref: "#/components/schemas/LogEntry" - - NotificationWsMessage: - type: object - description: Server notification (e.g. model download complete) - properties: - value: - type: string - id: - type: string - - FeatureFlagsWsMessage: - type: object - description: Feature flags sent on connect - additionalProperties: true - - AssetDownloadWsMessage: - type: object - description: Asset download progress - required: - - task_id - - asset_name - - bytes_total - - bytes_downloaded - - progress - - status - properties: - task_id: - type: string - asset_name: - type: string - bytes_total: - type: number - bytes_downloaded: - type: number - progress: - type: number - description: 0.0 to 1.0 - status: - type: string - enum: [created, running, completed, failed] - asset_id: - type: string - error: - type: string - - AssetExportWsMessage: - type: object - description: Bulk asset export progress - required: - - task_id - - assets_total - - assets_attempted - - assets_failed - - bytes_total - - bytes_processed - - progress - - status - properties: - task_id: - type: string - export_name: - type: string - assets_total: - type: number - assets_attempted: - type: number - assets_failed: - type: number - bytes_total: - type: number - bytes_processed: - type: number - progress: - type: number - description: 0.0 to 1.0 - status: - type: string - enum: [created, running, completed, failed] - error: - type: string - - # ------------------------------------------------------------------- - # Cloud-runtime schemas - # - # These schemas are exclusively referenced by cloud-runtime operations. - # Tagged x-runtime: [cloud]. - # ------------------------------------------------------------------- - CloudError: - type: object - x-runtime: [cloud] - description: "[cloud-only] Standard error response from cloud endpoints." - required: - - error - properties: - error: - type: string - description: Error message - code: - type: string - description: Machine-readable error code - details: - type: object - additionalProperties: true - description: Additional error context - - CloudJobStatus: - type: object - x-runtime: [cloud] - description: "[cloud-only] Status of a cloud job." - required: - - id - - status - properties: - id: - type: string - format: uuid - status: - type: string - enum: [pending, running, completed, failed, cancelled] - progress: - type: number - minimum: 0 - maximum: 1 - description: "Execution progress (0.0 to 1.0)" - started_at: - type: string - format: date-time - nullable: true - completed_at: - type: string - format: date-time - nullable: true - - CloudPrompt: - type: object - x-runtime: [cloud] - description: "[cloud-only] A cloud-executed prompt record." - required: - - id - - status - properties: - id: - type: string - format: uuid - status: - type: string - workflow: - type: object - additionalProperties: true - outputs: - type: object - additionalProperties: true - created_at: - type: string - format: date-time - completed_at: - type: string - format: date-time - nullable: true - - HistoryV2Response: - type: object - x-runtime: [cloud] - description: "[cloud-only] Paginated execution history in v2 format." - required: - - items - - total - - has_more - properties: - items: - type: array - items: - $ref: "#/components/schemas/HistoryV2Entry" - total: - type: integer - has_more: - type: boolean - - HistoryV2Entry: - type: object - x-runtime: [cloud] - description: "[cloud-only] A single execution history entry in v2 format." - required: - - id - - status - properties: - id: - type: string - format: uuid - status: - type: string - workflow: - type: object - additionalProperties: true - outputs: - type: object - additionalProperties: true - created_at: - type: string - format: date-time - started_at: - type: string - format: date-time - nullable: true - completed_at: - type: string - format: date-time - nullable: true - preview_output: - type: object - additionalProperties: true - - CloudLogsResponse: - type: object - x-runtime: [cloud] - description: "[cloud-only] Paginated cloud execution logs." - required: - - entries - properties: - entries: - type: array - items: - type: object - properties: - timestamp: - type: string - format: date-time - level: - type: string - enum: [debug, info, warn, error] - message: - type: string - job_id: - type: string - format: uuid - total: - type: integer - has_more: - type: boolean - - AssetDownloadRequest: - type: object - x-runtime: [cloud] - description: "[cloud-only] A single asset to download to the cloud runtime." - required: - - asset_id - properties: - asset_id: - type: string - format: uuid - description: ID of the asset to download - target_path: - type: string - description: Target path on the runtime filesystem - - ImportPublishedAssetsRequest: - type: object - x-runtime: [cloud] - description: "[cloud-only] Request body for importing published assets into the caller's library." - required: - - published_asset_ids - properties: - published_asset_ids: - type: array - description: IDs of published assets (inputs and models) to import. - items: - type: string - share_id: - type: string - nullable: true - description: | - Optional. Share ID of the published workflow these assets belong to. When provided (non-null, non-empty): all `published_asset_ids` must belong to this share's workflow version; returns 400 if the share is not found or any asset does not belong to it. When omitted, null, or empty string: no share-scoped validation is performed and the assets are validated only against global rules (preserved for clients that have not yet adopted `share_id`). - - ImportPublishedAssetsResponse: - type: object - x-runtime: [cloud] - description: "[cloud-only] Response after importing published assets. Each returned `AssetInfo.id` is the caller's newly-created private asset ID, not the published asset ID supplied in the request." - required: - - assets - properties: - assets: - type: array - items: - $ref: "#/components/schemas/AssetInfo" - - RemoteAssetMetadata: - type: object - x-runtime: [cloud] - description: "[cloud-only] Metadata fetched from a remote asset URL." - properties: - content_type: - type: string - description: MIME type of the remote file - content_length: - type: integer - format: int64 - description: Size in bytes - filename: - type: string - description: Suggested filename from Content-Disposition or URL - - CloudNode: - type: object - x-runtime: [cloud] - description: "[cloud-only] An installed custom node package in the cloud runtime." - required: - - id - - name - properties: - id: - type: string - name: - type: string - version: - type: string - description: - type: string - author: - type: string - repository: - type: string - format: uri - installed_at: - type: string - format: date-time - enabled: - type: boolean - - HubLabel: - type: object - x-runtime: [cloud] - description: "[cloud-only] A label/category used for tagging hub content." - required: - - id - - name - properties: - id: - type: string - name: - type: string - description: - type: string - color: - type: string - description: Hex color code for the label - - HubProfile: - type: object - x-runtime: [cloud] - description: "[cloud-only] A public user profile on the ComfyUI Hub." - required: - - username - properties: - username: - type: string - display_name: - type: string - bio: - type: string - avatar_url: - type: string - format: uri - links: - type: array - items: - type: string - format: uri - workflow_count: - type: integer - created_at: - type: string - format: date-time - - HubWorkflow: - type: object - x-runtime: [cloud] - description: "[cloud-only] A published workflow on the ComfyUI Hub." - required: - - share_id - - name - properties: - share_id: - type: string - name: - type: string - description: - type: string - author: - $ref: "#/components/schemas/HubProfile" - labels: - type: array - items: - $ref: "#/components/schemas/HubLabel" - thumbnail_url: - type: string - format: uri - content: - type: object - additionalProperties: true - description: Workflow graph JSON - likes: - type: integer - views: - type: integer - forks: - type: integer - created_at: - type: string - format: date-time - updated_at: - type: string - format: date-time - - HubWorkflowList: - type: object - x-runtime: [cloud] - description: "[cloud-only] Paginated list of hub workflows." - required: - - workflows - - total - - has_more - properties: - workflows: - type: array - items: - $ref: "#/components/schemas/HubWorkflow" - total: - type: integer - has_more: - type: boolean - - HubWorkflowIndexEntry: - type: object - x-runtime: [cloud] - description: "[cloud-only] Lightweight entry in the hub workflow index for client-side search." - required: - - share_id - - name - properties: - share_id: - type: string - name: - type: string - author_username: - type: string - labels: - type: array - items: - type: string - likes: - type: integer - updated_at: - type: string - format: date-time - - CloudWorkflow: - type: object - x-runtime: [cloud] - description: "[cloud-only] A cloud-managed workflow with version history." - required: - - id - - name - properties: - id: - type: string - format: uuid - name: - type: string - description: - type: string - share_id: - type: string - nullable: true - description: Public share identifier if published - latest_version_id: - type: string - format: uuid - nullable: true - thumbnail_url: - type: string - format: uri - nullable: true - created_at: - type: string - format: date-time - updated_at: - type: string - format: date-time - - CloudWorkflowList: - type: object - x-runtime: [cloud] - description: "[cloud-only] Paginated list of cloud workflows." - required: - - workflows - - total - - has_more - properties: - workflows: - type: array - items: - $ref: "#/components/schemas/CloudWorkflow" - total: - type: integer - has_more: - type: boolean - - CloudWorkflowVersion: - type: object - x-runtime: [cloud] - description: "[cloud-only] A version of a cloud workflow." - required: - - id - - workflow_id - properties: - id: - type: string - format: uuid - workflow_id: - type: string - format: uuid - version_number: - type: integer - created_at: - type: string - format: date-time - - AuthSession: - type: object - x-runtime: [cloud] - description: "[cloud-only] Current authentication session state." - required: - - user - properties: - user: - $ref: "#/components/schemas/CloudUser" - workspace: - $ref: "#/components/schemas/Workspace" - expires_at: - type: string - format: date-time - - AuthTokenResponse: - type: object - x-runtime: [cloud] - description: "[cloud-only] OAuth2 token response." - required: - - access_token - - token_type - properties: - access_token: - type: string - token_type: - type: string - description: Always "Bearer" - expires_in: - type: integer - description: Token lifetime in seconds - refresh_token: - type: string - nullable: true - scope: - type: string - - JwksResponse: - type: object - x-runtime: [cloud] - description: "[cloud-only] JSON Web Key Set for JWT verification." - required: - - keys - properties: - keys: - type: array - items: - type: object + type: string + hash: + description: Blake3 hash of the asset content. Preferred over asset_hash. + pattern: ^blake3:[a-f0-9]{64}$ + type: string + id: + description: Asset ID + format: uuid + type: string + job_id: + description: ID of the job that created this asset, if available + format: uuid + nullable: true + type: string + mime_type: + description: Updated MIME type of the asset + type: string + name: + description: Updated name of the asset + type: string + tags: + description: Tags associated with the asset + items: + type: string + type: array + updated_at: + description: Timestamp of the update + format: date-time + type: string + user_metadata: + additionalProperties: true + description: Updated custom metadata + type: object required: - - kty - - kid - - use + - id + - updated_at + type: object + CreateWorkflowRequest: + description: Request body for creating a new saved workflow. properties: - kty: - type: string - description: Key type (e.g. RSA) - kid: - type: string - description: Key ID - use: - type: string - description: Key use (e.g. sig) - alg: - type: string - description: Algorithm (e.g. RS256) - n: - type: string - description: RSA modulus (base64url) - e: - type: string - description: RSA exponent (base64url) - additionalProperties: true - - OAuthAuthorizationServerMetadata: - type: object - x-runtime: [cloud] - description: "[cloud-only] OAuth 2.1 authorization-server metadata (RFC 8414)." - required: - - issuer - - authorization_endpoint - - token_endpoint - - jwks_uri - - response_types_supported - - grant_types_supported - - code_challenge_methods_supported - - token_endpoint_auth_methods_supported - properties: - issuer: - type: string - format: uri - authorization_endpoint: - type: string - format: uri - token_endpoint: - type: string - format: uri - jwks_uri: - type: string - format: uri - registration_endpoint: - type: string - format: uri - description: "[cloud-only] RFC 7591 §3.1 Dynamic Client Registration endpoint. Advertised so MCP-spec-compliant clients can auto-discover and self-register without operator involvement. Present only when DCR is enabled." - response_types_supported: - type: array - items: - type: string - grant_types_supported: - type: array - items: - type: string - code_challenge_methods_supported: - type: array - items: - type: string - token_endpoint_auth_methods_supported: - type: array - items: - type: string - scopes_supported: - type: array - items: - type: string - - OAuthProtectedResourceMetadata: - type: object - x-runtime: [cloud] - description: "[cloud-only] OAuth 2.1 protected-resource metadata (RFC 9728)." - required: - - resource - - authorization_servers - - scopes_supported - properties: - resource: - type: string - format: uri - authorization_servers: - type: array - items: - type: string - format: uri - scopes_supported: - type: array - items: - type: string - bearer_methods_supported: - type: array - items: - type: string - - OAuthConsentChallenge: - type: object - x-runtime: [cloud] - description: "[cloud-only] Server-side state describing the OAuth consent decision the user is being asked to make. Returned by GET /oauth/authorize when a valid session exists; the frontend renders the consent UI from this payload and POSTs the decision back. Browser never sees the original OAuth params on resume." - required: - - oauth_request_id - - csrf_token - - client_display_name - - resource_display_name - - scopes - - workspaces - properties: - oauth_request_id: - type: string - format: uuid - description: Opaque server-side identifier for the authorization-request row. Carried back unchanged in the consent submission. - csrf_token: - type: string - description: Per-row CSRF token bound to this authorization request (not to the session). Must be echoed back on POST. - client_display_name: - type: string - description: Human-readable name of the OAuth client requesting authorization. - resource_display_name: - type: string - description: Human-readable name of the protected resource. - scopes: - type: array - description: Scopes the client is requesting for this resource. The frontend should present these for the user to approve. - items: - type: string - workspaces: - type: array - description: Workspaces the user can select from. Membership is re-checked on POST. - items: - $ref: "#/components/schemas/OAuthConsentChallengeWorkspace" - - OAuthConsentChallengeWorkspace: - type: object - x-runtime: [cloud] - description: "[cloud-only] One workspace option presented in the OAuth consent challenge." - required: [id, name, type, role] - properties: - id: { type: string } - name: { type: string } - type: { type: string, enum: [personal, team] } - role: { type: string, enum: [owner, member] } - - OAuthAuthorizeRedirectResponse: - type: object - x-runtime: [cloud] - description: "[cloud-only] Redirect target produced after a JSON consent submission. The frontend must navigate the browser to this URL so custom-scheme client callbacks work without relying on fetch-visible 302 headers." - required: - - redirect_url - properties: - redirect_url: - type: string - format: uri - description: OAuth client redirect URI with either code+state for allow, or error+state for deny. - - OAuthTokenResponse: - type: object - x-runtime: [cloud] - description: "[cloud-only] RFC 6749 §5.1 successful token response." - required: [access_token, token_type, expires_in, refresh_token, scope] - properties: - access_token: - type: string - description: Resource-bound access token (audience matches the protected resource). - token_type: - type: string - enum: [Bearer] - expires_in: - type: integer - description: Access token lifetime in seconds. - refresh_token: - type: string - description: Opaque refresh token. Rotates on every successful refresh; presenting an already-rotated token revokes the entire family. - scope: - type: string - description: Space-delimited scopes granted with this token. - - OAuthTokenError: - type: object - x-runtime: [cloud] - description: "[cloud-only] RFC 6749 §5.2 error response." - required: [error] - properties: - error: - type: string - description: 'RFC 6749 §5.2 error code: invalid_request, invalid_client, invalid_grant, unauthorized_client, unsupported_grant_type, invalid_scope.' - error_description: - type: string - description: Human-readable, no leak of internal storage state. - - OAuthRegisterRequest: - type: object - x-runtime: [cloud] - additionalProperties: false - description: "[cloud-only] RFC 7591 §2 client metadata document. Only the fields the server honors are listed; presence of `scope` or `resource_grants` in the request is rejected (`invalid_client_metadata`) because those are server-owned for dynamic clients." - required: - - redirect_uris - - application_type - properties: - redirect_uris: - type: array - items: - type: string - minItems: 1 - maxItems: 5 - description: 1–5 redirect URIs. Validated against `application_type` policy. - client_name: - type: string - maxLength: 100 - description: Human-readable name shown in the consent UI. Reserved-name list rejects impersonation of major clients. - application_type: - type: string - enum: [native, web] - description: | - RFC 7591 §2 application_type. **REQUIRED** — clients MUST declare intent; the server does not default this field. `native` for desktop / CLI / MCP-spec-strict clients (loopback redirects); `web` for hosted clients (HTTPS only, host must be allowlisted). A missing or explicitly empty `application_type` rejects with `invalid_client_metadata`. - token_endpoint_auth_method: - type: string - enum: [none] - description: 'Public clients only this phase — must be `none` if present. The server forces `none` regardless.' - grant_types: - type: array - items: - type: string - enum: [authorization_code, refresh_token] - description: Optional. Defaults to `["authorization_code","refresh_token"]`. - response_types: - type: array - items: - type: string - enum: [code] - description: Optional. Defaults to `["code"]`. - scope: - type: string - nullable: true - description: "**REJECTED IF PRESENT.** Dynamic clients do not pick scopes — the server assigns scopes from the active resource's published list. Sending `scope` in the registration body is treated as a privilege-escalation attempt and returns `invalid_client_metadata`." - resource_grants: - type: object - nullable: true - additionalProperties: - type: array - items: - type: string - description: "**REJECTED IF PRESENT.** Same reason as `scope`. The set of resources and scopes a dynamic client may request is server-policy, not request-driven." - client_uri: - type: string - nullable: true - description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." - logo_uri: - type: string - nullable: true - description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." - tos_uri: - type: string - nullable: true - description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." - policy_uri: - type: string - nullable: true - description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." - software_id: - type: string - nullable: true - description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." - software_version: - type: string - nullable: true - description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." - contacts: - type: array - nullable: true - items: - type: string - description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." - jwks: - type: object - nullable: true - additionalProperties: true - description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." - jwks_uri: - type: string - nullable: true - description: "**REJECTED IF PRESENT.** Unsupported RFC 7591 metadata for this public-client phase." - - OAuthRegisterResponse: - type: object - x-runtime: [cloud] - description: "[cloud-only] RFC 7591 §3.2.1 successful registration response." - required: - - client_id - - client_id_issued_at - - redirect_uris - - grant_types - - response_types - - token_endpoint_auth_method - - application_type - properties: - client_id: - type: string - description: Server-generated client_id. - client_id_issued_at: - type: integer - format: int64 - description: Unix timestamp (seconds) when the client was registered. - client_name: - type: string - redirect_uris: - type: array - items: - type: string - grant_types: - type: array - items: - type: string - response_types: - type: array - items: - type: string - token_endpoint_auth_method: - type: string - enum: [none] - application_type: - type: string - enum: [native, web] - - OAuthRegisterError: - type: object - x-runtime: [cloud] - description: "[cloud-only] RFC 7591 §3.2.2 error response." - required: - - error - properties: - error: - type: string - enum: [invalid_redirect_uri, invalid_client_metadata] - error_description: - type: string - nullable: true - - BillingBalance: - type: object - x-runtime: [cloud] - description: "[cloud-only] Current credit balance and usage summary." - required: - - credits_remaining - properties: - credits_remaining: - type: integer - description: Available credits - credits_used: - type: integer - description: Credits used in current billing period - credits_total: - type: integer - description: Total credits allocated in current period - - BillingEvent: - type: object - x-runtime: [cloud] - description: "[cloud-only] A billing event (charge, credit, refund)." - required: - - id - - type - - amount - - created_at - properties: - id: - type: string - type: - type: string - enum: [charge, credit, refund, topup, subscription] - amount: - type: integer - description: Amount in credits - description: - type: string - job_id: - type: string - format: uuid - nullable: true - created_at: - type: string - format: date-time - - BillingEventList: - type: object - x-runtime: [cloud] - description: "[cloud-only] Paginated list of billing events." - required: - - events - - total - - has_more - properties: - events: - type: array - items: - $ref: "#/components/schemas/BillingEvent" - total: - type: integer - has_more: - type: boolean - - BillingOp: - type: object - x-runtime: [cloud] - description: "[cloud-only] A billing operation record." - required: - - id - - status - properties: - id: - type: string - status: - type: string - enum: [pending, completed, failed] - type: - type: string - amount: - type: integer - created_at: - type: string - format: date-time - completed_at: - type: string - format: date-time - nullable: true - - BillingPlan: - type: object - x-runtime: [cloud] - description: "[cloud-only] A subscription plan with pricing details." - required: - - id - - name - properties: - id: - type: string - name: - type: string - description: - type: string - credits_per_month: - type: integer - price_cents: - type: integer - description: Monthly price in cents (USD) - currency: - type: string - default: usd - features: - type: array - items: - type: string - description: List of plan features - - BillingStatus: - type: string - x-runtime: [cloud] - description: "[cloud-only] Overall billing/payment lifecycle status." - enum: - - awaiting_payment_method - - pending_payment - - paid - - payment_failed - - inactive - - BillingSubscription: - type: object - x-runtime: [cloud] - description: "[cloud-only] Active subscription details." - required: - - id - - status - - plan_id - properties: - id: - type: string - status: - type: string - enum: [active, cancelled, past_due, trialing] - plan_id: - type: string - plan_name: - type: string - current_period_start: - type: string - format: date-time - current_period_end: - type: string - format: date-time - cancel_at_period_end: - type: boolean - - SubscriptionPreview: - type: object - x-runtime: [cloud] - description: "[cloud-only] Preview of a subscription change including prorations." - properties: - plan_id: - type: string - plan_name: - type: string - amount_due: - type: integer - description: Amount due in cents - proration_amount: - type: integer - description: Proration adjustment in cents - currency: - type: string - next_billing_date: - type: string - format: date-time - - Workspace: - type: object - x-runtime: [cloud] - description: "[cloud-only] A cloud workspace for team collaboration." - required: - - id - - name - properties: - id: - type: string - name: - type: string - type: - type: string - enum: - - personal - - team - description: Workspace type (personal vs. team). - owner_id: - type: string - member_count: - type: integer - created_at: - type: string - format: date-time - updated_at: - type: string - format: date-time - - WorkspaceMember: - type: object - x-runtime: [cloud] - description: "[cloud-only] A member of a cloud workspace." - required: - - user_id - - role - properties: - user_id: - type: string - email: - type: string - format: email - display_name: - type: string - avatar_url: - type: string - format: uri - role: - type: string - enum: [owner, admin, member] - joined_at: - type: string - format: date-time - - WorkspaceInvite: - type: object - x-runtime: [cloud] - description: "[cloud-only] A pending workspace invitation." - required: - - id - - email - - role - properties: - id: - type: string - email: - type: string - format: email - role: - type: string - enum: [admin, member] - invited_by: - type: string - created_at: - type: string - format: date-time - expires_at: - type: string - format: date-time - - WorkspaceApiKey: - type: object - x-runtime: [cloud] - description: "[cloud-only] A workspace API key (secret value redacted)." - required: - - id - - name - - description - properties: - id: - type: string - name: - type: string - description: - type: string - maxLength: 5000 - description: User-provided description of the key's purpose. Always present in responses; empty string when no description was supplied on create. - prefix: - type: string - description: First few characters of the key for identification - created_at: - type: string - format: date-time - last_used_at: - type: string - format: date-time - nullable: true - created_by: - type: string - - WorkspaceApiKeyCreated: - type: object - x-runtime: [cloud] - description: "[cloud-only] A newly created workspace API key, including the full secret value (shown only once)." - required: - - id - - name - - description - - key - properties: - id: - type: string - name: - type: string - description: - type: string - maxLength: 5000 - description: User-provided description of the key's purpose. Always present in responses; empty string when no description was supplied on create. - key: - type: string - description: Full API key value (only returned on creation) - prefix: - type: string - created_at: - type: string - format: date-time - - CloudUser: - type: object - x-runtime: [cloud] - description: "[cloud-only] A cloud-authenticated user profile." - required: - - id - - email - properties: - id: - type: string - email: - type: string - format: email - display_name: - type: string - avatar_url: - type: string - format: uri - created_at: - type: string - format: date-time - - SecretMeta: - type: object - x-runtime: [cloud] - description: "[cloud-only] Metadata for a stored secret (value is never returned)." - required: - - id - - name - properties: - id: - type: string - name: - type: string - provider: - type: string - description: "[cloud-only] Provider identifier (e.g., huggingface, civitai)." - x-runtime: [cloud] - last_used_at: - type: string - format: date-time - description: "[cloud-only] When the secret was last used for decryption." - x-runtime: [cloud] - created_at: - type: string - format: date-time - updated_at: - type: string - format: date-time - - UpdateSecretRequest: - type: object - x-runtime: [cloud] - description: "[cloud-only] Request body for updating an existing user secret." - properties: - name: - type: string - description: New name for the secret - secret_value: - type: string - description: New secret value (API key, token, etc.) - - CreateSessionResponse: - type: object - x-runtime: [cloud] - description: "[cloud-only] Response after creating a session cookie." - required: - - success - properties: - success: - type: boolean - expiresIn: - type: integer - description: Session expiration time in seconds. - - DeleteSessionResponse: - type: object - x-runtime: [cloud] - description: "[cloud-only] Response after deleting a session cookie." - required: - - success - properties: - success: - type: boolean - - CreateHubProfileRequest: - type: object - x-runtime: [cloud] - description: "[cloud-only] Request body for creating a new Hub profile." - required: - - workspace_id - - username - properties: - workspace_id: - type: string - username: - type: string - description: Unique URL-safe slug. Immutable after creation. - display_name: - type: string - description: - type: string - avatar_token: - type: string - website_urls: - type: array - items: - type: string - - PublishHubWorkflowRequest: - type: object - x-runtime: [cloud] - description: "[cloud-only] Request body for publishing or updating a workflow on the Hub." - required: - - username - - name - - workflow_filename - - asset_ids - properties: - username: - type: string - name: - type: string - workflow_filename: - type: string - asset_ids: - type: array - items: - type: string - description: - type: string - tags: - type: array - items: - type: string - models: - type: array - items: - type: string - custom_nodes: - type: array - items: - type: string - tutorial_url: - type: string - metadata: - type: object - additionalProperties: true - thumbnail_type: - type: string - enum: [image, video, image_comparison] - thumbnail_token_or_url: - type: string - thumbnail_comparison_token_or_url: - type: string - sample_image_tokens_or_urls: - type: array - items: - type: string - - HubWorkflowDetail: - type: object - x-runtime: [cloud] - description: "[cloud-only] Full Hub workflow detail including versions, assets, and statistics." - required: - - share_id - - workflow_id - - name - - workflow_json - - assets - - profile - - status - properties: - share_id: - type: string - workflow_id: - type: string - name: - type: string - status: - type: string - enum: [pending, approved, rejected, deprecated] - description: - type: string - thumbnail_type: - type: string - enum: [image, video, image_comparison] - thumbnail_url: - type: string - thumbnail_comparison_url: - type: string - tutorial_url: - type: string - metadata: - type: object - additionalProperties: true - sample_image_urls: - type: array - items: - type: string - publish_time: - type: string - format: date-time - nullable: true - workflow_json: - type: object - additionalProperties: true - assets: - type: array - items: - $ref: "#/components/schemas/AssetInfo" - profile: - $ref: "#/components/schemas/HubProfile" - - AssetInfo: - type: object - x-runtime: [cloud] - description: "[cloud-only] Lightweight asset reference used in workflow publishing payloads." - required: - - id - - filename - properties: - id: - type: string - filename: - type: string - mime_type: - type: string - size_bytes: - type: integer - format: int64 - - BulkRevokeAPIKeysResponse: - type: object - x-runtime: [cloud] - description: "[cloud-only] Response after bulk-revoking API keys for a workspace member." - required: - - revoked_count - properties: - revoked_count: - type: integer - minimum: 0 - - CreateWorkflowVersionRequest: - type: object - x-runtime: [cloud] - description: "[cloud-only] Request body for creating a new version of a saved workflow." - required: - - base_version - - workflow_json - properties: - base_version: - type: integer - description: Version number this change is based on (for optimistic concurrency). - workflow_json: - type: object - additionalProperties: true - - WorkflowVersionResponse: - type: object - x-runtime: [cloud] - description: "[cloud-only] Metadata for a single workflow version." - required: - - id - - version - - latest_version - - created_by - - created_at - properties: - id: - type: string - version: - type: integer - latest_version: - type: integer - created_by: - type: string - created_at: - type: string - format: date-time - - WorkflowPublishInfo: - type: object - x-runtime: [cloud] - description: "[cloud-only] Publishing metadata for a workflow shared to the Hub." - required: - - workflow_id - - share_id - - listed - - assets - properties: - workflow_id: - type: string - share_id: - type: string - publish_time: - type: string - format: date-time - nullable: true - listed: - type: boolean - assets: - type: array - items: - $ref: "#/components/schemas/AssetInfo" - - TaskEntry: - type: object - x-runtime: [cloud] - description: "[cloud-only] Task data for list views." - required: - - id - - task_name - - status - - create_time - properties: - id: - type: string - format: uuid - task_name: - type: string - status: - type: string - enum: [created, running, completed, failed] - create_time: - type: string - format: date-time - started_at: - type: string - format: date-time - completed_at: - type: string - format: date-time - - TaskResponse: - type: object - x-runtime: [cloud] - description: "[cloud-only] Full task details including payload and result." - required: - - id - - idempotency_key - - task_name - - payload - - status - - create_time - - update_time - properties: - id: - type: string - format: uuid - idempotency_key: - type: string - task_name: - type: string - payload: - type: object - additionalProperties: true - status: - type: string - enum: [created, running, completed, failed] - result: - type: object - additionalProperties: true - create_time: - type: string - format: date-time - update_time: - type: string - format: date-time - started_at: - type: string - format: date-time - completed_at: - type: string - format: date-time - error: - type: string - - TasksListResponse: - type: object - x-runtime: [cloud] - description: "[cloud-only] Paginated list of background tasks for the authenticated user." - required: - - tasks - - pagination - properties: - tasks: - type: array - items: - $ref: "#/components/schemas/TaskEntry" - pagination: - $ref: "#/components/schemas/PaginationInfo" - - # ===== Cloud-only schemas (Comfy-Org/cloud runtime, BE-1106) ===== - AssetDownloadResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Acknowledgement of an async asset download task; clients poll GET /api/tasks/{task_id} for status.' - required: - - task_id - - status - properties: - task_id: - type: string - format: uuid - description: Task ID for tracking download progress via GET /api/tasks/{task_id} - status: - type: string - enum: - - created - - running - - completed - - failed - description: Current task status - message: - type: string - description: Human-readable message - example: Download task created. Use task_id to track progress. - - AssetMetadataResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Metadata for a remotely hosted asset resolved by URL.' - required: - - content_length - properties: - content_length: - type: integer - format: int64 - description: Size of the asset in bytes (-1 if unknown) - example: 4294967296 - content_type: - type: string - description: MIME type of the asset - example: application/octet-stream - filename: - type: string - description: Suggested filename for the asset from source - example: realistic-vision-v5.safetensors - name: - type: string - description: Display name or title for the asset from source - example: Realistic Vision v5.0 - tags: - type: array - items: - type: string - description: Tags for categorization from source - example: - - models - - checkpoint - preview_image: - type: string - description: Preview image as base64-encoded data URL - example: data:image/jpeg;base64,/9j/4AAQSkZJRg... - validation: - description: Validation results for the file - allOf: - - $ref: '#/components/schemas/ValidationResult' - - BillingBalanceResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Current credit balance and usage details for a workspace.' - required: - - amount_micros - - currency - properties: - amount_micros: - type: number - format: double - description: The total remaining balance in microamount (1/1,000,000 of the currency unit) - prepaid_balance_micros: - type: number - format: double - description: The remaining balance from prepaid commits in microamount - cloud_credit_balance_micros: - type: number - format: double - description: The remaining balance from cloud credits in microamount - pending_charges_micros: - type: number - format: double - description: The total amount of pending/unbilled charges from draft invoices in microamount - effective_balance_micros: - type: number - format: double - description: The effective balance (total balance minus pending charges). Can be negative if pending charges exceed - the balance. - currency: - type: string - example: usd - description: Currency code - - BillingPlansResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] List of available billing plans for subscription.' - required: - - plans - properties: - current_plan_slug: - type: string - description: Current plan slug if subscribed - plans: - type: array - items: - $ref: '#/components/schemas/Plan' - - BillingStatusResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Current billing and subscription status for a workspace.' - required: - - is_active - - has_funds - properties: - is_active: - type: boolean - description: Whether the workspace has an active subscription - subscription_status: - type: string - enum: - - active - - ended - - canceled - description: Subscription activity status (scheduled subscriptions are not returned) - subscription_tier: - $ref: '#/components/schemas/SubscriptionTier' - subscription_duration: - $ref: '#/components/schemas/SubscriptionDuration' - plan_slug: - type: string - description: Plan identifier (e.g., standard-monthly, team-pro-annual) - billing_status: - $ref: '#/components/schemas/BillingStatus' - has_funds: - type: boolean - description: Whether the workspace has available credits - cancel_at: - type: string - format: date-time - description: When the subscription will become inactive (if canceled) - renewal_date: - type: string - format: date-time - description: When the current billing period ends and the next one begins - - GetUserDataResponseFull: - type: array - x-runtime: [cloud] - description: '[cloud-only] List of user data file entries (each with path, size, and modification time) returned when full_info=true.' - items: - $ref: '#/components/schemas/GetUserDataResponseFullFile' - - HistoryDetailEntry: - type: object - x-runtime: [cloud] - description: '[cloud-only] History entry with full prompt data' - properties: - prompt: - type: object - description: Full prompt execution data - properties: - priority: - type: number - format: double - description: Execution priority - prompt_id: - type: string - description: The prompt ID - prompt: - type: object - description: The workflow nodes - additionalProperties: true - extra_data: - type: object - description: Additional execution data - additionalProperties: true - outputs_to_execute: - type: array - items: - type: string - description: Output nodes to execute - outputs: - type: object - description: Output data from execution (generated images, files, etc.) - additionalProperties: true - status: - type: object - description: Execution status and timeline information - additionalProperties: true - meta: - type: object - description: Metadata about the execution and nodes - additionalProperties: true - - HistoryDetailResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Detailed execution history response for a specific prompt. - - Returns a dictionary with prompt_id as key and full history data as value. - - ' - additionalProperties: - $ref: '#/components/schemas/HistoryDetailEntry' - - HistoryResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Execution history response with history array. - - Returns an object with a "history" key containing an array of history entries. - - Each entry includes prompt_id as a property along with execution data. - - ' - required: - - history - properties: - history: - type: array - description: Array of history entries ordered by creation time (newest first) - items: - $ref: '#/components/schemas/HistoryEntry' - - HubLabelInfo: - type: object - x-runtime: [cloud] - description: '[cloud-only] Metadata for a single Hub label.' - required: - - name - - display_name - - type - properties: - name: - type: string - description: Slug identifier. - display_name: - type: string - description: Human-readable display name. - description: - type: string - description: Optional description of the label. - type: - type: string - enum: - - tag - - model - - custom_node - description: Label category. - - HubLabelListResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Response wrapper for the available Hub label catalog.' - required: - - labels - properties: - labels: - type: array - items: - $ref: '#/components/schemas/HubLabelInfo' - description: Available labels, optionally filtered by type. - - HubProfileSummary: - type: object - x-runtime: [cloud] - description: '[cloud-only] Abbreviated Hub profile used in workflow listings.' - required: - - username - properties: - username: - type: string - display_name: - type: string - avatar_url: - type: string - description: Public URL of the profile avatar image. - - HubWorkflowListResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Paginated list of Hub workflows matching search criteria.' - required: - - workflows - properties: - workflows: - type: array - items: - anyOf: - - $ref: '#/components/schemas/HubWorkflowSummary' - - $ref: '#/components/schemas/HubWorkflowDetail' - description: Array of HubWorkflowSummary (default) or HubWorkflowDetail (when detail=true). - next_cursor: - type: string - description: Cursor for the next page, empty if no more results. - - HubWorkflowStatus: - type: string - x-runtime: [cloud] - description: '[cloud-only] Public workflow status. NULL in the database is represented as pending in API responses.' - enum: - - pending - - approved - - rejected - - deprecated - - HubWorkflowSummary: - type: object - x-runtime: [cloud] - description: '[cloud-only] Abbreviated Hub workflow metadata used in search and listing results.' - required: - - share_id - - name - - profile - - status - properties: - share_id: - type: string - name: - type: string - status: - $ref: '#/components/schemas/HubWorkflowStatus' - description: - type: string - tags: - type: array - items: - $ref: '#/components/schemas/LabelRef' - models: - type: array - items: - $ref: '#/components/schemas/LabelRef' - custom_nodes: - type: array - items: - $ref: '#/components/schemas/LabelRef' - thumbnail_type: - type: string - enum: - - image - - video - - image_comparison - thumbnail_url: - type: string - thumbnail_comparison_url: - type: string - publish_time: - type: string - format: date-time - nullable: true - profile: - $ref: '#/components/schemas/HubProfileSummary' - metadata: - type: object - additionalProperties: true - tutorial_url: - type: string - sample_image_urls: - type: array - items: - type: string - - HubWorkflowTemplateEntry: - type: object - x-runtime: [cloud] - description: '[cloud-only] Entry in the curated workflow template gallery shown on the home page.' - required: - - name - - title - - status - properties: - name: - type: string - description: Slug identifier for the template - title: - type: string - status: - $ref: '#/components/schemas/HubWorkflowStatus' - description: - type: string - tags: - type: array - items: - type: string - models: - type: array - items: - type: string - requiresCustomNodes: - type: array - items: - type: string - thumbnailVariant: - type: string - mediaType: - type: string - mediaSubtype: - type: string - size: - type: integer - format: int64 - description: Workflow asset size in bytes. - vram: - type: integer - format: int64 - description: Approximate VRAM requirement in bytes. - usage: - type: integer - format: int64 - description: Usage count reported upstream. - searchRank: - type: integer - format: int64 - description: Search ranking score reported upstream. - isEssential: - type: boolean - description: Whether the template belongs to a module marked as essential. - openSource: - type: boolean - profile: - $ref: '#/components/schemas/HubProfileSummary' - tutorialUrl: - type: string - logos: - type: array - items: - type: object - additionalProperties: true - date: - type: string - description: Publication date in YYYY-MM-DD format - io: - type: object - properties: - inputs: - type: array - items: - type: object - additionalProperties: true - outputs: - type: array - items: - type: object - additionalProperties: true - includeOnDistributions: - type: array - items: - type: string - thumbnailUrl: - type: string - description: Public URL of the primary thumbnail - thumbnailComparisonUrl: - type: string - description: Public URL of the comparison thumbnail - shareId: - type: string - description: Share ID for linking to the hub workflow detail - extendedDescription: - type: string - description: AI-generated extended description of the workflow - metaDescription: - type: string - description: AI-generated SEO meta description (under 160 chars) - howToUse: - type: array - items: - type: string - description: AI-generated step-by-step usage instructions - suggestedUseCases: - type: array - items: - type: string - description: AI-generated suggested use cases - faqItems: - type: array - items: - type: object + default_view: + description: Default view mode + enum: + - workflow + - app + type: string + description: + description: Description of the workflow + type: string + forked_from_workflow_id: + description: ID of the source workflow if forked + type: string + forked_from_workflow_version_id: + description: ID of the source workflow version if forked + type: string + name: + description: Display name for the workflow + type: string + workflow_json: + additionalProperties: true + description: The ComfyUI workflow JSON + type: object required: - - question - - answer - properties: - question: - type: string - answer: - type: string - description: AI-generated FAQ items - contentTemplate: - type: string - description: Content template used for generation (tutorial, showcase, comparison, breakthrough) - - JobStatusResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Job status information' - properties: - id: - type: string - format: uuid - description: The job ID - status: - type: string - enum: - - waiting_to_dispatch - - pending - - in_progress - - completed - - error - - cancelled - description: Current job status - created_at: - type: string - format: date-time - description: When the job was created - updated_at: - type: string - format: date-time - description: When the job was last updated - last_state_update: - type: string - format: date-time - description: When the job status was last changed - assigned_inference: - type: string - nullable: true - description: The inference instance assigned to this job (if any) - error_message: - type: string - nullable: true - description: Error message if the job failed - required: - - id - - status - - created_at - - updated_at - - JobsListResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Paginated list of jobs for the authenticated user.' - required: - - jobs - - pagination - properties: - jobs: - type: array - description: Array of jobs ordered by specified sort field - items: - $ref: '#/components/schemas/JobEntry' - pagination: - $ref: '#/components/schemas/PaginationInfo' - - LabelRef: - type: object - x-runtime: [cloud] - description: '[cloud-only] Reference to a Hub label by ID.' - required: - - name - - display_name - properties: - name: - type: string - description: Slug identifier (e.g. "video-generation", "flux"). - display_name: - type: string - description: Human-readable display name (e.g. "Video Generation", "Flux"). - - LogsResponse: - type: array - x-runtime: [cloud] - description: '[cloud-only] System logs response' - items: - type: object - properties: - timestamp: - type: string - format: date-time - description: When the log entry was created - level: - type: string - enum: - - debug - - info - - warn - - error - description: Log level - message: - type: string - description: Log message - source: - type: string - description: Source of the log entry - metadata: + - workflow_json type: object + CreateWorkflowVersionRequest: + description: Request body for creating a new version of a saved workflow. + properties: + base_version: + description: The version number this change is based on (for optimistic concurrency) + type: integer + workflow_json: + additionalProperties: true + description: The updated ComfyUI workflow JSON + type: object + required: + - base_version + - workflow_json + type: object + ErrorResponse: + description: Standard error response with a machine-readable code and human-readable message. + properties: + code: + type: string + details: + additionalProperties: true + description: Optional open object carrying structured, machine-readable context about the error (e.g. offending field names, validation specifics). Absent for most errors; consumers must not assume any particular shape. + type: object + message: + type: string + required: + - code + - message + type: object + ExecutionError: + description: Detailed execution error information from ComfyUI + properties: + current_inputs: + additionalProperties: true + description: Input values at time of failure (empty object if not available) + type: object + current_outputs: + additionalProperties: true + description: Output values at time of failure (empty object if not available) + type: object + exception_message: + description: Human-readable error message + type: string + exception_type: + description: Python exception type (e.g., "RuntimeError") + type: string + node_id: + description: ID of the node that failed + type: string + node_type: + description: Type name of the node (e.g., "KSampler") + type: string + traceback: + description: Array of traceback lines (empty array if not available) + items: + type: string + type: array + required: + - node_id + - node_type + - exception_message + - exception_type + - traceback + - current_inputs + - current_outputs + type: object + FeedbackRequest: + description: Request to submit user feedback + properties: + content: + description: The feedback content or message + type: string + metadata: + additionalProperties: true + description: Additional metadata about the feedback + type: object + rating: + description: User's rating of ComfyUI Cloud experience (1-5 stars) + maximum: 5 + minimum: 1 + type: integer + type: + description: Type of feedback being submitted + enum: + - missing_nodes + - general + - missing_models + type: string + required: + - type + type: object + FeedbackResponse: + description: Response after submitting feedback + type: object + ForkWorkflowRequest: + description: Request body for forking an existing workflow into the user's account. + properties: + name: + description: Name for the forked workflow + type: string + source_version: + description: Version number to fork from + type: integer + required: + - source_version + type: object + GetUserDataResponseFull: + description: List of user data file entries (each with path, size, and modification time) returned when full_info=true. + items: + $ref: '#/components/schemas/GetUserDataResponseFullFile' + type: array + GetUserDataResponseFullFile: + description: Individual file entry within a full user data response. + properties: + modified: + description: UNIX timestamp of the last modification in milliseconds. + format: int64 + type: integer + path: + description: File name or path relative to the user directory. + type: string + size: + description: File size in bytes. + type: integer + type: object + GlobalSubgraphData: + description: Full data for a global subgraph blueprint + properties: + data: + description: The full subgraph JSON data as a string + type: string + info: + description: Additional information about the subgraph + properties: + node_pack: + description: The node pack/module that provides this subgraph + type: string + required: + - node_pack + type: object + name: + description: Display name of the subgraph blueprint + type: string + source: + description: Source type of the subgraph - "templates" for workflow templates or "custom_node" for custom node subgraphs + type: string + required: + - source + - name + - info + - data + type: object + GlobalSubgraphInfo: + description: Metadata for a global subgraph blueprint (without full data) + properties: + data: + description: The full subgraph JSON data (may be empty in list view) + type: string + info: + description: Additional information about the subgraph + properties: + node_pack: + description: The node pack/module that provides this subgraph + type: string + required: + - node_pack + type: object + name: + description: Display name of the subgraph blueprint + type: string + source: + description: Source type of the subgraph - "templates" for workflow templates or "custom_node" for custom node subgraphs + type: string + required: + - source + - name + - info + type: object + HistoryDetailEntry: + description: History entry with full prompt data + properties: + meta: + additionalProperties: true + description: Metadata about the execution and nodes + type: object + outputs: + additionalProperties: true + description: Output data from execution (generated images, files, etc.) + type: object + prompt: + description: Full prompt execution data + properties: + extra_data: + additionalProperties: true + description: Additional execution data + type: object + outputs_to_execute: + description: Output nodes to execute + items: + type: string + type: array + priority: + description: Execution priority + format: double + type: number + prompt: + additionalProperties: true + description: The workflow nodes + type: object + prompt_id: + description: The prompt ID + type: string + type: object + status: + additionalProperties: true + description: Execution status and timeline information + type: object + type: object + HistoryDetailResponse: + additionalProperties: + $ref: '#/components/schemas/HistoryDetailEntry' + description: | + Detailed execution history response for a specific prompt. + Returns a dictionary with prompt_id as key and full history data as value. + type: object + HistoryEntry: + description: History entry with prompt_id and execution data + properties: + create_time: + description: Job creation timestamp (Unix timestamp in milliseconds) + format: int64 + type: integer + meta: + additionalProperties: true + description: Metadata about the execution and nodes + type: object + outputs: + additionalProperties: true + description: Output data from execution (generated images, files, etc.) + type: object + prompt: + description: Filtered prompt execution data (lightweight format) + properties: + extra_data: + additionalProperties: true + description: Additional execution data (workflow removed from extra_pnginfo) + type: object + priority: + description: Execution priority + format: double + type: number + prompt_id: + description: The prompt ID + type: string + type: object + prompt_id: + description: Unique identifier for this prompt execution + type: string + status: + additionalProperties: true + description: Execution status and timeline information + type: object + workflow_id: + description: UUID identifying the workflow graph definition + type: string + required: + - prompt_id + type: object + HistoryManageRequest: + additionalProperties: false + description: Request to manage history operations + properties: + clear: + description: If true, clear all history for the authenticated user + type: boolean + delete: + description: Array of job IDs to delete from history + items: + type: string + type: array + type: object + HistoryResponse: + description: | + Execution history response with history array. + Returns an object with a "history" key containing an array of history entries. + Each entry includes prompt_id as a property along with execution data. + properties: + history: + description: Array of history entries ordered by creation time (newest first) + items: + $ref: '#/components/schemas/HistoryEntry' + type: array + required: + - history + type: object + JobCancelResponse: + description: Response for POST /api/jobs/{job_id}/cancel. Returned on both fresh cancels and idempotent no-ops. + properties: + cancelled: + description: | + True when a cancel event was successfully dispatched by this call. + False when the job was already in a terminal or cancelling state, + in which case the call is a no-op (still 200 — idempotent). + type: boolean + required: + - cancelled + type: object + JobDetailResponse: + description: Full job details including workflow and outputs + properties: + create_time: + description: Job creation timestamp (Unix timestamp in milliseconds) + format: int64 + type: integer + execution_error: + allOf: + - $ref: '#/components/schemas/ExecutionError' + description: Detailed execution error from ComfyUI (only for failed jobs with structured error data) + execution_meta: + additionalProperties: true + description: Node-level execution metadata (only for terminal states) + type: object + execution_status: + additionalProperties: true + description: ComfyUI execution status and timeline (only for terminal states) + type: object + id: + description: Unique job identifier + format: uuid + type: string + outputs: + additionalProperties: true + description: Full outputs object from ComfyUI (only for terminal states) + type: object + outputs_count: + description: Total number of output files (omitted for non-terminal states) + type: integer + preview_output: + additionalProperties: true + description: Primary preview output (only for terminal states) + type: object + status: + description: User-friendly job status + enum: + - pending + - in_progress + - completed + - failed + - cancelled + type: string + update_time: + description: Last update timestamp (Unix timestamp in milliseconds) + format: int64 + type: integer + workflow: + additionalProperties: true + description: | + Full ComfyUI workflow (10-100KB, omitted if not available). + + Sensitive credentials are redacted before the response is returned: + `extra_data.api_key_comfy_org`, when present, is replaced with the + literal string `"[REDACTED]"`. The field is preserved (not removed) + so existence checks still pass, but the value is not usable. + type: object + workflow_id: + description: UUID identifying the workflow graph definition + type: string + required: + - id + - status + - create_time + - update_time + type: object + JobEntry: + description: Lightweight job data for list views (workflow and full outputs excluded) + properties: + create_time: + description: Job creation timestamp (Unix timestamp in milliseconds) + format: int64 + type: integer + execution_end_time: + description: Workflow execution completion timestamp (Unix milliseconds, only present for terminal states) + format: int64 + type: integer + execution_error: + allOf: + - $ref: '#/components/schemas/ExecutionError' + description: Detailed execution error from ComfyUI (only for failed jobs with structured error data) + execution_start_time: + description: Workflow execution start timestamp (Unix milliseconds, only present for terminal states) + format: int64 + type: integer + id: + description: Unique job identifier + format: uuid + type: string + outputs_count: + description: Total number of output files (omitted for non-terminal states) + type: integer + preview_output: + additionalProperties: true + description: Primary preview output (only present for terminal states) + type: object + status: + description: User-friendly job status + enum: + - pending + - in_progress + - completed + - failed + - cancelled + type: string + workflow_id: + description: UUID identifying the workflow graph definition + type: string + required: + - id + - status + - create_time + type: object + JobStatusResponse: + description: Job status information + properties: + assigned_inference: + description: The inference instance assigned to this job (if any) + nullable: true + type: string + created_at: + description: When the job was created + format: date-time + type: string + error_message: + description: Error message if the job failed + nullable: true + type: string + id: + description: The job ID + format: uuid + type: string + last_state_update: + description: When the job status was last changed + format: date-time + type: string + status: + description: Current job status + enum: + - waiting_to_dispatch + - pending + - in_progress + - completed + - error + - cancelled + type: string + updated_at: + description: When the job was last updated + format: date-time + type: string + required: + - id + - status + - created_at + - updated_at + type: object + JobsListResponse: + description: Paginated list of jobs for the authenticated user. + properties: + jobs: + description: Array of jobs ordered by specified sort field + items: + $ref: '#/components/schemas/JobEntry' + type: array + pagination: + $ref: '#/components/schemas/PaginationInfo' + required: + - jobs + - pagination + type: object + ListAssetsResponse: + description: Paginated list of assets belonging to the authenticated user. + properties: + assets: + description: List of assets matching the query + items: + $ref: '#/components/schemas/Asset' + type: array + has_more: + description: Whether more assets are available beyond this page + type: boolean + next_cursor: + description: | + Opaque cursor to pass as the `after` query parameter to fetch the + next page. Omitted from the response when there are no more results. + type: string + total: + description: Total number of assets matching the filters + type: integer + required: + - assets + - total + - has_more + type: object + ListTagsResponse: + description: Paginated list of available asset tags. + properties: + has_more: + description: Whether more tags are available + type: boolean + tags: + description: List of tags + items: + $ref: '#/components/schemas/TagInfo' + type: array + total: + description: Total number of tags + type: integer + required: + - tags + - total + - has_more + type: object + ModelFile: + description: Represents a model file with metadata + properties: + name: + description: The filename of the model + example: model.safetensors + type: string + pathIndex: + description: Index of the path where this model is located + example: 0 + type: integer + required: + - name + - pathIndex + type: object + ModelFolder: + description: Represents a folder containing models + properties: + folders: + description: List of paths where models of this type are stored + example: + - checkpoints + items: + type: string + type: array + name: + description: The name of the model folder + example: checkpoints + type: string + required: + - name + - folders + type: object + NodeInfo: + description: Metadata describing a single ComfyUI node type and its inputs/outputs. + properties: + api_node: + description: Whether this is an API node + type: boolean + category: + description: Category of the node + type: string + deprecated: + description: Whether the node is deprecated + type: boolean + description: + description: Description of the node + type: string + display_name: + description: Display name of the node + type: string + experimental: + description: Whether the node is experimental + type: boolean + input: + additionalProperties: true + description: Input specifications for the node + type: object + input_order: + additionalProperties: + items: + type: string + type: array + description: Order of inputs for display + type: object + name: + description: Internal name of the node + type: string + output: + description: Output types of the node + items: + type: string + type: array + output_is_list: + description: Whether each output is a list + items: + type: boolean + type: array + output_name: + description: Names of the outputs + items: + type: string + type: array + output_node: + description: Whether this is an output node + type: boolean + output_tooltips: + description: Tooltips for outputs + items: + type: string + type: array + python_module: + description: Python module implementing the node + type: string + type: object + PaginationInfo: + description: Offset/limit-based pagination metadata included in list responses. + properties: + has_more: + description: Whether more items are available beyond this page + type: boolean + limit: + description: Items per page + minimum: 1 + type: integer + offset: + description: Current offset (0-based) + minimum: 0 + type: integer + total: + description: Total number of items matching filters + minimum: 0 + type: integer + required: + - offset + - limit + - total + - has_more + type: object + PromptErrorResponse: additionalProperties: true - description: Additional log metadata + description: Error response for ComfyUI prompt execution. + type: object + PromptInfo: + description: Metadata about the currently running and queued prompts. + properties: + exec_info: + properties: + queue_remaining: + description: Number of items remaining in the queue + type: integer + type: object + type: object + PromptRequest: + description: Request body for submitting a ComfyUI workflow prompt for execution. + properties: + extra_data: + additionalProperties: true + description: Extra data to be associated with the prompt + type: object + front: + description: If true, adds the prompt to the front of the queue + type: boolean + number: + description: Priority number for the queue (lower numbers have higher priority) + type: number + partial_execution_targets: + description: List of node names to execute + items: + type: string + type: array + prompt: + additionalProperties: true + description: The workflow graph to execute + type: object + workflow_id: + description: UUID identifying the cloud workflow entity to associate with this job + type: string + workflow_version_id: + description: UUID identifying the workflow version to associate with this job + type: string + required: + - prompt + type: object + PromptResponse: + description: Response returned after successfully queuing a workflow prompt. + properties: + node_errors: + additionalProperties: true + description: Any errors in the nodes of the prompt + type: object + number: + description: Priority number in the queue + type: number + prompt_id: + description: Unique identifier for the prompt execution + format: uuid + type: string + type: object + PublishWorkflowAssetsRequest: + description: Request body for publishing workflow assets to the Hub. + properties: + asset_ids: + description: IDs of assets (inputs and models) to snapshot. + items: + type: string + type: array + required: + - asset_ids + type: object + PublishedWorkflowDetail: + description: Full detail of a publicly published workflow on the Hub. + properties: + assets: + description: Published assets with their library status for the caller. + items: + $ref: '#/components/schemas/AssetInfo' + type: array + listed: + type: boolean + name: + description: Human-readable workflow name. + type: string + publish_time: + format: date-time + nullable: true + type: string + share_id: + type: string + workflow_id: + type: string + workflow_json: + additionalProperties: true + description: The workflow JSON content at publish time. + type: object + required: + - share_id + - workflow_id + - name + - listed + - workflow_json + - assets + type: object + QueueInfo: + description: Queue information with pending and running jobs + properties: + queue_pending: + description: Array of pending job items (ordered by creation time, oldest first) + items: + description: | + Queue item tuple format: [job_number, prompt_id, workflow_json, output_node_ids, metadata] + - [0] job_number (integer): Position in queue (1-based) + - [1] prompt_id (string): Job UUID + - [2] workflow_json (object): Full ComfyUI workflow + - [3] output_node_ids (array): Node IDs to return results from + - [4] metadata (object): Contains {create_time: } + items: {} + maxItems: 5 + minItems: 5 + type: array + type: array + queue_running: + description: Array of currently running job items + items: + description: | + Queue item tuple format: [job_number, prompt_id, workflow_json, output_node_ids, metadata] + - [0] job_number (integer): Position in queue (1-based) + - [1] prompt_id (string): Job UUID + - [2] workflow_json (object): Full ComfyUI workflow + - [3] output_node_ids (array): Node IDs to return results from + - [4] metadata (object): Contains {create_time: } + items: {} + maxItems: 5 + minItems: 5 + type: array + type: array + type: object + QueueManageRequest: + additionalProperties: false + description: Request to manage queue operations + properties: + clear: + description: If true, clear all pending jobs from the queue + type: boolean + delete: + description: Array of PENDING job IDs to cancel + items: + type: string + type: array + type: object + QueueManageResponse: + description: Response after a queue management action (delete or clear). + properties: + cleared: + description: Whether the queue was cleared + type: boolean + deleted: + description: Array of job IDs that were successfully cancelled + items: + type: string + type: array + type: object + SystemStatsResponse: + description: System statistics response + properties: + devices: + items: + properties: + name: + description: Device name + type: string + type: + description: Device type + type: string + vram_free: + description: Free VRAM in bytes + type: number + vram_total: + description: Total VRAM in bytes + type: number + required: + - name + - type + type: object + type: array + system: + properties: + argv: + description: Command line arguments + items: + type: string + type: array + cloud_version: + description: Cloud ingest service version (commit hash) + type: string + comfyui_frontend_version: + description: ComfyUI frontend version (commit hash or tag) + type: string + comfyui_version: + description: ComfyUI version + type: string + embedded_python: + description: Whether using embedded Python + type: boolean + os: + description: Operating system + type: string + python_version: + description: Python version + type: string + pytorch_version: + description: PyTorch version + type: string + ram_free: + description: Free RAM in bytes + type: number + ram_total: + description: Total RAM in bytes + type: number + workflow_templates_version: + description: Workflow templates version + type: string + required: + - os + - python_version + - embedded_python + - comfyui_version + - pytorch_version + - argv + - ram_total + - ram_free + type: object + required: + - system + - devices + type: object + TagInfo: + description: Metadata for a single tag that can be applied to assets. + properties: + count: + description: Number of assets using this tag + type: integer + name: + description: Tag name + type: string + required: + - name + - count + type: object + TagsModificationResponse: + description: Response after adding, updating, or removing tags on an asset. + properties: + added: + description: Tags that were successfully added (for add operation) + items: + type: string + type: array + already_present: + description: Tags that were already present (for add operation) + items: + type: string + type: array + not_present: + description: Tags that were not present (for remove operation) + items: + type: string + type: array + removed: + description: Tags that were successfully removed (for remove operation) + items: + type: string + type: array + total_tags: + description: All tags on the asset after the operation + items: + type: string + type: array + required: + - total_tags + type: object + TaskEntry: + description: Task data for list views + properties: + completed_at: + description: When task completed or failed (null if not finished) + format: date-time + type: string + create_time: + description: Task creation timestamp + format: date-time + type: string + id: + description: Unique task identifier + format: uuid + type: string + started_at: + description: When task execution started (null if not started) + format: date-time + type: string + status: + description: Current task status + enum: + - created + - running + - completed + - failed + type: string + task_name: + description: Task type name (e.g., model_upload) + type: string + required: + - id + - task_name + - status + - create_time + type: object + TaskResponse: + description: Full task details including payload and result + properties: + completed_at: + description: When task completed or failed (null if not finished) + format: date-time + type: string + create_time: + description: Task creation timestamp + format: date-time + type: string + error_message: + description: Error message on failure (null if not failed) + type: string + id: + description: Unique task identifier + format: uuid + type: string + idempotency_key: + description: Caller-provided key for idempotent task creation + type: string + payload: + additionalProperties: true + description: Task input data + type: object + result: + additionalProperties: true + description: Task output data (null if not completed) + type: object + started_at: + description: When task execution started (null if not started) + format: date-time + type: string + status: + description: Current task status + enum: + - created + - running + - completed + - failed + type: string + task_name: + description: Task type name (e.g., model_upload) + type: string + update_time: + description: Task last update timestamp + format: date-time + type: string + required: + - id + - idempotency_key + - task_name + - payload + - status + - create_time + - update_time + type: object + TasksListResponse: + description: Paginated list of background tasks for the authenticated user. + properties: + pagination: + $ref: '#/components/schemas/PaginationInfo' + tasks: + description: Array of tasks ordered by create_time + items: + $ref: '#/components/schemas/TaskEntry' + type: array + required: + - tasks + - pagination + type: object + UpdateWorkflowRequest: + description: Request body for updating an existing saved workflow. + properties: + default_view: + description: New default view mode + enum: + - workflow + - app + type: string + description: + description: New description + type: string + name: + description: New display name + type: string + type: object + UserDataResponseFull: + description: User data listing entry with file metadata (path, size, modification time). + properties: + modified: + description: UNIX timestamp of the last modification in milliseconds. + format: int64 + type: integer + path: + type: string + size: + type: integer + type: object + UserResponse: + description: User information response + properties: + id: + description: Firebase UID of the authenticated user + type: string + status: + description: User status (always "active" for authenticated users) + type: string + required: + - id + - status + type: object + WorkflowForkedFrom: + description: Reference to the parent workflow from which this workflow was forked. + properties: + workflow_id: + type: string + workflow_version_id: + type: string + type: object + WorkflowListResponse: + description: Paginated list of saved workflows. + properties: + data: + items: + $ref: '#/components/schemas/WorkflowResponse' + type: array + pagination: + $ref: '#/components/schemas/PaginationInfo' + required: + - data + - pagination + type: object + WorkflowPublishInfo: + description: Publishing metadata for a workflow shared to the Hub. + properties: + assets: + description: Published assets (inputs and models). + items: + $ref: '#/components/schemas/AssetInfo' + type: array + listed: + type: boolean + publish_time: + format: date-time + nullable: true + type: string + share_id: + type: string + workflow_id: + type: string + required: + - workflow_id + - share_id + - listed + - assets + type: object + WorkflowResponse: + description: Full workflow entity including metadata and version history. + properties: + created_at: + format: date-time + type: string + created_by: + type: string + default_view: + enum: + - workflow + - app + type: string + description: + type: string + forked_from: + $ref: '#/components/schemas/WorkflowForkedFrom' + id: + type: string + latest_version: + type: integer + name: + type: string + updated_at: + format: date-time + type: string + required: + - id + - latest_version + - created_by + - created_at + - updated_at + type: object + WorkflowVersionContentResponse: + description: Full workflow version including the serialized workflow JSON. + properties: + created_at: + format: date-time + type: string + created_by: + type: string + dependency_asset_ids: + items: + type: string + type: array + id: + type: string + version: + type: integer + workflow_json: + additionalProperties: true + type: object + required: + - id + - version + - workflow_json + - created_by + - created_at + type: object + WorkflowVersionResponse: + description: Metadata for a single workflow version. + properties: + created_at: + format: date-time + type: string + created_by: + type: string + id: + type: string + latest_version: + type: integer + version: + type: integer + required: + - id + - version + - latest_version + - created_by + - created_at + type: object + securitySchemes: + ApiKeyAuth: + description: | + API key authentication. Keys are prefixed with 'comfyui-' and can be + generated from user account settings. Example: 'comfyui-abc123...' + in: header + name: X-API-Key + type: apiKey + BearerAuth: + bearerFormat: JWT + description: | + Firebase JWT token authentication. Obtain a token by authenticating + with Firebase and pass it in the Authorization header. + scheme: bearer + type: http + CookieAuth: + description: | + Session cookie authentication. Set automatically after successful + login via the /api/auth/session endpoint. + in: cookie + name: session + type: apiKey +info: + description: | + API for ComfyUI - A powerful and modular UI for Stable Diffusion. - Member: - type: object - x-runtime: [cloud] - description: '[cloud-only] Workspace member with profile and role information.' - required: - - id - - name - - email - - role - - joined_at - properties: - id: - type: string - description: User ID - name: - type: string - description: User's display name - email: - type: string - format: email - description: User's email address - role: - type: string - enum: - - owner - - member - description: User's role in the workspace - joined_at: - type: string - format: date-time - description: When the user joined the workspace + This API allows you to interact with ComfyUI programmatically, including: + - Retrieving prompt information + - Retrieving node information + license: + name: GNU General Public License v3.0 + url: https://github.com/Comfy-Org/ComfyUI/blob/master/LICENSE + title: ComfyUI API + version: 1.0.0 +openapi: 3.0.3 +paths: + /api/assets: + get: + description: | + Retrieves a paginated list of assets belonging to the authenticated user. + Supports filtering by tags, name, metadata, and sorting options. + operationId: listAssets + parameters: + - description: Filter assets that have ALL of these tags + explode: false + in: query + name: include_tags + schema: + items: + type: string + type: array + style: form + - description: Exclude assets that have ANY of these tags + explode: false + in: query + name: exclude_tags + schema: + items: + type: string + type: array + style: form + - description: Filter assets where name contains this substring (case-insensitive) + in: query + name: name_contains + schema: + type: string + - description: JSON object for filtering by metadata fields + in: query + name: metadata_filter + schema: + type: string + - description: Maximum number of assets to return (1-500) + in: query + name: limit + schema: + default: 20 + maximum: 500 + minimum: 1 + type: integer + - description: Number of assets to skip for pagination + in: query + name: offset + schema: + default: 0 + minimum: 0 + type: integer + - description: Field to sort by + in: query + name: sort + schema: + default: created_at + enum: + - name + - created_at + - updated_at + - size + - last_access_time + type: string + - description: Sort order + in: query + name: order + schema: + default: desc + enum: + - asc + - desc + type: string + - description: Whether to include public/shared assets in results + in: query + name: include_public + schema: + default: true + type: boolean + - description: Filter assets by exact content hash. Preferred over asset_hash. + in: query + name: hash + schema: + type: string + - deprecated: true + description: 'Deprecated: use hash instead. Filter assets by exact content hash.' + in: query + name: asset_hash + schema: + type: string + - description: | + Opaque cursor for keyset pagination. Pass the `next_cursor` value + from the previous response to fetch the next page. When provided, + `offset` is ignored. Cursor pagination is only supported with + `sort` values `created_at`, `updated_at`, `name`, or `size`; + requests combining `after` with other sort fields return 400. + The cursor must have been minted under the same `sort` value used + in the follow-up request. + in: query + name: after + schema: + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/ListAssetsResponse' + description: Success - Assets returned + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request parameters + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: List user assets + tags: + - file + post: + description: | + Uploads a new asset to the system with associated metadata. + Supports two upload methods: + 1. Direct file upload (multipart/form-data) + 2. URL-based upload (application/json with source: "url") - OAuthRegisterBadRequestResponse: - x-runtime: [cloud] - description: "[cloud-only] Union of the two 400 shapes /oauth/register can emit. `OAuthRegisterError` is the handler-shaped\ - \ RFC 7591 \xA73.2.2 error; `BindingErrorResponse` is the strict-server binding-layer error fired when the request body\ - \ fails OpenAPI-schema validation before the handler runs.\n" - oneOf: - - $ref: '#/components/schemas/OAuthRegisterError' - - $ref: '#/components/schemas/BindingErrorResponse' + If an asset with the same hash already exists, returns the existing asset. + operationId: uploadAsset + requestBody: + content: + application/json: + schema: + properties: + name: + description: Display name for the asset (used to determine file extension) + type: string + preview_id: + description: Optional preview asset ID + format: uuid + type: string + tags: + description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order. + items: + type: string + type: array + url: + description: HTTP/HTTPS URL to download the asset from + format: uri + type: string + user_metadata: + additionalProperties: true + description: Custom metadata to store with the asset + type: object + required: + - url + - name + type: object + multipart/form-data: + schema: + properties: + file: + description: The asset file to upload + format: binary + type: string + id: + description: Optional asset ID for idempotent creation. If provided and asset exists, returns existing asset. + format: uuid + type: string + mime_type: + description: MIME type of the asset (e.g., "image/png", "video/mp4") + type: string + name: + description: Display name for the asset + type: string + preview_id: + description: Optional preview asset ID. If not provided, images will use their own ID as preview. + format: uuid + type: string + tags: + description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order. + items: + type: string + type: array + user_metadata: + description: Custom JSON metadata as a string + type: string + required: + - file + type: object + required: true + responses: + "201": + content: + application/json: + schema: + $ref: '#/components/schemas/AssetCreated' + description: Asset created successfully + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request (bad file, invalid URL, invalid content type, etc.) + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "403": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Source URL requires authentication or access denied + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Source URL not found + "413": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: File too large + "415": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unsupported media type + "422": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Download failed due to network error or timeout + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Upload a new asset + tags: + - file + /api/assets/{id}: + delete: + description: Deletes the asset record. + operationId: deleteAsset + parameters: + - description: Asset ID + in: path + name: id + required: true + schema: + format: uuid + type: string + responses: + "204": + description: Asset record deleted successfully + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Asset not found + "409": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Asset cannot be deleted because it is referenced by another resource (e.g., workflow version) + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Delete asset + tags: + - file + get: + description: Retrieves detailed information about a specific asset + operationId: getAssetById + parameters: + - description: Asset ID + in: path + name: id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/Asset' + description: Asset details retrieved successfully + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Asset not found + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Get asset details + tags: + - file + put: + description: | + Updates an asset's metadata. At least one field must be provided. + Only name, mime_type, preview_id, and user_metadata can be updated. + For tag management, use the dedicated PUT /api/assets/{id}/tags endpoint. + operationId: updateAsset + parameters: + - description: Asset ID + in: path + name: id + required: true + schema: + format: uuid + type: string + requestBody: + content: + application/json: + schema: + minProperties: 1 + properties: + mime_type: + description: Updated MIME type of the asset + type: string + name: + description: New display name for the asset + type: string + preview_id: + description: Updated preview asset ID + format: uuid + type: string + user_metadata: + additionalProperties: true + description: Updated custom metadata + type: object + type: object + required: true + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/AssetUpdated' + description: Asset updated successfully + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request (no fields provided) + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Asset not found + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Update asset metadata + tags: + - file + /api/assets/{id}/tags: + delete: + description: Removes one or more tags from an existing asset + operationId: removeAssetTags + parameters: + - description: Asset ID + in: path + name: id + required: true + schema: + format: uuid + type: string + requestBody: + content: + application/json: + schema: + properties: + tags: + description: Tags to remove from the asset + items: + type: string + minItems: 1 + type: array + required: + - tags + type: object + required: true + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/TagsModificationResponse' + description: Tags removed successfully + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Asset not found + "422": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Validation error (e.g., reserved tag) + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Remove tags from asset + tags: + - file + post: + description: Adds one or more tags to an existing asset + operationId: addAssetTags + parameters: + - description: Asset ID + in: path + name: id + required: true + schema: + format: uuid + type: string + requestBody: + content: + application/json: + schema: + properties: + tags: + description: Tags to add to the asset + items: + type: string + minItems: 1 + type: array + required: + - tags + type: object + required: true + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/TagsModificationResponse' + description: Tags added successfully + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Asset not found + "422": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Validation error (e.g., reserved tag) + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Add tags to asset + tags: + - file + put: + description: Adds and removes tags from an asset in a single operation + operationId: updateAssetTags + parameters: + - description: Asset ID + in: path + name: id + required: true + schema: + format: uuid + type: string + requestBody: + content: + application/json: + schema: + description: At least one of add or remove must contain items. Empty arrays are allowed when the other array has items. + minProperties: 1 + properties: + add: + description: Tags to add to the asset. Can be empty if remove has items. + items: + type: string + type: array + remove: + description: Tags to remove from the asset. Can be empty if add has items. + items: + type: string + type: array + type: object + required: true + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/TagsModificationResponse' + description: Tags updated successfully + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Asset not found + "422": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Reserved tag validation error + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Update asset tags + tags: + - file + /api/assets/from-hash: + post: + description: | + Creates a new asset reference using an existing asset's hash. + This avoids re-uploading the file content when the asset already exists in storage. + The user can provide their own metadata and tags for the reference. + operationId: createAssetFromHash + requestBody: + content: + application/json: + schema: + properties: + hash: + description: Hash of the existing asset. Supports Blake3 (blake3:) or SHA256 (sha256:) formats + pattern: ^(blake3|sha256):[a-f0-9]{64}$ + type: string + mime_type: + description: MIME type of the asset (e.g., "image/png", "video/mp4") + type: string + name: + description: Display name for the asset reference (optional) + type: string + tags: + description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order. + items: + type: string + minItems: 1 + type: array + user_metadata: + additionalProperties: true + description: Custom metadata for this asset reference + type: object + required: + - hash + - tags + type: object + required: true + responses: + "201": + content: + application/json: + schema: + $ref: '#/components/schemas/AssetCreated' + description: Asset reference created successfully + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request (bad hash format, invalid tags, etc.) + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Source asset with given hash not found + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Create asset reference from existing hash + tags: + - file + /api/assets/hash/{hash}: + head: + description: | + Checks if an asset exists in the system by its blake3 hash. + Returns 200 if the asset exists, 404 if it doesn't. + operationId: checkAssetByHash + parameters: + - description: Blake3 hash of the asset in format 'blake3:hex_digest' + in: path + name: hash + required: true + schema: + example: blake3:a1b2c3d4e5f67890123456789012345678901234567890123456789012345678 + pattern: ^blake3:[a-f0-9]{64}$ + type: string + responses: + "200": + description: Asset exists + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid hash format + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + description: Asset not found + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Check if asset exists by hash + tags: + - file + /api/assets/prune: + post: + description: Starts a background job that removes asset entries whose underlying content no longer exists on disk. + operationId: pruneAssets + responses: + "200": + content: + application/json: + schema: + properties: + marked: + description: Number of assets marked as missing + type: integer + status: + type: string + type: object + description: Prune result + summary: Mark assets whose backing files no longer exist on disk + /api/assets/seed: + post: + description: Starts a background job that scans configured directories and registers assets not yet in the asset database. + operationId: seedAssets + requestBody: + content: + application/json: + schema: + properties: + roots: + description: Root folder paths to scan (if omitted, scans all) + items: + type: string + type: array + type: object + responses: + "200": + content: + application/json: + schema: + properties: + status: + type: string + type: object + description: Seed started + summary: Trigger asset scan/seed from filesystem + /api/assets/seed/cancel: + post: + description: Requests cancellation of the currently-running asset seed job. + operationId: cancelAssetSeed + responses: + "200": + content: + application/json: + schema: + properties: + status: + type: string + type: object + description: Scan cancelled + summary: Cancel an in-progress asset scan + /api/assets/seed/status: + get: + description: Returns progress/status of the most recent asset seed job. + operationId: getAssetSeedStatus + responses: + "200": + content: + application/json: + schema: + additionalProperties: true + description: Scan progress details (files scanned, total, status, etc.) + type: object + description: Scan progress + summary: Get asset scan progress + /api/assets/tags/refine: + get: + description: | + Returns a histogram of tags appearing on assets matching the given filters. + Useful for refining asset searches by showing available tags and their counts. + Only returns tags with non-zero counts (tags that exist on matching assets). + operationId: getAssetTagHistogram + parameters: + - description: Filter assets that have ALL of these tags + explode: false + in: query + name: include_tags + schema: + items: + type: string + type: array + style: form + - description: Exclude assets that have ANY of these tags + explode: false + in: query + name: exclude_tags + schema: + items: + type: string + type: array + style: form + - description: Filter assets where name contains this substring (case-insensitive) + in: query + name: name_contains + schema: + type: string + - description: JSON object for filtering by metadata fields + in: query + name: metadata_filter + schema: + type: string + - description: Maximum number of tags to return (1-1000, default 100) + in: query + name: limit + schema: + default: 100 + maximum: 1000 + minimum: 1 + type: integer + - description: Whether to include public/shared assets in results + in: query + name: include_public + schema: + default: true + type: boolean + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/AssetTagHistogramResponse' + description: Success - Tag histogram returned + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request parameters + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Get tag histogram for filtered assets + tags: + - file + /api/embeddings: + get: + description: Returns the list of text-encoder embeddings available on disk. + operationId: getEmbeddings + responses: + "200": + content: + application/json: + schema: + items: + type: string + type: array + description: Embedding names + summary: List available embedding names + /api/experiment/models: + get: + description: | + Returns a list of model folders available in the system. + This is an experimental endpoint that replaces the legacy /models endpoint. + operationId: getModelFolders + responses: + "200": + content: + application/json: + schema: + items: + $ref: '#/components/schemas/ModelFolder' + type: array + description: Success - List of model folders + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + security: [] + summary: Get available model folders + tags: + - file + /api/experiment/models/{folder}: + get: + description: | + Returns a list of models available in the specified folder. + This is an experimental endpoint that provides enhanced model information. + operationId: getModelsInFolder + parameters: + - description: The folder name to list models from + in: path + name: folder + required: true + schema: + example: checkpoints + type: string + responses: + "200": + content: + application/json: + schema: + items: + $ref: '#/components/schemas/ModelFile' + type: array + description: Success - List of models in the folder + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Folder not found or no models in folder + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + security: [] + summary: Get models in a specific folder + tags: + - file + /api/extensions: + get: + description: | + Returns the list of custom node web extension JS files available for + loading by the ComfyUI frontend. Paths are relative to the web root + (e.g. `/extensions/VHS.core.js`). + operationId: getExtensions + responses: + "200": + content: + application/json: + schema: + description: URL paths (relative to web root) of available extension JS files + items: + type: string + type: array + description: JSON array of extension file paths + security: [] + summary: List custom node JS extensions + tags: + - node + /api/features: + get: + description: Returns the server's feature capabilities + operationId: getFeatures + responses: + "200": + content: + application/json: + schema: + additionalProperties: true + properties: + max_upload_size: + description: Maximum upload size in bytes + type: integer + supports_preview_metadata: + description: Whether the server supports preview metadata + type: boolean + type: object + description: Success + headers: + Cache-Control: + description: Short-lived private cache to deduplicate rapid-fire calls from the frontend + schema: + type: string + Vary: + description: Cache key includes auth headers so anonymous and authenticated responses are stored separately + schema: + type: string + security: + - ApiKeyAuth: [] + - BearerAuth: [] + - CookieAuth: [] + - {} + summary: Get server feature flags + tags: + - node + /api/feedback: + post: + description: Submit feedback about the ComfyUI service + operationId: submitFeedback + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/FeedbackRequest' + required: true + responses: + "201": + content: + application/json: + schema: + $ref: '#/components/schemas/FeedbackResponse' + description: Feedback submitted successfully + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Submit user feedback + tags: + - feedback + /api/files/mask-layers: + get: + description: | + Given a mask file (any of the 4 layers), returns all related mask layer files. + This is used by the mask editor to load the paint, mask, and painted layers + when reopening a previously edited mask. + operationId: getMaskLayers + parameters: + - description: Hash filename of any mask layer file + in: query + name: filename + required: true + schema: + example: abc123def456.png + type: string + responses: + "200": + content: + application/json: + schema: + properties: + mask: + description: Filename of the mask layer + nullable: true + type: string + paint: + description: Filename of the paint strokes layer + nullable: true + type: string + painted: + description: Filename of the painted image layer + nullable: true + type: string + painted_masked: + description: Filename of the final composite layer + nullable: true + type: string + type: object + description: Success - Related mask layers returned + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: File not found or not a mask file + summary: Get related mask layer files + tags: + - file + /api/free: + post: + description: Frees GPU memory by unloading models and/or freeing the resident model cache. + operationId: freeMemory + requestBody: + content: + application/json: + schema: + properties: + free_memory: + description: Run garbage collection and free cached memory + type: boolean + unload_models: + description: Unload all models from VRAM/RAM + type: boolean + type: object + responses: + "200": + description: Memory freed + summary: Free GPU memory and/or unload models + /api/global_subgraphs: + get: + description: | + Returns a list of globally available subgraph blueprints. + These are pre-built workflow components that can be used as nodes. + The data field contains a promise that resolves to the full subgraph JSON. + operationId: getGlobalSubgraphs + responses: + "200": + content: + application/json: + schema: + additionalProperties: + $ref: '#/components/schemas/GlobalSubgraphInfo' + type: object + description: Success - Map of subgraph IDs to their metadata + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + security: [] + summary: Get available subgraph blueprints + tags: + - workflow + /api/global_subgraphs/{id}: + get: + description: Returns the full data for a specific subgraph blueprint by ID + operationId: getGlobalSubgraph + parameters: + - description: The unique identifier of the subgraph blueprint + in: path + name: id + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/GlobalSubgraphData' + description: Success - Full subgraph data + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Subgraph not found + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + security: [] + summary: Get a specific subgraph blueprint + tags: + - workflow + /api/history: + post: + deprecated: true + description: | + **Deprecated.** Superseded by the job-management endpoints under + `/api/jobs`. Planned for removal no earlier than a future major + release; sunset timeline TBD. - PendingInvite: - type: object - x-runtime: [cloud] - description: '[cloud-only] An outstanding workspace invitation that has not yet been accepted.' - required: - - id - - email - - invited_at - - expires_at - properties: - id: - type: string - description: Invite ID - email: - type: string - format: email - description: Email address of the invited user - token: - type: string - description: Invite token for constructing invite links. Empty for expired invites. - invited_at: - type: string - format: date-time - description: When the invite was created - expires_at: - type: string - format: date-time - description: When the invite expires + Clear all history for the authenticated user or delete specific job IDs. + Supports clearing all history or deleting specific job IDs. + operationId: manageHistory + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/HistoryManageRequest' + required: true + responses: + "200": + description: Success - History management operation completed + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request parameters + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized - Authentication required + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Manage execution history + tags: + - workflow + /api/history_v2: + get: + deprecated: true + description: | + **Deprecated.** Superseded by `GET /api/jobs`, which returns the same + execution records in a paginated, filterable format. Planned for removal + no earlier than a future major release; sunset timeline TBD. - Plan: - type: object - x-runtime: [cloud] - description: '[cloud-only] Billing plan details including pricing, limits, and features.' - required: - - slug - - tier - - duration - - price_cents - - credits_cents - - max_seats - - availability - - seat_summary - properties: - slug: - type: string - description: Plan identifier (e.g., "pro-monthly", "team-standard-annual") - example: pro-monthly - tier: - $ref: '#/components/schemas/SubscriptionTier' - duration: - $ref: '#/components/schemas/SubscriptionDuration' - price_cents: - type: integer - format: int64 - description: Per-member price in cents (base + one seat) - example: 10000 - credits_cents: - type: integer - format: int64 - description: Per-member credits in cents (base + one seat) - example: 10000 - max_seats: - type: integer - format: int64 - description: Maximum number of seats allowed for this plan - example: 20 - availability: - $ref: '#/components/schemas/PlanAvailability' - seat_summary: - $ref: '#/components/schemas/PlanSeatSummary' + Retrieve execution history for the authenticated user with pagination support. + Returns a lightweight history format with filtered prompt data (workflow removed from extra_pnginfo). + operationId: getHistory + parameters: + - description: Maximum number of items to return + in: query + name: max_items + schema: + type: integer + - description: Starting position (default 0) + in: query + name: offset + schema: + default: 0 + type: integer + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/HistoryResponse' + description: Success - Execution history retrieved + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized - Authentication required + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Get execution history (v2) + tags: + - workflow + /api/history_v2/{prompt_id}: + get: + deprecated: true + description: | + **Deprecated.** Superseded by `GET /api/jobs/{job_id}`, which returns + the same execution record. Planned for removal no earlier than a future + major release; sunset timeline TBD. - PlanAvailability: - type: object - x-runtime: [cloud] - description: '[cloud-only] Availability and eligibility information for a billing plan.' - required: - - available - properties: - available: - type: boolean - description: Whether the workspace can subscribe to this plan - reason: - $ref: '#/components/schemas/PlanAvailabilityReason' + Retrieve detailed execution history for a specific prompt ID. + Returns full history data including complete prompt information. + operationId: getHistoryForPrompt + parameters: + - description: The prompt ID to retrieve history for + in: path + name: prompt_id + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/HistoryDetailResponse' + description: Success - History for prompt retrieved + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized - Authentication required + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Prompt not found + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Get history for specific prompt + tags: + - workflow + /api/i18n: + get: + description: Returns translation file URLs contributed by custom nodes, keyed by locale. + operationId: getI18n + responses: + "200": + content: + application/json: + schema: + additionalProperties: true + description: Nested map of locale to translation key-value pairs + type: object + description: Translation map + summary: Get internationalisation translation strings + /api/interrupt: + post: + description: | + Cancel all currently RUNNING jobs for the authenticated user. + This will interrupt any job that is currently in 'in_progress' status. + Note: This endpoint only affects running jobs. To cancel pending jobs, use /api/queue. + operationId: interruptJob + responses: + "200": + description: Success - Job interrupted or no running job found + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized - Authentication required + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Interrupt currently running jobs + tags: + - queue + /api/job/{job_id}/status: + get: + deprecated: true + description: | + **Deprecated.** Superseded by `GET /api/jobs/{job_id}` (plural path). + Clients should migrate; the endpoint is retained for backward + compatibility but will be removed in a future release. + operationId: getJobStatus + parameters: + - description: The unique ID of the job + in: path + name: job_id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/JobStatusResponse' + description: Success - Job status returned + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "403": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Forbidden - job belongs to another user + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Job not found + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Get job status (deprecated) + tags: + - job + /api/jobs: + get: + description: | + Retrieve a paginated list of jobs for the authenticated user. + Returns lightweight job data optimized for list views. + Workflow and full outputs are excluded to reduce payload size. + operationId: listJobs + parameters: + - description: Filter by one or more statuses (comma-separated). If not provided, returns all jobs. + example: pending,in_progress + in: query + name: status + schema: + type: string + - description: Filter by workflow ID (exact match) + example: 550e8400-e29b-41d4-a716-446655440000 + in: query + name: workflow_id + schema: + type: string + - description: Filter by output media type (only applies to completed jobs with outputs) + example: image + in: query + name: output_type + schema: + enum: + - image + - video + - audio + - 3d + type: string + - description: Field to sort by (create_time = when job was submitted, execution_time = how long workflow took to run) + example: execution_time + in: query + name: sort_by + schema: + default: create_time + enum: + - create_time + - execution_time + type: string + - description: Sort direction (asc = ascending, desc = descending) + in: query + name: sort_order + schema: + default: desc + enum: + - asc + - desc + type: string + - description: Pagination offset (0-based) + in: query + name: offset + schema: + default: 0 + minimum: 0 + type: integer + - description: Maximum items per page (1-1000) + in: query + name: limit + schema: + default: 100 + maximum: 1000 + minimum: 1 + type: integer + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/JobsListResponse' + description: Success - Jobs retrieved + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized - Authentication required + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: List jobs with pagination and filtering + tags: + - workflow + /api/jobs/{job_id}: + get: + description: | + Retrieve complete details for a specific job including workflow and outputs. + Used for detail views, workflow re-execution, and debugging. + operationId: getJobDetail + parameters: + - description: Job identifier (UUID) + in: path + name: job_id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/JobDetailResponse' + description: Success - Job details retrieved + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized - Authentication required + "403": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Forbidden - Job does not belong to user + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Job not found + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Get full job details + tags: + - workflow + /api/jobs/{job_id}/cancel: + post: + description: | + Cancel a specific job for the authenticated user. - PlanAvailabilityReason: - type: string - x-runtime: [cloud] - enum: - - same_plan - - incompatible_transition - - requires_team - - requires_personal - - exceeds_max_seats - description: '[cloud-only] Reason why a plan is unavailable' + Idempotent: a job that is already in a terminal state (completed, failed, + cancelled) or already cancelling is treated as a successful no-op and + returns 200. Only truly missing or cross-user jobs return 404. + operationId: cancelJob + parameters: + - description: Job identifier (UUID) + in: path + name: job_id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/JobCancelResponse' + description: Success - Cancel request accepted (or job was already terminal) + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Bad Request - job_id is not a valid UUID (emitted by request validation before the handler runs) + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized - Authentication required + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Job not found for this user + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error - cancellation failed + summary: Cancel a job + tags: + - workflow + /api/node_replacements: + get: + description: | + Returns mappings of unsupported node class names to their cloud-installed replacements. + Used by the frontend to offer "Quick Fix" when a workflow contains missing nodes. + operationId: getNodeReplacements + responses: + "200": + content: + application/json: + schema: + additionalProperties: true + type: object + description: Success - Node replacement mappings + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + security: [] + summary: Get node replacement mappings + tags: + - node + /api/object_info: + get: + description: Returns information about all available nodes + operationId: getNodeInfo + responses: + "200": + content: + application/json: + schema: + additionalProperties: + $ref: '#/components/schemas/NodeInfo' + type: object + description: Success + summary: Get all node information + tags: + - node + /api/prompt: + get: + description: Returns information about the current prompt in the execution queue + operationId: getPromptInfo + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/PromptInfo' + description: Success + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Get information about current prompt execution + tags: + - workflow + post: + description: | + Submit a workflow to be executed by the backend. + The workflow is a JSON object describing the nodes and their connections. + operationId: executePrompt + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PromptRequest' + required: true + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/PromptResponse' + description: Success - Prompt accepted + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/PromptErrorResponse' + description: Invalid prompt + "402": + content: + application/json: + schema: + $ref: '#/components/schemas/PromptErrorResponse' + description: Payment required - Insufficient credits + "429": + content: + application/json: + schema: + $ref: '#/components/schemas/PromptErrorResponse' + description: Payment required - User has not paid + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/PromptErrorResponse' + description: Internal server error + "503": + content: + application/json: + schema: + $ref: '#/components/schemas/PromptErrorResponse' + description: Service unavailable + summary: Submit a workflow for execution + tags: + - workflow + /api/queue: + get: + description: Returns information about running and pending items in the queue + operationId: getQueueInfo + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/QueueInfo' + description: Success + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request parameters + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request parameters + summary: Get queue information + tags: + - queue + post: + description: | + Cancel specific PENDING jobs by ID or clear all pending jobs in the queue. + Note: This endpoint only affects pending jobs. To cancel running jobs, use /api/interrupt. + operationId: manageQueue + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/QueueManageRequest' + required: true + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/QueueManageResponse' + description: Success + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request parameters + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Manage queue operations + tags: + - queue + /api/settings: + get: + description: Returns all settings for the authenticated user + operationId: getAllSettings + responses: + "200": + content: + application/json: + schema: + additionalProperties: true + description: User settings as key-value pairs + type: object + description: Success + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + summary: Get all user settings + tags: + - settings + post: + description: Update multiple settings (merge with existing) + operationId: updateMultipleSettings + requestBody: + content: + application/json: + schema: + additionalProperties: true + description: Settings to update as key-value pairs + type: object + text/plain: + schema: + description: JSON string of settings to update + type: string + required: true + responses: + "200": + content: + application/json: + schema: + additionalProperties: true + description: Updated user settings + type: object + description: Success + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + summary: Update multiple settings + tags: + - settings + /api/settings/{id}: + get: + description: Returns a specific setting value by its id + operationId: getSettingById + parameters: + - description: Setting id to retrieve + in: path + name: id + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + description: Setting value response + properties: + value: + description: The setting value + type: object + description: Success + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Setting not found + summary: Get a specific setting by id + tags: + - settings + post: + description: Update a specific setting by its id + operationId: updateSettingById + parameters: + - description: Setting id to update + in: path + name: id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + description: New value for the setting + text/plain: + schema: + description: JSON string of the new setting value + type: string + required: true + responses: + "200": + content: + application/json: + schema: + description: Updated setting value response + properties: + value: + description: The updated setting value + type: object + description: Success + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + summary: Update a specific setting by id + tags: + - settings + /api/system_stats: + get: + description: Returns system statistics including ComfyUI version, device info, and system resources + operationId: getSystemStats + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/SystemStatsResponse' + description: Success + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + security: [] + summary: Get system statistics + tags: + - system + /api/tags: + get: + description: | + Retrieves a list of all tags used across assets. + Includes usage counts and filtering options. + operationId: listTags + parameters: + - description: Filter tags by prefix + in: query + name: prefix + schema: + type: string + - description: Maximum number of tags to return (1-1000) + in: query + name: limit + schema: + default: 100 + maximum: 1000 + minimum: 1 + type: integer + - description: Number of tags to skip for pagination + in: query + name: offset + schema: + default: 0 + minimum: 0 + type: integer + - description: Sort order for tags + in: query + name: order + schema: + default: count_desc + enum: + - count_desc + - name_asc + type: string + - description: Include tags with zero usage count + in: query + name: include_zero + schema: + default: false + type: boolean + - description: Whether to include public/shared assets when counting tags + in: query + name: include_public + schema: + default: true + type: boolean + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/ListTagsResponse' + description: Tags retrieved successfully + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request parameters + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: List all tags + tags: + - file + /api/tasks: + get: + description: | + Retrieve a paginated list of background tasks for the authenticated user. + Supports filtering by task type, status, and creation time. + operationId: listTasks + parameters: + - description: Filter by task type name (exact match) + example: model_upload + in: query + name: task_name + schema: + type: string + - description: Filter by idempotency key (exact match). For best performance, specify task_name as well. + example: upload-model-abc123 + in: query + name: idempotency_key + schema: + type: string + - description: Filter by one or more statuses (comma-separated) + example: created,running + in: query + name: status + schema: + type: string + - description: Filter tasks created after this timestamp (RFC3339 format) + example: "2024-01-01T00:00:00Z" + in: query + name: created_after + schema: + format: date-time + type: string + - description: Filter tasks created before this timestamp (RFC3339 format) + example: "2024-12-31T23:59:59Z" + in: query + name: created_before + schema: + format: date-time + type: string + - description: Sort direction (asc = ascending, desc = descending by create_time) + in: query + name: sort_order + schema: + default: desc + enum: + - asc + - desc + type: string + - description: Pagination offset (0-based) + in: query + name: offset + schema: + default: 0 + minimum: 0 + type: integer + - description: Maximum items per page (1-100) + in: query + name: limit + schema: + default: 20 + maximum: 100 + minimum: 1 + type: integer + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/TasksListResponse' + description: Success - Tasks retrieved + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized - Authentication required + "422": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Validation error - Invalid filter values + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: List background tasks + tags: + - task + /api/tasks/{task_id}: + get: + description: | + Retrieve full details for a specific background task. + operationId: getTask + parameters: + - description: Task identifier (UUID) + in: path + name: task_id + required: true + schema: + format: uuid + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/TaskResponse' + description: Success - Task details retrieved + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized - Authentication required + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Task not found (also returned for ownership failures to avoid leaking task existence) + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Get task details + tags: + - task + /api/upload/image: + post: + description: | + Upload an image file to cloud storage. - PlanSeatSummary: - type: object - x-runtime: [cloud] - description: '[cloud-only] Summary of seat costs based on current workspace members' - required: - - seat_count - - total_cost_cents - - total_credits_cents - properties: - seat_count: - type: integer - description: Total number of seats (owner + members) that would be charged - example: 5 - total_cost_cents: - type: integer - format: int64 - description: Total cost for all seats in cents - example: 50000 - total_credits_cents: - type: integer - format: int64 - description: Total credits granted for all seats in cents - example: 50000 + Image limits: + - Maximum file size: 50 MB + - Maximum width/height per edge: 16384 px + - Maximum total pixel count: 64 megapixels (67108864 pixels) - PreviewPlanInfo: - type: object - x-runtime: [cloud] - description: '[cloud-only] Plan information for preview display' - required: - - slug - - tier - - duration - - price_cents - - credits_cents - - seat_summary - properties: - slug: - type: string - description: Plan slug - example: team-pro-monthly - tier: - $ref: '#/components/schemas/SubscriptionTier' - duration: - $ref: '#/components/schemas/SubscriptionDuration' - price_cents: - type: integer - format: int64 - description: Per-seat price in cents - example: 10000 - credits_cents: - type: integer - format: int64 - description: Per-seat credits in cents - example: 10000 - seat_summary: - $ref: '#/components/schemas/PlanSeatSummary' - period_start: - type: string - format: date-time - description: Current billing period start (only for current_plan) - period_end: - type: string - format: date-time - description: Current billing period end (only for current_plan) + Uploads that exceed any of these limits are rejected with HTTP 400. + operationId: uploadImage + requestBody: + content: + multipart/form-data: + schema: + properties: + image: + description: The image file to upload + format: binary + type: string + overwrite: + description: Whether to overwrite existing file (true/false) + type: string + subfolder: + description: Optional subfolder path + type: string + type: + description: Upload type (defaults to "output") + type: string + required: + - image + type: object + required: true + responses: + "200": + content: + application/json: + schema: + properties: + name: + description: Filename of the uploaded image + type: string + subfolder: + description: Subfolder path where image was saved + type: string + type: + description: Type of upload (e.g., "output") + type: string + type: object + description: Image uploaded successfully + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Bad request + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Upload an image file + tags: + - file + /api/upload/mask: + post: + description: | + Upload a mask image to be applied to an existing image. - PreviewSubscribeResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Itemized cost preview for a pending subscription change.' - required: - - allowed - - transition_type - - effective_at - - is_immediate - - cost_today_cents - - cost_next_period_cents - - credits_today_cents - - credits_next_period_cents - - new_plan - properties: - allowed: - type: boolean - description: Whether this subscription change is allowed - reason: - type: string - description: Reason why the change is not allowed (only present if allowed=false) - transition_type: - type: string - enum: - - new_subscription - - upgrade - - downgrade - - duration_change - description: Type of subscription transition - effective_at: - type: string - format: date-time - description: When the change takes effect - is_immediate: - type: boolean - description: Whether the change takes effect immediately (true) or at period end (false) - cost_today_cents: - type: integer - format: int64 - description: Amount to charge today in cents (0 for downgrades) - example: 5000 - cost_next_period_cents: - type: integer - format: int64 - description: Amount that will be charged at next billing period in cents - example: 10000 - credits_today_cents: - type: integer - format: int64 - description: Credits granted today in cents (prorated for mid-period upgrades) - example: 5000 - credits_next_period_cents: - type: integer - format: int64 - description: Credits that will be granted at next billing period in cents - example: 10000 - current_plan: - $ref: '#/components/schemas/PreviewPlanInfo' - new_plan: - $ref: '#/components/schemas/PreviewPlanInfo' + Image limits apply to both the uploaded mask and the referenced + original image: + - Maximum file size: 50 MB + - Maximum width/height per edge: 16384 px + - Maximum total pixel count: 64 megapixels (67108864 pixels) - PublishedWorkflowDetail: - type: object - x-runtime: [cloud] - description: '[cloud-only] Full detail of a publicly published workflow on the Hub.' - required: - - share_id - - workflow_id - - name - - listed - - workflow_json - - assets - properties: - share_id: - type: string - workflow_id: - type: string - name: - type: string - description: Human-readable workflow name. - listed: - type: boolean - publish_time: - type: string - format: date-time - nullable: true - workflow_json: - type: object - additionalProperties: true - description: The workflow JSON content at publish time. - assets: - type: array - description: Published assets with their library status for the caller. - items: - $ref: '#/components/schemas/AssetInfo' - - SecretResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] User secret metadata (the secret value itself is never returned after creation).' - required: - - id - - name - - created_at - - updated_at - properties: - id: - type: string - format: uuid - description: Unique identifier for the secret - name: - type: string - description: User-provided label for the secret - provider: - type: string - description: Provider identifier (e.g., huggingface, civitai) - last_used_at: - type: string - format: date-time - description: When the secret was last used for decryption - created_at: - type: string - format: date-time - description: When the secret was created - updated_at: - type: string - format: date-time - description: When the secret was last updated - - SubscriptionDuration: - type: string - x-runtime: [cloud] - enum: - - MONTHLY - - ANNUAL - description: '[cloud-only] Billing period (uppercase to match comfy-api)' - - SubscriptionTier: - type: string - x-runtime: [cloud] - enum: - - FREE - - STANDARD - - CREATOR - - PRO - - FOUNDERS_EDITION - description: '[cloud-only] Subscription tier (uppercase to match comfy-api)' - - UserDataResponseFull: - type: object - x-runtime: [cloud] - description: '[cloud-only] User data listing entry with file metadata (path, size, modification time).' - properties: - path: - type: string - size: - type: integer - modified: - type: integer - format: int64 - description: UNIX timestamp of the last modification in milliseconds. - - ValidationError: - type: object - x-runtime: [cloud] - description: '[cloud-only] Details of a single validation error encountered during asset operations.' - required: - - code - - message - - field - properties: - code: - type: string - description: Machine-readable error code - example: FORMAT_NOT_ALLOWED - message: - type: string - description: Human-readable error message - example: 'File format "PickleTensor" is not allowed. Allowed formats: [SafeTensor]' - field: - type: string - description: Field that failed validation - example: format - - ValidationResult: - type: object - x-runtime: [cloud] - description: '[cloud-only] Result of validating a set of asset operations.' - required: - - is_valid - properties: - is_valid: - type: boolean - description: Overall validation status (true if all checks passed) - example: true - errors: - type: array - items: - $ref: '#/components/schemas/ValidationError' - description: Blocking validation errors that prevent download - warnings: - type: array - items: - $ref: '#/components/schemas/ValidationError' - description: Non-blocking validation warnings (informational only) - - WorkflowForkedFrom: - type: object - x-runtime: [cloud] - description: '[cloud-only] Reference to the parent workflow from which this workflow was forked.' - properties: - workflow_id: - type: string - workflow_version_id: - type: string - - WorkflowResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Full workflow entity including metadata and version history.' - required: - - id - - latest_version - - created_by - - created_at - - updated_at - properties: - id: - type: string - name: - type: string - description: - type: string - default_view: - type: string - enum: - - workflow - - app - latest_version: - type: integer - forked_from: - $ref: '#/components/schemas/WorkflowForkedFrom' - created_by: - type: string - created_at: - type: string - format: date-time - updated_at: - type: string - format: date-time - - WorkflowVersionContentResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Full workflow version including the serialized workflow JSON.' - required: - - id - - version - - workflow_json - - created_by - - created_at - properties: - id: - type: string - version: - type: integer - workflow_json: - type: object - additionalProperties: true - created_by: - type: string - created_at: - type: string - format: date-time - dependency_asset_ids: - type: array - items: - type: string - - WorkspaceAPIKeyInfo: - type: object - x-runtime: [cloud] - description: '[cloud-only] Metadata for a workspace-scoped API key (secret is never returned).' - required: - - id - - workspace_id - - user_id - - name - - description - - key_prefix - - created_at - properties: - id: - type: string - format: uuid - description: API key ID - workspace_id: - type: string - description: Workspace this key belongs to - user_id: - type: string - description: User who created this key - name: - type: string - description: User-provided label - description: - type: string - description: User-provided description of the key's purpose. Limit is byte-based (UTF-8 encoding); 5000 bytes equals - 5000 ASCII characters or fewer multi-byte characters. - maxLength: 5000 - key_prefix: - type: string - description: First 8 chars after prefix for display - expires_at: - type: string - format: date-time - description: When the key expires (if set) - last_used_at: - type: string - format: date-time - description: Last time the key was used - revoked_at: - type: string - format: date-time - description: When the key was revoked (if revoked) - created_at: - type: string - format: date-time - description: When the key was created - - WorkspaceSummary: - type: object - x-runtime: [cloud] - description: '[cloud-only] Abbreviated workspace metadata used in list responses.' - required: - - id - - name - - type - properties: - id: - type: string - example: w-a1b2c3d4-5678-90ab-cdef-1234567890ab - name: - type: string - example: My Team - type: - type: string - enum: - - personal - - team - - WorkspaceWithRole: - type: object - x-runtime: [cloud] - description: '[cloud-only] Workspace entity annotated with the requesting user''s role.' - required: - - id - - name - - type - - role - - created_at - - joined_at - properties: - id: - type: string - example: w-a1b2c3d4-5678-90ab-cdef-1234567890ab - name: - type: string - example: My Team - type: - type: string - enum: - - personal - - team - role: - type: string - enum: - - owner - - member - created_at: - type: string - format: date-time - description: When the workspace was created - joined_at: - type: string - format: date-time - description: When the user joined the workspace (same as created_at for the workspace creator) - subscription_tier: - $ref: '#/components/schemas/SubscriptionTier' - - BindingErrorResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Error shape returned when request binding or validation fails before the handler runs.' - required: - - message - properties: - message: - type: string - - ErrorResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Standard error response from cloud endpoints with a machine-readable code and human-readable message.' - required: - - code - - message - properties: - code: - type: string - description: Machine-readable error code - message: - type: string - description: Human-readable error message - - AcceptInviteResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Response returned after successfully accepting a workspace invitation.' - required: - - workspace_id - - workspace_name - properties: - workspace_id: - type: string - description: ID of the workspace joined - workspace_name: - type: string - description: Name of the workspace joined - - BillingEventsResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Paginated list of billing events for a workspace.' - required: - - total - - events - - page - - limit - - totalPages - properties: - total: - type: integer - description: Total number of events - events: - type: array - items: - $ref: '#/components/schemas/BillingEvent' - page: - type: integer - description: Current page number (1-indexed) - limit: - type: integer - description: Items per page - totalPages: - type: integer - description: Total number of pages - - BillingOpStatusResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Status of an asynchronous billing operation.' - required: - - id - - status - - started_at - properties: - id: - type: string - description: Unique identifier for the billing operation - status: - type: string - enum: - - pending - - succeeded - - failed - description: Current status of the operation - error_message: - type: string - description: Error message if status is failed - started_at: - type: string - format: date-time - description: When the operation was initiated - completed_at: - type: string - format: date-time - description: When the operation completed (success or failure) - - CancelSubscriptionResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Response after successfully cancelling a subscription.' - required: - - cancel_at - - billing_op_id - properties: - billing_op_id: - type: string - description: Billing operation ID to poll for status via GET /api/billing/ops/{id} - cancel_at: - type: string - format: date-time - description: The date when the subscription will end (end of current billing period) - - CreateTopupResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Response after successfully purchasing a credit top-up.' - required: - - topup_id - - status - - amount_cents - - billing_op_id - properties: - billing_op_id: - type: string - description: Billing operation ID to poll for status via GET /api/billing/ops/{id} - topup_id: - type: string - description: Unique identifier for the top-up request (same as billing_op_id, deprecated) - status: - type: string - enum: - - pending - - completed - - failed - description: Current status of the top-up - amount_cents: - type: integer - format: int64 - description: Amount being charged in cents - - CreateWorkspaceAPIKeyResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Response containing the newly created workspace API key.' - required: - - id - - name - - description - - key - - key_prefix - - created_at - properties: - id: - type: string - format: uuid - description: API key ID - name: - type: string - description: User-provided label - description: - type: string - description: User-provided description of the key's purpose. Limit is byte-based (UTF-8 encoding); 5000 bytes equals - 5000 ASCII characters or fewer multi-byte characters. - maxLength: 5000 - key: - type: string - description: The full plaintext API key (only shown once) - key_prefix: - type: string - description: First 8 chars after prefix for display - expires_at: - type: string - format: date-time - description: When the key expires (if set) - created_at: - type: string - format: date-time - description: When the key was created - - ExchangeTokenResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Response containing the issued Cloud JWT and its expiry.' - required: - - token - - expires_at - - workspace - - role - - permissions - properties: - token: - type: string - description: Cloud JWT token - expires_at: - type: string - format: date-time - description: Token expiration time (RFC 3339) - workspace: - $ref: '#/components/schemas/WorkspaceSummary' - role: - type: string - enum: - - owner - - member - description: User's role in the workspace - permissions: - type: array - items: - type: string - description: Permission strings for the role - example: - - owner:* - - JobCancelResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Response for POST /api/jobs/{job_id}/cancel. Returned on both fresh cancels and idempotent no-ops.' - required: - - cancelled - properties: - cancelled: - type: boolean - description: "True when a cancel event was successfully dispatched by this call.\nFalse when the job was already in\ - \ a terminal or cancelling state,\nin which case the call is a no-op (still 200 \u2014 idempotent).\n" - - ResubscribeResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Response after successfully resubscribing to a billing plan.' - required: - - status - - billing_op_id - properties: - billing_op_id: - type: string - description: Billing operation ID to poll for status via GET /api/billing/ops/{id} - status: - type: string - enum: - - active - description: The subscription status after resubscribing - message: - type: string - description: Human-readable confirmation message - - SubscribeResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Response after successfully subscribing to a billing plan.' - required: - - status - - billing_op_id - properties: - billing_op_id: - type: string - description: Billing operation ID to poll for status via GET /api/billing/ops/{id} - status: - type: string - enum: - - subscribed - - needs_payment_method - - pending_payment - description: 'Status of the subscription operation: - - - subscribed: Subscription is active immediately - - - needs_payment_method: User must add payment method via payment_method_url - - - pending_payment: Upgrade initiated, waiting for payment to complete - - ' - effective_at: - type: string - format: date-time - description: When the subscription became/becomes active (present when status=subscribed or pending_payment) - payment_method_url: - type: string - description: URL to redirect user to add payment method (present when status=needs_payment_method) - - UserResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] User information response' - required: - - id - - status - properties: - id: - type: string - description: Firebase UID of the authenticated user - status: - type: string - description: User status (always "active" for authenticated users) - - WorkflowListResponse: - type: object - x-runtime: [cloud] - description: '[cloud-only] Paginated list of saved workflows.' - required: - - data - - pagination - properties: - data: - type: array - items: - $ref: '#/components/schemas/WorkflowResponse' - pagination: - $ref: '#/components/schemas/PaginationInfo' - - FeedbackRequest: - type: object - x-runtime: [cloud] - description: "[cloud-only] User feedback submission body." - required: - - message - properties: - type: - type: string - enum: - - missing_nodes - - general - - missing_models - description: Feedback category - category: - type: string - description: Additional category metadata - message: - type: string - description: User-provided feedback message + Uploads that exceed any of these limits are rejected with HTTP 400. + operationId: uploadMask + requestBody: + content: + multipart/form-data: + schema: + properties: + image: + description: The mask image file to upload + format: binary + type: string + original_ref: + description: JSON string containing reference to the original image + type: string + required: + - image + - original_ref + type: object + required: true + responses: + "200": + content: + application/json: + schema: + properties: + name: + description: Filename of the uploaded mask + type: string + subfolder: + description: Subfolder path where mask was saved + type: string + type: + description: Type of upload (e.g., "output") + type: string + type: object + description: Mask uploaded successfully + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Bad request + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Upload a mask image + tags: + - file + /api/user: + get: + description: Returns information about the currently authenticated user + operationId: getUser + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/UserResponse' + description: Success + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + summary: Get current user information + tags: + - user + /api/userdata: + get: + description: Returns a list of user data files in the specified directory, optionally recursively and with full metadata. + operationId: getUserdata + parameters: + - description: The directory to list files from. + in: query + name: dir + schema: + type: string + - description: Whether to list files recursively. + in: query + name: recurse + schema: + default: false + type: boolean + - description: Whether to split file information by type. + in: query + name: split + schema: + default: false + type: boolean + - description: Whether to return full file metadata. + in: query + name: full_info + schema: + default: false + type: boolean + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/GetUserDataResponseFull' + description: A list of user data files. + "400": + content: + text/plain: + schema: + type: string + description: Bad request (e.g., invalid filename). + "401": + content: + text/plain: + schema: + type: string + description: Unauthorized. + "404": + content: + text/plain: + schema: + type: string + description: File not found or invalid path. + "500": + content: + text/plain: + schema: + type: string + description: General error + summary: List user data files + tags: + - user + /api/userdata/{file}: + delete: + description: | + Delete a user data file from the database. The file parameter should be + the relative path within the user's data directory. + operationId: deleteUserdataFile + parameters: + - description: The file path to delete (URL encoded if necessary). + in: path + name: file + required: true + schema: + type: string + responses: + "204": + description: File deleted successfully (No Content). + "401": + content: + text/plain: + schema: + type: string + description: Unauthorized. + "404": + content: + text/plain: + schema: + type: string + description: File not found. + "500": + content: + text/plain: + schema: + type: string + description: Internal server error. + summary: Delete a user data file + tags: + - user + get: + description: Returns the requested user data file if it exists. + operationId: getUserdataFile + parameters: + - description: The filename of the user data to retrieve. + in: path + name: file + required: true + schema: + type: string + responses: + "200": + content: + application/octet-stream: + schema: + format: binary + type: string + description: Successfully retrieved the file. + "400": + content: + text/plain: + schema: + type: string + description: Bad request (e.g., invalid filename). + "401": + content: + text/plain: + schema: + type: string + description: Unauthorized. + "404": + content: + text/plain: + schema: + type: string + description: File not found or invalid path. + "500": + content: + text/plain: + schema: + type: string + description: General error + summary: Get user data file + tags: + - user + post: + description: | + Upload a file to a user's data directory. Optional query parameters allow + control over overwrite behavior and response detail. + operationId: postUserdataFile + parameters: + - description: The target file path (URL encoded if necessary). + in: path + name: file + required: true + schema: + type: string + - description: If "false", prevents overwriting existing files. Defaults to "true". + in: query + name: overwrite + schema: + default: "true" + enum: + - "true" + - "false" + type: string + - description: If "true", returns detailed file info; if "false", returns only the relative path. + in: query + name: full_info + schema: + default: "false" + enum: + - "true" + - "false" + type: string + requestBody: + content: + application/octet-stream: + schema: + format: binary + type: string + required: true + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/UserDataResponseFull' + description: File uploaded successfully. + "400": + content: + text/plain: + schema: + type: string + description: Missing or invalid 'file' parameter. + "401": + content: + text/plain: + schema: + type: string + description: Unauthorized. + "403": + content: + text/plain: + schema: + type: string + description: The requested path is not allowed. + "409": + content: + text/plain: + schema: + type: string + description: File already exists and overwrite is set to false. + "500": + content: + text/plain: + schema: + type: string + description: General error + summary: Upload or update a user data file + tags: + - user + /api/userdata/{file}/move/{dest}: + post: + description: | + Move or rename a file within a user's data directory, with options for + controlling overwrite behavior and response format. + operationId: moveUserdataFile + parameters: + - description: The source file path (URL encoded if necessary). + in: path + name: file + required: true + schema: + type: string + - description: The destination file path (URL encoded if necessary). + in: path + name: dest + required: true + schema: + type: string + - description: If "false", prevents overwriting existing files. Defaults to "true". + in: query + name: overwrite + schema: + default: "true" + enum: + - "true" + - "false" + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/UserDataResponseFull' + description: File moved successfully. + "400": + content: + text/plain: + schema: + type: string + description: Missing or invalid parameters. + "401": + content: + text/plain: + schema: + type: string + description: Unauthorized. + "404": + content: + text/plain: + schema: + type: string + description: Source file not found. + "409": + content: + text/plain: + schema: + type: string + description: Destination file already exists and overwrite is set to false. + "500": + content: + text/plain: + schema: + type: string + description: General error + summary: Move or rename a user data file + tags: + - user + /api/userdata/{file}/publish: + get: + description: Returns the publish status and share info for a workflow identified by its userdata path. + operationId: getUserdataFilePublish + parameters: + - description: The workflow file path within the user's data directory (URL encoded if necessary). + in: path + name: file + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/WorkflowPublishInfo' + description: Publish info (publish_time is null if never published) + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Workflow not found + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Get publish info for a workflow file + tags: + - workflows + post: + description: Creates a new published_workflow record from the latest version and snapshots the provided assets. + operationId: postUserdataFilePublish + parameters: + - description: The workflow file path within the user's data directory (URL encoded if necessary). + in: path + name: file + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/PublishWorkflowAssetsRequest' + required: true + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/WorkflowPublishInfo' + description: Workflow published + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Bad request + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Workflow not found + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Publish a workflow file + tags: + - workflows + /api/users: + get: + description: | + ComfyUI legacy users endpoint. Returns information about how user + data is stored. In cloud this is always server-managed, so callers + receive a constant response indicating server-side storage. + operationId: getUsersInfo + responses: + "200": + content: + application/json: + schema: + properties: + migrated: + description: Whether user data has been migrated (always true in cloud) + type: boolean + storage: + description: Where user data is stored (always "server" in cloud) + type: string + required: + - storage + - migrated + type: object + description: Userdata storage information + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + summary: ComfyUI userdata storage info + tags: + - user + /api/vhs/queryvideo: + get: + description: | + VHS custom node endpoint that returns metadata about a video file + (frame count, fps, resolution, duration). Currently returns default + placeholder values; real ffprobe integration is a follow-up. + operationId: getVhsQueryVideo + parameters: + - description: Name of the video file to query + in: query + name: filename + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + properties: + source: + description: Source video metadata + properties: + duration: + description: Duration in seconds + type: number + fps: + description: Frames per second + type: number + frames: + description: Total frame count + type: integer + size: + description: '[width, height] in pixels' + items: + type: integer + maxItems: 2 + minItems: 2 + type: array + required: + - size + - fps + - frames + - duration + type: object + required: + - source + type: object + description: Video metadata + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: | + Missing required query parameter. Produced by the oapi-codegen + wrapper via echo.NewHTTPError; the custom Echo HTTPErrorHandler + normalizes it to the standard ErrorResponse {code, message} shape + (BE-1178). + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + security: + - ApiKeyAuth: [] + - BearerAuth: [] + - CookieAuth: [] + summary: Query VHS video metadata + tags: + - file + /api/view: + get: + description: | + Retrieve and view a file from the ComfyUI file system. + This endpoint is typically used to view generated images or other output files. + Cookie auth is allowed on this endpoint because it's used by img/video tags in browsers. + operationId: viewFile + parameters: + - description: Name of the file to view + in: query + name: filename + required: true + schema: + example: ComfyUI_00004_.png + type: string + - description: Subfolder path where the file is located + in: query + name: subfolder + schema: + example: tests/foo/bar + type: string + - description: Type of file (e.g., output, input, temp) + in: query + name: type + schema: + example: output + type: string + - description: Full path to the file (used for temp files) + in: query + name: fullpath + schema: + type: string + - description: Format of the file + in: query + name: format + schema: + type: string + - description: Frame rate for video files + in: query + name: frame_rate + schema: + type: integer + - description: Workflow identifier + in: query + name: workflow + schema: + type: string + - description: Timestamp parameter + in: query + name: timestamp + schema: + example: 1234567890 + type: integer + - description: | + Image channel to extract from PNG images. + - 'rgb': Return only RGB channels (alpha set to fully opaque) + - 'a' or 'alpha': Return alpha channel as grayscale image + - If not specified, return original image unchanged via redirect + in: query + name: channel + schema: + example: rgb + type: string + - description: | + Maximum dimension (width or height) to resize the image to, preserving aspect ratio. + The image is fit within a res x res box. Returns a JPEG thumbnail. + Only applies to raster image files (PNG, JPEG, WebP, GIF). + in: query + name: res + schema: + example: 256 + maximum: 1024 + minimum: 64 + type: integer + responses: + "200": + content: + image/jpeg: + schema: + description: Resized JPEG thumbnail (returned when res parameter is used) + format: binary + type: string + image/png: + schema: + description: Processed PNG image with extracted channel + format: binary + type: string + description: Success - File content returned (used when channel or res parameter is present) + "302": + description: Redirect to GCS signed URL + headers: + Cache-Control: + description: Cache directive for the redirect response + schema: + type: string + Location: + description: Signed URL to access the file in GCS + schema: + type: string + Vary: + description: Headers that affect response caching + schema: + type: string + "400": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Invalid request parameters + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: File not found or unauthorized + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + security: + - ApiKeyAuth: [] + - BearerAuth: [] + - CookieAuth: [] + summary: View a file + tags: + - file + /api/workflow_templates: + get: + description: Returns available workflow templates + operationId: getWorkflowTemplates + responses: + "200": + content: + application/json: + schema: + description: Empty object for workflow templates + type: object + description: Success + security: [] + summary: Get available workflow templates + tags: + - workflow + /api/workflows: + get: + description: Returns a paginated list of workflows for the authenticated user in the current workspace. + operationId: listWorkflows + parameters: + - in: query + name: limit + schema: + default: 20 + maximum: 100 + type: integer + - in: query + name: offset + schema: + default: 0 + type: integer + - description: Search workflows by name (case-insensitive substring match) + in: query + name: name + schema: + type: string + - description: Filter by default view type + in: query + name: default_view + schema: + enum: + - workflow + - app + type: string + - description: Sort field + in: query + name: sort + schema: + default: create_time + enum: + - create_time + - update_time + - name + type: string + - description: Sort order + in: query + name: order + schema: + default: desc + enum: + - asc + - desc + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/WorkflowListResponse' + description: Success + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: List workflows + tags: + - workflows + post: + description: Creates a new workflow with its first version. + operationId: createWorkflow + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/CreateWorkflowRequest' + required: true + responses: + "201": + content: + application/json: + schema: + $ref: '#/components/schemas/WorkflowResponse' + description: Workflow created successfully + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "422": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Validation error + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Create a new workflow + tags: + - workflows + /api/workflows/{workflow_id}: + delete: + description: Soft-deletes a workflow. + operationId: deleteWorkflow + parameters: + - description: The UUID of the workflow to delete. + in: path + name: workflow_id + required: true + schema: + type: string + responses: + "204": + description: Workflow deleted successfully + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Workflow not found + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Delete workflow + tags: + - workflows + get: + description: Retrieves workflow metadata by ID. + operationId: getWorkflow + parameters: + - description: The UUID of the workflow. + in: path + name: workflow_id + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/WorkflowResponse' + description: Success + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "403": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Forbidden + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Workflow not found + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Get workflow + tags: + - workflows + patch: + description: Updates mutable workflow metadata (name, description, default_view). + operationId: updateWorkflow + parameters: + - description: The UUID of the workflow to update. + in: path + name: workflow_id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/UpdateWorkflowRequest' + required: true + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/WorkflowResponse' + description: Success + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Workflow not found + "422": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Validation error + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Update workflow metadata + tags: + - workflows + /api/workflows/{workflow_id}/content: + get: + description: Retrieves the latest version of a workflow and its JSON content. + operationId: getWorkflowContent + parameters: + - description: The UUID of the workflow whose content should be retrieved. + in: path + name: workflow_id + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/WorkflowVersionContentResponse' + description: Success + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "403": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Forbidden + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Workflow not found + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Get workflow content + tags: + - workflows + /api/workflows/{workflow_id}/fork: + post: + description: Creates a new workflow by forking from an existing version. + operationId: forkWorkflow + parameters: + - description: The UUID of the source workflow to fork from. + in: path + name: workflow_id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ForkWorkflowRequest' + required: true + responses: + "201": + content: + application/json: + schema: + $ref: '#/components/schemas/WorkflowResponse' + description: Workflow forked successfully + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "403": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Forbidden + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Source workflow or version not found + "422": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Validation error + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Fork a workflow + tags: + - workflows + /api/workflows/{workflow_id}/versions: + post: + description: Creates a new workflow version with updated workflow JSON. Uses optimistic concurrency via base_version. + operationId: createWorkflowVersion + parameters: + - description: The UUID of the workflow to create a new version for. + in: path + name: workflow_id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/CreateWorkflowVersionRequest' + required: true + responses: + "201": + content: + application/json: + schema: + $ref: '#/components/schemas/WorkflowVersionResponse' + description: Version created successfully + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "403": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Forbidden - not the workflow owner + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Workflow not found + "409": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Version conflict - base_version does not match latest + "422": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Validation error + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Create a new version + tags: + - workflows + /api/workflows/published/{share_id}: + get: + description: | + Returns the published workflow details including the status of each + published asset relative to the caller's library. Authentication is required. + operationId: getPublishedWorkflow + parameters: + - description: The share ID of the published workflow. + in: path + name: share_id + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/PublishedWorkflowDetail' + description: Published workflow details with asset statuses + "401": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Unauthorized + "404": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Share not found + "413": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Workflow JSON too large + "500": + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + description: Internal server error + summary: Get a published workflow by share ID + tags: + - workflows + /health: + get: + description: | + Returns `200 OK` if the database is reachable and dynamic config has + loaded, otherwise `503 Service Unavailable`. Used by the GKE ingress + for health checks. Response body is plain text for probe simplicity. + operationId: getHealth + responses: + "200": + content: + text/plain: + schema: + example: OK + type: string + description: Service is healthy + "503": + content: + text/plain: + schema: + example: Service Unavailable + type: string + description: Service is unhealthy + security: [] + summary: Health probe for Kubernetes readiness/liveness + tags: + - system + /internal/folder_paths: + get: + description: Returns the filesystem paths ComfyUI loads models and assets from, keyed by folder type. + operationId: getInternalFolderPaths + responses: + "200": + content: + application/json: + schema: + additionalProperties: + items: + items: + type: string + type: array + type: array + description: Map of folder type name to list of path entries + type: object + description: Dictionary of folder type to paths + summary: Get configured folder paths + /internal/logs: + get: + description: Returns ComfyUI log entries from the in-memory log buffer. + operationId: getInternalLogs + responses: + "200": + content: + text/plain: + schema: + type: string + description: Log text + summary: Get server logs as text + /internal/logs/raw: + get: + description: Returns the raw ComfyUI log buffer plus size metadata. + operationId: getInternalLogsRaw + responses: + "200": + content: + application/json: + schema: + properties: + entries: + items: + properties: + m: + description: Message + type: string + t: + description: Timestamp + type: number + type: object + type: array + size: + properties: + cols: + type: integer + rows: + type: integer + type: object + type: object + description: Structured log data + summary: Get raw structured log entries + /internal/logs/subscribe: + patch: + description: Subscribes or unsubscribes the current client from live log streaming over the WebSocket. + operationId: subscribeToLogs + requestBody: + content: + application/json: + schema: + properties: + clientId: + description: WebSocket client ID + type: string + enabled: + description: Enable or disable log streaming for this client + type: boolean + required: + - clientId + - enabled + type: object + required: true + responses: + "200": + description: Subscription updated + summary: Subscribe or unsubscribe a WebSocket client to log streaming +security: + - ApiKeyAuth: [] + - BearerAuth: [] +servers: + - description: Default ComfyUI server + url: / +tags: + - description: Workflow execution and management + name: workflow + - description: Node information + name: node + - description: File operations + name: file + - description: User settings management + name: settings + - description: User feedback management + name: feedback + - description: System operations and monitoring + name: system + - description: User information and management + name: user + - description: Background task management + name: task + - description: Workflow storage and version management + name: workflows + - description: Job queue state and control + name: queue + - description: Job lifecycle queries + name: job diff --git a/pyproject.toml b/pyproject.toml index 1e449b4a3..4107b4911 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.22.0" +version = "0.24.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" diff --git a/requirements.txt b/requirements.txt index 2ca6d8929..a49d968af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -comfyui-frontend-package==1.44.19 -comfyui-workflow-templates==0.9.82 -comfyui-embedded-docs==0.5.1 +comfyui-frontend-package==1.45.15 +comfyui-workflow-templates==0.9.98 +comfyui-embedded-docs==0.5.3 torch torchsde torchvision @@ -21,9 +21,9 @@ psutil alembic SQLAlchemy>=2.0.0 filelock -av>=14.2.0 -comfy-kitchen>=0.2.8 -comfy-aimdo==0.4.5 +av>=16.0.0 +comfy-kitchen==0.2.10 +comfy-aimdo==0.4.9 requests simpleeval>=1.0.0 blake3 diff --git a/tests-unit/comfy_extras_test/nodes_math_test.py b/tests-unit/comfy_extras_test/nodes_math_test.py index 714e37c32..030accc5e 100644 --- a/tests-unit/comfy_extras_test/nodes_math_test.py +++ b/tests-unit/comfy_extras_test/nodes_math_test.py @@ -197,3 +197,10 @@ class TestMathExpressionExecute: def test_pow_huge_exponent_raises(self): with pytest.raises(ValueError, match="Exponent .* exceeds maximum"): self._exec("pow(a, b)", a=10, b=10000000) + + def test_huge_int_result_raises_value_error(self): + # Exponent is within the allowed MAX_EXPONENT range, so the result is a + # finite Python int that is nonetheless too large to convert to float. + # This must raise a clean ValueError, not an uncaught OverflowError. + with pytest.raises(ValueError, match="too large to represent as a float"): + self._exec("2 ** 3999") diff --git a/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py new file mode 100644 index 000000000..2a6e3d430 --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_conditioning.py @@ -0,0 +1,213 @@ +"""Consolidated SeedVR2 conditioning and refactor regression tests. + +Merges the prior test_seedvr2_refactor_nodes.py and +test_seedvr_conditioning_hardening.py modules. Refactor tests use the +top-level comfy_extras.nodes_seedvr import; conditioning-hardening tests +use _import_nodes_seedvr_isolated() for sys.modules isolation when +mocking comfy.model_management. +""" + +import importlib +import sys +from unittest.mock import MagicMock + +import pytest +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + + +_SENTINEL = object() +_TARGETS = ( + ("comfy.model_management", "comfy"), + ("comfy_extras.nodes_seedvr", "comfy_extras"), +) + + +def _import_nodes_seedvr_isolated(): + """Import comfy_extras.nodes_seedvr with comfy.model_management mocked.""" + priors = [] + for mod_name, parent_name in _TARGETS: + prior_mod = sys.modules.get(mod_name, _SENTINEL) + parent = sys.modules.get(parent_name) + attr = mod_name.split(".")[-1] + prior_attr = ( + getattr(parent, attr, _SENTINEL) if parent is not None else _SENTINEL + ) + priors.append((mod_name, parent_name, attr, prior_mod, prior_attr)) + + mock_mm = MagicMock() + for fn in ( + "xformers_enabled", "xformers_enabled_vae", + "pytorch_attention_enabled", "pytorch_attention_enabled_vae", + "sage_attention_enabled", "flash_attention_enabled", + "is_intel_xpu", + ): + getattr(mock_mm, fn).return_value = False + tv = torch.version.__version__.split(".") + mock_mm.torch_version_numeric = (int(tv[0]), int(tv[1])) + mock_mm.WINDOWS = False + sys.modules["comfy.model_management"] = mock_mm + if sys.modules.get("comfy") is None: + import comfy as _comfy_pkg # noqa: F401 + comfy_pkg = sys.modules.get("comfy") + if comfy_pkg is not None: + setattr(comfy_pkg, "model_management", mock_mm) + nodes_seedvr = sys.modules.get("comfy_extras.nodes_seedvr") or ( + importlib.import_module("comfy_extras.nodes_seedvr") + ) + + def _restore(): + for mod_name, parent_name, attr, prior_mod, prior_attr in priors: + if prior_mod is _SENTINEL: + sys.modules.pop(mod_name, None) + else: + sys.modules[mod_name] = prior_mod + parent = sys.modules.get(parent_name) + if parent is None: + continue + if prior_attr is _SENTINEL: + if hasattr(parent, attr): + delattr(parent, attr) + else: + setattr(parent, attr, prior_attr) + + return nodes_seedvr, _restore + + +class _Rope(nn.Module): + """Minimal RoPE stub exposing a `freqs` parameter.""" + def __init__(self): + super().__init__() + self.freqs = nn.Parameter(torch.zeros(4)) + + +class _Block(nn.Module): + """Minimal transformer block stub holding a `_Rope`.""" + def __init__(self): + super().__init__() + self.rope = _Rope() + + +class _DiffusionModel(nn.Module): + """Stub diffusion model with N blocks and pos/neg conditioning buffers.""" + def __init__(self, n_blocks=3, zero_conditioning=False, conditioning_dtype=torch.float32): + super().__init__() + self.blocks = nn.ModuleList([_Block() for _ in range(n_blocks)]) + pos = torch.zeros if zero_conditioning else torch.ones + self.register_buffer("positive_conditioning", pos((2, 4), dtype=conditioning_dtype)) + self.register_buffer("negative_conditioning", torch.zeros((3, 4), dtype=conditioning_dtype)) + + +class _ModelInner: + """Inner model wrapper exposing `.diffusion_model`.""" + def __init__(self, diffusion_model): + self.diffusion_model = diffusion_model + + +class _ModelPatcher: + """ModelPatcher stub exposing `.model._ModelInner`.""" + def __init__(self, diffusion_model): + self.model = _ModelInner(diffusion_model) + + +def test_seedvr2_conditioning_schema_exposes_model_passthrough_output(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + schema = nodes_seedvr.SeedVR2Conditioning.define_schema() + assert [input_item.id for input_item in schema.inputs] == [ + "model", + "vae_conditioning", + ] + assert schema.inputs[1].display_name == "latent" + assert [output.display_name for output in schema.outputs] == [ + "model", + "positive", + "negative", + "latent", + ] + finally: + restore() + + +def test_seedvr2_conditioning_returns_packed_input_latent_deterministically(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel() + patcher = _ModelPatcher(diffusion_model) + samples = torch.arange(1, 25, dtype=torch.float32).reshape(1, 2, 3, 2, 2) + vae_conditioning = {"samples": samples} + + _, first_positive, first_negative, first_latent = ( + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, + vae_conditioning, + ) + ) + _, second_positive, second_negative, second_latent = ( + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, + vae_conditioning, + ) + ) + + expected_latent = samples.reshape(1, 6, 2, 2) + channel_last = samples.movedim(1, -1).contiguous() + expected_condition = torch.cat( + [ + channel_last, + torch.ones((*channel_last.shape[:-1], 1)), + ], + dim=-1, + ).movedim(-1, 1).reshape(1, 9, 2, 2) + + assert torch.equal(first_latent["samples"], expected_latent) + assert torch.equal(second_latent["samples"], expected_latent) + assert torch.equal( + first_positive[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + second_positive[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + first_negative[0][1]["condition"], + expected_condition, + ) + assert torch.equal( + second_negative[0][1]["condition"], + expected_condition, + ) + finally: + restore() + + +def test_seedvr2_conditioning_fails_loud_on_zero_buffers(): + nodes_seedvr, restore = _import_nodes_seedvr_isolated() + try: + diffusion_model = _DiffusionModel(zero_conditioning=True) + patcher = _ModelPatcher(diffusion_model) + vae_conditioning = {"samples": torch.zeros((1, 2, 1, 1, 1))} + + with pytest.raises(RuntimeError) as excinfo: + nodes_seedvr.SeedVR2Conditioning.execute( + patcher, vae_conditioning, + ) + + message = str(excinfo.value) + assert message.startswith( + nodes_seedvr._SEEDVR2_INVALID_MODEL_MSG_PREFIX + ), ( + "Fail-loud message must use the standard " + "_SEEDVR2_INVALID_MODEL_MSG_PREFIX so callers/log scrapers " + f"can match it. Got: {message!r}" + ) + assert "positive_conditioning" in message + assert "negative_conditioning" in message + finally: + restore() diff --git a/tests-unit/comfy_extras_test/test_seedvr2_nodes.py b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py new file mode 100644 index 000000000..f7d9a4f65 --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_nodes.py @@ -0,0 +1,55 @@ +import importlib +import inspect +import sys +from unittest.mock import MagicMock, patch + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + + +def test_seedvr_node_signature_matches_schema(): + mock_mm = MagicMock() + mock_mm.xformers_enabled.return_value = False + mock_mm.xformers_enabled_vae.return_value = False + mock_mm.sage_attention_enabled.return_value = False + mock_mm.flash_attention_enabled.return_value = False + + sentinel = object() + prior_cpu = cli_args.cpu + cli_args.cpu = True + prior_module = sys.modules.get("comfy_extras.nodes_seedvr", sentinel) + comfy_pkg = sys.modules.get("comfy") + prior_mm_attr = getattr(comfy_pkg, "model_management", sentinel) if comfy_pkg else sentinel + + with patch.dict(sys.modules, {"comfy.model_management": mock_mm}): + if comfy_pkg is not None: + setattr(comfy_pkg, "model_management", mock_mm) + sys.modules.pop("comfy_extras.nodes_seedvr", None) + try: + nodes_seedvr = importlib.import_module("comfy_extras.nodes_seedvr") + for node_cls in (nodes_seedvr.SeedVR2Preprocess, nodes_seedvr.SeedVR2PostProcessing, nodes_seedvr.SeedVR2Conditioning, nodes_seedvr.SeedVR2ProgressiveSampler): + schema_ids = [i.id for i in node_cls.define_schema().inputs] + exec_params = [ + p for p in inspect.signature(node_cls.execute).parameters.keys() + if p != "cls" + ] + assert schema_ids == exec_params, ( + f"{node_cls.__name__} schema/execute drift: " + f"schema_ids={schema_ids}, exec_params={exec_params}" + ) + finally: + cli_args.cpu = prior_cpu + if prior_module is sentinel: + sys.modules.pop("comfy_extras.nodes_seedvr", None) + else: + sys.modules["comfy_extras.nodes_seedvr"] = prior_module + if comfy_pkg is not None: + if prior_mm_attr is sentinel: + if hasattr(comfy_pkg, "model_management"): + delattr(comfy_pkg, "model_management") + else: + setattr(comfy_pkg, "model_management", prior_mm_attr) diff --git a/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py new file mode 100644 index 000000000..a27a8f8df --- /dev/null +++ b/tests-unit/comfy_extras_test/test_seedvr2_post_processing.py @@ -0,0 +1,57 @@ +from unittest.mock import patch + +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +from comfy_extras import nodes_seedvr # noqa: E402 + + +def _schema_ids(items): + return [item.id for item in items] + + +def test_seedvr2_post_processing_schema(): + schema = nodes_seedvr.SeedVR2PostProcessing.define_schema() + + assert _schema_ids(schema.inputs) == ["images", "original_resized_images", "color_correction_method"] + assert schema.inputs[2].options == ["lab", "wavelet", "adain", "none"] + assert schema.inputs[2].default == "lab" + assert schema.outputs[0].get_io_type() == "IMAGE" + + +def test_seedvr2_post_processing_oom_error_uses_color_correction_method(monkeypatch): + decoded = torch.full((1, 3, 4, 4), 0.25) + reference = torch.full((1, 3, 4, 4), 0.75) + + def _lab(content, style): + raise torch.cuda.OutOfMemoryError("CUDA out of memory") + + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "vae_device", lambda: torch.device("cpu")) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "get_free_memory", lambda device: 1_000_000) + monkeypatch.setattr(nodes_seedvr.comfy.model_management, "soft_empty_cache", lambda: None) + + with patch.object(nodes_seedvr, "lab_color_transfer", _lab): + try: + nodes_seedvr.SeedVR2PostProcessing._color_transfer_chunked( + decoded, reference, torch.device("cpu"), "lab", + ) + except RuntimeError as exc: + assert "color_correction_method=lab" in str(exc) + assert " method=lab" not in str(exc) + else: + raise AssertionError("expected RuntimeError for one-frame LAB OOM") + + +def test_seedvr2_post_processing_unknown_color_correction_method_raises(): + decoded = torch.zeros(1, 2, 4, 4, 3) + original = torch.zeros(1, 2, 4, 4, 3) + try: + nodes_seedvr.SeedVR2PostProcessing.execute(decoded, original, "bogus") + except ValueError as exc: + assert "color_correction_method" in str(exc) + else: + raise AssertionError("expected ValueError for unknown color_correction_method") diff --git a/tests-unit/comfy_test/model_detection_test.py b/tests-unit/comfy_test/model_detection_test.py index 4e9350602..c63f69a0d 100644 --- a/tests-unit/comfy_test/model_detection_test.py +++ b/tests-unit/comfy_test/model_detection_test.py @@ -73,6 +73,24 @@ def _make_flux_schnell_comfyui_sd(): return sd +def _make_seedvr2_7b_separate_mm_sd(): + return { + "blocks.35.mlp.vid.proj_in.weight": torch.empty(1, 3072), + } + + +def _make_seedvr2_7b_shared_mm_sd(): + return { + "blocks.35.mlp.all.proj_in_gate.weight": torch.empty(1, 1), + } + + +def _make_seedvr2_3b_shared_mm_sd(): + return { + "blocks.31.mlp.all.proj_in_gate.weight": torch.empty(1, 1), + } + + class TestModelDetection: """Verify that first-match model detection selects the correct model based on list ordering and unet_config specificity.""" @@ -125,6 +143,48 @@ class TestModelDetection: assert model_config is not None assert type(model_config).__name__ == "FluxSchnell" + def test_seedvr2_7b_separate_mm_detection_config(self): + sd = _make_seedvr2_7b_separate_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert unet_config is not None + assert unet_config["image_model"] == "seedvr2" + assert unet_config["vid_dim"] == 3072 + assert unet_config["heads"] == 24 + assert unet_config["num_layers"] == 36 + assert unet_config["mm_layers"] == 36 + assert unet_config["mlp_type"] == "normal" + assert unet_config["qk_rope"] is True + assert unet_config["rope_type"] == "rope3d" + assert unet_config["rope_dim"] == 64 + + def test_seedvr2_7b_shared_mm_detection_config(self): + sd = _make_seedvr2_7b_shared_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert unet_config is not None + assert unet_config["image_model"] == "seedvr2" + assert unet_config["vid_dim"] == 3072 + assert unet_config["heads"] == 24 + assert unet_config["num_layers"] == 36 + assert unet_config["mm_layers"] == 10 + assert unet_config["mlp_type"] == "swiglu" + assert unet_config["qk_rope"] is True + assert unet_config["rope_type"] == "rope3d" + assert unet_config["rope_dim"] == 64 + + def test_seedvr2_3b_shared_mm_detection_config(self): + sd = _make_seedvr2_3b_shared_mm_sd() + unet_config = detect_unet_config(sd, "") + + assert unet_config is not None + assert unet_config["image_model"] == "seedvr2" + assert unet_config["vid_dim"] == 2560 + assert unet_config["heads"] == 20 + assert unet_config["num_layers"] == 32 + assert unet_config["mlp_type"] == "swiglu" + assert unet_config["qk_rope"] is None + def test_unet_config_and_required_keys_combination_is_unique(self): """Each model in the registry must have a unique combination of ``unet_config`` and ``required_keys``. If two models share the same diff --git a/tests-unit/comfy_test/seedvr_vae_forward_test.py b/tests-unit/comfy_test/seedvr_vae_forward_test.py new file mode 100644 index 000000000..f9dbd6890 --- /dev/null +++ b/tests-unit/comfy_test/seedvr_vae_forward_test.py @@ -0,0 +1,90 @@ +"""Regression: ``comfy.ldm.seedvr.vae.VideoAutoencoderKL.forward`` must +honor the actual tensor/tuple return contract of ``encode()`` and +``decode_()`` and must NOT dereference diffusers-style ``.latent_dist`` +or ``.sample`` attributes on those returns. + +The pre-fix body raised ``AttributeError: 'Tensor' object has no +attribute 'latent_dist'`` for ``mode in {"encode", "all"}`` and +``AttributeError: 'VideoAutoencoderKL' object has no attribute 'decode'`` +for ``mode == "decode"`` (the class only defines ``decode_`` with a +trailing underscore). The post-fix body unwraps the optional one-element +tuple shape that ``return_dict=False`` produces and returns the tensor +directly. + +Tests construct a stub subclass of ``VideoAutoencoderKL`` that bypasses +the heavy ``__init__`` via ``torch.nn.Module.__init__(self)`` and +overrides ``encode``/``decode_`` with known tensors so the contract can +be probed without loading any real VAE weights. +""" + +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +from comfy.ldm.seedvr.vae import VideoAutoencoderKL # noqa: E402 + + +_LATENT_SHAPE = (1, 16, 2, 2, 2) +_DECODED_SHAPE = (1, 3, 5, 16, 16) +_INPUT_ENCODE_SHAPE = (1, 3, 5, 16, 16) +_INPUT_DECODE_SHAPE = (1, 16, 2, 2, 2) + + +class _StubVAE(VideoAutoencoderKL): + def __init__(self): + nn.Module.__init__(self) + self._encode_out = torch.zeros(*_LATENT_SHAPE) + self._decode_out = torch.zeros(*_DECODED_SHAPE) + + def encode(self, x, return_dict=True): + return self._encode_out + + def decode_(self, z, return_dict=True): + return self._decode_out + + +def test_forward_encode_returns_tensor(): + vae = _StubVAE() + x = torch.zeros(*_INPUT_ENCODE_SHAPE) + result = vae.forward(x, mode="encode") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_LATENT_SHAPE) + + +def test_forward_decode_returns_tensor(): + vae = _StubVAE() + z = torch.zeros(*_INPUT_DECODE_SHAPE) + result = vae.forward(z, mode="decode") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_DECODED_SHAPE) + + +class _TupleReturningStubVAE(VideoAutoencoderKL): + """Stub variant whose ``encode``/``decode_`` return the + ``(tensor,)`` one-element tuple shape ``return_dict=False`` produces + in the parent class. Exercises the unwrap branch of + ``VideoAutoencoderKL.forward``. + """ + + def __init__(self): + nn.Module.__init__(self) + self._encode_tensor = torch.zeros(*_LATENT_SHAPE) + self._decode_tensor = torch.zeros(*_DECODED_SHAPE) + + def encode(self, x, return_dict=True): + return (self._encode_tensor,) + + def decode_(self, z, return_dict=True): + return (self._decode_tensor,) + + +def test_forward_all_unwraps_one_tuple_at_each_step(): + vae = _TupleReturningStubVAE() + x = torch.zeros(*_INPUT_ENCODE_SHAPE) + result = vae.forward(x, mode="all") + assert type(result) is torch.Tensor + assert result.shape == torch.Size(_DECODED_SHAPE) diff --git a/tests-unit/comfy_test/test_seedvr2_dtype.py b/tests-unit/comfy_test/test_seedvr2_dtype.py new file mode 100644 index 000000000..e5d79a306 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_dtype.py @@ -0,0 +1,47 @@ +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.sd +import comfy.supported_models +import comfy.ldm.seedvr.model as seedvr_model + + +def test_seedvr2_fp16_manual_cast_only_for_bf16_device(monkeypatch): + bf16_device = object() + fp16_device = object() + + monkeypatch.setattr( + comfy.supported_models.comfy.model_management, + "should_use_bf16", + lambda device=None: device is bf16_device, + ) + + bf16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) + bf16_config.set_inference_dtype(torch.float16, None, device=bf16_device) + assert bf16_config.manual_cast_dtype is torch.bfloat16 + + fp16_config = comfy.supported_models.SeedVR2({"image_model": "seedvr2"}) + fp16_config.set_inference_dtype(torch.float16, None, device=fp16_device) + assert fp16_config.manual_cast_dtype is None + + +def test_seedvr2_text_conditioning_accepts_cfg1_single_branch(): + context = torch.arange(6, dtype=torch.float32).reshape(1, 3, 2) + + txt, txt_shape = seedvr_model.NaDiT._resolve_text_conditioning(object(), context, [0]) + + torch.testing.assert_close(txt, context.squeeze(0)) + torch.testing.assert_close(txt_shape, torch.tensor([[3]], device=context.device)) + + +def test_seedvr2_vae_decode_memory_covers_full_frame_lab_transfer(): + estimate = comfy.sd._seedvr2_vae_decode_memory_used((1, 16, 26, 120, 160)) + old_estimate = 16 * 120 * 160 * (4 * 8 * 8) * 2 + + assert estimate == 101 * 960 * 1280 * 160 + assert estimate > 15 * 1024 ** 3 + assert estimate > old_estimate * 100 diff --git a/tests-unit/comfy_test/test_seedvr2_internals.py b/tests-unit/comfy_test/test_seedvr2_internals.py new file mode 100644 index 000000000..5b008ea6e --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_internals.py @@ -0,0 +1,341 @@ +"""Consolidated SeedVR2 internals regression tests. + +Sources (all merged verbatim, helper names disambiguated where colliding): + + * RoPE rewrite — NaMMRotaryEmbedding3d.forward must match the legacy + apply_rotary_emb wrapper oracle at fp32. + * GroupNorm limit gate — causal_norm_wrapper at vae.py:509 must compare + memory_occupy against get_norm_limit(), not float('inf'). + * SeedVR2 variable-length attention split-loop contract. + +Pre-import CPU-only guard is required because comfy.ldm.seedvr.model and +comfy.ldm.modules.attention transitively pull in comfy.model_management, +which probes torch.cuda.current_device() at import time unless args.cpu is +set first. +""" + +from __future__ import annotations + +from unittest.mock import patch + +import pytest +import torch + +from comfy.cli_args import args + +if not torch.cuda.is_available(): + args.cpu = True + +import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 +import comfy.ldm.modules.attention as attention # noqa: E402 +import comfy.ops as comfy_ops # noqa: E402 +from comfy.ldm.seedvr.model import ( # noqa: E402 + Cache, + NaMMRotaryEmbedding3d, +) +from comfy.ldm.seedvr.vae import ( # noqa: E402 + causal_norm_wrapper, + set_norm_limit, +) +from comfy.ldm.modules.attention import var_attention_optimized_split # noqa: E402 + + +# --------------------------------------------------------------------------- +# RoPE rewrite tests (test_seedvr_rope_rewrite.py) +# --------------------------------------------------------------------------- + +# Test rig dimensions. dim=192 → per-axis rope dim = 64 (even, lucidrains +# requirement). vid_shape=(2,4,4) → L_vid = 32. txt_shape=(8,) → L_txt = 8. +_DIM = 192 +_HEADS = 4 +_VID_T, _VID_H, _VID_W = 2, 4, 4 +_TXT_L = 8 +_L_VID = _VID_T * _VID_H * _VID_W +_SEED = 0 + + +def _make_inputs(dtype=torch.float32, device="cpu"): + """Construct the 6 forward inputs + cache. Deterministic via local + Generator so global RNG state is not mutated. + """ + g = torch.Generator(device=device).manual_seed(_SEED) + vid_q = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) + vid_k = torch.randn(_L_VID, _HEADS, _DIM, dtype=dtype, device=device, generator=g) + txt_q = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) + txt_k = torch.randn(_TXT_L, _HEADS, _DIM, dtype=dtype, device=device, generator=g) + vid_shape = torch.tensor([[_VID_T, _VID_H, _VID_W]], dtype=torch.long, device=device) + txt_shape = torch.tensor([[_TXT_L]], dtype=torch.long, device=device) + cache = Cache(disable=True) + return vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache + + +def _legacy_get_freqs(rope: NaMMRotaryEmbedding3d, vid_shape, txt_shape): + """Reproduce the pre-rewrite ``get_freqs`` body verbatim against + ``self.get_axial_freqs`` (parent ``RotaryEmbeddingBase`` method, + unchanged by the rewrite). + """ + max_temporal = 0 + max_height = 0 + max_width = 0 + max_txt_len = 0 + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + max_temporal = max(max_temporal, l + f) + max_height = max(max_height, h) + max_width = max(max_width, w) + max_txt_len = max(max_txt_len, l) + with torch.amp.autocast(device_type="cuda", enabled=False): + vid_freqs_full = rope.get_axial_freqs( + min(max_temporal + 16, 1024), + min(max_height + 4, 128), + min(max_width + 4, 128), + ).float() + txt_freqs_full = rope.get_axial_freqs(min(max_txt_len + 16, 1024)) + vid_freq_list, txt_freq_list = [], [] + for (f, h, w), l in zip(vid_shape.tolist(), txt_shape[:, 0].tolist()): + vid_freq = vid_freqs_full[l : l + f, :h, :w].reshape(-1, vid_freqs_full.size(-1)) + txt_freq = txt_freqs_full[:l].repeat(1, 3).reshape(-1, vid_freqs_full.size(-1)) + vid_freq_list.append(vid_freq) + txt_freq_list.append(txt_freq) + return torch.cat(vid_freq_list, dim=0), torch.cat(txt_freq_list, dim=0) + + +def _legacy_forward(rope: NaMMRotaryEmbedding3d, vid_q, vid_k, vid_shape, + txt_q, txt_k, txt_shape): + """Compute expected forward output via the unchanged + ``apply_rotary_emb`` wrapper fed with legacy-shape freqs. This is the + oracle. The wrapper itself is out of scope for the rewrite (Shape B). + """ + vid_freqs, txt_freqs = _legacy_get_freqs(rope, vid_shape, txt_shape) + vid_freqs = vid_freqs.to(vid_q.device) + txt_freqs = txt_freqs.to(txt_q.device) + + from einops import rearrange + + vid_q = rearrange(vid_q, "L h d -> h L d") + vid_k = rearrange(vid_k, "L h d -> h L d") + vid_q_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_q.float()).to(vid_q.dtype) + vid_k_out = seedvr_model.apply_rotary_emb(vid_freqs, vid_k.float()).to(vid_k.dtype) + vid_q_out = rearrange(vid_q_out, "h L d -> L h d") + vid_k_out = rearrange(vid_k_out, "h L d -> L h d") + + txt_q = rearrange(txt_q, "L h d -> h L d") + txt_k = rearrange(txt_k, "L h d -> h L d") + txt_q_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_q.float()).to(txt_q.dtype) + txt_k_out = seedvr_model.apply_rotary_emb(txt_freqs, txt_k.float()).to(txt_k.dtype) + txt_q_out = rearrange(txt_q_out, "h L d -> L h d") + txt_k_out = rearrange(txt_k_out, "h L d -> L h d") + return vid_q_out, vid_k_out, txt_q_out, txt_k_out + + +def test_namm_forward_output_tensor_equal_against_legacy_oracle(): + rope = NaMMRotaryEmbedding3d(dim=_DIM) + vid_q, vid_k, vid_shape, txt_q, txt_k, txt_shape, cache = _make_inputs() + + expected_vid_q, expected_vid_k, expected_txt_q, expected_txt_k = _legacy_forward( + rope, + vid_q.clone(), vid_k.clone(), vid_shape, + txt_q.clone(), txt_k.clone(), txt_shape, + ) + + actual_vid_q, actual_vid_k, actual_txt_q, actual_txt_k = rope.forward( + vid_q.clone(), vid_k.clone(), vid_shape, + txt_q.clone(), txt_k.clone(), txt_shape, cache, + ) + + torch.testing.assert_close(actual_vid_q, expected_vid_q, rtol=0, atol=0, + msg="vid_q output diverges from wrapper oracle") + torch.testing.assert_close(actual_vid_k, expected_vid_k, rtol=0, atol=0, + msg="vid_k output diverges from wrapper oracle") + torch.testing.assert_close(actual_txt_q, expected_txt_q, rtol=0, atol=0, + msg="txt_q output diverges from wrapper oracle") + torch.testing.assert_close(actual_txt_k, expected_txt_k, rtol=0, atol=0, + msg="txt_k output diverges from wrapper oracle") + + +# --------------------------------------------------------------------------- +# GroupNorm limit tests (test_seedvr_groupnorm_limit.py) +# --------------------------------------------------------------------------- + +_NUM_CHANNELS = 8 +_NUM_GROUPS = 4 +_TENSOR_SHAPE = (1, 8, 2, 4, 4) + +_GROUPNORM_SUBCLASSES = [ + pytest.param(comfy_ops.disable_weight_init.GroupNorm, id="disable_weight_init"), + pytest.param(comfy_ops.manual_cast.GroupNorm, id="manual_cast"), +] + + +@pytest.mark.parametrize("groupnorm_cls", _GROUPNORM_SUBCLASSES) +def test_seedvr_groupnorm_low_limit_uses_chunked_groupnorm_path(groupnorm_cls): + real_group_norm = vae_mod.F.group_norm + set_norm_limit(1e-9) + try: + gn = groupnorm_cls(num_channels=_NUM_CHANNELS, num_groups=_NUM_GROUPS) + gn.eval() + + forward_hook_calls = [] + + def _hook(module, inputs, output): + forward_hook_calls.append(tuple(inputs[0].shape)) + + spy_calls = [] + + def _group_norm_spy(input_tensor, num_groups_arg, *args, **kwargs): + spy_calls.append({"num_groups": int(num_groups_arg)}) + return real_group_norm(input_tensor, num_groups_arg, *args, **kwargs) + + handle = gn.register_forward_hook(_hook) + try: + with patch.object(vae_mod.F, "group_norm", side_effect=_group_norm_spy): + out_tensor = causal_norm_wrapper(gn, torch.randn(*_TENSOR_SHAPE)) + finally: + handle.remove() + + full_calls = len(forward_hook_calls) + chunked_calls = sum(1 for entry in spy_calls if entry["num_groups"] < _NUM_GROUPS) + + assert tuple(int(s) for s in out_tensor.shape) == _TENSOR_SHAPE + assert full_calls == 0, ( + f"low-limit GroupNorm gate must NOT take the full-forward path; got full_calls={full_calls}" + ) + assert chunked_calls > 0, ( + f"low-limit GroupNorm gate must take the chunked path; got chunked_calls={chunked_calls}" + ) + finally: + set_norm_limit(None) + + +# --------------------------------------------------------------------------- +# SeedVR2 var_attention split-loop tests +# --------------------------------------------------------------------------- + +def test_var_attention_registry_contains_always_available_entries(): + assert ( + attention.REGISTERED_ATTENTION_FUNCTIONS["var_attention_optimized_split"] + is attention.var_attention_optimized_split + ) + + +def test_seedvr2_7b_swin_attention_forward_uses_optimized_var_attention(monkeypatch): + dim = 8 + heads = 2 + head_dim = 4 + attn = seedvr_model.NaSwinAttention( + vid_dim=dim, + txt_dim=dim, + heads=heads, + head_dim=head_dim, + qk_bias=False, + qk_norm=seedvr_model.CustomRMSNorm, + qk_norm_eps=1e-6, + rope_type=None, + rope_dim=head_dim, + shared_weights=False, + window=(2, 1, 1), + window_method="720pwin_by_size_bysize", + version=True, + device="cpu", + dtype=torch.float32, + operations=comfy_ops.disable_weight_init, + ) + generator = torch.Generator(device="cpu").manual_seed(11) + vid = torch.randn(8, dim, generator=generator) + txt = torch.randn(3, dim, generator=generator) + vid_shape = torch.tensor([[2, 2, 2]], dtype=torch.long) + txt_shape = torch.tensor([[3]], dtype=torch.long) + calls = [] + + def fake_optimized_var_attention(**kwargs): + calls.append(kwargs) + return kwargs["q"] + + monkeypatch.setattr(seedvr_model, "optimized_var_attention", fake_optimized_var_attention) + + vid_out, txt_out = attn(vid, txt, vid_shape, txt_shape, seedvr_model.Cache(disable=True)) + + assert tuple(vid_out.shape) == (8, dim) + assert tuple(txt_out.shape) == (3, dim) + assert len(calls) == 1 + call = calls[0] + assert tuple(call["q"].shape) == (14, heads, head_dim) + assert tuple(call["k"].shape) == (14, heads, head_dim) + assert tuple(call["v"].shape) == (14, heads, head_dim) + assert call["heads"] == heads + assert call["skip_reshape"] is True + assert call["skip_output_reshape"] is True + torch.testing.assert_close( + call["cu_seqlens_q"], + torch.tensor([0, 7, 14], dtype=torch.int32), + rtol=0, + atol=0, + ) + torch.testing.assert_close( + call["cu_seqlens_k"], + torch.tensor([0, 7, 14], dtype=torch.int32), + rtol=0, + atol=0, + ) + + +def test_var_attention_optimized_split_calls_dense_backend_per_window(monkeypatch): + heads = 2 + head_dim = 3 + q = torch.arange(30, dtype=torch.float32).reshape(5, heads, head_dim) + k = q + 100 + v = q + 200 + cu = torch.tensor([0, 2, 5], dtype=torch.int32) + calls = [] + + def fake_optimized_attention(q_arg, k_arg, v_arg, heads_arg, **kwargs): + calls.append( + { + "q_shape": tuple(q_arg.shape), + "k_shape": tuple(k_arg.shape), + "v_shape": tuple(v_arg.shape), + "heads": heads_arg, + "kwargs": kwargs, + } + ) + return q_arg + v_arg + + monkeypatch.setattr(attention, "optimized_attention", fake_optimized_attention) + + out = var_attention_optimized_split( + q, + k, + v, + heads, + cu, + cu, + skip_reshape=True, + skip_output_reshape=True, + ) + + assert tuple(out.shape) == (5, heads, head_dim) + assert len(calls) == 2 + assert calls[0]["q_shape"] == (1, heads, 2, head_dim) + assert calls[1]["q_shape"] == (1, heads, 3, head_dim) + assert all(call["heads"] == heads for call in calls) + assert all(call["kwargs"]["skip_reshape"] is True for call in calls) + assert all(call["kwargs"]["skip_output_reshape"] is True for call in calls) + torch.testing.assert_close(out, q + v, rtol=0, atol=0) + + +def test_var_attention_optimized_split_rejects_bad_offsets(): + q = torch.randn(5, 2, 3) + cu_bad = torch.tensor([0, 2, 6], dtype=torch.int32) + cu_ok = torch.tensor([0, 2, 5], dtype=torch.int32) + + with pytest.raises(ValueError, match="cu_seqlens_q does not match token count"): + var_attention_optimized_split( + q, + q, + q, + 2, + cu_bad, + cu_ok, + skip_reshape=True, + skip_output_reshape=True, + ) diff --git a/tests-unit/comfy_test/test_seedvr2_model.py b/tests-unit/comfy_test/test_seedvr2_model.py new file mode 100644 index 000000000..f2b9bcbbe --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_model.py @@ -0,0 +1,308 @@ +"""Consolidated SeedVR2 model/graph/forward regression tests. + +Merged from: +- seedvr_model_test.py +- test_seedvr_7b_final_block_text_path.py +- test_seedvr_forward_no_device_cast.py +- test_seedvr_latent_format.py +- test_seedvr2_vae_graph_boundaries.py +""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import torch +from torch import nn + +from comfy.cli_args import args + +if not torch.cuda.is_available(): + args.cpu = True + +import comfy # noqa: E402 +import comfy.latent_formats # noqa: E402 +import comfy.ldm.seedvr.model # noqa: E402 +import comfy.ldm.seedvr.model as seedvr_model # noqa: E402 +import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 +import comfy.model_management # noqa: E402 +import comfy.sample # noqa: E402 +import comfy.sd as sd_mod # noqa: E402 +import nodes as nodes_mod # noqa: E402 +from comfy.ldm.seedvr.model import NaDiT # noqa: E402 + + +# --------------------------------------------------------------------------- +# Helpers from seedvr_model_test.py +# --------------------------------------------------------------------------- + + +def _make_standin(positive_conditioning): + class _StandIn(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "positive_conditioning", positive_conditioning + ) + + _resolve_text_conditioning = NaDiT._resolve_text_conditioning + + return _StandIn() + + +# --------------------------------------------------------------------------- +# Helpers from test_seedvr_7b_final_block_text_path.py +# --------------------------------------------------------------------------- + + +class _StubModule(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + + +def _capture_last_layer_flags(monkeypatch, vid_dim: int, txt_in_dim: int) -> list[bool]: + flags = [] + + class _Block(_StubModule): + def __init__(self, *args, **kwargs): + flags.append(kwargs["is_last_layer"]) + super().__init__() + + monkeypatch.setattr(seedvr_model, "NaPatchIn", _StubModule) + monkeypatch.setattr(seedvr_model, "NaPatchOut", _StubModule) + monkeypatch.setattr(seedvr_model, "TimeEmbedding", _StubModule) + monkeypatch.setattr(seedvr_model, "NaMMSRTransformerBlock", _Block) + + seedvr_model.NaDiT( + norm_eps=1e-5, + qk_rope=None, + num_layers=4, + mlp_type="normal", + vid_dim=vid_dim, + txt_in_dim=txt_in_dim, + heads=24, + mm_layers=3, + ) + + return flags + + +# --------------------------------------------------------------------------- +# Helpers from test_seedvr_latent_format.py +# --------------------------------------------------------------------------- + + +class _Model: + def __init__(self, latent_format): + self._latent_format = latent_format + + def get_model_object(self, name): + assert name == "latent_format" + return self._latent_format + + +# --------------------------------------------------------------------------- +# Helpers from test_seedvr2_vae_graph_boundaries.py +# --------------------------------------------------------------------------- + + +class _Patcher: + def get_free_memory(self, device): + return 1024 * 1024 * 1024 + + +class _EncodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): + def __init__(self, encoded): + nn.Module.__init__(self) + self.encoded = encoded + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.seen = [] + + def encode(self, x): + self.seen.append(tuple(x.shape)) + return self.encoded.to(device=x.device, dtype=x.dtype) + + +class _DecodeWrapper(seedvr_vae_mod.VideoAutoencoderKLWrapper): + def __init__(self): + nn.Module.__init__(self) + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.calls = [] + + def decode(self, z, seedvr2_tiling=None): + self.calls.append({"shape": tuple(z.shape), "seedvr2_tiling": seedvr2_tiling}) + if z.ndim == 4: + b, tc, h, w = z.shape + t = tc // 16 + else: + b, _, t, h, w = z.shape + return torch.zeros(b, 3, t, h * 8, w * 8, dtype=z.dtype, device=z.device) + + +def _make_vae(wrapper): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = wrapper + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.latent_channels = 16 + vae.latent_dim = 3 + vae.downscale_ratio = (lambda a: max(0, (a + 3) // 4), 8, 8) + vae.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + vae.output_channels = 3 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.crop_input = False + vae.not_video = False + vae.patcher = _Patcher() + vae.process_input = lambda image: image + vae.process_output = lambda image: image.add(1.0).div(2.0).clamp(0.0, 1.0) + vae.vae_output_dtype = lambda: torch.float32 + vae.memory_used_encode = lambda shape, dtype: 1 + vae.memory_used_decode = lambda shape, dtype: 1 + vae.throw_exception_if_invalid = lambda: None + vae.vae_encode_crop_pixels = lambda pixels: pixels + vae.spacial_compression_decode = lambda: 8 + vae.temporal_compression_decode = lambda: 4 + return vae + + +# --------------------------------------------------------------------------- +# Tests from seedvr_model_test.py +# --------------------------------------------------------------------------- + + +def test_missing_context_falls_back_to_positive_buffer(): + """AC: ``context is None`` falls back to the registered + ``positive_conditioning`` buffer and runs to completion — no + silent zero substitution, no raised exception. + """ + pos_buffer = torch.full((58, 5120), 7.0) + standin = _make_standin(pos_buffer) + txt, txt_shape = standin._resolve_text_conditioning(None) + assert txt.shape == (58, 5120) + assert (txt == 7.0).all(), ( + "fallback path must use the positive_conditioning buffer " + "verbatim, not a zero tensor" + ) + assert txt_shape.shape == (1, 1) + assert txt_shape[0, 0].item() == 58 + + +# --------------------------------------------------------------------------- +# Tests from test_seedvr_7b_final_block_text_path.py +# --------------------------------------------------------------------------- + + +def test_seedvr2_7b_keeps_final_block_text_path(monkeypatch): + assert _capture_last_layer_flags(monkeypatch, vid_dim=3072, txt_in_dim=3072) == [ + False, + False, + False, + False, + ] + + +def test_seedvr2_7b_rope3d_matches_wrapper_oracle(): + rope = seedvr_model.get_na_rope("rope3d", dim=64) + generator = torch.Generator(device="cpu").manual_seed(0) + q = torch.randn(4, 2, 128, generator=generator) + k = torch.randn(4, 2, 128, generator=generator) + shape = torch.tensor([[1, 2, 2]], dtype=torch.long) + freqs = rope.get_axial_freqs(1, 2, 2).reshape(4, -1) + + expected_q = seedvr_model._apply_seedvr2_rotary_emb( + freqs, + q.permute(1, 0, 2).float(), + ).to(q.dtype).permute(1, 0, 2) + expected_k = seedvr_model._apply_seedvr2_rotary_emb( + freqs, + k.permute(1, 0, 2).float(), + ).to(k.dtype).permute(1, 0, 2) + + actual_q, actual_k = rope(q.clone(), k.clone(), shape, seedvr_model.Cache(disable=True)) + + torch.testing.assert_close(actual_q, expected_q, rtol=0, atol=0) + torch.testing.assert_close(actual_k, expected_k, rtol=0, atol=0) + + +# --------------------------------------------------------------------------- +# Tests from test_seedvr_latent_format.py +# --------------------------------------------------------------------------- + + +def test_seedvr2_latent_format_uses_16_channels_without_3d_empty_latent_expansion(): + latent_format = comfy.latent_formats.SeedVR2() + latent_image = torch.zeros(1, 1, 4, 5) + + fixed = comfy.sample.fix_empty_latent_channels(_Model(latent_format), latent_image) + + assert latent_format.latent_channels == 16 + assert latent_format.latent_dimensions == 2 + assert fixed.shape == (1, 16, 4, 5) + + +# --------------------------------------------------------------------------- +# Tests from test_seedvr2_vae_graph_boundaries.py +# --------------------------------------------------------------------------- + + +def test_seedvr2_encode_and_encode_tiled_preserve_native_latent_contract(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + + encoded = torch.full((1, 16, 2, 4, 5), 2.0) + vae = _make_vae(_EncodeWrapper(encoded)) + pixels = torch.zeros(1, 5, 32, 40, 3) + + node_output = nodes_mod.VAEEncode().encode(vae, pixels)[0] + node_latent = node_output["samples"] + assert set(node_output) == {"samples"} + assert tuple(node_latent.shape) == (1, 16, 2, 4, 5) + assert node_latent.dtype == torch.float32 + assert node_latent.stride()[-1] == 1 + assert torch.equal(node_latent, torch.full_like(node_latent, 2.0 * 0.9152)) + + tiled = torch.full((1, 16, 2, 4, 5), 3.0) + monkeypatch.setattr(seedvr_vae_mod, "tiled_vae", MagicMock(return_value=tiled)) + tiled_output = nodes_mod.VAEEncodeTiled().encode( + vae, + pixels, + tile_size=512, + overlap=64, + temporal_size=16, + temporal_overlap=4, + )[0] + tiled_latent = tiled_output["samples"] + assert set(tiled_output) == {"samples"} + assert tuple(tiled_latent.shape) == (1, 16, 2, 4, 5) + assert tiled_latent.dtype == torch.float32 + assert torch.equal(tiled_latent, torch.full_like(tiled_latent, 3.0 * 0.9152)) + + +def test_vaedecode_tiled_visible_inputs_are_seedvr2_decode_tiling_authority(monkeypatch): + monkeypatch.setattr(sd_mod.model_management, "load_models_gpu", lambda *a, **k: None) + vae = _make_vae(_DecodeWrapper()) + + nodes_mod.VAEDecodeTiled().decode( + vae, + {"samples": torch.zeros(1, 16, 2, 4, 5)}, + tile_size=512, + overlap=64, + temporal_size=16, + temporal_overlap=4, + ) + + assert vae.first_stage_model.calls == [ + { + "shape": (1, 16, 2, 4, 5), + "seedvr2_tiling": { + "enable_tiling": True, + "tile_size": (512, 512), + "tile_overlap": (64, 64), + "temporal_size": 16, + "temporal_overlap": 4, + }, + } + ] diff --git a/tests-unit/comfy_test/test_seedvr2_vae_decode.py b/tests-unit/comfy_test/test_seedvr2_vae_decode.py new file mode 100644 index 000000000..ea9f978f3 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_vae_decode.py @@ -0,0 +1,91 @@ +from unittest.mock import patch + +import pytest +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 +from comfy_extras import nodes_seedvr # noqa: E402 + + +def _make_wrapper() -> vae_mod.VideoAutoencoderKLWrapper: + wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( + vae_mod.VideoAutoencoderKLWrapper + ) + nn.Module.__init__(wrapper) + return wrapper + + +def _fingerprint_decode_(self, z, return_dict=True): + b = int(z.shape[0]) + t = int(z.shape[2]) + h = int(z.shape[3]) + w = int(z.shape[4]) + out = torch.empty(b, 3, t, h * 8, w * 8) + for batch_idx in range(b): + out[batch_idx].fill_(float(batch_idx + 1)) + return out + + +def _decode_with_patches(wrapper, z): + with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _fingerprint_decode_): + return wrapper.decode(z) + + +def test_decode_b2_t3_multi_frame_batch_unchanged(): + wrapper = _make_wrapper() + + out = _decode_with_patches(wrapper, torch.zeros(2, 16 * 3, 2, 2)) + + assert tuple(out.shape) == (2, 3, 3, 16, 16) + + +class _Wrapper(vae_mod.VideoAutoencoderKLWrapper): + def __init__(self): + nn.Module.__init__(self) + self.calls = [] + + def parameters(self): + return iter([torch.nn.Parameter(torch.zeros(()))]) + +def _decode_stub(self, latent): + self.calls.append(tuple(latent.shape)) + return torch.zeros(latent.shape[0], 3, latent.shape[2], latent.shape[3] * 8, latent.shape[4] * 8) + + +def test_seedvr2_wrapper_decode_accepts_5d_channel_first_latents_without_preprocessor_state(): + wrapper = _Wrapper() + + with patch.object(vae_mod.VideoAutoencoderKL, "decode_", _decode_stub): + out = wrapper.decode(torch.zeros(1, 16, 2, 4, 5)) + + assert tuple(out.shape) == (1, 3, 2, 32, 40) + assert wrapper.calls == [(1, 16, 2, 4, 5)] + + +def test_seedvr2_wrapper_decode_rejects_wrong_rank_latents(): + wrapper = _Wrapper() + + with pytest.raises(RuntimeError, match=r"latent input must be 4-D collapsed .* or 5-D"): + wrapper.decode(torch.zeros(1, 16, 4)) + + +def _t_padded(t_in: int) -> int: + if t_in == 1: + return 1 + if t_in <= 4: + return 5 + if (t_in - 1) % 4 == 0: + return t_in + return t_in + (4 - ((t_in - 1) % 4)) + + +@pytest.mark.parametrize("t_in", [1, 5, 9]) +def test_t_padded_matches_cut_videos(t_in): + dummy = torch.zeros(1, t_in, 1, 1, 1) + assert nodes_seedvr.cut_videos(dummy).shape[1] == _t_padded(t_in) diff --git a/tests-unit/comfy_test/test_seedvr2_vae_tiled.py b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py new file mode 100644 index 000000000..40079bbe2 --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr2_vae_tiled.py @@ -0,0 +1,347 @@ +from contextlib import ExitStack +from unittest.mock import MagicMock, patch + +import torch +import torch.nn as nn + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.ldm.seedvr.vae as vae_mod # noqa: E402 +import comfy.ldm.seedvr.vae as seedvr_vae_mod # noqa: E402 +import comfy.sd as sd_mod # noqa: E402 +from comfy.ldm.seedvr.vae import MemoryState, tiled_vae # noqa: E402 + + +# --------------------------------------------------------------------------- +# From test_seedvr_vae_tiled_decode_latent_min_size_override.py +# --------------------------------------------------------------------------- + + +def test_runtime_decode_zero_temporal_size_disables_slicing_for_call(): + from comfy.ldm.seedvr.vae import MemoryState, VideoAutoencoderKL, tiled_vae + + class StubVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_latent_min_size = 2 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self.use_slicing = True + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.decode_min_sizes = [] + self.memory_states = [] + + def decode_(self, t_chunk): + self.decode_min_sizes.append(self.slicing_latent_min_size) + return VideoAutoencoderKL.slicing_decode(self, t_chunk) + + def _decode(self, z, memory_state=MemoryState.DISABLED): + self.memory_states.append(memory_state) + b, c, d, h, w = z.shape + return torch.zeros((b, 3, d, h * 8, w * 8), dtype=z.dtype) + + vae = StubVAEModel() + z = torch.zeros((1, 16, 5, 8, 8), dtype=torch.float32) + + tiled_vae( + z, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=False, + ) + + assert vae.decode_min_sizes == [5] + assert vae.memory_states == [MemoryState.DISABLED] + assert vae.slicing_latent_min_size == 2 + + +# --------------------------------------------------------------------------- +# From test_seedvr_vae_tiled_encode_runt_slice_override.py +# --------------------------------------------------------------------------- + + +def test_zero_temporal_size_preserves_min_size_when_encode_raises(): + from comfy.ldm.seedvr.vae import tiled_vae + + class RaisingVAEModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.slicing_sample_min_size = 4 + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self._dummy = torch.nn.Parameter(torch.zeros(1, dtype=torch.float32)) + + def encode(self, t_chunk): + raise RuntimeError("simulated encode failure") + + vae = RaisingVAEModel() + x = torch.zeros((1, 3, 12, 64, 64), dtype=torch.float32) + + raised = False + try: + tiled_vae( + x, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=0, + temporal_overlap=0, + encode=True, + ) + except RuntimeError as exc: + if "simulated encode failure" not in str(exc): + raise + raised = True + + assert raised + assert vae.slicing_sample_min_size == 4 + + +# --------------------------------------------------------------------------- +# From test_seedvr_vae_tiled_temporal_slicing.py +# --------------------------------------------------------------------------- + + +class _SlicingDecodeVAE(nn.Module): + def __init__(self, slicing_latent_min_size): + super().__init__() + self.slicing_latent_min_size = slicing_latent_min_size + self.spatial_downsample_factor = 8 + self.temporal_downsample_factor = 4 + self.device = torch.device("cpu") + self.use_slicing = True + self._dummy = nn.Parameter(torch.zeros(1, dtype=torch.float32)) + self.decode_min_sizes = [] + self.memory_states = [] + + def decode_(self, z): + self.decode_min_sizes.append(self.slicing_latent_min_size) + return vae_mod.VideoAutoencoderKL.slicing_decode(self, z) + + def _decode(self, z, memory_state=MemoryState.DISABLED): + self.memory_states.append(memory_state) + x = z[:, :1].repeat( + 1, + 3, + 1, + self.spatial_downsample_factor, + self.spatial_downsample_factor, + ) + return x + + +def test_decode_tiled_vae_maps_temporal_args_to_latent_slicing_min_size(): + vae = _SlicingDecodeVAE(slicing_latent_min_size=2) + z = torch.arange(1 * 16 * 5 * 8 * 8, dtype=torch.float32).reshape(1, 16, 5, 8, 8) + + tiled_vae( + z, + vae, + tile_size=(64, 64), + tile_overlap=(0, 0), + temporal_size=12, + temporal_overlap=4, + encode=False, + ) + + assert vae.decode_min_sizes == [2] + assert vae.memory_states == [MemoryState.INITIALIZING, MemoryState.ACTIVE] + assert vae.slicing_latent_min_size == 2 + + wrapper = vae_mod.VideoAutoencoderKLWrapper.__new__( + vae_mod.VideoAutoencoderKLWrapper + ) + nn.Module.__init__(wrapper) + seedvr2_tiling = { + "enable_tiling": True, + "tile_size": (64, 64), + "tile_overlap": (0, 0), + "temporal_size": 8, + "temporal_overlap": 7, + } + + captured = {} + + def _fake_tiled_vae(latent, model, **kwargs): + captured.update(kwargs) + return torch.zeros(1, 3, 1, 16, 16) + + with patch.object(vae_mod, "tiled_vae", side_effect=_fake_tiled_vae): + wrapper.decode(torch.zeros(1, 16, 2, 2), seedvr2_tiling=seedvr2_tiling) + + assert captured["temporal_overlap"] == 7 + + +# --------------------------------------------------------------------------- +# From test_vae_decode_tiled_dispatcher_seedvr2_4d.py +# --------------------------------------------------------------------------- + + +def _force_oom(*a, **k): + raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") + + +def _make_vae(first_stage_model, latent_channels, latent_dim): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = first_stage_model + vae.patcher = MagicMock() + vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) + vae.device = vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.upscale_ratio = vae.downscale_ratio = 8 + vae.upscale_index_formula = vae.downscale_index_formula = None + vae.output_channels = 3 + vae.latent_channels = latent_channels + vae.latent_dim = latent_dim + vae.vae_output_dtype = lambda: torch.float32 + vae.spacial_compression_decode = lambda: 8 + vae.process_input = lambda x: x + vae.process_output = lambda x: x + vae.throw_exception_if_invalid = lambda: None + vae.memory_used_decode = lambda *a, **k: 1 + return vae + + +def _dispatch(vae, samples, seedvr2_call, generic_call, patch_wrapper_decode): + mm = sd_mod.model_management + with ExitStack() as stack: + stack.enter_context(patch.object(mm, "raise_non_oom", lambda e: None)) + stack.enter_context(patch.object(mm, "load_models_gpu", lambda *a, **k: None)) + stack.enter_context(patch.object(mm, "soft_empty_cache", lambda: None)) + stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_seedvr2", seedvr2_call)) + stack.enter_context(patch.object(sd_mod.VAE, "decode_tiled_", generic_call)) + if patch_wrapper_decode: + stack.enter_context(patch.object( + seedvr_vae_mod.VideoAutoencoderKLWrapper, "decode", + side_effect=_force_oom)) + vae.decode(samples) + + +def test_4d_seedvr2_latent_routes_to_decode_tiled_seedvr2(): + wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( + seedvr_vae_mod.VideoAutoencoderKLWrapper) + vae = _make_vae(wrapper, latent_channels=16, latent_dim=3) + seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) + generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) + _dispatch(vae, torch.zeros(1, 16 * 3, 8, 8), seedvr2_call, generic_call, True) + assert seedvr2_call.call_count == 1 + assert generic_call.call_count == 0 + + +def test_4d_non_seedvr2_latent_still_routes_to_generic_decode_tiled(): + first_stage = MagicMock() + first_stage.decode = MagicMock(side_effect=_force_oom) + vae = _make_vae(first_stage, latent_channels=4, latent_dim=2) + seedvr2_call = MagicMock(return_value=torch.zeros(1, 3, 9, 64, 64)) + generic_call = MagicMock(return_value=torch.zeros(1, 3, 64, 64)) + _dispatch(vae, torch.zeros(1, 4, 8, 8), seedvr2_call, generic_call, False) + assert generic_call.call_count == 1 + assert seedvr2_call.call_count == 0 + + +# --------------------------------------------------------------------------- +# From test_vae_encode_tiled_fallback_dispatcher_seedvr2.py +# --------------------------------------------------------------------------- + + +def _populate_common_vae_attrs_fallback(vae): + vae.patcher = MagicMock() + vae.patcher.get_free_memory = MagicMock(return_value=8 * 1024 * 1024 * 1024) + vae.device = torch.device("cpu") + vae.output_device = torch.device("cpu") + vae.vae_dtype = torch.float32 + vae.disable_offload = True + vae.extra_1d_channel = None + vae.upscale_ratio = 8 + vae.upscale_index_formula = None + vae.output_channels = 3 + vae.latent_channels = 16 + vae.latent_dim = 3 + vae.downscale_ratio = 8 + vae.downscale_index_formula = None + vae.not_video = False + vae.crop_input = False + vae.pad_channel_value = None + + vae.vae_output_dtype = lambda: torch.float32 + vae.spacial_compression_encode = lambda: 8 + vae.process_input = lambda x: x + vae.process_output = lambda x: x + vae.throw_exception_if_invalid = lambda: None + vae.memory_used_encode = lambda *a, **k: 1 + + +def _make_seedvr2_vae_fallback(): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + wrapper = seedvr_vae_mod.VideoAutoencoderKLWrapper.__new__( + seedvr_vae_mod.VideoAutoencoderKLWrapper + ) + vae.first_stage_model = wrapper + _populate_common_vae_attrs_fallback(vae) + return vae + + +def _make_non_seedvr2_vae_fallback(): + vae = sd_mod.VAE.__new__(sd_mod.VAE) + vae.first_stage_model = MagicMock() + _populate_common_vae_attrs_fallback(vae) + return vae + + +def _force_regular_encode_oom(*args, **kwargs): + raise torch.cuda.OutOfMemoryError("forced OOM for dispatcher test") + + +def test_seedvr2_3d_routes_to_encode_tiled_seedvr2_on_oom(): + vae = _make_seedvr2_vae_fallback() + pixel_samples = torch.zeros((1, 8, 64, 64, 3)) + + seedvr2_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + + with patch.object(sd_mod.model_management, "raise_non_oom", + lambda e: None), \ + patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.model_management, "soft_empty_cache", + lambda: None), \ + patch.object(seedvr_vae_mod.VideoAutoencoderKLWrapper, "encode", + side_effect=_force_regular_encode_oom), \ + patch.object(sd_mod.VAE, "encode_tiled_seedvr2", seedvr2_call, + create=True), \ + patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): + vae.encode(pixel_samples) + + assert seedvr2_call.call_count == 1, ( + f"Expected encode_tiled_seedvr2 to be called once for a SeedVR2 3D " + f"input under OOM fallback; got {seedvr2_call.call_count} calls." + ) + assert generic_call.call_count == 0, ( + f"encode_tiled_3d must NOT be called for a SeedVR2 input; got " + f"{generic_call.call_count} calls." + ) + + +def test_non_seedvr2_encode_tiled_3d_default_overlap_is_concrete(): + vae = _make_non_seedvr2_vae_fallback() + vae.downscale_ratio = (lambda a: max(1, a // 4), 8, 8) + vae.upscale_ratio = (lambda a: a * 4, 8, 8) + generic_call = MagicMock(return_value=torch.zeros(1, 16, 2, 8, 8)) + pixel_samples = torch.zeros((1, 8, 64, 64, 3)) + + with patch.object(sd_mod.model_management, "load_models_gpu", + lambda *a, **k: None), \ + patch.object(sd_mod.VAE, "encode_tiled_3d", generic_call): + vae.encode_tiled(pixel_samples) + + assert generic_call.call_args.kwargs["overlap"] == (1, 64, 64) diff --git a/tests-unit/comfy_test/test_seedvr_progressive_sampler.py b/tests-unit/comfy_test/test_seedvr_progressive_sampler.py new file mode 100644 index 000000000..05291989e --- /dev/null +++ b/tests-unit/comfy_test/test_seedvr_progressive_sampler.py @@ -0,0 +1,126 @@ +"""Unit tests for ``comfy_extras.nodes_seedvr.SeedVR2ProgressiveSampler``.""" + +from unittest.mock import patch + +import pytest +import torch + +from comfy.cli_args import args as cli_args + +if not torch.cuda.is_available(): + cli_args.cpu = True + +import comfy.sample # noqa: E402 +import comfy_extras.nodes_seedvr as nodes_seedvr_mod # noqa: E402 +from comfy_extras.nodes_seedvr import SeedVR2ProgressiveSampler # noqa: E402 + +_LAT_C = 16 +_COND_C = 17 + + +def _make_inputs(B: int = 1, T: int = 5, H: int = 8, W: int = 8): + """Build minimal SeedVR2-shaped sampling inputs.""" + samples_5d = torch.arange( + B * _LAT_C * T * H * W, dtype=torch.float32 + ).reshape(B, _LAT_C, T, H, W) + samples = samples_5d.reshape(B, _LAT_C * T, H, W).contiguous() + + cond_5d = torch.arange( + B * _COND_C * T * H * W, dtype=torch.float32 + ).reshape(B, _COND_C, T, H, W) + 10000.0 + cond = cond_5d.reshape(B, _COND_C * T, H, W).contiguous() + + text_pos = torch.zeros(1, 4, 32) + text_neg = torch.zeros(1, 4, 32) + positive = [[text_pos, {"condition": cond.clone()}]] + negative = [[text_neg, {"condition": cond.clone()}]] + latent_image = {"samples": samples} + return latent_image, positive, negative, samples_5d, cond_5d + + +def _identity_fix_empty(model, latent_image, downscale_ratio_spacial=None): + return latent_image + + +def _fingerprinted_prepare_noise(latent_image, seed, batch_inds=None): + """Return a tensor whose values encode ``(seed, position)``.""" + base = torch.arange( + latent_image.numel(), dtype=torch.float32 + ).reshape(latent_image.shape) + return base + float(seed) * 1e6 + + +def test_progressive_sampler_schema_exposes_manual_default_auto_chunking(): + schema = SeedVR2ProgressiveSampler.define_schema() + inputs = {item.id: item for item in schema.inputs} + + assert inputs["chunking_mode"].options == ["manual", "auto"] + assert inputs["chunking_mode"].default == "manual" + + +def test_auto_chunking_walks_two_three_four_chunk_ladder(): + """Auto mode must walk 2-, 3-, then 4-chunk geometries on OOM.""" + latent, pos, neg, _, _ = _make_inputs(T=17) + calls = [] + + def _oom_until_four_chunks(model, noise, steps, cfg, sampler_name, + scheduler, positive, negative, + latent_image, denoise=1.0, + noise_mask=None, seed=None): + calls.append(tuple(latent_image.shape)) + if latent_image.shape[1] > _LAT_C * 5: + raise torch.cuda.OutOfMemoryError("chunk too large") + return latent_image.clone() + + with patch.object(comfy.sample, "sample", + side_effect=_oom_until_four_chunks), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise), \ + patch.object(nodes_seedvr_mod.comfy.model_management, + "soft_empty_cache") as soft_empty: + out = SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent=latent, + denoise=1.0, frames_per_chunk=65, temporal_overlap=0, + chunking_mode="auto", + ) + + assert calls[:4] == [ + (1, _LAT_C * 17, 8, 8), + (1, _LAT_C * 9, 8, 8), + (1, _LAT_C * 6, 8, 8), + (1, _LAT_C * 5, 8, 8), + ] + assert torch.equal(out.result[0]["samples"], latent["samples"]) + assert soft_empty.call_count == 3 + + +@pytest.mark.parametrize("bad_chunk", [0, -1, 2]) +def test_t3_invalid_frames_per_chunk_raises_value_error(bad_chunk): + """``frames_per_chunk`` violating 4n+1 (or <1) must raise ``ValueError`` before any model invocation.""" + latent, pos, neg, _, _ = _make_inputs(T=5) + + sampler_called = {"n": 0} + + def _should_not_be_called(*args, **kwargs): + sampler_called["n"] += 1 + return torch.zeros(1) + + with patch.object(comfy.sample, "sample", + side_effect=_should_not_be_called), \ + patch.object(comfy.sample, "fix_empty_latent_channels", + side_effect=_identity_fix_empty), \ + patch.object(comfy.sample, "prepare_noise", + side_effect=_fingerprinted_prepare_noise): + with pytest.raises(ValueError) as excinfo: + SeedVR2ProgressiveSampler.execute( + model=None, seed=0, steps=2, cfg=1.0, + sampler_name="euler", scheduler="simple", + positive=pos, negative=neg, latent=latent, + denoise=1.0, frames_per_chunk=bad_chunk, temporal_overlap=0, + ) + assert str(bad_chunk) in str(excinfo.value) + assert sampler_called["n"] == 0