diff --git a/comfy/ldm/trellis2/cumesh.py b/comfy/ldm/trellis2/cumesh.py index cb067a32f..1be8408c6 100644 --- a/comfy/ldm/trellis2/cumesh.py +++ b/comfy/ldm/trellis2/cumesh.py @@ -207,11 +207,14 @@ class TorchHashMap: def lookup_flat(self, flat_keys: torch.Tensor) -> torch.Tensor: flat = flat_keys.long() + if self._n == 0: + return torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) idx = torch.searchsorted(self.sorted_keys, flat) - found = (idx < self._n) & (self.sorted_keys[idx] == flat) + idx_safe = torch.clamp(idx, max=self._n - 1) + found = (idx < self._n) & (self.sorted_keys[idx_safe] == flat) out = torch.full((flat.shape[0],), self.default_value, device=flat.device, dtype=self.sorted_vals.dtype) if found.any(): - out[found] = self.sorted_vals[idx[found]] + out[found] = self.sorted_vals[idx_safe[found]] return out diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 07cf86d30..ef1c25d33 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -8,7 +8,6 @@ from comfy.ldm.trellis2.attention import ( ) from comfy.ldm.genmo.joint_model.layers import TimestepEmbedder from comfy.ldm.flux.math import apply_rope, apply_rope1 -import builtins class SparseGELU(nn.GELU): def forward(self, input: VarLenTensor) -> VarLenTensor: @@ -829,19 +828,19 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): # FIXME: should find a way to distinguish between 512/1024 models # currently assumes 1024 + transformer_options = kwargs.get("transformer_options") embeds = kwargs.get("embeds") - _, cond = context.chunk(2) + #_, cond = context.chunk(2) # TODO cond = embeds.chunk(2)[0] context = torch.cat([torch.zeros_like(cond), cond]) - mode = getattr(builtins, "TRELLIS_MODE", "structure_generation") - coords = getattr(builtins, "TRELLIS_COORDS", None) + coords = transformer_options.get("coords", None) + mode = transformer_options.get("generation_mode", "structure_generation") if coords is not None: x = x.squeeze(0) not_struct_mode = True else: mode = "structure_generation" not_struct_mode = False - transformer_options = kwargs.get("transformer_options") sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index c8d84fd23..14f5484d6 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -4,7 +4,6 @@ import torch import comfy.model_management from PIL import Image import numpy as np -import builtins shape_slat_normalization = { "mean": torch.tensor([ @@ -258,21 +257,31 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): category="latent/3d", inputs=[ IO.Voxel.Input("structure_output"), + IO.Model.Input("model") ], outputs=[ IO.Latent.Output(), + IO.Model.Output() ] ) @classmethod - def execute(cls, structure_output): + def execute(cls, structure_output, model): decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(1, coords.shape[0], in_channels) - builtins.TRELLIS_MODE = "shape_generation" - builtins.TRELLIS_COORDS = coords - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + model = model.clone() + if "transformer_options" not in model.model_options: + model.model_options = {} + else: + model.model_options = model.model_options.copy() + + model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() + + model.model_options["transformer_options"]["coords"] = coords + model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -285,19 +294,29 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): ], outputs=[ IO.Latent.Output(), + IO.Model.Output() ] ) @classmethod - def execute(cls, structure_output): + def execute(cls, structure_output, model): # TODO decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) - builtins.TRELLIS_MODE = "texture_generation" - builtins.TRELLIS_COORDS = coords - return IO.NodeOutput({"samples": latent, "type": "trellis2"}) + model = model.clone() + if "transformer_options" not in model.model_options: + model.model_options = {} + else: + model.model_options = model.model_options.copy() + + model.model_options["transformer_options"] = model.model_options["transformer_options"].copy() + + model.model_options["transformer_options"]["coords"] = coords + model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + return IO.NodeOutput({"samples": latent, "type": "trellis2"}, model) + class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod