Merge branch 'master' into drop-vestigial-tag-type
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

This commit is contained in:
guill 2026-06-09 20:48:28 -07:00 committed by GitHub
commit 098e07fdfd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
60 changed files with 5020 additions and 433 deletions

View File

@ -1,5 +1,4 @@
As of the time of writing this you need this driver for best results:
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
As of the time of writing this you need a recent driver. Updating to the latest driver is recommended.
HOW TO RUN:
@ -7,9 +6,9 @@ If you have a AMD gpu:
run_amd_gpu.bat
If you have memory issues you can try disabling the smart memory management by running comfyui with:
If you have memory issues you can try enabling the new dynamic memory management by running comfyui with:
run_amd_gpu_disable_smart_memory.bat
run_amd_gpu_enable_dynamic_vram.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints

View File

@ -17,7 +17,7 @@ jobs:
- name: Check for Windows line endings (CRLF)
run: |
# Get the list of changed files in the PR
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} -- ':!.ci')
# Flag to track if CRLF is found
CRLF_FOUND=false

View File

@ -33,6 +33,7 @@ from app.assets.services.file_utils import (
verify_file_unchanged,
)
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.metadata_extract import extract_file_metadata
from app.assets.services.path_utils import (
compute_relative_filename,
@ -506,6 +507,10 @@ def enrich_asset(
if extract_metadata and metadata:
system_metadata = metadata.to_user_metadata()
if mime_type and mime_type.startswith("image/"):
dims = extract_image_dimensions(file_path, mime_type=mime_type)
if dims:
system_metadata.update(dims)
set_reference_system_metadata(session, reference_id, system_metadata)
if full_hash:

View File

@ -0,0 +1,63 @@
"""Image dimension extraction for asset ingest.
Reads only the image header via Pillow to capture width/height cheaply,
without a full pixel decode. Returns a metadata dict suitable for merging
into ``AssetReference.system_metadata``.
"""
from __future__ import annotations
import logging
from typing import Any
logger = logging.getLogger(__name__)
def extract_image_dimensions(
file_path: str, mime_type: str | None = None
) -> dict[str, Any] | None:
"""Extract image dimensions for the file at ``file_path``.
Args:
file_path: Absolute path to a file on disk.
mime_type: Optional MIME type hint. When provided and not prefixed
with ``image/``, extraction is skipped without touching the file.
Returns:
``{"kind": "image", "width": W, "height": H}`` when the file is a
recognizable image with positive dimensions, otherwise ``None``.
The dict shape is intended to be merged into ``system_metadata`` so the
asset response surfaces ``metadata.kind`` plus dimension fields for image
assets. Forward-compatible: future media kinds (e.g. ``"video"`` with
duration/fps) can extend this shape without schema changes.
"""
if mime_type is not None and not mime_type.startswith("image/"):
return None
try:
from PIL import Image, UnidentifiedImageError
except ImportError:
logger.debug(
"Pillow not available; skipping image dimension extraction for %s",
file_path,
)
return None
try:
with Image.open(file_path) as img:
width, height = img.size
except (OSError, UnidentifiedImageError, ValueError) as exc:
logger.debug(
"Failed to read image dimensions from %s: %s", file_path, exc
)
return None
if (
not isinstance(width, int)
or not isinstance(height, int)
or width <= 0
or height <= 0
):
return None
return {"kind": "image", "width": width, "height": height}

View File

@ -17,9 +17,11 @@ from app.assets.database.queries import (
get_reference_by_file_path,
get_reference_tags,
get_or_create_reference,
list_references_by_asset_id,
reference_exists,
remove_missing_tag_for_asset_id,
set_reference_metadata,
set_reference_system_metadata,
set_reference_tags,
update_asset_hash_and_mime,
upsert_asset,
@ -29,6 +31,7 @@ from app.assets.database.queries import (
from app.assets.helpers import get_utc_now, normalize_tags
from app.assets.services.bulk_ingest import batch_insert_seed_assets
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.image_dimensions import extract_image_dimensions
from app.assets.services.path_utils import (
compute_relative_filename,
get_name_and_tags_from_asset_path,
@ -118,6 +121,14 @@ def _ingest_file_from_path(
user_metadata=user_metadata,
)
_maybe_store_image_dimensions(
session,
reference_id=reference_id,
file_path=locator,
mime_type=mime_type,
current_system_metadata=ref.system_metadata,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
@ -288,6 +299,13 @@ def _register_existing_asset(
user_metadata=new_meta,
)
_backfill_image_dimensions_from_siblings(
session,
asset_id=asset.id,
new_reference_id=ref.id,
current_system_metadata=ref.system_metadata,
)
if tags is not None:
set_reference_tags(
session,
@ -334,6 +352,87 @@ def _update_metadata_with_filename(
)
_IMAGE_DIMENSION_KEYS = ("kind", "width", "height")
def _maybe_store_image_dimensions(
session: Session,
reference_id: str,
file_path: str,
mime_type: str | None,
current_system_metadata: dict | None,
) -> None:
"""Populate ``kind``/``width``/``height`` on system_metadata for image refs.
Non-image MIME types are a no-op. Pre-existing keys (e.g. enricher-written
safetensors metadata, download provenance) are preserved by merge.
"""
if not mime_type or not mime_type.startswith("image/"):
return
dims = extract_image_dimensions(file_path, mime_type=mime_type)
if not dims:
return
current = current_system_metadata or {}
merged = dict(current)
merged.update(dims)
if merged != current:
set_reference_system_metadata(
session,
reference_id=reference_id,
system_metadata=merged,
)
def _backfill_image_dimensions_from_siblings(
session: Session,
asset_id: str,
new_reference_id: str,
current_system_metadata: dict | None,
) -> None:
"""Copy image dimension keys from any sibling reference of the same asset.
The from-hash path doesn't read the file bytes, so dimensions can't be
extracted there directly. When another reference of the same asset already
carries image dimensions, copy them onto the new reference so consumers
see consistent metadata regardless of how the asset was registered.
Best-effort: missing siblings, non-image siblings, or absent dimension
keys leave the target reference unchanged.
"""
current = current_system_metadata or {}
if current.get("kind") == "image" and "width" in current and "height" in current:
return
for sibling in list_references_by_asset_id(session, asset_id):
if sibling.id == new_reference_id:
continue
meta = sibling.system_metadata or {}
if meta.get("kind") != "image":
continue
width = meta.get("width")
height = meta.get("height")
if (
type(width) is not int
or type(height) is not int
or width <= 0
or height <= 0
):
continue
merged = dict(current)
merged["kind"] = "image"
merged["width"] = width
merged["height"] = height
if merged != current:
set_reference_system_metadata(
session,
reference_id=new_reference_id,
system_metadata=merged,
)
return
def _sanitize_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
return n if n else fallback

View File

@ -166,6 +166,8 @@ class PerformanceFeature(enum.Enum):
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--debug-hang", action="store_true", help="Enable stack trace dumps on Ctrl-C for debugging hangs.")
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")

View File

@ -1,7 +1,13 @@
import torch
import torch.nn.functional as F
from comfy.text_encoders.bert import BertAttention
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.ldm.depth_anything_3.reference_view_selector import (
select_reference_view, reorder_by_reference, restore_original_order,
THRESH_FOR_REF_SELECTION,
)
class Dino2AttentionOutput(torch.nn.Module):
@ -14,13 +20,41 @@ class Dino2AttentionOutput(torch.nn.Module):
class Dino2AttentionBlock(torch.nn.Module):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations,
qk_norm=False):
super().__init__()
self.heads = heads
self.head_dim = embed_dim // heads
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
if qk_norm:
self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
else:
self.q_norm = None
self.k_norm = None
def forward(self, x, mask, optimized_attention):
return self.output(self.attention(x, mask, optimized_attention))
def forward(self, x, mask, optimized_attention, pos=None, rope=None):
# Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions).
if self.q_norm is None and rope is None:
return self.output(self.attention(x, mask, optimized_attention))
# DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE.
attn = self.attention
B, N, C = x.shape
h = self.heads
d = self.head_dim
q = attn.query(x).view(B, N, h, d).transpose(1, 2)
k = attn.key(x).view(B, N, h, d).transpose(1, 2)
v = attn.value(x).view(B, N, h, d).transpose(1, 2)
if self.q_norm is not None:
q = self.q_norm(q)
k = self.k_norm(k)
if rope is not None and pos is not None:
q = rope(q, pos)
k = rope(k, pos)
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
return self.output(out)
class LayerScale(torch.nn.Module):
@ -64,9 +98,11 @@ class SwiGLUFFN(torch.nn.Module):
class Dino2Block(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn,
qk_norm=False):
super().__init__()
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations,
qk_norm=qk_norm)
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
if use_swiglu_ffn:
@ -76,19 +112,90 @@ class Dino2Block(torch.nn.Module):
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, x, optimized_attention):
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None):
x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention,
pos=pos, rope=rope))
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
return x
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
# -----------------------------------------------------------------------------
# 2D Rotary position embedding (DA3 extension)
# -----------------------------------------------------------------------------
class _PositionGetter:
"""Cache (h, w) -> flat (y, x) position grid used to feed ``rope``."""
def __init__(self):
self._cache: dict = {}
def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor:
key = (height, width, device)
if key not in self._cache:
y = torch.arange(height, device=device)
x = torch.arange(width, device=device)
self._cache[key] = torch.cartesian_prod(y, x)
cached = self._cache[key]
return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
class RotaryPositionEmbedding2D(torch.nn.Module):
"""2D RoPE used by DA3-Small/Base. No learnable parameters."""
def __init__(self, frequency: float = 100.0):
super().__init__()
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
for _ in range(num_layers)])
self.base_frequency = frequency
self._freq_cache: dict = {}
def _components(self, dim: int, seq_len: int, device, dtype):
key = (dim, seq_len, device, dtype)
if key not in self._freq_cache:
exp = torch.arange(0, dim, 2, device=device).float() / dim
inv_freq = 1.0 / (self.base_frequency ** exp)
pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
ang = torch.einsum("i,j->ij", pos, inv_freq)
ang = ang.to(dtype)
ang = torch.cat((ang, ang), dim=-1)
self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype))
return self._freq_cache[key]
@staticmethod
def _rotate(x: torch.Tensor) -> torch.Tensor:
d = x.shape[-1]
x1, x2 = x[..., : d // 2], x[..., d // 2:]
return torch.cat((-x2, x1), dim=-1)
def _apply_1d(self, tokens, positions, cos_c, sin_c):
cos = F.embedding(positions, cos_c)[:, None, :, :]
sin = F.embedding(positions, sin_c)[:, None, :, :]
return (tokens * cos) + (self._rotate(tokens) * sin)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
feature_dim = tokens.size(-1) // 2
max_pos = int(positions.max()) + 1
cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype)
v, h = tokens.chunk(2, dim=-1)
v = self._apply_1d(v, positions[..., 0], cos_c, sin_c)
h = self._apply_1d(h, positions[..., 1], cos_c, sin_c)
return torch.cat((v, h), dim=-1)
class Dino2Encoder(torch.nn.Module):
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn,
qknorm_start: int = -1):
super().__init__()
self.layer = torch.nn.ModuleList([
Dino2Block(
dim, num_heads, layer_norm_eps, dtype, device, operations,
use_swiglu_ffn=use_swiglu_ffn,
qk_norm=(qknorm_start != -1 and i >= qknorm_start),
)
for i in range(num_layers)
])
def forward(self, x, intermediate_output=None):
# Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions).
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
if intermediate_output is not None:
@ -122,16 +229,27 @@ class Dino2PatchEmbeddings(torch.nn.Module):
class Dino2Embeddings(torch.nn.Module):
def __init__(self, dim, dtype, device, operations):
def __init__(self, dim, dtype, device, operations,
patch_size: int = 14, image_size: int = 518,
use_mask_token: bool = True,
num_camera_tokens: int = 0):
super().__init__()
patch_size = 14
image_size = 518
self.patch_size = patch_size
self.image_size = image_size
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key.
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
if use_mask_token:
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
else:
self.mask_token = None
if num_camera_tokens > 0:
# DA3 stores (ref_token, src_token) pairs that get injected at the
# alt-attn boundary; see ``Dinov2Model._inject_camera_token``.
self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device))
else:
self.camera_token = None
def interpolate_pos_encoding(self, x, h_pixels, w_pixels):
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32)
@ -140,12 +258,22 @@ class Dino2Embeddings(torch.nn.Module):
patch_pos = pos_embed[:, 1:]
N = patch_pos.shape[1]
M = int(N ** 0.5)
assert N == M * M, f"DINOv2 position grid must be square, got N={N} patches (sqrt={M})"
h0 = h_pixels // self.patch_size
w0 = w_pixels // self.patch_size
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
# +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
# scale_factor is (height_scale, width_scale) -- height MUST come first;
# swapping these only happens to work for square inputs and breaks
# non-square paths like DA3-Small / DA3-Base multi-view.
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M)
patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2)
patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False)
assert (h0, w0) == patch_pos.shape[-2:], (
f"Interpolated pos-embed grid {tuple(patch_pos.shape[-2:])} does not match "
f"target patch grid ({h0}, {w0}) for input {h_pixels}x{w_pixels} (patch_size={self.patch_size}); "
f"check scale_factor axis order and +0.1 rounding workaround"
)
patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2)
return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)
@ -168,12 +296,51 @@ class Dinov2Model(torch.nn.Module):
heads = config_dict["num_attention_heads"]
layer_norm_eps = config_dict["layer_norm_eps"]
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
patch_size = config_dict.get("patch_size", 14)
image_size = config_dict.get("image_size", 518)
use_mask_token = config_dict.get("use_mask_token", True)
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
# DA3 extensions (all default to disabled).
self.alt_start = config_dict.get("alt_start", -1)
self.qknorm_start = config_dict.get("qknorm_start", -1)
self.rope_start = config_dict.get("rope_start", -1)
self.cat_token = config_dict.get("cat_token", False)
rope_freq = config_dict.get("rope_freq", 100.0)
self.embed_dim = dim
self.patch_size = patch_size
self.num_register_tokens = 0
self.patch_start_idx = 1
if self.rope_start != -1 and rope_freq > 0:
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq)
self._position_getter = _PositionGetter()
else:
self.rope = None
self._position_getter = None
# camera_token shape: (1, 2, dim) -> (ref_token, src_token).
num_cam_tokens = 2 if self.alt_start != -1 else 0
self.embeddings = Dino2Embeddings(
dim, dtype, device, operations,
patch_size=patch_size, image_size=image_size,
use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens,
)
self.encoder = Dino2Encoder(
dim, heads, layer_norm_eps, num_layers, dtype, device, operations,
use_swiglu_ffn=use_swiglu_ffn,
qknorm_start=self.qknorm_start,
)
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
if self.alt_start != -1:
raise RuntimeError(
"Dinov2Model.forward() is the backward-compatible CLIP-vision path and does not "
"apply DA3 extensions (RoPE, alternating attention, camera-token injection). "
"Use get_intermediate_layers_da3() for Depth Anything 3 models."
)
x = self.embeddings(pixel_values)
x, i = self.encoder(x, intermediate_output=intermediate_output)
x = self.layernorm(x)
@ -181,6 +348,7 @@ class Dinov2Model(torch.nn.Module):
return x, i, pooled_output, None
def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
"""Single-view multi-layer feature extraction."""
x = self.embeddings(pixel_values)
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
n_layers = len(self.encoder.layer)
@ -197,3 +365,132 @@ class Dinov2Model(torch.nn.Module):
if i >= max_idx:
break
return [cache[i] for i in resolved]
# ------------------------------------------------------------------
# Depth Anything 3 forward
# ------------------------------------------------------------------
def _prepare_rope_positions(self, B, S, H, W, device):
if self.rope is None:
return None, None
ph, pw = H // self.patch_size, W // self.patch_size
pos = self._position_getter(B * S, ph, pw, device=device)
# Shift so the cls/cam token at position 0 is reserved for "no diff".
pos = pos + 1
cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype)
# Per-view local: real grid positions for patches, 0 for cls token.
pos_local = torch.cat([cls_pos, pos], dim=1)
# Global (across views): same grid positions; cls token still at 0,
# but patches share the same positions in every view.
pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1)
return pos_local, pos_global
def _inject_camera_token(self, x: torch.Tensor, B: int, S: int, cam_token: "torch.Tensor | None") -> torch.Tensor:
# x: (B, S, N, C). Replace token at index 0 with the camera token.
if cam_token is not None:
inj = cam_token
else:
ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype)
ref_token = ct[:, :1].expand(B, -1, -1)
src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1)
inj = torch.cat([ref_token, src_token], dim=1)
x = x.clone()
x[:, :, 0] = inj
return x
def get_intermediate_layers_da3(self, pixel_values, out_layers, cam_token=None, ref_view_strategy="saddle_balanced", export_feat_layers=None):
"""Multi-view multi-layer feature extraction used by Depth Anything 3."""
if pixel_values.ndim == 4:
pixel_values = pixel_values.unsqueeze(1)
assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \
f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}"
B, S, _, H, W = pixel_values.shape
# Patch + cls + (interpolated) pos embed for each view.
x = pixel_values.reshape(B * S, 3, H, W)
x = self.embeddings(x) # (B*S, 1+N, C)
x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C)
pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device)
# optimized_attention is only used by blocks without QK-norm/RoPE
# (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA.
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
out_set = set(out_layers)
export_set = set(export_feat_layers) if export_feat_layers else set()
outputs: list[torch.Tensor] = []
aux_outputs: list[torch.Tensor] = []
local_x = x
b_idx = None
for i, blk in enumerate(self.encoder.layer):
apply_rope = self.rope is not None and i >= self.rope_start
block_rope = self.rope if apply_rope else None
l_pos = pos_local if apply_rope else None
g_pos = pos_global if apply_rope else None
# Reference-view selection threshold: matches the upstream constant
# THRESH_FOR_REF_SELECTION = 3. Skipped when a user-supplied
# cam_token is provided (camera info already pins the geometry).
if (self.alt_start != -1 and i == self.alt_start - 1 and S >= THRESH_FOR_REF_SELECTION and cam_token is None):
b_idx = select_reference_view(x, strategy=ref_view_strategy)
x = reorder_by_reference(x, b_idx)
local_x = reorder_by_reference(local_x, b_idx)
if self.alt_start != -1 and i == self.alt_start:
x = self._inject_camera_token(x, B, S, cam_token)
if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1):
# Global attention across views: flatten S into the seq dim.
t = x.reshape(B, S * x.shape[-2], x.shape[-1])
p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
else:
# Per-view local attention.
t = x.reshape(B * S, x.shape[-2], x.shape[-1])
p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
local_x = x
if i in out_set:
if self.cat_token:
out_x = torch.cat([local_x, x], dim=-1)
else:
out_x = x
# Restore original view order on the way out so heads see views
# in the user's expected order.
if b_idx is not None and self.alt_start != -1:
out_x = restore_original_order(out_x, b_idx)
outputs.append(out_x)
if i in export_set:
aux = x
if b_idx is not None and self.alt_start != -1:
aux = restore_original_order(aux, b_idx)
aux_outputs.append(aux)
# Apply final norm. When cat_token is set, only the right half
# ("global" features) is normalised; the left half is left as-is to
# match the upstream DA3 head signature.
normed: list[torch.Tensor] = []
cls_tokens: list[torch.Tensor] = []
for out_x in outputs:
cls_tokens.append(out_x[:, :, 0])
if out_x.shape[-1] == self.embed_dim:
normed.append(self.layernorm(out_x))
elif out_x.shape[-1] == self.embed_dim * 2:
left = out_x[..., :self.embed_dim]
right = self.layernorm(out_x[..., self.embed_dim:])
normed.append(torch.cat([left, right], dim=-1))
else:
raise ValueError(f"Unexpected token width: {out_x.shape[-1]}")
# Drop cls/cam token from the patch sequence.
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
# Final layernorm + drop cls token from auxiliary features too.
aux_normed = [self.layernorm(o)[..., 1 + self.num_register_tokens:, :]
for o in aux_outputs]
return list(zip(normed, cls_tokens)), aux_normed

25
comfy/ldm/colormap.py Normal file
View File

@ -0,0 +1,25 @@
"""Colormap utilities for depth and geometry visualisation."""
from __future__ import annotations
import torch
def turbo(x: torch.Tensor) -> torch.Tensor:
"""Anton Mikhailov polynomial approximation of the Turbo colormap.
Args:
x: Float tensor with values in [0, 1].
Returns:
RGB tensor of the same shape as ``x`` with a trailing size-3 dimension.
"""
x = x.clamp(0.0, 1.0)
x2 = x * x
x3 = x2 * x
x4 = x2 * x2
x5 = x4 * x
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)

View File

@ -0,0 +1,177 @@
"""Camera-token encoder and decoder for Depth Anything 3."""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention_for_device
from .transform import affine_inverse, extri_intri_to_pose_encoding
# -----------------------------------------------------------------------
# Building blocks (mirror depth_anything_3.model.utils.{attention,block})
# -----------------------------------------------------------------------
class _Mlp(nn.Module):
"""Standard 2-layer MLP with GELU. Matches upstream ``utils.attention.Mlp``."""
def __init__(self, in_features, hidden_features=None, out_features=None, *, device=None, dtype=None, operations=None):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = operations.Linear(in_features, hidden_features, bias=True, device=device, dtype=dtype)
self.fc2 = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
def forward(self, x):
return self.fc2(F.gelu(self.fc1(x)))
class _LayerScale(nn.Module):
"""Per-channel learnable scaling. Matches upstream LayerScale."""
def __init__(self, dim, *, device=None, dtype=None):
super().__init__()
self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
def forward(self, x):
return x * self.gamma.to(dtype=x.dtype, device=x.device)
class _Attention(nn.Module):
""" Self-attention with fused QKV projection. Mirrors upstream utils.attention.Attention;
Layout matches the HF safetensors (attn.qkv.{weight,bias} and attn.proj.{weight,bias})."""
def __init__(self, dim, num_heads, *, device=None, dtype=None, operations=None):
super().__init__()
assert dim % num_heads == 0
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = operations.Linear(dim, dim * 3, bias=True, device=device, dtype=dtype)
self.proj = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, C)
q, k, v = qkv.unbind(2) # each (B, N, C)
attn_fn = optimized_attention_for_device(x.device, small_input=True)
out = attn_fn(q, k, v, heads=self.num_heads)
return self.proj(out)
class _Block(nn.Module):
"""Pre-norm transformer block with LayerScale. Used by :class:CameraEnc. Layout follows upstream utils.block.Block."""
def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01, *, device=None, dtype=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.attn = _Attention(dim, num_heads, device=device, dtype=dtype, operations=operations)
self.ls1 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
self.ls2 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
def forward(self, x):
x = x + self.ls1(self.attn(self.norm1(x)))
x = x + self.ls2(self.mlp(self.norm2(x)))
return x
class CameraEnc(nn.Module):
"""Encode per-view (extrinsics, intrinsics) into a camera token.
Maps a 9-D pose-encoding vector through a small MLP up to the backbone's
``embed_dim``, then runs ``trunk_depth`` transformer blocks. The output
has shape ``(B, S, embed_dim)`` and is injected at block ``alt_start``
of the DINOv2 backbone in place of the cls token.
Parameters mirror the upstream ``cam_enc.py`` so HF weights load directly.
"""
def __init__(
self,
dim_out: int = 1024,
dim_in: int = 9,
trunk_depth: int = 4,
target_dim: int = 9,
num_heads: int = 16,
mlp_ratio: int = 4,
init_values: float = 0.01,
*,
device=None, dtype=None, operations=None,
**_kwargs,
):
super().__init__()
self.target_dim = target_dim
self.trunk_depth = trunk_depth
self.trunk = nn.Sequential(*[
_Block(dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio,
init_values=init_values,
device=device, dtype=dtype, operations=operations)
for _ in range(trunk_depth)
])
self.token_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
self.trunk_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
self.pose_branch = _Mlp(
in_features=dim_in,
hidden_features=dim_out // 2,
out_features=dim_out,
device=device, dtype=dtype, operations=operations,
)
def forward(self, extrinsics: torch.Tensor, intrinsics: torch.Tensor,
image_size_hw) -> torch.Tensor:
"""Encode camera parameters into ``(B, S, dim_out)`` tokens."""
c2ws = affine_inverse(extrinsics)
pose_encoding = extri_intri_to_pose_encoding(c2ws, intrinsics, image_size_hw)
tokens = self.pose_branch(pose_encoding.to(self.pose_branch.fc1.weight.dtype))
tokens = self.token_norm(tokens)
tokens = self.trunk(tokens)
tokens = self.trunk_norm(tokens)
return tokens
class CameraDec(nn.Module):
"""Decode the final cam token into a 9-D pose encoding.
Output layout: ``[T(3), quat_xyzw(4), fov_h, fov_w]``. The translation is
always predicted by the network; the quaternion and FoV can either be
predicted or supplied via ``camera_encoding`` (used at training time
when GT cameras are available -- not exercised at inference here).
Parameters mirror the upstream ``cam_dec.py`` so HF weights load directly.
"""
def __init__(self, dim_in: int = 1536,
*, device=None, dtype=None, operations=None, **_kwargs):
super().__init__()
d = dim_in
self.backbone = nn.Sequential(
operations.Linear(d, d, device=device, dtype=dtype),
nn.ReLU(),
operations.Linear(d, d, device=device, dtype=dtype),
nn.ReLU(),
)
self.fc_t = operations.Linear(d, 3, device=device, dtype=dtype)
self.fc_qvec = operations.Linear(d, 4, device=device, dtype=dtype)
self.fc_fov = nn.Sequential(
operations.Linear(d, 2, device=device, dtype=dtype),
nn.ReLU(),
)
def forward(self, feat: torch.Tensor,
camera_encoding: "torch.Tensor | None" = None) -> torch.Tensor:
"""Decode ``(B, N, dim_in)`` cam tokens into ``(B, N, 9)`` pose enc."""
B, N = feat.shape[:2]
feat = feat.reshape(B * N, -1)
feat = self.backbone(feat)
out_t = self.fc_t(feat.float()).reshape(B, N, 3)
if camera_encoding is None:
out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4)
out_fov = self.fc_fov(feat.float()).reshape(B, N, 2)
else:
out_qvec = camera_encoding[..., 3:7]
out_fov = camera_encoding[..., -2:]
return torch.cat([out_t, out_qvec, out_fov], dim=-1)

