mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-11 17:52:46 +08:00
Merge 81ed835ffb into 41d73ad180
This commit is contained in:
commit
2e6958b113
@ -9,6 +9,7 @@ import comfy.model_management
|
|||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.clip_model
|
import comfy.clip_model
|
||||||
import comfy.image_encoders.dino2
|
import comfy.image_encoders.dino2
|
||||||
|
import comfy.image_encoders.dino3
|
||||||
|
|
||||||
class Output:
|
class Output:
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key):
|
||||||
@ -23,6 +24,7 @@ IMAGE_ENCODERS = {
|
|||||||
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
"siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
"siglip2_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||||
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
"dinov2": comfy.image_encoders.dino2.Dinov2Model,
|
||||||
|
"dinov3": comfy.image_encoders.dino3.DINOv3ViTModel
|
||||||
}
|
}
|
||||||
|
|
||||||
class ClipVisionModel():
|
class ClipVisionModel():
|
||||||
@ -134,6 +136,8 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
|||||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
|
||||||
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
elif 'encoder.layer.23.layer_scale2.lambda1' in sd:
|
||||||
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_large.json")
|
||||||
|
elif 'layer.9.attention.o_proj.bias' in sd: # dinov3
|
||||||
|
json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino3_large.json")
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
285
comfy/image_encoders/dino3.py
Normal file
285
comfy/image_encoders/dino3.py
Normal file
@ -0,0 +1,285 @@
|
|||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import comfy.model_management
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||||
|
from comfy.image_encoders.dino2 import LayerScale as DINOv3ViTLayerScale
|
||||||
|
|
||||||
|
class DINOv3ViTMLP(nn.Module):
|
||||||
|
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||||
|
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||||
|
self.act_fn = torch.nn.GELU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.down_proj(self.act_fn(self.up_proj(x)))
|
||||||
|
|
||||||
|
def rotate_half(x):
|
||||||
|
x1 = x[..., : x.shape[-1] // 2]
|
||||||
|
x2 = x[..., x.shape[-1] // 2 :]
|
||||||
|
return torch.cat((-x2, x1), dim=-1)
|
||||||
|
def apply_rotary_pos_emb(q, k, cos, sin, **kwargs):
|
||||||
|
num_tokens = q.shape[-2]
|
||||||
|
num_patches = sin.shape[-2]
|
||||||
|
num_prefix_tokens = num_tokens - num_patches
|
||||||
|
|
||||||
|
q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
|
||||||
|
k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
|
||||||
|
|
||||||
|
q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
|
||||||
|
k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
|
||||||
|
|
||||||
|
q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
|
||||||
|
k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
|
||||||
|
|
||||||
|
return q, k
|
||||||
|
|
||||||
|
class DINOv3ViTAttention(nn.Module):
|
||||||
|
def __init__(self, hidden_size, num_attention_heads, device, dtype, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.embed_dim = hidden_size
|
||||||
|
self.num_heads = num_attention_heads
|
||||||
|
self.head_dim = self.embed_dim // self.num_heads
|
||||||
|
|
||||||
|
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
|
||||||
|
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.q_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||||
|
self.o_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||||
|
|
||||||
|
batch_size, patches, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
|
||||||
|
|
||||||
|
if position_embeddings is not None:
|
||||||
|
cos, sin = position_embeddings
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
||||||
|
|
||||||
|
attn = optimized_attention_for_device(query_states.device, mask=False)
|
||||||
|
|
||||||
|
attn_output = attn(
|
||||||
|
query_states, key_states, value_states, self.num_heads, attention_mask, skip_reshape=True, skip_output_reshape=True
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2)
|
||||||
|
attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
class DINOv3ViTGatedMLP(nn.Module):
|
||||||
|
def __init__(self, hidden_size, intermediate_size, mlp_bias, device, dtype, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||||
|
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||||
|
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=mlp_bias, device=device, dtype=dtype)
|
||||||
|
self.act_fn = torch.nn.GELU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||||
|
return down_proj
|
||||||
|
|
||||||
|
def get_patches_center_coordinates(
|
||||||
|
num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
|
||||||
|
coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
|
||||||
|
coords_h = coords_h / num_patches_h
|
||||||
|
coords_w = coords_w / num_patches_w
|
||||||
|
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
|
||||||
|
coords = coords.flatten(0, 1)
|
||||||
|
coords = 2.0 * coords - 1.0
|
||||||
|
return coords
|
||||||
|
|
||||||
|
class DINOv3ViTRopePositionEmbedding(nn.Module):
|
||||||
|
inv_freq: torch.Tensor
|
||||||
|
|
||||||
|
def __init__(self, rope_theta, hidden_size, num_attention_heads, image_size, patch_size, device, dtype):
|
||||||
|
super().__init__()
|
||||||
|
self.base = rope_theta
|
||||||
|
self.head_dim = hidden_size // num_attention_heads
|
||||||
|
self.num_patches_h = image_size // patch_size
|
||||||
|
self.num_patches_w = image_size // patch_size
|
||||||
|
self.patch_size = patch_size
|
||||||
|
|
||||||
|
inv_freq = 1 / self.base ** torch.arange(0, 1, 4 / self.head_dim, dtype=torch.float32, device=device)
|
||||||
|
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
_, _, height, width = pixel_values.shape
|
||||||
|
num_patches_h = height // self.patch_size
|
||||||
|
num_patches_w = width // self.patch_size
|
||||||
|
|
||||||
|
device = pixel_values.device
|
||||||
|
device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
|
||||||
|
with torch.amp.autocast(device_type = device_type, enabled=False):
|
||||||
|
patch_coords = get_patches_center_coordinates(
|
||||||
|
num_patches_h, num_patches_w, dtype=torch.float32, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
self.inv_freq = self.inv_freq.to(device)
|
||||||
|
angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
|
||||||
|
angles = angles.flatten(1, 2)
|
||||||
|
angles = angles.tile(2)
|
||||||
|
|
||||||
|
cos = torch.cos(angles)
|
||||||
|
sin = torch.sin(angles)
|
||||||
|
|
||||||
|
dtype = pixel_values.dtype
|
||||||
|
return cos.to(dtype=dtype), sin.to(dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv3ViTEmbeddings(nn.Module):
|
||||||
|
def __init__(self, hidden_size, num_register_tokens, num_channels, patch_size, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_size, device=device, dtype=dtype))
|
||||||
|
self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size, device=device, dtype=dtype))
|
||||||
|
self.register_tokens = nn.Parameter(torch.empty(1, num_register_tokens, hidden_size, device=device, dtype=dtype))
|
||||||
|
self.patch_embeddings = operations.Conv2d(
|
||||||
|
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size, device=device, dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None):
|
||||||
|
batch_size = pixel_values.shape[0]
|
||||||
|
target_dtype = self.patch_embeddings.weight.dtype
|
||||||
|
|
||||||
|
patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
|
||||||
|
patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
|
||||||
|
|
||||||
|
if bool_masked_pos is not None:
|
||||||
|
mask_token = self.mask_token.to(patch_embeddings.dtype)
|
||||||
|
patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
|
||||||
|
|
||||||
|
cls_token = self.cls_token.expand(batch_size, -1, -1)
|
||||||
|
register_tokens = self.register_tokens.expand(batch_size, -1, -1)
|
||||||
|
device = patch_embeddings.device
|
||||||
|
cls_token = cls_token.to(device)
|
||||||
|
register_tokens = register_tokens.to(device)
|
||||||
|
embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
class DINOv3ViTLayer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, hidden_size, layer_norm_eps, use_gated_mlp, mlp_bias, intermediate_size, num_attention_heads,
|
||||||
|
device, dtype, operations):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.norm1 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||||
|
self.attention = DINOv3ViTAttention(hidden_size, num_attention_heads, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.layer_scale1 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||||
|
|
||||||
|
self.norm2 = operations.LayerNorm(hidden_size, eps=layer_norm_eps, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
if use_gated_mlp:
|
||||||
|
self.mlp = DINOv3ViTGatedMLP(hidden_size, intermediate_size, mlp_bias, device=device, dtype=dtype, operations=operations)
|
||||||
|
else:
|
||||||
|
self.mlp = DINOv3ViTMLP(hidden_size, intermediate_size=intermediate_size, mlp_bias=mlp_bias, device=device, dtype=dtype, operations=operations)
|
||||||
|
self.layer_scale2 = DINOv3ViTLayerScale(hidden_size, device=device, dtype=dtype, operations=None)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
|
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm1(hidden_states)
|
||||||
|
hidden_states = self.attention(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
hidden_states = self.layer_scale1(hidden_states)
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.norm2(hidden_states)
|
||||||
|
hidden_states = self.mlp(hidden_states)
|
||||||
|
hidden_states = self.layer_scale2(hidden_states)
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
class DINOv3ViTModel(nn.Module):
|
||||||
|
def __init__(self, config, dtype, device, operations):
|
||||||
|
super().__init__()
|
||||||
|
use_bf16 = comfy.model_management.should_use_bf16(device, prioritize_performance=True)
|
||||||
|
if dtype == torch.float16 and use_bf16:
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
elif dtype == torch.float16 and not use_bf16:
|
||||||
|
dtype = torch.float32
|
||||||
|
num_hidden_layers = config["num_hidden_layers"]
|
||||||
|
hidden_size = config["hidden_size"]
|
||||||
|
num_attention_heads = config["num_attention_heads"]
|
||||||
|
num_register_tokens = config["num_register_tokens"]
|
||||||
|
intermediate_size = config["intermediate_size"]
|
||||||
|
layer_norm_eps = config["layer_norm_eps"]
|
||||||
|
num_channels = config["num_channels"]
|
||||||
|
patch_size = config["patch_size"]
|
||||||
|
rope_theta = config["rope_theta"]
|
||||||
|
|
||||||
|
self.embeddings = DINOv3ViTEmbeddings(
|
||||||
|
hidden_size, num_register_tokens, num_channels=num_channels, patch_size=patch_size, dtype=dtype, device=device, operations=operations
|
||||||
|
)
|
||||||
|
self.rope_embeddings = DINOv3ViTRopePositionEmbedding(
|
||||||
|
rope_theta, hidden_size, num_attention_heads, image_size=512, patch_size=patch_size, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
self.layer = nn.ModuleList(
|
||||||
|
[DINOv3ViTLayer(hidden_size, layer_norm_eps, use_gated_mlp=False, mlp_bias=True,
|
||||||
|
intermediate_size=intermediate_size,num_attention_heads = num_attention_heads,
|
||||||
|
dtype=dtype, device=device, operations=operations)
|
||||||
|
for _ in range(num_hidden_layers)])
|
||||||
|
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
def get_input_embeddings(self):
|
||||||
|
return self.embeddings.patch_embeddings
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
pixel_values: torch.Tensor,
|
||||||
|
bool_masked_pos: torch.Tensor | None = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
|
||||||
|
pixel_values = pixel_values.to(self.embeddings.patch_embeddings.weight.dtype)
|
||||||
|
hidden_states = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos)
|
||||||
|
position_embeddings = self.rope_embeddings(pixel_values)
|
||||||
|
|
||||||
|
for i, layer_module in enumerate(self.layer):
|
||||||
|
hidden_states = layer_module(
|
||||||
|
hidden_states,
|
||||||
|
position_embeddings=position_embeddings,
|
||||||
|
)
|
||||||
|
|
||||||
|
if kwargs.get("skip_norm_elementwise", False):
|
||||||
|
sequence_output= F.layer_norm(hidden_states, hidden_states.shape[-1:])
|
||||||
|
else:
|
||||||
|
norm = self.norm.to(hidden_states.device)
|
||||||
|
sequence_output = norm(hidden_states)
|
||||||
|
pooled_output = sequence_output[:, 0, :]
|
||||||
|
|
||||||
|
return sequence_output, None, pooled_output, None
|
||||||
23
comfy/image_encoders/dino3_large.json
Normal file
23
comfy/image_encoders/dino3_large.json
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
{
|
||||||
|
"model_type": "dinov3",
|
||||||
|
"hidden_size": 1024,
|
||||||
|
"image_size": 224,
|
||||||
|
"initializer_range": 0.02,
|
||||||
|
"intermediate_size": 4096,
|
||||||
|
"key_bias": false,
|
||||||
|
"layer_norm_eps": 1e-05,
|
||||||
|
"mlp_bias": true,
|
||||||
|
"num_attention_heads": 16,
|
||||||
|
"num_channels": 3,
|
||||||
|
"num_hidden_layers": 24,
|
||||||
|
"num_register_tokens": 4,
|
||||||
|
"patch_size": 16,
|
||||||
|
"pos_embed_rescale": 2.0,
|
||||||
|
"proj_bias": true,
|
||||||
|
"query_bias": true,
|
||||||
|
"rope_theta": 100.0,
|
||||||
|
"use_gated_mlp": false,
|
||||||
|
"value_bias": true,
|
||||||
|
"image_mean": [0.485, 0.456, 0.406],
|
||||||
|
"image_std": [0.229, 0.224, 0.225]
|
||||||
|
}
|
||||||
@ -747,6 +747,8 @@ class Hunyuan3Dv2_1(LatentFormat):
|
|||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
|
||||||
|
class Trellis2(LatentFormat): # TODO
|
||||||
|
latent_channels = 32
|
||||||
class Hunyuan3Dv2mini(LatentFormat):
|
class Hunyuan3Dv2mini(LatentFormat):
|
||||||
latent_channels = 64
|
latent_channels = 64
|
||||||
latent_dimensions = 1
|
latent_dimensions = 1
|
||||||
|
|||||||
282
comfy/ldm/trellis2/attention.py
Normal file
282
comfy/ldm/trellis2/attention.py
Normal file
@ -0,0 +1,282 @@
|
|||||||
|
import torch
|
||||||
|
import math
|
||||||
|
from comfy.ldm.modules.attention import optimized_attention
|
||||||
|
from typing import Tuple, Union, List
|
||||||
|
from comfy.ldm.trellis2.vae import VarLenTensor
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
|
||||||
|
# replica of the seedvr2 code
|
||||||
|
def var_attn_arg(kwargs):
|
||||||
|
cu_seqlens_q = kwargs.get("cu_seqlens_q", None)
|
||||||
|
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||||
|
cu_seqlens_k = kwargs.get("cu_seqlens_kv", cu_seqlens_q)
|
||||||
|
max_seqlen_k = kwargs.get("max_kv_seqlen", max_seqlen_q)
|
||||||
|
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||||
|
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||||
|
|
||||||
|
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||||
|
var_length = True
|
||||||
|
if var_length:
|
||||||
|
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
||||||
|
if not skip_reshape:
|
||||||
|
# assumes 2D q, k,v [total_tokens, embed_dim]
|
||||||
|
total_tokens, embed_dim = q.shape
|
||||||
|
head_dim = embed_dim // heads
|
||||||
|
q = q.view(total_tokens, heads, head_dim)
|
||||||
|
k = k.view(k.shape[0], heads, head_dim)
|
||||||
|
v = v.view(v.shape[0], heads, head_dim)
|
||||||
|
|
||||||
|
b = q.size(0)
|
||||||
|
dim_head = q.shape[-1]
|
||||||
|
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||||
|
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||||
|
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||||
|
|
||||||
|
mask = None
|
||||||
|
q = q.transpose(1, 2)
|
||||||
|
k = k.transpose(1, 2)
|
||||||
|
v = v.transpose(1, 2)
|
||||||
|
|
||||||
|
if mask is not None:
|
||||||
|
if mask.ndim == 2:
|
||||||
|
mask = mask.unsqueeze(0)
|
||||||
|
if mask.ndim == 3:
|
||||||
|
mask = mask.unsqueeze(1)
|
||||||
|
|
||||||
|
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||||
|
if var_length:
|
||||||
|
return out.transpose(1, 2).values()
|
||||||
|
if not skip_output_reshape:
|
||||||
|
out = (
|
||||||
|
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
def scaled_dot_product_attention(*args, **kwargs):
|
||||||
|
num_all_args = len(args) + len(kwargs)
|
||||||
|
|
||||||
|
q = None
|
||||||
|
if num_all_args == 1:
|
||||||
|
qkv = args[0] if len(args) > 0 else kwargs.get('qkv')
|
||||||
|
elif num_all_args == 2:
|
||||||
|
q = args[0] if len(args) > 0 else kwargs.get('q')
|
||||||
|
kv = args[1] if len(args) > 1 else kwargs.get('kv')
|
||||||
|
elif num_all_args == 3:
|
||||||
|
q = args[0] if len(args) > 0 else kwargs.get('q')
|
||||||
|
k = args[1] if len(args) > 1 else kwargs.get('k')
|
||||||
|
v = args[2] if len(args) > 2 else kwargs.get('v')
|
||||||
|
|
||||||
|
if q is not None:
|
||||||
|
heads = q.shape[2]
|
||||||
|
else:
|
||||||
|
heads = qkv.shape[3]
|
||||||
|
|
||||||
|
if num_all_args == 1:
|
||||||
|
q, k, v = qkv.unbind(dim=2)
|
||||||
|
elif num_all_args == 2:
|
||||||
|
k, v = kv.unbind(dim=2)
|
||||||
|
|
||||||
|
q = q.permute(0, 2, 1, 3)
|
||||||
|
k = k.permute(0, 2, 1, 3)
|
||||||
|
v = v.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True, **kwargs)
|
||||||
|
|
||||||
|
out = out.permute(0, 2, 1, 3)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def sparse_windowed_scaled_dot_product_self_attention(
|
||||||
|
qkv,
|
||||||
|
window_size: int,
|
||||||
|
shift_window: Tuple[int, int, int] = (0, 0, 0)
|
||||||
|
):
|
||||||
|
|
||||||
|
serialization_spatial_cache_name = f'windowed_attention_{window_size}_{shift_window}'
|
||||||
|
serialization_spatial_cache = qkv.get_spatial_cache(serialization_spatial_cache_name)
|
||||||
|
if serialization_spatial_cache is None:
|
||||||
|
fwd_indices, bwd_indices, seq_lens, attn_func_args = calc_window_partition(qkv, window_size, shift_window)
|
||||||
|
qkv.register_spatial_cache(serialization_spatial_cache_name, (fwd_indices, bwd_indices, seq_lens, attn_func_args))
|
||||||
|
else:
|
||||||
|
fwd_indices, bwd_indices, seq_lens, attn_func_args = serialization_spatial_cache
|
||||||
|
|
||||||
|
qkv_feats = qkv.feats[fwd_indices] # [M, 3, H, C]
|
||||||
|
heads = qkv_feats.shape[2]
|
||||||
|
|
||||||
|
if optimized_attention.__name__ == 'attention_xformers':
|
||||||
|
q, k, v = qkv_feats.unbind(dim=1)
|
||||||
|
q = q.unsqueeze(0) # [1, M, H, C]
|
||||||
|
k = k.unsqueeze(0) # [1, M, H, C]
|
||||||
|
v = v.unsqueeze(0) # [1, M, H, C]
|
||||||
|
#out = xops.memory_efficient_attention(q, k, v, **attn_func_args)[0] # [M, H, C]
|
||||||
|
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||||
|
elif optimized_attention.__name__ == 'attention_flash':
|
||||||
|
if 'flash_attn' not in globals():
|
||||||
|
import flash_attn
|
||||||
|
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, **attn_func_args) # [M, H, C]
|
||||||
|
else:
|
||||||
|
out = optimized_attention(q, k, v, heads, skip_output_reshape=True, skip_reshape=True)
|
||||||
|
|
||||||
|
out = out[bwd_indices] # [T, H, C]
|
||||||
|
|
||||||
|
return qkv.replace(out)
|
||||||
|
|
||||||
|
def calc_window_partition(
|
||||||
|
tensor,
|
||||||
|
window_size: Union[int, Tuple[int, ...]],
|
||||||
|
shift_window: Union[int, Tuple[int, ...]] = 0,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, List[int], List[int]]:
|
||||||
|
|
||||||
|
DIM = tensor.coords.shape[1] - 1
|
||||||
|
shift_window = (shift_window,) * DIM if isinstance(shift_window, int) else shift_window
|
||||||
|
window_size = (window_size,) * DIM if isinstance(window_size, int) else window_size
|
||||||
|
shifted_coords = tensor.coords.clone().detach()
|
||||||
|
shifted_coords[:, 1:] += torch.tensor(shift_window, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
||||||
|
|
||||||
|
MAX_COORDS = [i + j for i, j in zip(tensor.spatial_shape, shift_window)]
|
||||||
|
NUM_WINDOWS = [math.ceil((mc + 1) / ws) for mc, ws in zip(MAX_COORDS, window_size)]
|
||||||
|
OFFSET = torch.cumprod(torch.tensor([1] + NUM_WINDOWS[::-1]), dim=0).tolist()[::-1]
|
||||||
|
|
||||||
|
shifted_coords[:, 1:] //= torch.tensor(window_size, device=tensor.device, dtype=torch.int32).unsqueeze(0)
|
||||||
|
shifted_indices = (shifted_coords * torch.tensor(OFFSET, device=tensor.device, dtype=torch.int32).unsqueeze(0)).sum(dim=1)
|
||||||
|
fwd_indices = torch.argsort(shifted_indices)
|
||||||
|
bwd_indices = torch.empty_like(fwd_indices)
|
||||||
|
bwd_indices[fwd_indices] = torch.arange(fwd_indices.shape[0], device=tensor.device)
|
||||||
|
seq_lens = torch.bincount(shifted_indices)
|
||||||
|
mask = seq_lens != 0
|
||||||
|
seq_lens = seq_lens[mask]
|
||||||
|
|
||||||
|
if optimized_attention.__name__ == 'attention_xformers':
|
||||||
|
if 'xops' not in globals():
|
||||||
|
import xformers.ops as xops
|
||||||
|
attn_func_args = {
|
||||||
|
'attn_bias': xops.fmha.BlockDiagonalMask.from_seqlens(seq_lens)
|
||||||
|
}
|
||||||
|
elif optimized_attention.__name__ == 'attention_flash':
|
||||||
|
attn_func_args = {
|
||||||
|
'cu_seqlens': torch.cat([torch.tensor([0], device=tensor.device), torch.cumsum(seq_lens, dim=0)], dim=0).int(),
|
||||||
|
'max_seqlen': torch.max(seq_lens)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fwd_indices, bwd_indices, seq_lens, attn_func_args
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_scaled_dot_product_attention(*args, **kwargs):
|
||||||
|
q=None
|
||||||
|
arg_names_dict = {
|
||||||
|
1: ['qkv'],
|
||||||
|
2: ['q', 'kv'],
|
||||||
|
3: ['q', 'k', 'v']
|
||||||
|
}
|
||||||
|
num_all_args = len(args) + len(kwargs)
|
||||||
|
for key in arg_names_dict[num_all_args][len(args):]:
|
||||||
|
assert key in kwargs, f"Missing argument {key}"
|
||||||
|
|
||||||
|
if num_all_args == 1:
|
||||||
|
qkv = args[0] if len(args) > 0 else kwargs['qkv']
|
||||||
|
device = qkv.device
|
||||||
|
|
||||||
|
s = qkv
|
||||||
|
q_seqlen = [qkv.layout[i].stop - qkv.layout[i].start for i in range(qkv.shape[0])]
|
||||||
|
kv_seqlen = q_seqlen
|
||||||
|
qkv = qkv.feats # [T, 3, H, C]
|
||||||
|
|
||||||
|
elif num_all_args == 2:
|
||||||
|
q = args[0] if len(args) > 0 else kwargs['q']
|
||||||
|
kv = args[1] if len(args) > 1 else kwargs['kv']
|
||||||
|
device = q.device
|
||||||
|
|
||||||
|
if isinstance(q, VarLenTensor):
|
||||||
|
s = q
|
||||||
|
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
|
||||||
|
q = q.feats # [T_Q, H, C]
|
||||||
|
else:
|
||||||
|
s = None
|
||||||
|
N, L, H, C = q.shape
|
||||||
|
q_seqlen = [L] * N
|
||||||
|
q = q.reshape(N * L, H, C) # [T_Q, H, C]
|
||||||
|
|
||||||
|
if isinstance(kv, VarLenTensor):
|
||||||
|
kv_seqlen = [kv.layout[i].stop - kv.layout[i].start for i in range(kv.shape[0])]
|
||||||
|
kv = kv.feats # [T_KV, 2, H, C]
|
||||||
|
else:
|
||||||
|
N, L, _, H, C = kv.shape
|
||||||
|
kv_seqlen = [L] * N
|
||||||
|
kv = kv.reshape(N * L, 2, H, C) # [T_KV, 2, H, C]
|
||||||
|
|
||||||
|
elif num_all_args == 3:
|
||||||
|
q = args[0] if len(args) > 0 else kwargs['q']
|
||||||
|
k = args[1] if len(args) > 1 else kwargs['k']
|
||||||
|
v = args[2] if len(args) > 2 else kwargs['v']
|
||||||
|
device = q.device
|
||||||
|
|
||||||
|
if isinstance(q, VarLenTensor):
|
||||||
|
s = q
|
||||||
|
q_seqlen = [q.layout[i].stop - q.layout[i].start for i in range(q.shape[0])]
|
||||||
|
q = q.feats # [T_Q, H, Ci]
|
||||||
|
else:
|
||||||
|
s = None
|
||||||
|
N, L, H, CI = q.shape
|
||||||
|
q_seqlen = [L] * N
|
||||||
|
q = q.reshape(N * L, H, CI) # [T_Q, H, Ci]
|
||||||
|
|
||||||
|
if isinstance(k, VarLenTensor):
|
||||||
|
kv_seqlen = [k.layout[i].stop - k.layout[i].start for i in range(k.shape[0])]
|
||||||
|
k = k.feats # [T_KV, H, Ci]
|
||||||
|
v = v.feats # [T_KV, H, Co]
|
||||||
|
else:
|
||||||
|
N, L, H, CI, CO = *k.shape, v.shape[-1]
|
||||||
|
kv_seqlen = [L] * N
|
||||||
|
k = k.reshape(N * L, H, CI) # [T_KV, H, Ci]
|
||||||
|
v = v.reshape(N * L, H, CO) # [T_KV, H, Co]
|
||||||
|
|
||||||
|
# TODO: change
|
||||||
|
if q is not None:
|
||||||
|
heads = q
|
||||||
|
else:
|
||||||
|
heads = qkv
|
||||||
|
heads = heads.shape[2]
|
||||||
|
if optimized_attention.__name__ == 'attention_xformers':
|
||||||
|
if 'xops' not in globals():
|
||||||
|
import xformers.ops as xops
|
||||||
|
if num_all_args == 1:
|
||||||
|
q, k, v = qkv.unbind(dim=1)
|
||||||
|
elif num_all_args == 2:
|
||||||
|
k, v = kv.unbind(dim=1)
|
||||||
|
q = q.unsqueeze(0)
|
||||||
|
k = k.unsqueeze(0)
|
||||||
|
v = v.unsqueeze(0)
|
||||||
|
mask = xops.fmha.BlockDiagonalMask.from_seqlens(q_seqlen, kv_seqlen)
|
||||||
|
out = xops.memory_efficient_attention(q, k, v, mask)[0]
|
||||||
|
elif optimized_attention.__name__ == 'attention_flash':
|
||||||
|
if 'flash_attn' not in globals():
|
||||||
|
import flash_attn
|
||||||
|
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
||||||
|
if num_all_args in [2, 3]:
|
||||||
|
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||||
|
if num_all_args == 1:
|
||||||
|
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv, cu_seqlens_q, max(q_seqlen))
|
||||||
|
elif num_all_args == 2:
|
||||||
|
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||||
|
elif num_all_args == 3:
|
||||||
|
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
||||||
|
|
||||||
|
elif optimized_attention.__name__ == "attention_pytorch":
|
||||||
|
cu_seqlens_q = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(q_seqlen), dim=0)]).int().to(device)
|
||||||
|
if num_all_args in [2, 3]:
|
||||||
|
cu_seqlens_kv = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(kv_seqlen), dim=0)]).int().to(device)
|
||||||
|
else:
|
||||||
|
cu_seqlens_kv = cu_seqlens_q
|
||||||
|
if num_all_args == 1:
|
||||||
|
q, k, v = qkv.unbind(dim=1)
|
||||||
|
elif num_all_args == 2:
|
||||||
|
k, v = kv.unbind(dim=1)
|
||||||
|
out = attention_pytorch(q, k, v, heads=heads,cu_seqlens_q=cu_seqlens_q,
|
||||||
|
cu_seqlens_kv=cu_seqlens_kv, max_seqlen_q=max(q_seqlen), max_kv_seqlen=max(kv_seqlen),
|
||||||
|
skip_reshape=True, skip_output_reshape=True)
|
||||||
|
|
||||||
|
if s is not None:
|
||||||
|
return s.replace(out)
|
||||||
|
else:
|
||||||
|
return out.reshape(N, L, H, -1)
|
||||||
433
comfy/ldm/trellis2/cumesh.py
Normal file
433
comfy/ldm/trellis2/cumesh.py
Normal file
@ -0,0 +1,433 @@
|
|||||||
|
# will contain every cuda -> pytorch operation
|
||||||
|
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
from typing import Callable
|
||||||
|
import logging
|
||||||
|
|
||||||
|
NO_TRITON = False
|
||||||
|
try:
|
||||||
|
allow_tf32 = torch.cuda.is_tf32_supported()
|
||||||
|
except Exception:
|
||||||
|
allow_tf32 = False
|
||||||
|
try:
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
heuristics = {
|
||||||
|
'valid_kernel': lambda args: args['valid_kernel'](args['B1']),
|
||||||
|
'valid_kernel_seg': lambda args: args['valid_kernel_seg'](args['B1']),
|
||||||
|
}
|
||||||
|
|
||||||
|
#@triton_autotune(
|
||||||
|
# configs=config.autotune_config,
|
||||||
|
# key=['LOGN', 'Ci', 'Co', 'V', 'allow_tf32'],
|
||||||
|
#)
|
||||||
|
@triton.heuristics(heuristics)
|
||||||
|
@triton.jit
|
||||||
|
def sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel(
|
||||||
|
input,
|
||||||
|
weight,
|
||||||
|
bias,
|
||||||
|
neighbor,
|
||||||
|
sorted_idx,
|
||||||
|
output,
|
||||||
|
# Tensor dimensions
|
||||||
|
N, LOGN, Ci, Co, V: tl.constexpr,
|
||||||
|
# Meta-parameters
|
||||||
|
B1: tl.constexpr, # Block size for N dimension
|
||||||
|
B2: tl.constexpr, # Block size for Co dimension
|
||||||
|
BK: tl.constexpr, # Block size for K dimension (V * Ci)
|
||||||
|
allow_tf32: tl.constexpr, # Allow TF32 precision for matmuls
|
||||||
|
# Huristic parameters
|
||||||
|
valid_kernel,
|
||||||
|
valid_kernel_seg,
|
||||||
|
):
|
||||||
|
|
||||||
|
block_id = tl.program_id(axis=0)
|
||||||
|
block_dim_co = tl.cdiv(Co, B2)
|
||||||
|
block_id_co = block_id % block_dim_co
|
||||||
|
block_id_n = block_id // block_dim_co
|
||||||
|
|
||||||
|
# Create pointers for submatrices of A and B.
|
||||||
|
num_k = tl.cdiv(Ci, BK) # Number of blocks in K dimension
|
||||||
|
valid_kernel_start = tl.load(valid_kernel_seg + block_id_n)
|
||||||
|
valid_kernel_seglen = tl.load(valid_kernel_seg + block_id_n + 1) - valid_kernel_start
|
||||||
|
offset_n = block_id_n * B1 + tl.arange(0, B1)
|
||||||
|
n_mask = offset_n < N
|
||||||
|
offset_sorted_n = tl.load(sorted_idx + offset_n, mask=n_mask, other=0) # (B1,)
|
||||||
|
offset_co = (block_id_co * B2 + tl.arange(0, B2)) % Co # (B2,)
|
||||||
|
offset_k = tl.arange(0, BK) # (BK,)
|
||||||
|
|
||||||
|
# Create a block of the output matrix C.
|
||||||
|
accumulator = tl.zeros((B1, B2), dtype=tl.float32)
|
||||||
|
|
||||||
|
# Iterate along V*Ci dimension.
|
||||||
|
for k in range(num_k * valid_kernel_seglen):
|
||||||
|
v = k // num_k
|
||||||
|
bk = k % num_k
|
||||||
|
v = tl.load(valid_kernel + valid_kernel_start + v)
|
||||||
|
# Calculate pointers to input matrix.
|
||||||
|
neighbor_offset_n = tl.load(neighbor + offset_sorted_n * V + v) # (B1,)
|
||||||
|
input_ptr = input + bk * BK + (neighbor_offset_n[:, None].to(tl.int64) * Ci + offset_k[None, :]) # (B1, BK)
|
||||||
|
# Calculate pointers to weight matrix.
|
||||||
|
weight_ptr = weight + v * Ci + bk * BK + (offset_co[None, :] * V * Ci + offset_k[:, None]) # (BK, B2)
|
||||||
|
# Load the next block of input and weight.
|
||||||
|
neigh_mask = neighbor_offset_n != 0xffffffff
|
||||||
|
k_mask = offset_k < Ci - bk * BK
|
||||||
|
input_block = tl.load(input_ptr, mask=neigh_mask[:, None] & k_mask[None, :], other=0.0)
|
||||||
|
weight_block = tl.load(weight_ptr, mask=k_mask[:, None], other=0.0)
|
||||||
|
# Accumulate along the K dimension.
|
||||||
|
accumulator = tl.dot(input_block, weight_block, accumulator,
|
||||||
|
input_precision='tf32' if allow_tf32 else 'ieee') # (B1, B2)
|
||||||
|
c = accumulator.to(input.type.element_ty)
|
||||||
|
|
||||||
|
# add bias
|
||||||
|
if bias is not None:
|
||||||
|
bias_block = tl.load(bias + offset_co)
|
||||||
|
c += bias_block[None, :]
|
||||||
|
|
||||||
|
# Write back the block of the output matrix with masks.
|
||||||
|
out_offset_n = offset_sorted_n
|
||||||
|
out_offset_co = block_id_co * B2 + tl.arange(0, B2)
|
||||||
|
out_ptr = output + (out_offset_n[:, None] * Co + out_offset_co[None, :])
|
||||||
|
out_mask = n_mask[:, None] & (out_offset_co[None, :] < Co)
|
||||||
|
tl.store(out_ptr, c, mask=out_mask)
|
||||||
|
def sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
|
||||||
|
input: torch.Tensor,
|
||||||
|
weight: torch.Tensor,
|
||||||
|
bias: torch.Tensor,
|
||||||
|
neighbor: torch.Tensor,
|
||||||
|
sorted_idx: torch.Tensor,
|
||||||
|
valid_kernel: Callable[[int], torch.Tensor],
|
||||||
|
valid_kernel_seg: Callable[[int], torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
N, Ci, Co, V = neighbor.shape[0], input.shape[1], weight.shape[0], weight.shape[1]
|
||||||
|
LOGN = int(math.log2(N))
|
||||||
|
output = torch.empty((N, Co), device=input.device, dtype=input.dtype)
|
||||||
|
grid = lambda META: (triton.cdiv(Co, META['B2']) * triton.cdiv(N, META['B1']),)
|
||||||
|
sparse_submanifold_conv_fwd_masked_implicit_gemm_kernel[grid](
|
||||||
|
input, weight, bias, neighbor, sorted_idx, output,
|
||||||
|
N, LOGN, Ci, Co, V,
|
||||||
|
B1=128,
|
||||||
|
B2=64,
|
||||||
|
BK=32,
|
||||||
|
valid_kernel=valid_kernel,
|
||||||
|
valid_kernel_seg=valid_kernel_seg,
|
||||||
|
allow_tf32=allow_tf32,
|
||||||
|
)
|
||||||
|
return output
|
||||||
|
except Exception:
|
||||||
|
NO_TRITON = True
|
||||||
|
|
||||||
|
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
|
||||||
|
# offsets in same order as CUDA kernel
|
||||||
|
offsets = []
|
||||||
|
for vx in range(Kw):
|
||||||
|
for vy in range(Kh):
|
||||||
|
for vz in range(Kd):
|
||||||
|
offsets.append((
|
||||||
|
vx * Dw,
|
||||||
|
vy * Dh,
|
||||||
|
vz * Dd
|
||||||
|
))
|
||||||
|
return torch.tensor(offsets, device=device)
|
||||||
|
|
||||||
|
def build_submanifold_neighbor_map(
|
||||||
|
hashmap,
|
||||||
|
coords: torch.Tensor,
|
||||||
|
W, H, D,
|
||||||
|
Kw, Kh, Kd,
|
||||||
|
Dw, Dh, Dd,
|
||||||
|
):
|
||||||
|
device = coords.device
|
||||||
|
M = coords.shape[0]
|
||||||
|
V = Kw * Kh * Kd
|
||||||
|
half_V = V // 2 + 1
|
||||||
|
|
||||||
|
INVALID = hashmap.default_value
|
||||||
|
|
||||||
|
neighbor = torch.full((M, V), INVALID, device=device, dtype=torch.long)
|
||||||
|
|
||||||
|
b = coords[:, 0].long()
|
||||||
|
x = coords[:, 1].long()
|
||||||
|
y = coords[:, 2].long()
|
||||||
|
z = coords[:, 3].long()
|
||||||
|
|
||||||
|
offsets = compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device)
|
||||||
|
|
||||||
|
ox = x - (Kw // 2) * Dw
|
||||||
|
oy = y - (Kh // 2) * Dh
|
||||||
|
oz = z - (Kd // 2) * Dd
|
||||||
|
|
||||||
|
for v in range(half_V):
|
||||||
|
if v == half_V - 1:
|
||||||
|
neighbor[:, v] = torch.arange(M, device=device)
|
||||||
|
continue
|
||||||
|
|
||||||
|
dx, dy, dz = offsets[v]
|
||||||
|
|
||||||
|
kx = ox + dx
|
||||||
|
ky = oy + dy
|
||||||
|
kz = oz + dz
|
||||||
|
|
||||||
|
# Check spatial bounds
|
||||||
|
valid = (
|
||||||
|
(kx >= 0) & (kx < W) &
|
||||||
|
(ky >= 0) & (ky < H) &
|
||||||
|
(kz >= 0) & (kz < D)
|
||||||
|
)
|
||||||
|
|
||||||
|
flat = (
|
||||||
|
b[valid] * (W * H * D) +
|
||||||
|
kx[valid] * (H * D) +
|
||||||
|
ky[valid] * D +
|
||||||
|
kz[valid]
|
||||||
|
)
|
||||||
|
|
||||||
|
if flat.numel() > 0:
|
||||||
|
found = hashmap.lookup_flat(flat)
|
||||||
|
idx_in_M = torch.where(valid)[0]
|
||||||
|
neighbor[idx_in_M, v] = found
|
||||||
|
|
||||||
|
valid_found_mask = (found != INVALID)
|
||||||
|
if valid_found_mask.any():
|
||||||
|
src_points = idx_in_M[valid_found_mask]
|
||||||
|
dst_points = found[valid_found_mask]
|
||||||
|
neighbor[dst_points, V - 1 - v] = src_points
|
||||||
|
|
||||||
|
return neighbor
|
||||||
|
|
||||||
|
class TorchHashMap:
|
||||||
|
def __init__(self, keys: torch.Tensor, values: torch.Tensor, default_value: int):
|
||||||
|
device = keys.device
|
||||||
|
# use long for searchsorted
|
||||||
|
self.sorted_keys, order = torch.sort(keys.to(torch.long))
|
||||||
|
self.sorted_vals = values.to(torch.long)[order]
|
||||||
|
self.default_value = torch.tensor(default_value, dtype=torch.long, device=device)
|
||||||
|
self._n = self.sorted_keys.numel()
|
||||||
|
|
||||||
|
def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor:
|
||||||
|
flat = flat_keys.to(torch.long)
|
||||||
|
if self._n == 0:
|
||||||
|
return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
|
||||||
|
idx = torch.searchsorted(self.sorted_keys, flat)
|
||||||
|
idx_safe = torch.clamp(idx, max=self._n - 1)
|
||||||
|
found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat)
|
||||||
|
out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype)
|
||||||
|
if found.any():
|
||||||
|
out[found] = self.sorted_vals[idx_safe[found]]
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
UINT32_SENTINEL = 0xFFFFFFFF
|
||||||
|
|
||||||
|
def neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor_map):
|
||||||
|
device = neighbor_map.device
|
||||||
|
N, V = neighbor_map.shape
|
||||||
|
|
||||||
|
sentinel = UINT32_SENTINEL
|
||||||
|
|
||||||
|
neigh_map_T = neighbor_map.t().reshape(-1)
|
||||||
|
neigh_mask_T = (neigh_map_T != sentinel).to(torch.int32)
|
||||||
|
|
||||||
|
mask = (neighbor_map != sentinel).to(torch.long)
|
||||||
|
gray_code = torch.zeros(N, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
for v in range(V):
|
||||||
|
gray_code |= (mask[:, v] << v)
|
||||||
|
|
||||||
|
binary_code = gray_code.clone()
|
||||||
|
for v in range(1, V):
|
||||||
|
binary_code ^= (gray_code >> v)
|
||||||
|
|
||||||
|
sorted_idx = torch.argsort(binary_code)
|
||||||
|
|
||||||
|
prefix_sum_neighbor_mask = torch.cumsum(neigh_mask_T, dim=0)
|
||||||
|
|
||||||
|
total_valid_signal = int(prefix_sum_neighbor_mask[-1].item()) if prefix_sum_neighbor_mask.numel() > 0 else 0
|
||||||
|
|
||||||
|
if total_valid_signal > 0:
|
||||||
|
pos = torch.nonzero(neigh_mask_T, as_tuple=True)[0]
|
||||||
|
to = (prefix_sum_neighbor_mask[pos] - 1).long()
|
||||||
|
|
||||||
|
valid_signal_i = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||||
|
valid_signal_o = torch.empty((total_valid_signal,), dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
valid_signal_i[to] = (pos % N).to(torch.long)
|
||||||
|
valid_signal_o[to] = neigh_map_T[pos].to(torch.long)
|
||||||
|
else:
|
||||||
|
valid_signal_i = torch.empty((0,), dtype=torch.long, device=device)
|
||||||
|
valid_signal_o = torch.empty((0,), dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
seg = torch.empty((V + 1,), dtype=torch.long, device=device)
|
||||||
|
seg[0] = 0
|
||||||
|
if V > 0:
|
||||||
|
idxs = (torch.arange(1, V + 1, device=device, dtype=torch.long) * N) - 1
|
||||||
|
seg[1:] = prefix_sum_neighbor_mask[idxs]
|
||||||
|
|
||||||
|
return gray_code, sorted_idx, valid_signal_i, valid_signal_o, seg
|
||||||
|
|
||||||
|
def _popcount_int32_tensor(x: torch.Tensor) -> torch.Tensor:
|
||||||
|
|
||||||
|
x = x.to(torch.int64)
|
||||||
|
|
||||||
|
m1 = torch.tensor(0x5555555555555555, dtype=torch.int64, device=x.device)
|
||||||
|
m2 = torch.tensor(0x3333333333333333, dtype=torch.int64, device=x.device)
|
||||||
|
m4 = torch.tensor(0x0F0F0F0F0F0F0F0F, dtype=torch.int64, device=x.device)
|
||||||
|
h01 = torch.tensor(0x0101010101010101, dtype=torch.int64, device=x.device)
|
||||||
|
|
||||||
|
x = x - ((x >> 1) & m1)
|
||||||
|
x = (x & m2) + ((x >> 2) & m2)
|
||||||
|
x = (x + (x >> 4)) & m4
|
||||||
|
x = (x * h01) >> 56
|
||||||
|
return x.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def neighbor_map_post_process_for_masked_implicit_gemm_2(
|
||||||
|
gray_code: torch.Tensor,
|
||||||
|
sorted_idx: torch.Tensor,
|
||||||
|
block_size: int
|
||||||
|
):
|
||||||
|
device = gray_code.device
|
||||||
|
N = gray_code.numel()
|
||||||
|
num_blocks = (N + block_size - 1) // block_size
|
||||||
|
|
||||||
|
pad = num_blocks * block_size - N
|
||||||
|
if pad > 0:
|
||||||
|
pad_vals = torch.zeros((pad,), dtype=torch.int32, device=device)
|
||||||
|
gray_padded = torch.cat([gray_code[sorted_idx], pad_vals], dim=0)
|
||||||
|
else:
|
||||||
|
gray_padded = gray_code[sorted_idx]
|
||||||
|
|
||||||
|
gray_blocks = gray_padded.view(num_blocks, block_size)
|
||||||
|
|
||||||
|
reduced_code = gray_blocks
|
||||||
|
while reduced_code.shape[1] > 1:
|
||||||
|
half = reduced_code.shape[1] // 2
|
||||||
|
remainder = reduced_code.shape[1] % 2
|
||||||
|
|
||||||
|
left = reduced_code[:, :half * 2:2]
|
||||||
|
right = reduced_code[:, 1:half * 2:2]
|
||||||
|
merged = left | right
|
||||||
|
|
||||||
|
if remainder:
|
||||||
|
reduced_code = torch.cat([merged, reduced_code[:, -1:]], dim=1)
|
||||||
|
else:
|
||||||
|
reduced_code = merged
|
||||||
|
|
||||||
|
reduced_code = reduced_code.squeeze(1)
|
||||||
|
|
||||||
|
seglen_counts = _popcount_int32_tensor(reduced_code).to(torch.int32)
|
||||||
|
|
||||||
|
seg = torch.empty((num_blocks + 1,), dtype=torch.int32, device=device)
|
||||||
|
seg[0] = 0
|
||||||
|
if num_blocks > 0:
|
||||||
|
seg[1:] = torch.cumsum(seglen_counts, dim=0)
|
||||||
|
|
||||||
|
total = int(seg[-1].item())
|
||||||
|
|
||||||
|
if total == 0:
|
||||||
|
return torch.empty((0,), dtype=torch.int32, device=device), seg
|
||||||
|
|
||||||
|
V = int(reduced_code.max().item()).bit_length() if reduced_code.max() > 0 else 0
|
||||||
|
|
||||||
|
if V == 0:
|
||||||
|
return torch.empty((0,), dtype=torch.int32, device=device), seg
|
||||||
|
|
||||||
|
bit_pos = torch.arange(0, V, dtype=torch.int32, device=device)
|
||||||
|
shifted = reduced_code.unsqueeze(1) >> bit_pos.unsqueeze(0)
|
||||||
|
bits = (shifted & 1).to(torch.bool)
|
||||||
|
|
||||||
|
positions = bit_pos.unsqueeze(0).expand(num_blocks, V)
|
||||||
|
valid_kernel_idx = positions[bits].to(torch.int32).contiguous()
|
||||||
|
|
||||||
|
return valid_kernel_idx, seg
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
|
||||||
|
if NO_TRITON: # TODO
|
||||||
|
raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.")
|
||||||
|
if feats.shape[0] == 0:
|
||||||
|
logging.warning("Found feats to be empty!")
|
||||||
|
Co = weight.shape[0]
|
||||||
|
return torch.empty((0, Co), device=feats.device, dtype=feats.dtype), None
|
||||||
|
if len(shape) == 5:
|
||||||
|
N, C, W, H, D = shape
|
||||||
|
else:
|
||||||
|
W, H, D = shape
|
||||||
|
|
||||||
|
Co, Kw, Kh, Kd, Ci = weight.shape
|
||||||
|
|
||||||
|
b_stride = W * H * D
|
||||||
|
x_stride = H * D
|
||||||
|
y_stride = D
|
||||||
|
z_stride = 1
|
||||||
|
|
||||||
|
flat_keys = (coords[:, 0].long() * b_stride +
|
||||||
|
coords[:, 1].long() * x_stride +
|
||||||
|
coords[:, 2].long() * y_stride +
|
||||||
|
coords[:, 3].long() * z_stride)
|
||||||
|
|
||||||
|
vals = torch.arange(coords.shape[0], dtype=torch.int32, device=coords.device)
|
||||||
|
|
||||||
|
hashmap = TorchHashMap(flat_keys, vals, 0xFFFFFFFF)
|
||||||
|
|
||||||
|
if neighbor_cache is None:
|
||||||
|
neighbor = build_submanifold_neighbor_map(
|
||||||
|
hashmap, coords, W, H, D, Kw, Kh, Kd,
|
||||||
|
dilation[0], dilation[1], dilation[2]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
neighbor = neighbor_cache
|
||||||
|
|
||||||
|
block_size = 128
|
||||||
|
|
||||||
|
gray_code, sorted_idx, valid_signal_i, valid_signal_o, valid_signal_seg = \
|
||||||
|
neighbor_map_post_process_for_masked_implicit_gemm_1(neighbor)
|
||||||
|
|
||||||
|
valid_kernel, valid_kernel_seg = \
|
||||||
|
neighbor_map_post_process_for_masked_implicit_gemm_2(gray_code, sorted_idx, block_size)
|
||||||
|
|
||||||
|
valid_kernel_fn = lambda b_size: valid_kernel
|
||||||
|
valid_kernel_seg_fn = lambda b_size: valid_kernel_seg
|
||||||
|
|
||||||
|
weight_flat = weight.contiguous().view(Co, -1, Ci)
|
||||||
|
|
||||||
|
out = sparse_submanifold_conv_fwd_masked_implicit_gemm_splitk(
|
||||||
|
feats,
|
||||||
|
weight_flat,
|
||||||
|
bias,
|
||||||
|
neighbor,
|
||||||
|
sorted_idx,
|
||||||
|
valid_kernel_fn,
|
||||||
|
valid_kernel_seg_fn
|
||||||
|
)
|
||||||
|
|
||||||
|
return out, neighbor
|
||||||
|
|
||||||
|
class Mesh:
|
||||||
|
def __init__(self,
|
||||||
|
vertices,
|
||||||
|
faces,
|
||||||
|
vertex_attrs=None
|
||||||
|
):
|
||||||
|
self.vertices = vertices.float()
|
||||||
|
self.faces = faces.int()
|
||||||
|
self.vertex_attrs = vertex_attrs
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self):
|
||||||
|
return self.vertices.device
|
||||||
|
|
||||||
|
def to(self, device, non_blocking=False):
|
||||||
|
return Mesh(
|
||||||
|
self.vertices.to(device, non_blocking=non_blocking),
|
||||||
|
self.faces.to(device, non_blocking=non_blocking),
|
||||||
|
self.vertex_attrs.to(device, non_blocking=non_blocking) if self.vertex_attrs is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def cuda(self, non_blocking=False):
|
||||||
|
return self.to('cuda', non_blocking=non_blocking)
|
||||||
|
|
||||||
|
def cpu(self):
|
||||||
|
return self.to('cpu')
|
||||||
1054
comfy/ldm/trellis2/model.py
Normal file
1054
comfy/ldm/trellis2/model.py
Normal file
File diff suppressed because it is too large
Load Diff
1444
comfy/ldm/trellis2/vae.py
Normal file
1444
comfy/ldm/trellis2/vae.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -52,6 +52,7 @@ import comfy.ldm.omnigen.omnigen2
|
|||||||
import comfy.ldm.qwen_image.model
|
import comfy.ldm.qwen_image.model
|
||||||
import comfy.ldm.kandinsky5.model
|
import comfy.ldm.kandinsky5.model
|
||||||
import comfy.ldm.anima.model
|
import comfy.ldm.anima.model
|
||||||
|
import comfy.ldm.trellis2.model
|
||||||
import comfy.ldm.ace.ace_step15
|
import comfy.ldm.ace.ace_step15
|
||||||
import comfy.ldm.cogvideo.model
|
import comfy.ldm.cogvideo.model
|
||||||
import comfy.ldm.rt_detr.rtdetr_v4
|
import comfy.ldm.rt_detr.rtdetr_v4
|
||||||
@ -1555,6 +1556,16 @@ class WAN22(WAN21):
|
|||||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||||
return latent_image
|
return latent_image
|
||||||
|
|
||||||
|
class Trellis2(BaseModel):
|
||||||
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None, unet_model=comfy.ldm.trellis2.model.Trellis2):
|
||||||
|
super().__init__(model_config, model_type, device, unet_model)
|
||||||
|
|
||||||
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
embeds = kwargs.get("embeds")
|
||||||
|
out["embeds"] = comfy.conds.CONDRegular(embeds)
|
||||||
|
return out
|
||||||
|
|
||||||
class WAN21_FlowRVS(WAN21):
|
class WAN21_FlowRVS(WAN21):
|
||||||
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG_FLOW, image_to_video=False, device=None):
|
||||||
model_config.unet_config["model_type"] = "t2v"
|
model_config.unet_config["model_type"] = "t2v"
|
||||||
@ -1596,7 +1607,6 @@ class WAN21_SCAIL(WAN21):
|
|||||||
pose_latents = kwargs.get("pose_video_latent", None)
|
pose_latents = kwargs.get("pose_video_latent", None)
|
||||||
if pose_latents is not None:
|
if pose_latents is not None:
|
||||||
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
|
out['pose_latents'] = [pose_latents.shape[0], 20, *pose_latents.shape[2:]]
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
class Hunyuan3Dv2(BaseModel):
|
class Hunyuan3Dv2(BaseModel):
|
||||||
|
|||||||
@ -113,6 +113,22 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
|||||||
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
unet_config['block_repeat'] = [[1, 1, 1, 1], [2, 2, 2, 2]]
|
||||||
return unet_config
|
return unet_config
|
||||||
|
|
||||||
|
if '{}img2shape.blocks.1.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||||
|
unet_config = {}
|
||||||
|
unet_config["image_model"] = "trellis2"
|
||||||
|
|
||||||
|
unet_config["init_txt_model"] = False
|
||||||
|
if '{}shape2txt.blocks.29.cross_attn.k_rms_norm.gamma'.format(key_prefix) in state_dict_keys:
|
||||||
|
unet_config["init_txt_model"] = True
|
||||||
|
|
||||||
|
unet_config["resolution"] = 64
|
||||||
|
if metadata is not None:
|
||||||
|
if "is_512" in metadata:
|
||||||
|
unet_config["resolution"] = 32
|
||||||
|
|
||||||
|
unet_config["num_heads"] = 12
|
||||||
|
return unet_config
|
||||||
|
|
||||||
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
if '{}transformer.rotary_pos_emb.inv_freq'.format(key_prefix) in state_dict_keys: #stable audio dit
|
||||||
unet_config = {}
|
unet_config = {}
|
||||||
unet_config["audio_model"] = "dit1.0"
|
unet_config["audio_model"] = "dit1.0"
|
||||||
|
|||||||
@ -7,6 +7,50 @@ import logging
|
|||||||
import comfy.nested_tensor
|
import comfy.nested_tensor
|
||||||
|
|
||||||
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||||
|
coord_counts = getattr(latent_image, "trellis_coord_counts", None)
|
||||||
|
if coord_counts is not None:
|
||||||
|
if coord_counts.ndim != 1:
|
||||||
|
raise ValueError(f"Trellis2 coord_counts must be 1D, got shape {tuple(coord_counts.shape)}")
|
||||||
|
if coord_counts.shape[0] != latent_image.size(0):
|
||||||
|
raise ValueError(
|
||||||
|
f"Trellis2 coord_counts length {coord_counts.shape[0]} does not match latent batch {latent_image.size(0)}"
|
||||||
|
)
|
||||||
|
if (coord_counts < 0).any() or (coord_counts > latent_image.size(2)).any():
|
||||||
|
raise ValueError(
|
||||||
|
f"Trellis2 coord_counts must be within [0, {latent_image.size(2)}], got {coord_counts.tolist()}"
|
||||||
|
)
|
||||||
|
noise = torch.zeros(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, device="cpu")
|
||||||
|
if noise_inds is None:
|
||||||
|
noise_inds = np.arange(latent_image.size(0), dtype=np.int64)
|
||||||
|
else:
|
||||||
|
noise_inds = np.asarray(noise_inds, dtype=np.int64)
|
||||||
|
if noise_inds.shape[0] != latent_image.size(0):
|
||||||
|
raise ValueError(
|
||||||
|
f"Trellis2 noise_inds length {noise_inds.shape[0]} does not match latent batch {latent_image.size(0)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
base_seed = int(generator.initial_seed())
|
||||||
|
unique_inds = np.unique(noise_inds)
|
||||||
|
sample_noises = {}
|
||||||
|
for noise_index in unique_inds.tolist():
|
||||||
|
rows = np.flatnonzero(noise_inds == noise_index)
|
||||||
|
max_count = max(int(coord_counts[row].item()) for row in rows.tolist())
|
||||||
|
local_generator = torch.Generator(device="cpu")
|
||||||
|
local_generator.manual_seed(base_seed + int(noise_index))
|
||||||
|
sample_noises[int(noise_index)] = torch.randn(
|
||||||
|
[1, latent_image.size(1), max_count, latent_image.size(3)],
|
||||||
|
dtype=torch.float32,
|
||||||
|
layout=latent_image.layout,
|
||||||
|
generator=local_generator,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_index, noise_index in enumerate(noise_inds.tolist()):
|
||||||
|
count = int(coord_counts[batch_index].item())
|
||||||
|
sample_noise = sample_noises[int(noise_index)]
|
||||||
|
noise[batch_index:batch_index + 1, :, :count, :] = sample_noise[:, :, :count, :]
|
||||||
|
return noise.to(dtype=latent_image.dtype)
|
||||||
|
|
||||||
if noise_inds is None:
|
if noise_inds is None:
|
||||||
return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
|
return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
|
||||||
|
|
||||||
|
|||||||
10
comfy/sd.py
10
comfy/sd.py
@ -15,6 +15,7 @@ import comfy.ldm.lightricks.vae.causal_video_autoencoder
|
|||||||
import comfy.ldm.lightricks.vae.audio_vae
|
import comfy.ldm.lightricks.vae.audio_vae
|
||||||
import comfy.ldm.cosmos.vae
|
import comfy.ldm.cosmos.vae
|
||||||
import comfy.ldm.wan.vae
|
import comfy.ldm.wan.vae
|
||||||
|
import comfy.ldm.trellis2.vae
|
||||||
import comfy.ldm.wan.vae2_2
|
import comfy.ldm.wan.vae2_2
|
||||||
import comfy.ldm.hunyuan3d.vae
|
import comfy.ldm.hunyuan3d.vae
|
||||||
import comfy.ldm.ace.vae.music_dcae_pipeline
|
import comfy.ldm.ace.vae.music_dcae_pipeline
|
||||||
@ -513,6 +514,15 @@ class VAE:
|
|||||||
self.first_stage_model = StageC_coder()
|
self.first_stage_model = StageC_coder()
|
||||||
self.downscale_ratio = 32
|
self.downscale_ratio = 32
|
||||||
self.latent_channels = 16
|
self.latent_channels = 16
|
||||||
|
elif "shape_dec.blocks.1.16.to_subdiv.weight" in sd: # trellis2
|
||||||
|
init_txt_model = False
|
||||||
|
if "txt_dec.blocks.1.16.norm1.weight" in sd:
|
||||||
|
init_txt_model = True
|
||||||
|
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
|
# TODO
|
||||||
|
self.memory_used_decode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
|
self.memory_used_encode = lambda shape, dtype: (2500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||||
|
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model)
|
||||||
elif "decoder.conv_in.weight" in sd:
|
elif "decoder.conv_in.weight" in sd:
|
||||||
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
if sd['decoder.conv_in.weight'].shape[1] == 64:
|
||||||
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||||
|
|||||||
@ -1293,6 +1293,29 @@ class WAN22_T2V(WAN21_T2V):
|
|||||||
out = model_base.WAN22(self, image_to_video=True, device=device)
|
out = model_base.WAN22(self, image_to_video=True, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
class Trellis2(supported_models_base.BASE):
|
||||||
|
unet_config = {
|
||||||
|
"image_model": "trellis2"
|
||||||
|
}
|
||||||
|
|
||||||
|
sampling_settings = {
|
||||||
|
"shift": 3.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_usage_factor = 3.5
|
||||||
|
|
||||||
|
latent_format = latent_formats.Trellis2
|
||||||
|
vae_key_prefix = ["vae."]
|
||||||
|
clip_vision_prefix = "conditioner.main_image_encoder.model."
|
||||||
|
# this is only needed for the texture model
|
||||||
|
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||||
|
|
||||||
|
def get_model(self, state_dict, prefix="", device=None):
|
||||||
|
return model_base.Trellis2(self, device=device)
|
||||||
|
|
||||||
|
def clip_target(self, state_dict={}):
|
||||||
|
return None
|
||||||
|
|
||||||
class WAN21_FlowRVS(WAN21_T2V):
|
class WAN21_FlowRVS(WAN21_T2V):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "wan2.1",
|
"image_model": "wan2.1",
|
||||||
@ -1313,6 +1336,7 @@ class WAN21_SCAIL(WAN21_T2V):
|
|||||||
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
|
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Hunyuan3Dv2(supported_models_base.BASE):
|
class Hunyuan3Dv2(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "hunyuan3d2",
|
"image_model": "hunyuan3d2",
|
||||||
@ -1684,6 +1708,7 @@ class Kandinsky5Image(Kandinsky5):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class ACEStep15(supported_models_base.BASE):
|
class ACEStep15(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"audio_model": "ace1.5",
|
"audio_model": "ace1.5",
|
||||||
@ -1723,7 +1748,6 @@ class ACEStep15(supported_models_base.BASE):
|
|||||||
|
|
||||||
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.ace15.ACE15Tokenizer, comfy.text_encoders.ace15.te(**detect))
|
||||||
|
|
||||||
|
|
||||||
class LongCatImage(supported_models_base.BASE):
|
class LongCatImage(supported_models_base.BASE):
|
||||||
unet_config = {
|
unet_config = {
|
||||||
"image_model": "flux",
|
"image_model": "flux",
|
||||||
@ -1801,6 +1825,7 @@ class ErnieImage(supported_models_base.BASE):
|
|||||||
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
|
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SAM3(supported_models_base.BASE):
|
class SAM3(supported_models_base.BASE):
|
||||||
unet_config = {"image_model": "SAM3"}
|
unet_config = {"image_model": "SAM3"}
|
||||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||||
@ -1898,7 +1923,6 @@ class CogVideoX_I2V(CogVideoX_T2V):
|
|||||||
out = model_base.CogVideoX(self, image_to_video=True, device=device)
|
out = model_base.CogVideoX(self, image_to_video=True, device=device)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
models = [
|
models = [
|
||||||
LotusD,
|
LotusD,
|
||||||
Stable_Zero123,
|
Stable_Zero123,
|
||||||
@ -1981,4 +2005,5 @@ models = [
|
|||||||
CogVideoX_I2V,
|
CogVideoX_I2V,
|
||||||
CogVideoX_T2V,
|
CogVideoX_T2V,
|
||||||
SVD_img2vid,
|
SVD_img2vid,
|
||||||
|
Trellis2
|
||||||
]
|
]
|
||||||
|
|||||||
@ -443,7 +443,9 @@ class VoxelToMeshBasic(IO.ComfyNode):
|
|||||||
vertices.append(v)
|
vertices.append(v)
|
||||||
faces.append(f)
|
faces.append(f)
|
||||||
|
|
||||||
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces):
|
||||||
|
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
||||||
|
return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces))
|
||||||
|
|
||||||
decode = execute # TODO: remove
|
decode = execute # TODO: remove
|
||||||
|
|
||||||
@ -479,12 +481,14 @@ class VoxelToMesh(IO.ComfyNode):
|
|||||||
vertices.append(v)
|
vertices.append(v)
|
||||||
faces.append(f)
|
faces.append(f)
|
||||||
|
|
||||||
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
if vertices and all(v.shape == vertices[0].shape for v in vertices) and all(f.shape == faces[0].shape for f in faces):
|
||||||
|
return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces)))
|
||||||
|
return IO.NodeOutput(pack_variable_mesh_batch(vertices, faces))
|
||||||
|
|
||||||
decode = execute # TODO: remove
|
decode = execute # TODO: remove
|
||||||
|
|
||||||
|
|
||||||
def save_glb(vertices, faces, filepath, metadata=None):
|
def save_glb(vertices, faces, filepath, metadata=None, colors=None):
|
||||||
"""
|
"""
|
||||||
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
|
||||||
|
|
||||||
@ -515,6 +519,13 @@ def save_glb(vertices, faces, filepath, metadata=None):
|
|||||||
indices_byte_length = len(indices_buffer)
|
indices_byte_length = len(indices_buffer)
|
||||||
indices_byte_offset = len(vertices_buffer_padded)
|
indices_byte_offset = len(vertices_buffer_padded)
|
||||||
|
|
||||||
|
if colors is not None:
|
||||||
|
colors_np = colors.cpu().numpy().astype(np.float32)
|
||||||
|
colors_buffer = colors_np.tobytes()
|
||||||
|
colors_byte_length = len(colors_buffer)
|
||||||
|
colors_byte_offset = len(buffer_data)
|
||||||
|
buffer_data += pad_to_4_bytes(colors_buffer)
|
||||||
|
|
||||||
gltf = {
|
gltf = {
|
||||||
"asset": {"version": "2.0", "generator": "ComfyUI"},
|
"asset": {"version": "2.0", "generator": "ComfyUI"},
|
||||||
"buffers": [
|
"buffers": [
|
||||||
@ -580,6 +591,14 @@ def save_glb(vertices, faces, filepath, metadata=None):
|
|||||||
"scene": 0
|
"scene": 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if colors is not None:
|
||||||
|
gltf["bufferViews"].append({"buffer": 0, "byteOffset": colors_byte_offset, "byteLength": colors_byte_length, "target": 34962})
|
||||||
|
gltf["accessors"].append({"bufferView": 2, "byteOffset": 0, "componentType": 5126, "count": len(colors_np), "type": "VEC3"})
|
||||||
|
gltf["meshes"][0]["primitives"][0]["attributes"]["COLOR_0"] = 2
|
||||||
|
# Define a base material so Three.js actually activates vertex coloring
|
||||||
|
gltf["materials"] =[{"pbrMetallicRoughness": {"baseColorFactor": [1.0, 1.0, 1.0, 1.0]}}]
|
||||||
|
gltf["meshes"][0]["primitives"][0]["material"] = 0
|
||||||
|
|
||||||
if metadata is not None:
|
if metadata is not None:
|
||||||
gltf["asset"]["extras"] = metadata
|
gltf["asset"]["extras"] = metadata
|
||||||
|
|
||||||
@ -613,6 +632,56 @@ def save_glb(vertices, faces, filepath, metadata=None):
|
|||||||
return filepath
|
return filepath
|
||||||
|
|
||||||
|
|
||||||
|
def pack_variable_mesh_batch(vertices, faces, colors=None):
|
||||||
|
batch_size = len(vertices)
|
||||||
|
max_vertices = max(v.shape[0] for v in vertices)
|
||||||
|
max_faces = max(f.shape[0] for f in faces)
|
||||||
|
|
||||||
|
packed_vertices = vertices[0].new_zeros((batch_size, max_vertices, vertices[0].shape[1]))
|
||||||
|
packed_faces = faces[0].new_zeros((batch_size, max_faces, faces[0].shape[1]))
|
||||||
|
vertex_counts = torch.tensor([v.shape[0] for v in vertices], device=vertices[0].device, dtype=torch.int64)
|
||||||
|
face_counts = torch.tensor([f.shape[0] for f in faces], device=faces[0].device, dtype=torch.int64)
|
||||||
|
|
||||||
|
for i, (v, f) in enumerate(zip(vertices, faces)):
|
||||||
|
packed_vertices[i, :v.shape[0]] = v
|
||||||
|
packed_faces[i, :f.shape[0]] = f
|
||||||
|
|
||||||
|
mesh = Types.MESH(packed_vertices, packed_faces)
|
||||||
|
mesh.vertex_counts = vertex_counts
|
||||||
|
mesh.face_counts = face_counts
|
||||||
|
|
||||||
|
if colors is not None:
|
||||||
|
max_colors = max(c.shape[0] for c in colors)
|
||||||
|
packed_colors = colors[0].new_zeros((batch_size, max_colors, colors[0].shape[1]))
|
||||||
|
color_counts = torch.tensor([c.shape[0] for c in colors], device=colors[0].device, dtype=torch.int64)
|
||||||
|
for i, c in enumerate(colors):
|
||||||
|
packed_colors[i, :c.shape[0]] = c
|
||||||
|
mesh.colors = packed_colors
|
||||||
|
mesh.color_counts = color_counts
|
||||||
|
|
||||||
|
return mesh
|
||||||
|
|
||||||
|
|
||||||
|
def get_mesh_batch_item(mesh, index):
|
||||||
|
if hasattr(mesh, "vertex_counts"):
|
||||||
|
vertex_count = int(mesh.vertex_counts[index].item())
|
||||||
|
face_count = int(mesh.face_counts[index].item())
|
||||||
|
vertices = mesh.vertices[index, :vertex_count]
|
||||||
|
faces = mesh.faces[index, :face_count]
|
||||||
|
colors = None
|
||||||
|
if hasattr(mesh, "colors") and mesh.colors is not None:
|
||||||
|
if hasattr(mesh, "color_counts"):
|
||||||
|
color_count = int(mesh.color_counts[index].item())
|
||||||
|
colors = mesh.colors[index, :color_count]
|
||||||
|
else:
|
||||||
|
colors = mesh.colors[index, :vertex_count]
|
||||||
|
return vertices, faces, colors
|
||||||
|
|
||||||
|
colors = None
|
||||||
|
if hasattr(mesh, "colors") and mesh.colors is not None:
|
||||||
|
colors = mesh.colors[index]
|
||||||
|
return mesh.vertices[index], mesh.faces[index], colors
|
||||||
|
|
||||||
class SaveGLB(IO.ComfyNode):
|
class SaveGLB(IO.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls):
|
def define_schema(cls):
|
||||||
@ -667,9 +736,11 @@ class SaveGLB(IO.ComfyNode):
|
|||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
# Handle Mesh input - save vertices and faces as GLB
|
# Handle Mesh input - save vertices and faces as GLB
|
||||||
for i in range(mesh.vertices.shape[0]):
|
bsz = mesh.vertices.shape[0]
|
||||||
|
for i in range(bsz):
|
||||||
f = f"{filename}_{counter:05}_.glb"
|
f = f"{filename}_{counter:05}_.glb"
|
||||||
save_glb(mesh.vertices[i], mesh.faces[i], os.path.join(full_output_folder, f), metadata)
|
vertices, faces, v_colors = get_mesh_batch_item(mesh, i)
|
||||||
|
save_glb(vertices, faces, os.path.join(full_output_folder, f), metadata, v_colors)
|
||||||
results.append({
|
results.append({
|
||||||
"filename": f,
|
"filename": f,
|
||||||
"subfolder": subfolder,
|
"subfolder": subfolder,
|
||||||
|
|||||||
1153
comfy_extras/nodes_trellis2.py
Normal file
1153
comfy_extras/nodes_trellis2.py
Normal file
File diff suppressed because it is too large
Load Diff
1
nodes.py
1
nodes.py
@ -2422,6 +2422,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_toolkit.py",
|
"nodes_toolkit.py",
|
||||||
"nodes_replacements.py",
|
"nodes_replacements.py",
|
||||||
"nodes_nag.py",
|
"nodes_nag.py",
|
||||||
|
"nodes_trellis2.py",
|
||||||
"nodes_sdpose.py",
|
"nodes_sdpose.py",
|
||||||
"nodes_math.py",
|
"nodes_math.py",
|
||||||
"nodes_number_convert.py",
|
"nodes_number_convert.py",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user