diff --git a/comfy_extras/nodes_morphology.py b/comfy_extras/nodes_morphology.py index 4ab2fb7e8..c4f22d042 100644 --- a/comfy_extras/nodes_morphology.py +++ b/comfy_extras/nodes_morphology.py @@ -1,12 +1,22 @@ import torch +import torch.nn.functional as F import comfy.model_management from typing_extensions import override from comfy_api.latest import ComfyExtension, io -from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat import kornia.color +def _max_pool_dilate(tensor, kernel_size): + pad = kernel_size // 2 + return F.max_pool2d(tensor, kernel_size, stride=1, padding=pad) + + +def _max_pool_erode(tensor, kernel_size): + pad = kernel_size // 2 + return -F.max_pool2d(-tensor, kernel_size, stride=1, padding=pad) + + class Morphology(io.ComfyNode): @classmethod def define_schema(cls): @@ -31,22 +41,21 @@ class Morphology(io.ComfyNode): @classmethod def execute(cls, image, operation, kernel_size) -> io.NodeOutput: device = comfy.model_management.get_torch_device() - kernel = torch.ones(kernel_size, kernel_size, device=device) image_k = image.to(device).movedim(-1, 1) if operation == "erode": - output = erosion(image_k, kernel) + output = _max_pool_erode(image_k, kernel_size) elif operation == "dilate": - output = dilation(image_k, kernel) + output = _max_pool_dilate(image_k, kernel_size) elif operation == "open": - output = opening(image_k, kernel) + output = _max_pool_dilate(_max_pool_erode(image_k, kernel_size), kernel_size) elif operation == "close": - output = closing(image_k, kernel) + output = _max_pool_erode(_max_pool_dilate(image_k, kernel_size), kernel_size) elif operation == "gradient": - output = gradient(image_k, kernel) + output = _max_pool_dilate(image_k, kernel_size) - _max_pool_erode(image_k, kernel_size) elif operation == "top_hat": - output = top_hat(image_k, kernel) + output = image_k - _max_pool_dilate(_max_pool_erode(image_k, kernel_size), kernel_size) elif operation == "bottom_hat": - output = bottom_hat(image_k, kernel) + output = _max_pool_erode(_max_pool_dilate(image_k, kernel_size), kernel_size) - image_k else: raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)