diff --git a/comfy/image_encoders/dino3.py b/comfy/image_encoders/dino3.py new file mode 100644 index 000000000..d07c2c5b8 --- /dev/null +++ b/comfy/image_encoders/dino3.py @@ -0,0 +1,240 @@ +import math +import torch +import torch.nn as nn + +from comfy.ldm.modules.attention import optimized_attention_for_device +from comfy.ldm.flux.math import apply_rope +from dino2 import Dinov2MLP as DINOv3ViTMLP, LayerScale as DINOv3ViTLayerScale + +class DINOv3ViTAttention(nn.Module): + def __init__(self, hidden_size, num_attention_heads, device, dtype, operations): + super().__init__() + self.embed_dim = hidden_size + self.num_heads = num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + self.is_causal = False + + self.scaling = self.head_dim**-0.5 + self.is_causal = False + + self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False + self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + + self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + + batch_size, patches, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + position_embeddings = torch.stack([cos, sin], dim = -1) + query_states, key_states = apply_rope(query_states, key_states, position_embeddings) + + attn_output, attn_weights = optimized_attention_for_device( + query_states, key_states, value_states, attention_mask, skip_reshape=True, skip_output_reshape=True + ) + + attn_output = attn_output.reshape(batch_size, patches, -1).contiguous() + attn_output = self.o_proj(attn_output) + + return attn_output, attn_weights + +class DINOv3ViTGatedMLP(nn.Module): + def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype) + self.act_fn = torch.nn.GELU() + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +def get_patches_center_coordinates( + num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device +) -> torch.Tensor: + + coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device) + coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device) + coords_h = coords_h / num_patches_h + coords_w = coords_w / num_patches_w + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) + coords = coords.flatten(0, 1) + coords = 2.0 * coords - 1.0 + return coords + +class DINOv3ViTRopePositionEmbedding(nn.Module): + inv_freq: torch.Tensor + + def __init__(self, rope_theta, hidden_size, num_attention_heads, image_size, patch_size, device, dtype): + super().__init__() + self.base = rope_theta + self.head_dim = hidden_size // num_attention_heads + self.num_patches_h = image_size // patch_size + self.num_patches_w = image_size // patch_size + + inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + _, _, height, width = pixel_values.shape + num_patches_h = height // self.patch_size + num_patches_w = width // self.patch_size + + device = pixel_values.device + device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu" + with torch.amp.autocast(device_type = device_type, enabled=False): + patch_coords = get_patches_center_coordinates( + num_patches_h, num_patches_w, dtype=torch.float32, device=device + ) + + angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :] + angles = angles.flatten(1, 2) + angles = angles.tile(2) + + cos = torch.cos(angles) + sin = torch.sin(angles) + + dtype = pixel_values.dtype + return cos.to(dtype=dtype), sin.to(dtype=dtype) + + +class DINOv3ViTEmbeddings(nn.Module): + def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations): + super().__init__() + self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_size, device=device, dtype=dtype)) + self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype)) + self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype)) + self.patch_embeddings = operations.Conv2d( + num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype + ) + + def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None): + batch_size = pixel_values.shape[0] + target_dtype = self.patch_embeddings.weight.dtype + + patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2) + + if bool_masked_pos is not None: + mask_token = self.mask_token.to(patch_embeddings.dtype) + patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings) + + cls_token = self.cls_token.expand(batch_size, -1, -1) + register_tokens = self.register_tokens.expand(batch_size, -1, -1) + embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1) + + return embeddings + +class DINOv3ViTLayer(nn.Module): + + def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, layerscale_value, mlp_bias, intermediate_size, num_attention_heads, + device, dtype, operations): + super().__init__() + + self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps) + self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations) + self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype) + + self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype) + + if use_gated_mlp: + self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations) + else: + self.mlp = DINOv3ViTMLP(hidden_size, device=device, dtype=dtype, operations=operations) + self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, layerscale_value, device=device, dtype=dtype) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor | None = None, + position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states, _ = self.attention( + hidden_states, + attention_mask=attention_mask, + position_embeddings=position_embeddings, + ) + hidden_states = self.layer_scale1(hidden_states) + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = self.layer_scale2(hidden_states) + hidden_states = hidden_states + residual + + return hidden_states + + +class DINOv3ViTModel(nn.Module): + def __init__(self, config, device, dtype, operations): + super().__init__() + num_hidden_layers = config["num_hidden_layers"] + hidden_size = config["hidden_size"] + num_attention_heads = config["num_attention_heads"] + num_register_tokens = config["num_register_tokens"] + intermediate_size = config["intermediate_size"] + layer_norm_eps = config["layer_norm_eps"] + layerscale_value = config["layerscale_value"] + num_channels = config["num_channels"] + patch_size = config["patch_size"] + rope_theta = config["rope_theta"] + + self.embeddings = DINOv3ViTEmbeddings( + hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size, dtype=dtype, device=device, operations=operations + ) + self.rope_embeddings = DINOv3ViTRopePositionEmbedding( + rope_theta, hidden_size, num_attention_heads, image_size=512, patch_size=patch_size, dtype=dtype, device=device + ) + self.layer = nn.ModuleList( + [DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, layerscale_value=layerscale_value, mlp_bias=True, + intermediate_size=intermediate_size,num_attention_heads = num_attention_heads, + dtype=dtype, device=device, operations=operations) + for _ in range(num_hidden_layers)]) + self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def forward( + self, + pixel_values: torch.Tensor, + bool_masked_pos: torch.Tensor | None = None, + **kwargs, + ): + + pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype) + hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + position_embeddings = self.rope_embeddings(pixel_values) + + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module( + hidden_states, + position_embeddings=position_embeddings, + ) + + sequence_output = self.norm(hidden_states) + pooled_output = sequence_output[:, 0, :] + + return sequence_output, None, pooled_output, None diff --git a/comfy/image_encoders/dino3_large.json b/comfy/image_encoders/dino3_large.json new file mode 100644 index 000000000..96263f0d6 --- /dev/null +++ b/comfy/image_encoders/dino3_large.json @@ -0,0 +1,24 @@ +{ + + "hidden_size": 384, + "image_size": 224, + "initializer_range": 0.02, + "intermediate_size": 1536, + "key_bias": false, + "layer_norm_eps": 1e-05, + "layerscale_value": 1.0, + "mlp_bias": true, + "num_attention_heads": 6, + "num_channels": 3, + "num_hidden_layers": 12, + "num_register_tokens": 4, + "patch_size": 16, + "pos_embed_rescale": 2.0, + "proj_bias": true, + "query_bias": true, + "rope_theta": 100.0, + "use_gated_mlp": false, + "value_bias": true, + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225] +} diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py new file mode 100644 index 000000000..9cd7d4995 --- /dev/null +++ b/comfy/ldm/trellis2/attention.py @@ -0,0 +1,194 @@ +import torch +import math +from comfy.ldm.modules.attention import optimized_attention +from typing import Tuple, Union, List +from vae import VarLenTensor + +def sparse_windowed_scaled_dot_product_self_attention( + qkv, + window_size: int, + shift_window: Tuple[int, int, int] = (0, 0, 0) +): + + serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}' + serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name) + if serialization_spatial_cache is None: + fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window) + qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args)) + else: + fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache + + qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C] + + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + q, k, v = qkv_feats.unbind(dim=1) + q = q.unsqueeze(0) # [1, M, H, C] + k = k.unsqueeze(0) # [1, M, H, C] + v = v.unsqueeze(0) # [1, M, H, C] + out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C] + elif optimized_attention.__name__ == 'attention_flash': + if 'flash_attn' not in globals(): + import flash_attn + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C] + + out = out[bwd_indices] # [T, H, C] + + return qkv.replace(out) + +def calc_window_partition( + tensor, + window_size: Union[int, Tuple[int, ...]], + shift_window: Union[int, Tuple[int, ...]] = 0, +) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]: + + DIM = tensor.coords.shape[1] - 1 + shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window + window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size + shifted_coords = tensor.coords.clone().detach() + shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0) + + MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)] + NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)] + OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1] + + shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0) + shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1) + fwd_indices = torch.argsort(shifted_indices) + bwd_indices = torch.empty_like(fwd_indices) + bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device) + seq_lens = torch.bincount(shifted_indices) + mask = seq_lens != 0 + seq_lens = seq_lens[mask] + + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + attn_func_args = { + 'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens) + } + elif optimized_attention.__name__ == 'attention_flash': + attn_func_args = { + 'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(), + 'max_seqlen': torch.max(seq_lens) + } + + return fwd_indices, bwd_indices, seq_lens, attn_func_args + + +def sparse_scaled_dot_product_attention(*args, **kwargs): + arg_names_dict = { + 1: ['qkv'], + 2: ['q', 'kv'], + 3: ['q', 'k', 'v'] + } + num_all_args = len(args) + len(kwargs) + for key in arg_names_dict[num_all_args][len(args):]: + assert key in kwargs, f"Missing argument {key}" + + if num_all_args == 1: + qkv = args[0] if len(args) > 0 else kwargs['qkv'] + device = qkv.device + + s = qkv + q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])] + kv_seqlen = q_seqlen + qkv = qkv.feats # [T, 3, H, C] + + elif num_all_args == 2: + q = args[0] if len(args) > 0 else kwargs['q'] + kv = args[1] if len(args) > 1 else kwargs['kv'] + device = q.device + + if isinstance(q, VarLenTensor): + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, C] + else: + s = None + N, L, H, C = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, C) # [T_Q, H, C] + + if isinstance(kv, VarLenTensor): + kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])] + kv = kv.feats # [T_KV, 2, H, C] + else: + N, L, _, H, C = kv.shape + kv_seqlen = [L] * N + kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C] + + elif num_all_args == 3: + q = args[0] if len(args) > 0 else kwargs['q'] + k = args[1] if len(args) > 1 else kwargs['k'] + v = args[2] if len(args) > 2 else kwargs['v'] + device = q.device + + if isinstance(q, VarLenTensor): + s = q + q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])] + q = q.feats # [T_Q, H, Ci] + else: + s = None + N, L, H, CI = q.shape + q_seqlen = [L] * N + q = q.reshape(N * L, H, CI) # [T_Q, H, Ci] + + if isinstance(k, VarLenTensor): + kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])] + k = k.feats # [T_KV, H, Ci] + v = v.feats # [T_KV, H, Co] + else: + N, L, H, CI, CO = *k.shape, v.shape[-1] + kv_seqlen = [L] * N + k = k.reshape(N * L, H, CI) # [T_KV, H, Ci] + v = v.reshape(N * L, H, CO) # [T_KV, H, Co] + + if optimized_attention.__name__ == 'attention_xformers': + if 'xops' not in globals(): + import xformers.ops as xops + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + q = q.unsqueeze(0) + k = k.unsqueeze(0) + v = v.unsqueeze(0) + mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen) + out = xops.memory_efficient_attention(q, k, v, mask)[0] + elif optimized_attention.__name__ == 'attention_flash': + if 'flash_attn' not in globals(): + import flash_attn + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args in [2, 3]: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen)) + elif num_all_args == 2: + out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif num_all_args == 3: + out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen)) + elif optimized_attention.__name__ == 'flash_attn_3': # TODO + if 'flash_attn_3' not in globals(): + import flash_attn_interface as flash_attn_3 + cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device) + if num_all_args == 1: + q, k, v = qkv.unbind(dim=1) + cu_seqlens_kv = cu_seqlens_q.clone() + max_q_seqlen = max_kv_seqlen = max(q_seqlen) + elif num_all_args == 2: + k, v = kv.unbind(dim=1) + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + elif num_all_args == 3: + cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device) + max_q_seqlen = max(q_seqlen) + max_kv_seqlen = max(kv_seqlen) + out = flash_attn_3.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_q_seqlen, max_kv_seqlen) + + if s is not None: + return s.replace(out) + else: + return out.reshape(N, L, H, -1) diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py new file mode 100644 index 000000000..be8200341 --- /dev/null +++ b/comfy/ldm/trellis2/cumesh.py @@ -0,0 +1,149 @@ +# will contain every cuda -> pytorch operation + +import torch +from typing import Dict + + +class TorchHashMap: + def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int): + device = keys.device + # use long for searchsorted + self.sorted_keys, order = torch.sort(keys.long()) + self.sorted_vals = values.long()[order] + self.default_value = torch.tensor(default_value, dtype=torch.long, device=device) + self._n = self.sorted_keys.numel() + + def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: + flat = flat_keys.long() + idx = torch.searchsorted(self.sorted_keys, flat) + found = (idx < self._n) & (self.sorted_keys[idx] == flat) + out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) + if found.any(): + out[found] = self.sorted_vals[idx[found]] + return out + +class Voxel: + def __init__( + self, + origin: list, + voxel_size: float, + coords: torch.Tensor = None, + attrs: torch.Tensor = None, + layout: Dict = {}, + device: torch.device = 'cuda' + ): + self.origin = torch.tensor(origin, dtype=torch.float32, device=device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.layout = layout + self.device = device + + @property + def position(self): + return (self.coords + 0.5) * self.voxel_size + self.origin[None, :] + + def split_attrs(self): + return { + k: self.attrs[:, self.layout[k]] + for k in self.layout + } + +class Mesh: + def __init__(self, + vertices, + faces, + vertex_attrs=None + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.vertex_attrs = vertex_attrs + + @property + def device(self): + return self.vertices.device + + def to(self, device, non_blocking=False): + return Mesh( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None, + ) + + def cuda(self, non_blocking=False): + return self.to('cuda', non_blocking=non_blocking) + + def cpu(self): + return self.to('cpu') + + # TODO could be an option + def fill_holes(self, max_hole_perimeter=3e-2): + import cumesh + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.get_edges() + mesh.get_boundary_info() + if mesh.num_boundaries == 0: + return + mesh.get_vertex_edge_adjacency() + mesh.get_vertex_boundary_adjacency() + mesh.get_manifold_boundary_adjacency() + mesh.read_manifold_boundary_adjacency() + mesh.get_boundary_connected_components() + mesh.get_boundary_loops() + if mesh.num_boundary_loops == 0: + return + mesh.fill_holes(max_hole_perimeter=max_hole_perimeter) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + + # TODO could be an option + def simplify(self, target=1000000, verbose: bool=False, options: dict={}): + import cumesh + vertices = self.vertices.cuda() + faces = self.faces.cuda() + + mesh = cumesh.CuMesh() + mesh.init(vertices, faces) + mesh.simplify(target, verbose=verbose, options=options) + new_vertices, new_faces = mesh.read() + + self.vertices = new_vertices.to(self.device) + self.faces = new_faces.to(self.device) + +class MeshWithVoxel(Mesh, Voxel): + def __init__(self, + vertices: torch.Tensor, + faces: torch.Tensor, + origin: list, + voxel_size: float, + coords: torch.Tensor, + attrs: torch.Tensor, + voxel_shape: torch.Size, + layout: Dict = {}, + ): + self.vertices = vertices.float() + self.faces = faces.int() + self.origin = torch.tensor(origin, dtype=torch.float32, device=self.device) + self.voxel_size = voxel_size + self.coords = coords + self.attrs = attrs + self.voxel_shape = voxel_shape + self.layout = layout + + def to(self, device, non_blocking=False): + return MeshWithVoxel( + self.vertices.to(device, non_blocking=non_blocking), + self.faces.to(device, non_blocking=non_blocking), + self.origin.tolist(), + self.voxel_size, + self.coords.to(device, non_blocking=non_blocking), + self.attrs.to(device, non_blocking=non_blocking), + self.voxel_shape, + self.layout, + ) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py new file mode 100644 index 000000000..a0889c4dd --- /dev/null +++ b/comfy/ldm/trellis2/model.py @@ -0,0 +1,499 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from vae import SparseTensor, SparseLinear, sparse_cat, VarLenTensor +from typing import Optional, Tuple, Literal, Union, List +from attention import sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention +from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder + +class SparseGELU(nn.GELU): + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + +class SparseFeedForwardNet(nn.Module): + def __init__(self, channels: int, mlp_ratio: float = 4.0): + super().__init__() + self.mlp = nn.Sequential( + SparseLinear(channels, int(channels * mlp_ratio)), + SparseGELU(approximate="tanh"), + SparseLinear(int(channels * mlp_ratio), channels), + ) + + def forward(self, x: VarLenTensor) -> VarLenTensor: + return self.mlp(x) + +def manual_cast(tensor, dtype): + if not torch.is_autocast_enabled(): + return tensor.type(dtype) + return tensor +class LayerNorm32(nn.LayerNorm): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x_dtype = x.dtype + x = manual_cast(x, torch.float32) + o = super().forward(x) + return manual_cast(o, x_dtype) + + +class SparseMultiHeadRMSNorm(nn.Module): + def __init__(self, dim: int, heads: int): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(heads, dim)) + + def forward(self, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + x_type = x.dtype + x = x.float() + if isinstance(x, VarLenTensor): + x = x.replace(F.normalize(x.feats, dim=-1) * self.gamma * self.scale) + else: + x = F.normalize(x, dim=-1) * self.gamma * self.scale + return x.to(x_type) + +# TODO: replace with apply_rope1 +class SparseRotaryPositionEmbedder(nn.Module): + def __init__( + self, + head_dim: int, + dim: int = 3, + rope_freq: Tuple[float, float] = (1.0, 10000.0) + ): + super().__init__() + assert head_dim % 2 == 0, "Head dim must be divisible by 2" + self.head_dim = head_dim + self.dim = dim + self.rope_freq = rope_freq + self.freq_dim = head_dim // 2 // dim + self.freqs = torch.arange(self.freq_dim, dtype=torch.float32) / self.freq_dim + self.freqs = rope_freq[0] / (rope_freq[1] ** (self.freqs)) + + def _get_phases(self, indices: torch.Tensor) -> torch.Tensor: + self.freqs = self.freqs.to(indices.device) + phases = torch.outer(indices, self.freqs) + phases = torch.polar(torch.ones_like(phases), phases) + return phases + + def _rotary_embedding(self, x: torch.Tensor, phases: torch.Tensor) -> torch.Tensor: + x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) + x_rotated = x_complex * phases.unsqueeze(-2) + x_embed = torch.view_as_real(x_rotated).reshape(*x_rotated.shape[:-1], -1).to(x.dtype) + return x_embed + + def forward(self, q: SparseTensor, k: Optional[SparseTensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + q (SparseTensor): [..., N, H, D] tensor of queries + k (SparseTensor): [..., N, H, D] tensor of keys + """ + assert q.coords.shape[-1] == self.dim + 1, "Last dimension of coords must be equal to dim+1" + phases_cache_name = f'rope_phase_{self.dim}d_freq{self.rope_freq[0]}-{self.rope_freq[1]}_hd{self.head_dim}' + phases = q.get_spatial_cache(phases_cache_name) + if phases is None: + coords = q.coords[..., 1:] + phases = self._get_phases(coords.reshape(-1)).reshape(*coords.shape[:-1], -1) + if phases.shape[-1] < self.head_dim // 2: + padn = self.head_dim // 2 - phases.shape[-1] + phases = torch.cat([phases, torch.polar( + torch.ones(*phases.shape[:-1], padn, device=phases.device), + torch.zeros(*phases.shape[:-1], padn, device=phases.device) + )], dim=-1) + q.register_spatial_cache(phases_cache_name, phases) + q_embed = q.replace(self._rotary_embedding(q.feats, phases)) + if k is None: + return q_embed + k_embed = k.replace(self._rotary_embedding(k.feats, phases)) + return q_embed, k_embed + +class SparseMultiHeadAttention(nn.Module): + def __init__( + self, + channels: int, + num_heads: int, + ctx_channels: Optional[int] = None, + type: Literal["self", "cross"] = "self", + attn_mode: Literal["full", "windowed", "double_windowed"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + qkv_bias: bool = True, + use_rope: bool = False, + rope_freq: Tuple[int, int] = (1.0, 10000.0), + qk_rms_norm: bool = False, + ): + super().__init__() + + self.channels = channels + self.head_dim = channels // num_heads + self.ctx_channels = ctx_channels if ctx_channels is not None else channels + self.num_heads = num_heads + self._type = type + self.attn_mode = attn_mode + self.window_size = window_size + self.shift_window = shift_window + self.use_rope = use_rope + self.qk_rms_norm = qk_rms_norm + + if self._type == "self": + self.to_qkv = nn.Linear(channels, channels * 3, bias=qkv_bias) + else: + self.to_q = nn.Linear(channels, channels, bias=qkv_bias) + self.to_kv = nn.Linear(self.ctx_channels, channels * 2, bias=qkv_bias) + + if self.qk_rms_norm: + self.q_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + self.k_rms_norm = SparseMultiHeadRMSNorm(self.head_dim, num_heads) + + self.to_out = nn.Linear(channels, channels) + + if use_rope: + self.rope = SparseRotaryPositionEmbedder(self.head_dim, rope_freq=rope_freq) + + @staticmethod + def _linear(module: nn.Linear, x: Union[VarLenTensor, torch.Tensor]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.replace(module(x.feats)) + else: + return module(x) + + @staticmethod + def _reshape_chs(x: Union[VarLenTensor, torch.Tensor], shape: Tuple[int, ...]) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + return x.reshape(*shape) + else: + return x.reshape(*x.shape[:2], *shape) + + def _fused_pre(self, x: Union[VarLenTensor, torch.Tensor], num_fused: int) -> Union[VarLenTensor, torch.Tensor]: + if isinstance(x, VarLenTensor): + x_feats = x.feats.unsqueeze(0) + else: + x_feats = x + x_feats = x_feats.reshape(*x_feats.shape[:2], num_fused, self.num_heads, -1) + return x.replace(x_feats.squeeze(0)) if isinstance(x, VarLenTensor) else x_feats + + def forward(self, x: SparseTensor, context: Optional[Union[VarLenTensor, torch.Tensor]] = None) -> SparseTensor: + if self._type == "self": + qkv = self._linear(self.to_qkv, x) + qkv = self._fused_pre(qkv, num_fused=3) + if self.qk_rms_norm or self.use_rope: + q, k, v = qkv.unbind(dim=-3) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k = self.k_rms_norm(k) + if self.use_rope: + q, k = self.rope(q, k) + qkv = qkv.replace(torch.stack([q.feats, k.feats, v.feats], dim=1)) + if self.attn_mode == "full": + h = sparse_scaled_dot_product_attention(qkv) + elif self.attn_mode == "windowed": + h = sparse_windowed_scaled_dot_product_self_attention( + qkv, self.window_size, shift_window=self.shift_window + ) + elif self.attn_mode == "double_windowed": + qkv0 = qkv.replace(qkv.feats[:, :, self.num_heads//2:]) + qkv1 = qkv.replace(qkv.feats[:, :, :self.num_heads//2]) + h0 = sparse_windowed_scaled_dot_product_self_attention( + qkv0, self.window_size, shift_window=(0, 0, 0) + ) + h1 = sparse_windowed_scaled_dot_product_self_attention( + qkv1, self.window_size, shift_window=tuple([self.window_size//2] * 3) + ) + h = qkv.replace(torch.cat([h0.feats, h1.feats], dim=1)) + else: + q = self._linear(self.to_q, x) + q = self._reshape_chs(q, (self.num_heads, -1)) + kv = self._linear(self.to_kv, context) + kv = self._fused_pre(kv, num_fused=2) + if self.qk_rms_norm: + q = self.q_rms_norm(q) + k, v = kv.unbind(dim=-3) + k = self.k_rms_norm(k) + h = sparse_scaled_dot_product_attention(q, k, v) + else: + h = sparse_scaled_dot_product_attention(q, kv) + h = self._reshape_chs(h, (-1,)) + h = self._linear(self.to_out, h) + return h + +class ModulatedSparseTransformerBlock(nn.Module): + """ + Sparse Transformer block (MSA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor) -> SparseTensor: + if self.use_checkpoint: + return torch.utils.checkpoint.checkpoint(self._forward, x, mod, use_reentrant=False) + else: + return self._forward(x, mod) + + +class ModulatedSparseTransformerCrossBlock(nn.Module): + """ + Sparse Transformer cross-attention block (MSA + MCA + FFN) with adaptive layer norm conditioning. + """ + def __init__( + self, + channels: int, + ctx_channels: int, + num_heads: int, + mlp_ratio: float = 4.0, + attn_mode: Literal["full", "swin"] = "full", + window_size: Optional[int] = None, + shift_window: Optional[Tuple[int, int, int]] = None, + use_checkpoint: bool = False, + use_rope: bool = False, + rope_freq: Tuple[float, float] = (1.0, 10000.0), + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + qkv_bias: bool = True, + share_mod: bool = False, + + ): + super().__init__() + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.norm1 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.norm2 = LayerNorm32(channels, elementwise_affine=True, eps=1e-6) + self.norm3 = LayerNorm32(channels, elementwise_affine=False, eps=1e-6) + self.self_attn = SparseMultiHeadAttention( + channels, + num_heads=num_heads, + type="self", + attn_mode=attn_mode, + window_size=window_size, + shift_window=shift_window, + qkv_bias=qkv_bias, + use_rope=use_rope, + rope_freq=rope_freq, + qk_rms_norm=qk_rms_norm, + ) + self.cross_attn = SparseMultiHeadAttention( + channels, + ctx_channels=ctx_channels, + num_heads=num_heads, + type="cross", + attn_mode="full", + qkv_bias=qkv_bias, + qk_rms_norm=qk_rms_norm_cross, + ) + self.mlp = SparseFeedForwardNet( + channels, + mlp_ratio=mlp_ratio, + ) + if not share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(channels, 6 * channels, bias=True) + ) + else: + self.modulation = nn.Parameter(torch.randn(6 * channels) / channels ** 0.5) + + def _forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + if self.share_mod: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.modulation + mod).type(mod.dtype).chunk(6, dim=1) + else: + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(mod).chunk(6, dim=1) + h = x.replace(self.norm1(x.feats)) + h = h * (1 + scale_msa) + shift_msa + h = self.self_attn(h) + h = h * gate_msa + x = x + h + h = x.replace(self.norm2(x.feats)) + h = self.cross_attn(h, context) + x = x + h + h = x.replace(self.norm3(x.feats)) + h = h * (1 + scale_mlp) + shift_mlp + h = self.mlp(h) + h = h * gate_mlp + x = x + h + return x + + def forward(self, x: SparseTensor, mod: torch.Tensor, context: Union[torch.Tensor, VarLenTensor]) -> SparseTensor: + return self._forward(x, mod, context) + + +class SLatFlowModel(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + model_channels: int, + cond_channels: int, + out_channels: int, + num_blocks: int, + num_heads: Optional[int] = None, + num_head_channels: Optional[int] = 64, + mlp_ratio: float = 4, + pe_mode: Literal["ape", "rope"] = "rope", + rope_freq: Tuple[float, float] = (1.0, 10000.0), + use_checkpoint: bool = False, + share_mod: bool = False, + initialization: str = 'vanilla', + qk_rms_norm: bool = False, + qk_rms_norm_cross: bool = False, + dtype = None, + device = None, + operations = None, + ): + super().__init__() + self.resolution = resolution + self.in_channels = in_channels + self.model_channels = model_channels + self.cond_channels = cond_channels + self.out_channels = out_channels + self.num_blocks = num_blocks + self.num_heads = num_heads or model_channels // num_head_channels + self.mlp_ratio = mlp_ratio + self.pe_mode = pe_mode + self.use_checkpoint = use_checkpoint + self.share_mod = share_mod + self.initialization = initialization + self.qk_rms_norm = qk_rms_norm + self.qk_rms_norm_cross = qk_rms_norm_cross + self.dtype = dtype + + self.t_embedder = TimestepEmbedder(model_channels) + if share_mod: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + nn.Linear(model_channels, 6 * model_channels, bias=True) + ) + + self.input_layer = SparseLinear(in_channels, model_channels) + + self.blocks = nn.ModuleList([ + ModulatedSparseTransformerCrossBlock( + model_channels, + cond_channels, + num_heads=self.num_heads, + mlp_ratio=self.mlp_ratio, + attn_mode='full', + use_checkpoint=self.use_checkpoint, + use_rope=(pe_mode == "rope"), + rope_freq=rope_freq, + share_mod=self.share_mod, + qk_rms_norm=self.qk_rms_norm, + qk_rms_norm_cross=self.qk_rms_norm_cross, + ) + for _ in range(num_blocks) + ]) + + self.out_layer = SparseLinear(model_channels, out_channels) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward( + self, + x: SparseTensor, + t: torch.Tensor, + cond: Union[torch.Tensor, List[torch.Tensor]], + concat_cond: Optional[SparseTensor] = None, + **kwargs + ) -> SparseTensor: + if concat_cond is not None: + x = sparse_cat([x, concat_cond], dim=-1) + if isinstance(cond, list): + cond = VarLenTensor.from_tensor_list(cond) + + h = self.input_layer(x) + h = manual_cast(h, self.dtype) + t_emb = self.t_embedder(t) + if self.share_mod: + t_emb = self.adaLN_modulation(t_emb) + t_emb = manual_cast(t_emb, self.dtype) + cond = manual_cast(cond, self.dtype) + + if self.pe_mode == "ape": + pe = self.pos_embedder(h.coords[:, 1:]) + h = h + manual_cast(pe, self.dtype) + for block in self.blocks: + h = block(h, t_emb, cond) + + h = manual_cast(h, x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.out_layer(h) + return h + +class Trellis2(nn.Module): + def __init__(self, resolution, + in_channels = 32, + out_channels = 32, + model_channels = 1536, + cond_channels = 1024, + num_blocks = 30, + num_heads = 12, + mlp_ratio = 5.3334, + share_mod = True, + qk_rms_norm = True, + qk_rms_norm_cross = True, + dtype=None, device=None, operations=None): + args = { + "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, + "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, + "qk_rms_norm": qk_rms_norm, "qk_rms_norm_cross": qk_rms_norm_cross, "device": device, "dtype": dtype, "operations": operations + } + # TODO: update the names/checkpoints + self.img2shape = SLatFlowModel(resolution, in_channels=in_channels, *args) + self.shape2txt = SLatFlowModel(resolution, in_channels=in_channels*2, *args) + self.shape_generation = True + + def forward(self, x, timestep, context): + pass diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py new file mode 100644 index 000000000..1d564bca2 --- /dev/null +++ b/comfy/ldm/trellis2/vae.py @@ -0,0 +1,1185 @@ +import torch +import torch.nn as nn +from typing import List, Any, Dict, Optional, overload, Union, Tuple +from fractions import Fraction +import torch.nn.functional as F +from dataclasses import dataclass +import numpy as np +from cumesh import TorchHashMap, Mesh, MeshWithVoxel + +# TODO: determine which conv they actually use +@dataclass +class config: + CONV = "none" + +# TODO post processing +def simplify(self, target_num_faces: int, verbose: bool=False, options: dict={}): + + num_face = self.cu_mesh.num_faces() + if num_face <= target_num_faces: + return + + thresh = options.get('thresh', 1e-8) + lambda_edge_length = options.get('lambda_edge_length', 1e-2) + lambda_skinny = options.get('lambda_skinny', 1e-3) + while True: + new_num_vert, new_num_face = self.cu_mesh.simplify_step(lambda_edge_length, lambda_skinny, thresh, False) + + if new_num_face <= target_num_faces: + break + + del_num_face = num_face - new_num_face + if del_num_face / num_face < 1e-2: + thresh *= 10 + num_face = new_num_face + +class VarLenTensor: + + def __init__(self, feats: torch.Tensor, layout: List[slice]=None): + self.feats = feats + self.layout = layout if layout is not None else [slice(0, feats.shape[0])] + self._cache = {} + + @staticmethod + def layout_from_seqlen(seqlen: list) -> List[slice]: + """ + Create a layout from a tensor of sequence lengths. + """ + layout = [] + start = 0 + for l in seqlen: + layout.append(slice(start, start + l)) + start += l + return layout + + @staticmethod + def from_tensor_list(tensor_list: List[torch.Tensor]) -> 'VarLenTensor': + """ + Create a VarLenTensor from a list of tensors. + """ + feats = torch.cat(tensor_list, dim=0) + layout = [] + start = 0 + for tensor in tensor_list: + layout.append(slice(start, start + tensor.shape[0])) + start += tensor.shape[0] + return VarLenTensor(feats, layout) + + def __len__(self) -> int: + return len(self.layout) + + @property + def shape(self) -> torch.Size: + return torch.Size([len(self.layout), *self.feats.shape[1:]]) + + def dim(self) -> int: + return len(self.shape) + + @property + def ndim(self) -> int: + return self.dim() + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + if 'seqlen' not in self._cache: + self._cache['seqlen'] = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + return self._cache['seqlen'] + + @property + def cum_seqlen(self) -> torch.LongTensor: + if 'cum_seqlen' not in self._cache: + self._cache['cum_seqlen'] = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + return self._cache['cum_seqlen'] + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + if 'batch_boardcast_map' not in self._cache: + self._cache['batch_boardcast_map'] = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + return self._cache['batch_boardcast_map'] + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'VarLenTensor': ... + + def to(self, *args, **kwargs) -> 'VarLenTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'VarLenTensor': + new_feats = self.feats.cpu() + return self.replace(new_feats) + + def cuda(self) -> 'VarLenTensor': + new_feats = self.feats.cuda() + return self.replace(new_feats) + + def half(self) -> 'VarLenTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'VarLenTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'VarLenTensor': + new_feats = self.feats.detach() + return self.replace(new_feats) + + def reshape(self, *shape) -> 'VarLenTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['VarLenTensor']: + return varlen_unbind(self, dim) + + def replace(self, feats: torch.Tensor) -> 'VarLenTensor': + new_tensor = VarLenTensor( + feats=feats, + layout=self.layout, + ) + new_tensor._cache = self._cache + return new_tensor + + def to_dense(self, max_length=None) -> torch.Tensor: + N = len(self) + L = max_length or self.seqlen.max().item() + spatial = self.feats.shape[1:] + idx = torch.arange(L, device=self.device).unsqueeze(0).expand(N, L) + mask = (idx < self.seqlen.unsqueeze(1)) + mapping = mask.reshape(-1).cumsum(dim=0) - 1 + dense = self.feats[mapping] + dense = dense.reshape(N, L, *spatial) + return dense, mask + + def __neg__(self) -> 'VarLenTensor': + return self.replace(-self.feats) + + def __elemwise__(self, other: Union[torch.Tensor, 'VarLenTensor'], op: callable) -> 'VarLenTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + return new_tensor + + def __add__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __radd__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.add) + + def __sub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.sub) + + def __rsub__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.sub(y, x)) + + def __mul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __rmul__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.mul) + + def __truediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, torch.div) + + def __rtruediv__(self, other: Union[torch.Tensor, 'VarLenTensor', float]) -> 'VarLenTensor': + return self.__elemwise__(other, lambda x, y: torch.div(y, x)) + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_feats = [] + new_layout = [] + start = 0 + for new_idx, old_idx in enumerate(idx): + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_feats[-1]))) + start += len(new_feats[-1]) + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = VarLenTensor(feats=new_feats, layout=new_layout) + return new_tensor + + def reduce(self, op: str, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + if isinstance(dim, int): + dim = (dim,) + + if op =='mean': + red = self.feats.mean(dim=dim, keepdim=keepdim) + elif op =='sum': + red = self.feats.sum(dim=dim, keepdim=keepdim) + elif op == 'prod': + red = self.feats.prod(dim=dim, keepdim=keepdim) + else: + raise ValueError(f"Unsupported reduce operation: {op}") + + if dim is None or 0 in dim: + return red + + red = torch.segment_reduce(red, reduce=op, lengths=self.seqlen) + return red + + def mean(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='mean', dim=dim, keepdim=keepdim) + + def sum(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='sum', dim=dim, keepdim=keepdim) + + def prod(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + return self.reduce(op='prod', dim=dim, keepdim=keepdim) + + def std(self, dim: Optional[Union[int, Tuple[int,...]]] = None, keepdim: bool = False) -> torch.Tensor: + mean = self.mean(dim=dim, keepdim=True) + mean2 = self.replace(self.feats ** 2).mean(dim=dim, keepdim=True) + std = (mean2 - mean ** 2).sqrt() + return std + + def __repr__(self) -> str: + return f"VarLenTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + +def varlen_unbind(input: VarLenTensor, dim: int) -> Union[List[VarLenTensor]]: + + if dim == 0: + return [input[i] for i in range(len(input))] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] + + +class SparseTensor(VarLenTensor): + + SparseTensorData = None + + @overload + def __init__(self, feats: torch.Tensor, coords: torch.Tensor, shape: Optional[torch.Size] = None, **kwargs): ... + + @overload + def __init__(self, data, shape: Optional[torch.Size] = None, **kwargs): ... + + def __init__(self, *args, **kwargs): + # Lazy import of sparse tensor backend + if self.SparseTensorData is None: + import importlib + if config.CONV == 'torchsparse': + self.SparseTensorData = importlib.import_module('torchsparse').SparseTensor + elif config.CONV == 'spconv': + self.SparseTensorData = importlib.import_module('spconv.pytorch').SparseConvTensor + + method_id = 0 + if len(args) != 0: + method_id = 0 if isinstance(args[0], torch.Tensor) else 1 + else: + method_id = 1 if 'data' in kwargs else 0 + + if method_id == 0: + feats, coords, shape = args + (None,) * (3 - len(args)) + if 'feats' in kwargs: + feats = kwargs['feats'] + del kwargs['feats'] + if 'coords' in kwargs: + coords = kwargs['coords'] + del kwargs['coords'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + if config.CONV == 'torchsparse': + self.data = self.SparseTensorData(feats, coords, **kwargs) + elif config.CONV == 'spconv': + spatial_shape = list(coords.max(0)[0] + 1) + self.data = self.SparseTensorData(feats.reshape(feats.shape[0], -1), coords, spatial_shape[1:], spatial_shape[0], **kwargs) + self.data._features = feats + else: + self.data = { + 'feats': feats, + 'coords': coords, + } + elif method_id == 1: + data, shape = args + (None,) * (2 - len(args)) + if 'data' in kwargs: + data = kwargs['data'] + del kwargs['data'] + if 'shape' in kwargs: + shape = kwargs['shape'] + del kwargs['shape'] + + self.data = data + + self._shape = shape + self._scale = kwargs.get('scale', (Fraction(1, 1), Fraction(1, 1), Fraction(1, 1))) + self._spatial_cache = kwargs.get('spatial_cache', {}) + + @staticmethod + def from_tensor_list(feats_list: List[torch.Tensor], coords_list: List[torch.Tensor]) -> 'SparseTensor': + """ + Create a SparseTensor from a list of tensors. + """ + feats = torch.cat(feats_list, dim=0) + coords = [] + for i, coord in enumerate(coords_list): + coord = torch.cat([torch.full_like(coord[:, :1], i), coord[:, 1:]], dim=1) + coords.append(coord) + coords = torch.cat(coords, dim=0) + return SparseTensor(feats, coords) + + def to_tensor_list(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Convert a SparseTensor to list of tensors. + """ + feats_list = [] + coords_list = [] + for s in self.layout: + feats_list.append(self.feats[s]) + coords_list.append(self.coords[s]) + return feats_list, coords_list + + def __len__(self) -> int: + return len(self.layout) + + def __cal_shape(self, feats, coords): + shape = [] + shape.append(coords[:, 0].max().item() + 1) + shape.extend([*feats.shape[1:]]) + return torch.Size(shape) + + def __cal_layout(self, coords, batch_size): + seq_len = torch.bincount(coords[:, 0], minlength=batch_size) + offset = torch.cumsum(seq_len, dim=0) + layout = [slice((offset[i] - seq_len[i]).item(), offset[i].item()) for i in range(batch_size)] + return layout + + def __cal_spatial_shape(self, coords): + return torch.Size((coords[:, 1:].max(0)[0] + 1).tolist()) + + @property + def shape(self) -> torch.Size: + if self._shape is None: + self._shape = self.__cal_shape(self.feats, self.coords) + return self._shape + + @property + def layout(self) -> List[slice]: + layout = self.get_spatial_cache('layout') + if layout is None: + layout = self.__cal_layout(self.coords, self.shape[0]) + self.register_spatial_cache('layout', layout) + return layout + + @property + def spatial_shape(self) -> torch.Size: + spatial_shape = self.get_spatial_cache('shape') + if spatial_shape is None: + spatial_shape = self.__cal_spatial_shape(self.coords) + self.register_spatial_cache('shape', spatial_shape) + return spatial_shape + + @property + def feats(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.F + elif config.CONV == 'spconv': + return self.data.features + else: + return self.data['feats'] + + @feats.setter + def feats(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.F = value + elif config.CONV == 'spconv': + self.data.features = value + else: + self.data['feats'] = value + + @property + def coords(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.C + elif config.CONV == 'spconv': + return self.data.indices + else: + return self.data['coords'] + + @coords.setter + def coords(self, value: torch.Tensor): + if config.CONV == 'torchsparse': + self.data.C = value + elif config.CONV == 'spconv': + self.data.indices = value + else: + self.data['coords'] = value + + @property + def dtype(self): + return self.feats.dtype + + @property + def device(self): + return self.feats.device + + @property + def seqlen(self) -> torch.LongTensor: + seqlen = self.get_spatial_cache('seqlen') + if seqlen is None: + seqlen = torch.tensor([l.stop - l.start for l in self.layout], dtype=torch.long, device=self.device) + self.register_spatial_cache('seqlen', seqlen) + return seqlen + + @property + def cum_seqlen(self) -> torch.LongTensor: + cum_seqlen = self.get_spatial_cache('cum_seqlen') + if cum_seqlen is None: + cum_seqlen = torch.cat([ + torch.tensor([0], dtype=torch.long, device=self.device), + self.seqlen.cumsum(dim=0) + ], dim=0) + self.register_spatial_cache('cum_seqlen', cum_seqlen) + return cum_seqlen + + @property + def batch_boardcast_map(self) -> torch.LongTensor: + """ + Get the broadcast map for the varlen tensor. + """ + batch_boardcast_map = self.get_spatial_cache('batch_boardcast_map') + if batch_boardcast_map is None: + batch_boardcast_map = torch.repeat_interleave( + torch.arange(len(self.layout), device=self.device), + self.seqlen, + ) + self.register_spatial_cache('batch_boardcast_map', batch_boardcast_map) + return batch_boardcast_map + + @overload + def to(self, dtype: torch.dtype, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + @overload + def to(self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, *, non_blocking: bool = False, copy: bool = False) -> 'SparseTensor': ... + + def to(self, *args, **kwargs) -> 'SparseTensor': + device = None + dtype = None + if len(args) == 2: + device, dtype = args + elif len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype = args[0] + else: + device = args[0] + if 'dtype' in kwargs: + assert dtype is None, "to() received multiple values for argument 'dtype'" + dtype = kwargs['dtype'] + if 'device' in kwargs: + assert device is None, "to() received multiple values for argument 'device'" + device = kwargs['device'] + non_blocking = kwargs.get('non_blocking', False) + copy = kwargs.get('copy', False) + + new_feats = self.feats.to(device=device, dtype=dtype, non_blocking=non_blocking, copy=copy) + new_coords = self.coords.to(device=device, non_blocking=non_blocking, copy=copy) + return self.replace(new_feats, new_coords) + + def type(self, dtype): + new_feats = self.feats.type(dtype) + return self.replace(new_feats) + + def cpu(self) -> 'SparseTensor': + new_feats = self.feats.cpu() + new_coords = self.coords.cpu() + return self.replace(new_feats, new_coords) + + def cuda(self) -> 'SparseTensor': + new_feats = self.feats.cuda() + new_coords = self.coords.cuda() + return self.replace(new_feats, new_coords) + + def half(self) -> 'SparseTensor': + new_feats = self.feats.half() + return self.replace(new_feats) + + def float(self) -> 'SparseTensor': + new_feats = self.feats.float() + return self.replace(new_feats) + + def detach(self) -> 'SparseTensor': + new_coords = self.coords.detach() + new_feats = self.feats.detach() + return self.replace(new_feats, new_coords) + + def reshape(self, *shape) -> 'SparseTensor': + new_feats = self.feats.reshape(self.feats.shape[0], *shape) + return self.replace(new_feats) + + def unbind(self, dim: int) -> List['SparseTensor']: + return sparse_unbind(self, dim) + + def replace(self, feats: torch.Tensor, coords: Optional[torch.Tensor] = None) -> 'SparseTensor': + if config.CONV == 'torchsparse': + new_data = self.SparseTensorData( + feats=feats, + coords=self.data.coords if coords is None else coords, + stride=self.data.stride, + spatial_range=self.data.spatial_range, + ) + new_data._caches = self.data._caches + elif config.CONV == 'spconv': + new_data = self.SparseTensorData( + self.data.features.reshape(self.data.features.shape[0], -1), + self.data.indices, + self.data.spatial_shape, + self.data.batch_size, + self.data.grid, + self.data.voxel_num, + self.data.indice_dict + ) + new_data._features = feats + new_data.benchmark = self.data.benchmark + new_data.benchmark_record = self.data.benchmark_record + new_data.thrust_allocator = self.data.thrust_allocator + new_data._timer = self.data._timer + new_data.force_algo = self.data.force_algo + new_data.int8_scale = self.data.int8_scale + if coords is not None: + new_data.indices = coords + else: + new_data = { + 'feats': feats, + 'coords': self.data['coords'] if coords is None else coords, + } + new_tensor = SparseTensor( + new_data, + shape=torch.Size([self._shape[0]] + list(feats.shape[1:])) if self._shape is not None else None, + scale=self._scale, + spatial_cache=self._spatial_cache + ) + return new_tensor + + def to_dense(self) -> torch.Tensor: + if config.CONV == 'torchsparse': + return self.data.dense() + elif config.CONV == 'spconv': + return self.data.dense() + else: + spatial_shape = self.spatial_shape + ret = torch.zeros(*self.shape, *spatial_shape, dtype=self.dtype, device=self.device) + idx = [self.coords[:, 0], slice(None)] + self.coords[:, 1:].unbind(1) + ret[tuple(idx)] = self.feats + return ret + + @staticmethod + def full(aabb, dim, value, dtype=torch.float32, device=None) -> 'SparseTensor': + N, C = dim + x = torch.arange(aabb[0], aabb[3] + 1) + y = torch.arange(aabb[1], aabb[4] + 1) + z = torch.arange(aabb[2], aabb[5] + 1) + coords = torch.stack(torch.meshgrid(x, y, z, indexing='ij'), dim=-1).reshape(-1, 3) + coords = torch.cat([ + torch.arange(N).view(-1, 1).repeat(1, coords.shape[0]).view(-1, 1), + coords.repeat(N, 1), + ], dim=1).to(dtype=torch.int32, device=device) + feats = torch.full((coords.shape[0], C), value, dtype=dtype, device=device) + return SparseTensor(feats=feats, coords=coords) + + def __merge_sparse_cache(self, other: 'SparseTensor') -> dict: + new_cache = {} + for k in set(list(self._spatial_cache.keys()) + list(other._spatial_cache.keys())): + if k in self._spatial_cache: + new_cache[k] = self._spatial_cache[k] + if k in other._spatial_cache: + if k not in new_cache: + new_cache[k] = other._spatial_cache[k] + else: + new_cache[k].update(other._spatial_cache[k]) + return new_cache + + def __elemwise__(self, other: Union[torch.Tensor, VarLenTensor], op: callable) -> 'SparseTensor': + if isinstance(other, torch.Tensor): + try: + other = torch.broadcast_to(other, self.shape) + other = other[self.batch_boardcast_map] + except: + pass + if isinstance(other, VarLenTensor): + other = other.feats + new_feats = op(self.feats, other) + new_tensor = self.replace(new_feats) + if isinstance(other, SparseTensor): + new_tensor._spatial_cache = self.__merge_sparse_cache(other) + return new_tensor + + def __getitem__(self, idx): + if isinstance(idx, int): + idx = [idx] + elif isinstance(idx, slice): + idx = range(*idx.indices(self.shape[0])) + elif isinstance(idx, list): + assert all(isinstance(i, int) for i in idx), f"Only integer indices are supported: {idx}" + elif isinstance(idx, torch.Tensor): + if idx.dtype == torch.bool: + assert idx.shape == (self.shape[0],), f"Invalid index shape: {idx.shape}" + idx = idx.nonzero().squeeze(1) + elif idx.dtype in [torch.int32, torch.int64]: + assert len(idx.shape) == 1, f"Invalid index shape: {idx.shape}" + else: + raise ValueError(f"Unknown index type: {idx.dtype}") + else: + raise ValueError(f"Unknown index type: {type(idx)}") + + new_coords = [] + new_feats = [] + new_layout = [] + new_shape = torch.Size([len(idx)] + list(self.shape[1:])) + start = 0 + for new_idx, old_idx in enumerate(idx): + new_coords.append(self.coords[self.layout[old_idx]].clone()) + new_coords[-1][:, 0] = new_idx + new_feats.append(self.feats[self.layout[old_idx]]) + new_layout.append(slice(start, start + len(new_coords[-1]))) + start += len(new_coords[-1]) + new_coords = torch.cat(new_coords, dim=0).contiguous() + new_feats = torch.cat(new_feats, dim=0).contiguous() + new_tensor = SparseTensor(feats=new_feats, coords=new_coords, shape=new_shape) + new_tensor.register_spatial_cache('layout', new_layout) + return new_tensor + + def clear_spatial_cache(self) -> None: + """ + Clear all spatial caches. + """ + self._spatial_cache = {} + + def register_spatial_cache(self, key, value) -> None: + """ + Register a spatial cache. + The spatial cache can be any thing you want to cache. + The registery and retrieval of the cache is based on current scale. + """ + scale_key = str(self._scale) + if scale_key not in self._spatial_cache: + self._spatial_cache[scale_key] = {} + self._spatial_cache[scale_key][key] = value + + def get_spatial_cache(self, key=None): + """ + Get a spatial cache. + """ + scale_key = str(self._scale) + cur_scale_cache = self._spatial_cache.get(scale_key, {}) + if key is None: + return cur_scale_cache + return cur_scale_cache.get(key, None) + + def __repr__(self) -> str: + return f"SparseTensor(shape={self.shape}, dtype={self.dtype}, device={self.device})" + +def sparse_cat(inputs: List[SparseTensor], dim: int = 0) -> SparseTensor: + if dim == 0: + start = 0 + coords = [] + for input in inputs: + coords.append(input.coords.clone()) + coords[-1][:, 0] += start + start += input.shape[0] + coords = torch.cat(coords, dim=0) + feats = torch.cat([input.feats for input in inputs], dim=0) + output = SparseTensor( + coords=coords, + feats=feats, + ) + else: + feats = torch.cat([input.feats for input in inputs], dim=dim) + output = inputs[0].replace(feats) + + return output + + +def sparse_unbind(input: SparseTensor, dim: int) -> List[SparseTensor]: + if dim == 0: + return [input[i] for i in range(input.shape[0])] + else: + feats = input.feats.unbind(dim) + return [input.replace(f) for f in feats] + +class SparseLinear(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super(SparseLinear, self).__init__(in_features, out_features, bias) + + def forward(self, input: VarLenTensor) -> VarLenTensor: + return input.replace(super().forward(input.feats)) + + + +class SparseUnetVaeEncoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + in_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__() + self.in_channels = in_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.dtype = torch.float16 if use_fp16 else torch.float32 + + self.input_layer = SparseLinear(in_channels, model_channels[0]) + self.to_latent = SparseLinear(model_channels[-1], 2 * latent_channels) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[down_block_type[i]]( + model_channels[i], + model_channels[i+1], + **block_args[i], + ) + ) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward(self, x: SparseTensor, sample_posterior=False, return_raw=False): + h = self.input_layer(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.to_latent(h) + + # Sample from the posterior distribution + mean, logvar = h.feats.chunk(2, dim=-1) + if sample_posterior: + std = torch.exp(0.5 * logvar) + z = mean + std * torch.randn_like(std) + else: + z = mean + z = h.replace(z) + + if return_raw: + return z, mean, logvar + else: + return z + + + +class FlexiDualGridVaeEncoder(SparseUnetVaeEncoder): + def __init__( + self, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + down_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + ): + super().__init__( + 6, + model_channels, + latent_channels, + num_blocks, + block_type, + down_block_type, + block_args, + use_fp16, + ) + + def forward(self, vertices: SparseTensor, intersected: SparseTensor, sample_posterior=False, return_raw=False): + x = vertices.replace(torch.cat([ + vertices.feats - 0.5, + intersected.feats.float() - 0.5, + ], dim=1)) + return super().forward(x, sample_posterior, return_raw) + +class SparseUnetVaeDecoder(nn.Module): + """ + Sparse Swin Transformer Unet VAE model. + """ + def __init__( + self, + out_channels: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + use_fp16: bool = False, + pred_subdiv: bool = True, + ): + super().__init__() + self.out_channels = out_channels + self.model_channels = model_channels + self.num_blocks = num_blocks + self.use_fp16 = use_fp16 + self.pred_subdiv = pred_subdiv + self.dtype = torch.float16 if use_fp16 else torch.float32 + self.low_vram = False + + self.output_layer = SparseLinear(model_channels[-1], out_channels) + self.from_latent = SparseLinear(latent_channels, model_channels[0]) + + self.blocks = nn.ModuleList([]) + for i in range(len(num_blocks)): + self.blocks.append(nn.ModuleList([])) + for j in range(num_blocks[i]): + self.blocks[-1].append( + globals()[block_type[i]]( + model_channels[i], + **block_args[i], + ) + ) + if i < len(num_blocks) - 1: + self.blocks[-1].append( + globals()[up_block_type[i]]( + model_channels[i], + model_channels[i+1], + pred_subdiv=pred_subdiv, + **block_args[i], + ) + ) + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + def forward(self, x: SparseTensor, guide_subs: Optional[List[SparseTensor]] = None, return_subs: bool = False) -> SparseTensor: + + h = self.from_latent(x) + h = h.type(self.dtype) + subs = [] + for i, res in enumerate(self.blocks): + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + if self.pred_subdiv: + h, sub = block(h) + subs.append(sub) + else: + h = block(h, subdiv=guide_subs[i] if guide_subs is not None else None) + else: + h = block(h) + h = h.type(x.dtype) + h = h.replace(F.layer_norm(h.feats, h.feats.shape[-1:])) + h = self.output_layer(h) + if return_subs: + return h, subs + else: + return h + + def upsample(self, x: SparseTensor, upsample_times: int) -> torch.Tensor: + + h = self.from_latent(x) + h = h.type(self.dtype) + for i, res in enumerate(self.blocks): + if i == upsample_times: + return h.coords + for j, block in enumerate(res): + if i < len(self.blocks) - 1 and j == len(res) - 1: + h, sub = block(h) + else: + h = block(h) + +class FlexiDualGridVaeDecoder(SparseUnetVaeDecoder): + def __init__( + self, + resolution: int, + model_channels: List[int], + latent_channels: int, + num_blocks: List[int], + block_type: List[str], + up_block_type: List[str], + block_args: List[Dict[str, Any]], + voxel_margin: float = 0.5, + use_fp16: bool = False, + ): + self.resolution = resolution + self.voxel_margin = voxel_margin + # cache for a TorchHashMap instance + self._torch_hashmap_cache = None + + super().__init__( + 7, + model_channels, + latent_channels, + num_blocks, + block_type, + up_block_type, + block_args, + use_fp16, + ) + + def set_resolution(self, resolution: int) -> None: + self.resolution = resolution + + def _build_or_get_hashmap(self, coords: torch.Tensor, grid_size: torch.Tensor): + device = coords.device + N = coords.shape[0] + # compute flat keys for all coords (prepend batch 0 same as original code) + b = torch.zeros((N,), dtype=torch.long, device=device) + x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() + W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + flat_keys = b * (W * H * D) + x * (H * D) + y * D + z + values = torch.arange(N, dtype=torch.long, device=device) + DEFAULT_VAL = 0xffffffff # sentinel used in original code + return TorchHashMap(flat_keys, values, DEFAULT_VAL) + + def forward(self, x: SparseTensor, gt_intersected: SparseTensor = None, **kwargs): + decoded = super().forward(x, **kwargs) + out_list = list(decoded) if isinstance(decoded, tuple) else [decoded] + h = out_list[0] + vertices = h.replace((1 + 2 * self.voxel_margin) * F.sigmoid(h.feats[..., 0:3]) - self.voxel_margin) + intersected = h.replace(h.feats[..., 3:6] > 0) + quad_lerp = h.replace(F.softplus(h.feats[..., 6:7])) + mesh = [Mesh(*flexible_dual_grid_to_mesh( + v.coords[:, 1:], v.feats, i.feats, q.feats, + aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], + grid_size=self.resolution, + train=False, + hashmap_builder=self._build_or_get_hashmap, + )) for v, i, q in zip(vertices, intersected, quad_lerp)] + out_list[0] = mesh + return out_list[0] if len(out_list) == 1 else tuple(out_list) + +def flexible_dual_grid_to_mesh( + coords: torch.Tensor, + dual_vertices: torch.Tensor, + intersected_flag: torch.Tensor, + split_weight: Union[torch.Tensor, None], + aabb: Union[list, tuple, np.ndarray, torch.Tensor], + voxel_size: Union[float, list, tuple, np.ndarray, torch.Tensor] = None, + grid_size: Union[int, list, tuple, np.ndarray, torch.Tensor] = None, + train: bool = False, + hashmap_builder=None, # optional callable for building/caching a TorchHashMap +): + + if not hasattr(flexible_dual_grid_to_mesh, "edge_neighbor_voxel_offset"): + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset = torch.tensor([ + [[0, 0, 0], [0, 0, 1], [0, 1, 1], [0, 1, 0]], # x-axis + [[0, 0, 0], [1, 0, 0], [1, 0, 1], [0, 0, 1]], # y-axis + [[0, 0, 0], [0, 1, 0], [1, 1, 0], [1, 0, 0]], # z-axis + ], dtype=torch.int, device=coords.device).unsqueeze(0) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_1"): + flexible_dual_grid_to_mesh.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=coords.device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_2"): + flexible_dual_grid_to_mesh.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=coords.device, requires_grad=False) + if not hasattr(flexible_dual_grid_to_mesh, "quad_split_train"): + flexible_dual_grid_to_mesh.quad_split_train = torch.tensor([0, 1, 4, 1, 2, 4, 2, 3, 4, 3, 0, 4], dtype=torch.long, device=coords.device, requires_grad=False) + + # AABB + if isinstance(aabb, (list, tuple)): + aabb = np.array(aabb) + if isinstance(aabb, np.ndarray): + aabb = torch.tensor(aabb, dtype=torch.float32, device=coords.device) + + # Voxel size + if voxel_size is not None: + if isinstance(voxel_size, float): + voxel_size = [voxel_size, voxel_size, voxel_size] + if isinstance(voxel_size, (list, tuple)): + voxel_size = np.array(voxel_size) + if isinstance(voxel_size, np.ndarray): + voxel_size = torch.tensor(voxel_size, dtype=torch.float32, device=coords.device) + grid_size = ((aabb[1] - aabb[0]) / voxel_size).round().int() + else: + if isinstance(grid_size, int): + grid_size = [grid_size, grid_size, grid_size] + if isinstance(grid_size, (list, tuple)): + grid_size = np.array(grid_size) + if isinstance(grid_size, np.ndarray): + grid_size = torch.tensor(grid_size, dtype=torch.int32, device=coords.device) + voxel_size = (aabb[1] - aabb[0]) / grid_size + + # Extract mesh + N = dual_vertices.shape[0] + mesh_vertices = (coords.float() + dual_vertices) / (2 * N) - 0.5 + + if hashmap_builder is None: + # build local TorchHashMap + device = coords.device + b = torch.zeros((N,), dtype=torch.long, device=device) + x, y, z = coords[:, 0].long(), coords[:, 1].long(), coords[:, 2].long() + W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + flat_keys = b * (W * H * D) + x * (H * D) + y * D + z + values = torch.arange(N, dtype=torch.long, device=device) + DEFAULT_VAL = 0xffffffff + torch_hashmap = TorchHashMap(flat_keys, values, DEFAULT_VAL) + else: + torch_hashmap = hashmap_builder(coords, grid_size) + + # Find connected voxels + edge_neighbor_voxel = coords.reshape(N, 1, 1, 3) + flexible_dual_grid_to_mesh.edge_neighbor_voxel_offset # (N, 3, 4, 3) + connected_voxel = edge_neighbor_voxel[intersected_flag] # (M, 4, 3) + M = connected_voxel.shape[0] + # flatten connected voxel coords and lookup + conn_flat_b = torch.zeros((M * 4,), dtype=torch.long, device=coords.device) + conn_x = connected_voxel.reshape(-1, 3)[:, 0].long() + conn_y = connected_voxel.reshape(-1, 3)[:, 1].long() + conn_z = connected_voxel.reshape(-1, 3)[:, 2].long() + W, H, D = int(grid_size[0].item()), int(grid_size[1].item()), int(grid_size[2].item()) + conn_flat = conn_flat_b * (W * H * D) + conn_x * (H * D) + conn_y * D + conn_z + + conn_indices = torch_hashmap.lookup_flat(conn_flat).reshape(M, 4).int() + connected_voxel_valid = (conn_indices != 0xffffffff).all(dim=1) + quad_indices = conn_indices[connected_voxel_valid].int() # (L, 4) + + mesh_vertices = (coords.float() + dual_vertices) * voxel_size + aabb[0].reshape(1, 3) + if split_weight is None: + # if split 1 + atempt_triangles_0 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1] + normals0 = torch.cross(mesh_vertices[atempt_triangles_0[:, 1]] - mesh_vertices[atempt_triangles_0[:, 0]], mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 0]]) + normals1 = torch.cross(mesh_vertices[atempt_triangles_0[:, 2]] - mesh_vertices[atempt_triangles_0[:, 1]], mesh_vertices[atempt_triangles_0[:, 3]] - mesh_vertices[atempt_triangles_0[:, 1]]) + align0 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() + # if split 2 + atempt_triangles_1 = quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + normals0 = torch.cross(mesh_vertices[atempt_triangles_1[:, 1]] - mesh_vertices[atempt_triangles_1[:, 0]], mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 0]]) + normals1 = torch.cross(mesh_vertices[atempt_triangles_1[:, 2]] - mesh_vertices[atempt_triangles_1[:, 1]], mesh_vertices[atempt_triangles_1[:, 3]] - mesh_vertices[atempt_triangles_1[:, 1]]) + align1 = (normals0 * normals1).sum(dim=1, keepdim=True).abs() + # select split + mesh_triangles = torch.where(align0 > align1, atempt_triangles_0, atempt_triangles_1).reshape(-1, 3) + else: + split_weight_ws = split_weight[quad_indices] + split_weight_ws_02 = split_weight_ws[:, 0] * split_weight_ws[:, 2] + split_weight_ws_13 = split_weight_ws[:, 1] * split_weight_ws[:, 3] + mesh_triangles = torch.where( + split_weight_ws_02 > split_weight_ws_13, + quad_indices[:, flexible_dual_grid_to_mesh.quad_split_1], + quad_indices[:, flexible_dual_grid_to_mesh.quad_split_2] + ).reshape(-1, 3) + + return mesh_vertices, mesh_triangles + +class Vae(nn.Module): + def __init__(self, config, operations=None): + operations = operations or torch.nn + + self.txt_dec = SparseUnetVaeDecoder( + out_channels=6, + model_channels=[1024, 512, 256, 128, 64], + latent_channels=32, + num_blocks=[4, 16, 8, 4, 0], + block_type=["SparseConvNeXtBlock3d"] * 5, + up_block_type=["SparseResBlockS2C3d"] * 4, + pred_subdiv=False + ) + + self.shape_dec = FlexiDualGridVaeDecoder( + resolution=256, + model_channels=[1024, 512, 256, 128, 64], + latent_channels=32, + num_blocks=[4, 16, 8, 4, 0], + block_type=["SparseConvNeXtBlock3d"] * 5, + up_block_type=["SparseResBlockS2C3d"] * 4, + ) + + def decode_shape_slat(self, slat, resolution: int): + self.shape_dec.set_resolution(resolution) + return self.shape_dec(slat, return_subs=True) + + def decode_tex_slat(self, slat, subs): + return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 + + @torch.no_grad() + def decode( + self, + shape_slat: SparseTensor, + tex_slat: SparseTensor, + resolution: int, + ): + meshes, subs = self.decode_shape_slat(shape_slat, resolution) + tex_voxels = self.decode_tex_slat(tex_slat, subs) + out_mesh = [] + for m, v in zip(meshes, tex_voxels): + m.fill_holes() # TODO + out_mesh.append( + MeshWithVoxel( + m.vertices, m.faces, + origin = [-0.5, -0.5, -0.5], + voxel_size = 1 / resolution, + coords = v.coords[:, 1:], + attrs = v.feats, + voxel_shape = torch.Size([*v.shape, *v.spatial_shape]), + layout=self.pbr_attr_layout + ) + ) + return out_mesh diff --git a/comfy_extras/trellis2.py b/comfy_extras/trellis2.py new file mode 100644 index 000000000..c3ad56007 --- /dev/null +++ b/comfy_extras/trellis2.py @@ -0,0 +1,240 @@ +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO +import torch +from comfy.ldm.trellis2.model import SparseTensor +import comfy.model_management +from PIL import Image +import PIL +import numpy as np + +shape_slat_normalization = { + "mean": torch.tensor([ + 0.781296, 0.018091, -0.495192, -0.558457, 1.060530, 0.093252, 1.518149, -0.933218, + -0.732996, 2.604095, -0.118341, -2.143904, 0.495076, -2.179512, -2.130751, -0.996944, + 0.261421, -2.217463, 1.260067, -0.150213, 3.790713, 1.481266, -1.046058, -1.523667, + -0.059621, 2.220780, 1.621212, 0.877230, 0.567247, -3.175944, -3.186688, 1.578665 + ])[None], + "std": torch.tensor([ + 5.972266, 4.706852, 5.445010, 5.209927, 5.320220, 4.547237, 5.020802, 5.444004, + 5.226681, 5.683095, 4.831436, 5.286469, 5.652043, 5.367606, 5.525084, 4.730578, + 4.805265, 5.124013, 5.530808, 5.619001, 5.103930, 5.417670, 5.269677, 5.547194, + 5.634698, 5.235274, 6.110351, 5.511298, 6.237273, 4.879207, 5.347008, 5.405691 + ])[None] +} + +tex_slat_normalization = { + "mean": torch.tensor([ + 3.501659, 2.212398, 2.226094, 0.251093, -0.026248, -0.687364, 0.439898, -0.928075, + 0.029398, -0.339596, -0.869527, 1.038479, -0.972385, 0.126042, -1.129303, 0.455149, + -1.209521, 2.069067, 0.544735, 2.569128, -0.323407, 2.293000, -1.925608, -1.217717, + 1.213905, 0.971588, -0.023631, 0.106750, 2.021786, 0.250524, -0.662387, -0.768862 + ])[None], + "std": torch.tensor([ + 2.665652, 2.743913, 2.765121, 2.595319, 3.037293, 2.291316, 2.144656, 2.911822, + 2.969419, 2.501689, 2.154811, 3.163343, 2.621215, 2.381943, 3.186697, 3.021588, + 2.295916, 3.234985, 3.233086, 2.260140, 2.874801, 2.810596, 3.292720, 2.674999, + 2.680878, 2.372054, 2.451546, 2.353556, 2.995195, 2.379849, 2.786195, 2.775190 + ])[None] +} + +def smart_crop_square( + image: torch.Tensor, + background_color=(128, 128, 128), +): + C, H, W = image.shape + size = max(H, W) + canvas = torch.empty( + (C, size, size), + dtype=image.dtype, + device=image.device + ) + for c in range(C): + canvas[c].fill_(background_color[c]) + top = (size - H) // 2 + left = (size - W) // 2 + canvas[:, top:top + H, left:left + W] = image + + return canvas + +def run_conditioning( + model, + image: torch.Tensor, + include_1024: bool = True, + background_color: str = "black", +): + # TODO: should check if normalization was applied in these steps + model = model.model + device = comfy.model_management.intermediate_device() # replaces .cpu() + torch_device = comfy.model_management.get_torch_device() # replaces .cuda() + bg_colors = { + "black": (0, 0, 0), + "gray": (128, 128, 128), + "white": (255, 255, 255), + } + bg_color = bg_colors.get(background_color, (128, 128, 128)) + + # Convert image to PIL + if image.dim() == 4: + pil_image = (image[0] * 255).clip(0, 255).astype(torch.uint8) + else: + pil_image = (image * 255).clip(0, 255).astype(torch.uint8) + + pil_image = smart_crop_square(pil_image, background_color=bg_color) + + model.image_size = 512 + def set_image_size(image, image_size=512): + image = PIL.from_array(image) + image = [i.resize((image_size, image_size), Image.LANCZOS) for i in image] + image = [np.array(i.convert('RGB')).astype(np.float32) / 255 for i in image] + image = [torch.from_numpy(i).permute(2, 0, 1).float() for i in image] + image = torch.stack(image).to(torch_device) + return image + + pil_image = set_image_size(image, 512) + cond_512 = model([pil_image]) + + cond_1024 = None + if include_1024: + model.image_size = 1024 + pil_image = set_image_size(pil_image, 1024) + cond_1024 = model([pil_image]) + + neg_cond = torch.zeros_like(cond_512) + + conditioning = { + 'cond_512': cond_512.to(device), + 'neg_cond': neg_cond.to(device), + } + if cond_1024 is not None: + conditioning['cond_1024'] = cond_1024.to(device) + + preprocessed_tensor = pil_image.to(torch.float32) / 255.0 + preprocessed_tensor = torch.from_numpy(preprocessed_tensor).unsqueeze(0) + + return conditioning, preprocessed_tensor + +class VaeDecodeShapeTrellis(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VaeDecodeShapeTrellis", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.Int.Input("resolution", tooltip="Shape Generation Resolution"), + ], + outputs=[ + IO.Mesh.Output("mesh"), + IO.AnyType.Output("shape_subs"), + ] + ) + + @classmethod + def execute(cls, samples, vae, resolution): + std = shape_slat_normalization["std"] + mean = shape_slat_normalization["mean"] + samples = samples * std + mean + + mesh, subs = vae.decode_shape_slat(resolution, samples) + return mesh, subs + +class VaeDecodeTextureTrellis(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="VaeDecodeTextureTrellis", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.AnyType.Input("shape_subs"), + ], + outputs=[ + IO.Mesh.Output("mesh"), + ] + ) + + @classmethod + def execute(cls, samples, vae, shape_subs): + if shape_subs is None: + raise ValueError("Shape subs must be provided for texture generation") + + std = tex_slat_normalization["std"] + mean = tex_slat_normalization["mean"] + samples = samples * std + mean + + mesh = vae.decode_tex_slat(samples, shape_subs) + return mesh + +class Trellis2Conditioning(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Trellis2Conditioning", + category="conditioning/video_models", + inputs=[ + IO.ClipVision.Input("clip_vision_model"), + IO.Image.Input("image"), + IO.MultiCombo.Input("background_color", options=["black", "gray", "white"], default="black") + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ] + ) + + @classmethod + def execute(cls, clip_vision_model, image, background_color) -> IO.NodeOutput: + # could make 1024 an option + conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color) + embeds = conditioning["cond_1024"] # should add that + positive = [[conditioning["cond_512"], {embeds}]] + negative = [[conditioning["cond_neg"], {embeds}]] + return IO.NodeOutput(positive, negative) + +class EmptyLatentTrellis2(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentTrellis2", + category="latent/3d", + inputs=[ + IO.Int.Input("resolution", default=3072, min=1, max=8192), + IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), + IO.Vae.Input("vae"), + IO.Boolean.Input("shape_generation", tooltip="Setting to false will generate texture."), + IO.MultiCombo.Input("generation_type", options=["structure_generation", "shape_generation", "texture_generation"]) + ], + outputs=[ + IO.Latent.Output(), + ] + ) + + @classmethod + def execute(cls, batch_size, coords, vae, generation_type) -> IO.NodeOutput: + # TODO: i will probably update how shape/texture is generated + # could split this too + in_channels = 32 + shape_generation = generation_type == "shape_generation" + device = comfy.model_management.intermediate_device() + if shape_generation: + latent = SparseTensor(feats=torch.randn(batch_size, in_channels).to(device), coords=coords) + else: + # coords = shape_slat in txt gen case + latent = coords.replace(feats=torch.randn(coords.coords.shape[0], in_channels - coords.feats.shape[1]).to(device)) + return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + +class Trellis2Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + Trellis2Conditioning, + EmptyLatentTrellis2, + VaeDecodeTextureTrellis, + VaeDecodeShapeTrellis + ] + + +async def comfy_entrypoint() -> Trellis2Extension: + return Trellis2Extension()