mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-20 15:32:32 +08:00
Merge f4ae7b8391 into 138571da95
This commit is contained in:
commit
b9f2709a14
@ -9,6 +9,7 @@ import comfy.model_management
|
||||
import comfy.utils
|
||||
import comfy.clip_model
|
||||
import comfy.image_encoders.dino2
|
||||
import comfy.image_encoders.dino3
|
||||
|
||||
class Output:
|
||||
def __getitem__(self, key):
|
||||
@ -23,6 +24,7 @@ IMAGE_ENCODERS = {
|
||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||
"dinov3": comfy.image_encoders.dino3.DINOv3ViTModel
|
||||
}
|
||||
|
||||
class ClipVisionModel():
|
||||
@ -134,6 +136,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||
elif 'layer.9.attention.o_proj.bias' in sd: # dinov3
|
||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino3_large.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
285
comfy/image_encoders/dino3.py
Normal file
285
comfy/image_encoders/dino3.py
Normal file
@ -0,0 +1,285 @@
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
|
||||
|
||||
class DINOv3ViTMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.act_fn = torch.nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||
|
||||
def rotate_half(x):
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
|
||||
num_tokens = q.shape[-2]
|
||||
num_patches = sin.shape[-2]
|
||||
num_prefix_tokens = num_tokens - num_patches
|
||||
|
||||
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
|
||||
|
||||
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
|
||||
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
|
||||
|
||||
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
|
||||
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
|
||||
|
||||
return q, k
|
||||
|
||||
class DINOv3ViTAttention(nn.Module):
|
||||
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.embed_dim = hidden_size
|
||||
self.num_heads = num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
|
||||
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
|
||||
batch_size, patches, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
if position_embeddings is not None:
|
||||
cos, sin = position_embeddings
|
||||
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||
|
||||
attn = optimized_attention_for_device(query_states.device, mask=False)
|
||||
|
||||
attn_output = attn(
|
||||
query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
class DINOv3ViTGatedMLP(nn.Module):
|
||||
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||
self.act_fn = torch.nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
def get_patches_center_coordinates(
|
||||
num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
|
||||
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
|
||||
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
|
||||
coords_h = coords_h / num_patches_h
|
||||
coords_w = coords_w / num_patches_w
|
||||
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
||||
coords = coords.flatten(0, 1)
|
||||
coords = 2.0 * coords - 1.0
|
||||
return coords
|
||||
|
||||
class DINOv3ViTRopePositionEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor
|
||||
|
||||
def __init__(self, rope_theta, hidden_size, num_attention_heads, image_size, patch_size, device, dtype):
|
||||
super().__init__()
|
||||
self.base = rope_theta
|
||||
self.head_dim = hidden_size // num_attention_heads
|
||||
self.num_patches_h = image_size // patch_size
|
||||
self.num_patches_w = image_size // patch_size
|
||||
self.patch_size = patch_size
|
||||
|
||||
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device)
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
_, _, height, width = pixel_values.shape
|
||||
num_patches_h = height // self.patch_size
|
||||
num_patches_w = width // self.patch_size
|
||||
|
||||
device = pixel_values.device
|
||||
device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
|
||||
with torch.amp.autocast(device_type = device_type, enabled=False):
|
||||
patch_coords = get_patches_center_coordinates(
|
||||
num_patches_h, num_patches_w, dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
self.inv_freq = self.inv_freq.to(device)
|
||||
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
|
||||
angles = angles.flatten(1, 2)
|
||||
angles = angles.tile(2)
|
||||
|
||||
cos = torch.cos(angles)
|
||||
sin = torch.sin(angles)
|
||||
|
||||
dtype = pixel_values.dtype
|
||||
return cos.to(dtype=dtype), sin.to(dtype=dtype)
|
||||
|
||||
|
||||
class DINOv3ViTEmbeddings(nn.Module):
|
||||
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
|
||||
super().__init__()
|
||||
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_size, device=device, dtype=dtype))
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype))
|
||||
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
|
||||
self.patch_embeddings = operations.Conv2d(
|
||||
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
|
||||
)
|
||||
|
||||
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None):
|
||||
batch_size = pixel_values.shape[0]
|
||||
target_dtype = self.patch_embeddings.weight.dtype
|
||||
|
||||
patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
|
||||
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
|
||||
|
||||
if bool_masked_pos is not None:
|
||||
mask_token = self.mask_token.to(patch_embeddings.dtype)
|
||||
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
|
||||
|
||||
cls_token = self.cls_token.expand(batch_size, -1, -1)
|
||||
register_tokens = self.register_tokens.expand(batch_size, -1, -1)
|
||||
device = patch_embeddings.device
|
||||
cls_token = cls_token.to(device)
|
||||
register_tokens = register_tokens.to(device)
|
||||
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
|
||||
|
||||
return embeddings
|
||||
|
||||
class DINOv3ViTLayer(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size, num_attention_heads,
|
||||
device, dtype, operations):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||
|
||||
self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
if use_gated_mlp:
|
||||
self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations)
|
||||
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.attention(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
hidden_states = self.layer_scale1(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = self.layer_scale2(hidden_states)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DINOv3ViTModel(nn.Module):
|
||||
def __init__(self, config, dtype, device, operations):
|
||||
super().__init__()
|
||||
use_bf16 = comfy.model_management.should_use_bf16(device, prioritize_performance=True)
|
||||
if dtype == torch.float16 and use_bf16:
|
||||
dtype = torch.bfloat16
|
||||
elif dtype == torch.float16 and not use_bf16:
|
||||
dtype = torch.float32
|
||||
num_hidden_layers = config["num_hidden_layers"]
|
||||
hidden_size = config["hidden_size"]
|
||||
num_attention_heads = config["num_attention_heads"]
|
||||
num_register_tokens = config["num_register_tokens"]
|
||||
intermediate_size = config["intermediate_size"]
|
||||
layer_norm_eps = config["layer_norm_eps"]
|
||||
num_channels = config["num_channels"]
|
||||
patch_size = config["patch_size"]
|
||||
rope_theta = config["rope_theta"]
|
||||
|
||||
self.embeddings = DINOv3ViTEmbeddings(
|
||||
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
|
||||
rope_theta, hidden_size, num_attention_heads, image_size=512, patch_size=patch_size, dtype=dtype, device=device
|
||||
)
|
||||
self.layer = nn.ModuleList(
|
||||
[DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, mlp_bias=True,
|
||||
intermediate_size=intermediate_size,num_attention_heads = num_attention_heads,
|
||||
dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(num_hidden_layers)])
|
||||
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings.patch_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
bool_masked_pos: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||
position_embeddings = self.rope_embeddings(pixel_values)
|
||||
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
hidden_states = layer_module(
|
||||
hidden_states,
|
||||
position_embeddings=position_embeddings,
|
||||
)
|
||||
|
||||
if kwargs.get("skip_norm_elementwise", False):
|
||||
sequence_output= F.layer_norm(hidden_states, hidden_states.shape[-1:])
|
||||
else:
|
||||
norm = self.norm.to(hidden_states.device)
|
||||
sequence_output = norm(hidden_states)
|
||||
pooled_output = sequence_output[:, 0, :]
|
||||
|
||||
return sequence_output, None, pooled_output, None
|
||||
23
comfy/image_encoders/dino3_large.json
Normal file
23
comfy/image_encoders/dino3_large.json
Normal file
@ -0,0 +1,23 @@
|
||||
{
|
||||
"model_type": "dinov3",
|
||||
"hidden_size": 1024,
|
||||
"image_size": 224,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"key_bias": false,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"mlp_bias": true,
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"num_register_tokens": 4,
|
||||
"patch_size": 16,
|
||||
"pos_embed_rescale": 2.0,
|
||||
"proj_bias": true,
|
||||
"query_bias": true,
|
||||
"rope_theta": 100.0,
|
||||
"use_gated_mlp": false,
|
||||
"value_bias": true,
|
||||
"image_mean": [0.485, 0.456, 0.406],
|
||||
"image_std": [0.229, 0.224, 0.225]
|
||||
}
|
||||
@ -746,6 +746,8 @@ class Hunyuan3Dv2_1(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
class Trellis2(LatentFormat): # TODO
|
||||
latent_channels = 32
|
||||
class Hunyuan3Dv2mini(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
282
comfy/ldm/trellis2/attention.py
Normal file
282
comfy/ldm/trellis2/attention.py
Normal file
@ -0,0 +1,282 @@
|
||||
import torch
|
||||
import math
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from typing import Tuple, Union, List
|
||||
from comfy.ldm.trellis2.vae import VarLenTensor
|
||||
import comfy.ops
|
||||
|
||||
|
||||
# replica of the seedvr2 code
|
||||
def var_attn_arg(kwargs):
|
||||
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q)
|
||||
max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q)
|
||||
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = True
|
||||
if var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
||||
if not skip_reshape:
|
||||
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||
total_tokens, embed_dim = q.shape
|
||||
head_dim = embed_dim // heads
|
||||
q = q.view(total_tokens, heads, head_dim)
|
||||
k = k.view(k.shape[0], heads, head_dim)
|
||||
v = v.view(v.shape[0], heads, head_dim)
|
||||
|
||||
b = q.size(0)
|
||||
dim_head = q.shape[-1]
|
||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||
|
||||
mask = None
|
||||
q = q.transpose(1, 2)
|
||||
k = k.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
if mask is not None:
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if var_length:
|
||||
return out.transpose(1, 2).values()
|
||||
if not skip_output_reshape:
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
return out
|
||||
|
||||
def scaled_dot_product_attention(*args, **kwargs):
|
||||
num_all_args = len(args) + len(kwargs)
|
||||
|
||||
q = None
|
||||
if num_all_args == 1:
|
||||
qkv = args[0] if len(args) > 0 else kwargs.get('qkv')
|
||||
elif num_all_args == 2:
|
||||
q = args[0] if len(args) > 0 else kwargs.get('q')
|
||||
kv = args[1] if len(args) > 1 else kwargs.get('kv')
|
||||
elif num_all_args == 3:
|
||||
q = args[0] if len(args) > 0 else kwargs.get('q')
|
||||
k = args[1] if len(args) > 1 else kwargs.get('k')
|
||||
v = args[2] if len(args) > 2 else kwargs.get('v')
|
||||
|
||||
if q is not None:
|
||||
heads = q.shape[2]
|
||||
else:
|
||||
heads = qkv.shape[3]
|
||||
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=2)
|
||||
|
||||
q = q.permute(0, 2, 1, 3)
|
||||
k = k.permute(0, 2, 1, 3)
|
||||
v = v.permute(0, 2, 1, 3)
|
||||
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs)
|
||||
|
||||
out = out.permute(0, 2, 1, 3)
|
||||
|
||||
return out
|
||||
|
||||
def sparse_windowed_scaled_dot_product_self_attention(
|
||||
qkv,
|
||||
window_size: int,
|
||||
shift_window: Tuple[int, int, int] = (0, 0, 0)
|
||||
):
|
||||
|
||||
serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}'
|
||||
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
|
||||
if serialization_spatial_cache is None:
|
||||
fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window)
|
||||
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args))
|
||||
else:
|
||||
fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
|
||||
|
||||
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
|
||||
heads = qkv_feats.shape[2]
|
||||
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
q, k, v = qkv_feats.unbind(dim=1)
|
||||
q = q.unsqueeze(0) # [1, M, H, C]
|
||||
k = k.unsqueeze(0) # [1, M, H, C]
|
||||
v = v.unsqueeze(0) # [1, M, H, C]
|
||||
#out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
elif optimized_attention.__name__ == 'attention_flash':
|
||||
if 'flash_attn' not in globals():
|
||||
import flash_attn
|
||||
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
|
||||
else:
|
||||
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||
|
||||
out = out[bwd_indices] # [T, H, C]
|
||||
|
||||
return qkv.replace(out)
|
||||
|
||||
def calc_window_partition(
|
||||
tensor,
|
||||
window_size: Union[int, Tuple[int, ...]],
|
||||
shift_window: Union[int, Tuple[int, ...]] = 0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
|
||||
|
||||
DIM = tensor.coords.shape[1] - 1
|
||||
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
|
||||
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
|
||||
shifted_coords = tensor.coords.clone().detach()
|
||||
shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
||||
|
||||
MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)]
|
||||
NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
|
||||
OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
|
||||
|
||||
shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
||||
shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
|
||||
fwd_indices = torch.argsort(shifted_indices)
|
||||
bwd_indices = torch.empty_like(fwd_indices)
|
||||
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
|
||||
seq_lens = torch.bincount(shifted_indices)
|
||||
mask = seq_lens != 0
|
||||
seq_lens = seq_lens[mask]
|
||||
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
if 'xops' not in globals():
|
||||
import xformers.ops as xops
|
||||
attn_func_args = {
|
||||
'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
|
||||
}
|
||||
elif optimized_attention.__name__ == 'attention_flash':
|
||||
attn_func_args = {
|
||||
'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(),
|
||||
'max_seqlen': torch.max(seq_lens)
|
||||
}
|
||||
|
||||
return fwd_indices, bwd_indices, seq_lens, attn_func_args
|
||||
|
||||
|
||||
def sparse_scaled_dot_product_attention(*args, **kwargs):
|
||||
q=None
|
||||
arg_names_dict = {
|
||||
1: ['qkv'],
|
||||
2: ['q', 'kv'],
|
||||
3: ['q', 'k', 'v']
|
||||
}
|
||||
num_all_args = len(args) + len(kwargs)
|
||||
for key in arg_names_dict[num_all_args][len(args):]:
|
||||
assert key in kwargs, f"Missing argument {key}"
|
||||
|
||||
if num_all_args == 1:
|
||||
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
||||
device = qkv.device
|
||||
|
||||
s = qkv
|
||||
q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
|
||||
kv_seqlen = q_seqlen
|
||||
qkv = qkv.feats # [T, 3, H, C]
|
||||
|
||||
elif num_all_args == 2:
|
||||
q = args[0] if len(args) > 0 else kwargs['q']
|
||||
kv = args[1] if len(args) > 1 else kwargs['kv']
|
||||
device = q.device
|
||||
|
||||
if isinstance(q, VarLenTensor):
|
||||
s = q
|
||||
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
|
||||
q = q.feats # [T_Q, H, C]
|
||||
else:
|
||||
s = None
|
||||
N, L, H, C = q.shape
|
||||
q_seqlen = [L] * N
|
||||
q = q.reshape(N * L, H, C) # [T_Q, H, C]
|
||||
|
||||
if isinstance(kv, VarLenTensor):
|
||||
kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
|
||||
kv = kv.feats # [T_KV, 2, H, C]
|
||||
else:
|
||||
N, L, _, H, C = kv.shape
|
||||
kv_seqlen = [L] * N
|
||||
kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
|
||||
|
||||
elif num_all_args == 3:
|
||||
q = args[0] if len(args) > 0 else kwargs['q']
|
||||
k = args[1] if len(args) > 1 else kwargs['k']
|
||||
v = args[2] if len(args) > 2 else kwargs['v']
|
||||
device = q.device
|
||||
|
||||
if isinstance(q, VarLenTensor):
|
||||
s = q
|
||||
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
|
||||
q = q.feats # [T_Q, H, Ci]
|
||||
else:
|
||||
s = None
|
||||
N, L, H, CI = q.shape
|
||||
q_seqlen = [L] * N
|
||||
q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
|
||||
|
||||
if isinstance(k, VarLenTensor):
|
||||
kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
|
||||
k = k.feats # [T_KV, H, Ci]
|
||||
v = v.feats # [T_KV, H, Co]
|
||||
else:
|
||||
N, L, H, CI, CO = *k.shape, v.shape[-1]
|
||||
kv_seqlen = [L] * N
|
||||
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
|
||||
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
|
||||
|
||||
# TODO: change
|
||||
if q is not None:
|
||||
heads = q
|
||||
else:
|
||||
heads = qkv
|
||||
heads = heads.shape[2]
|
||||
if optimized_attention.__name__ == 'attention_xformers':
|
||||
if 'xops' not in globals():
|
||||
import xformers.ops as xops
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=1)
|
||||
q = q.unsqueeze(0)
|
||||
k = k.unsqueeze(0)
|
||||
v = v.unsqueeze(0)
|
||||
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
|
||||
out = xops.memory_efficient_attention(q, k, v, mask)[0]
|
||||
elif optimized_attention.__name__ == 'attention_flash':
|
||||
if 'flash_attn' not in globals():
|
||||
import flash_attn
|
||||
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args in [2, 3]:
|
||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args == 1:
|
||||
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
|
||||
elif num_all_args == 2:
|
||||
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||
elif num_all_args == 3:
|
||||
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||
|
||||
elif optimized_attention.__name__ == "attention_pytorch":
|
||||
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
||||
if num_all_args in [2, 3]:
|
||||
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||
else:
|
||||
cu_seqlens_kv = cu_seqlens_q
|
||||
if num_all_args == 1:
|
||||
q, k, v = qkv.unbind(dim=1)
|
||||
elif num_all_args == 2:
|
||||
k, v = kv.unbind(dim=1)
|
||||
out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen),
|
||||
skip_reshape=True, skip_output_reshape=True)
|
||||
|
||||
if s is not None:
|
||||
return s.replace(out)
|
||||
else:
|
||||
return out.reshape(N, L, H, -1)
|
||||
433
comfy/ldm/trellis2/cumesh.py
Normal file
433
comfy/ldm/trellis2/cumesh.py
Normal file
@ -0,0 +1,433 @@
|
||||
# will contain every cuda -> pytorch operation
|
||||
|
||||
import math
|
||||
import torch
|
||||
from typing import Callable
|
||||
import logging
|
||||
|
||||
NO_TRITON = False
|
||||
try:
|
||||
allow_tf32 = torch.cuda.is_tf32_supported()
|
||||
except Exception:
|
||||
allow_tf32 = False
|
||||
try:
|
||||
import triton
|
||||
import triton.language as tl
|
||||
heuristics = {
|
||||
'valid_kernel': lambda args: args['valid_kernel'](args['B1']),
|
||||
'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']),
|
||||
}
|
||||
|
||||
#@triton_autotune(
|
||||
# configs=config.autotune_config,
|
||||
# key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
|
||||
#)
|
||||
@triton.heuristics(heuristics)
|
||||
@triton.jit
|
||||
def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
neighbor,
|
||||
sorted_idx,
|
||||
output,
|
||||
# Tensor dimensions
|
||||
N, LOGN, Ci, Co, V: tl.constexpr,
|
||||
# Meta-parameters
|
||||
B1: tl.constexpr, # Block size for N dimension
|
||||
B2: tl.constexpr, # Block size for Co dimension
|
||||
BK: tl.constexpr, # Block size for K dimension (V * Ci)
|
||||
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
|
||||
# Huristic parameters
|
||||
valid_kernel,
|
||||
valid_kernel_seg,
|
||||
):
|
||||
|
||||
block_id = tl.program_id(axis=0)
|
||||
block_dim_co = tl.cdiv(Co, B2)
|
||||
block_id_co = block_id % block_dim_co
|
||||
block_id_n = block_id // block_dim_co
|
||||
|
||||
# Create pointers for submatrices of A and B.
|
||||
num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
|
||||
valid_kernel_start = tl.load(valid_kernel_seg + block_id_n)
|
||||
valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start
|
||||
offset_n = block_id_n * B1 + tl.arange(0, B1)
|
||||
n_mask = offset_n < N
|
||||
offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,)
|
||||
offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
|
||||
offset_k = tl.arange(0, BK) # (BK,)
|
||||
|
||||
# Create a block of the output matrix C.
|
||||
accumulator = tl.zeros((B1, B2), dtype=tl.float32)
|
||||
|
||||
# Iterate along V*Ci dimension.
|
||||
for k in range(num_k * valid_kernel_seglen):
|
||||
v = k // num_k
|
||||
bk = k % num_k
|
||||
v = tl.load(valid_kernel + valid_kernel_start + v)
|
||||
# Calculate pointers to input matrix.
|
||||
neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,)
|
||||
input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
|
||||
# Calculate pointers to weight matrix.
|
||||
weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
|
||||
# Load the next block of input and weight.
|
||||
neigh_mask = neighbor_offset_n != 0xffffffff
|
||||
k_mask = offset_k < Ci - bk * BK
|
||||
input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
|
||||
weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
|
||||
# Accumulate along the K dimension.
|
||||
accumulator = tl.dot(input_block, weight_block, accumulator,
|
||||
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
|
||||
c = accumulator.to(input.type.element_ty)
|
||||
|
||||
# add bias
|
||||
if bias is not None:
|
||||
bias_block = tl.load(bias + offset_co)
|
||||
c += bias_block[None, :]
|
||||
|
||||
# Write back the block of the output matrix with masks.
|
||||
out_offset_n = offset_sorted_n
|
||||
out_offset_co = block_id_co * B2 + tl.arange(0, B2)
|
||||
out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :])
|
||||
out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co)
|
||||
tl.store(out_ptr, c, mask=out_mask)
|
||||
def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
neighbor: torch.Tensor,
|
||||
sorted_idx: torch.Tensor,
|
||||
valid_kernel: Callable[[int], torch.Tensor],
|
||||
valid_kernel_seg: Callable[[int], torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
|
||||
LOGN = int(math.log2(N))
|
||||
output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
|
||||
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
|
||||
sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid](
|
||||
input, weight, bias, neighbor, sorted_idx, output,
|
||||
N, LOGN, Ci, Co, V,
|
||||
B1=128,
|
||||
B2=64,
|
||||
BK=32,
|
||||
valid_kernel=valid_kernel,
|
||||
valid_kernel_seg=valid_kernel_seg,
|
||||
allow_tf32=allow_tf32,
|
||||
)
|
||||
return output
|
||||
except Exception:
|
||||
NO_TRITON = True
|
||||
|
||||
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
||||
# offsets in same order as CUDA kernel
|
||||
offsets = []
|
||||
for vx in range(Kw):
|
||||
for vy in range(Kh):
|
||||
for vz in range(Kd):
|
||||
offsets.append((
|
||||
vx * Dw,
|
||||
vy * Dh,
|
||||
vz * Dd
|
||||
))
|
||||
return torch.tensor(offsets, device=device)
|
||||
|
||||
def build_submanifold_neighbor_map(
|
||||
hashmap,
|
||||
coords: torch.Tensor,
|
||||
W, H, D,
|
||||
Kw, Kh, Kd,
|
||||
Dw, Dh, Dd,
|
||||
):
|
||||
device = coords.device
|
||||
M = coords.shape[0]
|
||||
V = Kw * Kh * Kd
|
||||
half_V = V // 2 + 1
|
||||
|
||||
INVALID = hashmap.default_value
|
||||
|
||||
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long)
|
||||
|
||||
b = coords[:, 0].long()
|
||||
x = coords[:, 1].long()
|
||||
y = coords[:, 2].long()
|
||||
z = coords[:, 3].long()
|
||||
|
||||
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
|
||||
|
||||
ox = x - (Kw // 2) * Dw
|
||||
oy = y - (Kh // 2) * Dh
|
||||
oz = z - (Kd // 2) * Dd
|
||||
|
||||
for v in range(half_V):
|
||||
if v == half_V - 1:
|
||||
neighbor[:, v] = torch.arange(M, device=device)
|
||||
continue
|
||||
|
||||
dx, dy, dz = offsets[v]
|
||||
|
||||
kx = ox + dx
|
||||
ky = oy + dy
|
||||
kz = oz + dz
|
||||
|
||||
# Check spatial bounds
|
||||
valid = (
|
||||
(kx >= 0) & (kx < W) &
|
||||
(ky >= 0) & (ky < H) &
|
||||
(kz >= 0) & (kz < D)
|
||||
)
|
||||
|
||||
flat = (
|
||||
b[valid] * (W * H * D) +
|
||||
kx[valid] * (H * D) +
|
||||
ky[valid] * D +
|
||||
kz[valid]
|
||||
)
|
||||
|
||||
if flat.numel() > 0:
|
||||
found = hashmap.lookup_flat(flat)
|
||||
idx_in_M = torch.where(valid)[0]
|
||||
neighbor[idx_in_M, v] = found
|
||||
|
||||
valid_found_mask = (found != INVALID)
|
||||
if valid_found_mask.any():
|
||||
src_points = idx_in_M[valid_found_mask]
|
||||
dst_points = found[valid_found_mask]
|
||||
neighbor[dst_points, V - 1 - v] = src_points
|
||||
|
||||
return neighbor
|
||||
|
||||
class TorchHashMap:
|
||||
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
|
||||
device = keys.device
|
||||
# use long for searchsorted
|
||||
self.sorted_keys, order = torch.sort(keys.to(torch.long))
|
||||
self.sorted_vals = values.to(torch.long)[order]
|
||||
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
|
||||
self._n = self.sorted_keys.numel()
|
||||
|
||||
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
||||
flat = flat_keys.to(torch.long)
|
||||
if self._n == 0:
|
||||
return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
|
||||
idx = torch.searchsorted(self.sorted_keys, flat)
|
||||
idx_safe = torch.clamp(idx, max=self._n - 1)
|
||||
found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat)
|
||||
out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
|
||||
if found.any():
|
||||
out[found] = self.sorted_vals[idx_safe[found]]
|
||||
return out
|
||||
|
||||
|
||||
UINT32_SENTINEL = 0xFFFFFFFF
|
||||
|
||||
def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
|
||||
device = neighbor_map.device
|
||||
N, V = neighbor_map.shape
|
||||
|
||||
sentinel = UINT32_SENTINEL
|
||||
|
||||
neigh_map_T = neighbor_map.t().reshape(-1)
|
||||
neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32)
|
||||
|
||||
mask = (neighbor_map != sentinel).to(torch.long)
|
||||
gray_code = torch.zeros(N, dtype=torch.long, device=device)
|
||||
|
||||
for v in range(V):
|
||||
gray_code |= (mask[:, v] << v)
|
||||
|
||||
binary_code = gray_code.clone()
|
||||
for v in range(1, V):
|
||||
binary_code ^= (gray_code >> v)
|
||||
|
||||
sorted_idx = torch.argsort(binary_code)
|
||||
|
||||
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0)
|
||||
|
||||
total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0
|
||||
|
||||
if total_valid_signal > 0:
|
||||
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
|
||||
to = (prefix_sum_neighbor_mask[pos] - 1).long()
|
||||
|
||||
valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||
valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||
|
||||
valid_signal_i[to] = (pos % N).to(torch.long)
|
||||
valid_signal_o[to] = neigh_map_T[pos].to(torch.long)
|
||||
else:
|
||||
valid_signal_i = torch.empty((0,), dtype=torch.long, device=device)
|
||||
valid_signal_o = torch.empty((0,), dtype=torch.long, device=device)
|
||||
|
||||
seg = torch.empty((V + 1,), dtype=torch.long, device=device)
|
||||
seg[0] = 0
|
||||
if V > 0:
|
||||
idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1
|
||||
seg[1:] = prefix_sum_neighbor_mask[idxs]
|
||||
|
||||
return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg
|
||||
|
||||
def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
x = x.to(torch.int64)
|
||||
|
||||
m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device)
|
||||
m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device)
|
||||
m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device)
|
||||
h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device)
|
||||
|
||||
x = x - ((x >> 1) & m1)
|
||||
x = (x & m2) + ((x >> 2) & m2)
|
||||
x = (x + (x >> 4)) & m4
|
||||
x = (x * h01) >> 56
|
||||
return x.to(torch.int32)
|
||||
|
||||
|
||||
def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
||||
gray_code: torch.Tensor,
|
||||
sorted_idx: torch.Tensor,
|
||||
block_size: int
|
||||
):
|
||||
device = gray_code.device
|
||||
N = gray_code.numel()
|
||||
num_blocks = (N + block_size - 1) // block_size
|
||||
|
||||
pad = num_blocks * block_size - N
|
||||
if pad > 0:
|
||||
pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device)
|
||||
gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0)
|
||||
else:
|
||||
gray_padded = gray_code[sorted_idx]
|
||||
|
||||
gray_blocks = gray_padded.view(num_blocks, block_size)
|
||||
|
||||
reduced_code = gray_blocks
|
||||
while reduced_code.shape[1] > 1:
|
||||
half = reduced_code.shape[1] // 2
|
||||
remainder = reduced_code.shape[1] % 2
|
||||
|
||||
left = reduced_code[:, :half * 2:2]
|
||||
right = reduced_code[:, 1:half * 2:2]
|
||||
merged = left | right
|
||||
|
||||
if remainder:
|
||||
reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1)
|
||||
else:
|
||||
reduced_code = merged
|
||||
|
||||
reduced_code = reduced_code.squeeze(1)
|
||||
|
||||
seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32)
|
||||
|
||||
seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device)
|
||||
seg[0] = 0
|
||||
if num_blocks > 0:
|
||||
seg[1:] = torch.cumsum(seglen_counts, dim=0)
|
||||
|
||||
total = int(seg[-1].item())
|
||||
|
||||
if total == 0:
|
||||
return torch.empty((0,), dtype=torch.int32, device=device), seg
|
||||
|
||||
V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0
|
||||
|
||||
if V == 0:
|
||||
return torch.empty((0,), dtype=torch.int32, device=device), seg
|
||||
|
||||
bit_pos = torch.arange(0, V, dtype=torch.int32, device=device)
|
||||
shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0)
|
||||
bits = (shifted & 1).to(torch.bool)
|
||||
|
||||
positions = bit_pos.unsqueeze(0).expand(num_blocks, V)
|
||||
valid_kernel_idx = positions[bits].to(torch.int32).contiguous()
|
||||
|
||||
return valid_kernel_idx, seg
|
||||
|
||||
|
||||
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
|
||||
if NO_TRITON: # TODO
|
||||
raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.")
|
||||
if feats.shape[0] == 0:
|
||||
logging.warning("Found feats to be empty!")
|
||||
Co = weight.shape[0]
|
||||
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
|
||||
if len(shape) == 5:
|
||||
N, C, W, H, D = shape
|
||||
else:
|
||||
W, H, D = shape
|
||||
|
||||
Co, Kw, Kh, Kd, Ci = weight.shape
|
||||
|
||||
b_stride = W * H * D
|
||||
x_stride = H * D
|
||||
y_stride = D
|
||||
z_stride = 1
|
||||
|
||||
flat_keys = (coords[:, 0].long() * b_stride +
|
||||
coords[:, 1].long() * x_stride +
|
||||
coords[:, 2].long() * y_stride +
|
||||
coords[:, 3].long() * z_stride)
|
||||
|
||||
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device)
|
||||
|
||||
hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF)
|
||||
|
||||
if neighbor_cache is None:
|
||||
neighbor = build_submanifold_neighbor_map(
|
||||
hashmap, coords, W, H, D, Kw, Kh, Kd,
|
||||
dilation[0], dilation[1], dilation[2]
|
||||
)
|
||||
else:
|
||||
neighbor = neighbor_cache
|
||||
|
||||
block_size = 128
|
||||
|
||||
gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \
|
||||
neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor)
|
||||
|
||||
valid_kernel, valid_kernel_seg = \
|
||||
neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size)
|
||||
|
||||
valid_kernel_fn = lambda b_size: valid_kernel
|
||||
valid_kernel_seg_fn = lambda b_size: valid_kernel_seg
|
||||
|
||||
weight_flat = weight.contiguous().view(Co, -1, Ci)
|
||||
|
||||
out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
|
||||
feats,
|
||||
weight_flat,
|
||||
bias,
|
||||
neighbor,
|
||||
sorted_idx,
|
||||
valid_kernel_fn,
|
||||
valid_kernel_seg_fn
|
||||
)
|
||||
|
||||
return out, neighbor
|
||||
|
||||
class Mesh:
|
||||
def __init__(self,
|
||||
vertices,
|
||||
faces,
|
||||
vertex_attrs=None
|
||||
):
|
||||
self.vertices = vertices.float()
|
||||
self.faces = faces.int()
|
||||
self.vertex_attrs = vertex_attrs
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self.vertices.device
|
||||
|
||||
def to(self, device, non_blocking=False):
|
||||
return Mesh(
|
||||
self.vertices.to(device, non_blocking=non_blocking),
|
||||
self.faces.to(device, non_blocking=non_blocking),
|
||||
self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None,
|
||||
)
|
||||
|
||||
def cuda(self, non_blocking=False):
|
||||
return self.to('cuda', non_blocking=non_blocking)
|
||||
|
||||
def cpu(self):
|
||||
return self.to('cpu')
|
||||
875
comfy/ldm/trellis2/model.py
Normal file
875
comfy/ldm/trellis2/model.py
Normal file
@ -0,0 +1,875 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor
|
||||
from typing import Optional, Tuple, Literal, Union, List
|
||||
from comfy.ldm.trellis2.attention import (
|
||||
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
|
||||
)
|
||||
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
|
||||
from comfy.ldm.flux.math import apply_rope, apply_rope1
|
||||
|
||||
class SparseGELU(nn.GELU):
|
||||
def forward(self, input: VarLenTensor) -> VarLenTensor:
|
||||
return input.replace(super().forward(input.feats))
|
||||
|
||||
class SparseFeedForwardNet(nn.Module):
|
||||
def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
SparseLinear(channels, int(channels * mlp_ratio), device=device, dtype=dtype, operations=operations),
|
||||
SparseGELU(approximate="tanh"),
|
||||
SparseLinear(int(channels * mlp_ratio), channels, device=device, dtype=dtype, operations=operations),
|
||||
)
|
||||
|
||||
def forward(self, x: VarLenTensor) -> VarLenTensor:
|
||||
return self.mlp(x)
|
||||
|
||||
def manual_cast(obj, dtype):
|
||||
return obj.to(dtype=dtype)
|
||||
|
||||
class LayerNorm32(nn.LayerNorm):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_dtype = x.dtype
|
||||
x = manual_cast(x, torch.float32)
|
||||
o = super().forward(x)
|
||||
return manual_cast(o, x_dtype)
|
||||
|
||||
|
||||
class SparseMultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, heads: int, device, dtype):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
||||
x_type = x.dtype
|
||||
x = x.float()
|
||||
if isinstance(x, VarLenTensor):
|
||||
x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale)
|
||||
else:
|
||||
x = F.normalize(x, dim=-1) * self.gamma * self.scale
|
||||
return x.to(x_type)
|
||||
|
||||
class SparseRotaryPositionEmbedder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
head_dim: int,
|
||||
dim: int = 3,
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
device=None
|
||||
):
|
||||
super().__init__()
|
||||
self.head_dim = head_dim
|
||||
self.dim = dim
|
||||
self.rope_freq = rope_freq
|
||||
self.freq_dim = head_dim // 2 // dim
|
||||
self.freqs = torch.arange(self.freq_dim, dtype=torch.float32, device=device) / self.freq_dim
|
||||
self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs))
|
||||
|
||||
def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor:
|
||||
phases_list = []
|
||||
for i in range(self.dim):
|
||||
phases_list.append(torch.outer(coords[..., i], self.freqs.to(coords.device)))
|
||||
|
||||
phases = torch.cat(phases_list, dim=-1)
|
||||
|
||||
if phases.shape[-1] < self.head_dim // 2:
|
||||
padn = self.head_dim // 2 - phases.shape[-1]
|
||||
phases = torch.cat([phases, torch.zeros(*phases.shape[:-1], padn, device=phases.device)], dim=-1)
|
||||
|
||||
cos = torch.cos(phases)
|
||||
sin = torch.sin(phases)
|
||||
|
||||
f_cis_0 = torch.stack([cos, sin], dim=-1)
|
||||
f_cis_1 = torch.stack([-sin, cos], dim=-1)
|
||||
freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1)
|
||||
|
||||
return freqs_cis
|
||||
|
||||
def _get_phases(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
self.freqs = self.freqs.to(indices.device)
|
||||
phases = torch.outer(indices, self.freqs)
|
||||
phases = torch.polar(torch.ones_like(phases), phases)
|
||||
return phases
|
||||
|
||||
def forward(self, q, k=None):
|
||||
cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}'
|
||||
freqs_cis = q.get_spatial_cache(cache_name)
|
||||
|
||||
if freqs_cis is None:
|
||||
coords = q.coords[..., 1:].to(torch.float32)
|
||||
freqs_cis = self._get_freqs_cis(coords)
|
||||
q.register_spatial_cache(cache_name, freqs_cis)
|
||||
|
||||
if q.feats.ndim == 3:
|
||||
f_cis = freqs_cis.unsqueeze(1)
|
||||
else:
|
||||
f_cis = freqs_cis
|
||||
|
||||
if k is None:
|
||||
return q.replace(apply_rope1(q.feats, f_cis))
|
||||
|
||||
q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis)
|
||||
return q.replace(q_feats), k.replace(k_feats)
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor:
|
||||
x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
||||
x_rotated = x_complex * phases.unsqueeze(-2)
|
||||
x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype)
|
||||
return x_embed
|
||||
|
||||
class RotaryPositionEmbedder(SparseRotaryPositionEmbedder):
|
||||
def forward(self, indices: torch.Tensor) -> torch.Tensor:
|
||||
phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1)
|
||||
if torch.is_complex(phases):
|
||||
phases = phases.to(torch.complex64)
|
||||
else:
|
||||
phases = phases.to(torch.float32)
|
||||
if phases.shape[-1] < self.head_dim // 2:
|
||||
padn = self.head_dim // 2 - phases.shape[-1]
|
||||
phases = torch.cat([phases, torch.polar(
|
||||
torch.ones(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32),
|
||||
torch.zeros(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32)
|
||||
)], dim=-1)
|
||||
return phases
|
||||
|
||||
class SparseMultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_heads: int,
|
||||
ctx_channels: Optional[int] = None,
|
||||
type: Literal["self", "cross"] = "self",
|
||||
attn_mode: Literal["full", "windowed", "double_windowed"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
qkv_bias: bool = True,
|
||||
use_rope: bool = False,
|
||||
rope_freq: Tuple[int, int] = (1.0, 10000.0),
|
||||
qk_rms_norm: bool = False,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.head_dim = channels // num_heads
|
||||
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||
self.num_heads = num_heads
|
||||
self._type = type
|
||||
self.attn_mode = attn_mode
|
||||
self.window_size = window_size
|
||||
self.shift_window = shift_window
|
||||
self.use_rope = use_rope
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
|
||||
if self._type == "self":
|
||||
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, device=device, dtype=dtype)
|
||||
else:
|
||||
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
if self.qk_rms_norm:
|
||||
self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
|
||||
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
||||
|
||||
if use_rope:
|
||||
self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq, device=device)
|
||||
|
||||
@staticmethod
|
||||
def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]:
|
||||
if isinstance(x, VarLenTensor):
|
||||
return x.replace(module(x.feats))
|
||||
else:
|
||||
return module(x)
|
||||
|
||||
@staticmethod
|
||||
def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]:
|
||||
if isinstance(x, VarLenTensor):
|
||||
return x.reshape(*shape)
|
||||
else:
|
||||
return x.reshape(*x.shape[:2], *shape)
|
||||
|
||||
def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]:
|
||||
if isinstance(x, VarLenTensor):
|
||||
x_feats = x.feats.unsqueeze(0)
|
||||
else:
|
||||
x_feats = x
|
||||
x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1)
|
||||
return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats
|
||||
|
||||
def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor:
|
||||
if self._type == "self":
|
||||
dtype = next(self.to_qkv.parameters()).dtype
|
||||
x = x.to(dtype)
|
||||
qkv = self._linear(self.to_qkv, x)
|
||||
qkv = self._fused_pre(qkv, num_fused=3)
|
||||
if self.qk_rms_norm or self.use_rope:
|
||||
q, k, v = qkv.unbind(dim=-3)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
if self.use_rope:
|
||||
q, k = self.rope(q, k)
|
||||
qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1))
|
||||
if self.attn_mode == "full":
|
||||
h = sparse_scaled_dot_product_attention(qkv)
|
||||
elif self.attn_mode == "windowed":
|
||||
h = sparse_windowed_scaled_dot_product_self_attention(
|
||||
qkv, self.window_size, shift_window=self.shift_window
|
||||
)
|
||||
elif self.attn_mode == "double_windowed":
|
||||
qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:])
|
||||
qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2])
|
||||
h0 = sparse_windowed_scaled_dot_product_self_attention(
|
||||
qkv0, self.window_size, shift_window=(0, 0, 0)
|
||||
)
|
||||
h1 = sparse_windowed_scaled_dot_product_self_attention(
|
||||
qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3)
|
||||
)
|
||||
h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1))
|
||||
else:
|
||||
q = self._linear(self.to_q, x)
|
||||
q = self._reshape_chs(q, (self.num_heads, -1))
|
||||
dtype = next(self.to_kv.parameters()).dtype
|
||||
context = context.to(dtype)
|
||||
kv = self._linear(self.to_kv, context)
|
||||
kv = self._fused_pre(kv, num_fused=2)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k, v = kv.unbind(dim=-3)
|
||||
k = self.k_rms_norm(k)
|
||||
h = sparse_scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
h = sparse_scaled_dot_product_attention(q, kv)
|
||||
h = self._reshape_chs(h, (-1,))
|
||||
h = self._linear(self.to_out, h)
|
||||
return h
|
||||
|
||||
class ModulatedSparseTransformerCrossBlock(nn.Module):
|
||||
"""
|
||||
Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
ctx_channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "swin"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
share_mod: bool = False,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.self_attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
type="self",
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_window=shift_window,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
rope_freq=rope_freq,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.cross_attn = SparseMultiHeadAttention(
|
||||
channels,
|
||||
ctx_channels=ctx_channels,
|
||||
num_heads=num_heads,
|
||||
type="cross",
|
||||
attn_mode="full",
|
||||
qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.mlp = SparseFeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype)
|
||||
)
|
||||
else:
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||
|
||||
def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = x.replace(self.norm1(x.feats))
|
||||
h = h * (1 + scale_msa) + shift_msa
|
||||
h = self.self_attn(h)
|
||||
h = h * gate_msa
|
||||
x = x + h
|
||||
h = x.replace(self.norm2(x.feats))
|
||||
h = self.cross_attn(h, context)
|
||||
x = x + h
|
||||
h = x.replace(self.norm3(x.feats))
|
||||
h = h * (1 + scale_mlp) + shift_mlp
|
||||
h = self.mlp(h)
|
||||
h = h * gate_mlp
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor:
|
||||
return self._forward(x, mod, context)
|
||||
|
||||
|
||||
class SLatFlowModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
cond_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: int,
|
||||
num_heads: Optional[int] = None,
|
||||
num_head_channels: Optional[int] = 64,
|
||||
mlp_ratio: float = 4,
|
||||
pe_mode: Literal["ape", "rope"] = "rope",
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
use_checkpoint: bool = False,
|
||||
share_mod: bool = False,
|
||||
initialization: str = 'vanilla',
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
dtype = None,
|
||||
device = None,
|
||||
operations = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.pe_mode = pe_mode
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.initialization = initialization
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||
self.dtype = dtype
|
||||
|
||||
self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
self.input_layer = SparseLinear(in_channels, model_channels, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
ModulatedSparseTransformerCrossBlock(
|
||||
model_channels,
|
||||
cond_channels,
|
||||
num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
attn_mode='full',
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
use_rope=(pe_mode == "rope"),
|
||||
rope_freq=rope_freq,
|
||||
share_mod=self.share_mod,
|
||||
qk_rms_norm=self.qk_rms_norm,
|
||||
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
self.out_layer = SparseLinear(model_channels, out_channels, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return next(self.parameters()).device
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: SparseTensor,
|
||||
t: torch.Tensor,
|
||||
cond: Union[torch.Tensor, List[torch.Tensor]],
|
||||
concat_cond: Optional[SparseTensor] = None,
|
||||
**kwargs
|
||||
) -> SparseTensor:
|
||||
if concat_cond is not None:
|
||||
x = sparse_cat([x, concat_cond], dim=-1)
|
||||
if isinstance(cond, list):
|
||||
cond = VarLenTensor.from_tensor_list(cond)
|
||||
|
||||
dtype = next(self.input_layer.parameters()).dtype
|
||||
x = x.to(dtype)
|
||||
h = self.input_layer(x)
|
||||
h = manual_cast(h, self.dtype)
|
||||
t = t.to(dtype)
|
||||
t_embedder = self.t_embedder.to(dtype)
|
||||
t_emb = t_embedder(t, out_dtype = t.dtype)
|
||||
if self.share_mod:
|
||||
t_emb = self.adaLN_modulation(t_emb)
|
||||
t_emb = manual_cast(t_emb, self.dtype)
|
||||
cond = manual_cast(cond, self.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
h = block(h, t_emb, cond)
|
||||
|
||||
h = manual_cast(h, x.dtype)
|
||||
h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:]))
|
||||
h = self.out_layer(h)
|
||||
return h
|
||||
|
||||
class FeedForwardNet(nn.Module):
|
||||
def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(channels, int(channels * mlp_ratio), device=device, dtype=dtype),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(int(channels * mlp_ratio), channels, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.mlp(x)
|
||||
|
||||
class MultiHeadRMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, heads: int, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.scale = dim ** 0.5
|
||||
self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype)
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_heads: int,
|
||||
ctx_channels: Optional[int]=None,
|
||||
type: Literal["self", "cross"] = "self",
|
||||
attn_mode: Literal["full", "windowed"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
qkv_bias: bool = True,
|
||||
use_rope: bool = False,
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
qk_rms_norm: bool = False,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.channels = channels
|
||||
self.head_dim = channels // num_heads
|
||||
self.ctx_channels = ctx_channels if ctx_channels is not None else channels
|
||||
self.num_heads = num_heads
|
||||
self._type = type
|
||||
self.attn_mode = attn_mode
|
||||
self.window_size = window_size
|
||||
self.shift_window = shift_window
|
||||
self.use_rope = use_rope
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
|
||||
if self._type == "self":
|
||||
self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype)
|
||||
self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype)
|
||||
|
||||
if self.qk_rms_norm:
|
||||
self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype)
|
||||
|
||||
self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
B, L, C = x.shape
|
||||
if self._type == "self":
|
||||
x = x.to(next(self.to_qkv.parameters()).dtype)
|
||||
qkv = self.to_qkv(x)
|
||||
qkv = qkv.reshape(B, L, 3, self.num_heads, -1)
|
||||
|
||||
if self.attn_mode == "full":
|
||||
if self.qk_rms_norm or self.use_rope:
|
||||
q, k, v = qkv.unbind(dim=2)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k = self.k_rms_norm(k)
|
||||
if self.use_rope:
|
||||
assert phases is not None, "Phases must be provided for RoPE"
|
||||
q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases)
|
||||
k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases)
|
||||
h = scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
h = scaled_dot_product_attention(qkv)
|
||||
else:
|
||||
Lkv = context.shape[1]
|
||||
q = self.to_q(x)
|
||||
context = context.to(next(self.to_kv.parameters()).dtype)
|
||||
kv = self.to_kv(context)
|
||||
q = q.reshape(B, L, self.num_heads, -1)
|
||||
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
|
||||
if self.qk_rms_norm:
|
||||
q = self.q_rms_norm(q)
|
||||
k, v = kv.unbind(dim=2)
|
||||
k = self.k_rms_norm(k)
|
||||
h = scaled_dot_product_attention(q, k, v)
|
||||
else:
|
||||
h = scaled_dot_product_attention(q, kv)
|
||||
h = h.reshape(B, L, -1)
|
||||
h = self.to_out(h)
|
||||
return h
|
||||
|
||||
class ModulatedTransformerCrossBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
ctx_channels: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
attn_mode: Literal["full", "windowed"] = "full",
|
||||
window_size: Optional[int] = None,
|
||||
shift_window: Optional[Tuple[int, int, int]] = None,
|
||||
use_checkpoint: bool = False,
|
||||
use_rope: bool = False,
|
||||
rope_freq: Tuple[int, int] = (1.0, 10000.0),
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
qkv_bias: bool = True,
|
||||
share_mod: bool = False,
|
||||
device=None, dtype=None, operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device)
|
||||
self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device)
|
||||
self.self_attn = MultiHeadAttention(
|
||||
channels,
|
||||
num_heads=num_heads,
|
||||
type="self",
|
||||
attn_mode=attn_mode,
|
||||
window_size=window_size,
|
||||
shift_window=shift_window,
|
||||
qkv_bias=qkv_bias,
|
||||
use_rope=use_rope,
|
||||
rope_freq=rope_freq,
|
||||
qk_rms_norm=qk_rms_norm,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.cross_attn = MultiHeadAttention(
|
||||
channels,
|
||||
ctx_channels=ctx_channels,
|
||||
num_heads=num_heads,
|
||||
type="cross",
|
||||
attn_mode="full",
|
||||
qkv_bias=qkv_bias,
|
||||
qk_rms_norm=qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
self.mlp = FeedForwardNet(
|
||||
channels,
|
||||
mlp_ratio=mlp_ratio,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
if not share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
else:
|
||||
self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5)
|
||||
|
||||
def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if self.share_mod:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1)
|
||||
h = self.norm1(x)
|
||||
h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
||||
h = self.self_attn(h, phases=phases)
|
||||
h = h * gate_msa.unsqueeze(1)
|
||||
x = x + h
|
||||
h = self.norm2(x)
|
||||
h = self.cross_attn(h, context)
|
||||
x = x + h
|
||||
h = self.norm3(x)
|
||||
h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||
h = self.mlp(h)
|
||||
h = h * gate_mlp.unsqueeze(1)
|
||||
x = x + h
|
||||
return x
|
||||
|
||||
def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return self._forward(x, mod, context, phases)
|
||||
|
||||
|
||||
class SparseStructureFlowModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
resolution: int,
|
||||
in_channels: int,
|
||||
model_channels: int,
|
||||
cond_channels: int,
|
||||
out_channels: int,
|
||||
num_blocks: int,
|
||||
num_heads: Optional[int] = None,
|
||||
num_head_channels: Optional[int] = 64,
|
||||
mlp_ratio: float = 4,
|
||||
pe_mode: Literal["ape", "rope"] = "rope",
|
||||
rope_freq: Tuple[float, float] = (1.0, 10000.0),
|
||||
use_checkpoint: bool = False,
|
||||
share_mod: bool = False,
|
||||
initialization: str = 'vanilla',
|
||||
qk_rms_norm: bool = False,
|
||||
qk_rms_norm_cross: bool = False,
|
||||
operations=None,
|
||||
device = None,
|
||||
dtype = torch.float32,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.cond_channels = cond_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads or model_channels // num_head_channels
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.pe_mode = pe_mode
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.share_mod = share_mod
|
||||
self.initialization = initialization
|
||||
self.qk_rms_norm = qk_rms_norm
|
||||
self.qk_rms_norm_cross = qk_rms_norm_cross
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.t_embedder = TimestepEmbedder(model_channels, dtype=dtype, device=device, operations=operations)
|
||||
if share_mod:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype)
|
||||
)
|
||||
|
||||
pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3, device=device)
|
||||
coords = torch.meshgrid(*[torch.arange(res, device=self.device, dtype=dtype) for res in [resolution] * 3], indexing='ij')
|
||||
coords = torch.stack(coords, dim=-1).reshape(-1, 3)
|
||||
rope_phases = pos_embedder(coords)
|
||||
self.register_buffer("rope_phases", rope_phases, persistent=False)
|
||||
|
||||
if pe_mode != "rope":
|
||||
self.rope_phases = None
|
||||
|
||||
self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
ModulatedTransformerCrossBlock(
|
||||
model_channels,
|
||||
cond_channels,
|
||||
num_heads=self.num_heads,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
attn_mode='full',
|
||||
use_checkpoint=self.use_checkpoint,
|
||||
use_rope=(pe_mode == "rope"),
|
||||
rope_freq=rope_freq,
|
||||
share_mod=share_mod,
|
||||
qk_rms_norm=self.qk_rms_norm,
|
||||
qk_rms_norm_cross=self.qk_rms_norm_cross,
|
||||
device=device, dtype=dtype, operations=operations
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
self.out_layer = operations.Linear(model_channels, out_channels, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
||||
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
|
||||
|
||||
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
|
||||
|
||||
h = h.to(next(self.input_layer.parameters()).dtype)
|
||||
h = self.input_layer(h)
|
||||
t_emb = self.t_embedder(t, out_dtype = t.dtype)
|
||||
if self.share_mod:
|
||||
t_emb = self.adaLN_modulation(t_emb)
|
||||
t_emb = manual_cast(t_emb, self.dtype)
|
||||
h = manual_cast(h, self.dtype)
|
||||
cond = manual_cast(cond, self.dtype)
|
||||
for block in self.blocks:
|
||||
h = block(h, t_emb, cond, self.rope_phases)
|
||||
h = manual_cast(h, x.dtype)
|
||||
h = F.layer_norm(h, h.shape[-1:])
|
||||
h = h.to(next(self.out_layer.parameters()).dtype)
|
||||
h = self.out_layer(h)
|
||||
|
||||
h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous()
|
||||
|
||||
return h
|
||||
|
||||
def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0):
|
||||
t_shifted = t_shifted / 1000.0
|
||||
t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1))
|
||||
t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear)
|
||||
t_new *= 1000.0
|
||||
return t_new
|
||||
|
||||
class Trellis2(nn.Module):
|
||||
def __init__(self, resolution,
|
||||
in_channels = 32,
|
||||
out_channels = 32,
|
||||
model_channels = 1536,
|
||||
cond_channels = 1024,
|
||||
num_blocks = 30,
|
||||
num_heads = 12,
|
||||
mlp_ratio = 5.3334,
|
||||
share_mod = True,
|
||||
qk_rms_norm = True,
|
||||
qk_rms_norm_cross = True,
|
||||
init_txt_model=False, # for now
|
||||
dtype=None, device=None, operations=None, **kwargs):
|
||||
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operations = operations or nn
|
||||
# for some reason it passes num_heads = -1
|
||||
if num_heads == -1:
|
||||
num_heads = 12
|
||||
args = {
|
||||
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
|
||||
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
|
||||
"qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations
|
||||
}
|
||||
self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args)
|
||||
self.shape2txt = None
|
||||
if init_txt_model:
|
||||
self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args)
|
||||
self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args)
|
||||
args.pop("out_channels")
|
||||
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
|
||||
self.guidance_interval = [0.6, 1.0]
|
||||
self.guidance_interval_txt = [0.6, 0.9]
|
||||
|
||||
def forward(self, x, timestep, context, **kwargs):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
embeds = kwargs.get("embeds")
|
||||
if embeds is None:
|
||||
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
|
||||
is_1024 = self.img2shape.resolution == 1024
|
||||
coords = transformer_options.get("coords", None)
|
||||
mode = transformer_options.get("generation_mode", "structure_generation")
|
||||
is_512_run = False
|
||||
timestep = timestep.to(self.dtype)
|
||||
if mode == "shape_generation_512":
|
||||
is_512_run = True
|
||||
mode = "shape_generation"
|
||||
if coords is not None:
|
||||
x = x.squeeze(-1).transpose(1, 2)
|
||||
not_struct_mode = True
|
||||
else:
|
||||
mode = "structure_generation"
|
||||
not_struct_mode = False
|
||||
|
||||
if is_1024 and not_struct_mode and not is_512_run:
|
||||
context = embeds
|
||||
|
||||
sigmas = transformer_options.get("sigmas")[0].item()
|
||||
if sigmas < 1.00001:
|
||||
timestep *= 1000.0
|
||||
if context.size(0) > 1:
|
||||
cond = context.chunk(2)[1]
|
||||
else:
|
||||
cond = context
|
||||
shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1]
|
||||
txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1]
|
||||
|
||||
if not_struct_mode:
|
||||
orig_bsz = x.shape[0]
|
||||
rule = txt_rule if mode == "texture_generation" else shape_rule
|
||||
|
||||
if rule and orig_bsz > 1:
|
||||
x_eval = x[1].unsqueeze(0)
|
||||
t_eval = timestep[1].unsqueeze(0) if timestep.shape[0] > 1 else timestep
|
||||
c_eval = cond
|
||||
else:
|
||||
x_eval = x
|
||||
t_eval = timestep
|
||||
c_eval = context
|
||||
|
||||
B, N, C = x_eval.shape
|
||||
|
||||
if mode in ["shape_generation", "texture_generation"]:
|
||||
feats_flat = x_eval.reshape(-1, C)
|
||||
|
||||
# inflate coords [N, 4] -> [B*N, 4]
|
||||
coords_list = []
|
||||
for i in range(B):
|
||||
c = coords.clone()
|
||||
c[:, 0] = i
|
||||
coords_list.append(c)
|
||||
|
||||
batched_coords = torch.cat(coords_list, dim=0)
|
||||
else:
|
||||
batched_coords = coords
|
||||
feats_flat = x_eval
|
||||
|
||||
x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32))
|
||||
|
||||
if mode == "shape_generation":
|
||||
if is_512_run:
|
||||
out = self.img2shape_512(x_st, t_eval, c_eval)
|
||||
else:
|
||||
out = self.img2shape(x_st, t_eval, c_eval)
|
||||
elif mode == "texture_generation":
|
||||
if self.shape2txt is None:
|
||||
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
|
||||
slat = transformer_options.get("shape_slat")
|
||||
if slat is None:
|
||||
raise ValueError("shape_slat can't be None")
|
||||
|
||||
base_slat_feats = slat[:N]
|
||||
slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device)
|
||||
x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1))
|
||||
out = self.shape2txt(x_st, t_eval, c_eval)
|
||||
else: # structure
|
||||
orig_bsz = x.shape[0]
|
||||
if shape_rule:
|
||||
x = x[1].unsqueeze(0)
|
||||
timestep = timestep[1].unsqueeze(0)
|
||||
out = self.structure_model(x, timestep, context if not shape_rule else cond)
|
||||
if shape_rule:
|
||||
out = out.repeat(orig_bsz, 1, 1, 1, 1)
|
||||
|
||||
if not_struct_mode:
|
||||
out = out.feats
|
||||
out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1)
|
||||
if rule and orig_bsz > 1:
|
||||
out = out.repeat(orig_bsz, 1, 1, 1)
|
||||
return out
|
||||
1444
comfy/ldm/trellis2/vae.py
Normal file
1444
comfy/ldm/trellis2/vae.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -51,6 +51,7 @@ import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
import comfy.ldm.anima.model
|
||||
import comfy.ldm.trellis2.model
|
||||
import comfy.ldm.ace.ace_step15
|
||||
import comfy.ldm.rt_detr.rtdetr_v4
|
||||
import comfy.ldm.ernie.model
|
||||
@ -1537,6 +1538,16 @@ class WAN22(WAN21):
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
return latent_image
|
||||
|
||||
class Trellis2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.trellis2.model.Trellis2):
|
||||
super().__init__(model_config, model_type, device, unet_model)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
embeds = kwargs.get("embeds")
|
||||
out["embeds"] = comfy.conds.CONDRegular(embeds)
|
||||
return out
|
||||
|
||||
class WAN21_FlowRVS(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
||||
model_config.unet_config["model_type"] = "t2v"
|
||||
@ -1578,7 +1589,6 @@ class WAN21_SCAIL(WAN21):
|
||||
pose_latents = kwargs.get("pose_video_latent", None)
|
||||
if pose_latents is not None:
|
||||
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
|
||||
|
||||
return out
|
||||
|
||||
class Hunyuan3Dv2(BaseModel):
|
||||
|
||||
@ -113,6 +113,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||
return unet_config
|
||||
|
||||
if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
unet_config = {}
|
||||
unet_config["image_model"] = "trellis2"
|
||||
|
||||
unet_config["init_txt_model"] = False
|
||||
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||
unet_config["init_txt_model"] = True
|
||||
|
||||
unet_config["resolution"] = 64
|
||||
if metadata is not None:
|
||||
if "is_512" in metadata:
|
||||
unet_config["resolution"] = 32
|
||||
|
||||
unet_config["num_heads"] = 12
|
||||
return unet_config
|
||||
|
||||
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
||||
unet_config = {}
|
||||
unet_config["audio_model"] = "dit1.0"
|
||||
|
||||
10
comfy/sd.py
10
comfy/sd.py
@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model
|
||||
import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
||||
import comfy.ldm.cosmos.vae
|
||||
import comfy.ldm.wan.vae
|
||||
import comfy.ldm.trellis2.vae
|
||||
import comfy.ldm.wan.vae2_2
|
||||
import comfy.ldm.hunyuan3d.vae
|
||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||
@ -507,6 +508,15 @@ class VAE:
|
||||
self.first_stage_model = StageC_coder()
|
||||
self.downscale_ratio = 32
|
||||
self.latent_channels = 16
|
||||
elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2
|
||||
init_txt_model = False
|
||||
if "txt_dec.blocks.1.16.norm1.weight" in sd:
|
||||
init_txt_model = True
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
# TODO
|
||||
self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model)
|
||||
elif "decoder.conv_in.weight" in sd:
|
||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
|
||||
@ -1273,6 +1273,29 @@ class WAN22_T2V(WAN21_T2V):
|
||||
out = model_base.WAN22(self, image_to_video=True, device=device)
|
||||
return out
|
||||
|
||||
class Trellis2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "trellis2"
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 3.5
|
||||
|
||||
latent_format = latent_formats.Trellis2
|
||||
vae_key_prefix = ["vae."]
|
||||
clip_vision_prefix = "conditioner.main_image_encoder.model."
|
||||
# this is only needed for the texture model
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.Trellis2(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
class WAN21_FlowRVS(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1293,6 +1316,7 @@ class WAN21_SCAIL(WAN21_T2V):
|
||||
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
|
||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan3d2",
|
||||
@ -1664,6 +1688,7 @@ class Kandinsky5Image(Kandinsky5):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
|
||||
class ACEStep15(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"audio_model": "ace1.5",
|
||||
@ -1703,7 +1728,6 @@ class ACEStep15(supported_models_base.BASE):
|
||||
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
||||
|
||||
|
||||
class LongCatImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
@ -1781,6 +1805,6 @@ class ErnieImage(supported_models_base.BASE):
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, Trellis2]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -484,7 +484,7 @@ class VoxelToMesh(IO.ComfyNode):
|
||||
decode = execute # TODO: remove
|
||||
|
||||
|
||||
def save_glb(vertices, faces, filepath, metadata=None):
|
||||
def save_glb(vertices, faces, filepath, metadata=None, colors=None):
|
||||
"""
|
||||
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
||||
|
||||
@ -515,6 +515,13 @@ def save_glb(vertices, faces, filepath, metadata=None):
|
||||
indices_byte_length = len(indices_buffer)
|
||||
indices_byte_offset = len(vertices_buffer_padded)
|
||||
|
||||
if colors is not None:
|
||||
colors_np = colors.cpu().numpy().astype(np.float32)
|
||||
colors_buffer = colors_np.tobytes()
|
||||
colors_byte_length = len(colors_buffer)
|
||||
colors_byte_offset = len(buffer_data)
|
||||
buffer_data += pad_to_4_bytes(colors_buffer)
|
||||
|
||||
gltf = {
|
||||
"asset": {"version": "2.0", "generator": "ComfyUI"},
|
||||
"buffers": [
|
||||
@ -580,6 +587,14 @@ def save_glb(vertices, faces, filepath, metadata=None):
|
||||
"scene": 0
|
||||
}
|
||||
|
||||
if colors is not None:
|
||||
gltf["bufferViews"].append({"buffer": 0, "byteOffset": colors_byte_offset, "byteLength": colors_byte_length, "target": 34962})
|
||||
gltf["accessors"].append({"bufferView": 2, "byteOffset": 0, "componentType": 5126, "count": len(colors_np), "type": "VEC3"})
|
||||
gltf["meshes"][0]["primitives"][0]["attributes"]["COLOR_0"] = 2
|
||||
# Define a base material so Three.js actually activates vertex coloring
|
||||
gltf["materials"] =[{"pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0]}}]
|
||||
gltf["meshes"][0]["primitives"][0]["material"] = 0
|
||||
|
||||
if metadata is not None:
|
||||
gltf["asset"]["extras"] = metadata
|
||||
|
||||
@ -669,7 +684,8 @@ class SaveGLB(IO.ComfyNode):
|
||||
# Handle Mesh input - save vertices and faces as GLB
|
||||
for i in range(mesh.vertices.shape[0]):
|
||||
f = f"{filename}_{counter:05}_.glb"
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
||||
v_colors = mesh.colors[i] if hasattr(mesh, "colors") and mesh.colors is not None else None
|
||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata, v_colors)
|
||||
results.append({
|
||||
"filename": f,
|
||||
"subfolder": subfolder,
|
||||
|
||||
693
comfy_extras/nodes_trellis2.py
Normal file
693
comfy_extras/nodes_trellis2.py
Normal file
@ -0,0 +1,693 @@
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, Types
|
||||
from comfy.ldm.trellis2.vae import SparseTensor
|
||||
import comfy.model_management
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import torch
|
||||
import scipy
|
||||
import copy
|
||||
|
||||
shape_slat_normalization = {
|
||||
"mean": torch.tensor([
|
||||
0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218,
|
||||
-0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944,
|
||||
0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667,
|
||||
-0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665
|
||||
])[None],
|
||||
"std": torch.tensor([
|
||||
5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004,
|
||||
5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578,
|
||||
4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194,
|
||||
5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691
|
||||
])[None]
|
||||
}
|
||||
|
||||
tex_slat_normalization = {
|
||||
"mean": torch.tensor([
|
||||
3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075,
|
||||
0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149,
|
||||
-1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717,
|
||||
1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862
|
||||
])[None],
|
||||
"std": torch.tensor([
|
||||
2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822,
|
||||
2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588,
|
||||
2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999,
|
||||
2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190
|
||||
])[None]
|
||||
}
|
||||
|
||||
def shape_norm(shape_latent, coords):
|
||||
std = shape_slat_normalization["std"].to(shape_latent)
|
||||
mean = shape_slat_normalization["mean"].to(shape_latent)
|
||||
samples = SparseTensor(feats = shape_latent, coords=coords)
|
||||
samples = samples * std + mean
|
||||
return samples
|
||||
|
||||
def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution):
|
||||
"""
|
||||
Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field.
|
||||
"""
|
||||
device = comfy.model_management.vae_offload_device()
|
||||
|
||||
origin = torch.tensor([-0.5, -0.5, -0.5], device=device)
|
||||
# TODO: generic independent node? if so: figure how pass the resolution parameter
|
||||
voxel_size = 1.0 / resolution
|
||||
|
||||
# map voxels
|
||||
voxel_pos = voxel_coords.to(device).float() * voxel_size + origin
|
||||
verts = mesh.vertices.to(device).squeeze(0)
|
||||
voxel_colors = voxel_colors.to(device)
|
||||
|
||||
voxel_pos_np = voxel_pos.numpy()
|
||||
verts_np = verts.numpy()
|
||||
|
||||
tree = scipy.spatial.cKDTree(voxel_pos_np)
|
||||
|
||||
# nearest neighbour k=1
|
||||
_, nearest_idx_np = tree.query(verts_np, k=1, workers=-1)
|
||||
|
||||
nearest_idx = torch.from_numpy(nearest_idx_np).long()
|
||||
v_colors = voxel_colors[nearest_idx]
|
||||
|
||||
# to [0, 1]
|
||||
srgb_colors = v_colors.clamp(0, 1)#(v_colors * 0.5 + 0.5).clamp(0, 1)
|
||||
|
||||
# to Linear RGB (required for GLTF)
|
||||
linear_colors = torch.pow(srgb_colors, 2.2)
|
||||
|
||||
final_colors = linear_colors.unsqueeze(0)
|
||||
|
||||
out_mesh = copy.deepcopy(mesh)
|
||||
out_mesh.colors = final_colors
|
||||
|
||||
return out_mesh
|
||||
|
||||
class VaeDecodeShapeTrellis(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VaeDecodeShapeTrellis",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Combo.Input("resolution", options=["512", "1024"], default="1024")
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output("mesh"),
|
||||
IO.AnyType.Output("shape_subs"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae, resolution):
|
||||
|
||||
resolution = int(resolution)
|
||||
patcher = vae.patcher
|
||||
device = comfy.model_management.get_torch_device()
|
||||
comfy.model_management.load_model_gpu(patcher)
|
||||
|
||||
vae = vae.first_stage_model
|
||||
coords = samples["coords"]
|
||||
|
||||
samples = samples["samples"]
|
||||
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||
samples = shape_norm(samples, coords)
|
||||
|
||||
mesh, subs = vae.decode_shape_slat(samples, resolution)
|
||||
faces = torch.stack([m.faces for m in mesh])
|
||||
verts = torch.stack([m.vertices for m in mesh])
|
||||
mesh = Types.MESH(vertices=verts, faces=faces)
|
||||
return IO.NodeOutput(mesh, subs)
|
||||
|
||||
class VaeDecodeTextureTrellis(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VaeDecodeTextureTrellis",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Mesh.Input("shape_mesh"),
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.AnyType.Input("shape_subs"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output("mesh"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, shape_mesh, samples, vae, shape_subs):
|
||||
|
||||
resolution = 1024
|
||||
patcher = vae.patcher
|
||||
device = comfy.model_management.get_torch_device()
|
||||
comfy.model_management.load_model_gpu(patcher)
|
||||
|
||||
vae = vae.first_stage_model
|
||||
coords = samples["coords"]
|
||||
|
||||
samples = samples["samples"]
|
||||
samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||
std = tex_slat_normalization["std"].to(samples)
|
||||
mean = tex_slat_normalization["mean"].to(samples)
|
||||
samples = SparseTensor(feats = samples, coords=coords)
|
||||
samples = samples * std + mean
|
||||
|
||||
voxel = vae.decode_tex_slat(samples, shape_subs)
|
||||
color_feats = voxel.feats[:, :3]
|
||||
voxel_coords = voxel.coords[:, 1:]
|
||||
|
||||
out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution)
|
||||
return IO.NodeOutput(out_mesh)
|
||||
|
||||
class VaeDecodeStructureTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VaeDecodeStructureTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Combo.Input("resolution", options=["32", "64"], default="32")
|
||||
],
|
||||
outputs=[
|
||||
IO.Voxel.Output("structure_output"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, samples, vae, resolution):
|
||||
resolution = int(resolution)
|
||||
vae = vae.first_stage_model
|
||||
decoder = vae.struct_dec
|
||||
load_device = comfy.model_management.get_torch_device()
|
||||
offload_device = comfy.model_management.vae_offload_device()
|
||||
decoder = decoder.to(load_device)
|
||||
samples = samples["samples"]
|
||||
samples = samples.to(load_device)
|
||||
decoded = decoder(samples)>0
|
||||
decoder.to(offload_device)
|
||||
current_res = decoded.shape[2]
|
||||
|
||||
if current_res != resolution:
|
||||
ratio = current_res // resolution
|
||||
decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5
|
||||
out = Types.VOXEL(decoded.squeeze(1).float())
|
||||
return IO.NodeOutput(out)
|
||||
|
||||
class Trellis2UpsampleCascade(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Trellis2UpsampleCascade",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Latent.Input("shape_latent_512"),
|
||||
IO.Vae.Input("vae"),
|
||||
IO.Combo.Input("target_resolution", options=["1024", "1536"], default="1024"),
|
||||
IO.Int.Input("max_tokens", default=49152, min=1024, max=100000)
|
||||
],
|
||||
outputs=[
|
||||
IO.AnyType.Output("hr_coords"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, shape_latent_512, vae, target_resolution, max_tokens):
|
||||
device = comfy.model_management.get_torch_device()
|
||||
comfy.model_management.load_model_gpu(vae.patcher)
|
||||
|
||||
feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device)
|
||||
coords_512 = shape_latent_512["coords"].to(device)
|
||||
|
||||
slat = shape_norm(feats, coords_512)
|
||||
|
||||
decoder = vae.first_stage_model.shape_dec
|
||||
|
||||
slat.feats = slat.feats.to(next(decoder.parameters()).dtype)
|
||||
hr_coords = decoder.upsample(slat, upsample_times=4)
|
||||
|
||||
lr_resolution = 512
|
||||
hr_resolution = int(target_resolution)
|
||||
|
||||
while True:
|
||||
quant_coords = torch.cat([
|
||||
hr_coords[:, :1],
|
||||
((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(),
|
||||
], dim=1)
|
||||
final_coords = quant_coords.unique(dim=0)
|
||||
num_tokens = final_coords.shape[0]
|
||||
|
||||
if num_tokens < max_tokens or hr_resolution <= 1024:
|
||||
break
|
||||
hr_resolution -= 128
|
||||
|
||||
return IO.NodeOutput(final_coords,)
|
||||
|
||||
dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
||||
dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
||||
|
||||
def run_conditioning(model, cropped_img_tensor, include_1024=True):
|
||||
model_internal = model.model
|
||||
device = comfy.model_management.intermediate_device()
|
||||
torch_device = comfy.model_management.get_torch_device()
|
||||
|
||||
def prepare_tensor(pil_img, size):
|
||||
resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS)
|
||||
img_np = np.array(resized_pil).astype(np.float32) / 255.0
|
||||
img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device)
|
||||
return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device)
|
||||
|
||||
model_internal.image_size = 512
|
||||
input_512 = prepare_tensor(cropped_img_tensor, 512)
|
||||
cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0]
|
||||
|
||||
cond_1024 = None
|
||||
if include_1024:
|
||||
model_internal.image_size = 1024
|
||||
input_1024 = prepare_tensor(cropped_img_tensor, 1024)
|
||||
cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0]
|
||||
|
||||
conditioning = {
|
||||
'cond_512': cond_512.to(device),
|
||||
'neg_cond': torch.zeros_like(cond_512).to(device),
|
||||
}
|
||||
if cond_1024 is not None:
|
||||
conditioning['cond_1024'] = cond_1024.to(device)
|
||||
|
||||
return conditioning
|
||||
|
||||
class Trellis2Conditioning(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Trellis2Conditioning",
|
||||
category="conditioning/video_models",
|
||||
inputs=[
|
||||
IO.ClipVision.Input("clip_vision_model"),
|
||||
IO.Image.Input("image"),
|
||||
IO.Mask.Input("mask"),
|
||||
IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black")
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput:
|
||||
|
||||
if image.ndim == 4:
|
||||
image = image[0]
|
||||
if mask.ndim == 3:
|
||||
mask = mask[0]
|
||||
|
||||
img_np = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
mask_np = (mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
|
||||
|
||||
pil_img = Image.fromarray(img_np)
|
||||
pil_mask = Image.fromarray(mask_np)
|
||||
|
||||
max_size = max(pil_img.size)
|
||||
scale = min(1.0, 1024 / max_size)
|
||||
if scale < 1.0:
|
||||
new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale)
|
||||
pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS)
|
||||
pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST)
|
||||
|
||||
rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8)
|
||||
rgba_np[:, :, :3] = np.array(pil_img)
|
||||
rgba_np[:, :, 3] = np.array(pil_mask)
|
||||
|
||||
alpha = rgba_np[:, :, 3]
|
||||
bbox_coords = np.argwhere(alpha > 0.8 * 255)
|
||||
|
||||
if len(bbox_coords) > 0:
|
||||
y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1])
|
||||
y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1])
|
||||
|
||||
center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0
|
||||
size = max(y_max - y_min, x_max - x_min)
|
||||
|
||||
crop_x1 = int(center_x - size // 2)
|
||||
crop_y1 = int(center_y - size // 2)
|
||||
crop_x2 = int(center_x + size // 2)
|
||||
crop_y2 = int(center_y + size // 2)
|
||||
|
||||
rgba_pil = Image.fromarray(rgba_np)
|
||||
cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2))
|
||||
cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0
|
||||
else:
|
||||
import logging
|
||||
logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.")
|
||||
cropped_np = rgba_np.astype(np.float32) / 255.0
|
||||
|
||||
bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]}
|
||||
bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32)
|
||||
|
||||
fg = cropped_np[:, :, :3]
|
||||
alpha_float = cropped_np[:, :, 3:4]
|
||||
composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float)
|
||||
|
||||
# to match trellis2 code (quantize -> dequantize)
|
||||
composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8)
|
||||
|
||||
cropped_pil = Image.fromarray(composite_uint8)
|
||||
|
||||
conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True)
|
||||
|
||||
embeds = conditioning["cond_1024"]
|
||||
positive = [[conditioning["cond_512"], {"embeds": embeds}]]
|
||||
negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]]
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
class EmptyShapeLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyShapeLatentTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.AnyType.Input("structure_or_coords"),
|
||||
IO.Model.Input("model")
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
IO.Model.Output()
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, structure_or_coords, model):
|
||||
# to accept the upscaled coords
|
||||
is_512_pass = False
|
||||
|
||||
if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4:
|
||||
decoded = structure_or_coords.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
is_512_pass = True
|
||||
|
||||
elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2:
|
||||
coords = structure_or_coords.int()
|
||||
is_512_pass = False
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid input to EmptyShapeLatent: {type(structure_or_coords)}")
|
||||
in_channels = 32
|
||||
# image like format
|
||||
latent = torch.randn(1, in_channels, coords.shape[0], 1)
|
||||
model = model.clone()
|
||||
model.model_options = model.model_options.copy()
|
||||
if "transformer_options" in model.model_options:
|
||||
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
||||
else:
|
||||
model.model_options["transformer_options"] = {}
|
||||
|
||||
model.model_options["transformer_options"]["coords"] = coords
|
||||
if is_512_pass:
|
||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512"
|
||||
else:
|
||||
model.model_options["transformer_options"]["generation_mode"] = "shape_generation"
|
||||
return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model)
|
||||
|
||||
class EmptyTextureLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyTextureLatentTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Voxel.Input("structure_or_coords"),
|
||||
IO.Latent.Input("shape_latent"),
|
||||
IO.Model.Input("model")
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
IO.Model.Output()
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, structure_or_coords, shape_latent, model):
|
||||
channels = 32
|
||||
if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4:
|
||||
decoded = structure_or_coords.data.unsqueeze(1)
|
||||
coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int()
|
||||
|
||||
elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2:
|
||||
coords = structure_or_coords.int()
|
||||
|
||||
shape_latent = shape_latent["samples"]
|
||||
if shape_latent.ndim == 4:
|
||||
shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels)
|
||||
|
||||
latent = torch.randn(1, channels, coords.shape[0], 1)
|
||||
model = model.clone()
|
||||
model.model_options = model.model_options.copy()
|
||||
if "transformer_options" in model.model_options:
|
||||
model.model_options["transformer_options"] = model.model_options["transformer_options"].copy()
|
||||
else:
|
||||
model.model_options["transformer_options"] = {}
|
||||
|
||||
model.model_options["transformer_options"]["coords"] = coords
|
||||
model.model_options["transformer_options"]["generation_mode"] = "texture_generation"
|
||||
model.model_options["transformer_options"]["shape_slat"] = shape_latent
|
||||
return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model)
|
||||
|
||||
|
||||
class EmptyStructureLatentTrellis2(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyStructureLatentTrellis2",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."),
|
||||
],
|
||||
outputs=[
|
||||
IO.Latent.Output(),
|
||||
]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, batch_size):
|
||||
in_channels = 8
|
||||
resolution = 16
|
||||
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
|
||||
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
|
||||
|
||||
def simplify_fn(vertices, faces, colors=None, target=100000):
|
||||
if vertices.ndim == 3:
|
||||
v_list, f_list, c_list = [], [], []
|
||||
for i in range(vertices.shape[0]):
|
||||
c_in = colors[i] if colors is not None else None
|
||||
v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target)
|
||||
v_list.append(v_i)
|
||||
f_list.append(f_i)
|
||||
if c_i is not None:
|
||||
c_list.append(c_i)
|
||||
|
||||
c_out = torch.stack(c_list) if len(c_list) > 0 else None
|
||||
return torch.stack(v_list), torch.stack(f_list), c_out
|
||||
|
||||
if faces.shape[0] <= target:
|
||||
return vertices, faces, colors
|
||||
|
||||
device = vertices.device
|
||||
target_v = max(target / 4.0, 1.0)
|
||||
|
||||
min_v = vertices.min(dim=0)[0]
|
||||
max_v = vertices.max(dim=0)[0]
|
||||
extent = max_v - min_v
|
||||
|
||||
volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8)
|
||||
cell_size = (volume / target_v) ** (1/3.0)
|
||||
|
||||
quantized = ((vertices - min_v) / cell_size).round().long()
|
||||
unique_coords, inverse_indices = torch.unique(quantized, dim=0, return_inverse=True)
|
||||
num_cells = unique_coords.shape[0]
|
||||
|
||||
new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device)
|
||||
counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device)
|
||||
new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices)
|
||||
counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1]))
|
||||
new_vertices = new_vertices / counts.clamp(min=1)
|
||||
|
||||
new_colors = None
|
||||
if colors is not None:
|
||||
new_colors = torch.zeros((num_cells, colors.shape[1]), dtype=colors.dtype, device=device)
|
||||
new_colors.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, colors.shape[1]), colors)
|
||||
new_colors = new_colors / counts.clamp(min=1)
|
||||
|
||||
new_faces = inverse_indices[faces]
|
||||
valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \
|
||||
(new_faces[:, 1] != new_faces[:, 2]) & \
|
||||
(new_faces[:, 2] != new_faces[:, 0])
|
||||
new_faces = new_faces[valid_mask]
|
||||
|
||||
unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True)
|
||||
final_vertices = new_vertices[unique_face_indices]
|
||||
final_faces = inv_face.reshape(-1, 3)
|
||||
|
||||
# assign colors
|
||||
final_colors = new_colors[unique_face_indices] if new_colors is not None else None
|
||||
|
||||
return final_vertices, final_faces, final_colors
|
||||
|
||||
def fill_holes_fn(vertices, faces, max_perimeter=0.03):
|
||||
is_batched = vertices.ndim == 3
|
||||
if is_batched:
|
||||
v_list, f_list = [],[]
|
||||
for i in range(vertices.shape[0]):
|
||||
v_i, f_i = fill_holes_fn(vertices[i], faces[i], max_perimeter)
|
||||
v_list.append(v_i)
|
||||
f_list.append(f_i)
|
||||
return torch.stack(v_list), torch.stack(f_list)
|
||||
|
||||
device = vertices.device
|
||||
v = vertices
|
||||
f = faces
|
||||
|
||||
if f.numel() == 0:
|
||||
return v, f
|
||||
|
||||
edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0)
|
||||
edges_sorted, _ = torch.sort(edges, dim=1)
|
||||
|
||||
max_v = v.shape[0]
|
||||
packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long()
|
||||
|
||||
unique_packed, counts = torch.unique(packed_undirected, return_counts=True)
|
||||
boundary_packed = unique_packed[counts == 1]
|
||||
|
||||
if boundary_packed.numel() == 0:
|
||||
return v, f
|
||||
|
||||
packed_directed_sorted = edges[:, 0].min(edges[:, 1]).long() * max_v + edges[:, 0].max(edges[:, 1]).long()
|
||||
is_boundary = torch.isin(packed_directed_sorted, boundary_packed)
|
||||
b_edges = edges[is_boundary]
|
||||
|
||||
adj = {u.item(): v_idx.item() for u, v_idx in b_edges}
|
||||
|
||||
loops =[]
|
||||
visited = set()
|
||||
|
||||
for start_node in adj.keys():
|
||||
if start_node in visited:
|
||||
continue
|
||||
|
||||
curr = start_node
|
||||
loop = []
|
||||
|
||||
while curr not in visited:
|
||||
visited.add(curr)
|
||||
loop.append(curr)
|
||||
curr = adj.get(curr, -1)
|
||||
|
||||
if curr == -1:
|
||||
loop = []
|
||||
break
|
||||
if curr == start_node:
|
||||
loops.append(loop)
|
||||
break
|
||||
|
||||
new_verts =[]
|
||||
new_faces = []
|
||||
v_idx = v.shape[0]
|
||||
|
||||
for loop in loops:
|
||||
loop_t = torch.tensor(loop, device=device, dtype=torch.long)
|
||||
loop_v = v[loop_t]
|
||||
|
||||
diffs = loop_v - torch.roll(loop_v, shifts=-1, dims=0)
|
||||
perimeter = torch.norm(diffs, dim=1).sum().item()
|
||||
|
||||
if perimeter <= max_perimeter:
|
||||
new_verts.append(loop_v.mean(dim=0))
|
||||
|
||||
for i in range(len(loop)):
|
||||
new_faces.append([loop[(i + 1) % len(loop)], loop[i], v_idx])
|
||||
v_idx += 1
|
||||
|
||||
if new_verts:
|
||||
v = torch.cat([v, torch.stack(new_verts)], dim=0)
|
||||
f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0)
|
||||
|
||||
return v, f
|
||||
|
||||
|
||||
def make_double_sided(vertices, faces):
|
||||
is_batched = vertices.ndim == 3
|
||||
if is_batched:
|
||||
f_list = []
|
||||
for i in range(faces.shape[0]):
|
||||
f_inv = faces[i][:, [0, 2, 1]]
|
||||
f_list.append(torch.cat([faces[i], f_inv], dim=0))
|
||||
return vertices, torch.stack(f_list)
|
||||
|
||||
faces_inv = faces[:, [0, 2, 1]]
|
||||
return vertices, torch.cat([faces, faces_inv], dim=0)
|
||||
|
||||
class PostProcessMesh(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="PostProcessMesh",
|
||||
category="latent/3d",
|
||||
inputs=[
|
||||
IO.Mesh.Input("mesh"),
|
||||
IO.Int.Input("simplify", default=1_000_000, min=0, max=50_000_000),
|
||||
IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001)
|
||||
],
|
||||
outputs=[
|
||||
IO.Mesh.Output("output_mesh"),
|
||||
]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mesh, simplify, fill_holes_perimeter):
|
||||
# TODO: batched mode may break
|
||||
verts, faces = mesh.vertices, mesh.faces
|
||||
colors = None
|
||||
if hasattr(mesh, "colors"):
|
||||
colors = mesh.colors
|
||||
|
||||
actual_face_count = faces.shape[1] if faces.ndim == 3 else faces.shape[0]
|
||||
if fill_holes_perimeter > 0:
|
||||
verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter)
|
||||
|
||||
if simplify > 0 and actual_face_count > simplify:
|
||||
verts, faces, colors = simplify_fn(verts, faces, target=simplify, colors=colors)
|
||||
|
||||
verts, faces = make_double_sided(verts, faces)
|
||||
|
||||
mesh = type(mesh)(vertices=verts, faces=faces)
|
||||
mesh.vertices = verts
|
||||
mesh.faces = faces
|
||||
if colors is not None:
|
||||
mesh.colors = colors
|
||||
return IO.NodeOutput(mesh)
|
||||
|
||||
class Trellis2Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
Trellis2Conditioning,
|
||||
EmptyShapeLatentTrellis2,
|
||||
EmptyStructureLatentTrellis2,
|
||||
EmptyTextureLatentTrellis2,
|
||||
VaeDecodeTextureTrellis,
|
||||
VaeDecodeShapeTrellis,
|
||||
VaeDecodeStructureTrellis2,
|
||||
Trellis2UpsampleCascade,
|
||||
PostProcessMesh
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> Trellis2Extension:
|
||||
return Trellis2Extension()
|
||||
Loading…
Reference in New Issue
Block a user