This commit is contained in:
Omri Marom 2026-04-19 14:57:19 +03:00 committed by GitHub
commit 41119df442
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,12 +1,22 @@
import torch import torch
import torch.nn.functional as F
import comfy.model_management import comfy.model_management
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
from kornia.morphology import dilation, erosion, opening, closing, gradient, top_hat, bottom_hat
import kornia.color import kornia.color
def _max_pool_dilate(tensor, kernel_size):
pad = kernel_size // 2
return F.max_pool2d(tensor, kernel_size, stride=1, padding=pad)
def _max_pool_erode(tensor, kernel_size):
pad = kernel_size // 2
return -F.max_pool2d(-tensor, kernel_size, stride=1, padding=pad)
class Morphology(io.ComfyNode): class Morphology(io.ComfyNode):
@classmethod @classmethod
def define_schema(cls): def define_schema(cls):
@ -31,22 +41,21 @@ class Morphology(io.ComfyNode):
@classmethod @classmethod
def execute(cls, image, operation, kernel_size) -> io.NodeOutput: def execute(cls, image, operation, kernel_size) -> io.NodeOutput:
device = comfy.model_management.get_torch_device() device = comfy.model_management.get_torch_device()
kernel = torch.ones(kernel_size, kernel_size, device=device)
image_k = image.to(device).movedim(-1, 1) image_k = image.to(device).movedim(-1, 1)
if operation == "erode": if operation == "erode":
output = erosion(image_k, kernel) output = _max_pool_erode(image_k, kernel_size)
elif operation == "dilate": elif operation == "dilate":
output = dilation(image_k, kernel) output = _max_pool_dilate(image_k, kernel_size)
elif operation == "open": elif operation == "open":
output = opening(image_k, kernel) output = _max_pool_dilate(_max_pool_erode(image_k, kernel_size), kernel_size)
elif operation == "close": elif operation == "close":
output = closing(image_k, kernel) output = _max_pool_erode(_max_pool_dilate(image_k, kernel_size), kernel_size)
elif operation == "gradient": elif operation == "gradient":
output = gradient(image_k, kernel) output = _max_pool_dilate(image_k, kernel_size) - _max_pool_erode(image_k, kernel_size)
elif operation == "top_hat": elif operation == "top_hat":
output = top_hat(image_k, kernel) output = image_k - _max_pool_dilate(_max_pool_erode(image_k, kernel_size), kernel_size)
elif operation == "bottom_hat": elif operation == "bottom_hat":
output = bottom_hat(image_k, kernel) output = _max_pool_erode(_max_pool_dilate(image_k, kernel_size), kernel_size) - image_k
else: else:
raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'") raise ValueError(f"Invalid operation {operation} for morphology. Must be one of 'erode', 'dilate', 'open', 'close', 'gradient', 'tophat', 'bottomhat'")
img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1) img_out = output.to(comfy.model_management.intermediate_device()).movedim(1, -1)