debugging

This commit is contained in:
Yousef Rafat 2026-02-11 20:33:59 +02:00
parent f4059c189e
commit b7764479c2
5 changed files with 65 additions and 25 deletions

View File

@ -14,6 +14,7 @@ except:
def scaled_dot_product_attention(*args, **kwargs):
num_all_args = len(args) + len(kwargs)
q = None
if num_all_args == 1:
qkv = args[0] if len(args) > 0 else kwargs['qkv']
@ -26,8 +27,10 @@ def scaled_dot_product_attention(*args, **kwargs):
k = args[1] if len(args) > 1 else kwargs['k']
v = args[2] if len(args) > 2 else kwargs['v']
# TODO verify
heads = q or qkv
if q is not None:
heads = q
else:
heads = qkv
heads = heads.shape[2]
if optimized_attention.__name__ == 'attention_xformers':

View File

@ -7,7 +7,6 @@ from comfy.ldm.trellis2.attention import (
sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention
)
from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder
from comfy.nested_tensor import NestedTensor
from comfy.ldm.flux.math import apply_rope, apply_rope1
class SparseGELU(nn.GELU):
@ -586,6 +585,7 @@ class MultiHeadAttention(nn.Module):
else:
Lkv = context.shape[1]
q = self.to_q(x)
context = context.to(next(self.to_kv.parameters()).dtype)
kv = self.to_kv(context)
q = q.reshape(B, L, self.num_heads, -1)
kv = kv.reshape(B, Lkv, 2, self.num_heads, -1)
@ -782,6 +782,7 @@ class SparseStructureFlowModel(nn.Module):
h = block(h, t_emb, cond, self.rope_phases)
h = manual_cast(h, x.dtype)
h = F.layer_norm(h, h.shape[-1:])
h = h.to(next(self.out_layer.parameters()).dtype)
h = self.out_layer(h)
h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous()
@ -823,9 +824,7 @@ class Trellis2(nn.Module):
args.pop("out_channels")
self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args)
def forward(self, x: NestedTensor, timestep, context, **kwargs):
if isinstance(x, NestedTensor):
x = x.tensors[0]
def forward(self, x, timestep, context, **kwargs):
embeds = kwargs.get("embeds")
if not hasattr(x, "feats"):
mode = "structure_generation"
@ -843,6 +842,5 @@ class Trellis2(nn.Module):
timestep = timestep_reshift(timestep)
out = self.structure_model(x, timestep, context)
out = NestedTensor([out])
out.generation_mode = mode
return out

View File

