From 2cb06431e8e481b41c2c4da2aa92ba30ea07d66c Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Tue, 7 Apr 2026 23:02:33 +0200 Subject: [PATCH] fix for conditioning --- comfy/ldm/trellis2/model.py | 2 +- comfy_extras/nodes_trellis2.py | 29 ++++++++++++++--------------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 7c6ffdd69..ea7ada9f8 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -671,7 +671,7 @@ class SparseStructureFlowModel(nn.Module): coords = torch.meshgrid(*[torch.arange(res, device=self.device) for res in [resolution] * 3], indexing='ij') coords = torch.stack(coords, dim=-1).reshape(-1, 3) rope_phases = pos_embedder(coords) - self.register_buffer("rope_phases", rope_phases) + self.register_buffer("rope_phases", rope_phases, persistent=False) if pe_mode != "rope": self.rope_phases = None diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 088cdd3f1..d3f5e4940 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -2,7 +2,6 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types from comfy.ldm.trellis2.vae import SparseTensor import comfy.model_management -import logging from PIL import Image import numpy as np import torch @@ -250,28 +249,28 @@ class Trellis2UpsampleCascade(IO.ComfyNode): return IO.NodeOutput(final_coords,) -dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) -dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) +dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) +dino_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) def run_conditioning(model, cropped_img_tensor, include_1024=True): model_internal = model.model device = comfy.model_management.intermediate_device() torch_device = comfy.model_management.get_torch_device() - img_t = cropped_img_tensor.to(torch_device) - - def prepare_tensor(img, size): - resized = torch.nn.functional.interpolate(img, size=(size, size), mode='bicubic', align_corners=False).clamp(0.0, 1.0) - return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) + def prepare_tensor(pil_img, size): + resized_pil = pil_img.resize((size, size), Image.Resampling.LANCZOS) + img_np = np.array(resized_pil).astype(np.float32) / 255.0 + img_t = torch.from_numpy(img_np).permute(2, 0, 1).unsqueeze(0).to(torch_device) + return (img_t - dino_mean.to(torch_device)) / dino_std.to(torch_device) model_internal.image_size = 512 - input_512 = prepare_tensor(img_t, 512) + input_512 = prepare_tensor(cropped_img_tensor, 512) cond_512 = model_internal(input_512)[0] cond_1024 = None if include_1024: model_internal.image_size = 1024 - input_1024 = prepare_tensor(img_t, 1024) + input_1024 = prepare_tensor(cropped_img_tensor, 1024) cond_1024 = model_internal(input_1024)[0] conditioning = { @@ -341,14 +340,15 @@ class Trellis2Conditioning(IO.ComfyNode): crop_x2 = int(center_x + size // 2) crop_y2 = int(center_y + size // 2) - rgba_pil = Image.fromarray(rgba_np, 'RGBA') + rgba_pil = Image.fromarray(rgba_np) cropped_rgba = rgba_pil.crop((crop_x1, crop_y1, crop_x2, crop_y2)) cropped_np = np.array(cropped_rgba).astype(np.float32) / 255.0 else: + import logging logging.warning("Mask for the image is empty. Trellis2 requires an image with a mask for the best mesh quality.") cropped_np = rgba_np.astype(np.float32) / 255.0 - bg_colors = {"black": [0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} + bg_colors = {"black":[0.0, 0.0, 0.0], "gray":[0.5, 0.5, 0.5], "white":[1.0, 1.0, 1.0]} bg_rgb = np.array(bg_colors.get(background_color, [0.0, 0.0, 0.0]), dtype=np.float32) fg = cropped_np[:, :, :3] @@ -358,10 +358,9 @@ class Trellis2Conditioning(IO.ComfyNode): # to match trellis2 code (quantize -> dequantize) composite_uint8 = (composite_np * 255.0).round().clip(0, 255).astype(np.uint8) - cropped_img_tensor = torch.from_numpy(composite_uint8).float() / 255.0 - cropped_img_tensor = cropped_img_tensor.movedim(-1, 0).unsqueeze(0) + cropped_pil = Image.fromarray(composite_uint8) - conditioning = run_conditioning(clip_vision_model, cropped_img_tensor, include_1024=True) + conditioning = run_conditioning(clip_vision_model, cropped_pil, include_1024=True) embeds = conditioning["cond_1024"] positive = [[conditioning["cond_512"], {"embeds": embeds}]]