convert dither, sharpen, blur, blend to pure torch

This commit is contained in:
EllangoK 2023-03-31 18:00:26 -04:00
parent a99978b722
commit 1094d816a6

View File

@ -1,6 +1,7 @@
import numpy as np
import cv2
import torch
import torch.nn.functional as F
from PIL import Image, ImageEnhance
class Dither:
@ -31,7 +32,7 @@ class Dither:
result = torch.zeros_like(image)
for b in range(batch_size):
tensor_image = image[b].numpy()
tensor_image = image[b]
img = (tensor_image * 255)
height, width, _ = img.shape
@ -39,8 +40,8 @@ class Dither:
for y in range(height):
for x in range(width):
old_pixel = img[y, x].copy()
new_pixel = np.round(old_pixel / scale) * scale
old_pixel = img[y, x].clone()
new_pixel = torch.round(old_pixel / scale) * scale
img[y, x] = new_pixel
quant_error = old_pixel - new_pixel
@ -55,7 +56,7 @@ class Dither:
img[y + 1, x + 1] += quant_error * 1 / 16
dithered = img / 255
tensor = torch.from_numpy(dithered).unsqueeze(0)
tensor = dithered.unsqueeze(0)
result[b] = tensor
return (result,)
@ -145,17 +146,22 @@ class GaussianBlur:
CATEGORY = "postprocessing"
def gaussian_kernel(self, kernel_size: int, sigma: float):
x, y = torch.meshgrid(torch.linspace(-1, 1, kernel_size), torch.linspace(-1, 1, kernel_size))
d = torch.sqrt(x * x + y * y)
g = torch.exp(-(d * d) / (2.0 * sigma * sigma))
return g / g.sum()
def blur(self, image: torch.Tensor, kernel_size: int, sigma: float):
batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
batch_size, height, width, channels = image.shape
for b in range(batch_size):
tensor_image = image[b].numpy()
blurred = cv2.GaussianBlur(tensor_image, (kernel_size, kernel_size), sigma)
tensor = torch.from_numpy(blurred).unsqueeze(0)
result[b] = tensor
kernel = self.gaussian_kernel(kernel_size, sigma).repeat(channels, 1, 1).unsqueeze(1)
return (result,)
image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
blurred = F.conv2d(image, kernel, padding=kernel_size // 2, groups=channels)
blurred = blurred.permute(0, 2, 3, 1)
return (blurred,)
class Sharpen:
def __init__(self):
@ -187,22 +193,19 @@ class Sharpen:
CATEGORY = "postprocessing"
def sharpen(self, image: torch.Tensor, kernel_size: int, alpha: float):
batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
batch_size, height, width, channels = image.shape
for b in range(batch_size):
tensor_image = image[b].numpy()
kernel = torch.ones((kernel_size, kernel_size), dtype=torch.float32) * -1
center = kernel_size // 2
kernel[center, center] = kernel_size**2
kernel *= alpha
kernel = kernel.repeat(channels, 1, 1).unsqueeze(1)
kernel = np.ones((kernel_size, kernel_size), dtype=np.float32) * -1
center = kernel_size // 2
kernel[center, center] = kernel_size**2
kernel *= alpha
tensor_image = image.permute(0, 3, 1, 2) # Torch wants (B, C, H, W) we use (B, H, W, C)
sharpened = F.conv2d(tensor_image, kernel, padding=center, groups=channels)
sharpened = sharpened.permute(0, 2, 3, 1)
sharpened = cv2.filter2D(tensor_image, -1, kernel)
tensor = torch.from_numpy(sharpened).unsqueeze(0)
tensor = torch.clamp(tensor, 0, 1)
result[b] = tensor
result = torch.clamp(sharpened, 0, 1)
return (result,)
@ -305,18 +308,18 @@ class ColorCorrect:
batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
brightness /= 100
contrast /= 100
saturation /= 100
temperature /= 100
brightness = 1 + brightness
contrast = 1 + contrast
saturation = 1 + saturation
for b in range(batch_size):
tensor_image = image[b].numpy()
brightness /= 100
contrast /= 100
saturation /= 100
temperature /= 100
brightness = 1 + brightness
contrast = 1 + contrast
saturation = 1 + saturation
modified_image = Image.fromarray((tensor_image * 255).astype(np.uint8))
# brightness
@ -380,21 +383,10 @@ class Blend:
CATEGORY = "postprocessing"
def blend_images(self, image1: torch.Tensor, image2: torch.Tensor, blend_factor: float, blend_mode: str):
batch_size, height, width, _ = image1.shape
result = torch.zeros_like(image1)
for b in range(batch_size):
img1 = image1[b].numpy()
img2 = image2[b].numpy()
blended_image = self.blend_mode(img1, img2, blend_mode)
blended_image = img1 * (1 - blend_factor) + blended_image * blend_factor
blended_image = np.clip(blended_image, 0, 1)
tensor = torch.from_numpy(blended_image).unsqueeze(0)
result[b] = tensor
return (result,)
blended_image = self.blend_mode(image1, image2, blend_mode)
blended_image = image1 * (1 - blend_factor) + blended_image * blend_factor
blended_image = torch.clamp(blended_image, 0, 1)
return (blended_image,)
def blend_mode(self, img1, img2, mode):
if mode == "normal":
@ -404,14 +396,14 @@ class Blend:
elif mode == "screen":
return 1 - (1 - img1) * (1 - img2)
elif mode == "overlay":
return np.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
return torch.where(img1 <= 0.5, 2 * img1 * img2, 1 - 2 * (1 - img1) * (1 - img2))
elif mode == "soft_light":
return np.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
return torch.where(img2 <= 0.5, img1 - (1 - 2 * img2) * img1 * (1 - img1), img1 + (2 * img2 - 1) * (self.g(img1) - img1))
else:
raise ValueError(f"Unsupported blend mode: {mode}")
def g(self, x):
return np.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, np.sqrt(x))
return torch.where(x <= 0.25, ((16 * x - 12) * x + 4) * x, torch.sqrt(x))
NODE_CLASS_MAPPINGS = {