View File

@ -0,0 +1,489 @@
"""DPT / DualDPT heads for Depth Anything 3."""
from __future__ import annotations
from typing import List, Optional, Sequence, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class Permute(nn.Module):
def __init__(self, dims: Tuple[int, ...]):
super().__init__()
self.dims = dims
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.permute(*self.dims)
def _custom_interpolate(
x: torch.Tensor,
size: Optional[Tuple[int, int]] = None,
scale_factor: Optional[float] = None,
mode: str = "bilinear",
align_corners: bool = True,
) -> torch.Tensor:
if size is None:
assert scale_factor is not None
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
INT_MAX = 1610612736
total = size[0] * size[1] * x.shape[0] * x.shape[1]
if total > INT_MAX:
chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0)
outs = [F.interpolate(c, size=size, mode=mode, align_corners=align_corners) for c in chunks]
return torch.cat(outs, dim=0).contiguous()
return F.interpolate(x, size=size, mode=mode, align_corners=align_corners)
def _create_uv_grid(width: int, height: int, aspect_ratio: float, dtype, device) -> torch.Tensor:
"""Normalised UV grid spanning (-x_span, -y_span)..(x_span, y_span)."""
diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5
span_x = aspect_ratio / diag_factor
span_y = 1.0 / diag_factor
left_x = -span_x * (width - 1) / width
right_x = span_x * (width - 1) / width
top_y = -span_y * (height - 1) / height
bottom_y = span_y * (height - 1) / height
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
return torch.stack((uu, vv), dim=-1) # (H, W, 2)
def _make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100.0) -> torch.Tensor:
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
omega = 1.0 / omega_0 ** (omega / (embed_dim / 2.0))
pos = pos.reshape(-1)
out = torch.einsum("m,d->md", pos, omega)
return torch.cat([out.sin(), out.cos()], dim=1).float()
def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100.0) -> torch.Tensor:
H, W, _ = pos_grid.shape
pos_flat = pos_grid.reshape(-1, 2)
emb_x = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0)
emb_y = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0)
emb = torch.cat([emb_x, emb_y], dim=-1)
return emb.view(H, W, embed_dim)
def _add_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
"""Stateless UV positional embedding added to a feature map (B, C, h, w)."""
pw, ph = x.shape[-1], x.shape[-2]
pe = _create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
pe = _position_grid_to_embed(pe, x.shape[1]) * ratio
pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1).to(dtype=x.dtype)
return x + pe
def _apply_activation(x: torch.Tensor, activation: str) -> torch.Tensor:
act = (activation or "linear").lower()
if act == "exp":
return torch.exp(x)
if act == "expp1":
return torch.exp(x) + 1
if act == "expm1":
return torch.expm1(x)
if act == "relu":
return torch.relu(x)
if act == "sigmoid":
return torch.sigmoid(x)
if act == "softplus":
return F.softplus(x)
if act == "tanh":
return torch.tanh(x)
return x
# -----------------------------------------------------------------------------
# Fusion building blocks
# -----------------------------------------------------------------------------
class ResidualConvUnit(nn.Module):
def __init__(self, features: int, device=None, dtype=None, operations=None):
super().__init__()
self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype)
self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype)
self.activation = nn.ReLU(inplace=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.activation(x)
out = self.conv1(out)
out = self.activation(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
def __init__(self, features: int, has_residual: bool = True, align_corners: bool = True, device=None, dtype=None, operations=None):
super().__init__()
self.align_corners = align_corners
self.has_residual = has_residual
if has_residual:
self.resConfUnit1 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
else:
self.resConfUnit1 = None
self.resConfUnit2 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True, device=device, dtype=dtype)
def forward(self, *xs: torch.Tensor, size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
y = xs[0]
if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
y = y + self.resConfUnit1(xs[1])
y = self.resConfUnit2(y)
if size is None:
up_kwargs = {"scale_factor": 2.0}
else:
up_kwargs = {"size": size}
y = _custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners)
y = self.out_conv(y)
return y
class _Scratch(nn.Module):
"""Container that mirrors upstream ``scratch`` attribute layout."""
def _make_scratch(in_shape: List[int], out_shape: int, device=None, dtype=None, operations=None) -> _Scratch:
scratch = _Scratch()
scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
return scratch
def _make_fusion_block(features: int, has_residual: bool = True, device=None, dtype=None, operations=None) -> FeatureFusionBlock:
return FeatureFusionBlock(features, has_residual=has_residual, align_corners=True, device=device, dtype=dtype, operations=operations)
# -----------------------------------------------------------------------------
# DPT (single head + optional sky head) -- used by DA3Mono/Metric
# -----------------------------------------------------------------------------
class DPT(nn.Module):
"""Single-head DPT used by DA3Mono-Large and DA3Metric-Large."""
def __init__(
self,
dim_in: int,
patch_size: int = 14,
output_dim: int = 1,
activation: str = "exp",
conf_activation: str = "expp1",
features: int = 256,
out_channels: Sequence[int] = (256, 512, 1024, 1024),
pos_embed: bool = False,
down_ratio: int = 1,
head_name: str = "depth",
use_sky_head: bool = True,
sky_name: str = "sky",
sky_activation: str = "relu",
norm_type: str = "idt",
device=None, dtype=None, operations=None,
):
super().__init__()
self.patch_size = patch_size
self.activation = activation
self.conf_activation = conf_activation
self.pos_embed = pos_embed
self.down_ratio = down_ratio
self.head_main = head_name
self.sky_name = sky_name
self.out_dim = output_dim
self.has_conf = output_dim > 1
self.use_sky_head = use_sky_head
self.sky_activation = sky_activation
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
if norm_type == "layer":
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
else:
self.norm = nn.Identity()
out_channels = list(out_channels)
self.projects = nn.ModuleList([
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype)
for oc in out_channels
])
self.resize_layers = nn.ModuleList([
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype),
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype),
nn.Identity(),
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype),
])
self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
head_features_1 = features
head_features_2 = 32
self.scratch.output_conv1 = operations.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
device=device, dtype=dtype,
)
self.scratch.output_conv2 = nn.Sequential(
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
nn.ReLU(inplace=False),
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
)
if self.use_sky_head:
self.scratch.sky_output_conv2 = nn.Sequential(
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
nn.ReLU(inplace=False),
operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
)
def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict:
# feats[i][0] is the patch-token tensor with shape (B, S, N_patch, C)
B, S, N, C = feats[0][0].shape
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
ph, pw = H // self.patch_size, W // self.patch_size
resized = []
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
x = feats_flat[take_idx][:, patch_start_idx:]
x = self.norm(x)
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
x = self.projects[stage_idx](x)
if self.pos_embed:
x = _add_pos_embed(x, W, H)
x = self.resize_layers[stage_idx](x)
resized.append(x)
l1_rn = self.scratch.layer1_rn(resized[0])
l2_rn = self.scratch.layer2_rn(resized[1])
l3_rn = self.scratch.layer3_rn(resized[2])
l4_rn = self.scratch.layer4_rn(resized[3])
out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
out = self.scratch.refinenet1(out, l1_rn)
h_out = int(ph * self.patch_size / self.down_ratio)
w_out = int(pw * self.patch_size / self.down_ratio)
fused = self.scratch.output_conv1(out)
fused = _custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
if self.pos_embed:
fused = _add_pos_embed(fused, W, H)
feat = fused
main_logits = self.scratch.output_conv2(feat)
outs = {}
if self.has_conf:
fmap = main_logits.permute(0, 2, 3, 1)
pred = _apply_activation(fmap[..., :-1], self.activation)
conf = _apply_activation(fmap[..., -1], self.conf_activation)
outs[self.head_main] = pred.squeeze(-1).view(B, S, *pred.shape[1:-1])
outs[f"{self.head_main}_conf"] = conf.view(B, S, *conf.shape[1:])
else:
pred = _apply_activation(main_logits, self.activation)
outs[self.head_main] = pred.squeeze(1).view(B, S, *pred.shape[2:])
if self.use_sky_head:
sky_logits = self.scratch.sky_output_conv2(feat)
if self.sky_activation.lower() == "sigmoid":
sky = torch.sigmoid(sky_logits)
elif self.sky_activation.lower() == "relu":
sky = F.relu(sky_logits)
else:
sky = sky_logits
outs[self.sky_name] = sky.squeeze(1).view(B, S, *sky.shape[2:])
return outs
# -----------------------------------------------------------------------------
# DualDPT (depth + auxiliary "ray" head) -- used by DA3-Small / DA3-Base
# -----------------------------------------------------------------------------
class DualDPT(nn.Module):
"""Two-head DPT used by DA3-Small / DA3-Base."""
def __init__(
self,
dim_in: int,
patch_size: int = 14,
output_dim: int = 2,
activation: str = "exp",
conf_activation: str = "expp1",
features: int = 256,
out_channels: Sequence[int] = (256, 512, 1024, 1024),
pos_embed: bool = True,
down_ratio: int = 1,
aux_pyramid_levels: int = 4,
aux_out1_conv_num: int = 5,
head_names: Tuple[str, str] = ("depth", "ray"),
device=None, dtype=None, operations=None,
):
super().__init__()
self.patch_size = patch_size
self.activation = activation
self.conf_activation = conf_activation
self.pos_embed = pos_embed
self.down_ratio = down_ratio
self.aux_levels = aux_pyramid_levels
self.aux_out1_conv_num = aux_out1_conv_num
self.head_main, self.head_aux = head_names
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
# Toggle the auxiliary ray branch at runtime. Default off (mono path).
# DepthAnything3Net flips this on when running multi-view + ray-pose.
self.enable_aux: bool = False
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
out_channels = list(out_channels)
self.projects = nn.ModuleList([
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype)
for oc in out_channels
])
self.resize_layers = nn.ModuleList([
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype),
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype),
nn.Identity(),
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype),
])
self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations)
# Main fusion chain
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
# Auxiliary fusion chain (separate copies)
self.scratch.refinenet1_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet2_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet3_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
head_features_1 = features
head_features_2 = 32
# Main head neck + final projection
self.scratch.output_conv1 = operations.Conv2d(
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
device=device, dtype=dtype,
)
self.scratch.output_conv2 = nn.Sequential(
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
nn.ReLU(inplace=False),
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
)
# Aux pre-head per level (multi-level pyramid)
self.scratch.output_conv1_aux = nn.ModuleList([
self._make_aux_out1_block(head_features_1, device=device, dtype=dtype, operations=operations)
for _ in range(self.aux_levels)
])
# Aux final projection per level (includes LayerNorm permute path).
ln_seq = [Permute((0, 2, 3, 1)),
operations.LayerNorm(head_features_2, device=device, dtype=dtype),
Permute((0, 3, 1, 2))]
self.scratch.output_conv2_aux = nn.ModuleList([
nn.Sequential(
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
*ln_seq,
nn.ReLU(inplace=False),
operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
)
for _ in range(self.aux_levels)
])
@staticmethod
def _make_aux_out1_block(in_ch: int, *, device=None, dtype=None, operations=None) -> nn.Sequential:
# aux_out1_conv_num=5 in all Apache-2.0 variants.
return nn.Sequential(
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
)
def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict:
B, S, N, C = feats[0][0].shape
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
ph, pw = H // self.patch_size, W // self.patch_size
resized = []
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
x = feats_flat[take_idx][:, patch_start_idx:]
x = self.norm(x)
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
x = self.projects[stage_idx](x)
if self.pos_embed:
x = _add_pos_embed(x, W, H)
x = self.resize_layers[stage_idx](x)
resized.append(x)
l1_rn = self.scratch.layer1_rn(resized[0])
l2_rn = self.scratch.layer2_rn(resized[1])
l3_rn = self.scratch.layer3_rn(resized[2])
l4_rn = self.scratch.layer4_rn(resized[3])
# Main pyramid (output_conv1 is applied inside the upstream `_fuse`,
# before interpolation -- replicate that order here).
m = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
if self.enable_aux:
a4 = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:])
aux_pyr = [a4]
m = self.scratch.refinenet3(m, l3_rn, size=l2_rn.shape[2:])
if self.enable_aux:
aux_pyr.append(self.scratch.refinenet3_aux(aux_pyr[-1], l3_rn, size=l2_rn.shape[2:]))
m = self.scratch.refinenet2(m, l2_rn, size=l1_rn.shape[2:])
if self.enable_aux:
aux_pyr.append(self.scratch.refinenet2_aux(aux_pyr[-1], l2_rn, size=l1_rn.shape[2:]))
m = self.scratch.refinenet1(m, l1_rn)
if self.enable_aux:
aux_pyr.append(self.scratch.refinenet1_aux(aux_pyr[-1], l1_rn))
m = self.scratch.output_conv1(m)
h_out = int(ph * self.patch_size / self.down_ratio)
w_out = int(pw * self.patch_size / self.down_ratio)
m = _custom_interpolate(m, (h_out, w_out), mode="bilinear", align_corners=True)
if self.pos_embed:
m = _add_pos_embed(m, W, H)
main_logits = self.scratch.output_conv2(m)
fmap = main_logits.permute(0, 2, 3, 1)
depth_pred = _apply_activation(fmap[..., :-1], self.activation)
depth_conf = _apply_activation(fmap[..., -1], self.conf_activation)
outs = {
self.head_main: depth_pred.squeeze(-1).view(B, S, *depth_pred.shape[1:-1]),
f"{self.head_main}_conf": depth_conf.view(B, S, *depth_conf.shape[1:]),
}
if self.enable_aux:
# Auxiliary "ray" head (multi-level inside) -- only the last level
# is returned. Mirrors upstream ``DualDPT._fuse`` + ``_forward_impl``:
# each aux pyramid level goes through ``output_conv1_aux[i]``
# (5-layer conv stack that ends at ``features // 2`` channels),
# then the last level optionally gets a pos-embed and finally
# ``output_conv2_aux[-1]``.
aux_processed = [
self.scratch.output_conv1_aux[i](a) for i, a in enumerate(aux_pyr)
]
last_aux = aux_processed[-1]
if self.pos_embed:
last_aux = _add_pos_embed(last_aux, W, H)
last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux)
fmap_last = last_aux_logits.permute(0, 2, 3, 1)
# Channels: [ray(6), ray_conf(1)]; ray uses 'linear' activation.
aux_pred = fmap_last[..., :-1]
aux_conf = _apply_activation(fmap_last[..., -1], self.conf_activation)
outs[self.head_aux] = aux_pred.view(B, S, *aux_pred.shape[1:])
outs[f"{self.head_aux}_conf"] = aux_conf.view(B, S, *aux_conf.shape[1:])
return outs

View File

@ -0,0 +1,236 @@
from __future__ import annotations
from typing import Dict, Optional, Sequence
import torch
import torch.nn as nn
from comfy.image_encoders.dino2 import Dinov2Model
from .camera import CameraDec, CameraEnc
from .dpt import DPT, DualDPT
from .ray_pose import get_extrinsic_from_camray
from .transform import affine_inverse, pose_encoding_to_extri_intri
_HEAD_REGISTRY = {
"dpt": DPT,
"dualdpt": DualDPT,
}
# Backbone presets (mirror the upstream DINOv2 ViT variants).
_BACKBONE_PRESETS = {
"vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False),
"vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False),
"vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False),
"vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True),
}
def _build_backbone_config(
backbone_name: str,
*,
alt_start: int,
qknorm_start: int,
rope_start: int,
cat_token: bool,
) -> dict:
if backbone_name not in _BACKBONE_PRESETS:
raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}")
cfg = dict(_BACKBONE_PRESETS[backbone_name])
cfg.update(dict(
layer_norm_eps=1e-6,
patch_size=14,
image_size=518,
# No mask_token in DA3 weights; omit param to avoid load warnings.
use_mask_token=False,
alt_start=alt_start,
qknorm_start=qknorm_start,
rope_start=rope_start,
cat_token=cat_token,
rope_freq=100.0,
))
return cfg
class DepthAnything3Net(nn.Module):
PATCH_SIZE = 14
def __init__(
self,
# --- Backbone ---
backbone_name: str = "vitl",
out_layers: Sequence[int] = (4, 11, 17, 23),
alt_start: int = -1,
qknorm_start: int = -1,
rope_start: int = -1,
cat_token: bool = False,
# --- Head ---
head_type: str = "dpt", # dpt or dualdpt
head_dim_in: int = 1024,
head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf
head_features: int = 256,
head_out_channels: Sequence[int] = (256, 512, 1024, 1024),
head_use_sky_head: bool = True, # ignored by DualDPT
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
# --- Camera (multi-view) ---
has_cam_enc: bool = False,
has_cam_dec: bool = False,
cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim)
cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token)
# ComfyUI plumbing
device=None, dtype=None, operations=None,
**_ignored,
):
super().__init__()
head_cls = _HEAD_REGISTRY[head_type.lower()]
self.head_type = head_type.lower()
self.has_sky = (self.head_type == "dpt") and head_use_sky_head
self.has_conf = head_output_dim > 1
self.out_layers = list(out_layers)
backbone_cfg = _build_backbone_config(
backbone_name,
alt_start=alt_start,
qknorm_start=qknorm_start,
rope_start=rope_start,
cat_token=cat_token,
)
self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations)
head_kwargs = dict(
dim_in=head_dim_in,
patch_size=self.PATCH_SIZE,
output_dim=head_output_dim,
features=head_features,
out_channels=tuple(head_out_channels),
device=device, dtype=dtype, operations=operations,
)
if self.head_type == "dpt":
head_kwargs.update(
use_sky_head=head_use_sky_head,
pos_embed=(False if head_pos_embed is None else head_pos_embed),
)
else: # dualdpt
head_kwargs.update(
pos_embed=(True if head_pos_embed is None else head_pos_embed),
)
self.head = head_cls(**head_kwargs)
# Built only if checkpoint has weights; cam_enc output dim == embed_dim.
embed_dim = backbone_cfg["hidden_size"]
if has_cam_enc:
self.cam_enc = CameraEnc(
dim_out=cam_dim_out if cam_dim_out is not None else embed_dim,
num_heads=max(1, embed_dim // 64),
device=device, dtype=dtype, operations=operations,
)
else:
self.cam_enc = None
if has_cam_dec:
default_dim = embed_dim * (2 if cat_token else 1)
self.cam_dec = CameraDec(
dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim,
device=device, dtype=dtype, operations=operations,
)
else:
self.cam_dec = None
self.dtype = dtype
def forward(
self,
image: torch.Tensor,
extrinsics: Optional[torch.Tensor] = None,
intrinsics: Optional[torch.Tensor] = None,
*,
use_ray_pose: bool = False,
ref_view_strategy: str = "saddle_balanced",
export_feat_layers: Optional[Sequence[int]] = None,
**_unused,
) -> Dict[str, torch.Tensor]:
"""Run depth and optionally pose prediction."""
if image.ndim == 4:
image = image.unsqueeze(1) # (B, 1, 3, H, W)
assert image.ndim == 5 and image.shape[2] == 3, \
f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}"
B, S, _, H, W = image.shape
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
# Camera-token preparation (multi-view path).
cam_token = None
if extrinsics is not None and intrinsics is not None and self.cam_enc is not None:
cam_token = self.cam_enc(extrinsics, intrinsics, (H, W))
# Toggle aux ray output on/off depending on what the caller asked for.
if isinstance(self.head, DualDPT):
self.head.enable_aux = bool(use_ray_pose)
feats, aux_feats = self.backbone.get_intermediate_layers_da3(
image, self.out_layers, cam_token=cam_token,
ref_view_strategy=ref_view_strategy,
export_feat_layers=export_feat_layers,
)
head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
# Pose prediction.
out: Dict[str, torch.Tensor] = {}
if use_ray_pose and "ray" in head_out and "ray_conf" in head_out:
ray = head_out["ray"]
ray_conf = head_out["ray_conf"]
extr_c2w, focal, pp = get_extrinsic_from_camray(
ray, ray_conf, ray.shape[-3], ray.shape[-2],
)
# Match the upstream output: w2c, drop the homogeneous row.
extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :]
# Build pixel-space intrinsics from the normalised focal/pp output.
intr = torch.eye(3, device=ray.device, dtype=ray.dtype)
intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone()
intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W
intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H
intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5
intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5
out["extrinsics"] = extr_w2c
out["intrinsics"] = intr
elif self.cam_dec is not None and S > 1:
# Decode the cam-token of the final out_layer into a pose encoding.
cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec)
pose_enc = self.cam_dec(cam_feat)
c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W))
# Match the upstream output convention: w2c (world->camera), 3x4.
c2w_4x4 = torch.cat([
c2w_3x4,
torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype)
.view(1, 1, 1, 4).expand(B, S, 1, 4),
], dim=-2)
out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :]
out["intrinsics"] = intr
# Flatten the views axis for per-pixel outputs (depth/conf/sky) so the
# per-image consumer keeps its (B*S, H, W) interface.
for k, v in head_out.items():
if k in ("ray", "ray_conf"):
# Keep multi-view shape for downstream pose work.
out[k] = v
elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S:
out[k] = v.reshape(B * S, *v.shape[2:])
else:
out[k] = v
if export_feat_layers:
out["aux_features"] = self._reshape_aux_features(aux_feats, H, W)
return out
def _reshape_aux_features(self, aux_feats, H: int, W: int):
"""Reshape (B, S, N, C) aux features into (B, S, h_p, w_p, C)."""
ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE
out = []
for f in aux_feats:
B, S, N, C = f.shape
assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}"
out.append(f.reshape(B, S, ph, pw, C))
return out

View File

@ -0,0 +1,128 @@
"""Input/output preprocessing helpers for Depth Anything 3."""
from __future__ import annotations
from typing import Tuple
import torch
import comfy.utils
PATCH_SIZE = 14
# ImageNet normalization constants used during DA3 training.
_IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406])
_IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225])
def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int:
down = (x // patch) * patch
up = down + patch
return up if abs(up - x) <= abs(x - down) else down
def compute_target_size(orig_h: int, orig_w: int, process_res: int, method: str = "upper_bound_resize") -> Tuple[int, int]:
"""Compute (target_h, target_w) for a single image.
upper_bound_resize: scale longest side to process_res, then round each dim to nearest multiple of 14 (default upstream method).
lower_bound_resize: scale shortest side to process_res, then round."""
if method == "upper_bound_resize":
longest = max(orig_h, orig_w)
scale = process_res / float(longest)
elif method == "lower_bound_resize":
shortest = min(orig_h, orig_w)
scale = process_res / float(shortest)
else:
raise ValueError(f"Unsupported process_res_method: {method}")
new_w = max(1, _round_to_patch(int(round(orig_w * scale))))
new_h = max(1, _round_to_patch(int(round(orig_h * scale))))
return new_h, new_w
def preprocess_image(image: torch.Tensor, process_res: int = 504, method: str = "upper_bound_resize") -> torch.Tensor:
assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
B, H, W, _ = image.shape
target_h, target_w = compute_target_size(H, W, process_res, method)
# (B, H, W, 3) -> (B, 3, H, W)
x = image.movedim(-1, 1).contiguous()
if (target_h, target_w) != (H, W):
# Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale).
# Lanczos in ``common_upscale`` is anti-aliased and produces the
# closest pixel-wise match in a sweep across {bilinear, bicubic,
# area, lanczos, bislerp}. Used in both directions for simplicity.
x = comfy.utils.common_upscale(x.float(), target_w, target_h, "lanczos", "disabled",)
x = x.clamp(0.0, 1.0)
mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
std = _IMAGENET_STD.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
x = (x - mean) / std
return x
# -----------------------------------------------------------------------------
# Output post-processing (sky-aware clipping for Mono/Metric variants)
# -----------------------------------------------------------------------------
def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor:
"""Boolean mask: True for non-sky pixels (sky probability < threshold)."""
return sky_prediction < threshold
def apply_sky_aware_clip(depth: torch.Tensor, sky: torch.Tensor, threshold: float = 0.3, quantile: float = 0.99) -> torch.Tensor:
"""Clips sky regions to the 99th percentile of non-sky depth. Returns a new depth tensor."""
non_sky = compute_non_sky_mask(sky, threshold=threshold)
if non_sky.sum() <= 10 or (~non_sky).sum() <= 10:
return depth.clone()
non_sky_depth = depth[non_sky]
if non_sky_depth.numel() > 100_000:
idx = torch.randint(0, non_sky_depth.numel(), (100_000,), device=non_sky_depth.device)
sampled = non_sky_depth[idx]
else:
sampled = non_sky_depth
max_depth = torch.quantile(sampled, quantile)
out = depth.clone()
out[~non_sky] = max_depth
return out
def normalize_depth_v2_style(depth: torch.Tensor, sky: torch.Tensor | None = None, low_quantile: float = 0.01, high_quantile: float = 0.99) -> torch.Tensor:
"""V2-style normalization computes percentile bounds over non-sky pixels (when available), then maps depth into [0, 1] with near = white (1.0)."""
if sky is not None:
mask = compute_non_sky_mask(sky)
if mask.any():
valid = depth[mask]
else:
valid = depth.flatten()
else:
valid = depth.flatten()
if valid.numel() > 100_000:
idx = torch.randint(0, valid.numel(), (100_000,), device=valid.device)
sample = valid[idx]
else:
sample = valid
lo = torch.quantile(sample, low_quantile)
hi = torch.quantile(sample, high_quantile)
rng = (hi - lo).clamp(min=1e-6)
norm = ((depth - lo) / rng).clamp(0.0, 1.0)
# Nearer pixels are brighter (1.0)
norm = 1.0 - norm
if sky is not None:
# Sky pixels become black (far / unknown)
sky_mask = ~compute_non_sky_mask(sky)
norm = torch.where(sky_mask, torch.zeros_like(norm), norm)
return norm
def normalize_depth_min_max(depth: torch.Tensor) -> torch.Tensor:
"""Simple per-frame min/max normalization with near=1.0 convention."""
lo = depth.amin(dim=(-2, -1), keepdim=True)
hi = depth.amax(dim=(-2, -1), keepdim=True)
rng = (hi - lo).clamp(min=1e-6)
return 1.0 - ((depth - lo) / rng).clamp(0.0, 1.0)

View File

@ -0,0 +1,272 @@
"""Ray-to-pose conversion for the multi-view path of Depth Anything 3."""
from __future__ import annotations
from typing import Optional, Tuple
import torch
# qr/svd use fp32: CUDA often has no fp16/bf16 kernels for these ops.
def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Decompose A = Q @ L with Q orthogonal and L lower-triangular.
Implemented in terms of QR by reversing the columns/rows; the standard
trick from the upstream reference. Inputs A are (3, 3)."""
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device, dtype=A.dtype)
A_tilde = A @ P
# CUDA QR is not implemented for fp16/bf16; upcast just for this call.
Q_tilde, R_tilde = torch.linalg.qr(A_tilde.float())
Q_tilde = Q_tilde.to(A.dtype)
R_tilde = R_tilde.to(A.dtype)
Q = Q_tilde @ P
L = P @ R_tilde @ P
d = torch.diag(L)
sign = torch.sign(d)
Q = Q * sign[None, :] # scale columns of Q
L = L * sign[:, None] # scale rows of L
return Q, L
def _homogenize_points(points: torch.Tensor) -> torch.Tensor:
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
# -----------------------------------------------------------------------------
# Weighted-LSQ + RANSAC homography (batched)
# -----------------------------------------------------------------------------
def _find_homography_weighted_lsq(src_pts: torch.Tensor, dst_pts: torch.Tensor, confident_weight: torch.Tensor,) -> torch.Tensor:
"""Solve a single H with weighted least-squares (DLT)."""
N = src_pts.shape[0]
if N < 4:
raise ValueError("At least 4 points are required to compute a homography.")
w = confident_weight.sqrt().unsqueeze(1) # (N, 1)
x = src_pts[:, 0:1]
y = src_pts[:, 1:2]
u = dst_pts[:, 0:1]
v = dst_pts[:, 1:2]
zeros = torch.zeros_like(x)
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1)
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1)
A = torch.cat([A1, A2], dim=0) # (2N, 9)
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
_, _, Vh = torch.linalg.svd(A.float())
Vh = Vh.to(A.dtype)
H = Vh[-1].reshape(3, 3)
return H / H[-1, -1]
def _find_homography_weighted_lsq_batched(src_pts_batch: torch.Tensor, dst_pts_batch: torch.Tensor, confident_weight_batch: torch.Tensor) -> torch.Tensor:
"""Batched DLT solver. Inputs (B, K, 2) / (B, K); output (B, 3, 3)."""
B, K, _ = src_pts_batch.shape
w = confident_weight_batch.sqrt().unsqueeze(2)
x = src_pts_batch[:, :, 0:1]
y = src_pts_batch[:, :, 1:2]
u = dst_pts_batch[:, :, 0:1]
v = dst_pts_batch[:, :, 1:2]
zeros = torch.zeros_like(x)
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2)
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2)
A = torch.cat([A1, A2], dim=1) # (B, 2K, 9)
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
_, _, Vh = torch.linalg.svd(A.float())
Vh = Vh.to(A.dtype)
H = Vh[:, -1].reshape(B, 3, 3)
return H / H[:, 2:3, 2:3]
def _ransac_find_homography_weighted_batched(
src_pts: torch.Tensor, # (B, N, 2)
dst_pts: torch.Tensor, # (B, N, 2)
confident_weight: torch.Tensor, # (B, N)
n_sample: int,
n_iter: int = 100,
reproj_threshold: float = 3.0,
num_sample_for_ransac: int = 8,
max_inlier_num: int = 10000,
rand_sample_iters_idx: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Batched weighted-RANSAC homography estimator. Returns (B, 3, 3) homography matrices."""
B, N, _ = src_pts.shape
assert N >= 4
device = src_pts.device
sorted_idx = torch.argsort(confident_weight, descending=True, dim=1)
candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample)
if rand_sample_iters_idx is None:
rand_sample_iters_idx = torch.stack(
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac]
for _ in range(n_iter)],
dim=0,
)
rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, k)
b_idx = (
torch.arange(B, device=device)
.view(B, 1, 1)
.expand(B, n_iter, num_sample_for_ransac)
)
src_b = src_pts[b_idx, rand_idx]
dst_b = dst_pts[b_idx, rand_idx]
w_b = confident_weight[b_idx, rand_idx]
cB, cN = src_b.shape[:2]
H_batch = _find_homography_weighted_lsq_batched(
src_b.flatten(0, 1), dst_b.flatten(0, 1), w_b.flatten(0, 1),
).unflatten(0, (cB, cN)) # (B, n_iter, 3, 3)
src_homo = torch.cat([src_pts, torch.ones(B, N, 1, device=device, dtype=src_pts.dtype)], dim=2)
proj = torch.bmm(
src_homo.unsqueeze(1).expand(B, n_iter, N, 3).reshape(-1, N, 3),
H_batch.reshape(-1, 3, 3).transpose(1, 2),
) # (B*n_iter, N, 3)
proj_xy = (proj[:, :, :2] / proj[:, :, 2:3]).reshape(B, n_iter, N, 2)
err = ((proj_xy - dst_pts.unsqueeze(1)) ** 2).sum(-1).sqrt() # (B, n_iter, N)
inlier_mask = err < reproj_threshold
score = (inlier_mask * confident_weight.unsqueeze(1)).sum(dim=2)
best_idx = torch.argmax(score, dim=1)
best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx]
# Refit with the inlier set (per-batch, since the inlier counts vary).
H_inlier_list = []
for b in range(B):
mask = best_inlier_mask[b]
in_src = src_pts[b][mask]
in_dst = dst_pts[b][mask]
in_w = confident_weight[b][mask]
if in_src.shape[0] < 4:
# Fall back to identity when RANSAC fails to find enough inliers.
H_inlier_list.append(torch.eye(3, device=device, dtype=src_pts.dtype))
continue
sorted_w = torch.argsort(in_w, descending=True)
if len(sorted_w) > max_inlier_num:
keep = max(int(len(sorted_w) * 0.95), max_inlier_num)
sorted_w = sorted_w[:keep][torch.randperm(keep, device=device)[:max_inlier_num]]
H_inlier_list.append(
_find_homography_weighted_lsq(in_src[sorted_w], in_dst[sorted_w], in_w[sorted_w])
)
return torch.stack(H_inlier_list, dim=0)
# -----------------------------------------------------------------------------
# Camera-ray utilities
# -----------------------------------------------------------------------------
def _unproject_identity(num_y: int, num_x: int, B: int, S: int, device, dtype) -> torch.Tensor:
"""Camera-space unit rays for an identity intrinsic on a 2x2 image plane."""
dx = 1.0 / num_x
dy = 1.0 / num_y
# Centered camera-space coords directly (skip the K^-1 step since it's
# just a translation by -1 on x and y when K is identity-with-center=1).
y = torch.linspace(-(1 - dy), (1 - dy), num_y, device=device, dtype=dtype)
x = torch.linspace(-(1 - dx), (1 - dx), num_x, device=device, dtype=dtype)
yy, xx = torch.meshgrid(y, x, indexing="ij")
grid = torch.stack((xx, yy), dim=-1) # (h, w, 2)
grid = grid.unsqueeze(0).unsqueeze(0).expand(B, S, num_y, num_x, 2)
return torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)
def _camray_to_caminfo(
camray: torch.Tensor, # (B, S, h, w, 6)
confidence: Optional[torch.Tensor] = None, # (B, S, h, w)
reproj_threshold: float = 0.2,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert per-pixel camera rays to per-view (R, T, focal, principal)."""
if confidence is None:
confidence = torch.ones_like(camray[..., 0])
B, S, h, w, _ = camray.shape
device = camray.device
dtype = camray.dtype
rays_target = camray[..., :3] # (B, S, h, w, 3)
rays_origin = _unproject_identity(h, w, B, S, device, dtype)
# Flatten (B*S, h*w, *) for the RANSAC routine.
rays_target = rays_target.flatten(0, 1).flatten(1, 2)
rays_origin = rays_origin.flatten(0, 1).flatten(1, 2)
weights = confidence.flatten(0, 1).flatten(1, 2).clone()
# Project to 2D in homogeneous form (the upstream calls this "perspective division").
z_thresh = 1e-4
mask = (rays_target[:, :, 2].abs() > z_thresh) & (rays_origin[:, :, 2].abs() > z_thresh)
weights = torch.where(mask, weights, torch.zeros_like(weights))
src = rays_origin.clone()
dst = rays_target.clone()
src[..., 0] = torch.where(mask, src[..., 0] / src[..., 2], src[..., 0])
src[..., 1] = torch.where(mask, src[..., 1] / src[..., 2], src[..., 1])
dst[..., 0] = torch.where(mask, dst[..., 0] / dst[..., 2], dst[..., 0])
dst[..., 1] = torch.where(mask, dst[..., 1] / dst[..., 2], dst[..., 1])
src = src[..., :2]
dst = dst[..., :2]
N = src.shape[1]
n_iter = 100
sample_ratio = 0.3
num_sample_for_ransac = 8
n_sample = max(num_sample_for_ransac, int(N * sample_ratio))
rand_idx = torch.stack(
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)],
dim=0,
)
# Chunk along the view axis to keep peak memory predictable.
chunk = 2
A_list = []
for i in range(0, src.shape[0], chunk):
A = _ransac_find_homography_weighted_batched(
src[i:i + chunk], dst[i:i + chunk], weights[i:i + chunk],
n_sample=n_sample, n_iter=n_iter,
num_sample_for_ransac=num_sample_for_ransac,
reproj_threshold=reproj_threshold,
rand_sample_iters_idx=rand_idx,
max_inlier_num=8000,
)
# Flip sign on dets that come out < 0 (so that the QL produces a
# right-handed rotation). ``det`` lacks fp16/bf16 CUDA kernels, so
# do the comparison in fp32.
flip = torch.linalg.det(A.float()) < 0
A = torch.where(flip[:, None, None], -A, A)
A_list.append(A)
A = torch.cat(A_list, dim=0) # (B*S, 3, 3)
R_list, f_list, pp_list = [], [], []
for i in range(A.shape[0]):
R, L = _ql_decomposition(A[i])
L = L / L[2][2]
f_list.append(torch.stack((L[0][0], L[1][1])))
pp_list.append(torch.stack((L[2][0], L[2][1])))
R_list.append(R)
R = torch.stack(R_list).reshape(B, S, 3, 3)
focal = torch.stack(f_list).reshape(B, S, 2)
pp = torch.stack(pp_list).reshape(B, S, 2)
# Translation: confidence-weighted average of camray direction(s).
cf = confidence.flatten(0, 1).flatten(1, 2)
T = (camray.flatten(0, 1).flatten(1, 2)[..., 3:] * cf.unsqueeze(-1)).sum(dim=1)
T = T / cf.sum(dim=-1, keepdim=True)
T = T.reshape(B, S, 3)
# Match upstream output convention: focal -> 1/focal, pp + 1.
return R, T, 1.0 / focal, pp + 1.0
def get_extrinsic_from_camray(
camray: torch.Tensor, # (B, S, h, w, 6)
conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w)
patch_size_y: int,
patch_size_x: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Wrap a 4x4 extrinsic + per-view focal + principal-point output."""
if conf.ndim == 5 and conf.shape[-1] == 1:
conf = conf.squeeze(-1)
R, T, focal, pp = _camray_to_caminfo(camray, confidence=conf)
extr = torch.cat([R, T.unsqueeze(-1)], dim=-1) # (B, S, 3, 4)
homo_row = torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device)
homo_row = homo_row.view(1, 1, 1, 4).expand(R.shape[0], R.shape[1], 1, 4)
extr = torch.cat([extr, homo_row], dim=-2) # (B, S, 4, 4)
return extr, focal, pp

View File

