Merge remote-tracking branch 'upstream/master' into bernini

This commit is contained in:
kijai 2026-06-01 21:20:48 +03:00
commit 886b2e5102
22 changed files with 1904 additions and 113 deletions

View File

@ -2,6 +2,7 @@ from .utils import load_torch_file, transformers_convert, state_dict_prefix_repl
import os
import json
import logging
import torch
import comfy.ops
import comfy.model_patcher
@ -9,6 +10,7 @@ import comfy.model_management
import comfy.utils
import comfy.clip_model
import comfy.image_encoders.dino2
import comfy.image_encoders.dino3
class Output:
def __getitem__(self, key):
@ -23,12 +25,16 @@ IMAGE_ENCODERS = {
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
"dinov3": comfy.image_encoders.dino3.DINOv3ViTModel,
}
class ClipVisionModel():
def __init__(self, json_config):
with open(json_config) as f:
config = json.load(f)
if isinstance(json_config, dict):
config = json_config
else:
with open(json_config) as f:
config = json.load(f)
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
@ -44,6 +50,10 @@ class ClipVisionModel():
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
if self.model_type == "dinov3" and self.dtype == torch.float16:
# DINOv3's activations borderline fits fp16, preferring bf16 if available for better stability #TODO: further fp16 tests in practice
if comfy.model_management.should_use_bf16(self.load_device, prioritize_performance=True):
self.dtype = torch.bfloat16
self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
@ -134,6 +144,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
elif 'layer.0.mlp.gate_proj.weight' in sd and 'layer.31.norm1.weight' in sd: # Dinov3 ViT-H/16+ (SwiGLU gated MLP, 32 layers)
json_config = comfy.image_encoders.dino3.DINOV3_VITH_CONFIG
else:
return None

View File

@ -0,0 +1,260 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from comfy.ldm.modules.attention import optimized_attention_for_device
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
# DINOv3 ViT-H/16+ (SwiGLU)
DINOV3_VITH_CONFIG = {
"model_type": "dinov3",
"num_hidden_layers": 32,
"hidden_size": 1280,
"num_attention_heads": 20,
"num_register_tokens": 4,
"intermediate_size": 5120,
"layer_norm_eps": 1e-5,
"num_channels": 3,
"patch_size": 16,
"rope_theta": 100.0,
"use_gated_mlp": True,
"gated_mlp_act": "silu",
"image_size": 1024,
"image_mean": [0.485, 0.456, 0.406],
"image_std": [0.229, 0.224, 0.225],
}
class DINOv3ViTMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
self.act_fn = torch.nn.GELU()
def forward(self, x):
return self.down_proj(self.act_fn(self.up_proj(x)))
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
num_tokens = q.shape[-2]
num_patches = sin.shape[-2]
num_prefix_tokens = num_tokens - num_patches
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
return q, k
class DINOv3ViTAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
super().__init__()
self.embed_dim = hidden_size
self.num_heads = num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
def forward(self, hidden_states, attention_mask=None, position_embeddings=None, **kwargs):
batch_size, patches, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
if position_embeddings is not None:
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
attn = optimized_attention_for_device(query_states.device, mask=False)
attn_output = attn(
query_states, key_states, value_states, self.num_heads, attention_mask,
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output
class DINOv3ViTGatedMLP(nn.Module):
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations, act="silu"):
super().__init__()
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
self.act_fn = torch.nn.SiLU() if act == "silu" else torch.nn.GELU()
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
def get_patches_center_coordinates(num_patches_h, num_patches_w, dtype, device):
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
coords_h = coords_h / num_patches_h
coords_w = coords_w / num_patches_w
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
coords = coords.flatten(0, 1)
coords = 2.0 * coords - 1.0
return coords
class DINOv3ViTRopePositionEmbedding(nn.Module):
inv_freq: torch.Tensor
def __init__(self, rope_theta, hidden_size, num_attention_heads, patch_size, device, dtype):
super().__init__()
self.base = rope_theta
self.head_dim = hidden_size // num_attention_heads
self.patch_size = patch_size
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
def forward(self, pixel_values):
_, _, height, width = pixel_values.shape
num_patches_h = height // self.patch_size
num_patches_w = width // self.patch_size
patch_coords = get_patches_center_coordinates(num_patches_h, num_patches_w, dtype=torch.float32, device=pixel_values.device)
self.inv_freq = self.inv_freq.to(pixel_values.device)
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
angles = angles.flatten(1, 2)
angles = angles.tile(2)
cos = torch.cos(angles).to(dtype=pixel_values.dtype)
sin = torch.sin(angles).to(dtype=pixel_values.dtype)
return cos, sin
class DINOv3ViTEmbeddings(nn.Module):
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
super().__init__()
self.cls_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
self.mask_token = nn.Parameter(torch.empty(1, 1, hidden_size, device=device, dtype=dtype))
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
self.patch_embeddings = operations.Conv2d(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
)
def forward(self, pixel_values, bool_masked_pos=None):
batch_size = pixel_values.shape[0]
target_dtype = self.patch_embeddings.weight.dtype
patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
if bool_masked_pos is not None:
mask_token = self.mask_token.to(patch_embeddings.dtype)
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
cls_token = self.cls_token.expand(batch_size, -1, -1).to(patch_embeddings.device)
register_tokens = self.register_tokens.expand(batch_size, -1, -1).to(patch_embeddings.device)
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
return embeddings
class DINOv3ViTLayer(nn.Module):
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size,
num_attention_heads, device, dtype, operations, gated_mlp_act="silu"):
super().__init__()
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
if use_gated_mlp:
self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations, act=gated_mlp_act)
else:
self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations)
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
def forward(self, hidden_states, attention_mask=None, position_embeddings=None):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.attention(hidden_states, attention_mask=attention_mask, position_embeddings=position_embeddings)
hidden_states = self.layer_scale1(hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.layer_scale2(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class DINOv3ViTModel(nn.Module):
def __init__(self, config, dtype, device, operations):
super().__init__()
num_hidden_layers = config["num_hidden_layers"]
hidden_size = config["hidden_size"]
num_attention_heads = config["num_attention_heads"]
num_register_tokens = config["num_register_tokens"]
intermediate_size = config["intermediate_size"]
layer_norm_eps = config["layer_norm_eps"]
num_channels = config["num_channels"]
patch_size = config["patch_size"]
rope_theta = config["rope_theta"]
use_gated_mlp = config.get("use_gated_mlp", False)
gated_mlp_act = config.get("gated_mlp_act", "silu")
self.embeddings = DINOv3ViTEmbeddings(
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size,
dtype=dtype, device=device, operations=operations
)
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
rope_theta, hidden_size, num_attention_heads, patch_size=patch_size, dtype=dtype, device=device
)
self.layer = nn.ModuleList([
DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=use_gated_mlp, mlp_bias=True,
intermediate_size=intermediate_size, num_attention_heads=num_attention_heads,
dtype=dtype, device=device, operations=operations, gated_mlp_act=gated_mlp_act)
for _ in range(num_hidden_layers)])
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
def get_input_embeddings(self):
return self.embeddings.patch_embeddings
def forward(self, pixel_values, bool_masked_pos=None, **kwargs):
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
position_embeddings = self.rope_embeddings(pixel_values)
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, position_embeddings=position_embeddings)
if kwargs.get("skip_norm_elementwise", False):
sequence_output = F.layer_norm(hidden_states, hidden_states.shape[-1:])
else:
norm = self.norm.to(hidden_states.device)
sequence_output = norm(hidden_states)
pooled_output = sequence_output[:, 0, :]
return sequence_output, None, pooled_output, None

View File

@ -239,6 +239,16 @@ class Flux2(LatentFormat):
def process_out(self, latent):
return latent
class TripoSplat(LatentFormat):
# Sequence latent (B, 8192, 16) the camera token rides alongside as a second nested latent
latent_channels = 16
def process_in(self, latent):
return latent
def process_out(self, latent):
return latent
class Mochi(LatentFormat):
latent_channels = 12
latent_dimensions = 3

View File

@ -0,0 +1,199 @@
# TripoSplat 3D gaussian container. Operates on already-decoded
# tensors and exposes them as render-ready tensors (render_tensors) for the generic SPLAT type.
import torch
import torch.nn.functional as F
import comfy.model_management
class GaussianModel:
def __init__(self, aabb: list, sh_degree: int = 0, mininum_kernel_size: float = 0.0,
scaling_bias: float = 0.01, opacity_bias: float = 0.1,
scaling_activation: str = "exp", device=None):
self.sh_degree = sh_degree
self.mininum_kernel_size = mininum_kernel_size
self.scaling_bias = scaling_bias
self.opacity_bias = opacity_bias
self.device = device
self.aabb = torch.tensor(aabb, dtype=torch.float32, device=device)
if scaling_activation == "exp":
self._scaling_activation = torch.exp
self._inverse_scaling_activation = torch.log
elif scaling_activation == "softplus":
self._scaling_activation = F.softplus
self._inverse_scaling_activation = lambda x: x + torch.log(-torch.expm1(-x))
self._opacity_activation = torch.sigmoid
self._inverse_opacity_activation = lambda x: torch.log(x / (1 - x))
self.scale_bias = self._inverse_scaling_activation(torch.tensor(self.scaling_bias)).to(self.device)
self.rots_bias = torch.zeros(4, device=self.device)
self.rots_bias[0] = 1
self.opacity_bias_val = self._inverse_opacity_activation(torch.tensor(self.opacity_bias)).to(self.device)
self._storage = {}
def _get_store(self, name):
return self._storage.get(name)
def _set_store(self, name, value):
self._storage[name] = value
@property
def _xyz(self):
return self._get_store("_xyz")
@_xyz.setter
def _xyz(self, value):
if value is None:
self._set_store("_xyz", None)
self._set_store("xyz", None)
return
self._set_store("_xyz", value)
self._set_store("xyz", value * self.aabb[None, 3:] + self.aabb[None, :3])
@property
def get_xyz(self):
return self._get_store("xyz")
@property
def _features_dc(self):
return self._get_store("_features_dc")
@_features_dc.setter
def _features_dc(self, value):
self._set_store("_features_dc", value)
@property
def _opacity(self):
return self._get_store("_opacity")
@_opacity.setter
def _opacity(self, value):
if value is None:
self._set_store("_opacity", None)
self._set_store("opacity", None)
return
self._set_store("_opacity", value)
self._set_store("opacity", self._opacity_activation(value + self.opacity_bias_val))
@property
def get_opacity(self):
return self._get_store("opacity")
@property
def _scaling(self):
return self._get_store("_scaling")
@_scaling.setter
def _scaling(self, value):
if value is None:
self._set_store("_scaling", None)
self._set_store("scaling", None)
return
self._set_store("_scaling", value)
s = self._scaling_activation(value + self.scale_bias)
s = torch.square(s) + self.mininum_kernel_size ** 2
self._set_store("scaling", torch.sqrt(s))
@property
def get_scaling(self):
return self._get_store("scaling")
@property
def _rotation(self):
return self._get_store("_rotation")
@_rotation.setter
def _rotation(self, value):
self._set_store("_rotation", value)
_DEFAULT_TRANSFORM = [[1, 0, 0], [0, 0, -1], [0, 1, 0]]
def render_tensors(self):
# Render-ready (activated, world-space) tensors for the generic SPLAT type. The axis transform
# (a 3x3 rotation, object frame -> viewer Y-up) is baked into positions and rotations.
# Returns float tensors on the intermediate device: positions (N,3), scales (N,3) linear,
# rotations (N,4) wxyz, opacities (N,1) in [0,1], sh (N,K,3) coefficients.
xyz = self.get_xyz.float()
scaling = self.get_scaling.float()
opacity = self.get_opacity.float()
rotation = (self._rotation + self.rots_bias[None, :]).float()
sh = self._features_dc.float() # (N, K, 3)
T = torch.as_tensor(self._DEFAULT_TRANSFORM, dtype=torch.float32, device=xyz.device)
xyz = xyz @ T.T
rotation = _matrix_to_quat(torch.matmul(T, _quat_to_matrix(rotation)))
rotation = rotation / torch.linalg.norm(rotation, dim=-1, keepdim=True)
out_device = comfy.model_management.intermediate_device()
return (
xyz.to(out_device).contiguous(), scaling.to(out_device).contiguous(),
rotation.to(out_device).contiguous(), opacity.to(out_device).contiguous(),
sh.to(out_device).contiguous(),
)
def _quat_to_matrix(q):
q = q / torch.linalg.norm(q, dim=-1, keepdim=True)
w, x, y, z = q[:, 0], q[:, 1], q[:, 2], q[:, 3]
R = torch.stack([
1 - 2*(y*y + z*z), 2*(x*y - w*z), 2*(x*z + w*y),
2*(x*y + w*z), 1 - 2*(x*x + z*z), 2*(y*z - w*x),
2*(x*z - w*y), 2*(y*z + w*x), 1 - 2*(x*x + y*y),
], dim=-1).reshape(-1, 3, 3)
return R
def _matrix_to_quat(R):
trace = R[:, 0, 0] + R[:, 1, 1] + R[:, 2, 2]
q = torch.zeros((R.shape[0], 4), dtype=R.dtype, device=R.device)
s = torch.sqrt(torch.clamp(trace + 1, min=0)) * 2
q[:, 0] = 0.25 * s
denom = torch.where(s != 0, s, torch.ones_like(s))
q[:, 1] = (R[:, 2, 1] - R[:, 1, 2]) / denom
q[:, 2] = (R[:, 0, 2] - R[:, 2, 0]) / denom
q[:, 3] = (R[:, 1, 0] - R[:, 0, 1]) / denom
m01 = (R[:, 0, 0] >= R[:, 1, 1]) & (R[:, 0, 0] >= R[:, 2, 2]) & (s == 0)
s1 = torch.sqrt(torch.clamp(1 + R[:, 0, 0] - R[:, 1, 1] - R[:, 2, 2], min=0)) * 2
q[m01, 0] = (R[m01, 2, 1] - R[m01, 1, 2]) / s1[m01]
q[m01, 1] = 0.25 * s1[m01]
q[m01, 2] = (R[m01, 0, 1] + R[m01, 1, 0]) / s1[m01]
q[m01, 3] = (R[m01, 0, 2] + R[m01, 2, 0]) / s1[m01]
m11 = (R[:, 1, 1] > R[:, 0, 0]) & (R[:, 1, 1] >= R[:, 2, 2]) & (s == 0)
s2 = torch.sqrt(torch.clamp(1 + R[:, 1, 1] - R[:, 0, 0] - R[:, 2, 2], min=0)) * 2
q[m11, 0] = (R[m11, 0, 2] - R[m11, 2, 0]) / s2[m11]
q[m11, 1] = (R[m11, 0, 1] + R[m11, 1, 0]) / s2[m11]
q[m11, 2] = 0.25 * s2[m11]
q[m11, 3] = (R[m11, 1, 2] + R[m11, 2, 1]) / s2[m11]
m21 = (R[:, 2, 2] > R[:, 0, 0]) & (R[:, 2, 2] > R[:, 1, 1]) & (s == 0)
s3 = torch.sqrt(torch.clamp(1 + R[:, 2, 2] - R[:, 0, 0] - R[:, 1, 1], min=0)) * 2
q[m21, 0] = (R[m21, 1, 0] - R[m21, 0, 1]) / s3[m21]
q[m21, 1] = (R[m21, 0, 2] + R[m21, 2, 0]) / s3[m21]
q[m21, 2] = (R[m21, 1, 2] + R[m21, 2, 1]) / s3[m21]
q[m21, 3] = 0.25 * s3[m21]
return q / torch.linalg.norm(q, dim=-1, keepdim=True)
def build_gaussian_models(decoder, points_pred: dict, pred: dict):
# Assemble GaussianModels from the elastic decoder layout. decoder is the ElasticGaussianFixedlenDecoder
# (carries layout / rep_config / _get_offset)
x = points_pred
offset = decoder._get_offset(pred['features'])
h = pred["features"]
ret = []
for i in range(h.shape[0]):
g = GaussianModel(
sh_degree=0,
aabb=[-0.5, -0.5, -0.5, 1.0, 1.0, 1.0],
mininum_kernel_size=decoder.rep_config['filter_kernel_size_3d'],
scaling_bias=decoder.rep_config['scaling_bias'],
opacity_bias=decoder.rep_config['opacity_bias'],
scaling_activation=decoder.rep_config['scaling_activation'],
device=h.device,
)
_x = x["points"][i, :, None, :]
for k, v in decoder.layout.items():
if k == '_xyz':
setattr(g, k, (offset[i] + _x).flatten(0, 1))
elif k in ('_xyz_center', '_offset_scale'):
continue
else:
feats = h[i][:, v['range'][0]:v['range'][1]].reshape(-1, *v['shape']).flatten(0, 1)
setattr(g, k, feats * decoder.rep_config['lr'][k])
ret.append(g)
return ret

View File

@ -0,0 +1,326 @@
# TripoSplat flow-matching denoiser (LatentSeqMMFlowModel). Registered as a ModelType.FLOW arch and
# driven by the standard KSampler; jointly denoises the (B, 8192, 16) latent and a (B, 1, 5) camera token
# carried as a 2-element nested latent.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.model_management
import comfy.patcher_extension
import comfy.rmsnorm
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.math import apply_rope
class MultiHeadRMSNorm(nn.Module):
def __init__(self, dim, heads, dtype=None, device=None):
super().__init__()
self.gamma = nn.Parameter(torch.empty(heads, dim, dtype=dtype, device=device))
def forward(self, x):
x = comfy.rmsnorm.rms_norm(x)
return x * comfy.model_management.cast_to(self.gamma, x.dtype, x.device)
# Positional embeddings
class RePo3DRotaryEmbedding(nn.Module):
def __init__(self, model_channels, num_heads, head_dim, repo_hidden_ratio=0.125, max_freq=16.0,
dtype=None, device=None, operations=None):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
repo_hidden_size = int(model_channels * repo_hidden_ratio)
self.norm = operations.LayerNorm(model_channels, dtype=dtype, device=device)
self.gate_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
self.content_map = operations.Linear(model_channels, repo_hidden_size, bias=False, dtype=dtype, device=device)
self.act = nn.SiLU()
self.final_map = operations.Linear(repo_hidden_size, 3 * num_heads, bias=False, dtype=dtype, device=device)
self.dim_0 = 2 * (head_dim // 6)
self.dim_1 = 2 * (head_dim // 6)
self.dim_2 = head_dim - self.dim_0 - self.dim_1
dims = [self.dim_0, self.dim_1, self.dim_2]
freqs_list = []
for d in dims:
freq_dim = d // 2
freqs_list.append(torch.linspace(1.0, float(max_freq), steps=freq_dim, dtype=torch.float32))
self.freqs_0 = nn.Parameter(freqs_list[0])
self.freqs_1 = nn.Parameter(freqs_list[1])
self.freqs_2 = nn.Parameter(freqs_list[2])
def forward(self, hidden_states):
h = self.norm(hidden_states)
feat = self.act(self.gate_map(h)) * self.content_map(h)
out = self.final_map(feat)
B, L, _ = out.shape
delta_pos = out.reshape(B, L, self.num_heads, 3)
f0 = comfy.model_management.cast_to(self.freqs_0, torch.float32, out.device)
f1 = comfy.model_management.cast_to(self.freqs_1, torch.float32, out.device)
f2 = comfy.model_management.cast_to(self.freqs_2, torch.float32, out.device)
ang_0 = delta_pos[..., 0].unsqueeze(-1) * f0 * torch.pi
ang_1 = delta_pos[..., 1].unsqueeze(-1) * f1 * torch.pi
ang_2 = delta_pos[..., 2].unsqueeze(-1) * f2 * torch.pi
ang = torch.cat([ang_0, ang_1, ang_2], dim=-1).float() # (B, L, heads, head_dim/2)
cos, sin = ang.cos(), ang.sin()
return torch.stack([cos, -sin, sin, cos], dim=-1).reshape(*ang.shape, 2, 2)
class PcdAbsolutePositionEmbedder(nn.Module):
# Sinusoidal absolute position embedding. Two fixed schedules are used in TripoSplat:
# "pow2" (flow-model latent anchors) and "log2" (octree / gaussian decoders).
def __init__(self, channels: int, in_channels: int = 3, max_res: int = 16, schedule: str = "pow2"):
super().__init__()
self.channels = channels
self.in_channels = in_channels
self.max_res = max_res
self.schedule = schedule
self.freq_dim = channels // in_channels // 2
def _freqs(self, device):
if self.schedule == "pow2":
freqs_2exp = torch.arange(self.max_res, dtype=torch.float32, device=device)
res_dim = max(0, self.freq_dim - self.max_res)
freqs_res = (torch.arange(res_dim, dtype=torch.float32, device=device) / max(res_dim, 1) * self.max_res
if res_dim > 0 else torch.empty(0, device=device))
freqs = torch.cat([freqs_2exp, freqs_res], dim=0)[:self.freq_dim]
return torch.pow(2.0, freqs) * 2.0 # *2 folds this schedule's 2*pi into the shared *pi below
logs = torch.linspace(0.0, float(self.max_res), steps=self.freq_dim, dtype=torch.float32, device=device)
return torch.pow(2.0, logs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
orig_dtype = x.dtype
x = x.float()
*dims, D = x.shape
out = torch.outer(x.reshape(-1), self._freqs(x.device)) * torch.pi
out = torch.cat([out.sin(), out.cos()], dim=-1).reshape(*dims, -1)
if out.shape[-1] < self.channels:
out = torch.cat([out, torch.zeros(*dims, self.channels - out.shape[-1],
device=out.device, dtype=out.dtype)], dim=-1)
return out.to(orig_dtype)
def attention(q, k, v, transformer_options=None):
# q, k, v: (B, L, heads, dim) -> (B, L, heads, dim). Shared optimized_attention call convention.
out = optimized_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), heads=q.shape[2],
skip_reshape=True, skip_output_reshape=True, low_precision_attention=False,
transformer_options=transformer_options)
return out.transpose(1, 2)
# Transformer building blocks
class MLP(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(in_channels, hidden_channels, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(hidden_channels, out_channels, dtype=dtype, device=device),
)
def forward(self, x):
return self.mlp(x)
class RopeMultiHeadAttention(nn.Module):
def __init__(self, channels, num_heads, qkv_bias=True, qk_rms_norm=False, use_rope=False,
dtype=None, device=None, operations=None):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.head_dim = channels // num_heads
self.qk_rms_norm = qk_rms_norm
self.use_rope = use_rope
self.qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
if self.qk_rms_norm:
self.q_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.k_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.out = operations.Linear(channels, channels, dtype=dtype, device=device)
def forward(self, x, rope_emb=None, transformer_options=None):
B, L, C = x.shape
qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim)
q, k, v = qkv.unbind(2)
if self.use_rope:
q, k = apply_rope(q, k, rope_emb)
if self.qk_rms_norm:
q = self.q_norm(q)
k = self.k_norm(k)
h = attention(q, k, v, transformer_options) # (B, L, heads, dim)
return self.out(h.reshape(B, L, C))
class UnifiedTransformerBlock(nn.Module):
def __init__(self, channels, num_heads, mlp_ratio=4.0,
use_rope=False, qk_rms_norm=False, qkv_bias=True,
modulation=True, share_mod=False,
dtype=None, device=None, operations=None):
super().__init__()
self.modulation = modulation
self.share_mod = share_mod
self.norm1 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=not modulation, eps=1e-6, dtype=dtype, device=device)
self.attn = RopeMultiHeadAttention(channels, num_heads=num_heads,
qkv_bias=qkv_bias, use_rope=use_rope, qk_rms_norm=qk_rms_norm,
dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
if modulation:
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
self.shift_table = nn.Parameter(torch.empty(1, 6 * channels, dtype=dtype, device=device))
def forward(self, x, mod=None, rotary_emb=None, transformer_options=None):
if self.modulation:
if not self.share_mod:
mod = self.adaLN_modulation(mod)
mod = mod + comfy.model_management.cast_to(self.shift_table, mod.dtype, mod.device)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
x = torch.addcmul(x, self.attn(h, rope_emb=rotary_emb, transformer_options=transformer_options), gate_msa.unsqueeze(1))
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
else:
x = x + self.attn(self.norm1(x), rope_emb=rotary_emb, transformer_options=transformer_options)
x = x + self.mlp(self.norm2(x))
return x
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
emb = self.timestep_embedding(t, self.frequency_embedding_size)
return self.mlp(emb.to(self.mlp[0].weight.dtype))
class LatentSeqMMFlowModel(nn.Module):
def __init__(self, image_model=None, q_token_length=8192, in_channels=16, model_channels=1024,
cond_channels=1280, out_channels=16, num_blocks=24, num_refiner_blocks=2,
num_heads=None, num_head_channels=64, cam_channels=5, cond2_channels=128,
mlp_ratio=4, share_mod=True, qk_rms_norm=True,
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.dtype = dtype
self.q_token_length = q_token_length
self.in_channels = in_channels
self.cam_channels = cam_channels
self.model_channels = model_channels
self.cond_channels = cond_channels
self.cond2_channels = cond2_channels
self.out_channels = out_channels
self.num_blocks = num_blocks
self.num_refiner_blocks = num_refiner_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.share_mod = share_mod
self.qk_rms_norm = qk_rms_norm
factory_kwargs = dict(dtype=dtype, device=device)
op_kwargs = dict(operations=operations, **factory_kwargs)
self.t_embedder = TimestepEmbedder(model_channels, **op_kwargs)
if share_mod:
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, **factory_kwargs))
self.input_layer = operations.Linear(in_channels, model_channels, **factory_kwargs)
self.cond_embedder = operations.Linear(cond_channels, model_channels, **factory_kwargs)
self.cond_embedder2 = operations.Linear(cond2_channels, model_channels, **factory_kwargs) if cond2_channels is not None else None
# Fixed Sobol (low-discrepancy) 3D anchor positions for the latent tokens, used as positional encoding.
# The embedder is parameter-free and the anchors are fixed, precompute once.
sobol_seq = torch.quasirandom.SobolEngine(dimension=3, scramble=True, seed=123).draw(q_token_length)
pos_emb = PcdAbsolutePositionEmbedder(model_channels)(sobol_seq.unsqueeze(0))
self.register_buffer("pos_emb", pos_emb, persistent=False)
# RePo3DRotaryEmbedding layers for the refiner and main blocks
repo_kwargs = dict(num_heads=self.num_heads, head_dim=num_head_channels, **op_kwargs)
self.noise_repo_layers = nn.ModuleList(
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
self.context_repo_layers = nn.ModuleList(
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_refiner_blocks)])
self.repo_layers = nn.ModuleList(
[RePo3DRotaryEmbedding(model_channels, **repo_kwargs) for _ in range(num_blocks)])
# Refiner blocks
block_kwargs = dict(num_heads=self.num_heads, mlp_ratio=self.mlp_ratio, use_rope=True, qk_rms_norm=self.qk_rms_norm, **op_kwargs)
self.noise_refiner = nn.ModuleList(
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_refiner_blocks)])
self.context_refiner = nn.ModuleList(
[UnifiedTransformerBlock(model_channels, modulation=False, **block_kwargs) for _ in range(num_refiner_blocks)])
self.cam_refiner = MLP(self.cam_channels, model_channels, model_channels, **op_kwargs)
self.blocks = nn.ModuleList(
[UnifiedTransformerBlock(model_channels, modulation=True, share_mod=self.share_mod, **block_kwargs) for _ in range(num_blocks)])
self.shift_table = nn.Parameter(torch.empty(1, 2, model_channels, **factory_kwargs))
self.out_layer = operations.Linear(model_channels, out_channels, **factory_kwargs)
self.cam_out_layer = operations.Linear(model_channels, cam_channels, **factory_kwargs)
def forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
).execute(x, t, context, ref_latents, transformer_options, **kwargs)
def _forward(self, x, t, context=None, ref_latents=None, transformer_options={}, **kwargs):
# x is the unpacked nested latent: [latent (B,8192,in_channels), camera (B,1,cam_channels)].
# context == feature1.
z, camera = x[0], x[1]
feat1 = context
h_x = self.input_layer(z)
h_cond = self.cond_embedder(feat1)
if ref_latents is not None and self.cond_embedder2 is not None:
# Flatten the Flux2 VAE latent (B,128,h,w) to a token sequence and front-pad to feat1's length
# (the pad count = feat1's prefix tokens: DINOv3 cls + registers), then add to the context.
feat2 = ref_latents[0].flatten(2).transpose(1, 2)
feat2 = F.pad(feat2, (0, 0, feat1.shape[1] - feat2.shape[1], 0))
h_cond = h_cond + self.cond_embedder2(feat2.to(h_cond.dtype))
t_emb = self.t_embedder(t)
t_mod = self.adaLN_modulation(t_emb) if self.share_mod else t_emb
h_x = h_x + self.pos_emb.to(z)
for i, block in enumerate(self.noise_refiner):
h_x = block(h_x, mod=t_mod, rotary_emb=self.noise_repo_layers[i](h_x), transformer_options=transformer_options)
for i, block in enumerate(self.context_refiner):
h_cond = block(h_cond, mod=None, rotary_emb=self.context_repo_layers[i](h_cond), transformer_options=transformer_options)
cam = camera.to(z)
h_cam = self.cam_refiner(cam)
h = torch.cat([h_x, h_cond, h_cam], dim=1)
for i, block in enumerate(self.blocks):
h = block(h, mod=t_mod, rotary_emb=self.repo_layers[i](h), transformer_options=transformer_options)
h_x = F.layer_norm(h[:, :z.shape[1]].float(), h.shape[-1:]).to(z)
h_cam = F.layer_norm(h[:, -cam.shape[1]:].float(), h.shape[-1:]).to(z)
shift, scale = (comfy.model_management.cast_to(self.shift_table, t_emb.dtype, t_emb.device) + t_emb.unsqueeze(1)).chunk(2, dim=1)
scale = 1 + scale
h_x = torch.addcmul(shift, h_x, scale)
h_cam = torch.addcmul(shift, h_cam, scale)
return self.out_layer(h_x), self.cam_out_layer(h_cam)

View File

