diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index b9d800bfc..0c6009f47 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -1,6 +1,7 @@ import numpy as np import cv2 import torch +import torch.nn.functional as F from PIL import Image, ImageEnhance class Dither: @@ -31,7 +32,7 @@ class Dither: result = torch.zeros_like(image) for b in range(batch_size): - tensor_image = image[b].numpy() + tensor_image = image[b] img = (tensor_image * 255) height, width, _ = img.shape @@ -39,8 +40,8 @@ class Dither: for y in range(height): for x in range(width): - old_pixel = img[y, x].copy() - new_pixel = np.round(old_pixel / scale) * scale + old_pixel = img[y, x].clone() + new_pixel = torch.round(old_pixel / scale) * scale img[y, x] = new_pixel quant_error = old_pixel - new_pixel @@ -55,7 +56,7 @@ class Dither: img[y + 1, x + 1] += quant_error * 1 / 16 dithered = img / 255 - tensor = torch.from_numpy(dithered).unsqueeze(0) + tensor = dithered.unsqueeze(0) result[b] = tensor return (result,) @@ -145,17 +146,22 @@ class GaussianBlur: CATEGORY = "postprocessing" + def gaussian_kernel(self, kernel_size: int, sigma: float): + x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size)) + d = torch.sqrt(x * x + y * y) + g = torch.exp(-(d * d) / (2.0 * sigma * sigma)) + return g / g.sum() + def blur(self, image: torch.Tensor, kernel_size: int, sigma: float): - batch_size, height, width, _ = image.shape - result = torch.zeros_like(image) + batch_size, height, width, channels = image.shape - for b in range(batch_size): - tensor_image = image[b].numpy() - blurred = cv2.GaussianBlur(tensor_image, (kernel_size, kernel_size), sigma) - tensor = torch.from_numpy(blurred).unsqueeze(0) - result[b] = tensor + kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1) - return (result,) + image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels) + blurred = blurred.permute(0, 2, 3, 1) + + return (blurred,) class Sharpen: def __init__(self): @@ -187,22 +193,19 @@ class Sharpen: CATEGORY = "postprocessing" def sharpen(self, image: torch.Tensor, kernel_size: int, alpha: float): - batch_size, height, width, _ = image.shape - result = torch.zeros_like(image) + batch_size, height, width, channels = image.shape - for b in range(batch_size): - tensor_image = image[b].numpy() + kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1 + center = kernel_size // 2 + kernel[center, center] = kernel_size**2 + kernel *= alpha + kernel = kernel.repeat(channels, 1, 1).unsqueeze(1) - kernel = np.ones((kernel_size, kernel_size), dtype=np.float32) * -1 - center = kernel_size // 2 - kernel[center, center] = kernel_size**2 - kernel *= alpha + tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C) + sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels) + sharpened = sharpened.permute(0, 2, 3, 1) - sharpened = cv2.filter2D(tensor_image, -1, kernel) - - tensor = torch.from_numpy(sharpened).unsqueeze(0) - tensor = torch.clamp(tensor, 0, 1) - result[b] = tensor + result = torch.clamp(sharpened, 0, 1) return (result,) @@ -305,18 +308,18 @@ class ColorCorrect: batch_size, height, width, _ = image.shape result = torch.zeros_like(image) + brightness /= 100 + contrast /= 100 + saturation /= 100 + temperature /= 100 + + brightness = 1 + brightness + contrast = 1 + contrast + saturation = 1 + saturation + for b in range(batch_size): tensor_image = image[b].numpy() - brightness /= 100 - contrast /= 100 - saturation /= 100 - temperature /= 100 - - brightness = 1 + brightness - contrast = 1 + contrast - saturation = 1 + saturation - modified_image = Image.fromarray((tensor_image * 255).astype(np.uint8)) # brightness @@ -380,21 +383,10 @@ class Blend: CATEGORY = "postprocessing" def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str): - batch_size, height, width, _ = image1.shape - result = torch.zeros_like(image1) - - for b in range(batch_size): - img1 = image1[b].numpy() - img2 = image2[b].numpy() - - blended_image = self.blend_mode(img1, img2, blend_mode) - blended_image = img1 * (1 - blend_factor) + blended_image * blend_factor - blended_image = np.clip(blended_image, 0, 1) - - tensor = torch.from_numpy(blended_image).unsqueeze(0) - result[b] = tensor - - return (result,) + blended_image = self.blend_mode(image1, image2, blend_mode) + blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor + blended_image = torch.clamp(blended_image, 0, 1) + return (blended_image,) def blend_mode(self, img1, img2, mode): if mode == "normal": @@ -404,14 +396,14 @@ class Blend: elif mode == "screen": return 1 - (1 - img1) * (1 - img2) elif mode == "overlay": - return np.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) + return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2)) elif mode == "soft_light": - return np.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) + return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1)) else: raise ValueError(f"Unsupported blend mode: {mode}") def g(self, x): - return np.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, np.sqrt(x)) + return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x)) NODE_CLASS_MAPPINGS = {