nodes_morphology: implement with max-pooling.

This commit is contained in:
omarom 2026-04-13 17:20:52 +00:00
parent acd718598e
commit 1ab72b1c0f

View File

@ -1,12 +1,22 @@
import torch
import torch.nn.functional as F
import comfy.model_management
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
import kornia.color
def _max_pool_dilate(tensor, kernel_size):
pad = kernel_size // 2
return F.max_pool2d(tensor, kernel_size, stride=1, padding=pad)
def _max_pool_erode(tensor, kernel_size):
pad = kernel_size // 2
return -F.max_pool2d(-tensor, kernel_size, stride=1, padding=pad)
class Morphology(io.ComfyNode):
@classmethod
def define_schema(cls):
@ -31,22 +41,21 @@ class Morphology(io.ComfyNode):
@classmethod
def execute(cls, image, operation, kernel_size) -> io.NodeOutput:
device = comfy.model_management.get_torch_device()
kernel = torch.ones(kernel_size, kernel_size, device=device)
image_k = image.to(device).movedim(-1, 1)
if operation == "erode":
output = erosion(image_k, kernel)
output = _max_pool_erode(image_k, kernel_size)
elif operation == "dilate":
output = dilation(image_k, kernel)
output = _max_pool_dilate(image_k, kernel_size)
elif operation == "open":
output = opening(image_k, kernel)
output = _max_pool_dilate(_max_pool_erode(image_k, kernel_size), kernel_size)
elif operation == "close":
output = closing(image_k, kernel)
output = _max_pool_erode(_max_pool_dilate(image_k, kernel_size), kernel_size)
elif operation == "gradient":
output = gradient(image_k, kernel)
output = _max_pool_dilate(image_k, kernel_size) - _max_pool_erode(image_k, kernel_size)
elif operation == "top_hat":
output = top_hat(image_k, kernel)
output = image_k - _max_pool_dilate(_max_pool_erode(image_k, kernel_size), kernel_size)
elif operation == "bottom_hat":
output = bottom_hat(image_k, kernel)
output = _max_pool_erode(_max_pool_dilate(image_k, kernel_size), kernel_size) - image_k
else:
raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'")
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)