@ -0,0 +1,91 @@
# Live preview for TripoSplat: decode an x0 estimate into a coarse gaussian splat and render it with a perspective orbit camera.
import numpy as np
from PIL import Image
_C0 = 0.28209479177387814
_LATENT_TOKENS = 8192 # q_token_length
_LATENT_CH = 16 # in_channels
_OBJECT_TO_VIEWER = np.array([[1, 0, 0], [0, 0, -1], [0, 1, 0]], np.float32) # object frame -> viewer Y-up frame
def _view_matrix(yaw_deg, pitch_deg):
y, p = np.radians(yaw_deg), np.radians(pitch_deg)
Ry = np.array([[np.cos(y), 0, np.sin(y)], [0, 1, 0], [-np.sin(y), 0, np.cos(y)]], np.float32)
Rx = np.array([[1, 0, 0], [0, np.cos(p), -np.sin(p)], [0, np.sin(p), np.cos(p)]], np.float32)
return Rx @ Ry
def render_splat(xyz, rgb, scale, opacity=None, yaw=35.0, pitch=30.0, size=320, min_px=2, gain=1.0,
max_px=9, min_opacity=0.0, fov=35.0, dist=2.2):
# Project gaussian centers with a perspective camera and paint each as a filled disk whose screen
# radius follows the gaussian's world-space scale, composited with a nearest-wins z-buffer.
# gain scales the footprint (≈ std spanned), `min_px`/`max_px` clamp the on-screen radius.
pts = xyz.astype(np.float32) @ _OBJECT_TO_VIEWER.T
v = pts @ _view_matrix(yaw, pitch).T
zc = v[:, 2] + dist
keep = zc > 1e-2
if opacity is not None and min_opacity > 0.0: # culls gaussians with very low opacity
keep = keep & (opacity > min_opacity)
v, zc, scale = v[keep], zc[keep], scale[keep]
col = (np.clip(rgb, 0, 1)[:, :3] * 255).astype(np.uint8)[keep]
if v.shape[0] == 0:
return Image.fromarray(np.zeros((size, size, 3), np.uint8))
f = (size / 2) / np.tan(np.radians(fov) / 2)
cx = size / 2 + f * v[:, 0] / zc
cy = size / 2 + f * v[:, 1] / zc
radius = np.clip(np.round(f * scale / zc * gain), min_px, max_px).astype(np.int32)
# Expand each splat to its disk pixels, bucketed by integer radius so it stays vectorized.
px, py, pz, pc = [], [], [], []
for r in range(int(radius.min()), int(radius.max()) + 1):
m = radius == r
if not m.any():
continue
dy, dx = np.mgrid[-r:r + 1, -r:r + 1]
disk = (dx * dx + dy * dy) <= r * r
ox, oy = dx[disk], dy[disk]
px.append((cx[m, None] + ox).ravel())
py.append((cy[m, None] + oy).ravel())
pz.append(np.repeat(zc[m], ox.size))
pc.append(np.repeat(col[m], ox.size, axis=0))
px, py = np.concatenate(px), np.concatenate(py)
pz, pc = np.concatenate(pz), np.concatenate(pc)
xi = np.clip(px, 0, size - 1).astype(np.int64)
yi = np.clip(py, 0, size - 1).astype(np.int64)
# Nearest-wins z-buffer: pack (quantized depth, source index), per-pixel min picks the closest
# splat, then decode the winning index back to its color.
pid = yi * size + xi
q = np.clip((pz * 1024.0).astype(np.int64), 0, (1 << 20) - 1) # near = small
key = (q << 32) | np.arange(pid.size, dtype=np.int64)
buf = np.full(size * size, 1 << 62, np.int64)
np.minimum.at(buf, pid, key)
img = np.zeros((size * size, 3), np.uint8)
hit = buf < (1 << 62)
img[hit] = pc[buf[hit] & 0xFFFFFFFF]
return Image.fromarray(img.reshape(size, size, 3))
def _extract_latent(x0):
# x0 from the sampler callback is the nested latent packed to (B, 1, TOKENS*CH + 1*5);
# the plain single-latent case is (B, TOKENS, CH). Return the (B, TOKENS, CH) latent stream.
if x0.ndim == 3 and x0.shape[1] == _LATENT_TOKENS and x0.shape[2] == _LATENT_CH:
return x0
flat = x0.reshape(x0.shape[0], -1)
return flat[:, :_LATENT_TOKENS * _LATENT_CH].reshape(x0.shape[0], _LATENT_TOKENS, _LATENT_CH)
def decode_x0_to_image(decoder, x0, cfg):
# Decode x0 at a coarse octree level / few gaussians and render a preview image.
latent = _extract_latent(x0)
fsm = decoder.first_stage_model
gaussian = fsm.decode(latent.to(decoder.device, decoder.vae_dtype),
num_gaussians=cfg.get("gaussians", 16384), level=cfg.get("level", 5))[0]
xyz = gaussian.get_xyz.float().cpu().numpy()
rgb = gaussian._features_dc.float().cpu().numpy()[:, 0, :] * _C0 + 0.5
scale = gaussian.get_scaling.float().cpu().numpy().max(axis=1) # per-splat world radius (largest axis)
opacity = gaussian.get_opacity.float().cpu().numpy()[:, 0]
return render_splat(xyz, rgb, scale, opacity=opacity, yaw=cfg.get("yaw", 35.0), pitch=cfg.get("pitch", 30.0),
size=cfg.get("size", 320), min_px=1, gain=1.0, max_px=cfg.get("point_size", 3),
min_opacity=0.01)

382
comfy/ldm/triposplat/vae.py Normal file
View File