@ -0,0 +1,87 @@
"""Reference-view selection for the multi-view path of Depth Anything 3."""
from __future__ import annotations
from typing import Literal
import torch
RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"]
# Per the upstream constants module: ``THRESH_FOR_REF_SELECTION = 3``.
# Reference selection only runs when there are at least this many views.
THRESH_FOR_REF_SELECTION: int = 3
def select_reference_view(x: torch.Tensor, strategy: RefViewStrategy = "saddle_balanced") -> torch.Tensor:
"""Pick a reference view index per batch element."""
B, S, _, _ = x.shape
if S <= 1:
return torch.zeros(B, dtype=torch.long, device=x.device)
if strategy == "first":
return torch.zeros(B, dtype=torch.long, device=x.device)
if strategy == "middle":
return torch.full((B,), S // 2, dtype=torch.long, device=x.device)
# Feature-based strategies: normalised cls/cam token per view.
img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # (B,S,C)
if strategy == "saddle_balanced":
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # (B,S,S)
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # (B,S)
feat_norm = x[:, :, 0].norm(dim=-1) # (B,S)
feat_var = img_class_feat.var(dim=-1) # (B,S)
def _normalize(metric):
mn = metric.min(dim=1, keepdim=True).values
mx = metric.max(dim=1, keepdim=True).values
return (metric - mn) / (mx - mn + 1e-8)
sim_n, norm_n, var_n = _normalize(sim_score), _normalize(feat_norm), _normalize(feat_var)
balance = (sim_n - 0.5).abs() + (norm_n - 0.5).abs() + (var_n - 0.5).abs()
return balance.argmin(dim=1)
if strategy == "saddle_sim_range":
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2))
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
sim_max = sim_no_diag.max(dim=-1).values
sim_min = sim_no_diag.min(dim=-1).values
return (sim_max - sim_min).argmax(dim=1)
raise ValueError(
f"Unknown reference view selection strategy: {strategy!r}. "
f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'"
)
def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
"""Reorder x so the reference view is at position 0 in axis S."""
B, S = x.shape[0], x.shape[1]
if S <= 1:
return x
positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
b_idx_exp = b_idx.unsqueeze(1)
reorder = torch.where(
(positions > 0) & (positions <= b_idx_exp),
positions - 1,
positions,
)
reorder[:, 0] = b_idx
batch = torch.arange(B, device=x.device).unsqueeze(1)
return x[batch, reorder]
def restore_original_order(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
"""Inverse of reorder_by_reference."""
B, S = x.shape[0], x.shape[1]
if S <= 1:
return x
target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
b_idx_exp = b_idx.unsqueeze(1)
restore = torch.where(target_positions < b_idx_exp, target_positions + 1, target_positions)
restore = torch.scatter(restore, dim=1, index=b_idx_exp, src=torch.zeros_like(b_idx_exp))
batch = torch.arange(B, device=x.device).unsqueeze(1)
return x[batch, restore]

View File

@ -0,0 +1,160 @@
"""Geometry / camera transform helpers for Depth Anything 3."""
from __future__ import annotations
from typing import Tuple
import torch
import torch.nn.functional as F
# -----------------------------------------------------------------------------
# Affine 4x4 helpers
# -----------------------------------------------------------------------------
def as_homogeneous(ext: torch.Tensor) -> torch.Tensor:
"""Promote (...,3,4) extrinsics to (...,4,4) homogeneous form. No-op when the input is already ``(...,4,4)``."""
if ext.shape[-2:] == (4, 4):
return ext
if ext.shape[-2:] == (3, 4):
ones = torch.zeros_like(ext[..., :1, :4])
ones[..., 0, 3] = 1.0
return torch.cat([ext, ones], dim=-2)
raise ValueError(f"Invalid affine shape: {ext.shape}")
def affine_inverse(A: torch.Tensor) -> torch.Tensor:
"""Inverse of an affine matrix ``[R|T; 0 0 0 1]``."""
R = A[..., :3, :3]
T = A[..., :3, 3:]
P = A[..., 3:, :]
return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2)
# -----------------------------------------------------------------------------
# Quaternion <-> rotation matrix (xyzw / scalar-last)
# -----------------------------------------------------------------------------
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""sqrt(max(0, x)) with a zero subgradient where x == 0."""
ret = torch.zeros_like(x)
positive_mask = x > 0
if torch.is_grad_enabled():
ret[positive_mask] = torch.sqrt(x[positive_mask])
else:
ret = torch.where(positive_mask, torch.sqrt(x), ret)
return ret
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
"""Force the real part of a unit quaternion (xyzw) to be non-negative."""
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
"""Convert quaternions (xyzw) to (...,3,3) rotation matrices."""
i, j, k, r = torch.unbind(quaternions, -1)
two_s = 2.0 / (quaternions * quaternions).sum(-1)
o = torch.stack(
(
1 - two_s * (j * j + k * k),
two_s * (i * j - k * r),
two_s * (i * k + j * r),
two_s * (i * j + k * r),
1 - two_s * (i * i + k * k),
two_s * (j * k - i * r),
two_s * (i * k - j * r),
two_s * (j * k + i * r),
1 - two_s * (i * i + j * j),
),
-1,
)
return o.reshape(quaternions.shape[:-1] + (3, 3))
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
"""Convert (...,3,3) rotation matrices to quaternions (xyzw)."""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
quat_by_rijk = torch.stack(
[
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
batch_dim + (4,)
)
# Reorder rijk -> xyzw (i.e. ijkr).
out = out[..., [1, 2, 3, 0]]
return standardize_quaternion(out)
# -----------------------------------------------------------------------------
# Pose-encoding <-> extrinsics + intrinsics
# -----------------------------------------------------------------------------
def extri_intri_to_pose_encoding(extrinsics: torch.Tensor, intrinsics: torch.Tensor, image_size_hw: Tuple[int, int]) -> torch.Tensor:
"""Pack (extr, intr, image_size) into the 9-D pose-encoding vector.
extrinsics: camera-to-world (c2w) (B,S,4,4) matrices,
intrinsics: pixel-space (B,S,3,3) matrices,
image_size_hw: is a (H, W) pair.
"""
R = extrinsics[..., :3, :3]
T = extrinsics[..., :3, 3]
quat = mat_to_quat(R)
H, W = image_size_hw
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
return torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
def pose_encoding_to_extri_intri(pose_encoding: torch.Tensor, image_size_hw: Tuple[int, int]) -> Tuple[torch.Tensor, torch.Tensor]:
"""Inverse of extri_intri_to_pose_encoding."""
T = pose_encoding[..., :3]
quat = pose_encoding[..., 3:7]
fov_h = pose_encoding[..., 7]
fov_w = pose_encoding[..., 8]
# Normalize to unit quaternion. CameraDec outputs raw values; a near-zero
# quaternion causes two_s = 2/norm² → inf in quat_to_mat → NaN extrinsics.
quat = quat / quat.norm(dim=-1, keepdim=True).clamp(min=1e-6)
R = quat_to_mat(quat)
extrinsics = torch.cat([R, T[..., None]], dim=-1)
H, W = image_size_hw
fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6)
fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6)
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device, dtype=pose_encoding.dtype)
intrinsics[..., 0, 0] = fx
intrinsics[..., 1, 1] = fy
intrinsics[..., 0, 2] = W / 2
intrinsics[..., 1, 2] = H / 2
intrinsics[..., 2, 2] = 1.0
return extrinsics, intrinsics

View File

@ -174,7 +174,7 @@ class Ideogram4Transformer(nn.Module):
llm = self.llm_cond_proj(llm) * text_mask
h[:, :L_text] = h[:, :L_text] + llm
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long))
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long), out_dtype=h.dtype)
# Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch).
freqs_cis = precompute_freqs_cis(
@ -235,7 +235,7 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer):
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options):
B = x_chunk.shape[0]
device = x_chunk.device
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
img_tokens = self._img_to_tokens(x_chunk)
L_img = img_tokens.shape[1]
L_text = context_chunk.shape[1]
L = L_text + L_img
@ -268,7 +268,7 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer):
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options):
B = x_chunk.shape[0]
device = x_chunk.device
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
img_tokens = self._img_to_tokens(x_chunk)
L_img = img_tokens.shape[1]
position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3)

View File

@ -51,6 +51,18 @@ class FeedForward(nn.Module):
return hidden_states
# Addin this back because Nunchaku custom nodes rely on it, see comment here:
# https://github.com/Comfy-Org/ComfyUI/pull/14178#issuecomment-4640475161
# TODO: Eventually remove this once we natively support SVDQuants
def apply_rotary_emb(x, freqs_cis):
if x.shape[1] == 0:
return x
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
return t_out.reshape(*x.shape)
class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
super().__init__()

View File

@ -8,7 +8,7 @@ from einops import rearrange
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope1
from comfy.ldm.flux.math import apply_rope1, rope
import comfy.ldm.common_dit
import comfy.model_management
import comfy.patcher_extension
@ -570,6 +570,14 @@ class WanModel(torch.nn.Module):
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
x = torch.concat((full_ref, x), dim=1)
# In-context reference (Bernini)
context_latents = kwargs.get("context_latents", None)
main_len = x.shape[1]
if context_latents is not None:
for lat in context_latents:
cl = self.patch_embedding(lat.float().to(x.device)).to(x.dtype).flatten(2).transpose(1, 2)
x = torch.cat([x, cl], dim=1)
# context
context = self.text_embedding(context)
@ -599,6 +607,9 @@ class WanModel(torch.nn.Module):
# head
x = self.head(x, e)
if context_latents is not None:
x = x[:, :main_len]
if full_ref is not None:
x = x[:, full_ref.shape[1]:]
@ -606,7 +617,7 @@ class WanModel(torch.nn.Module):
x = self.unpatchify(x, grid_sizes)
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0):
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
@ -638,6 +649,13 @@ class WanModel(torch.nn.Module):
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
freqs = self.rope_embedder(img_ids).movedim(1, 2)
# In-context reference: a non-zero source_id composes an extra rotation into the spatial rope
if source_id:
d = self.dim // self.num_heads
pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32)
id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype)
freqs = torch.einsum('...ij,...jk->...ik', freqs, id_rot)
return freqs
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
@ -661,6 +679,15 @@ class WanModel(torch.nn.Module):
t_len += 1
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
# In-context reference: one rope block per stream, each with it's own source_id (1, 2, ...) to distinguish from the target (id 0).
context_latents = kwargs.get("context_latents", None)
if context_latents is not None:
context_latents = [comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size) for lat in context_latents]
for i, lat in enumerate(context_latents):
freqs = torch.cat([freqs, self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=i + 1)], dim=1)
kwargs = {**kwargs, "context_latents": context_latents}
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
def unpatchify(self, x, grid_sizes):
@ -1631,13 +1658,15 @@ class SCAILWanModel(WanModel):
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs):
if reference_latent is not None:
x = torch.cat((reference_latent, x), dim=2)
# embeddings
x = self.patch_embedding(x.float()).to(x.dtype)
if ref_mask_latents is not None: # SCAIL-2 additive mask stream
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
grid_sizes = x.shape[2:]
transformer_options["grid_sizes"] = grid_sizes
x = x.flatten(2).transpose(1, 2)
@ -1645,6 +1674,8 @@ class SCAILWanModel(WanModel):
scail_pose_seq_len = 0
if pose_latents is not None:
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
if sam_latents is not None: # SCAIL-2 additive mask stream
scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype)
scail_x = scail_x.flatten(2).transpose(1, 2)
scail_pose_seq_len = scail_x.shape[1]
x = torch.cat([x, scail_x], dim=1)
@ -1695,7 +1726,36 @@ class SCAILWanModel(WanModel):
return x
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
# ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode,
# which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset.
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
if ref_mask_flag is not None and not bool(ref_mask_flag):
REF_ROPE_H = 120.0
POSE_ROPE_W = 120.0
ref_t_patches = 0
if reference_latent is not None:
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
main_t_patches = t - ref_t_patches
parts = []
if ref_t_patches > 0:
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
if main_t_patches > 0:
parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options))
if pose_latents is not None:
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
h_scale = h / H_pose
w_scale = w / W_pose
h_shift = (h_scale - 1) / 2
w_shift = (w_scale - 1) / 2
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf))
return torch.cat(parts, dim=1)
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
if pose_latents is None:
@ -1719,12 +1779,16 @@ class SCAILWanModel(WanModel):
return torch.cat([main_freqs, pose_freqs], dim=1)
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
if pose_latents is not None:
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
if ref_mask_latents is not None: # SCAIL-2
ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size)
if sam_latents is not None: # SCAIL-2
sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size)
t_len = t
if time_dim_concat is not None:
@ -1737,5 +1801,15 @@ class SCAILWanModel(WanModel):
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
t_len += reference_latent.shape[2]
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
ref_mask_flag = kwargs.pop("ref_mask_flag", None) # SCAIL-2
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag)
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w]
class SCAIL2WanModel(SCAILWanModel):
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs):
super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)

View File

@ -357,6 +357,12 @@ def model_lora_keys_unet(model, key_map={}):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["transformer.{}".format(key_lora)] = k
if isinstance(model, (comfy.model_base.LTXV, comfy.model_base.LTXAV)):
for k in sdk:
if k.startswith("diffusion_model.") and k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")]
key_map["{}".format(key_lora)] = k
return key_map

View File

@ -65,6 +65,7 @@ import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector
import comfy.ldm.hidream_o1.model
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
import comfy.ldm.depth_anything_3.model
import comfy.model_management
import comfy.patcher_extension
@ -1518,8 +1519,26 @@ class WAN21(BaseModel):
if reference_latents is not None:
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
# In-context reference conditioning (Bernini)
context_latents = kwargs.get("context_latents", None)
if context_latents is not None:
out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents])
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
# In-context cond slicing (Bernini)
if cond_key == "context_latents" and isinstance(getattr(cond_value, "cond", None), list):
dim = window.dim
out = []
for lat in cond_value.cond:
if lat.ndim > dim and lat.shape[dim] > 1 and lat.shape[dim] == x_in.shape[dim]:
out.append(window.get_tensor(lat, device, dim=dim, retain_index_list=retain_index_list))
else:
out.append(lat.to(device))
return cond_value._copy_with(out)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN21_CausalAR(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
@ -1754,6 +1773,80 @@ class WAN21_SCAIL(WAN21):
return out
class WAN21_SCAIL2(WAN21_SCAIL):
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAIL2WanModel)
self.memory_usage_factor_conds = ("reference_latent", "pose_latents", "ref_mask_latents", "sam_latents")
self.memory_usage_shape_process = {
"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
"sam_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
}
self.image_to_video = image_to_video
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
if driving_mask_28ch is not None:
out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous())
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
if ref_mask_28ch is not None:
out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous())
ref_mask_flag = kwargs.get("ref_mask_flag", None)
if ref_mask_flag is not None:
out['ref_mask_flag'] = comfy.conds.CONDConstant(ref_mask_flag)
return out
def extra_conds_shapes(self, **kwargs):
out = super().extra_conds_shapes(**kwargs)
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
if driving_mask_28ch is not None:
s = driving_mask_28ch.shape
out['sam_latents'] = [s[0], 28, s[1], s[3], s[4]]
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
if ref_mask_28ch is not None:
s = ref_mask_28ch.shape
out['ref_mask_latents'] = [s[0], 28, s[1], s[3], s[4]]
return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key in ("sam_latents", "pose_latents"):
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
def concat_cond(self, **kwargs):
# The 4 extra channels are the history_mask (1 at clean-anchor frames).
noise = kwargs.get("noise", None)
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
if extra_channels != 4:
return super().concat_cond(**kwargs)
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
if mask is None:
return torch.zeros_like(noise)[:, :4]
device = kwargs["device"]
if mask.shape[1] != 4:
mask = torch.mean(mask, dim=1, keepdim=True)
mask = 1.0 - mask
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
if mask.shape[-3] < noise.shape[-3]:
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
if mask.shape[1] == 1:
mask = mask.repeat(1, 4, 1, 1, 1)
mask = utils.resize_to_batch_size(mask, noise.shape[0])
return mask
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
# Hold anchor constant across all sigmas instead of base sigma*noise + (1-sigma)*latent_image.
return latent_image
class WAN22_WanDancer(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel)
@ -2227,6 +2320,12 @@ class RT_DETR_v4(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
class DepthAnything3(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device,
unet_model=comfy.ldm.depth_anything_3.model.DepthAnything3Net)
class ErnieImage(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)

View File

@ -630,6 +630,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["model_type"] = "humo"
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "animate"
elif '{}patch_embedding_mask.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "scail2"
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
dit_config["model_type"] = "scail"
elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys:
@ -860,6 +862,95 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
return dit_config
# Depth Anything 3 (repackaged to ComfyUI's native Dinov2Model layout via scripts/convert_da3.py)
if '{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix) in state_dict_keys:
dit_config = {}
dit_config["image_model"] = "DepthAnything3"
patch_w = state_dict['{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix)]
embed_dim = patch_w.shape[0]
depth = count_blocks(state_dict_keys, '{}backbone.encoder.layer.'.format(key_prefix) + '{}.')
# Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg).
backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim)
if backbone_name is None:
return None
dit_config["backbone_name"] = backbone_name
# Detect DA3 extensions on top of vanilla DINOv2.
has_camera_token = '{}backbone.embeddings.camera_token'.format(key_prefix) in state_dict_keys
# qk-norm shows up as `attention.q_norm.weight` on enabled blocks.
qknorm_indices = [
i for i in range(depth)
if '{}backbone.encoder.layer.{}.attention.q_norm.weight'.format(key_prefix, i) in state_dict_keys
]
qknorm_start = qknorm_indices[0] if qknorm_indices else -1
# The DA3 main-series configs always set alt_start == qknorm_start == rope_start.
# cat_token=True is implied by the presence of camera_token.
if has_camera_token:
dit_config["alt_start"] = qknorm_start
dit_config["rope_start"] = qknorm_start
dit_config["qknorm_start"] = qknorm_start
dit_config["cat_token"] = True
else:
dit_config["alt_start"] = -1
dit_config["rope_start"] = -1
dit_config["qknorm_start"] = -1
dit_config["cat_token"] = False
# Detect head type and config.
has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys
dit_config["head_dim_in"] = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
dit_config["head_features"] = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
dit_config["head_out_channels"] = [
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
for i in range(4)
]
if has_aux:
# DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width).
dit_config["head_type"] = "dualdpt"
dit_config["head_output_dim"] = 2
dit_config["head_use_sky_head"] = False
else:
dit_config["head_type"] = "dpt"
dit_config["head_output_dim"] = state_dict[
'{}head.scratch.output_conv2.2.weight'.format(key_prefix)
].shape[0]
dit_config["head_use_sky_head"] = (
'{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys
)
# out_layers: hard-coded per upstream YAML config (depth-aware default).
if depth >= 24:
# vitl: depths used vary between DA3-Large (DualDPT) and Mono/Metric (DPT).
if has_aux:
dit_config["out_layers"] = [11, 15, 19, 23]
else:
dit_config["out_layers"] = [4, 11, 17, 23]
else:
# vits/vitb: 12 blocks
dit_config["out_layers"] = [5, 7, 9, 11]
# Camera encoder/decoder presence (multi-view + pose path).
has_cam_enc = '{}cam_enc.token_norm.weight'.format(key_prefix) in state_dict_keys
has_cam_dec = '{}cam_dec.fc_t.weight'.format(key_prefix) in state_dict_keys
dit_config["has_cam_enc"] = has_cam_enc
dit_config["has_cam_dec"] = has_cam_dec
if has_cam_enc:
cam_enc_w = state_dict.get(
'{}cam_enc.pose_branch.fc2.weight'.format(key_prefix)
)
if cam_enc_w is not None:
dit_config["cam_dim_out"] = cam_enc_w.shape[0]
if has_cam_dec:
cam_dec_w = state_dict.get(
'{}cam_dec.fc_t.weight'.format(key_prefix)
)
if cam_dec_w is not None:
dit_config["cam_dec_dim_in"] = cam_dec_w.shape[1]
return dit_config
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
dit_config = {}
dit_config["image_model"] = "ernie"

View File

@ -651,8 +651,7 @@ def ensure_pin_budget(size, evict_active=False):
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
return free_pins(to_free, evict_active=evict_active) >= shortfall
def ensure_pin_registerable(size, evict_active=True):
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
def free_registrations(shortfall, evict_active=True):
if MAX_PINNED_MEMORY <= 0:
return False
if shortfall <= 0:
@ -674,6 +673,9 @@ def ensure_pin_registerable(size, evict_active=True):
return True
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
def ensure_pin_registerable(size, evict_active=True):
return free_registrations(TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY, evict_active=evict_active)
class LoadedModel:
def __init__(self, model: ModelPatcher):
self._set_model(model)
@ -956,8 +958,6 @@ def loaded_models(only_currently_used=False):
def cleanup_models_gc():
do_gc = False
reset_cast_buffers()
for i in range(len(current_loaded_models)):
cur = current_loaded_models[i]
if cur.is_dead():

View File

@ -379,10 +379,11 @@ class ModelPatcher:
def get_clone_model_override(self):
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
def clone(self, disable_dynamic=False, model_override=None):
def clone(self, disable_dynamic=False, model_override=None, force_deepcopy=False):
class_ = self.__class__
if self.is_dynamic() and disable_dynamic:
class_ = ModelPatcher
if self.is_dynamic() and disable_dynamic or force_deepcopy:
if self.is_dynamic() and disable_dynamic:
class_ = ModelPatcher
if model_override is None:
if self.cached_patcher_init is None:
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")

View File

@ -89,13 +89,26 @@ def pin_memory(module, subset="weights", size=None):
not comfy.model_management.ensure_pin_registerable(registerable_size)):
return _steal_pin(module, stack, buckets, size, priority)
extended = False
try:
hostbuf.extend(size=size)
hostbuf.extend(size=size, register=False)
extended = True
pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
pin.untyped_storage()._comfy_hostbuf = hostbuf
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
comfy.model_management.discard_cuda_async_error()
comfy.model_management.free_registrations(size)
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
comfy.model_management.discard_cuda_async_error()
del pin
hostbuf.truncate(offset, do_unregister=False)
return _steal_pin(module, stack, buckets, size, priority)
except RuntimeError:
if extended:
hostbuf.truncate(offset, do_unregister=False)
return _steal_pin(module, stack, buckets, size, priority)
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
module._pin = pin
stack.append((module, offset))
module._pin_registered = True
module._pin_stack_index = len(stack) - 1

View File

@ -1450,6 +1450,17 @@ class WAN21_SCAIL(WAN21_T2V):
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
return out
class WAN21_SCAIL2(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
"model_type": "scail2",
}
def get_model(self, state_dict, prefix="", device=None):
out = model_base.WAN21_SCAIL2(self, image_to_video=False, device=device)
return out
class WAN22_WanDancer(WAN21_T2V):
unet_config = {
"image_model": "wan2.1",
@ -2045,6 +2056,23 @@ class RT_DETR_v4(supported_models_base.BASE):
return None
class DepthAnything3(supported_models_base.BASE):
unet_config = {
"image_model": "DepthAnything3",
}
# Mono path: no num_heads / num_head_channels needed.
unet_extra_config = {}
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
return model_base.DepthAnything3(self, device=device)
def clip_target(self, state_dict={}):
return None
class ErnieImage(supported_models_base.BASE):
unet_config = {
"image_model": "ernie",
@ -2259,6 +2287,7 @@ models = [
WAN22_Animate,
WAN21_FlowRVS,
WAN21_SCAIL,
WAN21_SCAIL2,
WAN22_WanDancer,
Hunyuan3Dv2mini,
Hunyuan3Dv2,
@ -2286,4 +2315,5 @@ models = [
CogVideoX_I2V,
CogVideoX_T2V,
SVD_img2vid,
DepthAnything3,
]

View File

@ -32,7 +32,9 @@ class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer):
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
if llama_template is None:
if text.startswith('<|im_start|>'):
llama_text = text
elif llama_template is None:
llama_text = self.llama_template.format(text)
else:
llama_text = llama_template.format(text)

View File

@ -755,6 +755,18 @@ class File3DKSPLAT(ComfyTypeIO):
Type = File3D
@comfytype(io_type="FILE_3D_SPLAT_ANY")
class File3DSplatAny(ComfyTypeIO):
"""General 3D Gaussian splat file type - accepts any supported splat container (.ply / .spz / .splat / .ksplat)."""
Type = File3D
@comfytype(io_type="FILE_3D_POINT_CLOUD_ANY")
class File3DPointCloudAny(ComfyTypeIO):
"""General point cloud file type - accepts any supported point cloud container (currently .ply)."""
Type = File3D
@comfytype(io_type="HOOKS")
class Hooks(ComfyTypeIO):
if TYPE_CHECKING:
@ -2336,6 +2348,8 @@ __all__ = [
"File3DSPLAT",
"File3DSPZ",
"File3DKSPLAT",
"File3DSplatAny",
"File3DPointCloudAny",
"Hooks",
"HookKeyframes",
"TimestepsRange",

View File

@ -285,7 +285,7 @@ class AudioSaveHelper:
results = []
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
file = f"{filename_with_batch_num}_{counter:05}.{format}"
output_path = os.path.join(full_output_folder, file)
# Use original sample rate initially

View File

@ -43,6 +43,7 @@ class BFLFluxEraseRequest(BaseModel):
"white (255) marks areas to remove, black (0) marks areas to preserve.",
)
dilate_pixels: int = Field(10)
seed: int | None = Field(None)
output_format: str = Field("png")

View File

@ -97,3 +97,28 @@ class BriaRemoveVideoBackgroundResult(BaseModel):
class BriaRemoveVideoBackgroundResponse(BaseModel):
status: str = Field(...)
result: BriaRemoveVideoBackgroundResult | None = Field(None)
class BriaVideoGreenScreenRequest(BaseModel):
video: str = Field(..., description="Publicly accessible URL of the input video.")
green_shade: str = Field(
default="broadcast_green",
description="Solid chroma-key shade applied behind the foreground "
"(broadcast_green, chroma_green, or blue_screen).",
)
output_container_and_codec: str = Field(...)
preserve_audio: bool = Field(True)
seed: int = Field(...)
class BriaVideoReplaceBackgroundRequest(BaseModel):
video: str = Field(..., description="Publicly accessible URL of the input (foreground) video.")
background_url: str = Field(
...,
description="Publicly accessible URL of the background image or video to composite behind "
"the foreground. Stretched to the foreground frame; match its aspect ratio for "
"undistorted results.",
)
output_container_and_codec: str = Field(...)
preserve_audio: bool = Field(True)
seed: int = Field(...)

View File

@ -108,13 +108,19 @@ class GeminiVideoMetadata(BaseModel):
startOffset: GeminiOffset | None = Field(None)
class GeminiThinkingConfig(BaseModel):
includeThoughts: bool | None = Field(None)
thinkingLevel: str = Field(...)
class GeminiGenerationConfig(BaseModel):
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
maxOutputTokens: int | None = Field(None, ge=16, le=65536)
seed: int | None = Field(None)
stopSequences: list[str] | None = Field(None)
temperature: float | None = Field(None, ge=0.0, le=2.0)
topK: int | None = Field(None, ge=1)
topP: float | None = Field(None, ge=0.0, le=1.0)
thinkingConfig: GeminiThinkingConfig | None = Field(None)
class GeminiImageOutputOptions(BaseModel):
@ -128,11 +134,6 @@ class GeminiImageConfig(BaseModel):
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
class GeminiThinkingConfig(BaseModel):
includeThoughts: bool | None = Field(None)
thinkingLevel: str = Field(...)
class GeminiImageGenerationConfig(GeminiGenerationConfig):
responseModalities: list[str] | None = Field(None)
imageConfig: GeminiImageConfig | None = Field(None)

View File

@ -534,6 +534,15 @@ class FluxEraseNode(IO.ComfyNode):
max=25,
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
optional=True,
),
],
outputs=[IO.Image.Output()],
hidden=[
@ -553,6 +562,7 @@ class FluxEraseNode(IO.ComfyNode):
image: Input.Image,
mask: Input.Image,
dilate_pixels: int = 10,
seed: int = 0,
) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=256, min_height=256)
mask = resize_mask_to_image(mask, image)
@ -565,6 +575,7 @@ class FluxEraseNode(IO.ComfyNode):
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
mask=mask,
dilate_pixels=dilate_pixels,
seed=seed,
),
)

View File

@ -1,14 +1,19 @@
import av
import torch
from av.codec import CodecContext
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bria import (
BriaEditImageRequest,
BriaImageEditResponse,
BriaRemoveBackgroundRequest,
BriaRemoveBackgroundResponse,
BriaRemoveVideoBackgroundRequest,
BriaRemoveVideoBackgroundResponse,
BriaImageEditResponse,
BriaStatusResponse,
BriaVideoGreenScreenRequest,
BriaVideoReplaceBackgroundRequest,
InputModerationSettings,
)
from comfy_api_nodes.util import (
@ -316,6 +321,248 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
class BriaVideoGreenScreen(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaVideoGreenScreen",
display_name="Bria Video Green Screen",
category="partner/video/Bria",
description="Replace a video's background with a solid chroma-key screen using Bria.",
inputs=[
IO.Video.Input("video"),
IO.Combo.Input(
"green_shade",
options=["broadcast_green", "chroma_green", "blue_screen"],
tooltip="Solid chroma-key shade applied behind the foreground: "
"broadcast_green (#00B140), chroma_green (#00FF00), or blue_screen (#0000FF).",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
green_shade: str,
seed: int,
) -> IO.NodeOutput:
validate_video_duration(video, max_duration=60.0)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/video/edit/green_screen", method="POST"),
data=BriaVideoGreenScreenRequest(
video=await upload_video_to_comfyapi(cls, video),
green_shade=green_shade,
output_container_and_codec="mp4_h264",
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveVideoBackgroundResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
class BriaVideoReplaceBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaVideoReplaceBackground",
display_name="Bria Video Replace Background",
category="partner/video/Bria",
description="Replace a video's background with a supplied image or video using Bria. "
"The output keeps the foreground's resolution and frame rate; a background with a "
"different aspect ratio is stretched to fit, so match it for undistorted results.",
inputs=[
IO.Video.Input("video", tooltip="Foreground video whose background is replaced."),
IO.Image.Input(
"background_image",
optional=True,
tooltip="Background image to composite behind the foreground. "
"Provide either a background image or a background video, not both.",
),
IO.Video.Input(
"background_video",
optional=True,
tooltip="Background video to composite behind the foreground. "
"Provide either a background image or a background video, not both.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Video.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
seed: int,
background_image: Input.Image | None = None,
background_video: Input.Video | None = None,
) -> IO.NodeOutput:
if (background_image is None) == (background_video is None):
raise ValueError("Provide either a background image or a background video, not both.")
validate_video_duration(video, max_duration=60.0)
if background_video is not None:
validate_video_duration(background_video, max_duration=60.0)
background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background")
else:
background_url = await upload_image_to_comfyapi(cls, background_image, wait_label="Uploading background")
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"),
data=BriaVideoReplaceBackgroundRequest(
video=await upload_video_to_comfyapi(cls, video),
background_url=background_url,
output_container_and_codec="mp4_h264",
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveVideoBackgroundResponse,
)
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
def _video_to_images_and_mask(video: Input.Video) -> tuple[Input.Image, Input.Mask]:
"""Decode a transparent webm (VP9 + alpha) into image frames and an alpha mask.
VP9 keeps its alpha in a side layer that PyAV's default vp9 decoder drops, so the frames
are decoded with libvpx-vp9. Returns RGB images [B,H,W,3] in 0..1 and a mask [B,H,W]
following the Load Image convention (1 = transparent) for compositing or Save WEBM.
"""
rgb_frames: list[torch.Tensor] = []
alpha_frames: list[torch.Tensor] = []
with av.open(video.get_stream_source(), mode="r") as container:
stream = container.streams.video[0]
decoder = CodecContext.create("libvpx-vp9", "r") if stream.codec_context.name == "vp9" else None
for packet in container.demux(stream):
for frame in (decoder.decode(packet) if decoder is not None else packet.decode()):
rgba = torch.from_numpy(frame.to_ndarray(format="rgba")).float() / 255.0
rgb_frames.append(rgba[..., :3])
alpha_frames.append(rgba[..., 3])
images = torch.stack(rgb_frames) if rgb_frames else torch.zeros(0, 0, 0, 3)
mask = (1.0 - torch.stack(alpha_frames)) if alpha_frames else torch.zeros((images.shape[0], 64, 64))
return images, mask
class BriaTransparentVideoBackground(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="BriaTransparentVideoBackground",
display_name="Bria Remove Video Background (Transparent)",
category="partner/video/Bria",
description="Remove the background from a video using Bria and return the cut-out frames "
"plus an alpha mask. Connect both to a compositing node, or feed them to Save WEBM to "
"write a transparent video.",
inputs=[
IO.Video.Input("video"),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.Image.Output(display_name="images"),
IO.Mask.Output(display_name="mask"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
),
)
@classmethod
async def execute(
cls,
video: Input.Video,
seed: int,
) -> IO.NodeOutput:
validate_video_duration(video, max_duration=60.0)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bria/v2/video/edit/remove_background", method="POST"),
data=BriaRemoveVideoBackgroundRequest(
video=await upload_video_to_comfyapi(cls, video),
background_color="Transparent",
output_container_and_codec="webm_vp9",
seed=seed,
),
response_model=BriaStatusResponse,
)
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
status_extractor=lambda r: r.status,
response_model=BriaRemoveVideoBackgroundResponse,
)
video_out = await download_url_to_video_output(response.result.video_url)
images, mask = _video_to_images_and_mask(video_out)
return IO.NodeOutput(images, mask)
class BriaExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
@ -323,6 +570,9 @@ class BriaExtension(ComfyExtension):
BriaImageEditNode,
BriaRemoveImageBackground,
BriaRemoveVideoBackground,
BriaVideoGreenScreen,
# BriaVideoReplaceBackground, # server returns Status 500 when we pass background video
BriaTransparentVideoBackground,
]

View File

@ -7,6 +7,7 @@ from io import BytesIO
import torch
from typing_extensions import override
from comfy.utils import common_upscale
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.bytedance import (
RECOMMENDED_PRESETS,
@ -131,6 +132,44 @@ def _prepare_seedance_image(image: Input.Image) -> Input.Image:
return image
# Supported output aspect ratios, used to pre-size FLF frames to matching pixel pair to avoid the 1080p stretch jump.
SEEDANCE2_RATIO_WH = {
"16:9": (16, 9),
"4:3": (4, 3),
"1:1": (1, 1),
"3:4": (3, 4),
"9:16": (9, 16),
"21:9": (21, 9),
}
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080}
def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]:
"""Exact supported output (width, height) for (resolution, ratio).
The shorter side equals the resolution number (e.g. 1080p 16:9 -> 1920x1080). For ratio
"adaptive" (or any unexpected value) the ratio is derived from the image's own aspect, snapped
to the nearest supported ratio, so the output keeps the frame's orientation.
"""
short = SEEDANCE2_RES_SHORT_SIDE[resolution]
if ratio not in SEEDANCE2_RATIO_WH:
aspect = image.shape[-2] / image.shape[-3] # W / H; tensor is (B, H, W, C)
ratio = min(SEEDANCE2_RATIO_WH, key=lambda k: abs(SEEDANCE2_RATIO_WH[k][0] / SEEDANCE2_RATIO_WH[k][1] - aspect))
rw, rh = SEEDANCE2_RATIO_WH[ratio]
if rw >= rh: # landscape or square: shorter side is the height
out_w, out_h = round(short * rw / rh), short
else: # portrait: shorter side is the width
out_w, out_h = short, round(short * rh / rw)
return out_w - out_w % 2, out_h - out_h % 2
def _resize_to_exact(image: torch.Tensor, width: int, height: int) -> torch.Tensor:
"""Center-crop to the target aspect and resize to exactly width x height (lanczos)."""
samples = image.movedim(-1, 1) # (B, H, W, C) -> (B, C, H, W)
resized = common_upscale(samples, width, height, "lanczos", "center")
return resized.movedim(1, -1)
async def _resolve_reference_assets(
cls: type[IO.ComfyNode],
asset_ids: list[str],
@ -1790,10 +1829,28 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
if last_frame is not None and last_frame_asset_id:
raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.")
if first_frame is not None:
first_frame = _prepare_seedance_image(first_frame)
if last_frame is not None:
last_frame = _prepare_seedance_image(last_frame)
request_ratio = model["ratio"]
if first_frame_asset_id or last_frame_asset_id:
if first_frame is not None:
first_frame = _prepare_seedance_image(first_frame)
if last_frame is not None:
last_frame = _prepare_seedance_image(last_frame)
else:
# The 1080p FLF stretch fix (pre-size frames to a supported pixel pair + submit ratio="adaptive")
# only applies to local image inputs we can resize.
request_ratio = "adaptive"
target_dims: tuple[int, int] | None = None
if first_frame is not None:
validate_image_aspect_ratio(first_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
validate_image_dimensions(first_frame, min_width=300, min_height=300)
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], first_frame)
first_frame = _resize_to_exact(first_frame, *target_dims)
if last_frame is not None:
validate_image_aspect_ratio(last_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
validate_image_dimensions(last_frame, min_width=300, min_height=300)
if target_dims is None:
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], last_frame)
last_frame = _resize_to_exact(last_frame, *target_dims)
asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a]
image_assets: dict[str, str] = {}
@ -1844,7 +1901,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
content=content,
generate_audio=model["generate_audio"],
resolution=model["resolution"],
ratio=model["ratio"],
ratio=request_ratio,
duration=model["duration"],
seed=seed,
watermark=watermark,

View File

@ -8,7 +8,7 @@ import os
from enum import Enum
from fnmatch import fnmatch
from io import BytesIO
from typing import Literal
from typing import Any, Literal
import torch
from typing_extensions import override
@ -19,6 +19,7 @@ from comfy_api_nodes.apis.gemini import (
GeminiContent,
GeminiFileData,
GeminiGenerateContentRequest,
GeminiGenerationConfig,
GeminiGenerateContentResponse,
GeminiImageConfig,
GeminiImageGenerateContentRequest,
@ -40,13 +41,18 @@ from comfy_api_nodes.util import (
get_number_of_images,
sync_op,
tensor_to_base64_string,
upload_audio_to_comfyapi,
upload_image_to_comfyapi,
upload_images_to_comfyapi,
upload_video_to_comfyapi,
validate_string,
video_to_base64_string,
)
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
GEMINI_URL_INPUT_BUDGET = 10
GEMINI_MAX_INLINE_BYTES = 18 * 1024 * 1024
GEMINI_IMAGE_SYS_PROMPT = (
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
"Interpret all user input—regardless of "
@ -285,6 +291,140 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
return final_price / 1_000_000.0
def create_video_parts(video_input: Input.Video) -> list[GeminiPart]:
"""Convert a single video input to Gemini API compatible parts (inline MP4/H.264)."""
base_64_string = video_to_base64_string(
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.video_mp4,
data=base_64_string,
)
)
]
def create_audio_parts(audio_input: Input.Audio) -> list[GeminiPart]:
"""Convert an audio input to Gemini API compatible parts (one inline MP3 part per batch item)."""
audio_parts: list[GeminiPart] = []
for batch_index in range(audio_input["waveform"].shape[0]):
# Recreate an IO.AUDIO object for the given batch dimension index
audio_at_index = Input.Audio(
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
sample_rate=audio_input["sample_rate"],
)
# Convert to MP3 format for compatibility with Gemini API
audio_bytes = audio_to_base64_string(
audio_at_index,
container_format="mp3",
codec_name="libmp3lame",
)
audio_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.audio_mp3,
data=audio_bytes,
)
)
)
return audio_parts
def _flatten_images(images: list[Input.Image]) -> list[torch.Tensor]:
"""Expand any batched image tensors into individual (H, W, C) frames, preserving order."""
frames: list[torch.Tensor] = []
for img in images:
if len(img.shape) == 4:
frames.extend(img[i] for i in range(img.shape[0]))
else:
frames.append(img)
return frames
def _flatten_audio(audios: list[Input.Audio]) -> list[Input.Audio]:
"""Expand any batched audio inputs into individual single-clip audio inputs, preserving order."""
clips: list[Input.Audio] = []
for audio in audios:
waveform = audio["waveform"]
for i in range(waveform.shape[0]):
clips.append(Input.Audio(waveform=waveform[i].unsqueeze(0), sample_rate=audio["sample_rate"]))
return clips
async def _media_url_part(cls: type[IO.ComfyNode], kind: str, payload: Any) -> GeminiPart:
"""Upload a single media unit to ComfyAPI storage and return a fileData (URL) part."""
if kind == "image":
url = await upload_image_to_comfyapi(cls, payload, mime_type="image/png", wait_label="Uploading image")
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.image_png, fileUri=url))
if kind == "audio":
url = await upload_audio_to_comfyapi(
cls, payload, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mp3"
)
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.audio_mp3, fileUri=url))
url = await upload_video_to_comfyapi(cls, payload, wait_label="Uploading video")
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.video_mp4, fileUri=url))
def _media_inline_part(kind: str, payload: Any) -> tuple[GeminiPart, int]:
"""Encode a single media unit as an inline base64 part; returns (part, base64_length)."""
if kind == "image":
data = tensor_to_base64_string(payload, mime_type="image/webp")
mime = GeminiMimeType.image_webp
elif kind == "audio":
data = audio_to_base64_string(payload, container_format="mp3", codec_name="libmp3lame")
mime = GeminiMimeType.audio_mp3
else:
data = video_to_base64_string(
payload, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
mime = GeminiMimeType.video_mp4
return GeminiPart(inlineData=GeminiInlineData(mimeType=mime, data=data)), len(data)
async def build_gemini_media_parts(
cls: type[IO.ComfyNode],
images: list[Input.Image],
audios: list[Input.Audio],
videos: list[Input.Video],
*,
url_budget: int = GEMINI_URL_INPUT_BUDGET,
max_inline_bytes: int = GEMINI_MAX_INLINE_BYTES,
) -> list[GeminiPart]:
"""Build Gemini parts for multimodal inputs (images, audio, video).
fileData URLs are preferred for every media type: the upload is fetched directly by the
model, keeping the request body tiny regardless of media size. The URL budget is shared
across all media and assigned largest-first (video, then audio, then images), so that if it
is ever exhausted the inline-base64 overflow is limited to the smallest items. Total inline
payload is capped by `max_inline_bytes`.
"""
units: list[tuple[str, Any]] = (
[("video", v) for v in videos]
+ [("audio", a) for a in _flatten_audio(audios)]
+ [("image", f) for f in _flatten_images(images)]
)
parts: list[GeminiPart] = []
url_used = 0
inline_bytes = 0
for kind, payload in units:
if url_used < url_budget:
parts.append(await _media_url_part(cls, kind, payload))
url_used += 1
continue
part, nbytes = _media_inline_part(kind, payload)
inline_bytes += nbytes
if inline_bytes > max_inline_bytes:
raise ValueError(
f"Too much media to send inline (over {max_inline_bytes // (1024 * 1024)}MB after the first "
f"{url_budget} inputs are uploaded as URLs). Reduce the number or size of attached media."
)
parts.append(part)
return parts
class GeminiNode(IO.ComfyNode):
"""
Node to generate text responses from a Gemini model.
@ -407,58 +547,9 @@ class GeminiNode(IO.ComfyNode):
)
""",
),
is_deprecated=True,
)
@classmethod
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts."""
base_64_string = video_to_base64_string(
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.video_mp4,
data=base_64_string,
)
)
]
@classmethod
def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]:
"""
Convert audio input to Gemini API compatible parts.
Args:
audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
Returns:
List of GeminiPart objects containing the encoded audio.
"""
audio_parts: list[GeminiPart] = []
for batch_index in range(audio_input["waveform"].shape[0]):
# Recreate an IO.AUDIO object for the given batch dimension index
audio_at_index = Input.Audio(
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
sample_rate=audio_input["sample_rate"],
)
# Convert to MP3 format for compatibility with Gemini API
audio_bytes = audio_to_base64_string(
audio_at_index,
container_format="mp3",
codec_name="libmp3lame",
)
audio_parts.append(
GeminiPart(
inlineData=GeminiInlineData(
mimeType=GeminiMimeType.audio_mp3,
data=audio_bytes,
)
)
)
return audio_parts
@classmethod
async def execute(
cls,
@ -482,9 +573,9 @@ class GeminiNode(IO.ComfyNode):
if images is not None:
parts.extend(await create_image_parts(cls, images))
if audio is not None:
parts.extend(cls.create_audio_parts(audio))
parts.extend(create_audio_parts(audio))
if video is not None:
parts.extend(cls.create_video_parts(video))
parts.extend(create_video_parts(video))
if files is not None:
parts.extend(files)
@ -512,6 +603,210 @@ class GeminiNode(IO.ComfyNode):
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
GEMINI_V2_MODELS: dict[str, str] = {
"Gemini 3.1 Pro": "gemini-3.1-pro-preview",
"Gemini 3.1 Flash-Lite": "gemini-3.1-flash-lite-preview",
}
def _gemini_text_model_inputs(thinking_default: str) -> list[Input]:
"""Per-model inputs revealed by the model DynamicCombo (shared media + sampling controls)."""
return [
IO.Autogrow.Input(
"images",
template=IO.Autogrow.TemplateNames(
IO.Image.Input("image"),
names=[f"image_{i}" for i in range(1, 17)],
min=0,
),
tooltip="Optional image(s) to use as context for the model. Up to 16 images.",
),
IO.Autogrow.Input(
"audio",
template=IO.Autogrow.TemplateNames(
IO.Audio.Input("audio"),
names=["audio_1"],
min=0,
),
tooltip="Optional audio clip to use as context for the model.",
),
IO.Autogrow.Input(
"video",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("video"),
names=["video_1"],
min=0,
),
tooltip="Optional video clip to use as context for the model.",
),
IO.Custom("GEMINI_INPUT_FILES").Input(
"files",
optional=True,
tooltip="Optional file(s) to use as context for the model. "
"Accepts inputs from the Gemini Input Files node.",
),
IO.Combo.Input(
"thinking_level",
options=["LOW", "HIGH"],
default=thinking_default,
tooltip="How hard the model reasons internally before answering. "
"HIGH improves quality on difficult tasks but costs more (thinking) tokens and is slower.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.01,
tooltip="Controls randomness. Lower is more focused/deterministic, higher is more creative.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=0.95,
min=0.0,
max=1.0,
step=0.01,
tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.",
advanced=True,
),
IO.Int.Input(
"max_output_tokens",
default=32768,
min=16,
max=65536,
tooltip="Maximum tokens to generate, including the model's internal thinking. "
"With thinking_level HIGH, a low value can leave no room for the answer; raise this if "
"responses come back empty or truncated. The model stops early when finished, so a higher "
"cap costs nothing extra for short replies.",
advanced=True,
),
]
class GeminiNodeV2(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="GeminiNodeV2",
display_name="Google Gemini",
category="partner/text/Gemini",
essentials_category="Text Generation",
description="Generate text responses with Google's Gemini models. Provide a text prompt and, "
"optionally, one or more images, audio clips, videos, or files as multimodal context.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text input to the model. Include detailed instructions, questions, or context.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option("Gemini 3.1 Pro", _gemini_text_model_inputs("HIGH")),
IO.DynamicCombo.Option("Gemini 3.1 Flash-Lite", _gemini_text_model_inputs("LOW")),
],
tooltip="The Gemini model used to generate the response.",
),
IO.Int.Input(
"seed",
default=42,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed for sampling. Set to 0 for a random seed. Deterministic output isn't guaranteed.",
),
IO.String.Input(
"system_prompt",
multiline=True,
default="",
optional=True,
advanced=True,
tooltip="Foundational instructions that dictate the model's behavior.",
),
],
outputs=[
IO.String.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
expr="""
(
$m := widgets.model;
$contains($m, "lite") ? {
"type": "list_usd",
"usd": [0.00025, 0.0015],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
} : {
"type": "list_usd",
"usd": [0.002, 0.012],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
)
""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_id = GEMINI_V2_MODELS[model["model"]]
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
images = [t for t in (model.get("images") or {}).values() if t is not None]
audios = [a for a in (model.get("audio") or {}).values() if a is not None]
videos = [v for v in (model.get("video") or {}).values() if v is not None]
if images or audios or videos:
parts.extend(await build_gemini_media_parts(cls, images, audios, videos))
files = model.get("files")
if files is not None:
parts.extend(files)
gemini_system_prompt = None
if system_prompt:
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
response = await sync_op(
cls,
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"),
data=GeminiGenerateContentRequest(
contents=[
GeminiContent(
role=GeminiRole.user,
parts=parts,
)
],
generationConfig=GeminiGenerationConfig(
temperature=model["temperature"],
topP=model["top_p"],
maxOutputTokens=model["max_output_tokens"],
seed=seed if seed > 0 else None,
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
),
systemInstruction=gemini_system_prompt,
),
response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price,
)
output_text = get_text_from_response(response)
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
class GeminiInputFiles(IO.ComfyNode):
"""
Loads and formats input files for use with the Gemini API.
@ -1129,6 +1424,26 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
tooltip="Foundational instructions that dictate an AI's behavior.",
advanced=True,
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.01,
optional=True,
tooltip="Controls randomness in generation. Lower is more focused/deterministic.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=0.95,
min=0.0,
max=1.0,
step=0.01,
optional=True,
tooltip="Nucleus sampling threshold. Lower is more focused, higher more diverse.",
advanced=True,
),
],
outputs=[
IO.Image.Output(),
@ -1165,6 +1480,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
seed: int,
response_modalities: str,
system_prompt: str = "",
temperature: float = 1.0,
top_p: float = 0.95,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
model_choice = model["model"]
@ -1204,6 +1521,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
imageConfig=image_config,
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
temperature=temperature,
topP=top_p,
),
systemInstruction=gemini_system_prompt,
),
@ -1222,6 +1541,7 @@ class GeminiExtension(ComfyExtension):
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
GeminiNode,
GeminiNodeV2,
GeminiImage,
GeminiImage2,
GeminiNanoBanana2,

View File

@ -42,9 +42,11 @@ async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Ima
_MODEL_MEDIUM = "Krea 2 Medium"
_MODEL_MEDIUM_TURBO = "Krea 2 Medium Turbo"
_MODEL_LARGE = "Krea 2 Large"
_MODEL_ENDPOINTS: dict[str, str] = {
_MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium",
_MODEL_MEDIUM_TURBO: "/proxy/krea/generate/image/krea/krea-2/medium-turbo",
_MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large",
}
@ -57,7 +59,7 @@ _UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F
def _krea_model_inputs() -> list:
"""Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo."""
"""Nested inputs shared by Krea 2 Medium, Medium Turbo and Large under the DynamicCombo."""
return [
IO.Combo.Input(
"aspect_ratio",
@ -123,6 +125,7 @@ class Krea2ImageNode(IO.ComfyNode):
"model",
options=[
IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()),
IO.DynamicCombo.Option(_MODEL_MEDIUM_TURBO, _krea_model_inputs()),
IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()),
],
tooltip="Krea 2 Medium is best for expressive illustrations; "
@ -151,14 +154,15 @@ class Krea2ImageNode(IO.ComfyNode):
),
expr="""
(
$isLarge := widgets.model = "krea 2 large";
$rates := {
"krea 2 medium turbo": {"text": 0.015, "style": 0.0175, "moodboard": 0.02},
"krea 2 medium": {"text": 0.03, "style": 0.035, "moodboard": 0.04},
"krea 2 large": {"text": 0.06, "style": 0.065, "moodboard": 0.07}
};
$r := $lookup($rates, widgets.model);
$hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0;
$hasStyle := $lookup(inputs, "model.style_reference").connected;
$usd := $hasMoodboard
? ($isLarge ? 0.07 : 0.04)
: ($hasStyle
? ($isLarge ? 0.065 : 0.035)
: ($isLarge ? 0.06 : 0.03));
$usd := $hasMoodboard ? $r.moodboard : ($hasStyle ? $r.style : $r.text);
{"type":"usd","usd": $usd}
)
""",

View File

@ -158,7 +158,7 @@ class SaveAudio(IO.ComfyNode):
return IO.Schema(
node_id="SaveAudio",
search_aliases=["export flac"],
display_name="Save Audio (FLAC)",
display_name="Save Audio (FLAC) (Deprecated)",
category="audio",
essentials_category="Audio",
inputs=[
@ -167,6 +167,7 @@ class SaveAudio(IO.ComfyNode):
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
is_deprecated=True,
)
@classmethod
@ -186,7 +187,7 @@ class SaveAudioMP3(IO.ComfyNode):
return IO.Schema(
node_id="SaveAudioMP3",
search_aliases=["export mp3"],
display_name="Save Audio (MP3)",
display_name="Save Audio (MP3) (Deprecated)",
category="audio",
essentials_category="Audio",
inputs=[
@ -196,6 +197,7 @@ class SaveAudioMP3(IO.ComfyNode):
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
is_deprecated=True,
)
@classmethod
@ -217,7 +219,7 @@ class SaveAudioOpus(IO.ComfyNode):
return IO.Schema(
node_id="SaveAudioOpus",
search_aliases=["export opus"],
display_name="Save Audio (Opus)",
display_name="Save Audio (Opus) (Deprecated)",
category="audio",
inputs=[
IO.Audio.Input("audio"),
@ -226,6 +228,7 @@ class SaveAudioOpus(IO.ComfyNode):
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
is_deprecated=True,
)
@classmethod
@ -241,6 +244,54 @@ class SaveAudioOpus(IO.ComfyNode):
save_opus = execute # TODO: remove
class SaveAudioAdvanced(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="SaveAudioAdvanced",
search_aliases=["save audio", "export audio", "output audio", "write audio", "flac", "mp3", "opus"],
display_name="Save Audio (Advanced)",
description="Saves the input audio to your ComfyUI output directory.",
category="audio",
inputs=[
IO.Audio.Input("audio", tooltip="The audio to save."),
IO.String.Input(
"filename_prefix",
default="audio/ComfyUI",
tooltip=(
"The prefix for the file to save. May include formatting tokens "
"such as %date:yyyy-MM-dd%."
),
),
IO.DynamicCombo.Input(
"format",
options=[
IO.DynamicCombo.Option("flac", []),
IO.DynamicCombo.Option("mp3", [
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
]),
IO.DynamicCombo.Option("opus", [
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
]),
],
tooltip="The file format in which to save the audio.",
),
],
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
is_output_node=True,
)
@classmethod
def execute(cls, audio, filename_prefix: str, format: dict) -> IO.NodeOutput:
file_format = format.get("format", None)
quality = format.get("quality", None)
if quality:
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format, quality=quality)
else:
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format)
return IO.NodeOutput(ui=ui)
class PreviewAudio(IO.ComfyNode):
@classmethod
def define_schema(cls):
@ -822,6 +873,7 @@ class AudioExtension(ComfyExtension):
SaveAudio,
SaveAudioMP3,
SaveAudioOpus,
SaveAudioAdvanced,
LoadAudio,
PreviewAudio,
ConditioningStableAudio,

View File

@ -0,0 +1,115 @@
import torch
from typing_extensions import override
import comfy.model_management
import comfy.utils
import node_helpers
from comfy_api.latest import ComfyExtension, io
def _resize_long_edge(image, max_size, stride=16):
"""Resize (preserve aspect) so the long edge <= max_size, then snap each side to `stride`"""
h, w = image.shape[1], image.shape[2]
scale = min(max_size / max(h, w), 1.0)
nh = max(stride, round(h * scale / stride) * stride)
nw = max(stride, round(w * scale / stride) * stride)
return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "area", "disabled").movedim(1, -1)
class BerniniConditioning(io.ComfyNode):
"""Bernini in-context conditioning for a Wan2.2-A14B model.
Attaches the VAE-encoded source video / reference images to the conditioning
source video first, then each reference image
The task is inferred from which inputs are connected:
(nothing) -> t2v (text-to-video)
source_video -> v2v (video-to-video)
source_video + ref_images -> rv2v (reference-guided video editing)
ref_images only -> r2v (reference-to-video)
source_video + ref_video -> ads2v (insert image/video into video)
source_video is the edit base / canvas (resized to width x height).
reference_video is moving content to composite in.
Streams are ordered source_video, reference_video, then reference_images -> source_id (1, 2, 3, ...).
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="BerniniConditioning",
display_name="Bernini Conditioning",
category="conditioning/video_models",
description="Conditioning node for Bernini in-context video/image conditioning. It can be used for the following tasks: t2v (text-to-video), v2v (video-to-video), rv2v (reference-guided video editing), r2v (reference-to-video), ads2v (insert image/video into video)."
"Reference images injected as in-context tokens (r2v, rv2v) are encoded independently at their own native aspect ratio (long edge capped at ref_max_size)",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=832, min=16, max=8192, step=16),
io.Int.Input("height", default=480, min=16, max=8192, step=16),
io.Int.Input("length", default=81, min=1, max=8192, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("source_video", optional=True, tooltip=(
"Source video to edit or restyle (v2v, rv2v). Resized to width/height and trimmed to length.")),
io.Image.Input("reference_video", optional=True, tooltip=(
"Video to insert into the source video (ads2v).")),
io.Autogrow.Input("reference_images", optional=True,
template=io.Autogrow.TemplatePrefix(
input=io.Image.Input("reference_image", tooltip=(
"Reference image injected as an in-context token (r2v, rv2v).")),
prefix="reference_image_", min=0, max=8)),
io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True, tooltip=(
"Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio and snapped to 16px.")),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent"),
],
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size,
source_video=None, reference_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
device=comfy.model_management.intermediate_device())
# source_video (1), reference_video (2), reference_images (3, 4, ...).
context = []
if source_video is not None:
vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
context.append(vae.encode(vid[:, :, :, :3]))
if reference_video is not None:
ref_vid = _resize_long_edge(reference_video[:length], ref_max_size) # moving content, native aspect
context.append(vae.encode(ref_vid[:, :, :, :3]))
# reference_images is an autogrow dict {reference_image_0: IMAGE, ...}; each slot is a
# separate stream at its own native aspect (a multi-image batch in one slot -> one stream per frame).
if reference_images:
for name in sorted(reference_images):
imgs = reference_images[name]
if imgs is None:
continue
for i in range(imgs.shape[0]):
img = _resize_long_edge(imgs[i:i + 1], ref_max_size) # native aspect per ref
context.append(vae.encode(img[:, :, :, :3]))
if context:
positive = node_helpers.conditioning_set_values(positive, {"context_latents": context})
negative = node_helpers.conditioning_set_values(negative, {"context_latents": context})
return io.NodeOutput(positive, negative, {"samples": latent})
class BerniniExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
BerniniConditioning,
]
async def comfy_entrypoint() -> BerniniExtension:
return BerniniExtension()

View File

@ -36,15 +36,15 @@ class RemoveBackground(IO.ComfyNode):
category="image/background removal",
description="Generates a foreground mask to remove the background from an image using a background removal model.",
inputs=[
IO.Image.Input("image", tooltip="Input image to remove the background from"),
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask"),
IO.Image.Input("image", tooltip="Input image to remove the background from")
],
outputs=[
IO.Mask.Output("mask", tooltip="Generated foreground mask")
]
)
@classmethod
def execute(cls, image, bg_removal_model):
def execute(cls, bg_removal_model, image):
mask = bg_removal_model.encode_image(image)
return IO.NodeOutput(mask)

View File

@ -7,29 +7,29 @@ class ColorToRGBInt(io.ComfyNode):
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="ColorToRGBInt",
display_name="Color to RGB Int",
display_name="Color Picker",
category="utilities",
description="Convert a color to a RGB integer value.",
description="Return a color RGB integer value and hexadecimal representation.",
inputs=[
io.Color.Input("color"),
],
outputs=[
io.Int.Output(display_name="rgb_int"),
io.Color.Output(display_name="hex")
],
)
@classmethod
def execute(
cls,
color: str,
) -> io.NodeOutput:
def execute(cls, color: str) -> io.NodeOutput:
# expect format #RRGGBB
if len(color) != 7 or color[0] != "#":
raise ValueError("Color must be in format #RRGGBB")
r = int(color[1:3], 16)
g = int(color[3:5], 16)
b = int(color[5:7], 16)
return io.NodeOutput(r * 256 * 256 + g * 256 + b)
rgb_int = r * 256 * 256 + g * 256 + b
return io.NodeOutput(rgb_int, color)
class ColorExtension(ComfyExtension):

View File

@ -933,9 +933,10 @@ class Guider_DualModel(comfy.samplers.CFGGuider):
def predict_noise(self, x, timestep, model_options={}, seed=None):
positive = self.conds.get("positive", None)
if self.uncond_inner is None: # cfg == 1 or no negative -> single model, cond only
return comfy.samplers.calc_cond_batch(self.inner_model, [positive], x, timestep, model_options)[0]
cond = comfy.samplers.calc_cond_batch(self.inner_model, [positive], x, timestep, model_options)[0]
# uncond model not loaded (base cfg==1/no negative), or cfg driven to 1.0 this step -> single model, cond only
if self.uncond_inner is None or (math.isclose(self.cfg, 1.0) and not model_options.get("disable_cfg1_optimization", False)):
return cond
uncond_model_options = model_options
if "multigpu_clones" in model_options: # TODO: support multigpu instead of just running uncond on a single GPU
@ -1140,7 +1141,7 @@ class CFGOverride(io.ComfyNode):
return io.Schema(
node_id="CFGOverride",
display_name="CFG Override",
description="Override cfg to a fixed value over a [start, end] percent slice of the steps. "
description="Override cfg to a fixed value over a [start, end] percent (sigma) range. "
"With multiple overrides, the one nearest the sampler wins on overlap.",
category="sampling/custom_sampling",
inputs=[

View File

@ -411,6 +411,21 @@ class ImageProcessingNode(io.ComfyNode):
return has_group
@classmethod
def _ensure_image_list(cls, images):
"""Normalize to a flat list of [1, H, W, C] tensors."""
if isinstance(images, torch.Tensor):
if images.ndim != 4:
raise ValueError(f"Expected 4D image tensor, got shape {tuple(images.shape)}")
return [images[i:i+1] for i in range(images.shape[0])]
flat = []
for item in images:
if not isinstance(item, torch.Tensor) or item.ndim != 4:
raise ValueError(f"Expected 4D image tensor, got {type(item).__name__} shape {getattr(item, 'shape', None)}")
flat.extend([item[i:i+1] for i in range(item.shape[0])])
return flat
@classmethod
def define_schema(cls):
if cls.node_id is None:
@ -458,6 +473,9 @@ class ImageProcessingNode(io.ComfyNode):
"""Execute the node. Routes to _process or _group_process based on mode."""
is_group = cls._detect_processing_mode()
if is_group:
images = cls._ensure_image_list(images)
# Extract scalar values from lists for parameters
params = {}
for k, v in kwargs.items():

View File

@ -0,0 +1,681 @@
"""ComfyUI nodes for Depth Anything 3.
Model capability matrix:
Variant head_type has_sky has_conf cam_dec
DA3-Small dualdpt False True yes
DA3-Base dualdpt False True yes
DA3-Mono-Large dpt True False no
DA3-Metric-Large dpt True False no (raw output is metres)
"""
from __future__ import annotations
import logging
from typing_extensions import override
import torch
import comfy.model_management as mm
import comfy.sd
import folder_paths
from comfy.ldm.colormap import turbo as _turbo
from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess
from comfy_api.latest import ComfyExtension, Types, io
from comfy.ldm.moge.geometry import triangulate_grid_mesh
DA3ModelType = io.Custom("DA3_MODEL")
DA3Geometry = io.Custom("DA3_GEOMETRY")
DA3PointCloud = io.Custom("DA3_POINT_CLOUD")
# DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
#
# Per-frame tensors - B = batch size in mono mode; B = S (number of views) in multi-view mode.
# "depth": torch.Tensor (B, H, W) -- raw model depth (always present; matches MoGe convention)
# "image": torch.Tensor (B, H, W, 3) -- source image in [0, 1], CPU (always present)
# "mode": str -- "mono" or "multiview" (always present)
# "sky": torch.Tensor (B, H, W) -- sky probability in [0, 1] (Mono/Metric variants only)
# "confidence": torch.Tensor (B, H, W) -- raw model confidence output (Small/Base variants only)
#
# Multi-view only - S = number of views; the leading 1 is the scene dimension from the model.
# "extrinsics": torch.Tensor (1, S, 3, 4) -- world-to-camera [R|t] matrices
# "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics
#
# DA3_POINT_CLOUD is a dict:
# "points": torch.Tensor (N, 3) -- 3-D coords in glTF convention (Y-up, Z-back)
# "colors": torch.Tensor (N, 3) -- RGB in [0, 1], or None
# "confidence": torch.Tensor (N,) -- raw confidence per point, or None
def _da3_unproject(depth: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
"""Pixel-space K⁻¹ unprojection: (H,W) depth → (H,W,3) point map in OpenCV space."""
H, W = depth.shape
u = torch.arange(W, dtype=torch.float32, device=depth.device)
v = torch.arange(H, dtype=torch.float32, device=depth.device)
u, v = torch.meshgrid(u, v, indexing='xy') # both (H, W)
pix = torch.stack([u, v, torch.ones_like(u)], dim=-1) # (H, W, 3)
rays = torch.einsum('ij,hwj->hwi', torch.linalg.inv(K.to(depth.device)), pix)
return rays * depth.unsqueeze(-1) # (H, W, 3)
def _da3_default_K(H: int, W: int) -> torch.Tensor:
"""Fallback ~60° FOV pinhole K for mono-mode DA3 (no intrinsics in geometry)."""
fx = fy = float(W) * 0.7
return torch.tensor([[fx, 0.0, (W - 1) / 2.0],
[0.0, fy, (H - 1) / 2.0],
[0.0, 0.0, 1.0]], dtype=torch.float32)
def _da3_get_K(geometry: dict, b: int, H: int, W: int) -> torch.Tensor:
"""Return pixel-space K for batch element b, falling back to a default estimate."""
if "intrinsics" in geometry:
# shape (1, S, 3, 3) - leading scene dimension from the multiview head
return geometry["intrinsics"][0, b].float()
logging.getLogger("comfy").warning(
"DA3_GEOMETRY has no intrinsics (mono-mode model). "
"Using a ~60° FOV estimate; 3-D reconstruction may be inaccurate."
)
return _da3_default_K(H, W)
def _da3_get_extrinsic(geometry: dict, b: int) -> torch.Tensor | None:
"""Return the world-to-camera extrinsic for batch element b, or None in mono mode.
The model outputs (1, S, 3, 4) [R|t] matrices; the fallback identity is (4, 4).
_da3_apply_extrinsic handles both shapes via [:3, :3] / [:3, 3] slicing.
"""
if "extrinsics" not in geometry:
return None
return geometry["extrinsics"][0, b].float()
def _da3_apply_extrinsic(points_cam: torch.Tensor, E: torch.Tensor) -> torch.Tensor:
"""Transform (H,W,3) OpenCV camera-space points to world space."""
E = E.to(points_cam.device).float()
if not torch.isfinite(E).all():
logging.getLogger("comfy").warning(
"DA3 extrinsic matrix contains non-finite values (pose estimation may have failed). "
"Falling back to camera-space coordinates."
)
return points_cam
H, W, _ = points_cam.shape
R = E[:3, :3] # (3, 3) rotation
t = E[:3, 3] # (3,) translation
R_inv = R.T # rotation inverse = transpose for orthogonal R
t_inv = -(R_inv @ t) # (3,)
pts = points_cam.reshape(-1, 3) # (N, 3)
pts_world = pts @ R_inv.T + t_inv # (N, 3)
return pts_world.reshape(H, W, 3)
def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor:
"""Map raw confidence to [0, 1] per image."""
B = conf.shape[0]
out = []
for i in range(B):
c = conf[i]
c_min, c_max = c.min(), c.max()
out.append((c - c_min) / (c_max - c_min) if c_max > c_min else torch.ones_like(c))
return torch.stack(out, dim=0)
def _da3_build_mask(geometry: dict, b: int, H: int, W: int, confidence_threshold: float, use_sky_mask: bool) -> torch.Tensor:
"""Build (H,W) bool keep-mask from sky probability and confidence."""
mask = torch.ones(H, W, dtype=torch.bool)
if use_sky_mask and "sky" in geometry:
mask = mask & (geometry["sky"][b] < 0.5)
if "confidence" in geometry and confidence_threshold > 0.0:
conf_norm = _normalize_confidence(geometry["confidence"][b:b + 1])[0]
mask = mask & (conf_norm >= confidence_threshold)
return mask
class LoadDA3Model(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="LoadDA3Model",
display_name="Load Depth Anything 3",
category="model/loaders",
inputs=[
io.Combo.Input(
"model_name",
options=folder_paths.get_filename_list("geometry_estimation"),
),
io.Combo.Input(
"weight_dtype",
options=["default", "fp16", "bf16", "fp32"],
default="default",
),
],
outputs=[DA3ModelType.Output()],
)
@classmethod
def execute(cls, model_name, weight_dtype) -> io.NodeOutput:
model_options = {}
if weight_dtype == "fp16":
model_options["dtype"] = torch.float16
elif weight_dtype == "bf16":
model_options["dtype"] = torch.bfloat16
elif weight_dtype == "fp32":
model_options["dtype"] = torch.float32
path = folder_paths.get_full_path_or_raise("geometry_estimation", model_name)
model = comfy.sd.load_diffusion_model(path, model_options=model_options)
return io.NodeOutput(model)
def _run_da3(model_patcher, image: torch.Tensor, process_res: int, method: str = "upper_bound_resize"):
"""Run DA3 on (B,H,W,3), returns depth/conf/sky at original resolution (or None)."""
assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
B, H, W, _ = image.shape
mm.load_model_gpu(model_patcher)
diffusion = model_patcher.model.diffusion_model
device = mm.get_torch_device()
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
depths, confs, skies = [], [], []
for i in range(B):
single = image[i:i + 1].to(device)
x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method)
x = x.to(dtype=dtype)
with torch.no_grad():
out = diffusion(x)
depth_lr = out["depth"]
depth_full = torch.nn.functional.interpolate(
depth_lr.unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
depths.append(depth_full)
if "depth_conf" in out:
conf_full = torch.nn.functional.interpolate(
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
confs.append(conf_full)
if "sky" in out:
sky_full = torch.nn.functional.interpolate(
out["sky"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
skies.append(sky_full)
depth = torch.cat(depths, dim=0)
confidence = torch.cat(confs, dim=0) if confs else None
sky = torch.cat(skies, dim=0) if skies else None
return depth, confidence, sky
class DA3Inference(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DA3Inference",
search_aliases=["depth", "geometry", "da3", "depth anything", "monocular", "pointmap", "sky", "3d", "metric depth", "disparity"],
display_name="Run Depth Anything 3",
category="image/geometry estimation",
description="Run Depth Anything 3 on an image. In multi-view mode each image is treated as a separate view of the same scene.",
inputs=[
DA3ModelType.Input("da3_model"),
io.Image.Input("image"),
io.Int.Input("resolution", default=504, min=140, max=2520, step=14,
tooltip="Resolution the model runs at (longest side, multiple of 14).\n"
"Lower = faster / less VRAM.\n"
"Higher = more detail.\n"
"Output is upsampled back to the original size."),
io.Combo.Input("resize_method", options=["upper_bound_resize", "lower_bound_resize"], default="upper_bound_resize",
tooltip="upper_bound_resize: scale so the longest side = resolution (caps memory, default).\n"
"lower_bound_resize: scale so the shortest side = resolution (preserves more detail on tall/wide images, uses more memory)."),
io.DynamicCombo.Input("mode", tooltip="mono: single view image (works with any model variant).\n"
"multiview: all images processed together for geometric consistency + camera pose (for Small/Base models only).",
options=[
io.DynamicCombo.Option("mono", []),
io.DynamicCombo.Option("multiview", [
io.Combo.Input("ref_view_strategy", options=["saddle_balanced", "saddle_sim_range", "first", "middle"], default="saddle_balanced",
tooltip="Which view acts as the geometric anchor.\n"
"- saddle_balanced: the view most 'average' across all others (best general choice).\n"
"- saddle_sim_range: the view most visually distinct from the others.\n"
"- first / middle: fixed positional picks."),
io.Combo.Input("pose_method", options=["cam_dec", "ray_pose"], default="cam_dec",
tooltip="How the camera field-of-view is estimated (for Small/Base models only).\n"
"- cam_dec: learned from image features.\n"
"- ray_pose: derived geometrically from the model's 3D ray output.\n"
"Affects perspective correctness of the 3D output. Try both if results look distorted."),
]),
]),
],
outputs=[
DA3Geometry.Output("da3_geometry", tooltip="Dictionary of non-normalized tensors.\n"
"Always has the keys: depth, image, mode.\n"
"Optional keys: sky (for Mono/Metric), confidence (for Small/Base), extrinsics + intrinsics (for multi-view)."),
],
)
@classmethod
def execute(cls, da3_model, image, resolution, resize_method, mode) -> io.NodeOutput:
mode_val = mode["mode"] # "mono" or "multiview"
if mode_val == "mono":
return cls._execute_mono(da3_model, image, resolution, resize_method)
# Capability checks for multi-view mode.
diffusion = da3_model.model.diffusion_model
pose_method = mode["pose_method"]
ref_view_strategy = mode["ref_view_strategy"]
has_cam_dec = diffusion.cam_dec is not None
has_dualdpt = diffusion.head_type == "dualdpt"
if not has_cam_dec and not has_dualdpt:
raise ValueError(
"multi-view mode requires Small or Base model. The loaded model "
f"(head_type='{diffusion.head_type}') does not support cross-view "
"attention or camera pose estimation. Switch mode to 'mono', or "
"load Small or Base model for mult-view."
)
if pose_method == "cam_dec" and not has_cam_dec:
raise ValueError(
"pose_method='cam_dec' requires a camera decoder, but the loaded "
f"model (head_type='{diffusion.head_type}') does not have one. "
"Use pose_method='ray_pose' instead."
)
if pose_method == "ray_pose" and not has_dualdpt:
raise ValueError(
"pose_method='ray_pose' requires a DualDPT head, but the loaded "
f"model has a '{diffusion.head_type}' head. "
"Use pose_method='cam_dec' instead."
)
return cls._execute_multiview(
da3_model, image, resolution, resize_method,
ref_view_strategy, pose_method,
)
@classmethod
def _execute_mono(cls, model, image, resolution, resize_method) -> io.NodeOutput:
depth, confidence, sky = _run_da3(model, image, resolution, method=resize_method)
geometry: dict = {
"depth": depth.contiguous(),
"image": image[..., :3].cpu(),
"mode": "mono",
}
if sky is not None:
geometry["sky"] = sky.contiguous()
if confidence is not None:
geometry["confidence"] = confidence.contiguous()
return io.NodeOutput(geometry)
@classmethod
def _execute_multiview(cls, model, image, resolution, resize_method, ref_view_strategy, pose_method) -> io.NodeOutput:
assert image.ndim == 4 and image.shape[-1] == 3, \
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
S, H, W, _ = image.shape
mm.load_model_gpu(model)
diffusion = model.model.diffusion_model
device = mm.get_torch_device()
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
# All views in a single forward pass: (1, S, 3, H', W').
x = image.to(device)
x = da3_preprocess.preprocess_image(x, process_res=resolution, method=resize_method)
x = x.to(dtype=dtype).unsqueeze(0)
use_ray_pose = (pose_method == "ray_pose")
with torch.no_grad():
out = diffusion(x, use_ray_pose=use_ray_pose, ref_view_strategy=ref_view_strategy)
depth = torch.nn.functional.interpolate(
out["depth"].float().unsqueeze(1), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
sky = None
if "sky" in out:
sky = torch.nn.functional.interpolate(
out["sky"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
if "extrinsics" in out and "intrinsics" in out:
extrinsics = out["extrinsics"].float().cpu()
intrinsics = out["intrinsics"].float().cpu()
else:
extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone()
intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone()
geometry: dict = {
"depth": depth.contiguous(),
"image": image[..., :3].cpu(),
"mode": "multiview",
"extrinsics": extrinsics.contiguous(),
"intrinsics": intrinsics.contiguous(),
}
if sky is not None:
geometry["sky"] = sky.contiguous()
if "depth_conf" in out:
conf = torch.nn.functional.interpolate(
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
mode="bilinear", align_corners=False,
).squeeze(1).cpu()
geometry["confidence"] = conf.contiguous()
return io.NodeOutput(geometry)
class DA3Render(io.ComfyNode):
"""Render a visualization from a DA3_GEOMETRY packet."""
_DEPTH_RENDER_INPUTS = [
io.Combo.Input("normalization",
options=["v2_style", "min_max", "raw"],
default="v2_style",
tooltip="- v2_style: mean/std normalisation for perceptually balanced results (default).\n"
"- min_max: stretches the full depth range to [0, 1] for maximum contrast.\n"
"- raw: no scaling,preserves metric units for Metric model."),
io.Boolean.Input("apply_sky_clip", default=False,
tooltip="Clip sky-region depth to the 99th percentile of foreground depth before normalisation. "
"Requires a sky key in the da3_geometry input (for Mono/Metric models only)."),
]
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DA3Render",
display_name="Render Depth Anything 3",
category="image/geometry estimation",
description="Render a depth map, confidence map, or sky mask from Depth Anything 3 geometry data.",
inputs=[
DA3Geometry.Input("da3_geometry"),
io.DynamicCombo.Input("output",
tooltip="- depth: normalised greyscale depth image.\n"
"- depth_colored: depth mapped through the Turbo colormap.\n"
"- sky_mask: sky probability in [0, 1] (for Mono/Metric models only).\n"
"- confidence: normalised depth confidence (for Small/Base models only).",
options=[
io.DynamicCombo.Option("depth", cls._DEPTH_RENDER_INPUTS),
io.DynamicCombo.Option("depth_colored", cls._DEPTH_RENDER_INPUTS),
io.DynamicCombo.Option("sky_mask", [
io.Boolean.Input("colored", default=False, tooltip="Apply the Turbo colormap to the sky mask."),
]),
io.DynamicCombo.Option("confidence", [
io.Boolean.Input("colored", default=False, tooltip="Apply the Turbo colormap to the confidence map."),
]),
]),
],
outputs=[io.Image.Output()],
)
@classmethod
def execute(cls, da3_geometry, output) -> io.NodeOutput:
output_val = output["output"]
if output_val in ("depth", "depth_colored"):
normalization = output["normalization"]
apply_sky_clip = output["apply_sky_clip"]
if apply_sky_clip and "sky" not in da3_geometry:
raise ValueError(
"apply_sky_clip=True requires a sky tensor in the da3_geometry input, but none is present. "
"Run with Mono/Metric models or set apply_sky_clip=False."
)
depth = da3_geometry["depth"]
sky = da3_geometry.get("sky")
if apply_sky_clip and sky is not None:
depth = torch.stack([
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
for i in range(depth.shape[0])
], dim=0)
grey = cls._depth_to_image(depth, sky, normalization) # (B,H,W,3) greyscale
result = _turbo(grey[..., 0]) if output_val == "depth_colored" else grey
elif output_val == "sky_mask":
if "sky" not in da3_geometry:
raise ValueError("geometry has no sky output; run with Mono/Metric models.")
sky = da3_geometry["sky"]
if output["colored"]:
result = _turbo(sky)
else:
result = sky.unsqueeze(-1).expand(*sky.shape, 3).contiguous()
elif output_val == "confidence":
if "confidence" not in da3_geometry:
raise ValueError("da3_geometry has no confidence output; run with Small/Base models.")
conf = _normalize_confidence(da3_geometry["confidence"])
if output["colored"]:
result = _turbo(conf)
else:
result = conf.unsqueeze(-1).expand(*conf.shape, 3).contiguous()
else:
raise ValueError(f"Unknown output mode: {output_val}")
return io.NodeOutput(result.float())
@staticmethod
def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None, normalization: str) -> torch.Tensor:
"""Normalise depth and pack as an (B,H,W,3) image tensor."""
N = depth.shape[0]
if normalization == "v2_style":
norm = torch.stack([
da3_preprocess.normalize_depth_v2_style(
depth[i], sky_for_norm[i] if sky_for_norm is not None else None)
for i in range(N)
], dim=0)
elif normalization == "min_max":
norm = da3_preprocess.normalize_depth_min_max(depth)
else:
norm = depth
out = norm.unsqueeze(-1).repeat(1, 1, 1, 3)
if normalization != "raw":
out = out.clamp(0.0, 1.0)
return out.contiguous()
class DA3GeometryToMesh(io.ComfyNode):
"""Convert a DA3_GEOMETRY packet into a Types.MESH by unprojecting depth and triangulating."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DA3GeometryToMesh",
search_aliases=["da3", "depth anything", "mesh", "geometry", "3d", "triangulate"],
display_name="Convert DA3 Geometry to Mesh",
category="image/geometry estimation",
description="Convert a depth map into a triangulated 3D mesh.",
inputs=[
DA3Geometry.Input("da3_geometry"),
io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batch to convert. Per-image vertex counts differ so batches cannot be stacked."),
io.Int.Input("decimation", default=1, min=1, max=8, tooltip="Vertex stride. 1 = full resolution, 2 = half, etc."),
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01, tooltip="Drop triangles whose 3x3 depth span exceeds this fraction. 0 = off."),
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all, 1 = keep only the single most confident pixel). "
"Used when the geometry has a confidence map (Small/Base models)."),
io.Boolean.Input("use_sky_mask", default=True, tooltip="Exclude sky-probability pixels (sky >= 0.5) from the mesh. Used when the geometry has a sky map (Mono/Metric models)."),
io.Boolean.Input("texture", default=True, tooltip="Use the source image as a base color texture."),
],
outputs=[io.Mesh.Output()],
)
@classmethod
def execute(cls, da3_geometry, batch_index, decimation, discontinuity_threshold, confidence_threshold, use_sky_mask, texture) -> io.NodeOutput:
depth_all = da3_geometry["depth"] # (B, H, W)
B = depth_all.shape[0]
if batch_index >= B:
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
depth = depth_all[batch_index] # (H, W)
H, W = depth.shape
# NaN/inf depth would propagate silently through unproject and produce an
# empty mesh; replace them with 0 here so those pixels are later excluded
# by the isfinite check inside triangulate_grid_mesh.
depth = depth.clone()
n_bad = (~torch.isfinite(depth)).sum().item()
if n_bad:
logging.getLogger("comfy").warning(
f"DA3GeometryToMesh: depth[{batch_index}] has {n_bad} non-finite pixels "
f"({100*n_bad/(H*W):.1f}%) - zeroed before unproject."
)
depth[~torch.isfinite(depth)] = 0.0
logging.getLogger("comfy").debug(
f"DA3GeometryToMesh: depth[{batch_index}] range "
f"[{depth.min():.4g}, {depth.max():.4g}], mean={depth.mean():.4g}"
)
K = _da3_get_K(da3_geometry, batch_index, H, W)
points = _da3_unproject(depth, K) # (H, W, 3) in OpenCV camera space
# Apply world-to-camera inverse so multi-view frames share a common world frame.
E = _da3_get_extrinsic(da3_geometry, batch_index)
if E is not None:
points = _da3_apply_extrinsic(points, E)
# Mask invalid pixels by setting them to inf so triangulate_grid_mesh skips them.
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
# Also exclude pixels where depth was invalid.
mask = mask & (depth_all[batch_index] > 0) & torch.isfinite(depth_all[batch_index])
points = points.clone()
points[~mask] = float('inf')
verts, faces, uvs = triangulate_grid_mesh(
points,
decimation=decimation,
discontinuity_threshold=discontinuity_threshold,
depth=depth,
)
if verts.shape[0] == 0 or faces.shape[0] == 0:
raise ValueError(
"DA3GeometryToMesh produced an empty mesh. "
"Try raising discontinuity_threshold, lowering confidence_threshold, "
"or disabling use_sky_mask."
)
# OpenCV (X right, Y down, Z forward) → glTF (X right, Y up, Z back).
# Same transform as MoGePointMapToMesh perspective branch.
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
faces = faces[:, [0, 2, 1]].contiguous()
tex = da3_geometry["image"][batch_index:batch_index + 1] if texture else None
mesh = Types.MESH(
vertices=verts.unsqueeze(0),
faces=faces.unsqueeze(0),
uvs=uvs.unsqueeze(0),
texture=tex,
)
return io.NodeOutput(mesh)
class DA3GeometryToPointCloud(io.ComfyNode):
"""Unproject a DA3_GEOMETRY depth map into a filtered DA3_POINT_CLOUD."""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="DA3GeometryToPointCloud",
search_aliases=["da3", "depth anything", "point cloud", "pointcloud", "3d", "geometry"],
display_name="Convert DA3 Geometry to Point Cloud",
category="image/geometry estimation",
description="Convert a depth map into a 3D point cloud.",
inputs=[
DA3Geometry.Input("da3_geometry"),
io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batch to convert."),
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all). Used when the geometry has a confidence map (Small/Base models)."),
io.Boolean.Input("use_sky_mask", default=True,
tooltip="Exclude sky-probability pixels (sky >= 0.5). Used when the geometry has a sky map (Mono/Metric models)."),
io.Int.Input("downsample", default=1, min=1, max=16,
tooltip="Take every Nth pixel (1 = full resolution). Higher values give fewer points and faster processing."),
],
# TODO: add a proper PointCloud output type
outputs=[DA3PointCloud.Output(display_name="point_cloud")],
)
@classmethod
def execute(cls, da3_geometry, batch_index, confidence_threshold, use_sky_mask, downsample) -> io.NodeOutput:
depth_all = da3_geometry["depth"] # (B, H, W)
B = depth_all.shape[0]
if batch_index >= B:
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
depth = depth_all[batch_index].clone() # (H, W)
depth[~torch.isfinite(depth)] = 0.0
H, W = depth.shape
K = _da3_get_K(da3_geometry, batch_index, H, W)
if downsample > 1:
depth = depth[::downsample, ::downsample].contiguous()
# Scale intrinsics to the downsampled grid.
K = K.clone()
K[0, :] /= downsample
K[1, :] /= downsample
H_ds, W_ds = depth.shape
points = _da3_unproject(depth, K) # (H_ds, W_ds, 3) in OpenCV camera space
# Apply world-to-camera inverse so multi-view frames share a common world frame.
E = _da3_get_extrinsic(da3_geometry, batch_index)
if E is not None:
points = _da3_apply_extrinsic(points, E)
# Rebuild mask at downsampled resolution.
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
if downsample > 1:
mask = mask[::downsample, ::downsample]
mask = mask & torch.isfinite(depth)
# OpenCV → glTF: flip Y and Z.
points_gltf = points.clone()
points_gltf[..., 1] *= -1.0
points_gltf[..., 2] *= -1.0
pts_flat = points_gltf.reshape(-1, 3)[mask.reshape(-1)]
colors_flat = None
if "image" in da3_geometry:
img = da3_geometry["image"][batch_index] # (H, W, 3)
if downsample > 1:
img = img[::downsample, ::downsample]
colors_flat = img.reshape(-1, 3)[mask.reshape(-1)]
conf_flat = None
if "confidence" in da3_geometry:
conf = da3_geometry["confidence"][batch_index] # (H, W)
if downsample > 1:
conf = conf[::downsample, ::downsample]
conf_flat = conf.reshape(-1)[mask.reshape(-1)]
if pts_flat.shape[0] == 0:
raise ValueError(
"DA3GeometryToPointCloud produced zero points after filtering. "
"Try lowering confidence_threshold or disabling use_sky_mask."
)
return io.NodeOutput({
"points": pts_flat,
"colors": colors_flat,
"confidence": conf_flat,
})
class DA3Extension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
LoadDA3Model,
DA3Inference,
DA3Render,
DA3GeometryToMesh,
# DA3GeometryToPointCloud, # Keep this commented out for now until we have a proper PointCloud output type
]
async def comfy_entrypoint() -> DA3Extension:
return DA3Extension()

