ImagePadForOutpaint now correctly returns a MaskBatch

This commit is contained in:
doctorpangloss 2025-02-16 15:39:36 -08:00
parent d404ab3185
commit d04288ce8d
2 changed files with 25 additions and 17 deletions

View File

@ -27,7 +27,7 @@ from ..cli_args import args
from ..cmd import folder_paths, latent_preview
from ..comfy_types import IO, ComfyNodeABC, InputTypeDict
from ..component_model.deprecation import _deprecate_method
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch
from ..execution_context import current_execution_context
from ..images import open_image
from ..interruption import interrupt_current_processing
@ -1917,35 +1917,35 @@ class ImagePadForOutpaint:
CATEGORY = "image"
def expand_image(self, image, left, top, right, bottom, feathering):
d1, d2, d3, d4 = image.size()
def expand_image(self, image: RGBImageBatch | RGBAImageBatch, left, top, right, bottom, feathering) -> tuple[RGBImageBatch | RGBAImageBatch, MaskBatch]:
batch, height, width, channels = image.size()
new_image = torch.ones(
(d1, d2 + top + bottom, d3 + left + right, d4),
(batch, height + top + bottom, width + left + right, channels),
dtype=torch.float32,
) * 0.5
new_image[:, top:top + d2, left:left + d3, :] = image
new_image[:, top:top + height, left:left + width, :] = image
mask = torch.ones(
(d2 + top + bottom, d3 + left + right),
(batch, height + top + bottom, width + left + right),
dtype=torch.float32,
)
t = torch.zeros(
(d2, d3),
(height, width),
dtype=torch.float32
)
if feathering > 0 and feathering * 2 < d2 and feathering * 2 < d3:
if feathering > 0 and feathering * 2 < height and feathering * 2 < width:
for i in range(d2):
for j in range(d3):
dt = i if top != 0 else d2
db = d2 - i if bottom != 0 else d2
for i in range(height):
for j in range(width):
dt = i if top != 0 else height
db = height - i if bottom != 0 else height
dl = j if left != 0 else d3
dr = d3 - j if right != 0 else d3
dl = j if left != 0 else width
dr = width - j if right != 0 else width
d = min(dt, db, dl, dr)
@ -1956,9 +1956,9 @@ class ImagePadForOutpaint:
t[i, j] = v * v
mask[top:top + d2, left:left + d3] = t
mask[:, top:top + height, left:left + width] = t
return (new_image, mask)
return new_image, mask
NODE_CLASS_MAPPINGS = {

View File

@ -142,7 +142,15 @@ def test_image_batch():
def test_image_pad_for_outpaint():
padded, mask = ImagePadForOutpaint().expand_image(_image_1x1, 1, 1, 1, 1, 0)
assert padded.shape == (1, 3, 3, 3)
assert mask.shape == (3, 3)
# the mask should now be batched
assert mask.shape == (1, 3, 3)
def test_image_pad_for_outpaint_batched():
padded, mask = ImagePadForOutpaint().expand_image(_image_1x1.expand(2, -1, -1, -1), 1, 1, 1, 1, 0)
assert padded.shape == (2, 3, 3, 3)
# the mask should now be batched
assert mask.shape == (2, 3, 3)
def test_empty_image():