Image split, getchannel nodes

This commit is contained in:
missionfloyd 2023-04-09 18:48:05 -06:00 committed by GitHub
parent d36ad5d958
commit 6f7abd9497
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -320,6 +320,82 @@ class Rotate:
return (result,)
class GetChannel:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"channel": ([
"Red",
"Green",
"Blue",
],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "getchannel"
CATEGORY = "image/postprocessing"
def getchannel(self, image: torch.Tensor, channel: str):
batch_size, height, width, _ = image.shape
result = torch.zeros_like(image)
for b in range(batch_size):
tensor_image = image[b]
img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB')
output_image = pil_image.getchannel(channel[0])
output_array = torch.tensor(np.array(output_image.convert("RGB"))).float() / 255
result[b] = output_array
return (result,)
class Split:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE", "IMAGE")
FUNCTION = "split"
CATEGORY = "image/postprocessing"
def split(self, image: torch.Tensor):
batch_size, height, width, _ = image.shape
result_r = torch.zeros_like(image)
result_g = torch.zeros_like(image)
result_b = torch.zeros_like(image)
for b in range(batch_size):
tensor_image = image[b]
img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB')
output_r, output_g, output_b = pil_image.split()
output_array_r = torch.tensor(np.array(output_r.convert("RGB"))).float() / 255
output_array_g = torch.tensor(np.array(output_g.convert("RGB"))).float() / 255
output_array_b = torch.tensor(np.array(output_b.convert("RGB"))).float() / 255
result_r[b], result_g[b], result_b[b] = output_array_r, output_array_g, output_array_b
return (result_r, result_g, result_b)
NODE_CLASS_MAPPINGS = {
"ImageBlend": Blend,
"ImageBlur": Blur,
@ -327,4 +403,6 @@ NODE_CLASS_MAPPINGS = {
"ImageSharpen": Sharpen,
"ImageTranspose": Transpose,
"ImageRotate": Rotate,
"ImageGetChannel": GetChannel,
"ImageSplit": Split,
}