mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 11:07:24 +08:00
convert dither, sharpen, blur, blend to pure torch
This commit is contained in:
parent
a99978b722
commit
1094d816a6
@ -1,6 +1,7 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image, ImageEnhance
|
||||
|
||||
class Dither:
|
||||
@ -31,7 +32,7 @@ class Dither:
|
||||
result = torch.zeros_like(image)
|
||||
|
||||
for b in range(batch_size):
|
||||
tensor_image = image[b].numpy()
|
||||
tensor_image = image[b]
|
||||
img = (tensor_image * 255)
|
||||
height, width, _ = img.shape
|
||||
|
||||
@ -39,8 +40,8 @@ class Dither:
|
||||
|
||||
for y in range(height):
|
||||
for x in range(width):
|
||||
old_pixel = img[y, x].copy()
|
||||
new_pixel = np.round(old_pixel / scale) * scale
|
||||
old_pixel = img[y, x].clone()
|
||||
new_pixel = torch.round(old_pixel / scale) * scale
|
||||
img[y, x] = new_pixel
|
||||
|
||||
quant_error = old_pixel - new_pixel
|
||||
@ -55,7 +56,7 @@ class Dither:
|
||||
img[y + 1, x + 1] += quant_error * 1 / 16
|
||||
|
||||
dithered = img / 255
|
||||
tensor = torch.from_numpy(dithered).unsqueeze(0)
|
||||
tensor = dithered.unsqueeze(0)
|
||||
result[b] = tensor
|
||||
|
||||
return (result,)
|
||||
@ -145,17 +146,22 @@ class GaussianBlur:
|
||||
|
||||
CATEGORY = "postprocessing"
|
||||
|
||||
def gaussian_kernel(self, kernel_size: int, sigma: float):
|
||||
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size))
|
||||
d = torch.sqrt(x * x + y * y)
|
||||
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
|
||||
return g / g.sum()
|
||||
|
||||
def blur(self, image: torch.Tensor, kernel_size: int, sigma: float):
|
||||
batch_size, height, width, _ = image.shape
|
||||
result = torch.zeros_like(image)
|
||||
batch_size, height, width, channels = image.shape
|
||||
|
||||
for b in range(batch_size):
|
||||
tensor_image = image[b].numpy()
|
||||
blurred = cv2.GaussianBlur(tensor_image, (kernel_size, kernel_size), sigma)
|
||||
tensor = torch.from_numpy(blurred).unsqueeze(0)
|
||||
result[b] = tensor
|
||||
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
return (result,)
|
||||
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
|
||||
blurred = blurred.permute(0, 2, 3, 1)
|
||||
|
||||
return (blurred,)
|
||||
|
||||
class Sharpen:
|
||||
def __init__(self):
|
||||
@ -187,22 +193,19 @@ class Sharpen:
|
||||
CATEGORY = "postprocessing"
|
||||
|
||||
def sharpen(self, image: torch.Tensor, kernel_size: int, alpha: float):
|
||||
batch_size, height, width, _ = image.shape
|
||||
result = torch.zeros_like(image)
|
||||
batch_size, height, width, channels = image.shape
|
||||
|
||||
for b in range(batch_size):
|
||||
tensor_image = image[b].numpy()
|
||||
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
|
||||
center = kernel_size // 2
|
||||
kernel[center, center] = kernel_size**2
|
||||
kernel *= alpha
|
||||
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
|
||||
|
||||
kernel = np.ones((kernel_size, kernel_size), dtype=np.float32) * -1
|
||||
center = kernel_size // 2
|
||||
kernel[center, center] = kernel_size**2
|
||||
kernel *= alpha
|
||||
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
|
||||
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
|
||||
sharpened = sharpened.permute(0, 2, 3, 1)
|
||||
|
||||
sharpened = cv2.filter2D(tensor_image, -1, kernel)
|
||||
|
||||
tensor = torch.from_numpy(sharpened).unsqueeze(0)
|
||||
tensor = torch.clamp(tensor, 0, 1)
|
||||
result[b] = tensor
|
||||
result = torch.clamp(sharpened, 0, 1)
|
||||
|
||||
return (result,)
|
||||
|
||||
@ -305,18 +308,18 @@ class ColorCorrect:
|
||||
batch_size, height, width, _ = image.shape
|
||||
result = torch.zeros_like(image)
|
||||
|
||||
brightness /= 100
|
||||
contrast /= 100
|
||||
saturation /= 100
|
||||
temperature /= 100
|
||||
|
||||
brightness = 1 + brightness
|
||||
contrast = 1 + contrast
|
||||
saturation = 1 + saturation
|
||||
|
||||
for b in range(batch_size):
|
||||
tensor_image = image[b].numpy()
|
||||
|
||||
brightness /= 100
|
||||
contrast /= 100
|
||||
saturation /= 100
|
||||
temperature /= 100
|
||||
|
||||
brightness = 1 + brightness
|
||||
contrast = 1 + contrast
|
||||
saturation = 1 + saturation
|
||||
|
||||
modified_image = Image.fromarray((tensor_image * 255).astype(np.uint8))
|
||||
|
||||
# brightness
|
||||
@ -380,21 +383,10 @@ class Blend:
|
||||
CATEGORY = "postprocessing"
|
||||
|
||||
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
|
||||
batch_size, height, width, _ = image1.shape
|
||||
result = torch.zeros_like(image1)
|
||||
|
||||
for b in range(batch_size):
|
||||
img1 = image1[b].numpy()
|
||||
img2 = image2[b].numpy()
|
||||
|
||||
blended_image = self.blend_mode(img1, img2, blend_mode)
|
||||
blended_image = img1 * (1 - blend_factor) + blended_image * blend_factor
|
||||
blended_image = np.clip(blended_image, 0, 1)
|
||||
|
||||
tensor = torch.from_numpy(blended_image).unsqueeze(0)
|
||||
result[b] = tensor
|
||||
|
||||
return (result,)
|
||||
blended_image = self.blend_mode(image1, image2, blend_mode)
|
||||
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
|
||||
blended_image = torch.clamp(blended_image, 0, 1)
|
||||
return (blended_image,)
|
||||
|
||||
def blend_mode(self, img1, img2, mode):
|
||||
if mode == "normal":
|
||||
@ -404,14 +396,14 @@ class Blend:
|
||||
elif mode == "screen":
|
||||
return 1 - (1 - img1) * (1 - img2)
|
||||
elif mode == "overlay":
|
||||
return np.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
|
||||
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
|
||||
elif mode == "soft_light":
|
||||
return np.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
|
||||
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
|
||||
else:
|
||||
raise ValueError(f"Unsupported blend mode: {mode}")
|
||||
|
||||
def g(self, x):
|
||||
return np.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, np.sqrt(x))
|
||||
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user