From b7764479c263c9d41bd077c7453b8d4c15551a34 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 11 Feb 2026 20:33:59 +0200 Subject: [PATCH] debugging --- comfy/ldm/trellis2/attention.py | 7 ++++-- comfy/ldm/trellis2/model.py | 8 +++---- comfy/ldm/trellis2/vae.py | 41 +++++++++++++++++++++++++++------ comfy/sd.py | 4 ++++ comfy_extras/nodes_trellis2.py | 30 +++++++++++++++--------- 5 files changed, 65 insertions(+), 25 deletions(-) diff --git a/comfy/ldm/trellis2/attention.py b/comfy/ldm/trellis2/attention.py index 3038f4023..e6aa50842 100644 --- a/comfy/ldm/trellis2/attention.py +++ b/comfy/ldm/trellis2/attention.py @@ -14,6 +14,7 @@ except: def scaled_dot_product_attention(*args, **kwargs): num_all_args = len(args) + len(kwargs) + q = None if num_all_args == 1: qkv = args[0] if len(args) > 0 else kwargs['qkv'] @@ -26,8 +27,10 @@ def scaled_dot_product_attention(*args, **kwargs): k = args[1] if len(args) > 1 else kwargs['k'] v = args[2] if len(args) > 2 else kwargs['v'] - # TODO verify - heads = q or qkv + if q is not None: + heads = q + else: + heads = qkv heads = heads.shape[2] if optimized_attention.__name__ == 'attention_xformers': diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 8bc8e8f7a..17286a553 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -7,7 +7,6 @@ from comfy.ldm.trellis2.attention import ( sparse_windowed_scaled_dot_product_self_attention, sparse_scaled_dot_product_attention, scaled_dot_product_attention ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder -from comfy.nested_tensor import NestedTensor from comfy.ldm.flux.math import apply_rope, apply_rope1 class SparseGELU(nn.GELU): @@ -586,6 +585,7 @@ class MultiHeadAttention(nn.Module): else: Lkv = context.shape[1] q = self.to_q(x) + context = context.to(next(self.to_kv.parameters()).dtype) kv = self.to_kv(context) q = q.reshape(B, L, self.num_heads, -1) kv = kv.reshape(B, Lkv, 2, self.num_heads, -1) @@ -782,6 +782,7 @@ class SparseStructureFlowModel(nn.Module): h = block(h, t_emb, cond, self.rope_phases) h = manual_cast(h, x.dtype) h = F.layer_norm(h, h.shape[-1:]) + h = h.to(next(self.out_layer.parameters()).dtype) h = self.out_layer(h) h = h.permute(0, 2, 1).view(h.shape[0], h.shape[2], *[self.resolution] * 3).contiguous() @@ -823,9 +824,7 @@ class Trellis2(nn.Module): args.pop("out_channels") self.structure_model = SparseStructureFlowModel(resolution=16, in_channels=8, out_channels=8, **args) - def forward(self, x: NestedTensor, timestep, context, **kwargs): - if isinstance(x, NestedTensor): - x = x.tensors[0] + def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") if not hasattr(x, "feats"): mode = "structure_generation" @@ -843,6 +842,5 @@ class Trellis2(nn.Module): timestep = timestep_reshift(timestep) out = self.structure_model(x, timestep, context) - out = NestedTensor([out]) out.generation_mode = mode return out diff --git a/comfy/ldm/trellis2/vae.py b/comfy/ldm/trellis2/vae.py index 6e13afd8d..57bf78346 100644 --- a/comfy/ldm/trellis2/vae.py +++ b/comfy/ldm/trellis2/vae.py @@ -10,9 +10,6 @@ from comfy.ldm.trellis2.cumesh import TorchHashMap, Mesh, MeshWithVoxel, sparse_ def pixel_shuffle_3d(x: torch.Tensor, scale_factor: int) -> torch.Tensor: - """ - 3D pixel shuffle. - """ B, C, H, W, D = x.shape C_ = C // scale_factor**3 x = x.reshape(B, C_, scale_factor, scale_factor, scale_factor, H, W, D) @@ -967,6 +964,25 @@ class SparseLinear(nn.Linear): return input.replace(super().forward(input.feats)) +MIX_PRECISION_MODULES = ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.ConvTranspose1d, + nn.ConvTranspose2d, + nn.ConvTranspose3d, + nn.Linear, + SparseConv3d, + SparseLinear, +) + + +def convert_module_to_f16(l): + if isinstance(l, MIX_PRECISION_MODULES): + for p in l.parameters(): + p.data = p.data.half() + + class SparseUnetVaeEncoder(nn.Module): """ @@ -1381,8 +1397,12 @@ class ResBlock3d(nn.Module): self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity() def forward(self, x: torch.Tensor) -> torch.Tensor: + self.norm1 = self.norm1.to(torch.float32) + self.norm2 = self.norm2.to(torch.float32) h = self.norm1(x) h = F.silu(h) + dtype = next(self.conv1.parameters()).dtype + h = h.to(dtype) h = self.conv1(h) h = self.norm2(h) h = F.silu(h) @@ -1400,7 +1420,7 @@ class SparseStructureDecoder(nn.Module): channels: List[int], num_res_blocks_middle: int = 2, norm_type = "layer", - use_fp16: bool = False, + use_fp16: bool = True, ): super().__init__() self.out_channels = out_channels @@ -1439,20 +1459,27 @@ class SparseStructureDecoder(nn.Module): if use_fp16: self.convert_to_fp16() - @property def device(self) -> torch.device: return next(self.parameters()).device + def convert_to_fp16(self) -> None: + self.use_fp16 = True + self.dtype = torch.float16 + self.blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + def forward(self, x: torch.Tensor) -> torch.Tensor: + dtype = next(self.input_layer.parameters()).dtype + x = x.to(dtype) h = self.input_layer(x) h = h.type(self.dtype) - h = self.middle_block(h) for block in self.blocks: h = block(h) - h = h.type(x.dtype) + h = h.to(torch.float32) + self.out_layer = self.out_layer.to(torch.float32) h = self.out_layer(h) return h diff --git a/comfy/sd.py b/comfy/sd.py index 25fd3ba7b..276e87d2a 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -497,6 +497,10 @@ class VAE: init_txt_model = False if "txt_dec.blocks.1.16.norm1.weight" in sd: init_txt_model = True + self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] + # TODO + self.memory_used_decode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (6500 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.first_stage_model = comfy.ldm.trellis2.vae.Vae(init_txt_model) elif "decoder.conv_in.weight" in sd: if sd['decoder.conv_in.weight'].shape[1] == 64: diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 4eff2dbc3..c735469be 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -3,7 +3,7 @@ from comfy_api.latest import ComfyExtension, IO import torch from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management -from comfy.nested_tensor import NestedTensor +import comfy.model_patcher from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode shape_slat_normalization = { @@ -137,14 +137,15 @@ class VaeDecodeShapeTrellis(IO.ComfyNode): ) @classmethod - def execute(cls, samples: NestedTensor, vae, resolution): - samples = samples.tensors[0] + def execute(cls, samples, vae, resolution): + vae = vae.first_stage_model + samples = samples["samples"] std = shape_slat_normalization["std"] mean = shape_slat_normalization["mean"] samples = samples * std + mean mesh, subs = vae.decode_shape_slat(resolution, samples) - return mesh, subs + return IO.NodeOutput(mesh, subs) class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod @@ -164,13 +165,14 @@ class VaeDecodeTextureTrellis(IO.ComfyNode): @classmethod def execute(cls, samples, vae, shape_subs): - samples = samples.tensors[0] + vae = vae.first_stage_model + samples = samples["samples"] std = tex_slat_normalization["std"] mean = tex_slat_normalization["mean"] samples = samples * std + mean mesh = vae.decode_tex_slat(samples, shape_subs) - return mesh + return IO.NodeOutput(mesh) class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod @@ -189,10 +191,19 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): @classmethod def execute(cls, samples, vae): + vae = vae.first_stage_model decoder = vae.struct_dec + load_device = comfy.model_management.get_torch_device() + decoder = comfy.model_patcher.ModelPatcher( + decoder, load_device=load_device, offload_device=comfy.model_management.vae_offload_device() + ) + comfy.model_management.load_model_gpu(decoder) + decoder = decoder.model + samples = samples["samples"] + samples = samples.to(load_device) decoded = decoder(samples)>0 coords = torch.argwhere(decoded)[:, [0, 2, 3, 4]].int() - return coords + return IO.NodeOutput(coords) class Trellis2Conditioning(IO.ComfyNode): @classmethod @@ -240,7 +251,6 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): coords = structure_output # or structure_output.coords in_channels = 32 latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) - latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @@ -262,7 +272,6 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): # TODO in_channels = 32 latent = structure_output.replace(feats=torch.randn(structure_output.coords.shape[0], in_channels - structure_output.feats.shape[1])) - latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @@ -283,7 +292,6 @@ class EmptyStructureLatentTrellis2(IO.ComfyNode): in_channels = 8 resolution = 16 latent = torch.randn(batch_size, in_channels, resolution, resolution, resolution) - latent = NestedTensor([latent]) return IO.NodeOutput({"samples": latent, "type": "trellis2"}) def simplify_fn(vertices, faces, target=100000): @@ -469,7 +477,7 @@ class PostProcessMesh(IO.ComfyNode): mesh.vertices = verts mesh.faces = faces - return mesh + return IO.NodeOutput(mesh) class Trellis2Extension(ComfyExtension): @override