mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-14 19:17:32 +08:00
post processing now loops over batch dim
This commit is contained in:
parent
406f2872db
commit
a305691d48
@ -27,7 +27,11 @@ class Dither:
|
|||||||
CATEGORY = "postprocessing"
|
CATEGORY = "postprocessing"
|
||||||
|
|
||||||
def dither(self, image: torch.Tensor, bits: int):
|
def dither(self, image: torch.Tensor, bits: int):
|
||||||
tensor_image = image.numpy()[0]
|
batch_size, height, width, _ = image.shape
|
||||||
|
result = torch.zeros_like(image)
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
tensor_image = image[b].numpy()
|
||||||
img = (tensor_image * 255)
|
img = (tensor_image * 255)
|
||||||
height, width, _ = img.shape
|
height, width, _ = img.shape
|
||||||
|
|
||||||
@ -52,7 +56,9 @@ class Dither:
|
|||||||
|
|
||||||
dithered = img / 255
|
dithered = img / 255
|
||||||
tensor = torch.from_numpy(dithered).unsqueeze(0)
|
tensor = torch.from_numpy(dithered).unsqueeze(0)
|
||||||
return (tensor,)
|
result[b] = tensor
|
||||||
|
|
||||||
|
return (result,)
|
||||||
|
|
||||||
class KMeansQuantize:
|
class KMeansQuantize:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -84,7 +90,11 @@ class KMeansQuantize:
|
|||||||
CATEGORY = "postprocessing"
|
CATEGORY = "postprocessing"
|
||||||
|
|
||||||
def kmeans_quantize(self, image: torch.Tensor, colors: int, precision: int):
|
def kmeans_quantize(self, image: torch.Tensor, colors: int, precision: int):
|
||||||
tensor_image = image.numpy()[0].astype(np.float32)
|
batch_size, height, width, _ = image.shape
|
||||||
|
result = torch.zeros_like(image)
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
tensor_image = image[b].numpy().astype(np.float32)
|
||||||
img = tensor_image
|
img = tensor_image
|
||||||
|
|
||||||
height, width, c = img.shape
|
height, width, c = img.shape
|
||||||
@ -100,9 +110,11 @@ class KMeansQuantize:
|
|||||||
criteria, 1, cv2.KMEANS_PP_CENTERS
|
criteria, 1, cv2.KMEANS_PP_CENTERS
|
||||||
)
|
)
|
||||||
|
|
||||||
result = center[label.flatten()].reshape(*img.shape)
|
img = center[label.flatten()].reshape(*img.shape)
|
||||||
tensor = torch.from_numpy(result).unsqueeze(0)
|
tensor = torch.from_numpy(img).unsqueeze(0)
|
||||||
return (tensor,)
|
result[b] = tensor
|
||||||
|
|
||||||
|
return (result,)
|
||||||
|
|
||||||
class GaussianBlur:
|
class GaussianBlur:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -134,10 +146,16 @@ class GaussianBlur:
|
|||||||
CATEGORY = "postprocessing"
|
CATEGORY = "postprocessing"
|
||||||
|
|
||||||
def blur(self, image: torch.Tensor, kernel_size: int, sigma: float):
|
def blur(self, image: torch.Tensor, kernel_size: int, sigma: float):
|
||||||
tensor_image = image.numpy()[0]
|
batch_size, height, width, _ = image.shape
|
||||||
|
result = torch.zeros_like(image)
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
tensor_image = image[b].numpy()
|
||||||
blurred = cv2.GaussianBlur(tensor_image, (kernel_size, kernel_size), sigma)
|
blurred = cv2.GaussianBlur(tensor_image, (kernel_size, kernel_size), sigma)
|
||||||
tensor = torch.from_numpy(blurred).unsqueeze(0)
|
tensor = torch.from_numpy(blurred).unsqueeze(0)
|
||||||
return (tensor,)
|
result[b] = tensor
|
||||||
|
|
||||||
|
return (result,)
|
||||||
|
|
||||||
class Sharpen:
|
class Sharpen:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -169,7 +187,11 @@ class Sharpen:
|
|||||||
CATEGORY = "postprocessing"
|
CATEGORY = "postprocessing"
|
||||||
|
|
||||||
def sharpen(self, image: torch.Tensor, kernel_size: int, alpha: float):
|
def sharpen(self, image: torch.Tensor, kernel_size: int, alpha: float):
|
||||||
tensor_image = image.numpy()[0]
|
batch_size, height, width, _ = image.shape
|
||||||
|
result = torch.zeros_like(image)
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
tensor_image = image[b].numpy()
|
||||||
|
|
||||||
kernel = np.ones((kernel_size, kernel_size), dtype=np.float32) * -1
|
kernel = np.ones((kernel_size, kernel_size), dtype=np.float32) * -1
|
||||||
center = kernel_size // 2
|
center = kernel_size // 2
|
||||||
@ -180,7 +202,9 @@ class Sharpen:
|
|||||||
|
|
||||||
tensor = torch.from_numpy(sharpened).unsqueeze(0)
|
tensor = torch.from_numpy(sharpened).unsqueeze(0)
|
||||||
tensor = torch.clamp(tensor, 0, 1)
|
tensor = torch.clamp(tensor, 0, 1)
|
||||||
return (tensor,)
|
result[b] = tensor
|
||||||
|
|
||||||
|
return (result,)
|
||||||
|
|
||||||
class CannyEdgeDetection:
|
class CannyEdgeDetection:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -212,11 +236,17 @@ class CannyEdgeDetection:
|
|||||||
CATEGORY = "postprocessing"
|
CATEGORY = "postprocessing"
|
||||||
|
|
||||||
def canny(self, image: torch.Tensor, lower_threshold: int, upper_threshold: int):
|
def canny(self, image: torch.Tensor, lower_threshold: int, upper_threshold: int):
|
||||||
tensor_image = image.numpy()[0]
|
batch_size, height, width, _ = image.shape
|
||||||
gray_image = (cv2.cvtColor(tensor_image, cv2.COLOR_BGR2GRAY) * 255).astype(np.uint8)
|
result = torch.zeros(batch_size, height, width)
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
tensor_image = image[b].numpy().copy()
|
||||||
|
gray_image = (cv2.cvtColor(tensor_image, cv2.COLOR_RGB2GRAY) * 255).astype(np.uint8)
|
||||||
canny = cv2.Canny(gray_image, lower_threshold, upper_threshold)
|
canny = cv2.Canny(gray_image, lower_threshold, upper_threshold)
|
||||||
tensor = torch.from_numpy(canny).unsqueeze(0)
|
tensor = torch.from_numpy(canny)
|
||||||
return (tensor,)
|
result[b] = tensor
|
||||||
|
|
||||||
|
return (result,)
|
||||||
|
|
||||||
class ColorCorrect:
|
class ColorCorrect:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -272,7 +302,11 @@ class ColorCorrect:
|
|||||||
CATEGORY = "postprocessing"
|
CATEGORY = "postprocessing"
|
||||||
|
|
||||||
def color_correct(self, image: torch.Tensor, temperature: float, hue: float, brightness: float, contrast: float, saturation: float, gamma: float):
|
def color_correct(self, image: torch.Tensor, temperature: float, hue: float, brightness: float, contrast: float, saturation: float, gamma: float):
|
||||||
tensor_image = image.numpy()[0]
|
batch_size, height, width, _ = image.shape
|
||||||
|
result = torch.zeros_like(image)
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
tensor_image = image[b].numpy()
|
||||||
|
|
||||||
brightness /= 100
|
brightness /= 100
|
||||||
contrast /= 100
|
contrast /= 100
|
||||||
@ -316,8 +350,9 @@ class ColorCorrect:
|
|||||||
modified_image = modified_image.astype(np.uint8)
|
modified_image = modified_image.astype(np.uint8)
|
||||||
modified_image = modified_image / 255
|
modified_image = modified_image / 255
|
||||||
modified_image = torch.from_numpy(modified_image).unsqueeze(0)
|
modified_image = torch.from_numpy(modified_image).unsqueeze(0)
|
||||||
|
result[b] = modified_image
|
||||||
|
|
||||||
return (modified_image, )
|
return (result, )
|
||||||
|
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user