From d04288ce8dc7af0a64fc0833b4abc4f7c596fbbb Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Sun, 16 Feb 2025 15:39:36 -0800 Subject: [PATCH] ImagePadForOutpaint now correctly returns a MaskBatch --- comfy/nodes/base_nodes.py | 32 ++++++++++++++++---------------- tests/unit/test_base_nodes.py | 10 +++++++++- 2 files changed, 25 insertions(+), 17 deletions(-) diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 3f22f0778..bedc36af3 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -27,7 +27,7 @@ from ..cli_args import args from ..cmd import folder_paths, latent_preview from ..comfy_types import IO, ComfyNodeABC, InputTypeDict from ..component_model.deprecation import _deprecate_method -from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch +from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch from ..execution_context import current_execution_context from ..images import open_image from ..interruption import interrupt_current_processing @@ -1917,35 +1917,35 @@ class ImagePadForOutpaint: CATEGORY = "image" - def expand_image(self, image, left, top, right, bottom, feathering): - d1, d2, d3, d4 = image.size() + def expand_image(self, image: RGBImageBatch | RGBAImageBatch, left, top, right, bottom, feathering) -> tuple[RGBImageBatch | RGBAImageBatch, MaskBatch]: + batch, height, width, channels = image.size() new_image = torch.ones( - (d1, d2 + top + bottom, d3 + left + right, d4), + (batch, height + top + bottom, width + left + right, channels), dtype=torch.float32, ) * 0.5 - new_image[:, top:top + d2, left:left + d3, :] = image + new_image[:, top:top + height, left:left + width, :] = image mask = torch.ones( - (d2 + top + bottom, d3 + left + right), + (batch, height + top + bottom, width + left + right), dtype=torch.float32, ) t = torch.zeros( - (d2, d3), + (height, width), dtype=torch.float32 ) - if feathering > 0 and feathering * 2 < d2 and feathering * 2 < d3: + if feathering > 0 and feathering * 2 < height and feathering * 2 < width: - for i in range(d2): - for j in range(d3): - dt = i if top != 0 else d2 - db = d2 - i if bottom != 0 else d2 + for i in range(height): + for j in range(width): + dt = i if top != 0 else height + db = height - i if bottom != 0 else height - dl = j if left != 0 else d3 - dr = d3 - j if right != 0 else d3 + dl = j if left != 0 else width + dr = width - j if right != 0 else width d = min(dt, db, dl, dr) @@ -1956,9 +1956,9 @@ class ImagePadForOutpaint: t[i, j] = v * v - mask[top:top + d2, left:left + d3] = t + mask[:, top:top + height, left:left + width] = t - return (new_image, mask) + return new_image, mask NODE_CLASS_MAPPINGS = { diff --git a/tests/unit/test_base_nodes.py b/tests/unit/test_base_nodes.py index 7b4c218ee..060f17755 100644 --- a/tests/unit/test_base_nodes.py +++ b/tests/unit/test_base_nodes.py @@ -142,7 +142,15 @@ def test_image_batch(): def test_image_pad_for_outpaint(): padded, mask = ImagePadForOutpaint().expand_image(_image_1x1, 1, 1, 1, 1, 0) assert padded.shape == (1, 3, 3, 3) - assert mask.shape == (3, 3) + # the mask should now be batched + assert mask.shape == (1, 3, 3) + + +def test_image_pad_for_outpaint_batched(): + padded, mask = ImagePadForOutpaint().expand_image(_image_1x1.expand(2, -1, -1, -1), 1, 1, 1, 1, 0) + assert padded.shape == (2, 3, 3, 3) + # the mask should now be batched + assert mask.shape == (2, 3, 3) def test_empty_image():