From def8947e75b406fa87df78f9c6f4b9d0929eb6d8 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 20 Mar 2026 18:37:11 +0200 Subject: [PATCH] shape working --- comfy/ldm/trellis2/model.py | 50 +++++++++++++++++++++++++--------- comfy_extras/nodes_trellis2.py | 22 +++++++++++---- 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 7a3e387c3..8a0c6d8b6 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -770,14 +770,20 @@ class Trellis2(nn.Module): is_1024 = self.img2shape.resolution == 1024 coords = transformer_options.get("coords", None) mode = transformer_options.get("generation_mode", "structure_generation") + is_512_run = False + if mode == "shape_generation_512": + is_512_run = True + mode = "shape_generation" if coords is not None: x = x.squeeze(-1).transpose(1, 2) not_struct_mode = True else: mode = "structure_generation" not_struct_mode = False - if is_1024 and mode == "shape_generation": + + if is_1024 and mode == "shape_generation" and not is_512_run: context = embeds + sigmas = transformer_options.get("sigmas")[0].item() if sigmas < 1.00001: timestep *= 1000.0 @@ -786,12 +792,24 @@ class Trellis2(nn.Module): txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] if not_struct_mode: - B, N, C = x.shape + orig_bsz = x.shape[0] + rule = txt_rule if mode == "texture_generation" else shape_rule - if mode == "shape_generation": - feats_flat = x.reshape(-1, C) + if rule and orig_bsz > 1: + x_eval = x[1].unsqueeze(0) + t_eval = timestep[1].unsqueeze(0) if timestep.shape[0] > 1 else timestep + c_eval = cond + else: + x_eval = x + t_eval = timestep + c_eval = context - # 3. inflate coords [N, 4] -> [B*N, 4] + B, N, C = x_eval.shape + + if mode in ["shape_generation", "texture_generation"]: + feats_flat = x_eval.reshape(-1, C) + + # inflate coords [N, 4] -> [B*N, 4] coords_list = [] for i in range(B): c = coords.clone() @@ -799,23 +817,27 @@ class Trellis2(nn.Module): coords_list.append(c) batched_coords = torch.cat(coords_list, dim=0) - else: # TODO: texture - # may remove the else if texture doesn't require special handling + else: batched_coords = coords - feats_flat = x - x = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) + feats_flat = x_eval + + x_st = SparseTensor(feats=feats_flat, coords=batched_coords.to(torch.int32)) if mode == "shape_generation": # TODO - out = self.img2shape(x, timestep, context) + if is_512_run: + out = self.img2shape_512(x_st, t_eval, c_eval) + else: + out = self.img2shape(x_st, t_eval, c_eval) elif mode == "texture_generation": if self.shape2txt is None: raise ValueError("Checkpoint for Trellis2 doesn't include texture generation!") slat = transformer_options.get("shape_slat") if slat is None: raise ValueError("shape_slat can't be None") - x = sparse_cat([x, slat]) - out = self.shape2txt(x, timestep, context if not txt_rule else cond) + slat.feats = slat.feats.repeat(B, 1) + x_st = sparse_cat([x_st, slat]) + out = self.shape2txt(x_st, t_eval, c_eval) else: # structure #timestep = timestep_reshift(timestep) orig_bsz = x.shape[0] @@ -828,6 +850,8 @@ class Trellis2(nn.Module): if not_struct_mode: out = out.feats - if mode == "shape_generation": + if not_struct_mode: out = out.view(B, N, -1).transpose(1, 2).unsqueeze(-1) + if rule and orig_bsz > 1: + out = out.repeat(orig_bsz, 1, 1, 1) return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 409d2d23c..cba6b3241 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -178,6 +178,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode): @classmethod def execute(cls, shape_latent_512, vae, target_resolution, max_tokens): device = comfy.model_management.get_torch_device() + comfy.model_management.load_model_gpu(vae.patcher) feats = shape_latent_512["samples"].squeeze(-1).transpose(1, 2).reshape(-1, 32).to(device) coords_512 = shape_latent_512["coords"].to(device) @@ -185,11 +186,9 @@ class Trellis2UpsampleCascade(IO.ComfyNode): slat = shape_norm(feats, coords_512) decoder = vae.first_stage_model.shape_dec - decoder.to(device) slat.feats = slat.feats.to(next(decoder.parameters()).dtype) hr_coords = decoder.upsample(slat, upsample_times=4) - decoder.cpu() lr_resolution = 512 hr_resolution = int(target_resolution) @@ -206,7 +205,7 @@ class Trellis2UpsampleCascade(IO.ComfyNode): break hr_resolution -= 128 - return IO.NodeOutput(final_coords.cpu()) + return IO.NodeOutput(final_coords,) dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) @@ -341,11 +340,19 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_or_coords, model): # to accept the upscaled coords - if hasattr(structure_or_coords, "data"): + is_512_pass = False + + if hasattr(structure_or_coords, "data") and structure_or_coords.data.ndim == 4: decoded = structure_or_coords.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() + is_512_pass = True + + elif isinstance(structure_or_coords, torch.Tensor) and structure_or_coords.ndim == 2: + coords = structure_or_coords.int() + is_512_pass = False + else: - coords = structure_or_coords + raise ValueError(f"Invalid input to EmptyShapeLatent: {type(structure_or_coords)}") in_channels = 32 # image like format latent = torch.randn(1, in_channels, coords.shape[0], 1) @@ -357,7 +364,10 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): model.model_options["transformer_options"] = {} model.model_options["transformer_options"]["coords"] = coords - model.model_options["transformer_options"]["generation_mode"] = "shape_generation" + if is_512_pass: + model.model_options["transformer_options"]["generation_mode"] = "shape_generation_512" + else: + model.model_options["transformer_options"]["generation_mode"] = "shape_generation" return IO.NodeOutput({"samples": latent, "coords": coords, "type": "trellis2"}, model) class EmptyTextureLatentTrellis2(IO.ComfyNode):