View File

@ -488,7 +488,7 @@ class SplatToFile3D(IO.ComfyNode):
"spz: Niantic gzip-compressed (~10x smaller), base color only "
),
],
outputs=[IO.File3DAny.Output(display_name="model_3d")],
outputs=[IO.File3DSplatAny.Output(display_name="model_3d")],
)
@classmethod
@ -516,7 +516,7 @@ class File3DToSplat(IO.ComfyNode):
inputs=[
IO.MultiType.Input(
IO.File3DAny.Input("model_3d"),
types=[IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ],
types=[IO.File3DSplatAny, IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ],
tooltip="A gaussian splat 3D file",
),
],

View File

@ -51,6 +51,14 @@ class Load3D(IO.ComfyNode):
],
)
@classmethod
def validate_inputs(cls, model_file, **kwargs) -> bool | str:
if not model_file or model_file == "none":
return True
if not folder_paths.exists_annotated_filepath(model_file):
return f"Invalid 3D model file: {model_file}"
return True
@classmethod
def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
image_path = folder_paths.get_annotated_filepath(image['image'])
@ -136,7 +144,7 @@ class Preview3DAdvanced(IO.ComfyNode):
is_output_node=True,
inputs=[
IO.MultiType.Input(
"model_file",
"model_3d",
types=[
IO.File3DGLB,
IO.File3DGLTF,
@ -148,34 +156,161 @@ class Preview3DAdvanced(IO.ComfyNode):
],
tooltip="3D model file from an upstream 3D node.",
),
IO.Load3D.Input("image"),
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True),
IO.Load3D.Input("viewport_state"),
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
],
outputs=[
IO.File3DAny.Output(display_name="model_file"),
IO.Load3DCamera.Output(display_name="camera_info"),
IO.File3DAny.Output(display_name="model_3d"),
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
IO.Load3DCamera.Output(display_name="camera_info"),
IO.Int.Output(display_name="width"),
IO.Int.Output(display_name="height"),
],
)
@classmethod
def execute(cls, model_file: Types.File3D, image, width: int, height: int, **kwargs) -> IO.NodeOutput:
filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_file.format}"
model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_3d.format}"
model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename))
camera_info_input = kwargs.get("camera_info", None)
camera_info = camera_info_input if camera_info_input is not None else image['camera_info']
camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info']
model_3d_info_input = kwargs.get("model_3d_info", None)
model_3d_info = model_3d_info_input if model_3d_info_input is not None else image.get('model_3d_info', [])
model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', [])
return IO.NodeOutput(
model_file,
camera_info,
model_3d,
model_3d_info,
camera_info,
width,
height,
ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info),
)
class PreviewGaussianSplat(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="PreviewGaussianSplat",
display_name="Preview Splat",
category="3d",
is_experimental=True,
is_output_node=True,
search_aliases=[
"view splat",
"view gaussian",
"view gaussian splat",
"preview gaussian",
"preview gaussian splat",
"view 3dgs",
"preview 3dgs",
"preview ply",
"preview spz",
"preview splat",
"preview ksplat",
],
inputs=[
IO.MultiType.Input(
"model_3d",
types=[
IO.File3DSplatAny,
IO.File3DPLY,
IO.File3DSPLAT,
IO.File3DSPZ,
IO.File3DKSPLAT,
],
tooltip="A gaussian splat 3D file.",
),
IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True),
IO.Load3D.Input("viewport_state"),
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
],
outputs=[
IO.File3DSplatAny.Output(display_name="model_3d"),
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
IO.Load3DCamera.Output(display_name="camera_info"),
IO.Int.Output(display_name="width"),
IO.Int.Output(display_name="height"),
],
)
@classmethod
def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
filename = f"preview_splat_{uuid.uuid4().hex}.{model_3d.format}"
model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename))
camera_info_input = kwargs.get("camera_info", None)
camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info']
model_3d_info_input = kwargs.get("model_3d_info", None)
model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', [])
return IO.NodeOutput(
model_3d,
model_3d_info,
camera_info,
width,
height,
ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info),
)
class PreviewPointCloud(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="PreviewPointCloud",
display_name="Preview Point Cloud",
category="3d",
is_experimental=True,
is_output_node=True,
search_aliases=[
"view point cloud",
"view pointcloud",
"preview point cloud",
"preview pointcloud",
"preview ply",
],
inputs=[
IO.MultiType.Input(
"model_3d",
types=[
IO.File3DPointCloudAny,
IO.File3DPLY,
],
tooltip="Point cloud file (.ply)",
),
IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True),
IO.Load3D.Input("viewport_state"),
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
],
outputs=[
IO.File3DPointCloudAny.Output(display_name="model_3d"),
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
IO.Load3DCamera.Output(display_name="camera_info"),
IO.Int.Output(display_name="width"),
IO.Int.Output(display_name="height"),
],
)
@classmethod
def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
filename = f"preview_pointcloud_{uuid.uuid4().hex}.{model_3d.format}"
model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename))
camera_info_input = kwargs.get("camera_info", None)
camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info']
model_3d_info_input = kwargs.get("model_3d_info", None)
model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', [])
return IO.NodeOutput(
model_3d,
model_3d_info,
camera_info,
width,
height,
ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info),
@ -189,6 +324,8 @@ class Load3DExtension(ComfyExtension):
Load3D,
Preview3D,
Preview3DAdvanced,
PreviewGaussianSplat,
PreviewPointCloud,
]

View File

@ -8,6 +8,7 @@ import folder_paths
from comfy_api.latest import ComfyExtension, Types, io
from typing_extensions import override
from comfy.ldm.colormap import turbo as _turbo
from comfy.ldm.moge.model import MoGeModel
from comfy.ldm.moge.geometry import triangulate_grid_mesh
from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid
@ -27,19 +28,6 @@ MoGeGeometry = io.Custom("MOGE_GEOMETRY")
# "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present)
def _turbo(x: torch.Tensor) -> torch.Tensor:
"""Anton Mikhailov polynomial approximation of the turbo colormap."""
x = x.clamp(0.0, 1.0)
x2 = x * x
x3 = x2 * x
x4 = x2 * x2
x5 = x4 * x
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
def _normals_from_points(points: torch.Tensor) -> torch.Tensor:
"""Camera-space surface normals from a (B, H, W, 3) point map (v1 fallback)."""
finite = torch.isfinite(points).all(dim=-1)

View File

@ -6,24 +6,24 @@ from comfy_api.latest import ComfyExtension, io
class AspectRatio(str, Enum):
SQUARE = "1:1 (Square)"
PHOTO_V = "2:3 (Portrait Photo)"
PHOTO_H = "3:2 (Photo)"
STANDARD_V = "3:4 (Portrait Standard)"
STANDARD_H = "4:3 (Standard)"
WIDESCREEN_V = "9:16 (Portrait Widescreen)"
WIDESCREEN_H = "16:9 (Widescreen)"
ULTRAWIDE_H = "21:9 (Ultrawide)"
PHOTO_V = "2:3 (Portrait Photo)"
STANDARD_V = "3:4 (Portrait Standard)"
WIDESCREEN_V = "9:16 (Portrait Widescreen)"
ASPECT_RATIOS: dict[AspectRatio, tuple[int, int]] = {
AspectRatio.SQUARE: (1, 1),
AspectRatio.PHOTO_V: (2, 3),
AspectRatio.PHOTO_H: (3, 2),
AspectRatio.STANDARD_V: (3, 4),
AspectRatio.STANDARD_H: (4, 3),
AspectRatio.WIDESCREEN_V: (9, 16),
AspectRatio.WIDESCREEN_H: (16, 9),
AspectRatio.ULTRAWIDE_H: (21, 9),
AspectRatio.PHOTO_V: (2, 3),
AspectRatio.STANDARD_V: (3, 4),
AspectRatio.WIDESCREEN_V: (9, 16),
}
@ -50,26 +50,35 @@ class ResolutionSelector(io.ComfyNode):
min=0.1,
max=16.0,
step=0.1,
tooltip="Target total megapixels. 1.0 MP ≈ 1024×1024 for square.",
tooltip="Target total megapixels. 1.0 MP ≈ 1024x1024 for square.",
),
io.Int.Input(
id="multiple",
default=8,
min=8,
max=128,
step=4,
tooltip="Nearest multiple of the result to set the selected resolution to.",
advanced=True,
),
],
outputs=[
io.Int.Output(
"width", tooltip="Calculated width in pixels (multiple of 8)."
"width", tooltip="Calculated width in pixels multiplied by the selected multiple."
),
io.Int.Output(
"height", tooltip="Calculated height in pixels (multiple of 8)."
"height", tooltip="Calculated height in pixels multiplied by the selected multiple."
),
],
)
@classmethod
def execute(cls, aspect_ratio: str, megapixels: float) -> io.NodeOutput:
def execute(cls, aspect_ratio: str, megapixels: float, multiple: int) -> io.NodeOutput:
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
total_pixels = megapixels * 1024 * 1024
scale = math.sqrt(total_pixels / (w_ratio * h_ratio))
width = round(w_ratio * scale / 8) * 8
height = round(h_ratio * scale / 8) * 8
width = round(w_ratio * scale / multiple) * multiple
height = round(h_ratio * scale / multiple) * multiple
return io.NodeOutput(width, height)

View File

@ -337,6 +337,12 @@ class SaveGLB(IO.ComfyNode):
IO.File3DFBX,
IO.File3DSTL,
IO.File3DUSDZ,
IO.File3DPLY,
IO.File3DSPLAT,
IO.File3DSPZ,
IO.File3DKSPLAT,
IO.File3DSplatAny,
IO.File3DPointCloudAny,
IO.File3DAny,
],
tooltip="Mesh or 3D file to save",

321
comfy_extras/nodes_scail.py Normal file
View File

@ -0,0 +1,321 @@
"""SCAIL / SCAIL-2 nodes: the WanSCAILToVideo conditioning node and the SAM3
preprocessing that turns video tracks into the bundle the SCAIL-2 model consumes."""
from typing_extensions import override
import torch
import torch.nn.functional as F
import nodes
import node_helpers
import comfy.model_management
import comfy.utils
from comfy_api.latest import ComfyExtension, io
from comfy.ldm.sam3.tracker import unpack_masks
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
# Model was trained on these exact colors; deviating degrades multi-identity quality.
DEFAULT_PALETTE = [
(0.0, 0.0, 1.0), # Blue
(1.0, 0.0, 0.0), # Red
(0.0, 1.0, 0.0), # Green
(1.0, 0.0, 1.0), # Magenta
(0.0, 1.0, 1.0), # Cyan
(1.0, 1.0, 0.0), # Yellow
]
def _unpack(track_data):
packed = track_data["packed_masks"]
if packed is None or packed.shape[1] == 0:
return None
return unpack_masks(packed)
def _first_frame_cx_area(masks_bool):
first = masks_bool[0].float()
H, W = first.shape[-2], first.shape[-1]
n_pixels = H * W
grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W)
area = first.sum(dim=(-1, -2)).clamp_(min=1)
cx = (first * grid_x).sum(dim=(-1, -2)) / area
return (cx / W).tolist(), (area / n_pixels).tolist()
def _subset_track_data(track_data, obj_indices):
out = dict(track_data)
packed = track_data["packed_masks"]
if packed is None or not obj_indices:
out["packed_masks"] = None
if "scores" in out:
out["scores"] = []
return out
out["packed_masks"] = packed[:, obj_indices].contiguous()
scores = track_data.get("scores")
if scores is not None:
out["scores"] = [scores[i] for i in obj_indices if i < len(scores)]
return out
def _render_colored_masks(track_data, background="black"):
packed = track_data["packed_masks"]
H, W = track_data["orig_size"]
device = comfy.model_management.intermediate_device()
dtype = comfy.model_management.intermediate_dtype()
bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0)
if packed is None or packed.shape[1] == 0:
T = track_data.get("n_frames", 1) if packed is None else packed.shape[0]
out = torch.empty(T, H, W, 3, device=device, dtype=dtype)
out[..., 0], out[..., 1], out[..., 2] = bg_rgb[0], bg_rgb[1], bg_rgb[2]
return out
T, N_obj = packed.shape[0], packed.shape[1]
colors = torch.tensor(
[DEFAULT_PALETTE[i % len(DEFAULT_PALETTE)] for i in range(N_obj)],
device=device, dtype=dtype,
)
masks_full = unpack_masks(packed.to(device)).float()
Hm, Wm = masks_full.shape[-2], masks_full.shape[-1]
masks_full = F.interpolate(
masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest"
).view(T, N_obj, H, W) > 0.5
any_mask = masks_full.any(dim=1)
obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1)
color_overlay = colors[obj_idx_map]
bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3)
return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay))
def _extract_mask_to_28ch(rgb_video):
"""Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent
(1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c)
threshold-extracted at 225/255, 8x spatial downsample, 4-frame temporal stacking."""
T, H, W, _ = rgb_video.shape
_ON_THRESH = 225.0 / 255.0
mask = rgb_video.movedim(-1, 1).float()
R = (mask[:, 0:1] > _ON_THRESH).float()
G = (mask[:, 1:2] > _ON_THRESH).float()
B = (mask[:, 2:3] > _ON_THRESH).float()
nR, nG, nB = 1 - R, 1 - G, 1 - B
binary_7ch = torch.cat([
R * G * B, # white
R * nG * nB, # red
nR * G * nB, # green
nR * nG * B, # blue
R * G * nB, # yellow
R * nG * B, # magenta
nR * G * B, # cyan
], dim=1)
H_lat, W_lat = H, W
for _ in range(3):
H_lat = (H_lat + 1) // 2
W_lat = (W_lat + 1) // 2
binary_7ch = torch.nn.functional.interpolate(binary_7ch, size=(H_lat, W_lat), mode='area')
T_latent = (T - 1) // 4 + 1
padded = torch.cat([binary_7ch[:1].repeat(4, 1, 1, 1), binary_7ch[1:]], dim=0)
out = padded.view(T_latent, 28, H_lat, W_lat)
return out.unsqueeze(0)
class WanSCAILToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanSCAILToVideo",
category="model/conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
io.Image.Input("pose_video_mask", optional=True, tooltip="SCAIL-2 only. Colored per-identity SAM3 mask video at the same resolution as pose_video."),
io.Boolean.Input("replacement_mode", default=False, optional=True, tooltip="SCAIL-2 only. False = Animation Mode (pose_video_mask should have black background). True = Replacement Mode (pose_video_mask should have white background)."),
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."),
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."),
io.Image.Input("reference_image", optional=True, tooltip="Reference image, for multiple references composite all on single image."),
io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask at the same resolution as reference_image."),
io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."),
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."),
io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."),
io.Image.Input("previous_frames", optional=True, tooltip="SCAIL-2 only. Full decoded output of the previous chunk. Only the last previous_frame_count are used as the extension anchor."),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
io.Int.Output(display_name="video_frame_offset", tooltip="Adjusted offset + length. Wire into the next chunk."),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end,
video_frame_offset, previous_frame_count, replacement_mode=False, reference_image=None, clip_vision_output=None, pose_video=None,
pose_video_mask=None, reference_image_mask=None, previous_frames=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
noise_mask = None
ref_mask_flag = not replacement_mode
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_flag": ref_mask_flag})
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_flag": ref_mask_flag})
prev_trimmed = None
if previous_frames is not None and previous_frames.shape[0] > 0:
prev_trimmed = previous_frames[-previous_frame_count:]
video_frame_offset -= prev_trimmed.shape[0]
video_frame_offset = max(0, video_frame_offset)
ref_latent = None
if reference_image is not None:
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
# Replacement Mode: composite ref on black bg using reference_image_mask as alpha matte
if replacement_mode and reference_image_mask is not None:
rm = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype)
reference_image = reference_image * is_char
ref_latent = vae.encode(reference_image[:, :, :, :3])
if ref_latent is not None:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
if pose_video is not None:
if pose_video.shape[0] <= video_frame_offset:
pose_video = None
else:
pose_video = pose_video[video_frame_offset:]
if pose_video_mask is not None:
if pose_video_mask.shape[0] <= video_frame_offset:
pose_video_mask = None
else:
pose_video_mask = pose_video_mask[video_frame_offset:]
# Truncate pose+mask jointly to the shorter of the two, capped at length.
ts = [v.shape[0] for v in (pose_video, pose_video_mask) if v is not None]
if ts:
T_kept = ((min(min(ts), length) - 1) // 4) * 4 + 1
if pose_video is not None:
pose_video = pose_video[:T_kept]
if pose_video_mask is not None:
pose_video_mask = pose_video_mask[:T_kept]
if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
if pose_video_mask is not None:
mask_video_hw = comfy.utils.common_upscale(pose_video_mask[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
driving_mask_28ch = _extract_mask_to_28ch(mask_video_hw)
positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch})
negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch})
if reference_image_mask is not None:
ref_mask_hw = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw)
zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype)
ref_mask_28ch = torch.cat([ref_mask_1f, zeros], dim=1)
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch})
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch})
if prev_trimmed is not None:
pf = comfy.utils.common_upscale(prev_trimmed.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
prev_latent = vae.encode(pf[:, :, :, :3])
prev_latent_frames = min(prev_latent.shape[2], latent.shape[2])
latent[:, :, :prev_latent_frames] = prev_latent[:, :, :prev_latent_frames].to(latent.dtype)
noise_mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=latent.device, dtype=latent.dtype)
noise_mask[:, :, :prev_latent_frames] = 0.0
out_latent = {"samples": latent}
if noise_mask is not None:
out_latent["noise_mask"] = noise_mask
return io.NodeOutput(positive, negative, out_latent, video_frame_offset + length)
class SCAIL2ColoredMask(io.ComfyNode):
"""Render SAM3 tracks for the driving pose video and (optionally) the reference
image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by`
across both outputs guarantees identity K maps to the same color on both
sides, for multi-person workflow consistency.
reference_image_mask is always rendered black-bg (model convention)
pose_video_mask bg follows replacement_mode: black = Animation Mode, white = Replacement Mode
"""
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SCAIL2ColoredMask",
display_name="Create SCAIL-2 Colored Mask",
category="conditioning/video_models/scail",
inputs=[
SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."),
SAM3TrackData.Input("ref_track_data", optional=True,
tooltip="SAM3 track of the reference image."),
io.String.Input("object_indices", default="",
tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."),
io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right",
tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."),
io.Boolean.Input("replacement_mode", default=False,
tooltip="False = mask_video has black bg (Animation Mode). True = white bg (Replacement Mode). Set the matching replacement_mode on WanSCAILToVideo. reference_image_mask is always black-bg regardless."),
],
outputs=[
io.Image.Output("pose_video_mask"),
io.Image.Output("reference_image_mask"),
],
is_experimental=True,
)
@classmethod
def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None):
def _prep(td):
masks_bool = _unpack(td)
if sort_by != "none" and masks_bool is not None:
cx, area = _first_frame_cx_area(masks_bool)
if sort_by == "left_to_right":
order = sorted(range(len(cx)), key=lambda i: cx[i])
else: # "area"
order = sorted(range(len(area)), key=lambda i: -area[i])
td = _subset_track_data(td, order)
if object_indices.strip():
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
packed = td.get("packed_masks")
n_obj = packed.shape[1] if packed is not None else 0
indices = [i for i in indices if 0 <= i < n_obj]
td = _subset_track_data(td, indices)
return td
drv = _prep(driving_track_data)
mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black")
if ref_track_data is not None:
ref = _prep(ref_track_data)
reference_image_mask = _render_colored_masks(ref, "black")
else:
H, W = drv["orig_size"]
reference_image_mask = torch.zeros(1, H, W, 3, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return io.NodeOutput(mask_video, reference_image_mask)
class SCAILExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
WanSCAILToVideo,
SCAIL2ColoredMask,
]
async def comfy_entrypoint() -> SCAILExtension:
return SCAILExtension()

View File

@ -15,6 +15,7 @@ import comfy.sampler_helpers
import comfy.sd
import comfy.utils
import comfy.model_management
from comfy.conds import CONDRegular, CONDList
from comfy.cli_args import args, PerformanceFeature
import comfy_extras.nodes_custom_sampler
import folder_paths
@ -120,6 +121,11 @@ def process_cond_list(d, prefix=""):
process_cond_list(v, f"{prefix}.{k}")
elif isinstance(v, torch.Tensor):
d[k] = v.clone()
elif isinstance(v, CONDList):
v.cond = [t.detach() if isinstance(t, torch.Tensor) else t for t in v.cond]
elif isinstance(v, CONDRegular):
if isinstance(v.cond, torch.Tensor):
v.cond = v.cond.detach()
elif isinstance(v, (list, tuple)):
for index, item in enumerate(v):
process_cond_list(item, f"{prefix}.{k}.{index}")
@ -1143,45 +1149,45 @@ class TrainLoraNode(io.ComfyNode):
# Process conditioning
positive = _process_conditioning(positive)
# Setup model and dtype
mp = model.clone()
use_grad_scaler = False
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
else:
# Detect model's native dtype for autocast
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
# GradScaler only supports float16 gradients, not bfloat16.
# Only enable it when lora params will also be in float16.
if lora_dtype != torch.bfloat16:
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
# Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count(
latents, latents_dtype, bucket_mode
)
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
with torch.inference_mode(False):
# Setup model and dtype
mp = model.clone(force_deepcopy=True)
use_grad_scaler = False
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
else:
# Detect model's native dtype for autocast
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
# GradScaler only supports float16 gradients, not bfloat16.
# Only enable it when lora params will also be in float16.
if lora_dtype != torch.bfloat16:
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
# Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count(
latents, latents_dtype, bucket_mode
)
# Validate and expand conditioning
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
# Setup models for training
mp.model.requires_grad_(False)
mp.model.requires_grad_(False).train()
# Load existing LoRA weights if provided
existing_weights, existing_steps = _load_existing_lora(existing_lora)

View File

@ -19,7 +19,7 @@ class SaveWEBM(io.ComfyNode):
category="video",
is_experimental=True,
inputs=[
io.Image.Input("images"),
io.Image.Input("images", tooltip="RGBA images are saved with their alpha channel as transparency (vp9 codec only)."),
io.String.Input("filename_prefix", default="ComfyUI"),
io.Combo.Input("codec", options=["vp9", "av1"]),
io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01),
@ -45,18 +45,25 @@ class SaveWEBM(io.ComfyNode):
for x in cls.hidden.extra_pnginfo:
container.metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x])
# Save transparency when the images carry an alpha channel (RGBA) and the codec supports it.
# vp9 -> yuva420p; other codecs have no usable alpha path, so the alpha is ignored.
save_alpha = images.shape[-1] == 4 and codec == "vp9"
codec_map = {"vp9": "libvpx-vp9", "av1": "libsvtav1"}
stream = container.add_stream(codec_map[codec], rate=Fraction(round(fps * 1000), 1000))
stream.width = images.shape[-2]
stream.height = images.shape[-3]
stream.pix_fmt = "yuv420p10le" if codec == "av1" else "yuv420p"
stream.pix_fmt = "yuva420p" if save_alpha else ("yuv420p10le" if codec == "av1" else "yuv420p")
stream.bit_rate = 0
stream.options = {'crf': str(crf)}
if codec == "av1":
stream.options["preset"] = "6"
for frame in images:
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
if save_alpha:
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :4] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgba")
else:
frame = av.VideoFrame.from_ndarray(torch.clamp(frame[..., :3] * 255, min=0, max=255).to(device=torch.device("cpu"), dtype=torch.uint8).numpy(), format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
container.mux(stream.encode())

View File

@ -1456,63 +1456,6 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
class WanSCAILToVideo(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="WanSCAILToVideo",
category="model/conditioning/video_models",
inputs=[
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Vae.Input("vae"),
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
io.Int.Input("batch_size", default=1, min=1, max=4096),
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
io.Image.Input("reference_image", optional=True),
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."),
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."),
],
outputs=[
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
],
is_experimental=True,
)
@classmethod
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
ref_latent = None
if reference_image is not None:
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
ref_latent = vae.encode(reference_image[:, :, :, :3])
if ref_latent is not None:
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
if clip_vision_output is not None:
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
if pose_video is not None:
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
out_latent = {}
out_latent["samples"] = latent
return io.NodeOutput(positive, negative, out_latent)
class WanExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -1533,7 +1476,6 @@ class WanExtension(ComfyExtension):
WanAnimateToVideo,
Wan22ImageToVideoLatent,
WanInfiniteTalkToVideo,
WanSCAILToVideo,
]
async def comfy_entrypoint() -> WanExtension:

View File

@ -2,6 +2,7 @@ import os
import importlib.util
from comfy.cli_args import args, PerformanceFeature
import subprocess
import re
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
def get_gpu_names():
@ -77,11 +78,24 @@ try:
except:
pass
def get_raw_cuda_version(version_str):
match = re.search(r'\+cu(\d+)', version_str)
if match:
try:
return int(match.group(1))
except:
pass
return None
if not args.cuda_malloc:
try:
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
args.cuda_malloc = cuda_malloc_supported()
cuda_version = get_raw_cuda_version(version)
if cuda_version is not None and cuda_version >= 130:
args.cuda_malloc = True
else:
args.cuda_malloc = cuda_malloc_supported()
except:
pass

15
main.py
View File

@ -26,6 +26,7 @@ import utils.extra_config
from utils.mime_types import init_mime_types
import faulthandler
import logging
import signal
import sys
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
@ -37,7 +38,19 @@ if __name__ == "__main__":
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
os.environ['DO_NOT_TRACK'] = '1'
faulthandler.enable(file=sys.stderr, all_threads=False)
faulthandler.enable(file=sys.stderr, all_threads=args.debug_hang)
if __name__ == "__main__" and args.debug_hang:
dumping_traceback = False
def dump_traceback_on_sigint(signum, frame):
global dumping_traceback
if dumping_traceback:
raise KeyboardInterrupt
dumping_traceback = True
faulthandler.dump_traceback(file=sys.stderr, all_threads=True)
raise KeyboardInterrupt
signal.signal(signal.SIGINT, dump_traceback_on_sigint)
import comfy_aimdo.control

View File

@ -2404,6 +2404,7 @@ async def init_builtin_extra_nodes():
"nodes_video.py",
"nodes_lumina2.py",
"nodes_wan.py",
"nodes_bernini.py",
"nodes_lotus.py",
"nodes_hunyuan3d.py",
"nodes_primitive.py",
@ -2450,6 +2451,7 @@ async def init_builtin_extra_nodes():
"nodes_rtdetr.py",
"nodes_frame_interpolation.py",
"nodes_sam3.py",
"nodes_scail.py",
"nodes_void.py",
"nodes_wandancer.py",
"nodes_hidream_o1.py",
@ -2457,7 +2459,8 @@ async def init_builtin_extra_nodes():
"nodes_moge.py",
"nodes_mediapipe.py",
"nodes_gaussian_splat.py",
"nodes_triposplat.py"
"nodes_triposplat.py",
"nodes_depth_anything_3.py",
]
import_failed = []

View File

@ -3,11 +3,6 @@ components:
Asset:
description: Represents a user-owned asset (image, video, or other generated output).
properties:
asset_hash:
deprecated: true
description: 'Deprecated: use hash instead. Blake3 hash of the asset content.'
pattern: ^blake3:[a-f0-9]{64}$
type: string
created_at:
description: Timestamp when the asset was created
format: date-time
@ -16,8 +11,12 @@ components:
description: Display name of the asset. Mirrors name for backwards compatibility.
nullable: true
type: string
file_path:
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
nullable: true
type: string
hash:
description: Blake3 hash of the asset content. Preferred over asset_hash.
description: Blake3 hash of the asset content.
pattern: ^blake3:[a-f0-9]{64}$
type: string
id:
@ -139,17 +138,16 @@ components:
AssetUpdated:
description: Response returned when an existing asset is successfully updated.
properties:
asset_hash:
deprecated: true
description: 'Deprecated: use hash instead. Blake3 hash of the asset content.'
pattern: ^blake3:[a-f0-9]{64}$
type: string
display_name:
description: Display name of the asset. Mirrors name for backwards compatibility.
nullable: true
type: string
file_path:
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
nullable: true
type: string
hash:
description: Blake3 hash of the asset content. Preferred over asset_hash.
description: Blake3 hash of the asset content.
pattern: ^blake3:[a-f0-9]{64}$
type: string
id:
@ -828,7 +826,11 @@ components:
type: string
type: object
PaginationInfo:
description: Offset/limit-based pagination metadata included in list responses.
description: |
Pagination metadata included in list responses. Supports both legacy
offset/limit pagination and cursor-based pagination. When cursor-based
pagination is used, `next_cursor` is the primary pagination token and
`offset`/`total` may be zero.
properties:
has_more:
description: Whether more items are available beyond this page
@ -837,12 +839,19 @@ components:
description: Items per page
minimum: 1
type: integer
next_cursor:
description: |
Opaque cursor for the next page. Pass this value as the `after`
query parameter on the next request. Empty or absent when there
are no more results.
type: string
offset:
description: Current offset (0-based)
deprecated: true
description: 'Current offset (0-based). Deprecated: use cursor-based pagination.'
minimum: 0
type: integer
total:
description: Total number of items matching filters
description: Total number of items matching filters (may be 0 when using cursor pagination)
minimum: 0
type: integer
required:
@ -1518,17 +1527,11 @@ paths:
schema:
default: true
type: boolean
- description: Filter assets by exact content hash. Preferred over asset_hash.
- description: Filter assets by exact content hash.
in: query
name: hash
schema:
type: string
- deprecated: true
description: 'Deprecated: use hash instead. Filter assets by exact content hash.'
in: query
name: asset_hash
schema:
type: string
- description: |
Opaque cursor for keyset pagination. Pass the `next_cursor` value
from the previous response to fetch the next page. When provided,
@ -1571,42 +1574,12 @@ paths:
- file
post:
description: |
Uploads a new asset to the system with associated metadata.
Supports two upload methods:
1. Direct file upload (multipart/form-data)
2. URL-based upload (application/json with source: "url")
Creates a new asset from a direct file upload (multipart/form-data) with associated metadata.
If an asset with the same hash already exists, returns the existing asset.
operationId: uploadAsset
operationId: createAsset
requestBody:
content:
application/json:
schema:
properties:
name:
description: Display name for the asset (used to determine file extension)
type: string
preview_id:
description: Optional preview asset ID
format: uuid
type: string
tags:
description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
items:
type: string
type: array
url:
description: HTTP/HTTPS URL to download the asset from
format: uri
type: string
user_metadata:
additionalProperties: true
description: Custom metadata to store with the asset
type: object
required:
- url
- name
type: object
multipart/form-data:
schema:
properties:
@ -1614,6 +1587,10 @@ paths:
description: The asset file to upload
format: binary
type: string
hash:
description: Content hash of the file.
pattern: ^(blake3|sha256):[a-f0-9]{64}$
type: string
id:
description: Optional asset ID for idempotent creation. If provided and asset exists, returns existing asset.
format: uuid
@ -1629,10 +1606,8 @@ paths:
format: uuid
type: string
tags:
description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
items:
type: string
type: array
description: JSON-encoded array of freeform tag strings, e.g. '["models","checkpoint"]'. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
type: string
user_metadata:
description: Custom JSON metadata as a string
type: string
@ -1641,36 +1616,32 @@ paths:
type: object
required: true
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/AssetCreated'
description: |
Asset already existed for this user (deduplicated by content hash); the
existing asset is returned with created_new=false.
"201":
content:
application/json:
schema:
$ref: '#/components/schemas/AssetCreated'
description: Asset created successfully
description: Asset created successfully (created_new=true)
"400":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Invalid request (bad file, invalid URL, invalid content type, etc.)
description: Invalid request (bad file, invalid content type, etc.)
"401":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Unauthorized
"403":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Source URL requires authentication or access denied
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Source URL not found
"413":
content:
application/json:
@ -1683,19 +1654,13 @@ paths:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Unsupported media type
"422":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Download failed due to network error or timeout
"500":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Internal server error
summary: Upload a new asset
summary: Create a new asset
tags:
- file
/api/assets/{id}:
@ -1730,7 +1695,7 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Asset cannot be deleted because it is referenced by another resource (e.g., workflow version)
description: 'Asset cannot be deleted because it is referenced by another resource, e.g. a workflow version (error code: ASSET_IN_USE)'
"500":
content:
application/json:
@ -1783,7 +1748,7 @@ paths:
description: |
Updates an asset's metadata. At least one field must be provided.
Only name, mime_type, preview_id, and user_metadata can be updated.
For tag management, use the dedicated PUT /api/assets/{id}/tags endpoint.
For tag management, use POST (add) and DELETE (remove) /api/assets/{id}/tags.
operationId: updateAsset
parameters:
- description: Asset ID
@ -1982,76 +1947,6 @@ paths:
summary: Add tags to asset
tags:
- file
put:
description: Adds and removes tags from an asset in a single operation
operationId: updateAssetTags
parameters:
- description: Asset ID
in: path
name: id
required: true
schema:
format: uuid
type: string
requestBody:
content:
application/json:
schema:
description: At least one of add or remove must contain items. Empty arrays are allowed when the other array has items.
minProperties: 1
properties:
add:
description: Tags to add to the asset. Can be empty if remove has items.
items:
type: string
type: array
remove:
description: Tags to remove from the asset. Can be empty if add has items.
items:
type: string
type: array
type: object
required: true
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/TagsModificationResponse'
description: Tags updated successfully
"400":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Invalid request
"401":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Unauthorized
"404":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Asset not found
"422":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Reserved tag validation error
"500":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Internal server error
summary: Update asset tags
tags:
- file
/api/assets/from-hash:
post:
description: |
@ -2065,8 +1960,8 @@ paths:
schema:
properties:
hash:
description: Hash of the existing asset. Supports Blake3 (blake3:) or SHA256 (sha256:) formats
pattern: ^(blake3|sha256):[a-f0-9]{64}$
description: 'Blake3 content hash of the existing asset (blake3: prefix)'
pattern: ^blake3:[a-f0-9]{64}$
type: string
mime_type:
description: MIME type of the asset (e.g., "image/png", "video/mp4")
@ -2090,12 +1985,20 @@ paths:
type: object
required: true
responses:
"200":
content:
application/json:
schema:
$ref: '#/components/schemas/AssetCreated'
description: |
Asset reference already existed for this user (deduplicated by content
hash); the existing asset is returned with created_new=false.
"201":
content:
application/json:
schema:
$ref: '#/components/schemas/AssetCreated'
description: Asset reference created successfully
description: Asset reference created successfully (created_new=true)
"400":
content:
application/json:
@ -2887,7 +2790,21 @@ paths:
- asc
- desc
type: string
- description: Pagination offset (0-based)
- description: |
Opaque cursor for keyset pagination. Pass the `next_cursor` value
from a previous response to fetch the next page.
Cursor pagination is supported only when `sort_by=create_time`
(default). If `sort_by=execution_time`, `after` is ignored and
offset/limit pagination is used.
Cursors are opaque base64url payloads — clients should treat them
as strings and not parse the contents.
example: eyJzIjoiY3JlYXRlX3RpbWUiLCJ2IjoiMTcxNjIwMDAwMDAwMDAwMCIsImlkIjoiYTFiMmMzZDQtZTVmNi03YTg5LWIwYzEtZDJlM2Y0YTViNmM3In0
in: query
name: after
schema:
type: string
- deprecated: true
description: 'Pagination offset (0-based). Deprecated: prefer cursor-based pagination via `after`.'
in: query
name: offset
schema:
@ -2909,6 +2826,12 @@ paths:
schema:
$ref: '#/components/schemas/JobsListResponse'
description: Success - Jobs retrieved
"400":
content:
application/json:
schema:
$ref: '#/components/schemas/ErrorResponse'
description: Bad request (e.g. malformed pagination cursor).
"401":
content:
application/json:

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.44.19
comfyui-workflow-templates==0.9.94
comfyui-embedded-docs==0.5.2
comfyui-frontend-package==1.45.15
comfyui-workflow-templates==0.9.98
comfyui-embedded-docs==0.5.3
torch
torchsde
torchvision
@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
filelock
av>=16.0.0
comfy-kitchen==0.2.10
comfy-aimdo==0.4.8
comfy-aimdo==0.4.9
requests
simpleeval>=1.0.0
blake3

View File

@ -1253,6 +1253,15 @@ class PromptServer():
if verbose:
logging.info("Starting server\n")
if args.debug_hang:
logging.info(
f"{'-' * 80}\n"
"ComfyUI has been started in debug-hang mode. Run your workflow as normal up to\n"
"the point of the hang or freeze, then use ctrl-C in the cmd or controlling\n"
"terminal to dump the python backtraces for debugging. Please attach the extra\n"
"debug info to your bug report.\n"
f"{'-' * 80}"
)
for addr in addresses:
address = addr[0]
port = addr[1]

View File

@ -0,0 +1,86 @@
"""Tests for the image_dimensions service."""
from __future__ import annotations
from pathlib import Path
import pytest
from PIL import Image
from app.assets.services.image_dimensions import extract_image_dimensions
def _make_png(path: Path, size: tuple[int, int]) -> Path:
img = Image.new("RGB", size, color=(123, 45, 67))
img.save(path, format="PNG")
return path
def _make_jpeg(path: Path, size: tuple[int, int]) -> Path:
img = Image.new("RGB", size, color=(10, 20, 30))
img.save(path, format="JPEG", quality=80)
return path
class TestExtractImageDimensions:
def test_extracts_png_dimensions(self, tmp_path: Path):
f = _make_png(tmp_path / "rect.png", (320, 240))
result = extract_image_dimensions(str(f), mime_type="image/png")
assert result == {"kind": "image", "width": 320, "height": 240}
def test_extracts_jpeg_dimensions(self, tmp_path: Path):
f = _make_jpeg(tmp_path / "shot.jpg", (1920, 1080))
result = extract_image_dimensions(str(f), mime_type="image/jpeg")
assert result == {"kind": "image", "width": 1920, "height": 1080}
def test_works_when_mime_type_is_none(self, tmp_path: Path):
f = _make_png(tmp_path / "no_mime.png", (50, 100))
result = extract_image_dimensions(str(f), mime_type=None)
assert result == {"kind": "image", "width": 50, "height": 100}
def test_skips_non_image_mime_without_touching_file(self, tmp_path: Path):
# Path doesn't need to exist — non-image MIME short-circuits.
result = extract_image_dimensions(
str(tmp_path / "model.safetensors"),
mime_type="application/octet-stream",
)
assert result is None
@pytest.mark.parametrize(
"mime",
["application/json", "text/plain", "video/mp4", "audio/mpeg"],
)
def test_skips_all_non_image_mime_types(self, tmp_path: Path, mime: str):
f = tmp_path / "file.bin"
f.write_bytes(b"\x00\x01\x02")
assert extract_image_dimensions(str(f), mime_type=mime) is None
def test_returns_none_for_missing_file(self, tmp_path: Path):
result = extract_image_dimensions(
str(tmp_path / "does_not_exist.png"), mime_type="image/png"
)
assert result is None
def test_returns_none_for_corrupt_image(self, tmp_path: Path):
f = tmp_path / "corrupt.png"
f.write_bytes(b"not actually a png file")
result = extract_image_dimensions(str(f), mime_type="image/png")
assert result is None
def test_returns_none_for_empty_file(self, tmp_path: Path):
f = tmp_path / "empty.png"
f.write_bytes(b"")
result = extract_image_dimensions(str(f), mime_type="image/png")
assert result is None

View File

@ -4,10 +4,12 @@ from pathlib import Path
from unittest.mock import patch
import pytest
from PIL import Image
from sqlalchemy.orm import Session as SASession, Session
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
from app.assets.database.queries import get_reference_tags
from app.assets.helpers import get_utc_now
from app.assets.services.ingest import (
_ingest_file_from_path,
_register_existing_asset,
@ -15,6 +17,11 @@ from app.assets.services.ingest import (
)
def _make_png(path: Path, size: tuple[int, int]) -> Path:
Image.new("RGB", size, color=(80, 120, 200)).save(path, format="PNG")
return path
class TestIngestFileFromPath:
def test_creates_asset_and_reference(self, mock_create_session, temp_dir: Path, session: Session):
file_path = temp_dir / "test_file.bin"
@ -279,4 +286,203 @@ class TestIngestExistingFileTagFK:
ref_tags = sess.query(AssetReferenceTag).all()
ref_tag_names = {rt.tag_name for rt in ref_tags}
assert "output" in ref_tag_names
assert "my-job" in ref_tag_names
class TestIngestImageDimensions:
"""system_metadata should carry {kind, width, height} for image assets."""
def test_image_asset_emits_dimensions(
self, mock_create_session, temp_dir: Path, session: Session
):
f = _make_png(temp_dir / "shot.png", (640, 480))
result = _ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:img1",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000000,
mime_type="image/png",
)
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
assert ref.system_metadata == {
"kind": "image",
"width": 640,
"height": 480,
}
def test_non_image_asset_leaves_system_metadata_empty(
self, mock_create_session, temp_dir: Path, session: Session
):
f = temp_dir / "model.safetensors"
f.write_bytes(b"not an image")
result = _ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:safetensors1",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000000,
mime_type="application/octet-stream",
)
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
assert ref.system_metadata in (None, {})
def test_preserves_existing_system_metadata_keys(
self, mock_create_session, temp_dir: Path, session: Session
):
f = _make_png(temp_dir / "annotated.png", (100, 200))
# First pass populates a sentinel system_metadata key (simulating prior
# enricher write).
result = _ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:img-merge",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000000,
mime_type="image/png",
)
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/x.png"}
session.commit()
# Second pass with the same path triggers the merge code path again.
_ingest_file_from_path(
abs_path=str(f),
asset_hash="blake3:img-merge",
size_bytes=f.stat().st_size,
mtime_ns=1234567890000000001,
mime_type="image/png",
)
session.refresh(ref)
assert ref.system_metadata["kind"] == "image"
assert ref.system_metadata["width"] == 100
assert ref.system_metadata["height"] == 200
assert ref.system_metadata["source_url"] == "https://example/x.png"
class TestRegisterExistingAssetBackfill:
"""The from-hash path back-fills dimensions from a sibling reference."""
def _add_reference(
self,
session: Session,
asset: Asset,
name: str,
system_metadata: dict | None = None,
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
asset_id=asset.id,
name=name,
owner_id="",
created_at=now,
updated_at=now,
last_access_time=now,
system_metadata=system_metadata or {},
)
session.add(ref)
session.flush()
return ref
def test_backfills_dimensions_from_sibling_image_reference(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:shared", size_bytes=2048, mime_type="image/png")
session.add(asset)
session.flush()
self._add_reference(
session,
asset,
name="original.png",
system_metadata={"kind": "image", "width": 800, "height": 600},
)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:shared",
name="from_hash.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
assert ref.system_metadata.get("kind") == "image"
assert ref.system_metadata.get("width") == 800
assert ref.system_metadata.get("height") == 600
def test_no_backfill_when_sibling_has_no_image_metadata(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:nodims", size_bytes=2048, mime_type="image/png")
session.add(asset)
session.flush()
self._add_reference(
session,
asset,
name="original.png",
system_metadata={"base_model": "flux"}, # no kind=image
)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:nodims",
name="from_hash.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
meta = ref.system_metadata or {}
assert "kind" not in meta
assert "width" not in meta
assert "height" not in meta
def test_no_backfill_when_no_sibling_exists(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:lonely", size_bytes=1024, mime_type="image/png")
session.add(asset)
session.commit()
result = _register_existing_asset(
asset_hash="blake3:lonely",
name="solo.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
assert ref.system_metadata in (None, {})
def test_backfill_preserves_caller_supplied_keys(
self, mock_create_session, session: Session
):
asset = Asset(hash="blake3:preserve", size_bytes=2048, mime_type="image/png")
session.add(asset)
session.flush()
self._add_reference(
session,
asset,
name="original.png",
system_metadata={"kind": "image", "width": 1024, "height": 768},
)
session.commit()
# Simulate a from-hash path where the new reference already carries
# some system_metadata (e.g. a download-provenance source_url written
# by an earlier step). The back-fill must merge dim keys without
# clobbering existing keys.
result = _register_existing_asset(
asset_hash="blake3:preserve",
name="from_hash.png",
owner_id="user-x",
)
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
# Seed a sentinel key and re-run back-fill via a second register call
# to exercise the merge path with pre-existing data.
ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/p"}
session.commit()
assert ref.system_metadata.get("source_url") == "https://example/p"
assert ref.system_metadata.get("kind") == "image"
assert ref.system_metadata.get("width") == 1024
assert ref.system_metadata.get("height") == 768