diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 1691fca81..71f2200b7 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -9,6 +9,7 @@ import comfy.model_management import comfy.utils import comfy.clip_model import comfy.image_encoders.dino2 +import comfy.image_encoders.dino3 class Output: def __getitem__(self, key): @@ -23,6 +24,7 @@ IMAGE_ENCODERS = { "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection, "siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection, "dinov2": comfy.image_encoders.dino2.Dinov2Model, + "dinov3": comfy.image_encoders.dino3.DINOv3ViTModel } class ClipVisionModel(): @@ -134,6 +136,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json") elif 'encoder.layer.23.layer_scale2.lambda1' in sd: json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json") + elif 'layer.9.attention.o_proj.bias' in sd: # dinov3 + json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino3_large.json") else: return None diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py new file mode 100644 index 000000000..145bd5490 --- /dev/null +++ b/comfy/image_encoders/dino3.py @@ -0,0 +1,285 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +import comfy.model_management +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale + +class DINOv3ViTMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.GELU() + + def forward(self, x): + return self.down_proj(self.act_fn(self.up_proj(x))) + +def rotate_half(x): + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) +def apply_rotary_pos_emb(q, k, cos, sin, **kwargs): + num_tokens = q.shape[-2] + num_patches = sin.shape[-2] + num_prefix_tokens = num_tokens - num_patches + + q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2) + k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2) + + q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin) + k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin) + + q = torch.cat((q_prefix_tokens, q_patches), dim=-2) + k = torch.cat((k_prefix_tokens, k_patches), dim=-2) + + return q, k + +class DINOv3ViTAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): + super().__init__() + self.embed_dim = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + + self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False + self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + + self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + attn = optimized_attention_for_device(query_states.device, mask=False) + + attn_output = attn( + query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True + ) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output + +class DINOv3ViTGatedMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.GELU() + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +def get_patches_center_coordinates( + num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + + coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) + coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) + coords_h = coords_h / num_patches_h + coords_w = coords_w / num_patches_w + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) + coords = coords.flatten(0, 1) + coords = 2.0 * coords - 1.0 + return coords + +class DINOv3ViTRopePositionEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, rope_theta, hidden_size, num_attention_heads, image_size, patch_size, device, dtype): + super().__init__() + self.base = rope_theta + self.head_dim = hidden_size // num_attention_heads + self.num_patches_h = image_size // patch_size + self.num_patches_w = image_size // patch_size + self.patch_size = patch_size + + inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + _, _, height, width = pixel_values.shape + num_patches_h = height // self.patch_size + num_patches_w = width // self.patch_size + + device = pixel_values.device + device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu" + with torch.amp.autocast(device_type = device_type, enabled=False): + patch_coords = get_patches_center_coordinates( + num_patches_h, num_patches_w, dtype=torch.float32, device=device + ) + + self.inv_freq = self.inv_freq.to(device) + angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] + angles = angles.flatten(1, 2) + angles = angles.tile(2) + + cos = torch.cos(angles) + sin = torch.sin(angles) + + dtype = pixel_values.dtype + return cos.to(dtype=dtype), sin.to(dtype=dtype) + + +class DINOv3ViTEmbeddings(nn.Module): + def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations): + super().__init__() + self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_size, device=device, dtype=dtype)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype)) + self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype)) + self.patch_embeddings = operations.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype + ) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None): + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embeddings.weight.dtype + + patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) + + if bool_masked_pos is not None: + mask_token = self.mask_token.to(patch_embeddings.dtype) + patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) + + cls_token = self.cls_token.expand(batch_size, -1, -1) + register_tokens = self.register_tokens.expand(batch_size, -1, -1) + device = patch_embeddings.device + cls_token = cls_token.to(device) + register_tokens = register_tokens.to(device) + embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) + + return embeddings + +class DINOv3ViTLayer(nn.Module): + + def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size, num_attention_heads, + device, dtype, operations): + super().__init__() + + self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) + self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) + + self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + + if use_gated_mlp: + self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations) + else: + self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations) + self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = self.attention( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + hidden_states = self.layer_scale1(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_scale2(hidden_states) + hidden_states = hidden_states + residual + + return hidden_states + + +class DINOv3ViTModel(nn.Module): + def __init__(self, config, dtype, device, operations): + super().__init__() + use_bf16 = comfy.model_management.should_use_bf16(device, prioritize_performance=True) + if dtype == torch.float16 and use_bf16: + dtype = torch.bfloat16 + elif dtype == torch.float16 and not use_bf16: + dtype = torch.float32 + num_hidden_layers = config["num_hidden_layers"] + hidden_size = config["hidden_size"] + num_attention_heads = config["num_attention_heads"] + num_register_tokens = config["num_register_tokens"] + intermediate_size = config["intermediate_size"] + layer_norm_eps = config["layer_norm_eps"] + num_channels = config["num_channels"] + patch_size = config["patch_size"] + rope_theta = config["rope_theta"] + + self.embeddings = DINOv3ViTEmbeddings( + hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size, dtype=dtype, device=device, operations=operations + ) + self.rope_embeddings = DINOv3ViTRopePositionEmbedding( + rope_theta, hidden_size, num_attention_heads, image_size=512, patch_size=patch_size, dtype=dtype, device=device + ) + self.layer = nn.ModuleList( + [DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, mlp_bias=True, + intermediate_size=intermediate_size,num_attention_heads = num_attention_heads, + dtype=dtype, device=device, operations=operations) + for _ in range(num_hidden_layers)]) + self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: torch.Tensor | None = None, + **kwargs, + ): + + pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + position_embeddings = self.rope_embeddings(pixel_values) + + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module( + hidden_states, + position_embeddings=position_embeddings, + ) + + if kwargs.get("skip_norm_elementwise", False): + sequence_output= F.layer_norm(hidden_states, hidden_states.shape[-1:]) + else: + norm = self.norm.to(hidden_states.device) + sequence_output = norm(hidden_states) + pooled_output = sequence_output[:, 0, :] + + return sequence_output, None, pooled_output, None diff --git a/comfy/image_encoders/dino3_large.json b/comfy/image_encoders/dino3_large.json new file mode 100644 index 000000000..b37b61dc8 --- /dev/null +++ b/comfy/image_encoders/dino3_large.json @@ -0,0 +1,23 @@ +{ + "model_type": "dinov3", + "hidden_size": 1024, + "image_size": 224, + "initializer_range": 0.02, + "intermediate_size": 4096, + "key_bias": false, + "layer_norm_eps": 1e-05, + "mlp_bias": true, + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "num_register_tokens": 4, + "patch_size": 16, + "pos_embed_rescale": 2.0, + "proj_bias": true, + "query_bias": true, + "rope_theta": 100.0, + "use_gated_mlp": false, + "value_bias": true, + "image_mean": [0.485, 0.456, 0.406], + "image_std": [0.229, 0.224, 0.225] +} diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 6a57bca1c..8ea84fa05 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -746,6 +746,8 @@ class Hunyuan3Dv2_1(LatentFormat): latent_channels = 64 latent_dimensions = 1 +class Trellis2(LatentFormat): # TODO + latent_channels = 32 class Hunyuan3Dv2mini(LatentFormat): latent_channels = 64 latent_dimensions = 1 diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py new file mode 100644 index 000000000..d95b071b5 --- /dev/null +++ b/comfy/ldm/trellis2/attention.py @@ -0,0 +1,282 @@ +import torch +import math +from comfy.ldm.modules.attention import optimized_attention +from typing import Tuple, Union, List +from comfy.ldm.trellis2.vae import VarLenTensor +import comfy.ops + + +# replica of the seedvr2 code +def var_attn_arg(kwargs): + cu_seqlens_q = kwargs.get("cu_seqlens_q", None) + max_seqlen_q = kwargs.get("max_seqlen_q", None) + cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q) + max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q) + assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True" + return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k + +def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + var_length = True + if var_length: + cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs) + if not skip_reshape: + # assumes 2D q, k,v [total_tokens, embed_dim] + total_tokens, embed_dim = q.shape + head_dim = embed_dim // heads + q = q.view(total_tokens, heads, head_dim) + k = k.view(k.shape[0], heads, head_dim) + v = v.view(v.shape[0], heads, head_dim) + + b = q.size(0) + dim_head = q.shape[-1] + q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long()) + k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long()) + v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long()) + + mask = None + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + if mask is not None: + if mask.ndim == 2: + mask = mask.unsqueeze(0) + if mask.ndim == 3: + mask = mask.unsqueeze(1) + + out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + if var_length: + return out.transpose(1, 2).values() + if not skip_output_reshape: + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + return out + +def scaled_dot_product_attention(*args, **kwargs): + num_all_args = len(args) + len(kwargs) + + q = None + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs.get('qkv') + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs.get('q') + kv = args[1] if len(args) > 1 else kwargs.get('kv') + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs.get('q') + k = args[1] if len(args) > 1 else kwargs.get('k') + v = args[2] if len(args) > 2 else kwargs.get('v') + + if q is not None: + heads = q.shape[2] + else: + heads = qkv.shape[3] + + if num_all_args == 1: + q, k, v = qkv.unbind(dim=2) + elif num_all_args == 2: + k, v = kv.unbind(dim=2) + + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs) + + out = out.permute(0, 2, 1, 3) + + return out + +def sparse_windowed_scaled_dot_product_self_attention( + qkv, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +): + + serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args)) + else: + fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + heads = qkv_feats.shape[2] + + if optimized_attention.__name__ == 'attention_xformers': + q, k, v = qkv_feats.unbind(dim=1) + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + #out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) + elif optimized_attention.__name__ == 'attention_flash': + if 'flash_attn' not in globals(): + import flash_attn + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C] + else: + out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True) + + out = out[bwd_indices] # [T, H, C] + + return qkv.replace(out) + +def calc_window_partition( + tensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0, +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)] + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + mask = seq_lens != 0 + seq_lens = seq_lens[mask] + + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + attn_func_args = { + 'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + } + elif optimized_attention.__name__ == 'attention_flash': + attn_func_args = { + 'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(), + 'max_seqlen': torch.max(seq_lens) + } + + return fwd_indices, bwd_indices, seq_lens, attn_func_args + + +def sparse_scaled_dot_product_attention(*args, **kwargs): + q=None + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + device = q.device + + if isinstance(q, VarLenTensor): + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, VarLenTensor): + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + device = q.device + + if isinstance(q, VarLenTensor): + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, VarLenTensor): + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + # TODO: change + if q is not None: + heads = q + else: + heads = qkv + heads = heads.shape[2] + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif optimized_attention.__name__ == 'attention_flash': + if 'flash_attn' not in globals(): + import flash_attn + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + + elif optimized_attention.__name__ == "attention_pytorch": + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + else: + cu_seqlens_kv = cu_seqlens_q + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen), + skip_reshape=True, skip_output_reshape=True) + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py new file mode 100644 index 000000000..047e785ff --- /dev/null +++ b/comfy/ldm/trellis2/cumesh.py @@ -0,0 +1,433 @@ +# will contain every cuda -> pytorch operation + +import math +import torch +from typing import Callable +import logging + +NO_TRITON = False +try: + allow_tf32 = torch.cuda.is_tf32_supported() +except Exception: + allow_tf32 = False +try: + import triton + import triton.language as tl + heuristics = { + 'valid_kernel': lambda args: args['valid_kernel'](args['B1']), + 'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']), + } + + #@triton_autotune( + # configs=config.autotune_config, + # key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'], + #) + @triton.heuristics(heuristics) + @triton.jit + def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel( + input, + weight, + bias, + neighbor, + sorted_idx, + output, + # Tensor dimensions + N, LOGN, Ci, Co, V: tl.constexpr, + # Meta-parameters + B1: tl.constexpr, # Block size for N dimension + B2: tl.constexpr, # Block size for Co dimension + BK: tl.constexpr, # Block size for K dimension (V * Ci) + allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls + # Huristic parameters + valid_kernel, + valid_kernel_seg, + ): + + block_id = tl.program_id(axis=0) + block_dim_co = tl.cdiv(Co, B2) + block_id_co = block_id % block_dim_co + block_id_n = block_id // block_dim_co + + # Create pointers for submatrices of A and B. + num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension + valid_kernel_start = tl.load(valid_kernel_seg + block_id_n) + valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start + offset_n = block_id_n * B1 + tl.arange(0, B1) + n_mask = offset_n < N + offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,) + offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,) + offset_k = tl.arange(0, BK) # (BK,) + + # Create a block of the output matrix C. + accumulator = tl.zeros((B1, B2), dtype=tl.float32) + + # Iterate along V*Ci dimension. + for k in range(num_k * valid_kernel_seglen): + v = k // num_k + bk = k % num_k + v = tl.load(valid_kernel + valid_kernel_start + v) + # Calculate pointers to input matrix. + neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,) + input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK) + # Calculate pointers to weight matrix. + weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2) + # Load the next block of input and weight. + neigh_mask = neighbor_offset_n != 0xffffffff + k_mask = offset_k < Ci - bk * BK + input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0) + weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0) + # Accumulate along the K dimension. + accumulator = tl.dot(input_block, weight_block, accumulator, + input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2) + c = accumulator.to(input.type.element_ty) + + # add bias + if bias is not None: + bias_block = tl.load(bias + offset_co) + c += bias_block[None, :] + + # Write back the block of the output matrix with masks. + out_offset_n = offset_sorted_n + out_offset_co = block_id_co * B2 + tl.arange(0, B2) + out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :]) + out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co) + tl.store(out_ptr, c, mask=out_mask) + def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk( + input: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + neighbor: torch.Tensor, + sorted_idx: torch.Tensor, + valid_kernel: Callable[[int], torch.Tensor], + valid_kernel_seg: Callable[[int], torch.Tensor], + ) -> torch.Tensor: + N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1] + LOGN = int(math.log2(N)) + output = torch.empty((N, Co), device=input.device, dtype=input.dtype) + grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),) + sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid]( + input, weight, bias, neighbor, sorted_idx, output, + N, LOGN, Ci, Co, V, + B1=128, + B2=64, + BK=32, + valid_kernel=valid_kernel, + valid_kernel_seg=valid_kernel_seg, + allow_tf32=allow_tf32, + ) + return output +except Exception: + NO_TRITON = True + +def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device): + # offsets in same order as CUDA kernel + offsets = [] + for vx in range(Kw): + for vy in range(Kh): + for vz in range(Kd): + offsets.append(( + vx * Dw, + vy * Dh, + vz * Dd + )) + return torch.tensor(offsets, device=device) + +def build_submanifold_neighbor_map( + hashmap, + coords: torch.Tensor, + W, H, D, + Kw, Kh, Kd, + Dw, Dh, Dd, +): + device = coords.device + M = coords.shape[0] + V = Kw * Kh * Kd + half_V = V // 2 + 1 + + INVALID = hashmap.default_value + + neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long) + + b = coords[:, 0].long() + x = coords[:, 1].long() + y = coords[:, 2].long() + z = coords[:, 3].long() + + offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device) + + ox = x - (Kw // 2) * Dw + oy = y - (Kh // 2) * Dh + oz = z - (Kd // 2) * Dd + + for v in range(half_V): + if v == half_V - 1: + neighbor[:, v] = torch.arange(M, device=device) + continue + + dx, dy, dz = offsets[v] + + kx = ox + dx + ky = oy + dy + kz = oz + dz + + # Check spatial bounds + valid = ( + (kx >= 0) & (kx < W) & + (ky >= 0) & (ky < H) & + (kz >= 0) & (kz < D) + ) + + flat = ( + b[valid] * (W * H * D) + + kx[valid] * (H * D) + + ky[valid] * D + + kz[valid] + ) + + if flat.numel() > 0: + found = hashmap.lookup_flat(flat) + idx_in_M = torch.where(valid)[0] + neighbor[idx_in_M, v] = found + + valid_found_mask = (found != INVALID) + if valid_found_mask.any(): + src_points = idx_in_M[valid_found_mask] + dst_points = found[valid_found_mask] + neighbor[dst_points, V - 1 - v] = src_points + + return neighbor + +class TorchHashMap: + def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): + device = keys.device + # use long for searchsorted + self.sorted_keys, order = torch.sort(keys.to(torch.long)) + self.sorted_vals = values.to(torch.long)[order] + self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) + self._n = self.sorted_keys.numel() + + def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: + flat = flat_keys.to(torch.long) + if self._n == 0: + return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) + idx = torch.searchsorted(self.sorted_keys, flat) + idx_safe = torch.clamp(idx, max=self._n - 1) + found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat) + out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) + if found.any(): + out[found] = self.sorted_vals[idx_safe[found]] + return out + + +UINT32_SENTINEL = 0xFFFFFFFF + +def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map): + device = neighbor_map.device + N, V = neighbor_map.shape + + sentinel = UINT32_SENTINEL + + neigh_map_T = neighbor_map.t().reshape(-1) + neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32) + + mask = (neighbor_map != sentinel).to(torch.long) + gray_code = torch.zeros(N, dtype=torch.long, device=device) + + for v in range(V): + gray_code |= (mask[:, v] << v) + + binary_code = gray_code.clone() + for v in range(1, V): + binary_code ^= (gray_code >> v) + + sorted_idx = torch.argsort(binary_code) + + prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0) + + total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0 + + if total_valid_signal > 0: + pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0] + to = (prefix_sum_neighbor_mask[pos] - 1).long() + + valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device) + valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device) + + valid_signal_i[to] = (pos % N).to(torch.long) + valid_signal_o[to] = neigh_map_T[pos].to(torch.long) + else: + valid_signal_i = torch.empty((0,), dtype=torch.long, device=device) + valid_signal_o = torch.empty((0,), dtype=torch.long, device=device) + + seg = torch.empty((V + 1,), dtype=torch.long, device=device) + seg[0] = 0 + if V > 0: + idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1 + seg[1:] = prefix_sum_neighbor_mask[idxs] + + return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg + +def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor: + + x = x.to(torch.int64) + + m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device) + m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device) + m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device) + h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device) + + x = x - ((x >> 1) & m1) + x = (x & m2) + ((x >> 2) & m2) + x = (x + (x >> 4)) & m4 + x = (x * h01) >> 56 + return x.to(torch.int32) + + +def neighbor_map_post_process_for_masked_implicit_gemm_2( + gray_code: torch.Tensor, + sorted_idx: torch.Tensor, + block_size: int +): + device = gray_code.device + N = gray_code.numel() + num_blocks = (N + block_size - 1) // block_size + + pad = num_blocks * block_size - N + if pad > 0: + pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device) + gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0) + else: + gray_padded = gray_code[sorted_idx] + + gray_blocks = gray_padded.view(num_blocks, block_size) + + reduced_code = gray_blocks + while reduced_code.shape[1] > 1: + half = reduced_code.shape[1] // 2 + remainder = reduced_code.shape[1] % 2 + + left = reduced_code[:, :half * 2:2] + right = reduced_code[:, 1:half * 2:2] + merged = left | right + + if remainder: + reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1) + else: + reduced_code = merged + + reduced_code = reduced_code.squeeze(1) + + seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32) + + seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device) + seg[0] = 0 + if num_blocks > 0: + seg[1:] = torch.cumsum(seglen_counts, dim=0) + + total = int(seg[-1].item()) + + if total == 0: + return torch.empty((0,), dtype=torch.int32, device=device), seg + + V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0 + + if V == 0: + return torch.empty((0,), dtype=torch.int32, device=device), seg + + bit_pos = torch.arange(0, V, dtype=torch.int32, device=device) + shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0) + bits = (shifted & 1).to(torch.bool) + + positions = bit_pos.unsqueeze(0).expand(num_blocks, V) + valid_kernel_idx = positions[bits].to(torch.int32).contiguous() + + return valid_kernel_idx, seg + + +def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation): + if NO_TRITON: # TODO + raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.") + if feats.shape[0] == 0: + logging.warning("Found feats to be empty!") + Co = weight.shape[0] + return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None + if len(shape) == 5: + N, C, W, H, D = shape + else: + W, H, D = shape + + Co, Kw, Kh, Kd, Ci = weight.shape + + b_stride = W * H * D + x_stride = H * D + y_stride = D + z_stride = 1 + + flat_keys = (coords[:, 0].long() * b_stride + + coords[:, 1].long() * x_stride + + coords[:, 2].long() * y_stride + + coords[:, 3].long() * z_stride) + + vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device) + + hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF) + + if neighbor_cache is None: + neighbor = build_submanifold_neighbor_map( + hashmap, coords, W, H, D, Kw, Kh, Kd, + dilation[0], dilation[1], dilation[2] + ) + else: + neighbor = neighbor_cache + + block_size = 128 + + gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \ + neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor) + + valid_kernel, valid_kernel_seg = \ + neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size) + + valid_kernel_fn = lambda b_size: valid_kernel + valid_kernel_seg_fn = lambda b_size: valid_kernel_seg + + weight_flat = weight.contiguous().view(Co, -1, Ci) + + out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk( + feats, + weight_flat, + bias, + neighbor, + sorted_idx, + valid_kernel_fn, + valid_kernel_seg_fn + ) + + return out, neighbor + +class Mesh: + def __init__(self, + vertices, + faces, + vertex_attrs=None + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.vertex_attrs = vertex_attrs + + @property + def device(self): + return self.vertices.device + + def to(self, device, non_blocking=False): + return Mesh( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None, + ) + + def cuda(self, non_blocking=False): + return self.to('cuda', non_blocking=non_blocking) + + def cpu(self): + return self.to('cpu') diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py new file mode 100644 index 000000000..a613fb325 --- /dev/null +++ b/comfy/ldm/trellis2/model.py @@ -0,0 +1,875 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from comfy.ldm.trellis2.vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor +from typing import Optional, Tuple, Literal, Union, List +from comfy.ldm.trellis2.attention import ( + sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention +) +from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder +from comfy.ldm.flux.math import apply_rope, apply_rope1 + +class SparseGELU(nn.GELU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio), device=device, dtype=dtype, operations=operations), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels, device=device, dtype=dtype, operations=operations), + ) + + def forward(self, x: VarLenTensor) -> VarLenTensor: + return self.mlp(x) + +def manual_cast(obj, dtype): + return obj.to(dtype=dtype) + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int, device, dtype): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) + + def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, VarLenTensor): + x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale) + else: + x = F.normalize(x, dim=-1) * self.gamma * self.scale + return x.to(x_type) + +class SparseRotaryPositionEmbedder(nn.Module): + def __init__( + self, + head_dim: int, + dim: int = 3, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + device=None + ): + super().__init__() + self.head_dim = head_dim + self.dim = dim + self.rope_freq = rope_freq + self.freq_dim = head_dim // 2 // dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32, device=device) / self.freq_dim + self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) + + def _get_freqs_cis(self, coords: torch.Tensor) -> torch.Tensor: + phases_list = [] + for i in range(self.dim): + phases_list.append(torch.outer(coords[..., i], self.freqs.to(coords.device))) + + phases = torch.cat(phases_list, dim=-1) + + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.zeros(*phases.shape[:-1], padn, device=phases.device)], dim=-1) + + cos = torch.cos(phases) + sin = torch.sin(phases) + + f_cis_0 = torch.stack([cos, sin], dim=-1) + f_cis_1 = torch.stack([-sin, cos], dim=-1) + freqs_cis = torch.stack([f_cis_0, f_cis_1], dim=-1) + + return freqs_cis + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def forward(self, q, k=None): + cache_name = f'rope_cis_{self.dim}d_f{self.rope_freq[1]}_hd{self.head_dim}' + freqs_cis = q.get_spatial_cache(cache_name) + + if freqs_cis is None: + coords = q.coords[..., 1:].to(torch.float32) + freqs_cis = self._get_freqs_cis(coords) + q.register_spatial_cache(cache_name, freqs_cis) + + if q.feats.ndim == 3: + f_cis = freqs_cis.unsqueeze(1) + else: + f_cis = freqs_cis + + if k is None: + return q.replace(apply_rope1(q.feats, f_cis)) + + q_feats, k_feats = apply_rope(q.feats, k.feats, f_cis) + return q.replace(q_feats), k.replace(k_feats) + + @staticmethod + def apply_rotary_embedding(x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + +class RotaryPositionEmbedder(SparseRotaryPositionEmbedder): + def forward(self, indices: torch.Tensor) -> torch.Tensor: + phases = self._get_phases(indices.reshape(-1)).reshape(*indices.shape[:-1], -1) + if torch.is_complex(phases): + phases = phases.to(torch.complex64) + else: + phases = phases.to(torch.float32) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32), + torch.zeros(*phases.shape[:-1], padn, device=phases.device, dtype=torch.float32) + )], dim=-1) + return phases + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed", "double_windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + device=None, dtype=None, operations=None + ): + super().__init__() + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, device=device, dtype=dtype) + else: + self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype) + self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) + self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) + + self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype) + + if use_rope: + self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq, device=device) + + @staticmethod + def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats + + def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor: + if self._type == "self": + dtype = next(self.to_qkv.parameters()).dtype + x = x.to(dtype) + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=-3) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + q, k = self.rope(q, k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + elif self.attn_mode == "double_windowed": + qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:]) + qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2]) + h0 = sparse_windowed_scaled_dot_product_self_attention( + qkv0, self.window_size, shift_window=(0, 0, 0) + ) + h1 = sparse_windowed_scaled_dot_product_self_attention( + qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3) + ) + h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1)) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + dtype = next(self.to_kv.parameters()).dtype + context = context.to(dtype) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=-3) + k = self.k_rms_norm(k) + h = sparse_scaled_dot_product_attention(q, k, v) + else: + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + device=None, dtype=None, operations=None + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + device=device, dtype=dtype, operations=operations + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + device=device, dtype=dtype, operations=operations + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(channels, 6 * channels, bias=True, device=device, dtype=dtype) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.self_attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + return self._forward(x, mod, context) + + +class SLatFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "rope", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + dtype = None, + device = None, + operations = None, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = dtype + + self.t_embedder = TimestepEmbedder(model_channels, device=device, dtype=dtype, operations=operations) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype) + ) + + self.input_layer = SparseLinear(in_channels, model_channels, device=device, dtype=dtype, operations=operations) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations + ) + for _ in range(num_blocks) + ]) + + self.out_layer = SparseLinear(model_channels, out_channels, device=device, dtype=dtype, operations=operations) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward( + self, + x: SparseTensor, + t: torch.Tensor, + cond: Union[torch.Tensor, List[torch.Tensor]], + concat_cond: Optional[SparseTensor] = None, + **kwargs + ) -> SparseTensor: + if concat_cond is not None: + x = sparse_cat([x, concat_cond], dim=-1) + if isinstance(cond, list): + cond = VarLenTensor.from_tensor_list(cond) + + dtype = next(self.input_layer.parameters()).dtype + x = x.to(dtype) + h = self.input_layer(x) + h = manual_cast(h, self.dtype) + t = t.to(dtype) + t_embedder = self.t_embedder.to(dtype) + t_emb = t_embedder(t, out_dtype = t.dtype) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + cond = manual_cast(cond, self.dtype) + + for block in self.blocks: + h = block(h, t_emb, cond) + + h = manual_cast(h, x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return h + +class FeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0, device=None, dtype=None, operations=None): + super().__init__() + self.mlp = nn.Sequential( + operations.Linear(channels, int(channels * mlp_ratio), device=device, dtype=dtype), + nn.GELU(approximate="tanh"), + operations.Linear(int(channels * mlp_ratio), channels, device=device, dtype=dtype), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.mlp(x) + +class MultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int, device=None, dtype=None): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim, device=device, dtype=dtype)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return (F.normalize(x.float(), dim = -1) * self.gamma * self.scale).to(x.dtype) + + +class MultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int]=None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + device=None, dtype=None, operations=None + ): + super().__init__() + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = operations.Linear(channels, channels * 3, bias=qkv_bias, dtype=dtype, device=device) + else: + self.to_q = operations.Linear(channels, channels, bias=qkv_bias, device=device, dtype=dtype) + self.to_kv = operations.Linear(self.ctx_channels, channels * 2, bias=qkv_bias, device=device, dtype=dtype) + + if self.qk_rms_norm: + self.q_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) + self.k_rms_norm = MultiHeadRMSNorm(self.head_dim, num_heads, device=device, dtype=dtype) + + self.to_out = operations.Linear(channels, channels, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + B, L, C = x.shape + if self._type == "self": + x = x.to(next(self.to_qkv.parameters()).dtype) + qkv = self.to_qkv(x) + qkv = qkv.reshape(B, L, 3, self.num_heads, -1) + + if self.attn_mode == "full": + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + assert phases is not None, "Phases must be provided for RoPE" + q = RotaryPositionEmbedder.apply_rotary_embedding(q, phases) + k = RotaryPositionEmbedder.apply_rotary_embedding(k, phases) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(qkv) + else: + Lkv = context.shape[1] + q = self.to_q(x) + context = context.to(next(self.to_kv.parameters()).dtype) + kv = self.to_kv(context) + q = q.reshape(B, L, self.num_heads, -1) + kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=2) + k = self.k_rms_norm(k) + h = scaled_dot_product_attention(q, k, v) + else: + h = scaled_dot_product_attention(q, kv) + h = h.reshape(B, L, -1) + h = self.to_out(h) + return h + +class ModulatedTransformerCrossBlock(nn.Module): + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + device=None, dtype=None, operations=None + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6, device=device) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6, device=device) + self.self_attn = MultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + device=device, dtype=dtype, operations=operations + ) + self.cross_attn = MultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations + ) + self.mlp = FeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + device=device, dtype=dtype, operations=operations + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(channels, 6 * channels, bias=True, dtype=dtype, device=device) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels, device=device, dtype=dtype) / channels ** 0.5) + + def _forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = self.norm1(x) + h = h * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + h = self.self_attn(h, phases=phases) + h = h * gate_msa.unsqueeze(1) + x = x + h + h = self.norm2(x) + h = self.cross_attn(h, context) + x = x + h + h = self.norm3(x) + h = h * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1) + h = self.mlp(h) + h = h * gate_mlp.unsqueeze(1) + x = x + h + return x + + def forward(self, x: torch.Tensor, mod: torch.Tensor, context: torch.Tensor, phases: Optional[torch.Tensor] = None) -> torch.Tensor: + return self._forward(x, mod, context, phases) + + +class SparseStructureFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "rope", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + operations=None, + device = None, + dtype = torch.float32, + **kwargs + ): + super().__init__() + self.device = device + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = dtype + self.device = device + + self.t_embedder = TimestepEmbedder(model_channels, dtype=dtype, device=device, operations=operations) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(model_channels, 6 * model_channels, bias=True, device=device, dtype=dtype) + ) + + pos_embedder = RotaryPositionEmbedder(self.model_channels // self.num_heads, 3, device=device) + coords = torch.meshgrid(*[torch.arange(res, device=self.device, dtype=dtype) for res in [resolution] * 3], indexing='ij') + coords = torch.stack(coords, dim=-1).reshape(-1, 3) + rope_phases = pos_embedder(coords) + self.register_buffer("rope_phases", rope_phases, persistent=False) + + if pe_mode != "rope": + self.rope_phases = None + + self.input_layer = operations.Linear(in_channels, model_channels, device=device, dtype=dtype) + + self.blocks = nn.ModuleList([ + ModulatedTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + device=device, dtype=dtype, operations=operations + ) + for _ in range(num_blocks) + ]) + + self.out_layer = operations.Linear(model_channels, out_channels, device=device, dtype=dtype) + + def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: + x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3) + + h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() + + h = h.to(next(self.input_layer.parameters()).dtype) + h = self.input_layer(h) + t_emb = self.t_embedder(t, out_dtype = t.dtype) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + h = manual_cast(h, self.dtype) + cond = manual_cast(cond, self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond, self.rope_phases) + h = manual_cast(h, x.dtype) + h = F.layer_norm(h, h.shape[-1:]) + h = h.to(next(self.out_layer.parameters()).dtype) + h = self.out_layer(h) + + h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous() + + return h + +def timestep_reshift(t_shifted, old_shift=3.0, new_shift=5.0): + t_shifted = t_shifted / 1000.0 + t_linear = t_shifted / (old_shift - t_shifted * (old_shift - 1)) + t_new = (new_shift * t_linear) / (1 + (new_shift - 1) * t_linear) + t_new *= 1000.0 + return t_new + +class Trellis2(nn.Module): + def __init__(self, resolution, + in_channels = 32, + out_channels = 32, + model_channels = 1536, + cond_channels = 1024, + num_blocks = 30, + num_heads = 12, + mlp_ratio = 5.3334, + share_mod = True, + qk_rms_norm = True, + qk_rms_norm_cross = True, + init_txt_model=False, # for now + dtype=None, device=None, operations=None, **kwargs): + + super().__init__() + self.dtype = dtype + operations = operations or nn + # for some reason it passes num_heads = -1 + if num_heads == -1: + num_heads = 12 + args = { + "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, + "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, + "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations + } + self.img2shape = SLatFlowModel(resolution=resolution, in_channels=in_channels, **args) + self.shape2txt = None + if init_txt_model: + self.shape2txt = SLatFlowModel(resolution=resolution, in_channels=in_channels*2, **args) + self.img2shape_512 = SLatFlowModel(resolution=32, in_channels=in_channels, **args) + args.pop("out_channels") + self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) + self.guidance_interval = [0.6, 1.0] + self.guidance_interval_txt = [0.6, 0.9] + + def forward(self, x, timestep, context, **kwargs): + transformer_options = kwargs.get("transformer_options", {}) + embeds = kwargs.get("embeds") + if embeds is None: + raise ValueError("Trellis2.forward requires 'embeds' in kwargs") + is_1024 = self.img2shape.resolution == 1024 + coords = transformer_options.get("coords", None) + mode = transformer_options.get("generation_mode", "structure_generation") + is_512_run = False + timestep = timestep.to(self.dtype) + if mode == "shape_generation_512": + is_512_run = True + mode = "shape_generation" + if coords is not None: + x = x.squeeze(-1).transpose(1, 2) + not_struct_mode = True + else: + mode = "structure_generation" + not_struct_mode = False + + if is_1024 and not_struct_mode and not is_512_run: + context = embeds + + sigmas = transformer_options.get("sigmas")[0].item() + if sigmas < 1.00001: + timestep *= 1000.0 + if context.size(0) > 1: + cond = context.chunk(2)[1] + else: + cond = context + shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] + txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] + + if not_struct_mode: + orig_bsz = x.shape[0] + rule = txt_rule if mode == "texture_generation" else shape_rule + + if rule and orig_bsz > 1: + x_eval = x[1].unsqueeze(0) + t_eval = timestep[1].unsqueeze(0) if timestep.shape[0] > 1 else timestep + c_eval = cond + else: + x_eval = x + t_eval = timestep + c_eval = context + + B, N, C = x_eval.shape + + if mode in ["shape_generation", "texture_generation"]: + feats_flat = x_eval.reshape(-1, C) + + # inflate coords [N, 4] -> [B*N, 4] + coords_list = [] + for i in range(B): + c = coords.clone() + c[:, 0] = i + coords_list.append(c) + + batched_coords = torch.cat(coords_list, dim=0) + else: + batched_coords = coords + feats_flat = x_eval + + x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) + + if mode == "shape_generation": + if is_512_run: + out = self.img2shape_512(x_st, t_eval, c_eval) + else: + out = self.img2shape(x_st, t_eval, c_eval) + elif mode == "texture_generation": + if self.shape2txt is None: + raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") + slat = transformer_options.get("shape_slat") + if slat is None: + raise ValueError("shape_slat can't be None") + + base_slat_feats = slat[:N] + slat_feats_batched = base_slat_feats.repeat(B, 1).to(x_st.device) + x_st = x_st.replace(feats=torch.cat([x_st.feats, slat_feats_batched], dim=-1)) + out = self.shape2txt(x_st, t_eval, c_eval) + else: # structure + orig_bsz = x.shape[0] + if shape_rule: + x = x[1].unsqueeze(0) + timestep = timestep[1].unsqueeze(0) + out = self.structure_model(x, timestep, context if not shape_rule else cond) + if shape_rule: + out = out.repeat(orig_bsz, 1, 1, 1, 1) + + if not_struct_mode: + out = out.feats + out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) + if rule and orig_bsz > 1: + out = out.repeat(orig_bsz, 1, 1, 1) + return out diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py new file mode 100644 index 000000000..30f902868 --- /dev/null +++ b/comfy/ldm/trellis2/vae.py @@ -0,0 +1,1444 @@ +import math +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from fractions import Fraction +from dataclasses import dataclass +from typing import List, Any, Dict, Optional, overload, Union, Tuple +from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, sparse_submanifold_conv3d + + +def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: + B, C, H, W, D = x.shape + C_ = C // scale_factor**3 + x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4) + x = x.reshape(B, C_, H*scale_factor, W*scale_factor, D*scale_factor) + return x + +class SparseConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + super(SparseConv3d, self).__init__() + sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride, dilation, padding, bias, indice_key) + + def forward(self, x): + return sparse_conv3d_forward(self, x) + + +def sparse_conv3d_init(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, padding=None, bias=True, indice_key=None): + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = tuple(kernel_size) if isinstance(kernel_size, (list, tuple)) else (kernel_size, ) * 3 + self.stride = tuple(stride) if isinstance(stride, (list, tuple)) else (stride, ) * 3 + self.dilation = tuple(dilation) if isinstance(dilation, (list, tuple)) else (dilation, ) * 3 + + self.weight = nn.Parameter(torch.empty((out_channels, in_channels, *self.kernel_size))) + if bias: + self.bias = nn.Parameter(torch.empty(out_channels)) + else: + self.register_parameter("bias", None) + + if self.bias is not None: + fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight) + if fan_in != 0: + bound = 1 / math.sqrt(fan_in) + torch.nn.init.uniform_(self.bias, -bound, bound) + + # Permute weight (Co, Ci, Kd, Kh, Kw) -> (Co, Kd, Kh, Kw, Ci) + self.weight = nn.Parameter(self.weight.permute(0, 2, 3, 4, 1).contiguous()) + + +def sparse_conv3d_forward(self, x): + # check if neighbor map is already computed + Co, Kd, Kh, Kw, Ci = self.weight.shape + neighbor_cache_key = f'SubMConv3d_neighbor_cache_{Kw}x{Kh}x{Kd}_dilation{self.dilation}' + neighbor_cache = x.get_spatial_cache(neighbor_cache_key) + x = x.to(self.weight.dtype).to(self.weight.device) + + out, neighbor_cache_ = sparse_submanifold_conv3d( + x.feats, + x.coords, + x.spatial_shape, + self.weight, + self.bias, + neighbor_cache, + self.dilation + ) + + if neighbor_cache is None: + x.register_spatial_cache(neighbor_cache_key, neighbor_cache_) + + out = x.replace(out) + return out + +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = x.to(torch.float32) + w = self.weight.to(torch.float32) if self.weight is not None else None + b = self.bias.to(torch.float32) if self.bias is not None else None + + o = F.layer_norm(x, self.normalized_shape, w, b, self.eps) + return o.to(x_dtype) + +class SparseConvNeXtBlock3d(nn.Module): + def __init__( + self, + channels: int, + mlp_ratio: float = 4.0, + use_checkpoint: bool = False, + ): + super().__init__() + self.channels = channels + self.use_checkpoint = use_checkpoint + + self.norm = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.conv = SparseConv3d(channels, channels, 3) + self.mlp = nn.Sequential( + nn.Linear(channels, int(channels * mlp_ratio)), + nn.SiLU(), + nn.Linear(int(channels * mlp_ratio), channels), + ) + + def _forward(self, x): + h = self.conv(x) + h = h.replace(self.norm(h.feats)) + h = h.replace(self.mlp(h.feats)) + return h + x + + def forward(self, x): + return self._forward(x) + +class SparseSpatial2Channel(nn.Module): + def __init__(self, factor: int = 2): + super(SparseSpatial2Channel, self).__init__() + self.factor = factor + + def forward(self, x): + DIM = x.coords.shape[-1] - 1 + cache = x.get_spatial_cache(f'spatial2channel_{self.factor}') + if cache is None: + coord = list(x.coords.unbind(dim=-1)) + for i in range(DIM): + coord[i+1] = coord[i+1] // self.factor + subidx = x.coords[:, 1:] % self.factor + subidx = sum([subidx[..., i] * self.factor ** i for i in range(DIM)]) + + MAX = [(s + self.factor - 1) // self.factor for s in x.spatial_shape] + OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] + code = sum([c * o for c, o in zip(coord, OFFSET)]) + code, idx = code.unique(return_inverse=True) + + new_coords = torch.stack( + [code // OFFSET[0]] + + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], + dim=-1 + ) + else: + new_coords, idx, subidx = cache + + new_feats = torch.zeros(new_coords.shape[0] * self.factor ** DIM, x.feats.shape[1], device=x.feats.device, dtype=x.feats.dtype) + new_feats[idx * self.factor ** DIM + subidx] = x.feats + + out = SparseTensor(new_feats.reshape(new_coords.shape[0], -1), new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] * self.factor ** DIM])) + out._scale = tuple([s * self.factor for s in x._scale]) + out._spatial_cache = x._spatial_cache + + if cache is None: + x.register_spatial_cache(f'spatial2channel_{self.factor}', (new_coords, idx, subidx)) + out.register_spatial_cache(f'channel2spatial_{self.factor}', (x.coords, idx, subidx)) + out.register_spatial_cache('shape', torch.Size(MAX)) + + return out + +class SparseChannel2Spatial(nn.Module): + def __init__(self, factor: int = 2): + super(SparseChannel2Spatial, self).__init__() + self.factor = factor + + def forward(self, x, subdivision = None): + DIM = x.coords.shape[-1] - 1 + + cache = x.get_spatial_cache(f'channel2spatial_{self.factor}') + if cache is None: + if subdivision is None: + raise ValueError('Cache not found. Provide subdivision tensor or pair SparseChannel2Spatial with SparseSpatial2Channel.') + else: + sub = subdivision.feats # [N, self.factor ** DIM] + N_leaf = sub.sum(dim=-1) # [N] + subidx = sub.nonzero()[:, -1] + new_coords = x.coords.clone().detach() + new_coords[:, 1:] *= self.factor + new_coords = torch.repeat_interleave(new_coords, N_leaf, dim=0, output_size=subidx.shape[0]) + for i in range(DIM): + new_coords[:, i+1] += subidx // self.factor ** i % self.factor + idx = torch.repeat_interleave(torch.arange(x.coords.shape[0], device=x.device), N_leaf, dim=0, output_size=subidx.shape[0]) + else: + new_coords, idx, subidx = cache + + x_feats = x.feats.reshape(x.feats.shape[0] * self.factor ** DIM, -1) + new_feats = x_feats[idx * self.factor ** DIM + subidx] + out = SparseTensor(new_feats, new_coords, None if x._shape is None else torch.Size([x._shape[0], x._shape[1] // self.factor ** DIM])) + out._scale = tuple([s / self.factor for s in x._scale]) + if cache is not None: # only keep cache when subdiv following it + out._spatial_cache = x._spatial_cache + return out + +class SparseResBlockC2S3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + use_checkpoint: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_checkpoint = use_checkpoint + self.pred_subdiv = pred_subdiv + + self.norm1 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm2 = LayerNorm32(self.out_channels, elementwise_affine=False, eps=1e-6) + self.conv1 = SparseConv3d(channels, self.out_channels * 8, 3) + self.conv2 = SparseConv3d(self.out_channels, self.out_channels, 3) + self.skip_connection = lambda x: x.replace(x.feats.repeat_interleave(out_channels // (channels // 8), dim=1)) + if pred_subdiv: + self.to_subdiv = SparseLinear(channels, 8) + self.updown = SparseChannel2Spatial(2) + + def forward(self, x, subdiv = None): + if self.pred_subdiv: + dtype = next(self.to_subdiv.parameters()).dtype + x = x.to(dtype) + subdiv = self.to_subdiv(x) + h = x.replace(self.norm1(x.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv1(h) + subdiv_binarized = subdiv.replace(subdiv.feats > 0) if subdiv is not None else None + h = self.updown(h, subdiv_binarized) + x = self.updown(x, subdiv_binarized) + h = h.replace(self.norm2(h.feats)) + h = h.replace(F.silu(h.feats)) + h = self.conv2(h) + h = h + self.skip_connection(x) + if self.pred_subdiv: + return h, subdiv + else: + return h + +@dataclass +class config: + CONV = "flexgemm" + FLEX_GEMM_HASHMAP_RATIO = 2.0 + +class VarLenTensor: + + def __init__(self, feats: torch.Tensor, layout: List[slice]=None): + self.feats = feats + self.layout = layout if layout is not None else [slice(0, feats.shape[0])] + self._cache = {} + + @staticmethod + def layout_from_seqlen(seqlen: list) -> List[slice]: + """ + Create a layout from a tensor of sequence lengths. + """ + layout = [] + start = 0 + for l in seqlen: + layout.append(slice(start, start + l)) + start += l + return layout + + @staticmethod + def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor': + """ + Create a VarLenTensor from a list of tensors. + """ + feats = torch.cat(tensor_list, dim=0) + layout = [] + start = 0 + for tensor in tensor_list: + layout.append(slice(start, start + tensor.shape[0])) + start += tensor.shape[0] + return VarLenTensor(feats, layout) + + def __len__(self) -> int: + return len(self.layout) + + @property + def shape(self) -> torch.Size: + return torch.Size([len(self.layout), *self.feats.shape[1:]]) + + def dim(self) -> int: + return len(self.shape) + + @property + def ndim(self) -> int: + return self.dim() + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + if 'seqlen' not in self._cache: + self._cache['seqlen'] = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + return self._cache['seqlen'] + + @property + def cum_seqlen(self) -> torch.LongTensor: + if 'cum_seqlen' not in self._cache: + self._cache['cum_seqlen'] = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + return self._cache['cum_seqlen'] + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + if 'batch_boardcast_map' not in self._cache: + self._cache['batch_boardcast_map'] = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + return self._cache['batch_boardcast_map'] + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + def to(self, *args, **kwargs) -> 'VarLenTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'VarLenTensor': + new_feats = self.feats.cpu() + return self.replace(new_feats) + + def cuda(self) -> 'VarLenTensor': + new_feats = self.feats.cuda() + return self.replace(new_feats) + + def half(self) -> 'VarLenTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'VarLenTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'VarLenTensor': + new_feats = self.feats.detach() + return self.replace(new_feats) + + def reshape(self, *shape) -> 'VarLenTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['VarLenTensor']: + return varlen_unbind(self, dim) + + def replace(self, feats: torch.Tensor) -> 'VarLenTensor': + new_tensor = VarLenTensor( + feats=feats, + layout=self.layout, + ) + new_tensor._cache = self._cache + return new_tensor + + def to_dense(self, max_length=None) -> torch.Tensor: + N = len(self) + L = max_length or self.seqlen.max().item() + spatial = self.feats.shape[1:] + idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L) + mask = (idx < self.seqlen.unsqueeze(1)) + mapping = mask.reshape(-1).cumsum(dim=0) - 1 + dense = self.feats[mapping] + dense = dense.reshape(N, L, *spatial) + return dense, mask + + def __neg__(self) -> 'VarLenTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_feats = [] + new_layout = [] + start = 0 + for new_idx, old_idx in enumerate(idx): + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_feats[-1]))) + start += len(new_feats[-1]) + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = VarLenTensor(feats=new_feats, layout=new_layout) + return new_tensor + + def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + if isinstance(dim, int): + dim = (dim,) + + if op =='mean': + red = self.feats.mean(dim=dim, keepdim=keepdim) + elif op =='sum': + red = self.feats.sum(dim=dim, keepdim=keepdim) + elif op == 'prod': + red = self.feats.prod(dim=dim, keepdim=keepdim) + else: + raise ValueError(f"Unsupported reduce operation: {op}") + + if dim is None or 0 in dim: + return red + + red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen) + return red + + def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='mean', dim=dim, keepdim=keepdim) + + def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='sum', dim=dim, keepdim=keepdim) + + def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='prod', dim=dim, keepdim=keepdim) + + def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + mean = self.mean(dim=dim, keepdim=True) + mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True) + std = (mean2 - mean ** 2).sqrt() + return std + + def __repr__(self) -> str: + return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + +def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]: + + if dim == 0: + return [input[i] for i in range(len(input))] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] + + +class SparseTensor(VarLenTensor): + + SparseTensorData = None + + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + if self.SparseTensorData is None: + import importlib + if config.CONV == 'torchsparse': + self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif config.CONV == 'spconv': + self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape = args + (None,) * (3 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + if config.CONV == 'torchsparse': + self.data = self.SparseTensorData(feats, coords, **kwargs) + elif config.CONV == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1) + self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs) + self.data._features = feats + else: + self.data = { + 'feats': feats, + 'coords': coords, + } + elif method_id == 1: + data, shape = args + (None,) * (2 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + self.data = data + + self._shape = shape + self._scale = kwargs.get('scale', (Fraction(1, 1), Fraction(1, 1), Fraction(1, 1))) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + @staticmethod + def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Tensor]) -> 'SparseTensor': + """ + Create a SparseTensor from a list of tensors. + """ + feats = torch.cat(feats_list, dim=0) + coords = [] + for i, coord in enumerate(coords_list): + coord = torch.cat([torch.full_like(coord[:, :1], i), coord[:, 1:]], dim=1) + coords.append(coord) + coords = torch.cat(coords, dim=0) + return SparseTensor(feats, coords) + + def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Convert a SparseTensor to list of tensors. + """ + feats_list = [] + coords_list = [] + for s in self.layout: + feats_list.append(self.feats[s]) + coords_list.append(self.coords[s]) + return feats_list, coords_list + + def __len__(self) -> int: + return len(self.layout) + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + def __cal_spatial_shape(self, coords): + return torch.Size((coords[:, 1:].max(0)[0] + 1).tolist()) + + @property + def shape(self) -> torch.Size: + if self._shape is None: + self._shape = self.__cal_shape(self.feats, self.coords) + return self._shape + + @property + def layout(self) -> List[slice]: + layout = self.get_spatial_cache('layout') + if layout is None: + layout = self.__cal_layout(self.coords, self.shape[0]) + self.register_spatial_cache('layout', layout) + return layout + + @property + def spatial_shape(self) -> torch.Size: + spatial_shape = self.get_spatial_cache('shape') + if spatial_shape is None: + spatial_shape = self.__cal_spatial_shape(self.coords) + self.register_spatial_cache('shape', spatial_shape) + return spatial_shape + + @property + def feats(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.F + elif config.CONV == 'spconv': + return self.data.features + else: + return self.data['feats'] + + @feats.setter + def feats(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.F = value + elif config.CONV == 'spconv': + self.data.features = value + else: + self.data['feats'] = value + + @property + def coords(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.C + elif config.CONV == 'spconv': + return self.data.indices + else: + return self.data['coords'] + + @coords.setter + def coords(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.C = value + elif config.CONV == 'spconv': + self.data.indices = value + else: + self.data['coords'] = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + seqlen = self.get_spatial_cache('seqlen') + if seqlen is None: + seqlen = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + self.register_spatial_cache('seqlen', seqlen) + return seqlen + + @property + def cum_seqlen(self) -> torch.LongTensor: + cum_seqlen = self.get_spatial_cache('cum_seqlen') + if cum_seqlen is None: + cum_seqlen = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + self.register_spatial_cache('cum_seqlen', cum_seqlen) + return cum_seqlen + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + batch_boardcast_map = self.get_spatial_cache('batch_boardcast_map') + if batch_boardcast_map is None: + batch_boardcast_map = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + self.register_spatial_cache('batch_boardcast_map', batch_boardcast_map) + return batch_boardcast_map + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + new_coords = self.coords.to(device=device, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + if config.CONV == 'torchsparse': + new_data = self.SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif config.CONV == 'spconv': + new_data = self.SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + else: + new_data = { + 'feats': feats, + 'coords': self.data['coords'] if coords is None else coords, + } + new_tensor = SparseTensor( + new_data, + shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None, + scale=self._scale, + spatial_cache=self._spatial_cache + ) + return new_tensor + + def to_dense(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.dense() + elif config.CONV == 'spconv': + return self.data.dense() + else: + spatial_shape = self.spatial_shape + ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device) + idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1) + ret[tuple(idx)] = self.feats + return ret + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_coords = [] + new_feats = [] + new_layout = [] + new_shape = torch.Size([len(idx)] + list(self.shape[1:])) + start = 0 + for new_idx, old_idx in enumerate(idx): + new_coords.append(self.coords[self.layout[old_idx]].clone()) + new_coords[-1][:, 0] = new_idx + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_coords[-1]))) + start += len(new_coords[-1]) + new_coords = torch.cat(new_coords, dim=0).contiguous() + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = SparseTensor(feats=new_feats, coords=new_coords, shape=new_shape) + new_tensor.register_spatial_cache('layout', new_layout) + return new_tensor + + def clear_spatial_cache(self) -> None: + """ + Clear all spatial caches. + """ + self._spatial_cache = {} + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + def __repr__(self) -> str: + return f"SparseTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] + +# allow operations.Linear inheritance +class SparseLinear: + def __new__(cls, in_features, out_features, bias=True, device=None, dtype=None, operations=nn, *args, **kwargs): + class _SparseLinear(operations.Linear): + def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): + super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + return _SparseLinear(in_features, out_features, bias=bias, device=device, dtype=dtype, *args, **kwargs) + +MIX_PRECISION_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + SparseConv3d, + SparseLinear, +) + + +def convert_module_to_f16(l): + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.half() + +class SparseUnetVaeDecoder(nn.Module): + def __init__( + self, + out_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.out_channels = out_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.use_fp16 = use_fp16 + self.pred_subdiv = pred_subdiv + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.low_vram = False + + self.output_layer = SparseLinear(model_channels[-1], out_channels) + self.from_latent = SparseLinear(latent_channels, model_channels[0]) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[up_block_type[i]]( + model_channels[i], + model_channels[i+1], + pred_subdiv=pred_subdiv, + **block_args[i], + ) + ) + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor: + + dtype = next(self.from_latent.parameters()).dtype + device = next(self.from_latent.parameters()).device + x.feats = x.feats.to(dtype).to(device) + h = self.from_latent(x) + h = h.type(self.dtype) + subs = [] + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + if self.pred_subdiv: + h, sub = block(h) + subs.append(sub) + else: + h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None) + else: + h = block(h) + h = h.type(x.feats.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.output_layer(h) + if return_subs: + return h, subs + else: + return h + + def upsample(self, x: SparseTensor, upsample_times: int) -> torch.Tensor: + + h = self.from_latent(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + if i == upsample_times: + return h.coords + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + h, sub = block(h) + else: + h = block(h) + +class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): + def __init__( + self, + resolution: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + voxel_margin: float = 0.5, + use_fp16: bool = False, + ): + self.resolution = resolution + self.voxel_margin = voxel_margin + # cache for a TorchHashMap instance + self._torch_hashmap_cache = None + + super().__init__( + 7, + model_channels, + latent_channels, + num_blocks, + block_type, + up_block_type, + block_args, + use_fp16, + ) + + def set_resolution(self, resolution: int) -> None: + self.resolution = resolution + + def _build_or_get_hashmap(self, coords: torch.Tensor, grid_size: torch.Tensor): + device = coords.device + N = coords.shape[0] + # compute flat keys for all coords (prepend batch 0 same as original code) + b = torch.zeros((N,), dtype=torch.long, device=device) + x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32) + W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + flat_keys = b * (W * H * D) + x * (H * D) + y * D + z + values = torch.arange(N, dtype=torch.int32, device=device) + DEFAULT_VAL = 0xffffffff # sentinel used in original code + return TorchHashMap(flat_keys, values, DEFAULT_VAL) + + def forward(self, x: SparseTensor, gt_intersected: SparseTensor = None, **kwargs): + decoded = super().forward(x, **kwargs) + out_list = list(decoded) if isinstance(decoded, tuple) else [decoded] + h = out_list[0] + vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin) + intersected = h.replace(h.feats[..., 3:6] > 0) + quad_lerp = h.replace(F.softplus(h.feats[..., 6:7])) + mesh = [Mesh(*flexible_dual_grid_to_mesh( + v.coords[:, 1:], v.feats, i.feats, q.feats, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + grid_size=self.resolution, + train=False, + hashmap_builder=self._build_or_get_hashmap, + )) for v, i, q in zip(vertices, intersected, quad_lerp)] + out_list[0] = mesh + return out_list[0] if len(out_list) == 1 else tuple(out_list) + +def flexible_dual_grid_to_mesh( + coords: torch.Tensor, + dual_vertices: torch.Tensor, + intersected_flag: torch.Tensor, + split_weight: Union[torch.Tensor, None], + aabb: Union[list, tuple, np.ndarray, torch.Tensor], + voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None, + grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None, + train: bool = False, + hashmap_builder=None, # optional callable for building/caching a TorchHashMap +): + + device = coords.device + if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset") \ + or flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset.device != device: + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([ + [[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis + [[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis + [[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis + ], dtype=torch.int, device=device).unsqueeze(0) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1") or flexible_dual_grid_to_mesh.quad_split_1.device != device: + flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2") or flexible_dual_grid_to_mesh.quad_split_2.device != device: + flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train") or flexible_dual_grid_to_mesh.quad_split_train.device != device: + flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=device, requires_grad=False) + + # AABB + if isinstance(aabb, (list, tuple)): + aabb = np.array(aabb) + if isinstance(aabb, np.ndarray): + aabb = torch.tensor(aabb, dtype=torch.float32, device=device) + + # Voxel size + if voxel_size is not None: + if isinstance(voxel_size, float): + voxel_size = [voxel_size, voxel_size, voxel_size] + if isinstance(voxel_size, (list, tuple)): + voxel_size = np.array(voxel_size) + if isinstance(voxel_size, np.ndarray): + voxel_size = torch.tensor(voxel_size, dtype=torch.float32, device=coords.device) + grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int() + else: + if isinstance(grid_size, int): + grid_size = [grid_size, grid_size, grid_size] + if isinstance(grid_size, (list, tuple)): + grid_size = np.array(grid_size) + if isinstance(grid_size, np.ndarray): + grid_size = torch.tensor(grid_size, dtype=torch.int32, device=coords.device) + voxel_size = (aabb[1] - aabb[0]) / grid_size + + # Extract mesh + N = dual_vertices.shape[0] + + if hashmap_builder is None: + # build local TorchHashMap + device = coords.device + b = torch.zeros((N,), dtype=torch.long, device=device) + x, y, z = coords[:, 0].to(torch.int32), coords[:, 1].to(torch.int32), coords[:, 2].to(torch.int32) + W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + flat_keys = b * (W * H * D) + x * (H * D) + y * D + z + values = torch.arange(N, dtype=torch.long, device=device) + DEFAULT_VAL = 0xffffffff + torch_hashmap = TorchHashMap(flat_keys, values, DEFAULT_VAL) + else: + torch_hashmap = hashmap_builder(coords, grid_size) + + # Find connected voxels + edge_neighbor_voxel = coords.reshape(N, 1, 1, 3) + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset # (N, 3, 4, 3) + connected_voxel = edge_neighbor_voxel[intersected_flag] # (M, 4, 3) + M = connected_voxel.shape[0] + # flatten connected voxel coords and lookup + conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device) + conn_x = connected_voxel.reshape(-1, 3)[:, 0].to(torch.int32) + conn_y = connected_voxel.reshape(-1, 3)[:, 1].to(torch.int32) + conn_z = connected_voxel.reshape(-1, 3)[:, 2].to(torch.int32) + W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z + + conn_indices = torch_hashmap.lookup_flat(conn_flat).reshape(M, 4).int() + connected_voxel_valid = (conn_indices != 0xffffffff).all(dim=1) + quad_indices = conn_indices[connected_voxel_valid].int() # (L, 4) + + mesh_vertices = (coords.float() + dual_vertices) * voxel_size + aabb[0].reshape(1, 3) + if split_weight is None: + # if split 1 + atempt_triangles_0 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1] + normals0 = torch.cross(mesh_vertices[atempt_triangles_0[:, 1]] - mesh_vertices[atempt_triangles_0[:, 0]], mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 0]]) + normals1 = torch.cross(mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 1]], mesh_vertices[atempt_triangles_0[:, 3]] - mesh_vertices[atempt_triangles_0[:, 1]]) + align0 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() + # if split 2 + atempt_triangles_1 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + normals0 = torch.cross(mesh_vertices[atempt_triangles_1[:, 1]] - mesh_vertices[atempt_triangles_1[:, 0]], mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 0]]) + normals1 = torch.cross(mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 1]], mesh_vertices[atempt_triangles_1[:, 3]] - mesh_vertices[atempt_triangles_1[:, 1]]) + align1 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() + # select split + mesh_triangles = torch.where(align0 > align1, atempt_triangles_0, atempt_triangles_1).reshape(-1, 3) + else: + split_weight_ws = split_weight[quad_indices] + split_weight_ws_02 = split_weight_ws[:, 0] * split_weight_ws[:, 2] + split_weight_ws_13 = split_weight_ws[:, 1] * split_weight_ws[:, 3] + mesh_triangles = torch.where( + split_weight_ws_02 > split_weight_ws_13, + quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1], + quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + ).reshape(-1, 3) + + return mesh_vertices, mesh_triangles + +class ChannelLayerNorm32(LayerNorm32): + def forward(self, x: torch.Tensor) -> torch.Tensor: + DIM = x.dim() + x = x.permute(0, *range(2, DIM), 1).contiguous() + x = super().forward(x) + x = x.permute(0, DIM-1, *range(1, DIM-1)).contiguous() + return x + +class UpsampleBlock3d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + mode = "conv", + ): + assert mode in ["conv", "nearest"], f"Invalid mode {mode}" + + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + if mode == "conv": + self.conv = nn.Conv3d(in_channels, out_channels*8, 3, padding=1) + elif mode == "nearest": + assert in_channels == out_channels, "Nearest mode requires in_channels to be equal to out_channels" + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if hasattr(self, "conv"): + x = self.conv(x) + return pixel_shuffle_3d(x, 2) + else: + return F.interpolate(x, scale_factor=2, mode="nearest") + +def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module: + return ChannelLayerNorm32(*args, **kwargs) + +class ResBlock3d(nn.Module): + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + norm_type = "layer", + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.norm1 = norm_layer(norm_type, channels) + self.norm2 = norm_layer(norm_type, self.out_channels) + self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1) + self.conv2 = nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1) + self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + h = self.norm1(x) + h = F.silu(h) + dtype = next(self.conv1.parameters()).dtype + h = h.to(dtype) + h = self.conv1(h) + h = self.norm2(h) + h = F.silu(h) + h = self.conv2(h) + h = h + self.skip_connection(x) + return h + + +class SparseStructureDecoder(nn.Module): + def __init__( + self, + out_channels: int, + latent_channels: int, + num_res_blocks: int, + channels: List[int], + num_res_blocks_middle: int = 2, + norm_type = "layer", + use_fp16: bool = True, + ): + super().__init__() + self.out_channels = out_channels + self.latent_channels = latent_channels + self.num_res_blocks = num_res_blocks + self.channels = channels + self.num_res_blocks_middle = num_res_blocks_middle + self.norm_type = norm_type + self.use_fp16 = use_fp16 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = nn.Conv3d(latent_channels, channels[0], 3, padding=1) + + self.middle_block = nn.Sequential(*[ + ResBlock3d(channels[0], channels[0]) + for _ in range(num_res_blocks_middle) + ]) + + self.blocks = nn.ModuleList([]) + for i, ch in enumerate(channels): + self.blocks.extend([ + ResBlock3d(ch, ch) + for _ in range(num_res_blocks) + ]) + if i < len(channels) - 1: + self.blocks.append( + UpsampleBlock3d(ch, channels[i+1]) + ) + + self.out_layer = nn.Sequential( + norm_layer(norm_type, channels[-1]), + nn.SiLU(), + nn.Conv3d(channels[-1], out_channels, 3, padding=1) + ) + + if use_fp16: + self.convert_to_fp16() + + def device(self) -> torch.device: + return next(self.parameters()).device + + def convert_to_fp16(self) -> None: + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = next(self.input_layer.parameters()).dtype + x = x.to(dtype) + h = self.input_layer(x) + + h = h.type(self.dtype) + h = self.middle_block(h) + for block in self.blocks: + h = block(h) + + h = h.type(x.dtype) + h = self.out_layer(h) + return h + +class Vae(nn.Module): + def __init__(self, init_txt_model, operations=None): + super().__init__() + operations = operations or torch.nn + if init_txt_model: + self.txt_dec = SparseUnetVaeDecoder( + out_channels=6, + model_channels=[1024, 512, 256, 128, 64], + latent_channels=32, + num_blocks=[4, 16, 8, 4, 0], + block_type=["SparseConvNeXtBlock3d"] * 5, + up_block_type=["SparseResBlockC2S3d"] * 4, + block_args=[{}, {}, {}, {}, {}], + pred_subdiv=False + ) + + self.shape_dec = FlexiDualGridVaeDecoder( + resolution=256, + model_channels=[1024, 512, 256, 128, 64], + latent_channels=32, + num_blocks=[4, 16, 8, 4, 0], + block_type=["SparseConvNeXtBlock3d"] * 5, + up_block_type=["SparseResBlockC2S3d"] * 4, + block_args=[{}, {}, {}, {}, {}], + ) + + self.struct_dec = SparseStructureDecoder( + out_channels=1, + latent_channels=8, + num_res_blocks=2, + num_res_blocks_middle=2, + channels=[512, 128, 32], + ) + + @torch.no_grad() + def decode_shape_slat(self, slat, resolution: int): + self.shape_dec.set_resolution(resolution) + return self.shape_dec(slat, return_subs=True) + + @torch.no_grad() + def decode_tex_slat(self, slat, subs): + if self.txt_dec is None: + raise ValueError("Checkpoint doesn't include texture model") + return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 + + # shouldn't be called (placeholder) + @torch.no_grad() + def decode( + self, + shape_slat: SparseTensor, + tex_slat: SparseTensor, + resolution: int, + ): + meshes, subs = self.decode_shape_slat(shape_slat, resolution) + tex_voxels = self.decode_tex_slat(tex_slat, subs) + return tex_voxels diff --git a/comfy/model_base.py b/comfy/model_base.py index 5c2668ba9..5f258178f 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -51,6 +51,7 @@ import comfy.ldm.omnigen.omnigen2 import comfy.ldm.qwen_image.model import comfy.ldm.kandinsky5.model import comfy.ldm.anima.model +import comfy.ldm.trellis2.model import comfy.ldm.ace.ace_step15 import comfy.ldm.rt_detr.rtdetr_v4 import comfy.ldm.ernie.model @@ -1537,6 +1538,16 @@ class WAN22(WAN21): def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs): return latent_image +class Trellis2(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.trellis2.model.Trellis2): + super().__init__(model_config, model_type, device, unet_model) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + embeds = kwargs.get("embeds") + out["embeds"] = comfy.conds.CONDRegular(embeds) + return out + class WAN21_FlowRVS(WAN21): def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None): model_config.unet_config["model_type"] = "t2v" @@ -1578,7 +1589,6 @@ class WAN21_SCAIL(WAN21): pose_latents = kwargs.get("pose_video_latent", None) if pose_latents is not None: out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]] - return out class Hunyuan3Dv2(BaseModel): diff --git a/comfy/model_detection.py b/comfy/model_detection.py index ca06cdd1e..8255683d1 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -113,6 +113,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]] return unet_config + if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: + unet_config = {} + unet_config["image_model"] = "trellis2" + + unet_config["init_txt_model"] = False + if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys: + unet_config["init_txt_model"] = True + + unet_config["resolution"] = 64 + if metadata is not None: + if "is_512" in metadata: + unet_config["resolution"] = 32 + + unet_config["num_heads"] = 12 + return unet_config + if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit unet_config = {} unet_config["audio_model"] = "dit1.0" diff --git a/comfy/sd.py b/comfy/sd.py index e573804a5..0cba71e3f 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -14,6 +14,7 @@ import comfy.ldm.genmo.vae.model import comfy.ldm.lightricks.vae.causal_video_autoencoder import comfy.ldm.cosmos.vae import comfy.ldm.wan.vae +import comfy.ldm.trellis2.vae import comfy.ldm.wan.vae2_2 import comfy.ldm.hunyuan3d.vae import comfy.ldm.ace.vae.music_dcae_pipeline @@ -507,6 +508,15 @@ class VAE: self.first_stage_model = StageC_coder() self.downscale_ratio = 32 self.latent_channels = 16 + elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2 + init_txt_model = False + if "txt_dec.blocks.1.16.norm1.weight" in sd: + init_txt_model = True + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + # TODO + self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model) elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True} diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 58d4ce731..927ae79ba 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -1273,6 +1273,29 @@ class WAN22_T2V(WAN21_T2V): out = model_base.WAN22(self, image_to_video=True, device=device) return out +class Trellis2(supported_models_base.BASE): + unet_config = { + "image_model": "trellis2" + } + + sampling_settings = { + "shift": 3.0, + } + + memory_usage_factor = 3.5 + + latent_format = latent_formats.Trellis2 + vae_key_prefix = ["vae."] + clip_vision_prefix = "conditioner.main_image_encoder.model." + # this is only needed for the texture model + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + def get_model(self, state_dict, prefix="", device=None): + return model_base.Trellis2(self, device=device) + + def clip_target(self, state_dict={}): + return None + class WAN21_FlowRVS(WAN21_T2V): unet_config = { "image_model": "wan2.1", @@ -1293,6 +1316,7 @@ class WAN21_SCAIL(WAN21_T2V): out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device) return out + class Hunyuan3Dv2(supported_models_base.BASE): unet_config = { "image_model": "hunyuan3d2", @@ -1664,6 +1688,7 @@ class Kandinsky5Image(Kandinsky5): return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect)) + class ACEStep15(supported_models_base.BASE): unet_config = { "audio_model": "ace1.5", @@ -1703,7 +1728,6 @@ class ACEStep15(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect)) - class LongCatImage(supported_models_base.BASE): unet_config = { "image_model": "flux", @@ -1781,6 +1805,6 @@ class ErnieImage(supported_models_base.BASE): return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect)) -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage] +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, Trellis2] models += [SVD_img2vid] diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py index df0c3e4b1..ac91fe0a7 100644 --- a/comfy_extras/nodes_hunyuan3d.py +++ b/comfy_extras/nodes_hunyuan3d.py @@ -484,7 +484,7 @@ class VoxelToMesh(IO.ComfyNode): decode = execute # TODO: remove -def save_glb(vertices, faces, filepath, metadata=None): +def save_glb(vertices, faces, filepath, metadata=None, colors=None): """ Save PyTorch tensor vertices and faces as a GLB file without external dependencies. @@ -515,6 +515,13 @@ def save_glb(vertices, faces, filepath, metadata=None): indices_byte_length = len(indices_buffer) indices_byte_offset = len(vertices_buffer_padded) + if colors is not None: + colors_np = colors.cpu().numpy().astype(np.float32) + colors_buffer = colors_np.tobytes() + colors_byte_length = len(colors_buffer) + colors_byte_offset = len(buffer_data) + buffer_data += pad_to_4_bytes(colors_buffer) + gltf = { "asset": {"version": "2.0", "generator": "ComfyUI"}, "buffers": [ @@ -580,6 +587,14 @@ def save_glb(vertices, faces, filepath, metadata=None): "scene": 0 } + if colors is not None: + gltf["bufferViews"].append({"buffer": 0, "byteOffset": colors_byte_offset, "byteLength": colors_byte_length, "target": 34962}) + gltf["accessors"].append({"bufferView": 2, "byteOffset": 0, "componentType": 5126, "count": len(colors_np), "type": "VEC3"}) + gltf["meshes"][0]["primitives"][0]["attributes"]["COLOR_0"] = 2 + # Define a base material so Three.js actually activates vertex coloring + gltf["materials"] =[{"pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0]}}] + gltf["meshes"][0]["primitives"][0]["material"] = 0 + if metadata is not None: gltf["asset"]["extras"] = metadata @@ -669,7 +684,8 @@ class SaveGLB(IO.ComfyNode): # Handle Mesh input - save vertices and faces as GLB for i in range(mesh.vertices.shape[0]): f = f"{filename}_{counter:05}_.glb" - save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata) + v_colors = mesh.colors[i] if hasattr(mesh, "colors") and mesh.colors is not None else None + save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata, v_colors) results.append({ "filename": f, "subfolder": subfolder, diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py new file mode 100644 index 000000000..3479d5410 --- /dev/null +++ b/comfy_extras/nodes_trellis2.py @@ -0,0 +1,693 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, Types +from comfy.ldm.trellis2.vae import SparseTensor +import comfy.model_management +from PIL import Image +import numpy as np +import torch +import scipy +import copy + +shape_slat_normalization = { + "mean": torch.tensor([ + 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, + -0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944, + 0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667, + -0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665 + ])[None], + "std": torch.tensor([ + 5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004, + 5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578, + 4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194, + 5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691 + ])[None] +} + +tex_slat_normalization = { + "mean": torch.tensor([ + 3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075, + 0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149, + -1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717, + 1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862 + ])[None], + "std": torch.tensor([ + 2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822, + 2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588, + 2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999, + 2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190 + ])[None] +} + +def shape_norm(shape_latent, coords): + std = shape_slat_normalization["std"].to(shape_latent) + mean = shape_slat_normalization["mean"].to(shape_latent) + samples = SparseTensor(feats = shape_latent, coords=coords) + samples = samples * std + mean + return samples + +def paint_mesh_with_voxels(mesh, voxel_coords, voxel_colors, resolution): + """ + Generic function to paint a mesh using nearest-neighbor colors from a sparse voxel field. + """ + device = comfy.model_management.vae_offload_device() + + origin = torch.tensor([-0.5, -0.5, -0.5], device=device) + # TODO: generic independent node? if so: figure how pass the resolution parameter + voxel_size = 1.0 / resolution + + # map voxels + voxel_pos = voxel_coords.to(device).float() * voxel_size + origin + verts = mesh.vertices.to(device).squeeze(0) + voxel_colors = voxel_colors.to(device) + + voxel_pos_np = voxel_pos.numpy() + verts_np = verts.numpy() + + tree = scipy.spatial.cKDTree(voxel_pos_np) + + # nearest neighbour k=1 + _, nearest_idx_np = tree.query(verts_np, k=1, workers=-1) + + nearest_idx = torch.from_numpy(nearest_idx_np).long() + v_colors = voxel_colors[nearest_idx] + + # to [0, 1] + srgb_colors = v_colors.clamp(0, 1)#(v_colors * 0.5 + 0.5).clamp(0, 1) + + # to Linear RGB (required for GLTF) + linear_colors = torch.pow(srgb_colors, 2.2) + + final_colors = linear_colors.unsqueeze(0) + + out_mesh = copy.deepcopy(mesh) + out_mesh.colors = final_colors + + return out_mesh + +class VaeDecodeShapeTrellis(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VaeDecodeShapeTrellis", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.Combo.Input("resolution", options=["512", "1024"], default="1024") + ], + outputs=[ + IO.Mesh.Output("mesh"), + IO.AnyType.Output("shape_subs"), + ] + ) + + @classmethod + def execute(cls, samples, vae, resolution): + + resolution = int(resolution) + patcher = vae.patcher + device = comfy.model_management.get_torch_device() + comfy.model_management.load_model_gpu(patcher) + + vae = vae.first_stage_model + coords = samples["coords"] + + samples = samples["samples"] + samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) + samples = shape_norm(samples, coords) + + mesh, subs = vae.decode_shape_slat(samples, resolution) + faces = torch.stack([m.faces for m in mesh]) + verts = torch.stack([m.vertices for m in mesh]) + mesh = Types.MESH(vertices=verts, faces=faces) + return IO.NodeOutput(mesh, subs) + +class VaeDecodeTextureTrellis(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VaeDecodeTextureTrellis", + category="latent/3d", + inputs=[ + IO.Mesh.Input("shape_mesh"), + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.AnyType.Input("shape_subs"), + ], + outputs=[ + IO.Mesh.Output("mesh"), + ] + ) + + @classmethod + def execute(cls, shape_mesh, samples, vae, shape_subs): + + resolution = 1024 + patcher = vae.patcher + device = comfy.model_management.get_torch_device() + comfy.model_management.load_model_gpu(patcher) + + vae = vae.first_stage_model + coords = samples["coords"] + + samples = samples["samples"] + samples = samples.squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) + std = tex_slat_normalization["std"].to(samples) + mean = tex_slat_normalization["mean"].to(samples) + samples = SparseTensor(feats = samples, coords=coords) + samples = samples * std + mean + + voxel = vae.decode_tex_slat(samples, shape_subs) + color_feats = voxel.feats[:, :3] + voxel_coords = voxel.coords[:, 1:] + + out_mesh = paint_mesh_with_voxels(shape_mesh, voxel_coords, color_feats, resolution=resolution) + return IO.NodeOutput(out_mesh) + +class VaeDecodeStructureTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VaeDecodeStructureTrellis2", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.Combo.Input("resolution", options=["32", "64"], default="32") + ], + outputs=[ + IO.Voxel.Output("structure_output"), + ] + ) + + @classmethod + def execute(cls, samples, vae, resolution): + resolution = int(resolution) + vae = vae.first_stage_model + decoder = vae.struct_dec + load_device = comfy.model_management.get_torch_device() + offload_device = comfy.model_management.vae_offload_device() + decoder = decoder.to(load_device) + samples = samples["samples"] + samples = samples.to(load_device) + decoded = decoder(samples)>0 + decoder.to(offload_device) + current_res = decoded.shape[2] + + if current_res != resolution: + ratio = current_res // resolution + decoded = torch.nn.functional.max_pool3d(decoded.float(), ratio, ratio, 0) > 0.5 + out = Types.VOXEL(decoded.squeeze(1).float()) + return IO.NodeOutput(out) + +class Trellis2UpsampleCascade(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Trellis2UpsampleCascade", + category="latent/3d", + inputs=[ + IO.Latent.Input("shape_latent_512"), + IO.Vae.Input("vae"), + IO.Combo.Input("target_resolution", options=["1024", "1536"], default="1024"), + IO.Int.Input("max_tokens", default=49152, min=1024, max=100000) + ], + outputs=[ + IO.AnyType.Output("hr_coords"), + ] + ) + + @classmethod + def execute(cls, shape_latent_512, vae, target_resolution, max_tokens): + device = comfy.model_management.get_torch_device() + comfy.model_management.load_model_gpu(vae.patcher) + + feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) + coords_512 = shape_latent_512["coords"].to(device) + + slat = shape_norm(feats, coords_512) + + decoder = vae.first_stage_model.shape_dec + + slat.feats = slat.feats.to(next(decoder.parameters()).dtype) + hr_coords = decoder.upsample(slat, upsample_times=4) + + lr_resolution = 512 + hr_resolution = int(target_resolution) + + while True: + quant_coords = torch.cat([ + hr_coords[:, :1], + ((hr_coords[:, 1:] + 0.5) / lr_resolution * (hr_resolution // 16)).int(), + ], dim=1) + final_coords = quant_coords.unique(dim=0) + num_tokens = final_coords.shape[0] + + if num_tokens < max_tokens or hr_resolution <= 1024: + break + hr_resolution -= 128 + + return IO.NodeOutput(final_coords,) + +dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) +dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + +def run_conditioning(model, cropped_img_tensor, include_1024=True): + model_internal = model.model + device = comfy.model_management.intermediate_device() + torch_device = comfy.model_management.get_torch_device() + + def prepare_tensor(pil_img, size): + resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) + img_np = np.array(resized_pil).astype(np.float32) / 255.0 + img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device) + return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device) + + model_internal.image_size = 512 + input_512 = prepare_tensor(cropped_img_tensor, 512) + cond_512 = model_internal(input_512, skip_norm_elementwise=True)[0] + + cond_1024 = None + if include_1024: + model_internal.image_size = 1024 + input_1024 = prepare_tensor(cropped_img_tensor, 1024) + cond_1024 = model_internal(input_1024, skip_norm_elementwise=True)[0] + + conditioning = { + 'cond_512': cond_512.to(device), + 'neg_cond': torch.zeros_like(cond_512).to(device), + } + if cond_1024 is not None: + conditioning['cond_1024'] = cond_1024.to(device) + + return conditioning + +class Trellis2Conditioning(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Trellis2Conditioning", + category="conditioning/video_models", + inputs=[ + IO.ClipVision.Input("clip_vision_model"), + IO.Image.Input("image"), + IO.Mask.Input("mask"), + IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black") + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ] + ) + + @classmethod + def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput: + + if image.ndim == 4: + image = image[0] + if mask.ndim == 3: + mask = mask[0] + + img_np = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + mask_np = (mask.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + + pil_img = Image.fromarray(img_np) + pil_mask = Image.fromarray(mask_np) + + max_size = max(pil_img.size) + scale = min(1.0, 1024 / max_size) + if scale < 1.0: + new_w, new_h = int(pil_img.width * scale), int(pil_img.height * scale) + pil_img = pil_img.resize((new_w, new_h), Image.Resampling.LANCZOS) + pil_mask = pil_mask.resize((new_w, new_h), Image.Resampling.NEAREST) + + rgba_np = np.zeros((pil_img.height, pil_img.width, 4), dtype=np.uint8) + rgba_np[:, :, :3] = np.array(pil_img) + rgba_np[:, :, 3] = np.array(pil_mask) + + alpha = rgba_np[:, :, 3] + bbox_coords = np.argwhere(alpha > 0.8 * 255) + + if len(bbox_coords) > 0: + y_min, x_min = np.min(bbox_coords[:, 0]), np.min(bbox_coords[:, 1]) + y_max, x_max = np.max(bbox_coords[:, 0]), np.max(bbox_coords[:, 1]) + + center_y, center_x = (y_min + y_max) / 2.0, (x_min + x_max) / 2.0 + size = max(y_max - y_min, x_max - x_min) + + crop_x1 = int(center_x - size // 2) + crop_y1 = int(center_y - size // 2) + crop_x2 = int(center_x + size // 2) + crop_y2 = int(center_y + size // 2) + + rgba_pil = Image.fromarray(rgba_np) + cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) + cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 + else: + import logging + logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") + cropped_np = rgba_np.astype(np.float32) / 255.0 + + bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} + bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) + + fg = cropped_np[:, :, :3] + alpha_float = cropped_np[:, :, 3:4] + composite_np = fg * alpha_float + bg_rgb * (1.0 - alpha_float) + + # to match trellis2 code (quantize -> dequantize) + composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) + + cropped_pil = Image.fromarray(composite_uint8) + + conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True) + + embeds = conditioning["cond_1024"] + positive = [[conditioning["cond_512"], {"embeds": embeds}]] + negative = [[conditioning["neg_cond"], {"embeds": torch.zeros_like(embeds)}]] + return IO.NodeOutput(positive, negative) + +class EmptyShapeLatentTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyShapeLatentTrellis2", + category="latent/3d", + inputs=[ + IO.AnyType.Input("structure_or_coords"), + IO.Model.Input("model") + ], + outputs=[ + IO.Latent.Output(), + IO.Model.Output() + ] + ) + + @classmethod + def execute(cls, structure_or_coords, model): + # to accept the upscaled coords + is_512_pass = False + + if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: + decoded = structure_or_coords.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + is_512_pass = True + + elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: + coords = structure_or_coords.int() + is_512_pass = False + + else: + raise ValueError(f"Invalid input to EmptyShapeLatent: {type(structure_or_coords)}") + in_channels = 32 + # image like format + latent = torch.randn(1, in_channels, coords.shape[0], 1) + model = model.clone() + model.model_options = model.model_options.copy() + if "transformer_options" in model.model_options: + model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() + else: + model.model_options["transformer_options"] = {} + + model.model_options["transformer_options"]["coords"] = coords + if is_512_pass: + model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512" + else: + model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) + +class EmptyTextureLatentTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyTextureLatentTrellis2", + category="latent/3d", + inputs=[ + IO.Voxel.Input("structure_or_coords"), + IO.Latent.Input("shape_latent"), + IO.Model.Input("model") + ], + outputs=[ + IO.Latent.Output(), + IO.Model.Output() + ] + ) + + @classmethod + def execute(cls, structure_or_coords, shape_latent, model): + channels = 32 + if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: + decoded = structure_or_coords.data.unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + + elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: + coords = structure_or_coords.int() + + shape_latent = shape_latent["samples"] + if shape_latent.ndim == 4: + shape_latent = shape_latent.squeeze(-1).transpose(1, 2).reshape(-1, channels) + + latent = torch.randn(1, channels, coords.shape[0], 1) + model = model.clone() + model.model_options = model.model_options.copy() + if "transformer_options" in model.model_options: + model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() + else: + model.model_options["transformer_options"] = {} + + model.model_options["transformer_options"]["coords"] = coords + model.model_options["transformer_options"]["generation_mode"] = "texture_generation" + model.model_options["transformer_options"]["shape_slat"] = shape_latent + return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) + + +class EmptyStructureLatentTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyStructureLatentTrellis2", + category="latent/3d", + inputs=[ + IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), + ], + outputs=[ + IO.Latent.Output(), + ] + ) + @classmethod + def execute(cls, batch_size): + in_channels = 8 + resolution = 16 + latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + +def simplify_fn(vertices, faces, colors=None, target=100000): + if vertices.ndim == 3: + v_list, f_list, c_list = [], [], [] + for i in range(vertices.shape[0]): + c_in = colors[i] if colors is not None else None + v_i, f_i, c_i = simplify_fn(vertices[i], faces[i], c_in, target) + v_list.append(v_i) + f_list.append(f_i) + if c_i is not None: + c_list.append(c_i) + + c_out = torch.stack(c_list) if len(c_list) > 0 else None + return torch.stack(v_list), torch.stack(f_list), c_out + + if faces.shape[0] <= target: + return vertices, faces, colors + + device = vertices.device + target_v = max(target / 4.0, 1.0) + + min_v = vertices.min(dim=0)[0] + max_v = vertices.max(dim=0)[0] + extent = max_v - min_v + + volume = (extent[0] * extent[1] * extent[2]).clamp(min=1e-8) + cell_size = (volume / target_v) ** (1/3.0) + + quantized = ((vertices - min_v) / cell_size).round().long() + unique_coords, inverse_indices = torch.unique(quantized, dim=0, return_inverse=True) + num_cells = unique_coords.shape[0] + + new_vertices = torch.zeros((num_cells, 3), dtype=vertices.dtype, device=device) + counts = torch.zeros((num_cells, 1), dtype=vertices.dtype, device=device) + new_vertices.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), vertices) + counts.scatter_add_(0, inverse_indices.unsqueeze(1), torch.ones_like(vertices[:, :1])) + new_vertices = new_vertices / counts.clamp(min=1) + + new_colors = None + if colors is not None: + new_colors = torch.zeros((num_cells, colors.shape[1]), dtype=colors.dtype, device=device) + new_colors.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, colors.shape[1]), colors) + new_colors = new_colors / counts.clamp(min=1) + + new_faces = inverse_indices[faces] + valid_mask = (new_faces[:, 0] != new_faces[:, 1]) & \ + (new_faces[:, 1] != new_faces[:, 2]) & \ + (new_faces[:, 2] != new_faces[:, 0]) + new_faces = new_faces[valid_mask] + + unique_face_indices, inv_face = torch.unique(new_faces.reshape(-1), return_inverse=True) + final_vertices = new_vertices[unique_face_indices] + final_faces = inv_face.reshape(-1, 3) + + # assign colors + final_colors = new_colors[unique_face_indices] if new_colors is not None else None + + return final_vertices, final_faces, final_colors + +def fill_holes_fn(vertices, faces, max_perimeter=0.03): + is_batched = vertices.ndim == 3 + if is_batched: + v_list, f_list = [],[] + for i in range(vertices.shape[0]): + v_i, f_i = fill_holes_fn(vertices[i], faces[i], max_perimeter) + v_list.append(v_i) + f_list.append(f_i) + return torch.stack(v_list), torch.stack(f_list) + + device = vertices.device + v = vertices + f = faces + + if f.numel() == 0: + return v, f + + edges = torch.cat([f[:, [0, 1]], f[:, [1, 2]], f[:, [2, 0]]], dim=0) + edges_sorted, _ = torch.sort(edges, dim=1) + + max_v = v.shape[0] + packed_undirected = edges_sorted[:, 0].long() * max_v + edges_sorted[:, 1].long() + + unique_packed, counts = torch.unique(packed_undirected, return_counts=True) + boundary_packed = unique_packed[counts == 1] + + if boundary_packed.numel() == 0: + return v, f + + packed_directed_sorted = edges[:, 0].min(edges[:, 1]).long() * max_v + edges[:, 0].max(edges[:, 1]).long() + is_boundary = torch.isin(packed_directed_sorted, boundary_packed) + b_edges = edges[is_boundary] + + adj = {u.item(): v_idx.item() for u, v_idx in b_edges} + + loops =[] + visited = set() + + for start_node in adj.keys(): + if start_node in visited: + continue + + curr = start_node + loop = [] + + while curr not in visited: + visited.add(curr) + loop.append(curr) + curr = adj.get(curr, -1) + + if curr == -1: + loop = [] + break + if curr == start_node: + loops.append(loop) + break + + new_verts =[] + new_faces = [] + v_idx = v.shape[0] + + for loop in loops: + loop_t = torch.tensor(loop, device=device, dtype=torch.long) + loop_v = v[loop_t] + + diffs = loop_v - torch.roll(loop_v, shifts=-1, dims=0) + perimeter = torch.norm(diffs, dim=1).sum().item() + + if perimeter <= max_perimeter: + new_verts.append(loop_v.mean(dim=0)) + + for i in range(len(loop)): + new_faces.append([loop[(i + 1) % len(loop)], loop[i], v_idx]) + v_idx += 1 + + if new_verts: + v = torch.cat([v, torch.stack(new_verts)], dim=0) + f = torch.cat([f, torch.tensor(new_faces, device=device, dtype=torch.long)], dim=0) + + return v, f + + +def make_double_sided(vertices, faces): + is_batched = vertices.ndim == 3 + if is_batched: + f_list = [] + for i in range(faces.shape[0]): + f_inv = faces[i][:, [0, 2, 1]] + f_list.append(torch.cat([faces[i], f_inv], dim=0)) + return vertices, torch.stack(f_list) + + faces_inv = faces[:, [0, 2, 1]] + return vertices, torch.cat([faces, faces_inv], dim=0) + +class PostProcessMesh(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PostProcessMesh", + category="latent/3d", + inputs=[ + IO.Mesh.Input("mesh"), + IO.Int.Input("simplify", default=1_000_000, min=0, max=50_000_000), + IO.Float.Input("fill_holes_perimeter", default=0.03, min=0.0, step=0.0001) + ], + outputs=[ + IO.Mesh.Output("output_mesh"), + ] + ) + + @classmethod + def execute(cls, mesh, simplify, fill_holes_perimeter): + # TODO: batched mode may break + verts, faces = mesh.vertices, mesh.faces + colors = None + if hasattr(mesh, "colors"): + colors = mesh.colors + + actual_face_count = faces.shape[1] if faces.ndim == 3 else faces.shape[0] + if fill_holes_perimeter > 0: + verts, faces = fill_holes_fn(verts, faces, max_perimeter=fill_holes_perimeter) + + if simplify > 0 and actual_face_count > simplify: + verts, faces, colors = simplify_fn(verts, faces, target=simplify, colors=colors) + + verts, faces = make_double_sided(verts, faces) + + mesh = type(mesh)(vertices=verts, faces=faces) + mesh.vertices = verts + mesh.faces = faces + if colors is not None: + mesh.colors = colors + return IO.NodeOutput(mesh) + +class Trellis2Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + Trellis2Conditioning, + EmptyShapeLatentTrellis2, + EmptyStructureLatentTrellis2, + EmptyTextureLatentTrellis2, + VaeDecodeTextureTrellis, + VaeDecodeShapeTrellis, + VaeDecodeStructureTrellis2, + Trellis2UpsampleCascade, + PostProcessMesh + ] + + +async def comfy_entrypoint() -> Trellis2Extension: + return Trellis2Extension() diff --git a/nodes.py b/nodes.py index 299b3d758..bac3c1ce2 100644 --- a/nodes.py +++ b/nodes.py @@ -2452,6 +2452,7 @@ async def init_builtin_extra_nodes(): "nodes_toolkit.py", "nodes_replacements.py", "nodes_nag.py", + "nodes_trellis2.py", "nodes_sdpose.py", "nodes_math.py", "nodes_number_convert.py",