post processing now loops over batch dim

This commit is contained in:
EllangoK 2023-03-30 19:45:43 -04:00
parent 406f2872db
commit a305691d48

View File

@ -27,7 +27,11 @@ class Dither:
CATEGORY = "postprocessing" CATEGORY = "postprocessing"
def dither(self, image: torch.Tensor, bits: int): def dither(self, image: torch.Tensor, bits: int):
tensor_image = image.numpy()[0] batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
for b in range(batch_size):
tensor_image = image[b].numpy()
img = (tensor_image * 255) img = (tensor_image * 255)
height, width, _ = img.shape height, width, _ = img.shape
@ -52,7 +56,9 @@ class Dither:
dithered = img / 255 dithered = img / 255
tensor = torch.from_numpy(dithered).unsqueeze(0) tensor = torch.from_numpy(dithered).unsqueeze(0)
return (tensor,) result[b] = tensor
return (result,)
class KMeansQuantize: class KMeansQuantize:
def __init__(self): def __init__(self):
@ -84,7 +90,11 @@ class KMeansQuantize:
CATEGORY = "postprocessing" CATEGORY = "postprocessing"
def kmeans_quantize(self, image: torch.Tensor, colors: int, precision: int): def kmeans_quantize(self, image: torch.Tensor, colors: int, precision: int):
tensor_image = image.numpy()[0].astype(np.float32) batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
for b in range(batch_size):
tensor_image = image[b].numpy().astype(np.float32)
img = tensor_image img = tensor_image
height, width, c = img.shape height, width, c = img.shape
@ -100,9 +110,11 @@ class KMeansQuantize:
criteria, 1, cv2.KMEANS_PP_CENTERS criteria, 1, cv2.KMEANS_PP_CENTERS
) )
result = center[label.flatten()].reshape(*img.shape) img = center[label.flatten()].reshape(*img.shape)
tensor = torch.from_numpy(result).unsqueeze(0) tensor = torch.from_numpy(img).unsqueeze(0)
return (tensor,) result[b] = tensor
return (result,)
class GaussianBlur: class GaussianBlur:
def __init__(self): def __init__(self):
@ -134,10 +146,16 @@ class GaussianBlur:
CATEGORY = "postprocessing" CATEGORY = "postprocessing"
def blur(self, image: torch.Tensor, kernel_size: int, sigma: float): def blur(self, image: torch.Tensor, kernel_size: int, sigma: float):
tensor_image = image.numpy()[0] batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
for b in range(batch_size):
tensor_image = image[b].numpy()
blurred = cv2.GaussianBlur(tensor_image, (kernel_size, kernel_size), sigma) blurred = cv2.GaussianBlur(tensor_image, (kernel_size, kernel_size), sigma)
tensor = torch.from_numpy(blurred).unsqueeze(0) tensor = torch.from_numpy(blurred).unsqueeze(0)
return (tensor,) result[b] = tensor
return (result,)
class Sharpen: class Sharpen:
def __init__(self): def __init__(self):
@ -169,7 +187,11 @@ class Sharpen:
CATEGORY = "postprocessing" CATEGORY = "postprocessing"
def sharpen(self, image: torch.Tensor, kernel_size: int, alpha: float): def sharpen(self, image: torch.Tensor, kernel_size: int, alpha: float):
tensor_image = image.numpy()[0] batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
for b in range(batch_size):
tensor_image = image[b].numpy()
kernel = np.ones((kernel_size, kernel_size), dtype=np.float32) * -1 kernel = np.ones((kernel_size, kernel_size), dtype=np.float32) * -1
center = kernel_size // 2 center = kernel_size // 2
@ -180,7 +202,9 @@ class Sharpen:
tensor = torch.from_numpy(sharpened).unsqueeze(0) tensor = torch.from_numpy(sharpened).unsqueeze(0)
tensor = torch.clamp(tensor, 0, 1) tensor = torch.clamp(tensor, 0, 1)
return (tensor,) result[b] = tensor
return (result,)
class CannyEdgeDetection: class CannyEdgeDetection:
def __init__(self): def __init__(self):
@ -212,11 +236,17 @@ class CannyEdgeDetection:
CATEGORY = "postprocessing" CATEGORY = "postprocessing"
def canny(self, image: torch.Tensor, lower_threshold: int, upper_threshold: int): def canny(self, image: torch.Tensor, lower_threshold: int, upper_threshold: int):
tensor_image = image.numpy()[0] batch_size, height, width, _ = image.shape
gray_image = (cv2.cvtColor(tensor_image, cv2.COLOR_BGR2GRAY) * 255).astype(np.uint8) result = torch.zeros(batch_size, height, width)
for b in range(batch_size):
tensor_image = image[b].numpy().copy()
gray_image = (cv2.cvtColor(tensor_image, cv2.COLOR_RGB2GRAY) * 255).astype(np.uint8)
canny = cv2.Canny(gray_image, lower_threshold, upper_threshold) canny = cv2.Canny(gray_image, lower_threshold, upper_threshold)
tensor = torch.from_numpy(canny).unsqueeze(0) tensor = torch.from_numpy(canny)
return (tensor,) result[b] = tensor
return (result,)
class ColorCorrect: class ColorCorrect:
def __init__(self): def __init__(self):
@ -272,7 +302,11 @@ class ColorCorrect:
CATEGORY = "postprocessing" CATEGORY = "postprocessing"
def color_correct(self, image: torch.Tensor, temperature: float, hue: float, brightness: float, contrast: float, saturation: float, gamma: float): def color_correct(self, image: torch.Tensor, temperature: float, hue: float, brightness: float, contrast: float, saturation: float, gamma: float):
tensor_image = image.numpy()[0] batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
for b in range(batch_size):
tensor_image = image[b].numpy()
brightness /= 100 brightness /= 100
contrast /= 100 contrast /= 100
@ -316,8 +350,9 @@ class ColorCorrect:
modified_image = modified_image.astype(np.uint8) modified_image = modified_image.astype(np.uint8)
modified_image = modified_image / 255 modified_image = modified_image / 255
modified_image = torch.from_numpy(modified_image).unsqueeze(0) modified_image = torch.from_numpy(modified_image).unsqueeze(0)
result[b] = modified_image
return (modified_image, ) return (result, )
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {