mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
ImagePadForOutpaint now correctly returns a MaskBatch
This commit is contained in:
parent
d404ab3185
commit
d04288ce8d
@ -27,7 +27,7 @@ from ..cli_args import args
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
from ..comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||
from ..component_model.deprecation import _deprecate_method
|
||||
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch
|
||||
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch
|
||||
from ..execution_context import current_execution_context
|
||||
from ..images import open_image
|
||||
from ..interruption import interrupt_current_processing
|
||||
@ -1917,35 +1917,35 @@ class ImagePadForOutpaint:
|
||||
|
||||
CATEGORY = "image"
|
||||
|
||||
def expand_image(self, image, left, top, right, bottom, feathering):
|
||||
d1, d2, d3, d4 = image.size()
|
||||
def expand_image(self, image: RGBImageBatch | RGBAImageBatch, left, top, right, bottom, feathering) -> tuple[RGBImageBatch | RGBAImageBatch, MaskBatch]:
|
||||
batch, height, width, channels = image.size()
|
||||
|
||||
new_image = torch.ones(
|
||||
(d1, d2 + top + bottom, d3 + left + right, d4),
|
||||
(batch, height + top + bottom, width + left + right, channels),
|
||||
dtype=torch.float32,
|
||||
) * 0.5
|
||||
|
||||
new_image[:, top:top + d2, left:left + d3, :] = image
|
||||
new_image[:, top:top + height, left:left + width, :] = image
|
||||
|
||||
mask = torch.ones(
|
||||
(d2 + top + bottom, d3 + left + right),
|
||||
(batch, height + top + bottom, width + left + right),
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
t = torch.zeros(
|
||||
(d2, d3),
|
||||
(height, width),
|
||||
dtype=torch.float32
|
||||
)
|
||||
|
||||
if feathering > 0 and feathering * 2 < d2 and feathering * 2 < d3:
|
||||
if feathering > 0 and feathering * 2 < height and feathering * 2 < width:
|
||||
|
||||
for i in range(d2):
|
||||
for j in range(d3):
|
||||
dt = i if top != 0 else d2
|
||||
db = d2 - i if bottom != 0 else d2
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
dt = i if top != 0 else height
|
||||
db = height - i if bottom != 0 else height
|
||||
|
||||
dl = j if left != 0 else d3
|
||||
dr = d3 - j if right != 0 else d3
|
||||
dl = j if left != 0 else width
|
||||
dr = width - j if right != 0 else width
|
||||
|
||||
d = min(dt, db, dl, dr)
|
||||
|
||||
@ -1956,9 +1956,9 @@ class ImagePadForOutpaint:
|
||||
|
||||
t[i, j] = v * v
|
||||
|
||||
mask[top:top + d2, left:left + d3] = t
|
||||
mask[:, top:top + height, left:left + width] = t
|
||||
|
||||
return (new_image, mask)
|
||||
return new_image, mask
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
@ -142,7 +142,15 @@ def test_image_batch():
|
||||
def test_image_pad_for_outpaint():
|
||||
padded, mask = ImagePadForOutpaint().expand_image(_image_1x1, 1, 1, 1, 1, 0)
|
||||
assert padded.shape == (1, 3, 3, 3)
|
||||
assert mask.shape == (3, 3)
|
||||
# the mask should now be batched
|
||||
assert mask.shape == (1, 3, 3)
|
||||
|
||||
|
||||
def test_image_pad_for_outpaint_batched():
|
||||
padded, mask = ImagePadForOutpaint().expand_image(_image_1x1.expand(2, -1, -1, -1), 1, 1, 1, 1, 0)
|
||||
assert padded.shape == (2, 3, 3, 3)
|
||||
# the mask should now be batched
|
||||
assert mask.shape == (2, 3, 3)
|
||||
|
||||
|
||||
def test_empty_image():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user