Make scale-to-multiple shape handling explicit

This commit is contained in:
nomadoor 2026-01-14 09:01:26 +09:00
parent e5810f7af8
commit 0b8e03a1e7

View File

@ -254,7 +254,7 @@ class ResizeType(str, Enum):
SCALE_HEIGHT = "scale height" SCALE_HEIGHT = "scale height"
SCALE_TOTAL_PIXELS = "scale total pixels" SCALE_TOTAL_PIXELS = "scale total pixels"
MATCH_SIZE = "match size" MATCH_SIZE = "match size"
CROP_TO_MULTIPLE = "crop to multiple" SCALE_TO_MULTIPLE = "scale to multiple"
def is_image(input: torch.Tensor) -> bool: def is_image(input: torch.Tensor) -> bool:
# images have 4 dimensions: [batch, height, width, channels] # images have 4 dimensions: [batch, height, width, channels]
@ -364,26 +364,40 @@ def scale_match_size(input: torch.Tensor, match: torch.Tensor, scale_method: str
input = finalize_image_mask_input(input, is_type_image) input = finalize_image_mask_input(input, is_type_image)
return input return input
def crop_to_multiple(input: torch.Tensor, multiple: int, crop: str="center") -> torch.Tensor: def scale_to_multiple_cover(input: torch.Tensor, multiple: int, scale_method: str) -> torch.Tensor:
if multiple <= 1: if multiple <= 1:
return input return input
width = input.shape[2] is_type_image = is_image(input)
height = input.shape[1] if is_type_image:
new_w = (width // multiple) * multiple _, height, width, _ = input.shape
new_h = (height // multiple) * multiple
if new_w == 0 or new_h == 0:
return input
if new_w == width and new_h == height:
return input
if crop == "center":
x0 = (width - new_w) // 2
y0 = (height - new_h) // 2
else: else:
x0 = 0 _, height, width = input.shape
y0 = 0 target_w = (width // multiple) * multiple
x1 = x0 + new_w target_h = (height // multiple) * multiple
y1 = y0 + new_h if target_w == 0 or target_h == 0:
if is_image(input): return input
if target_w == width and target_h == height:
return input
s_w = target_w / width
s_h = target_h / height
if s_w >= s_h:
scaled_w = target_w
scaled_h = int(math.ceil(height * s_w))
if scaled_h < target_h:
scaled_h = target_h
else:
scaled_h = target_h
scaled_w = int(math.ceil(width * s_h))
if scaled_w < target_w:
scaled_w = target_w
input = init_image_mask_input(input, is_type_image)
input = comfy.utils.common_upscale(input, scaled_w, scaled_h, scale_method, "disabled")
input = finalize_image_mask_input(input, is_type_image)
x0 = (scaled_w - target_w) // 2
y0 = (scaled_h - target_h) // 2
x1 = x0 + target_w
y1 = y0 + target_h
if is_type_image:
return input[:, y0:y1, x0:x1, :] return input[:, y0:y1, x0:x1, :]
return input[:, y0:y1, x0:x1] return input[:, y0:y1, x0:x1]
@ -442,7 +456,7 @@ class ResizeImageMaskNode(io.ComfyNode):
io.MultiType.Input("match", [io.Image, io.Mask]), io.MultiType.Input("match", [io.Image, io.Mask]),
crop_combo, crop_combo,
]), ]),
io.DynamicCombo.Option(ResizeType.CROP_TO_MULTIPLE, [ io.DynamicCombo.Option(ResizeType.SCALE_TO_MULTIPLE, [
io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1), io.Int.Input("multiple", default=8, min=1, max=MAX_RESOLUTION, step=1),
]), ]),
]), ]),
@ -470,8 +484,8 @@ class ResizeImageMaskNode(io.ComfyNode):
return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method)) return io.NodeOutput(scale_total_pixels(input, resize_type["megapixels"], scale_method))
elif selected_type == ResizeType.MATCH_SIZE: elif selected_type == ResizeType.MATCH_SIZE:
return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"])) return io.NodeOutput(scale_match_size(input, resize_type["match"], scale_method, resize_type["crop"]))
elif selected_type == ResizeType.CROP_TO_MULTIPLE: elif selected_type == ResizeType.SCALE_TO_MULTIPLE:
return io.NodeOutput(crop_to_multiple(input, resize_type["multiple"])) return io.NodeOutput(scale_to_multiple_cover(input, resize_type["multiple"], scale_method))
raise ValueError(f"Unsupported resize type: {selected_type}") raise ValueError(f"Unsupported resize type: {selected_type}")
def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None: def batch_images(images: list[torch.Tensor]) -> torch.Tensor | None: