diff --git a/comfy_extras/nodes_trellis2.py b/comfy_extras/nodes_trellis2.py index ad9881db7..bf457135c 100644 --- a/comfy_extras/nodes_trellis2.py +++ b/comfy_extras/nodes_trellis2.py @@ -71,7 +71,14 @@ def smart_crop_square(image, mask, margin_ratio=0.1, bg_color=(128, 128, 128)): dst_x2 = dst_x1 + (src_x2 - src_x1) dst_y2 = dst_y1 + (src_y2 - src_y1) - canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = image[:, src_y1:src_y2, src_x1:src_x2] + img_crop = image[:, src_y1:src_y2, src_x1:src_x2] + mask_crop = mask[0, src_y1:src_y2, src_x1:src_x2] + + bg_val = torch.tensor(bg_color, device=image.device, dtype=image.dtype).view(-1, 1, 1) / 255.0 + + masked_crop = img_crop * mask_crop + bg_val * (1.0 - mask_crop) + + canvas[:, dst_y1:dst_y2, dst_x1:dst_x2] = masked_crop return canvas