mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
ImagePadForOutpaint now correctly returns a MaskBatch
This commit is contained in:
parent
d404ab3185
commit
d04288ce8d
@ -27,7 +27,7 @@ from ..cli_args import args
|
|||||||
from ..cmd import folder_paths, latent_preview
|
from ..cmd import folder_paths, latent_preview
|
||||||
from ..comfy_types import IO, ComfyNodeABC, InputTypeDict
|
from ..comfy_types import IO, ComfyNodeABC, InputTypeDict
|
||||||
from ..component_model.deprecation import _deprecate_method
|
from ..component_model.deprecation import _deprecate_method
|
||||||
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch
|
from ..component_model.tensor_types import RGBImage, RGBImageBatch, MaskBatch, RGBAImageBatch
|
||||||
from ..execution_context import current_execution_context
|
from ..execution_context import current_execution_context
|
||||||
from ..images import open_image
|
from ..images import open_image
|
||||||
from ..interruption import interrupt_current_processing
|
from ..interruption import interrupt_current_processing
|
||||||
@ -1917,35 +1917,35 @@ class ImagePadForOutpaint:
|
|||||||
|
|
||||||
CATEGORY = "image"
|
CATEGORY = "image"
|
||||||
|
|
||||||
def expand_image(self, image, left, top, right, bottom, feathering):
|
def expand_image(self, image: RGBImageBatch | RGBAImageBatch, left, top, right, bottom, feathering) -> tuple[RGBImageBatch | RGBAImageBatch, MaskBatch]:
|
||||||
d1, d2, d3, d4 = image.size()
|
batch, height, width, channels = image.size()
|
||||||
|
|
||||||
new_image = torch.ones(
|
new_image = torch.ones(
|
||||||
(d1, d2 + top + bottom, d3 + left + right, d4),
|
(batch, height + top + bottom, width + left + right, channels),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
) * 0.5
|
) * 0.5
|
||||||
|
|
||||||
new_image[:, top:top + d2, left:left + d3, :] = image
|
new_image[:, top:top + height, left:left + width, :] = image
|
||||||
|
|
||||||
mask = torch.ones(
|
mask = torch.ones(
|
||||||
(d2 + top + bottom, d3 + left + right),
|
(batch, height + top + bottom, width + left + right),
|
||||||
dtype=torch.float32,
|
dtype=torch.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
t = torch.zeros(
|
t = torch.zeros(
|
||||||
(d2, d3),
|
(height, width),
|
||||||
dtype=torch.float32
|
dtype=torch.float32
|
||||||
)
|
)
|
||||||
|
|
||||||
if feathering > 0 and feathering * 2 < d2 and feathering * 2 < d3:
|
if feathering > 0 and feathering * 2 < height and feathering * 2 < width:
|
||||||
|
|
||||||
for i in range(d2):
|
for i in range(height):
|
||||||
for j in range(d3):
|
for j in range(width):
|
||||||
dt = i if top != 0 else d2
|
dt = i if top != 0 else height
|
||||||
db = d2 - i if bottom != 0 else d2
|
db = height - i if bottom != 0 else height
|
||||||
|
|
||||||
dl = j if left != 0 else d3
|
dl = j if left != 0 else width
|
||||||
dr = d3 - j if right != 0 else d3
|
dr = width - j if right != 0 else width
|
||||||
|
|
||||||
d = min(dt, db, dl, dr)
|
d = min(dt, db, dl, dr)
|
||||||
|
|
||||||
@ -1956,9 +1956,9 @@ class ImagePadForOutpaint:
|
|||||||
|
|
||||||
t[i, j] = v * v
|
t[i, j] = v * v
|
||||||
|
|
||||||
mask[top:top + d2, left:left + d3] = t
|
mask[:, top:top + height, left:left + width] = t
|
||||||
|
|
||||||
return (new_image, mask)
|
return new_image, mask
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
@ -142,7 +142,15 @@ def test_image_batch():
|
|||||||
def test_image_pad_for_outpaint():
|
def test_image_pad_for_outpaint():
|
||||||
padded, mask = ImagePadForOutpaint().expand_image(_image_1x1, 1, 1, 1, 1, 0)
|
padded, mask = ImagePadForOutpaint().expand_image(_image_1x1, 1, 1, 1, 1, 0)
|
||||||
assert padded.shape == (1, 3, 3, 3)
|
assert padded.shape == (1, 3, 3, 3)
|
||||||
assert mask.shape == (3, 3)
|
# the mask should now be batched
|
||||||
|
assert mask.shape == (1, 3, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_image_pad_for_outpaint_batched():
|
||||||
|
padded, mask = ImagePadForOutpaint().expand_image(_image_1x1.expand(2, -1, -1, -1), 1, 1, 1, 1, 0)
|
||||||
|
assert padded.shape == (2, 3, 3, 3)
|
||||||
|
# the mask should now be batched
|
||||||
|
assert mask.shape == (2, 3, 3)
|
||||||
|
|
||||||
|
|
||||||
def test_empty_image():
|
def test_empty_image():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user