@ -0,0 +1,382 @@
# TripoSplat gaussian decoder ("VAE"): an octree probability decoder picks point coords, then an
# elastic-gaussian decoder predicts per-point gaussian params. OctreeGaussianDecoder.decode() returns
# a Gaussian. The octree sampler uses the global torch RNG (no generator) like upstream, so seed it for repeatable decodes.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import comfy.model_management
import comfy.ops
from .gaussian import build_gaussian_models
from .model import MultiHeadRMSNorm, MLP, PcdAbsolutePositionEmbedder, attention
# Quasi-random sampling utilities (pure functions, dtype/device-agnostic)
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
def radical_inverse(base, n):
val = 0
inv_base = 1.0 / base
inv_base_n = inv_base
while n > 0:
digit = n % base
val += digit * inv_base_n
n //= base
inv_base_n *= inv_base
return val
def halton_sequence(dim, n):
return [radical_inverse(PRIMES[i], n) for i in range(dim)]
def hammersley_sequence(dim, n, num_samples):
return [n / num_samples] + halton_sequence(dim - 1, n)
def sample_probs(probs, counts, generator=None):
# Systematic resampling: distribute counts[r] draws across the P bins of row r
batch_shape = counts.shape
R = counts.numel()
P = probs.size(-1)
device = probs.device
probs = probs.reshape(R, P).to(torch.float32).clamp_min(0)
counts = counts.reshape(R).to(device=device, dtype=torch.long)
row_sums = probs.sum(1, keepdim=True)
probs = torch.where(row_sums == 0, probs.new_tensor(1.0 / P), probs / row_sums.clamp_min(1))
cdf = probs.cumsum(dim=1).clamp(max=1.0 - 1e-12)
Nmax = int(counts.max())
if Nmax == 0:
return counts.new_zeros(*batch_shape, P)
cnt = counts.clamp_min(1).float().unsqueeze(1) # (R, 1)
grid = torch.arange(Nmax, device=device, dtype=torch.float32).unsqueeze(0) # (1, Nmax)
u = (torch.rand(R, 1, generator=generator).to(device) + grid) / cnt # (R, Nmax) systematic samples (CPU-seeded)
idx = torch.searchsorted(cdf, u.clamp(max=1.0 - 1e-12)).clamp_max(P - 1)
weight = (grid < counts.unsqueeze(1)).to(cdf.dtype) # mask out j >= counts[r]
out = torch.zeros(R, P, dtype=torch.float32, device=device)
out.scatter_add_(1, idx, weight)
return out.to(torch.long).view(*batch_shape, P)
class MultiHeadAttention(nn.Module):
def __init__(self, channels, num_heads, ctx_channels=None, type="self", qkv_bias=True, qk_rms_norm=False,
dtype=None, device=None, operations=None):
super().__init__()
assert channels % num_heads == 0
self.channels = channels
self.head_dim = channels // num_heads
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
self.num_heads = num_heads
self._type = type
self.qk_rms_norm = qk_rms_norm
if self._type == "self":
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
else:
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, dtype=dtype, device=device)
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, dtype=dtype, device=device)
if self.qk_rms_norm:
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, dtype=dtype, device=device)
self.to_out = operations.Linear(channels, channels, dtype=dtype, device=device)
def forward(self, x, context=None):
B, L, C = x.shape
if self._type == "self":
q, k, v = self.to_qkv(x).reshape(B, L, 3, self.num_heads, -1).unbind(dim=2)
else:
Lkv = context.shape[1]
q = self.to_q(x).reshape(B, L, self.num_heads, -1)
k, v = self.to_kv(context).reshape(B, Lkv, 2, self.num_heads, -1).unbind(dim=2)
if self.qk_rms_norm:
q = self.q_rms_norm(q)
k = self.k_rms_norm(k)
h = attention(q, k, v)
return self.to_out(h.reshape(B, L, -1))
# Octree probability decoder
class LevelEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256, max_period=1024,
dtype=None, device=None, operations=None):
super().__init__()
self.mlp = nn.Sequential(
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
)
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
@staticmethod
def level_embedding(t, dim, max_period=1024):
half = dim // 2
freqs = torch.exp(-np.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
args = t[:, None].float() * freqs[None] * 2 * torch.pi
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
emb = self.level_embedding(t, self.frequency_embedding_size, self.max_period)
return self.mlp(emb.to(self.mlp[0].weight.dtype))
class ModulatedTransformerCrossOnlyBlock(nn.Module):
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0, share_mod=False,
qk_rms_norm_cross=True, qkv_bias=True, dtype=None, device=None, operations=None):
super().__init__()
self.share_mod = share_mod
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads,
type="cross", qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
if not share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device))
def forward(self, x, mod, context):
if self.share_mod:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = mod.chunk(6, dim=1)
else:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
h = torch.addcmul(shift_msa.unsqueeze(1), self.norm1(x), 1 + scale_msa.unsqueeze(1))
x = torch.addcmul(x, self.cross_attn(h, context), gate_msa.unsqueeze(1))
h = torch.addcmul(shift_mlp.unsqueeze(1), self.norm2(x), 1 + scale_mlp.unsqueeze(1))
x = torch.addcmul(x, self.mlp(h), gate_mlp.unsqueeze(1))
return x
class OctreeProbabilityFixedlenDecoder(nn.Module):
# Cross-attention transformer over octree coords -> per-node 8-way child occupancy logits.
def __init__(self, model_channels=1024, cond_channels=16, num_blocks=4, num_heads=16,
num_head_channels=64, mlp_ratio=4.0, share_mod=True,
qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
super().__init__()
self.model_channels = model_channels
self.cond_channels = cond_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.share_mod = share_mod
self.qk_rms_norm_cross = qk_rms_norm_cross
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
self.l_embedder = LevelEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
if share_mod:
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), operations.Linear(model_channels, 6 * model_channels, bias=True, dtype=dtype, device=device))
if cond_channels is not None:
self.blocks = nn.ModuleList([
ModulatedTransformerCrossOnlyBlock(
model_channels, ctx_channels=cond_channels, num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio, qk_rms_norm_cross=self.qk_rms_norm_cross,
share_mod=self.share_mod, dtype=dtype, device=device, operations=operations)
for _ in range(num_blocks)
])
self.out_proj = operations.Linear(model_channels, 8, dtype=dtype, device=device)
self.in_proj = operations.Linear(3, model_channels, dtype=dtype, device=device)
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
def forward(self, x, l, cond):
d = next(self.parameters()).dtype
B, L, _ = x.shape
h = self.in_proj(x.to(d)) + self.pos_embedder(x.reshape(-1, 3)).reshape(B, L, -1).to(d)
h = self.input_layer(h)
l_emb = self.l_embedder(l)
if self.share_mod:
l_emb = self.adaLN_modulation(l_emb)
cond = cond.to(d)
for block in self.blocks:
h = block(h, l_emb, cond)
h = F.layer_norm(h.float(), h.shape[-1:]).to(d)
logits = self.out_proj(h)
return {"logits": logits, "probs": torch.softmax(logits, dim=-1)}
@staticmethod
def sample(model, cond, num_points, level, temperature=1.0, generator=None):
B = cond.shape[0]
device = cond.device
child_offset = torch.tensor([[i, j, k] for k in [0, 1] for j in [0, 1] for i in [0, 1]],
dtype=torch.long, device=device)
prev_coords_int = torch.zeros(B, 1, 3, dtype=torch.long, device=device)
prev_counts = torch.full((B, 1), num_points, dtype=torch.long, device=device)
prev_log_probs = torch.zeros(B, 1, dtype=torch.float32, device=device)
batch_indices_range = torch.arange(B, device=device).unsqueeze(1)
for lv in range(1, level + 1):
res_p = 1 << (lv - 1)
res = 1 << lv
parent_coords_norm = (prev_coords_int.to(torch.float32) + 0.5) / res_p
res_tensor = torch.full((B,), res, dtype=torch.long, device=device)
pred_logits = model(parent_coords_norm, res_tensor, cond)["logits"] / temperature
pred_probs = torch.softmax(pred_logits, dim=-1)
pred_log_probs = torch.log_softmax(pred_logits, dim=-1)
sampled = sample_probs(pred_probs, prev_counts, generator=generator).flatten(1, 2)
pred_log_probs = pred_log_probs.flatten(1, 2)
prev_log_probs_expanded = prev_log_probs.repeat_interleave(8, dim=1)
child_coords_int = (prev_coords_int[:, :, None, :] * 2 + child_offset[None, None, :, :]).flatten(1, 2)
mask = sampled > 0
max_valid = mask.sum(dim=1).max().item()
scatter_indices = mask.cumsum(dim=1) - 1
valid_scatter_indices = scatter_indices[mask]
valid_batch_indices = batch_indices_range.expand_as(mask)[mask]
next_prev_coords_int = torch.zeros(B, max_valid, 3, dtype=child_coords_int.dtype, device=device)
next_prev_coords_int[valid_batch_indices, valid_scatter_indices] = child_coords_int[mask]
next_prev_counts = torch.zeros(B, max_valid, dtype=sampled.dtype, device=device)
next_prev_counts[valid_batch_indices, valid_scatter_indices] = sampled[mask]
next_prev_log_probs = torch.zeros(B, max_valid, dtype=prev_log_probs.dtype, device=device)
next_prev_log_probs[valid_batch_indices, valid_scatter_indices] = (prev_log_probs_expanded + pred_log_probs)[mask]
prev_coords_int = next_prev_coords_int
prev_counts = next_prev_counts
prev_log_probs = next_prev_log_probs
res = 1 << level
prev_log_probs = torch.repeat_interleave(prev_log_probs.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points)
coords_int = torch.repeat_interleave(prev_coords_int.flatten(0, 1), prev_counts.flatten(0, 1), dim=0).reshape(B, num_points, -1)
rand = torch.rand(coords_int.shape, dtype=torch.float32, generator=generator).to(device)
coords_norm = (coords_int.to(torch.float32) + rand) / res
return {"points": coords_norm, "log_probs": prev_log_probs}
# Elastic gaussian decoder
class TransformerCrossBlock(nn.Module):
def __init__(self, channels, ctx_channels, num_heads, mlp_ratio=4.0,
qk_rms_norm=True, qk_rms_norm_cross=True, qkv_bias=True,
dtype=None, device=None, operations=None):
super().__init__()
self.norm1 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.norm2 = operations.LayerNorm(channels, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
self.norm3 = operations.LayerNorm(channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.self_attn = MultiHeadAttention(channels, num_heads=num_heads, type="self", qkv_bias=qkv_bias,
qk_rms_norm=qk_rms_norm, dtype=dtype, device=device, operations=operations)
self.cross_attn = MultiHeadAttention(channels, ctx_channels=ctx_channels, num_heads=num_heads, type="cross",
qkv_bias=qkv_bias, qk_rms_norm=qk_rms_norm_cross, dtype=dtype, device=device, operations=operations)
self.mlp = MLP(channels, int(channels * mlp_ratio), channels, dtype=dtype, device=device, operations=operations)
def forward(self, x, context):
x = x + self.self_attn(self.norm1(x))
x = x + self.cross_attn(self.norm2(x), context)
x = x + self.mlp(self.norm3(x))
return x
class ElasticGaussianFixedlenDecoder(nn.Module):
# Cross-attention transformer over sampled octree points -> per-point gaussian params.
def __init__(self, in_channels=3, model_channels=1024, cond_channels=16, num_blocks=16, num_heads=16,
num_head_channels=64, mlp_ratio=4.0, *, representation_config=None,
qk_rms_norm=True, qk_rms_norm_cross=True, dtype=None, device=None, operations=None):
super().__init__()
self.rep_config = representation_config or dict(
lr=dict(_xyz=1.0, _features_dc=1.0, _opacity=1.0, _scaling=1.0, _rotation=0.1),
perturb_offset=True, perturbe_size=1.5, offset_scale=0.05, num_gaussians=32,
filter_kernel_size_3d=0.0009, scaling_bias=0.004, opacity_bias=0.1,
scaling_activation="softplus",
)
self.out_channels = self._calc_layout()
self.model_channels = model_channels
self.cond_channels = cond_channels
self.num_blocks = num_blocks
self.num_heads = num_heads or model_channels // num_head_channels
self.mlp_ratio = mlp_ratio
self.input_layer = operations.Linear(model_channels, model_channels, dtype=dtype, device=device)
if cond_channels is not None:
self.blocks = nn.ModuleList([
TransformerCrossBlock(model_channels, ctx_channels=cond_channels,
num_heads=self.num_heads, mlp_ratio=self.mlp_ratio,
qk_rms_norm=qk_rms_norm, qk_rms_norm_cross=qk_rms_norm_cross,
dtype=dtype, device=device, operations=operations)
for _ in range(num_blocks)
])
self.in_proj = operations.Linear(in_channels, model_channels, dtype=dtype, device=device)
self.pos_embedder = PcdAbsolutePositionEmbedder(channels=model_channels, in_channels=3, max_res=10, schedule="log2")
self.out_proj = operations.Linear(model_channels, self.out_channels, dtype=dtype, device=device)
self._build_perturbation()
def _calc_layout(self):
ng = self.rep_config['num_gaussians']
self.layout = {
'_xyz': {'shape': (ng, 3), 'size': ng * 3},
'_features_dc': {'shape': (ng, 1, 3), 'size': ng * 3},
'_scaling': {'shape': (ng, 3), 'size': ng * 3},
'_rotation': {'shape': (ng, 4), 'size': ng * 4},
'_opacity': {'shape': (ng, 1), 'size': ng},
}
self.layout['_offset_scale'] = {'shape': (ng, 1), 'size': ng}
start = 0
for k, v in self.layout.items():
v['range'] = (start, start + v['size'])
start += v['size']
return start
def _build_perturbation(self):
ng = self.rep_config['num_gaussians']
perturbation = torch.tensor([hammersley_sequence(3, i, ng) for i in range(ng)]).float()
perturbation = torch.atanh((perturbation * 2 - 1) / self.rep_config['perturbe_size'])
self.register_buffer('points_offset_perturbation', perturbation)
base = torch.tensor(self.rep_config['offset_scale'])
self.register_buffer('base_offset_scale', torch.log(torch.exp(base) - 1.0))
def _get_offset(self, h):
B = h.shape[0]
r = self.layout['_offset_scale']['range']
_offset_scale = F.softplus(
h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_offset_scale']['shape'])
+ comfy.model_management.cast_to(self.base_offset_scale, h.dtype, h.device))
r = self.layout['_xyz']['range']
offset = h[:, :, r[0]:r[1]].reshape(B, -1, *self.layout['_xyz']['shape'])
offset = offset * self.rep_config['lr']['_xyz']
if self.rep_config['perturb_offset']:
offset = offset + comfy.model_management.cast_to(self.points_offset_perturbation, offset.dtype, offset.device)
offset = torch.tanh(offset) * 0.5 * self.rep_config['perturbe_size']
offset = offset * _offset_scale
return offset
def forward(self, x=None, cond=None):
pcd = x["points"]
d = next(self.parameters()).dtype
B, L, _ = pcd.shape
h = self.in_proj(pcd.to(d)) + self.pos_embedder(pcd.reshape(-1, 3)).reshape(B, L, -1).to(d)
h = self.input_layer(h)
cond = cond.to(d)
for block in self.blocks:
h = block(h, cond)
h = F.layer_norm(h.float(), h.shape[-1:]).to(h.dtype)
return {"features": self.out_proj(h)}
# Combined octree gaussian decoder (comfy first-stage model)
class OctreeGaussianDecoder(nn.Module):
_MAX_VOXEL_LEVEL = 8
def __init__(self, dtype=None, device=None, operations=None):
super().__init__()
if operations is None:
operations = comfy.ops.disable_weight_init
self.octree = OctreeProbabilityFixedlenDecoder(dtype=dtype, device=device, operations=operations)
self.gs = ElasticGaussianFixedlenDecoder(dtype=dtype, device=device, operations=operations)
@property
def gaussians_per_point(self) -> int:
return self.gs.rep_config['num_gaussians']
def decode(self, latent: torch.Tensor, num_gaussians: int, level: int = None, generator=None):
# level defaults to the full octree depth, a lower level is cheaper (coarser) for live previews.
# generator (a CPU torch.Generator) makes the octree sampling reproducible without touching global RNG.
level = self._MAX_VOXEL_LEVEL if level is None else level
num_decoder_tokens = max(1, num_gaussians // self.gaussians_per_point)
points_pred = OctreeProbabilityFixedlenDecoder.sample(
self.octree, latent, num_points=num_decoder_tokens, level=level, temperature=1.0, generator=generator,
)
pred = self.gs(x=points_pred, cond=latent)
return build_gaussian_models(self.gs, points_pred, pred) # one GaussianModel per batch item

View File

@ -46,6 +46,7 @@ import comfy.ldm.wan.model_animate
import comfy.ldm.wan.ar_model
import comfy.ldm.wan.model_wandancer
import comfy.ldm.hunyuan3d.model
import comfy.ldm.triposplat.model
import comfy.ldm.hidream.model
import comfy.ldm.chroma.model
import comfy.ldm.chroma_radiance.model
@ -1813,6 +1814,24 @@ class Hunyuan3Dv2_1(BaseModel):
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class TripoSplat(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.triposplat.model.LatentSeqMMFlowModel)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
cross_attn = kwargs.get("cross_attn", None) # DINOv3 token sequence -> cross-attention context.
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
ref_latents = kwargs.get("reference_latents", None) # Flux2 VAE image latent -> additive second conditioning.
if ref_latents is not None:
out['ref_latents'] = comfy.conds.CONDList(list(ref_latents))
latent_shapes = kwargs.get("latent_shapes", None) # {latent, camera} nested latent
if latent_shapes is not None:
out['latent_shapes'] = comfy.conds.CONDConstant(latent_shapes)
return out
class HiDream(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)

View File

@ -676,6 +676,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["guidance_cond_proj_dim"] = None#f"{key_prefix}t_embedder.cond_proj.weight" in state_dict_keys
return dit_config
if '{}cam_out_layer.weight'.format(key_prefix) in state_dict_keys and '{}repo_layers.0.final_map.weight'.format(key_prefix) in state_dict_keys: # TripoSplat
return {"image_model": "triposplat"}
if '{}t_embedder1.mlp.0.weight'.format(key_prefix) in state_dict_keys and '{}x_embedder.proj1.weight'.format(key_prefix) in state_dict_keys: # HiDream-O1
return {"image_model": "hidream_o1"}

View File

@ -16,6 +16,7 @@ import comfy.ldm.cosmos.vae
import comfy.ldm.wan.vae
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.triposplat.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae
import comfy.ldm.hunyuan_video.vae
@ -894,6 +895,16 @@ class VAE:
#Force cast it for --disable-dynamic-vram users until there is a true core fix.
if not comfy.memory_management.aimdo_enabled:
self.disable_offload = True
elif "gs.base_offset_scale" in sd and "octree.out_proj.weight" in sd: # TripoSplat octree gaussian decoder
self.first_stage_model = comfy.ldm.triposplat.vae.OctreeGaussianDecoder()
self.latent_channels = 16
self.latent_dim = 1
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# The generic VAE.encode/decode path isn't used: VAEDecodeTripoSplat calls the gaussian
# decoder directly (structured GaussianSplat objects, not a tensor and reserves VRAM itself from num_gaussians.
def _no_generic_io(*args, **kwargs):
raise RuntimeError("TripoSplat gaussian decoder: use the 'TripoSplat Decode' (VAEDecodeTripoSplat)")
self.memory_used_encode = self.memory_used_decode = _no_generic_io
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None

View File

@ -1538,6 +1538,30 @@ class Hunyuan3Dv2mini(Hunyuan3Dv2):
latent_format = latent_formats.Hunyuan3Dv2mini
class TripoSplat(supported_models_base.BASE):
# Image -> 3D gaussian splat flow denoiser
unet_config = {
"image_model": "triposplat",
}
unet_extra_config = {}
sampling_settings = {
"shift": 3.0,
}
memory_usage_factor = 0.6
latent_format = latent_formats.TripoSplat
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
def get_model(self, state_dict, prefix="", device=None):
return model_base.TripoSplat(self, device=device)
def clip_target(self, state_dict={}):
return None
class HiDream(supported_models_base.BASE):
unet_config = {
"image_model": "hidream",
@ -2200,6 +2224,7 @@ models = [
Hunyuan3Dv2mini,
Hunyuan3Dv2,
Hunyuan3Dv2_1,
TripoSplat,
HiDream,
HiDreamO1,
Chroma,

View File

@ -65,6 +65,12 @@ class VideoInput(ABC):
buffer.seek(0)
return buffer
def get_active_trim_window(self) -> tuple[float, float]:
"""Return the active trim as ``(start_time, duration)`` in seconds (start_time normalized
to ``>= 0``; ``duration == 0`` means "until the end"). Default: no trim; trimmable subclasses override.
"""
return 0.0, 0.0
# Provide a default implementation, but subclasses can provide optimized versions
# if possible.
def get_dimensions(self) -> tuple[int, int]:

View File

@ -75,6 +75,12 @@ class VideoFromFile(VideoInput):
self.__file.seek(0)
return self.__file
def get_active_trim_window(self) -> tuple[float, float]:
start_time = self.__start_time
if start_time < 0:
start_time = max(self._get_raw_duration() + start_time, 0.0)
return float(start_time), float(self.__duration)
def get_dimensions(self) -> tuple[int, int]:
"""
Returns the dimensions of the video input.

View File

@ -1,71 +1,71 @@
from enum import Enum
from typing import Any, Dict, Optional
from typing import Any
from pydantic import BaseModel, Field, confloat, conint
class BFLOutputFormat(str, Enum):
png = 'png'
jpeg = 'jpeg'
from pydantic import BaseModel, Field
class BFLFluxExpandImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image')
bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image')
left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image')
right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
top: int = Field(...)
bottom: int = Field(...)
left: int = Field(...)
right: int = Field(...)
steps: int = Field(...)
guidance: float = Field(...)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image: str = Field(None, description="A Base64-encoded string representing the image you wish to expand")
class BFLFluxFillImageRequest(BaseModel):
prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
steps: int = Field(...)
guidance: float = Field(...)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image: str = Field(
None, description="Base64-encoded string representing the image to modify. Can contain alpha mask if desired.",
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
mask: str = Field(
None, description="Base64-encoded string representing the mask of the areas you wish to modify."
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
class BFLFluxEraseRequest(BaseModel):
image: str = Field(..., description="A Base64-encoded string representing the image to erase from.")
mask: str = Field(
...,
description="A Base64-encoded black/white mask matching the input dimensions; "
"white (255) marks areas to remove, black (0) marks areas to preserve.",
)
image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.')
mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
dilate_pixels: int = Field(10)
output_format: str = Field("png")
class BFLFluxVTORequest(BaseModel):
prompt: str = Field(
..., description="Natural-language styling instruction. Required field, but may be an empty string."
)
person: str = Field(..., description="A Base64-encoded string representing the person image.")
garment: str = Field(..., description="A Base64-encoded string representing the garment reference image.")
seed: int | None = Field(None)
safety_tolerance: int = Field(5)
output_format: str = Field("png")
class BFLFluxProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.')
height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
# image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
# None, description='Blend between the prompt and the image prompt.'
# )
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
width: int = Field(1024, description="Must be a multiple of 32.")
height: int = Field(768, description="Must be a multiple of 32.")
safety_tolerance: int = Field(6)
output_format: str = Field("png")
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
class Flux2ProGenerateRequest(BaseModel):
@ -83,55 +83,37 @@ class Flux2ProGenerateRequest(BaseModel):
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
safety_tolerance: int | None = Field(
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
)
output_format: str | None = Field(
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
)
safety_tolerance: int = Field(5)
output_format: str = Field("png")
class BFLFluxKontextProGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
guidance: confloat(ge=0.1, le=99.0) = Field(..., description='Guidance strength for the image generation process')
steps: conint(ge=1, le=150) = Field(..., description='Number of steps for the image generation process')
safety_tolerance: Optional[conint(ge=0, le=2)] = Field(
2, description='Tolerance level for input and output moderation. Between 0 and 2, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
prompt: str = Field(...)
input_image: str | None = Field(None, description="Image to edit in base64 format")
seed: int | None = Field(None)
guidance: float = Field(...)
steps: int = Field(...)
safety_tolerance: int = Field(2)
output_format: str = Field("png")
aspect_ratio: str | None = Field(None)
prompt_upsampling: bool | None = Field(None)
class BFLFluxProUltraGenerateRequest(BaseModel):
prompt: str = Field(..., description='The text prompt for image generation.')
prompt_upsampling: Optional[bool] = Field(
None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
)
seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
)
output_format: Optional[BFLOutputFormat] = Field(
BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
)
raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
None, description='Blend between the prompt and the image prompt.'
)
prompt: str = Field(...)
prompt_upsampling: bool | None = Field(None)
seed: int | None = Field(None)
aspect_ratio: str | None = Field(None)
safety_tolerance: int = Field(6)
output_format: str = Field("png")
raw: bool | None = Field(None)
image_prompt: str | None = Field(None, description="Optional image to remix in base64 format")
image_prompt_strength: float | None = Field(None)
class BFLFluxProGenerateResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.")
polling_url: str = Field(..., description="URL to poll for the generation result.")
id: str = Field(...)
polling_url: str = Field(...)
cost: float | None = Field(None, description="Price in cents")
@ -145,7 +127,7 @@ class BFLStatus(str, Enum):
class BFLFluxStatusResponse(BaseModel):
id: str = Field(..., description="The unique identifier for the generation task.")
status: BFLStatus = Field(..., description="The status of the task.")
result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).")
progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0)
id: str = Field(...)
status: BFLStatus = Field(...)
result: dict[str, Any] | None = Field(None)
progress: float | None = Field(None, ge=0.0, le=1.0)

View File

@ -4,17 +4,20 @@ from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.bfl import (
BFLFluxEraseRequest,
BFLFluxExpandImageRequest,
BFLFluxFillImageRequest,
BFLFluxKontextProGenerateRequest,
BFLFluxProGenerateResponse,
BFLFluxProUltraGenerateRequest,
BFLFluxStatusResponse,
BFLFluxVTORequest,
BFLStatus,
Flux2ProGenerateRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
convert_mask_to_image,
download_url_to_image_tensor,
get_number_of_images,
poll_op,
@ -22,19 +25,11 @@ from comfy_api_nodes.util import (
sync_op,
tensor_to_base64_string,
validate_aspect_ratio_string,
validate_image_dimensions,
validate_string,
)
def convert_mask_to_image(mask: Input.Image):
"""
Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
"""
mask = mask.unsqueeze(-1)
mask = torch.cat([mask] * 3, dim=-1)
return mask
class FluxProUltraImageNode(IO.ComfyNode):
@classmethod
@ -519,6 +514,163 @@ class FluxProFillNode(IO.ComfyNode):
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxEraseNode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxEraseNode",
display_name="Flux Erase Image",
category="image/partner/BFL",
description="Removes the masked object from an image and reconstructs the background. "
"Paint the mask over what you want to erase.",
inputs=[
IO.Image.Input("image"),
IO.Mask.Input("mask", tooltip="White areas are removed; black areas are preserved."),
IO.Int.Input(
"dilate_pixels",
default=10,
min=0,
max=25,
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"range_usd","min_usd":0.03,"max_usd":0.06,"format":{"approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
mask: Input.Image,
dilate_pixels: int = 10,
) -> IO.NodeOutput:
validate_image_dimensions(image, min_width=256, min_height=256)
mask = resize_mask_to_image(mask, image)
mask = tensor_to_base64_string(convert_mask_to_image(mask))
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/erase-v1", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxEraseRequest(
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
mask=mask,
dilate_pixels=dilate_pixels,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class FluxVTONode(IO.ComfyNode):
@classmethod
def define_schema(cls) -> IO.Schema:
return IO.Schema(
node_id="FluxVTONode",
display_name="Flux Virtual Try-On",
category="image/partner/BFL",
description="Virtual try-on: dresses the person in the provided garment.",
inputs=[
IO.Image.Input("person", tooltip="Image of the person to dress."),
IO.Image.Input("garment", tooltip="Image of the garment to apply."),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Optional natural-language styling instruction (e.g. how the garment should fit).",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=0xFFFFFFFFFFFFFFFF,
control_after_generate=True,
tooltip="The random seed used for creating the noise.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"range_usd","min_usd":0.0375,"max_usd":0.075,"format":{"approximate":true}}""",
),
)
@classmethod
async def execute(
cls,
person: Input.Image,
garment: Input.Image,
prompt: str = "",
seed: int = 0,
) -> IO.NodeOutput:
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/bfl/v1/flux-tools/vto-v1", method="POST"),
response_model=BFLFluxProGenerateResponse,
data=BFLFluxVTORequest(
prompt=prompt,
person=tensor_to_base64_string(person[:, :, :, :3]),
garment=tensor_to_base64_string(garment[:, :, :, :3]),
seed=seed,
),
)
def price_extractor(_r: BaseModel) -> float | None:
return None if initial_response.cost is None else initial_response.cost / 100
response = await poll_op(
cls,
ApiEndpoint(initial_response.polling_url),
response_model=BFLFluxStatusResponse,
status_extractor=lambda r: r.status,
progress_extractor=lambda r: r.progress,
price_extractor=price_extractor,
completed_statuses=[BFLStatus.ready],
failed_statuses=[
BFLStatus.request_moderated,
BFLStatus.content_moderated,
BFLStatus.error,
BFLStatus.task_not_found,
],
queued_statuses=[],
)
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
class Flux2ProImageNode(IO.ComfyNode):
NODE_ID = "Flux2ProImageNode"
@ -853,6 +1005,8 @@ class BFLExtension(ComfyExtension):
FluxKontextMaxImageNode,
FluxProExpandNode,
FluxProFillNode,
FluxEraseNode,
FluxVTONode,
Flux2ProImageNode,
Flux2MaxImageNode,
Flux2ImageNode,

View File

@ -469,6 +469,11 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input
input_container = None
output_container = None
# get_stream_source() is untrimmed, so apply the trim window in this same pass.
# start_time is normalized (>= 0); duration == 0 means "until the end".
start_time, duration = video.get_active_trim_window()
trimming = bool(start_time or duration)
try:
input_source = video.get_stream_source()
input_container = av.open(input_source, mode="r")
@ -487,16 +492,45 @@ def _apply_video_scale(video: Input.Video, scale_dims: tuple[int, int]) -> Input
audio_stream.layout = stream.layout
break
in_video = input_container.streams.video[0]
start_pts = int(start_time / in_video.time_base) if trimming else 0
end_pts = int((start_time + duration) / in_video.time_base) if duration else None
if start_pts:
input_container.seek(start_pts, stream=in_video)
encoded = 0
for frame in input_container.decode(video=0):
if trimming:
if frame.pts is None or frame.pts < start_pts:
continue
if end_pts is not None and frame.pts >= end_pts:
break
frame = frame.reformat(width=out_w, height=out_h, format="yuv420p")
# Re-wrap as a fresh frame: dropping irregular source timestamps (VFR/AVI/GIF/...)
# lets the encoder assign clean ones and avoids mp4 muxer errors.
frame = av.VideoFrame.from_ndarray(frame.to_ndarray(format="yuv420p"), format="yuv420p")
for packet in video_stream.encode(frame):
output_container.mux(packet)
encoded += 1
for packet in video_stream.encode():
output_container.mux(packet)
if encoded == 0:
raise ValueError(
f"resize produced no frames (start_time={start_time}, duration={duration} "
"selected nothing from the source)"
)
if audio_stream is not None:
input_container.seek(0)
for audio_frame in input_container.decode(audio=0):
if trimming:
if audio_frame.time is None or audio_frame.time < start_time:
continue
if duration and audio_frame.time > start_time + duration:
break
# Carry odd audio time bases the mp4 muxer rejects; reset pts, encoder assigns clean ones (MP3-in-AVI)
audio_frame.pts = None
for packet in audio_stream.encode(audio_frame):
output_container.mux(packet)
for packet in audio_stream.encode():

View File

@ -968,7 +968,8 @@ class RenderSplat(IO.ComfyNode):
bg = _hex_to_rgb(background)
bg_imgs = None
if bg_image is not None: # resize the plate(s) to the render size: (B,H,W,3)
bi = comfy.utils.common_upscale(bg_image.movedim(-1, 1), width, height, "bicubic", "disabled")
bi = bg_image[... , :3].movedim(-1, 1) # (B,3,H,W)
bi = comfy.utils.common_upscale(bi, width, height, "bicubic", "disabled")
bg_imgs = bi.movedim(1, -1).clamp(0, 1)
n_frames = abs(int(frames)) or 1 # magnitude = frame count (0 -> single still)
orbit_dir = -1.0 if frames < 0 else 1.0 # sign = orbit direction

View File

@ -0,0 +1,269 @@
# TripoSplat nodes: image -> 3D gaussian splat
import logging
import torch
import torch.nn.functional as F
from typing_extensions import override
import comfy.model_management
import comfy.nested_tensor
import comfy.patcher_extension
import comfy.utils
from comfy_api.latest import ComfyExtension, IO, Types
_Q_TOKEN_LENGTH = 8192
_LATENT_CHANNELS = 16
_CAM_CHANNELS = 5
_DINOV3_MEAN = [0.485, 0.456, 0.406]
_DINOV3_STD = [0.229, 0.224, 0.225]
_NUM_GAUSSIANS_MIN = 32768
_NUM_GAUSSIANS_MAX = 1048576
def _preprocess(image: torch.Tensor, mask: torch.Tensor, erode_radius: int, size: int) -> torch.Tensor:
# Match original preprocessing:
# resize min side to `size` -> erode alpha -> alpha bbox -> 1.2x square crop -> resize -> composite on black.
rgb = image[..., :3].clamp(0, 1).movedim(-1, 0) # (3, H, W)
alpha = mask.clamp(0, 1)[None] # (1, H, W)
rgba = torch.cat([rgb, alpha], 0)[None] # (1, 4, H, W)
h, w = rgba.shape[-2:]
s = size / min(w, h)
rgba = comfy.utils.common_upscale(rgba, max(1, round(w * s)), max(1, round(h * s)), "lanczos", "disabled").clamp(0, 1)
a = rgba[:, 3:4]
if erode_radius > 0:
# min filter over a (2r+1) window == morphological erosion of the alpha matte.
a = -F.max_pool2d(-a, 2 * erode_radius + 1, stride=1, padding=erode_radius)
rgba = torch.cat([rgba[:, :3], a], 1)
ys, xs = torch.nonzero(a[0, 0] > 0, as_tuple=True)
if xs.numel() == 0:
raise ValueError("TripoSplatPreprocessImage: mask is empty (no foreground pixels).")
x0, x1 = int(xs.min()), int(xs.max())
y0, y1 = int(ys.min()), int(ys.max())
cx, cy = (x0 + x1) / 2, (y0 + y1) / 2
half = max(x1 - x0, y1 - y0) / 2 * 1.2
left, upper, right, lower = int(cx - half), int(cy - half), int(cx + half), int(cy + half)
H, W = rgba.shape[-2:]
crop = rgba.new_zeros((1, 4, lower - upper, right - left)) # out-of-bounds stays 0, matching PIL.crop
sx0, sy0, sx1, sy1 = max(left, 0), max(upper, 0), min(right, W), min(lower, H)
if sx1 > sx0 and sy1 > sy0:
crop[:, :, sy0 - upper:sy1 - upper, sx0 - left:sx1 - left] = rgba[:, :, sy0:sy1, sx0:sx1]
crop = comfy.utils.common_upscale(crop, size, size, "lanczos", "disabled").clamp(0, 1)
out = (crop[:, :3] * crop[:, 3:4])[0].movedim(0, -1) # composite over black == rgb * alpha
return out.unsqueeze(0) # (1, 1024, 1024, 3)
class TripoSplatPreprocessImage(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoSplatPreprocessImage",
display_name="TripoSplat Preprocess Image",
category="3d/conditioning",
description="Crop center each image to a square canvas on a black background and add padding.",
inputs=[
IO.Image.Input("image"),
IO.Mask.Input("mask"),
IO.Int.Input("erode_radius", default=1, min=0, max=16,
tooltip="Erode the alpha matte by this pixel radius before cropping (avoids border bleed)."),
IO.Int.Input("size", default=1024, min=256, max=4096, step=16,
tooltip="Square image size. The model is trained at 1024; other sizes run but are off-distribution."),
],
outputs=[IO.Image.Output(display_name="image")],
)
@classmethod
def execute(cls, image, mask, erode_radius, size) -> IO.NodeOutput:
size = max(16, (int(size) // 16) * 16) # DINOv3 patch / Flux2 VAE stride is 16
if mask.shape[0] != image.shape[0]:
mask = comfy.utils.repeat_to_batch_size(mask, image.shape[0])
if tuple(mask.shape[1:]) != tuple(image.shape[1:3]):
mask = F.interpolate(mask[:, None].float(), size=tuple(image.shape[1:3]), mode="bilinear", align_corners=False)[:, 0]
prepared = torch.cat([_preprocess(image[i], mask[i], erode_radius, size) for i in range(image.shape[0])], dim=0)
return IO.NodeOutput(prepared)
class TripoSplatConditioning(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoSplatConditioning",
display_name="TripoSplat Conditioning",
category="3d/conditioning",
description="Encode the image with DINOv3 and the Flux2 VAE into TripoSplat positive/negative "
"conditioning, and create the fixed size noise target (latent + camera) for the KSampler",
inputs=[
IO.ClipVision.Input("clip_vision", tooltip="DINOv3 ViT-H/16+ image encoder"),
IO.Vae.Input("vae", tooltip="Flux2 VAE"),
IO.Image.Input("image"),
],
outputs=[
IO.Conditioning.Output(display_name="positive"),
IO.Conditioning.Output(display_name="negative"),
IO.Latent.Output(display_name="latent", tooltip="The fixed size noise target (latent +camera)."),
],
)
@classmethod
def execute(cls, clip_vision, vae, image) -> IO.NodeOutput:
# feature1: DINOv3 token sequence (cls + registers + patches), ImageNet-normalized, with a final non-affine layer norm on top
comfy.model_management.load_model_gpu(clip_vision.patcher)
device = clip_vision.load_device
model_dtype = next(clip_vision.model.parameters()).dtype
img = image.movedim(-1, 1).to(device) # (B,3,H,W) in [0,1]
mean = torch.tensor(_DINOV3_MEAN, device=device).view(1, 3, 1, 1)
std = torch.tensor(_DINOV3_STD, device=device).view(1, 3, 1, 1)
img = (img - mean) / std
seq = clip_vision.model(pixel_values=img.to(model_dtype))[0]
feature1 = F.layer_norm(seq.float(), seq.shape[-1:]).to(comfy.model_management.intermediate_device())
# Second conditioning: the Flux2 VAE latent of the image, carried as a standard reference_latents entry
ref = vae.encode(image).to(comfy.model_management.intermediate_device()) # (B, 128, H, W)
b = ref.shape[0]
positive = [[feature1, {"reference_latents": [ref]}]]
negative = [[torch.zeros_like(feature1), {"reference_latents": [torch.zeros_like(ref)]}]]
# Fixed noise target: the latent is a constant-shape (8192, 16) shape-code + a (1, 5) camera token
dev = comfy.model_management.intermediate_device()
latent_seq = torch.zeros([b, _Q_TOKEN_LENGTH, _LATENT_CHANNELS], device=dev)
camera = torch.zeros([b, 1, _CAM_CHANNELS], device=dev)
samples = comfy.nested_tensor.NestedTensor((latent_seq, camera))
return IO.NodeOutput(positive, negative, {"samples": samples})
class VAEDecodeTripoSplat(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="VAEDecodeTripoSplat",
display_name="TripoSplat Decode",
category="3d/latent",
description="Decode the sampled TripoSplat latent into a 3D gaussian splat. "
"Modify the number of gaussians to vary the density.",
inputs=[
IO.Latent.Input("samples"),
IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"),
IO.Int.Input("num_gaussians", default=262144, min=_NUM_GAUSSIANS_MIN, max=_NUM_GAUSSIANS_MAX, step=32,
tooltip="Number of gaussians to produce (rounded to a multiple of 32). "
"262144 matches the octree's point density; higher oversamples the same points "
"(denser, but no new detail) and costs proportionally more VRAM/time."),
IO.Int.Input("seed", default=0, min=0, max=0xffffffffffffffff,
tooltip="Seeds the octree point sampler (global RNG) for deterministic decodes."),
],
outputs=[IO.Splat.Output(display_name="splat")],
)
@classmethod
def execute(cls, samples, vae, num_gaussians, seed) -> IO.NodeOutput:
s = samples["samples"]
latent = s.unbind()[0] if getattr(s, "is_nested", False) else s # take the latent stream, drop camera
decoder = vae.first_stage_model
gpp = decoder.gaussians_per_point
n = max(_NUM_GAUSSIANS_MIN, min(_NUM_GAUSSIANS_MAX, int(num_gaussians)))
if n % gpp != 0:
n = round(n / gpp) * gpp
dtype_size = comfy.model_management.dtype_size(vae.vae_dtype)
hidden = decoder.gs.model_channels
cond_tokens = latent.shape[1]
memory_required = (cond_tokens * 4 + (n // gpp) * 10) * hidden * dtype_size
comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required)
latent = latent.to(device=vae.device, dtype=vae.vae_dtype)
generator = torch.Generator(device="cpu").manual_seed(seed)
parts = [g.render_tensors() for g in decoder.decode(latent, num_gaussians=n, generator=generator)]
positions, scales, rotations, opacities, sh = (torch.stack(t) for t in zip(*parts))
return IO.NodeOutput(Types.SPLAT(positions, scales, rotations, opacities, sh))
class TripoSplatSamplingPreview(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TripoSplatSamplingPreview",
display_name="TripoSplat Sampling Preview",
category="3d/latent",
description="Patch the TripoSplat model for the standard Ksampler node to show a live decoded "
"gaussian splat preview at each step.",
inputs=[
IO.Model.Input("model"),
IO.Vae.Input("vae", tooltip="TripoSplat VAE decoder"),
IO.Int.Input("octree_level", default=5, min=2, max=8, advanced=True,
tooltip="Octree depth for the preview decode (lower = cheaper/coarser)."),
IO.Int.Input("num_gaussians", default=16384, min=1024, max=262144, step=32,
tooltip="Number of gaussians to produce for the preview (rounded to a multiple of 32)."),
IO.Float.Input("yaw", default=90.0, min=-360.0, max=360.0, step=1.0, tooltip="Preview camera yaw in degrees.", advanced=True,),
IO.Float.Input("pitch", default=15.0, min=-89.0, max=89.0, step=1.0, tooltip="Preview camera pitch in degrees.", advanced=True,),
IO.Int.Input("point_size", default=3, min=1, max=16,
tooltip="Maximum splat radius in pixels. Each gaussian is sized from its scale and capped here; "
"lower = finer/pointier, higher = chunkier."),
],
outputs=[IO.Model.Output()],
)
@classmethod
def execute(cls, model, vae, octree_level, num_gaussians, yaw, pitch, point_size) -> IO.NodeOutput:
from comfy.ldm.triposplat.preview import decode_x0_to_image
cfg = {"gaussians": num_gaussians, "level": octree_level, "yaw": yaw, "pitch": pitch,
"point_size": point_size}
fsm = vae.first_stage_model
cond_tokens = model.model.diffusion_model.q_token_length
memory_required = (cond_tokens * 4 + (num_gaussians // fsm.gaussians_per_point) * 10) * fsm.gs.model_channels * comfy.model_management.dtype_size(vae.vae_dtype)
# Live preview via WrappersMP.OUTER_SAMPLE + ProgressBar
# The wrapper augments the sampler's own callback to decode x0 -> gaussian splat -> preview image each step
def outer_sample_wrapper(executor, *args, **kwargs):
args = list(args)
cb_idx = 5 # outer_sample(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed)
orig_cb = args[cb_idx] if len(args) > cb_idx else kwargs.get("callback")
state = {"ok": True, "pbar": None, "loaded": False}
def callback(step, x0, x, total_steps):
if orig_cb is not None:
orig_cb(step, x0, x, total_steps)
if not state["ok"]:
return
try:
if not state["loaded"]:
comfy.model_management.load_models_gpu([vae.patcher], memory_required=memory_required)
state["loaded"] = True
img = decode_x0_to_image(vae, x0, cfg)
if state["pbar"] is None:
state["pbar"] = comfy.utils.ProgressBar(total_steps)
state["pbar"].update_absolute(step + 1, total_steps, ("JPEG", img, 512))
except Exception as e:
logging.warning("TripoSplatSamplingPreview: preview failed, disabling ({})".format(e))
state["ok"] = False
if len(args) > cb_idx:
args[cb_idx] = callback
else:
kwargs["callback"] = callback
return executor(*args, **kwargs)
m = model.clone()
m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "triposplat_sampling_preview", outer_sample_wrapper)
return IO.NodeOutput(m)
class TripoSplatExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TripoSplatPreprocessImage,
TripoSplatConditioning,
VAEDecodeTripoSplat,
TripoSplatSamplingPreview,
]
async def comfy_entrypoint() -> TripoSplatExtension:
return TripoSplatExtension()

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.22.0"
__version__ = "0.23.0"

View File

@ -2457,6 +2457,7 @@ async def init_builtin_extra_nodes():
"nodes_moge.py",
"nodes_mediapipe.py",
"nodes_gaussian_splat.py",
"nodes_triposplat.py"
]
import_failed = []

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.22.0"
version = "0.23.0"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.44.19
comfyui-workflow-templates==0.9.91
comfyui-workflow-templates==0.9.92
comfyui-embedded-docs==0.5.2
torch
torchsde