@ -10,9 +10,6 @@ from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_
def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor:
"""
3D pixel shuffle.
"""
B, C, H, W, D = x.shape
C_ = C // scale_factor**3
x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D)
@ -967,6 +964,25 @@ class SparseLinear(nn.Linear):
return input.replace(super().forward(input.feats))
MIX_PRECISION_MODULES = (
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
nn.Linear,
SparseConv3d,
SparseLinear,
)
def convert_module_to_f16(l):
if isinstance(l, MIX_PRECISION_MODULES):
for p in l.parameters():
p.data = p.data.half()
class SparseUnetVaeEncoder(nn.Module):
"""
@ -1381,8 +1397,12 @@ class ResBlock3d(nn.Module):
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
def forward(self, x: torch.Tensor) -> torch.Tensor:
self.norm1 = self.norm1.to(torch.float32)
self.norm2 = self.norm2.to(torch.float32)
h = self.norm1(x)
h = F.silu(h)
dtype = next(self.conv1.parameters()).dtype
h = h.to(dtype)
h = self.conv1(h)
h = self.norm2(h)
h = F.silu(h)
@ -1400,7 +1420,7 @@ class SparseStructureDecoder(nn.Module):
channels: List[int],
num_res_blocks_middle: int = 2,
norm_type = "layer",
use_fp16: bool = False,
use_fp16: bool = True,
):
super().__init__()
self.out_channels = out_channels
@ -1439,20 +1459,27 @@ class SparseStructureDecoder(nn.Module):
if use_fp16:
self.convert_to_fp16()
@property
def device(self) -> torch.device:
return next(self.parameters()).device
def convert_to_fp16(self) -> None:
self.use_fp16 = True
self.dtype = torch.float16
self.blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = next(self.input_layer.parameters()).dtype
x = x.to(dtype)
h = self.input_layer(x)
h = h.type(self.dtype)
h = self.middle_block(h)
for block in self.blocks:
h = block(h)
h = h.type(x.dtype)
h = h.to(torch.float32)
self.out_layer = self.out_layer.to(torch.float32)
h = self.out_layer(h)
return h

View File

@ -497,6 +497,10 @@ class VAE:
init_txt_model = False
if "txt_dec.blocks.1.16.norm1.weight" in sd:
init_txt_model = True
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
# TODO
self.memory_used_decode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model)
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:

View File

@ -3,7 +3,7 @@ from comfy_api.latest import ComfyExtension, IO
import torch
from comfy.ldm.trellis2.model import SparseTensor
import comfy.model_management
from comfy.nested_tensor import NestedTensor
import comfy.model_patcher
from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode
shape_slat_normalization = {
@ -137,14 +137,15 @@ class VaeDecodeShapeTrellis(IO.ComfyNode):
)
@classmethod
def execute(cls, samples: NestedTensor, vae, resolution):
samples = samples.tensors[0]
def execute(cls, samples, vae, resolution):
vae = vae.first_stage_model
samples = samples["samples"]
std = shape_slat_normalization["std"]
mean = shape_slat_normalization["mean"]
samples = samples * std + mean
mesh, subs = vae.decode_shape_slat(resolution, samples)
return mesh, subs
return IO.NodeOutput(mesh, subs)
class VaeDecodeTextureTrellis(IO.ComfyNode):
@classmethod
@ -164,13 +165,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode):
@classmethod
def execute(cls, samples, vae, shape_subs):
samples = samples.tensors[0]
vae = vae.first_stage_model
samples = samples["samples"]
std = tex_slat_normalization["std"]
mean = tex_slat_normalization["mean"]
samples = samples * std + mean
mesh = vae.decode_tex_slat(samples, shape_subs)
return mesh
return IO.NodeOutput(mesh)
class VaeDecodeStructureTrellis2(IO.ComfyNode):
@classmethod
@ -189,10 +191,19 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode):
@classmethod
def execute(cls, samples, vae):
vae = vae.first_stage_model
decoder = vae.struct_dec
load_device = comfy.model_management.get_torch_device()
decoder = comfy.model_patcher.ModelPatcher(
decoder, load_device=load_device, offload_device=comfy.model_management.vae_offload_device()
)
comfy.model_management.load_model_gpu(decoder)
decoder = decoder.model
samples = samples["samples"]
samples = samples.to(load_device)
decoded = decoder(samples)>0
coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int()
return coords
return IO.NodeOutput(coords)
class Trellis2Conditioning(IO.ComfyNode):
@classmethod
@ -240,7 +251,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode):
coords = structure_output # or structure_output.coords
in_channels = 32
latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords)
latent = NestedTensor([latent])
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
class EmptyTextureLatentTrellis2(IO.ComfyNode):
@ -262,7 +272,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode):
# TODO
in_channels = 32
latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1]))
latent = NestedTensor([latent])
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
class EmptyStructureLatentTrellis2(IO.ComfyNode):
@ -283,7 +292,6 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode):
in_channels = 8
resolution = 16
latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution)
latent = NestedTensor([latent])
return IO.NodeOutput({"samples": latent, "type": "trellis2"})
def simplify_fn(vertices, faces, target=100000):
@ -469,7 +477,7 @@ class PostProcessMesh(IO.ComfyNode):
mesh.vertices = verts
mesh.faces = faces
return mesh
return IO.NodeOutput(mesh)
class Trellis2Extension(ComfyExtension):
@override