diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index bd250e5f5..fc9a15cfa 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -2,6 +2,8 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, IO, Types import torch import comfy.model_management +from PIL import Image +import numpy as np shape_slat_normalization = { "mean": torch.tensor([ @@ -226,6 +228,31 @@ class Trellis2Conditioning(IO.ComfyNode): @classmethod def execute(cls, clip_vision_model, image, mask, background_color) -> IO.NodeOutput: + + if image.ndim == 4: + image = image[0] + + # TODO + image = Image.fromarray(image.numpy()) + max_size = max(image.size) + scale = min(1, 1024 / max_size) + if scale < 1: + image = image.resize((int(image.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS) + + output_np = np.array(image) + alpha = output_np[:, :, 3] + bbox = np.argwhere(alpha > 0.8 * 255) + bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) + center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 + size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) + size = int(size * 1) + bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2 + output = image.crop(bbox) # type: ignore + output = np.array(output).astype(np.float32) / 255 + output = output[:, :, :3] * output[:, :, 3:4] + + image = torch.tensor(output) + # could make 1024 an option conditioning, _ = run_conditioning(clip_vision_model, image, mask, include_1024=True, background_color=background_color) embeds = conditioning["cond_1024"] # should add that