From 91fa563b21a745041c50bb7f7e5038330e01ae38 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Mon, 16 Feb 2026 01:53:53 +0200 Subject: [PATCH] rewriting conditioning logic + model code addition --- comfy/ldm/trellis2/model.py | 10 ++- comfy_extras/nodes_trellis2.py | 125 ++++++++++++++++----------------- 2 files changed, 70 insertions(+), 65 deletions(-) diff --git a/comfy/ldm/trellis2/model.py b/comfy/ldm/trellis2/model.py index 76fe8ad19..4c398294a 100644 --- a/comfy/ldm/trellis2/model.py +++ b/comfy/ldm/trellis2/model.py @@ -826,8 +826,11 @@ class Trellis2(nn.Module): def forward(self, x, timestep, context, **kwargs): embeds = kwargs.get("embeds") mode = kwargs.get("generation_mode") - sigmas = kwargs.get("sigmas")[0].item() - cond = context.chunk(2) + transformer_options = kwargs.get("transformer_options") + sigmas = transformer_options.get("sigmas")[0].item() + if sigmas < 1.00001: + timestep *= 1000.0 + cond = context.chunk(2)[1] shape_rule = sigmas < self.guidance_interval[0] or sigmas > self.guidance_interval[1] txt_rule = sigmas < self.guidance_interval_txt[0] or sigmas > self.guidance_interval_txt[1] @@ -838,6 +841,9 @@ class Trellis2(nn.Module): out = self.shape2txt(x, timestep, context if not txt_rule else cond) else: # structure timestep = timestep_reshift(timestep) + if shape_rule: + x = x[0].unsqueeze(0) + timestep = timestep[0] out = self.structure_model(x, timestep, context if not shape_rule else cond) out.generation_mode = mode diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index 560751091..4d97129eb 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -4,7 +4,6 @@ import torch from comfy.ldm.trellis2.model import SparseTensor import comfy.model_management import comfy.model_patcher -from torchvision.transforms import ToPILImage, ToTensor, Resize, InterpolationMode shape_slat_normalization = { "mean": torch.tensor([ @@ -36,86 +35,85 @@ tex_slat_normalization = { ])[None] } -def smart_crop_square( - image: torch.Tensor, - background_color=(128, 128, 128), -): +dino_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1) +dino_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1) + +def smart_crop_square(image, mask, margin_ratio=0.1, bg_color=(128, 128, 128)): + nz = torch.nonzero(mask[0] > 0.5) + if nz.shape[0] == 0: + C, H, W = image.shape + side = max(H, W) + canvas = torch.full((C, side, side), 0.5, device=image.device) # Gray + canvas[:, (side-H)//2:(side-H)//2+H, (side-W)//2:(side-W)//2+W] = image + return canvas + + y_min, x_min = nz.min(dim=0)[0] + y_max, x_max = nz.max(dim=0)[0] + + obj_w, obj_h = x_max - x_min, y_max - y_min + center_x, center_y = (x_min + x_max) / 2, (y_min + y_max) / 2 + + side = int(max(obj_w, obj_h) * (1 + margin_ratio * 2)) + half_side = side / 2 + + x1, y1 = int(center_x - half_side), int(center_y - half_side) + x2, y2 = x1 + side, y1 + side + C, H, W = image.shape - size = max(H, W) - canvas = torch.empty( - (C, size, size), - dtype=image.dtype, - device=image.device - ) + canvas = torch.ones((C, side, side), device=image.device) for c in range(C): - canvas[c].fill_(background_color[c]) - top = (size - H) // 2 - left = (size - W) // 2 - canvas[:, top:top + H, left:left + W] = image + canvas[c] *= (bg_color[c] / 255.0) + + src_x1, src_y1 = max(0, x1), max(0, y1) + src_x2, src_y2 = min(W, x2), min(H, y2) + + dst_x1, dst_y1 = max(0, -x1), max(0, -y1) + dst_x2 = dst_x1 + (src_x2 - src_x1) + dst_y2 = dst_y1 + (src_y2 - src_y1) + + canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = image[:, src_y1:src_y2, src_x1:src_x2] return canvas -def run_conditioning( - model, - image: torch.Tensor, - include_1024: bool = True, - background_color: str = "black", -): - # TODO: should check if normalization was applied in these steps - model = model.model - device = comfy.model_management.intermediate_device() # replaces .cpu() - torch_device = comfy.model_management.get_torch_device() # replaces .cuda() - bg_colors = { - "black": (0, 0, 0), - "gray": (128, 128, 128), - "white": (255, 255, 255), - } - bg_color = bg_colors.get(background_color, (128, 128, 128)) +def run_conditioning(model, image, mask, include_1024 = True, background_color = "black"): + model_internal = model.model + device = comfy.model_management.intermediate_device() + torch_device = comfy.model_management.get_torch_device() - # Convert image to PIL - if image.dim() == 4: - pil_image = (image[0] * 255).clip(0, 255).to(torch.uint8) - else: - pil_image = (image * 255).clip(0, 255).to(torch.uint8) + bg_colors = {"black": (0, 0, 0), "gray": (128, 128, 128), "white": (255, 255, 255)} + bg_rgb = bg_colors.get(background_color, (128, 128, 128)) - pil_image = pil_image.movedim(-1, 0) - pil_image = smart_crop_square(pil_image, background_color=bg_color) + img_t = image[0].movedim(-1, 0).to(torch_device).float() + mask_t = mask[0].to(torch_device).float() + if mask_t.ndim == 2: + mask_t = mask_t.unsqueeze(0) - model.image_size = 512 - def set_image_size(image, image_size=512): - if image.ndim == 3: - image = image.unsqueeze(0) + cropped_img = smart_crop_square(img_t, mask_t, bg_color=bg_rgb) - to_pil = ToPILImage() - to_tensor = ToTensor() - resizer = Resize((image_size, image_size), interpolation=InterpolationMode.LANCZOS) + def prepare_tensor(img, size): + resized = torch.nn.functional.interpolate( + img.unsqueeze(0), size=(size, size), mode='bicubic', align_corners=False + ) + return (resized - dino_mean.to(torch_device)) / dino_std.to(torch_device) - pil_img = to_pil(image.squeeze(0)) - resized_pil = resizer(pil_img) - image = to_tensor(resized_pil).unsqueeze(0) - - return image.to(torch_device).float() - - pil_image = set_image_size(pil_image, 512) - cond_512 = model(pil_image)[0] + model_internal.image_size = 512 + input_512 = prepare_tensor(cropped_img, 512) + cond_512 = model_internal(input_512)[0] cond_1024 = None if include_1024: - model.image_size = 1024 - pil_image = set_image_size(pil_image, 1024) - cond_1024 = model(pil_image)[0] - - neg_cond = torch.zeros_like(cond_512) + model_internal.image_size = 1024 + input_1024 = prepare_tensor(cropped_img, 1024) + cond_1024 = model_internal(input_1024)[0] conditioning = { 'cond_512': cond_512.to(device), - 'neg_cond': neg_cond.to(device), + 'neg_cond': torch.zeros_like(cond_512).to(device), } if cond_1024 is not None: conditioning['cond_1024'] = cond_1024.to(device) - preprocessed_tensor = pil_image.to(torch.float32) / 255.0 - preprocessed_tensor = preprocessed_tensor.unsqueeze(0) + preprocessed_tensor = cropped_img.movedim(0, -1).unsqueeze(0).cpu() return conditioning, preprocessed_tensor @@ -213,6 +211,7 @@ class Trellis2Conditioning(IO.ComfyNode): inputs=[ IO.ClipVision.Input("clip_vision_model"), IO.Image.Input("image"), + IO.Mask.Input("mask"), IO.Combo.Input("background_color", options=["black", "gray", "white"], default="black") ], outputs=[ @@ -222,9 +221,9 @@ class Trellis2Conditioning(IO.ComfyNode): ) @classmethod - def execute(cls, clip_vision_model, image, background_color) -> IO.NodeOutput: + def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput: # could make 1024 an option - conditioning, _ = run_conditioning(clip_vision_model, image, include_1024=True, background_color=background_color) + conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that positive = [[conditioning["cond_512"], {"embeds": embeds}]] negative = [[conditioning["neg_cond"], {"embeds": embeds}]]