This commit is contained in:
Yousef Rafat 2026-02-22 01:25:10 +02:00
parent 1fde60b2bc
commit 253ee4c02c
5 changed files with 26 additions and 17 deletions

View File

@ -44,9 +44,6 @@ class DINOv3ViTAttention(nn.Module):
self.num_heads = num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.scaling = self.head_dim**-0.5
self.is_causal = False
self.k_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=False, device=device, dtype=dtype) # key_bias = False
self.v_proj = operations.Linear(self.embed_dim, self.embed_dim, bias=True, device=device, dtype=dtype)
@ -251,7 +248,7 @@ class DINOv3ViTModel(nn.Module):
intermediate_size=intermediate_size,num_attention_heads = num_attention_heads,
dtype=dtype, device=device, operations=operations)
for _ in range(num_hidden_layers)])
self.norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
self.norm = operations.LayerNorm(hidden_size, eps=layer_norm_eps, dtype=dtype, device=device)
def get_input_embeddings(self):
return self.embeddings.patch_embeddings

View File

@ -4,7 +4,7 @@ import math
import torch
from typing import Dict, Callable
NO_TRITION = False
NO_TRITON = False
try:
allow_tf32 = torch.cuda.is_tf32_supported()
except Exception:
@ -115,8 +115,8 @@ try:
allow_tf32=allow_tf32,
)
return output
except:
NO_TRITION = True
except Exception:
NO_TRITON = True
def compute_kernel_offsets(Kw, Kh, Kd, Dw, Dh, Dd, device):
# offsets in same order as CUDA kernel
@ -364,6 +364,8 @@ def neighbor_map_post_process_for_masked_implicit_gemm_2(
def sparse_submanifold_conv3d(feats, coords, shape, weight, bias, neighbor_cache, dilation):
if NO_TRITON: # TODO
raise RuntimeError("sparse_submanifold_conv3d requires Triton, which is not available.")
if len(shape) == 5:
N, C, W, H, D = shape
else:

View File

@ -697,8 +697,6 @@ class SparseStructureFlowModel(nn.Module):
def forward(self, x: torch.Tensor, t: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
x = x.view(x.shape[0], self.in_channels, *[self.resolution] * 3)
assert [*x.shape] == [x.shape[0], self.in_channels, *[self.resolution] * 3], \
f"Input shape mismatch, got {x.shape}, expected {[x.shape[0], self.in_channels, *[self.resolution] * 3]}"
h = x.view(*x.shape[:2], -1).permute(0, 2, 1).contiguous()
@ -746,7 +744,8 @@ class Trellis2(nn.Module):
super().__init__()
self.dtype = dtype
# for some reason it passes num_heads = -1
num_heads = 12
if num_heads == -1:
num_heads = 12
args = {
"out_channels":out_channels, "num_blocks":num_blocks, "cond_channels" :cond_channels,
"model_channels":model_channels, "num_heads":num_heads, "mlp_ratio": mlp_ratio, "share_mod": share_mod,
@ -763,8 +762,10 @@ class Trellis2(nn.Module):
def forward(self, x, timestep, context, **kwargs):
# FIXME: should find a way to distinguish between 512/1024 models
# currently assumes 1024
transformer_options = kwargs.get("transformer_options")
transformer_options = kwargs.get("transformer_options", {})
embeds = kwargs.get("embeds")
if embeds is None:
raise ValueError("Trellis2.forward requires 'embeds' in kwargs")
#_, cond = context.chunk(2) # TODO
cond = embeds.chunk(2)[0]
context = torch.cat([torch.zeros_like(cond), cond])
@ -807,6 +808,8 @@ class Trellis2(nn.Module):
# TODO
out = self.img2shape(x, timestep, context)
elif mode == "texture_generation":
if self.shape2txt is None:
raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!")
out = self.shape2txt(x, timestep, context if not txt_rule else cond)
else: # structure
timestep = timestep_reshift(timestep)

View File

@ -1522,6 +1522,8 @@ class Vae(nn.Module):
return self.shape_dec(slat, return_subs=True)
def decode_tex_slat(self, slat, subs):
if self.txt_dec is None:
raise ValueError("Checkpoint doesn't include texture model")
return self.txt_dec(slat, guide_subs=subs) * 0.5 + 0.5
@torch.no_grad()

View File

@ -1,9 +1,11 @@
from typing_extensions import override
from comfy_api.latest import ComfyExtension, IO, Types
import torch
import torch.nn.functional as TF
import comfy.model_management
from PIL import Image
import numpy as np
import torch
import copy
shape_slat_normalization = {
"mean": torch.tensor([
@ -145,11 +147,11 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
def execute(cls, samples, vae, resolution):
vae = vae.first_stage_model
samples = samples["samples"]
std = shape_slat_normalization["std"]
mean = shape_slat_normalization["mean"]
std = shape_slat_normalization["std"].to(samples)
mean = shape_slat_normalization["mean"].to(samples)
samples = samples * std + mean
mesh, subs = vae.decode_shape_slat(resolution, samples)
mesh, subs = vae.decode_shape_slat(samples, resolution)
return IO.NodeOutput(mesh, subs)
class VaeDecodeTextureTrellis(IO.ComfyNode):
@ -172,8 +174,8 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
def execute(cls, samples, vae, shape_subs):
vae = vae.first_stage_model
samples = samples["samples"]
std = tex_slat_normalization["std"]
mean = tex_slat_normalization["mean"]
std = tex_slat_normalization["std"].to(samples)
mean = tex_slat_normalization["mean"].to(samples)
samples = samples * std + mean
mesh = vae.decode_tex_slat(samples, shape_subs)
@ -239,6 +241,8 @@ class Trellis2Conditioning(IO.ComfyNode):
scale = min(1, 1024 / max_size)
if scale < 1:
image = image.resize((int(image.width * scale), int(image.height * scale)), Image.Resampling.LANCZOS)
new_h, new_w = int(mask.shape[-2] * scale), int(mask.shape[-1] * scale)
mask = TF.interpolate(mask.unsqueeze(0).float(), size=(new_h, new_w), mode='nearest').squeeze(0)
image = torch.tensor(np.array(image)).unsqueeze(0).float() / 255
@ -510,6 +514,7 @@ class PostProcessMesh(IO.ComfyNode):
)
@classmethod
def execute(cls, mesh, simplify, fill_holes_perimeter):
mesh = copy.deepcopy(mesh)
verts, faces = mesh.vertices, mesh.faces
if fill_holes_perimeter != 0.0: