From 6191cd86bfcaa1c18df68d2f0c932212dbc0a64d Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Thu, 19 Feb 2026 22:05:33 +0200 Subject: [PATCH] trellis2conditioning and a hidden bug --- comfy/ldm/trellis2/model.py | 13 ++++++++++--- comfy_extras/nodes_trellis2.py | 21 +++++---------------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index b4fc15abc..8579b0580 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -826,7 +826,12 @@ class Trellis2(nn.Module): self.guidance_interval_txt = [0.6, 0.9] def forward(self, x, timestep, context, **kwargs): + # FIXME: should find a way to distinguish between 512/1024 models + # currently assumes 1024 embeds = kwargs.get("embeds") + _, cond = context.chunk(2) + cond = embeds.chunk(2)[0] + context = torch.cat([torch.zeros_like(cond), cond]) mode = kwargs.get("generation_mode") coords = kwargs.get("coords") transformer_options = kwargs.get("transformer_options") @@ -837,12 +842,13 @@ class Trellis2(nn.Module): shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] - if mode in ["shape_generation", "texture_generation"]: + not_struct_mode = mode in ["shape_generation", "texture_generation"] + if not_struct_mode: x = SparseTensor(feats=x, coords=coords) if mode == "shape_generation": # TODO - out = self.img2shape(x, timestep, torch.cat([embeds, torch.empty_like(embeds)])) + out = self.img2shape(x, timestep, context) elif mode == "texture_generation": out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure @@ -855,5 +861,6 @@ class Trellis2(nn.Module): if shape_rule: out = out.repeat(orig_bsz, 1, 1, 1, 1) - out.generation_mode = mode + if not_struct_mode: + out = out.feats return out diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index fc9a15cfa..1683949a3 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -233,25 +233,14 @@ class Trellis2Conditioning(IO.ComfyNode): image = image[0] # TODO - image = Image.fromarray(image.numpy()) + image = (image.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + image = Image.fromarray(image) max_size = max(image.size) scale = min(1, 1024 / max_size) if scale < 1: image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) - output_np = np.array(image) - alpha = output_np[:, :, 3] - bbox = np.argwhere(alpha > 0.8 * 255) - bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) - center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 - size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) - size = int(size * 1) - bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 - output = image.crop(bbox) # type: ignore - output = np.array(output).astype(np.float32) / 255 - output = output[:, :, :3] * output[:, :, 3:4] - - image = torch.tensor(output) + image = torch.tensor(np.array(image)).unsqueeze(0) # could make 1024 an option conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) @@ -276,7 +265,7 @@ class EmptyShapeLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): - decoded = structure_output.data + decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(coords.shape[0], in_channels) @@ -299,7 +288,7 @@ class EmptyTextureLatentTrellis2(IO.ComfyNode): @classmethod def execute(cls, structure_output): # TODO - decoded = structure_output.data + decoded = structure_output.data.unsqueeze(1) coords = torch.argwhere(decoded.bool())[:, [0, 2, 3, 4]].int() in_channels = 32 latent = torch.randn(coords.shape[0], in_channels - structure_output.feats.shape[1])