mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-13 02:33:02 +08:00
Use torch for extract, split, merge
This commit is contained in:
parent
9b40cd3f89
commit
fcc561261d
@ -326,9 +326,9 @@ class GetChannel:
|
|||||||
"required": {
|
"required": {
|
||||||
"image": ("IMAGE",),
|
"image": ("IMAGE",),
|
||||||
"channel": ([
|
"channel": ([
|
||||||
"Red",
|
"red",
|
||||||
"Green",
|
"green",
|
||||||
"Blue",
|
"blue",
|
||||||
],),
|
],),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -341,16 +341,11 @@ class GetChannel:
|
|||||||
def getchannel(self, image: torch.Tensor, channel: str):
|
def getchannel(self, image: torch.Tensor, channel: str):
|
||||||
batch_size, height, width, _ = image.shape
|
batch_size, height, width, _ = image.shape
|
||||||
result = torch.zeros_like(image)
|
result = torch.zeros_like(image)
|
||||||
|
channels = ["red", "green", "blue"]
|
||||||
|
|
||||||
for b in range(batch_size):
|
for b in range(batch_size):
|
||||||
tensor_image = image[b]
|
channel_out = image[b, :, :, channels.index(channel)]
|
||||||
img = (tensor_image * 255).to(torch.uint8).numpy()
|
result[b] = channel_out[:, :, None].expand(-1, -1, 3)
|
||||||
pil_image = Image.fromarray(img, mode='RGB')
|
|
||||||
|
|
||||||
output_image = pil_image.getchannel(channel[0])
|
|
||||||
|
|
||||||
output_array = torch.tensor(np.array(output_image.convert("RGB"))).float() / 255
|
|
||||||
result[b] = output_array
|
|
||||||
|
|
||||||
return (result,)
|
return (result,)
|
||||||
|
|
||||||
@ -380,16 +375,8 @@ class Split:
|
|||||||
result_b = torch.zeros_like(image)
|
result_b = torch.zeros_like(image)
|
||||||
|
|
||||||
for b in range(batch_size):
|
for b in range(batch_size):
|
||||||
tensor_image = image[b]
|
channels = torch.chunk(image[b], 3, 2)
|
||||||
img = (tensor_image * 255).to(torch.uint8).numpy()
|
result_r[b], result_g[b], result_b[b] = [x.expand(-1, -1, 3) for x in channels]
|
||||||
pil_image = Image.fromarray(img, mode='RGB')
|
|
||||||
|
|
||||||
output_r, output_g, output_b = pil_image.split()
|
|
||||||
|
|
||||||
output_array_r = torch.tensor(np.array(output_r.convert("RGB"))).float() / 255
|
|
||||||
output_array_g = torch.tensor(np.array(output_g.convert("RGB"))).float() / 255
|
|
||||||
output_array_b = torch.tensor(np.array(output_b.convert("RGB"))).float() / 255
|
|
||||||
result_r[b], result_g[b], result_b[b] = output_array_r, output_array_g, output_array_b
|
|
||||||
|
|
||||||
return (result_r, result_g, result_b)
|
return (result_r, result_g, result_b)
|
||||||
|
|
||||||
@ -415,19 +402,10 @@ class Merge:
|
|||||||
def merge(self, red: torch.Tensor, green: torch.Tensor, blue: torch.Tensor):
|
def merge(self, red: torch.Tensor, green: torch.Tensor, blue: torch.Tensor):
|
||||||
batch_size, height, width, _ = red.shape
|
batch_size, height, width, _ = red.shape
|
||||||
result = torch.zeros_like(red)
|
result = torch.zeros_like(red)
|
||||||
|
images = [red, green, blue]
|
||||||
for b in range(batch_size):
|
for b in range(batch_size):
|
||||||
img_r = (red[b] * 255).to(torch.uint8).numpy()
|
for i in range(3):
|
||||||
img_g = (green[b] * 255).to(torch.uint8).numpy()
|
result[b, :, :, i] = images[i][b, :, :, 0]
|
||||||
img_b = (blue[b] * 255).to(torch.uint8).numpy()
|
|
||||||
pil_image_r = Image.fromarray(img_r, mode='RGB').convert("L")
|
|
||||||
pil_image_g = Image.fromarray(img_g, mode='RGB').convert("L")
|
|
||||||
pil_image_b = Image.fromarray(img_b, mode='RGB').convert("L")
|
|
||||||
|
|
||||||
output_image = Image.merge("RGB", (pil_image_r, pil_image_g, pil_image_b))
|
|
||||||
|
|
||||||
output_array = torch.tensor(np.array(output_image.convert("RGB"))).float() / 255
|
|
||||||
result[b] = output_array
|
|
||||||
|
|
||||||
return (result,)
|
return (result,)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user