This commit is contained in:
Yousef Rafat 2026-02-22 01:25:10 +02:00
parent 1fde60b2bc
commit 253ee4c02c
5 changed files with 26 additions and 17 deletions

View File

@ -44,9 +44,6 @@ class DINOv3ViTAttention(nn.Module):
self.num_heads = num_attention_heads self.num_heads = num_attention_heads
self.head_dim = self.embed_dim // self.num_heads self.head_dim = self.embed_dim // self.num_heads
self.scaling = self.head_dim**-0.5
self.is_causal = False
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype) self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
@ -251,7 +248,7 @@ class DINOv3ViTModel(nn.Module):
intermediate_size=intermediate_size,num_attention_heads = num_attention_heads, intermediate_size=intermediate_size,num_attention_heads = num_attention_heads,
dtype=dtype, device=device, operations=operations) dtype=dtype, device=device, operations=operations)
for _ in range(num_hidden_layers)]) for _ in range(num_hidden_layers)])
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device) self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embeddings.patch_embeddings return self.embeddings.patch_embeddings

View File

@ -4,7 +4,7 @@ import math
import torch import torch
from typing import Dict, Callable from typing import Dict, Callable
NO_TRITION = False NO_TRITON = False
try: try:
allow_tf32 = torch.cuda.is_tf32_supported() allow_tf32 = torch.cuda.is_tf32_supported()
except Exception: except Exception:
@ -115,8 +115,8 @@ try:
allow_tf32=allow_tf32, allow_tf32=allow_tf32,
) )
return output return output
except: except Exception:
NO_TRITION = True NO_TRITON = True
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device): def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
# offsets in same order as CUDA kernel # offsets in same order as CUDA kernel
@ -364,6 +364,8 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2(
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation): def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
if NO_TRITON: # TODO
raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.")
if len(shape) == 5: if len(shape) == 5:
N, C, W, H, D = shape N, C, W, H, D = shape
else: else:

View File

@ -697,8 +697,6 @@ class SparseStructureFlowModel(nn.Module):
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3) x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous() h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
@ -746,7 +744,8 @@ class Trellis2(nn.Module):
super().__init__() super().__init__()
self.dtype = dtype self.dtype = dtype
# for some reason it passes num_heads = -1 # for some reason it passes num_heads = -1
num_heads = 12 if num_heads == -1:
num_heads = 12
args = { args = {
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels, "out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod, "model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
@ -763,8 +762,10 @@ class Trellis2(nn.Module):
def forward(self, x, timestep, context, **kwargs): def forward(self, x, timestep, context, **kwargs):
# FIXME: should find a way to distinguish between 512/1024 models # FIXME: should find a way to distinguish between 512/1024 models
# currently assumes 1024 # currently assumes 1024
transformer_options = kwargs.get("transformer_options") transformer_options = kwargs.get("transformer_options", {})
embeds = kwargs.get("embeds") embeds = kwargs.get("embeds")
if embeds is None:
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
#_, cond = context.chunk(2) # TODO #_, cond = context.chunk(2) # TODO
cond = embeds.chunk(2)[0] cond = embeds.chunk(2)[0]
context = torch.cat([torch.zeros_like(cond), cond]) context = torch.cat([torch.zeros_like(cond), cond])
@ -807,6 +808,8 @@ class Trellis2(nn.Module):
# TODO # TODO
out = self.img2shape(x, timestep, context) out = self.img2shape(x, timestep, context)
elif mode == "texture_generation": elif mode == "texture_generation":
if self.shape2txt is None:
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
out = self.shape2txt(x, timestep, context if not txt_rule else cond) out = self.shape2txt(x, timestep, context if not txt_rule else cond)
else: # structure else: # structure
timestep = timestep_reshift(timestep) timestep = timestep_reshift(timestep)

View File

@ -1522,6 +1522,8 @@ class Vae(nn.Module):
return self.shape_dec(slat, return_subs=True) return self.shape_dec(slat, return_subs=True)
def decode_tex_slat(self, slat, subs): def decode_tex_slat(self, slat, subs):
if self.txt_dec is None:
raise ValueError("Checkpoint doesn't include texture model")
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5 return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5
@torch.no_grad() @torch.no_grad()

View File

@ -1,9 +1,11 @@
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types from comfy_api.latest import ComfyExtension, IO, Types
import torch import torch.nn.functional as TF
import comfy.model_management import comfy.model_management
from PIL import Image from PIL import Image
import numpy as np import numpy as np
import torch
import copy
shape_slat_normalization = { shape_slat_normalization = {
"mean": torch.tensor([ "mean": torch.tensor([
@ -145,11 +147,11 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
def execute(cls, samples, vae, resolution): def execute(cls, samples, vae, resolution):
vae = vae.first_stage_model vae = vae.first_stage_model
samples = samples["samples"] samples = samples["samples"]
std = shape_slat_normalization["std"] std = shape_slat_normalization["std"].to(samples)
mean = shape_slat_normalization["mean"] mean = shape_slat_normalization["mean"].to(samples)
samples = samples * std + mean samples = samples * std + mean
mesh, subs = vae.decode_shape_slat(resolution, samples) mesh, subs = vae.decode_shape_slat(samples, resolution)
return IO.NodeOutput(mesh, subs) return IO.NodeOutput(mesh, subs)
class VaeDecodeTextureTrellis(IO.ComfyNode): class VaeDecodeTextureTrellis(IO.ComfyNode):
@ -172,8 +174,8 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
def execute(cls, samples, vae, shape_subs): def execute(cls, samples, vae, shape_subs):
vae = vae.first_stage_model vae = vae.first_stage_model
samples = samples["samples"] samples = samples["samples"]
std = tex_slat_normalization["std"] std = tex_slat_normalization["std"].to(samples)
mean = tex_slat_normalization["mean"] mean = tex_slat_normalization["mean"].to(samples)
samples = samples * std + mean samples = samples * std + mean
mesh = vae.decode_tex_slat(samples, shape_subs) mesh = vae.decode_tex_slat(samples, shape_subs)
@ -239,6 +241,8 @@ class Trellis2Conditioning(IO.ComfyNode):
scale = min(1, 1024 / max_size) scale = min(1, 1024 / max_size)
if scale < 1: if scale < 1:
image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS) image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS)
new_h, new_w = int(mask.shape[-2] * scale), int(mask.shape[-1] * scale)
mask = TF.interpolate(mask.unsqueeze(0).float(), size=(new_h, new_w), mode='nearest').squeeze(0)
image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255 image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255
@ -510,6 +514,7 @@ class PostProcessMesh(IO.ComfyNode):
) )
@classmethod @classmethod
def execute(cls, mesh, simplify, fill_holes_perimeter): def execute(cls, mesh, simplify, fill_holes_perimeter):
mesh = copy.deepcopy(mesh)
verts, faces = mesh.vertices, mesh.faces verts, faces = mesh.vertices, mesh.faces
if fill_holes_perimeter != 0.0: if fill_holes_perimeter != 0.0: