diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index eb410fe8b..b4fc15abc 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -828,6 +828,7 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") mode = kwargs.get("generation_mode") + coords = kwargs.get("coords") transformer_options = kwargs.get("transformer_options") sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: @@ -836,6 +837,9 @@ class Trellis2(nn.Module): shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] + if mode in ["shape_generation", "texture_generation"]: + x = SparseTensor(feats=x, coords=coords) + if mode == "shape_generation": # TODO out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 817769d08..bd250e5f5 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -1,9 +1,7 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types import torch -from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management -import comfy.model_patcher shape_slat_normalization = { "mean": torch.tensor([ @@ -205,7 +203,6 @@ class VaeDecodeStructureTrellis2(IO.ComfyNode): samples = samples.to(load_device) decoded = decoder(samples)>0 decoder.to(offload_device) - comfy.model_management.get_offload_stream out = Types.VOXEL(decoded.squeeze(1).float()) return IO.NodeOutput(out) @@ -253,10 +250,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): decoded = structure_output.data - coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int().unsqueeze(1) + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 - latent = SparseTensor(feats=torch.randn(coords.shape[0], in_channels), coords=coords) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation"}) + latent = torch.randn(coords.shape[0], in_channels) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "shape_generation", "coords": coords}) class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod @@ -275,9 +272,11 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): # TODO + decoded = structure_output.data + coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 - latent = structure_output.replace(feats=torch.randn(structure_output.data.shape[0], in_channels - structure_output.feats.shape[1])) - return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation"}) + latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1]) + return IO.NodeOutput({"samples": latent, "type": "trellis2", "generation_mode": "texture_generation", "coords": coords}) class EmptyStructureLatentTrellis2(IO.ComfyNode): @classmethod