mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-11 01:32:31 +08:00
Use torch for extract, split, merge
This commit is contained in:
parent
9b40cd3f89
commit
fcc561261d
@ -326,9 +326,9 @@ class GetChannel:
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"channel": ([
|
||||
"Red",
|
||||
"Green",
|
||||
"Blue",
|
||||
"red",
|
||||
"green",
|
||||
"blue",
|
||||
],),
|
||||
},
|
||||
}
|
||||
@ -341,16 +341,11 @@ class GetChannel:
|
||||
def getchannel(self, image: torch.Tensor, channel: str):
|
||||
batch_size, height, width, _ = image.shape
|
||||
result = torch.zeros_like(image)
|
||||
channels = ["red", "green", "blue"]
|
||||
|
||||
for b in range(batch_size):
|
||||
tensor_image = image[b]
|
||||
img = (tensor_image * 255).to(torch.uint8).numpy()
|
||||
pil_image = Image.fromarray(img, mode='RGB')
|
||||
|
||||
output_image = pil_image.getchannel(channel[0])
|
||||
|
||||
output_array = torch.tensor(np.array(output_image.convert("RGB"))).float() / 255
|
||||
result[b] = output_array
|
||||
channel_out = image[b, :, :, channels.index(channel)]
|
||||
result[b] = channel_out[:, :, None].expand(-1, -1, 3)
|
||||
|
||||
return (result,)
|
||||
|
||||
@ -380,16 +375,8 @@ class Split:
|
||||
result_b = torch.zeros_like(image)
|
||||
|
||||
for b in range(batch_size):
|
||||
tensor_image = image[b]
|
||||
img = (tensor_image * 255).to(torch.uint8).numpy()
|
||||
pil_image = Image.fromarray(img, mode='RGB')
|
||||
|
||||
output_r, output_g, output_b = pil_image.split()
|
||||
|
||||
output_array_r = torch.tensor(np.array(output_r.convert("RGB"))).float() / 255
|
||||
output_array_g = torch.tensor(np.array(output_g.convert("RGB"))).float() / 255
|
||||
output_array_b = torch.tensor(np.array(output_b.convert("RGB"))).float() / 255
|
||||
result_r[b], result_g[b], result_b[b] = output_array_r, output_array_g, output_array_b
|
||||
channels = torch.chunk(image[b], 3, 2)
|
||||
result_r[b], result_g[b], result_b[b] = [x.expand(-1, -1, 3) for x in channels]
|
||||
|
||||
return (result_r, result_g, result_b)
|
||||
|
||||
@ -415,19 +402,10 @@ class Merge:
|
||||
def merge(self, red: torch.Tensor, green: torch.Tensor, blue: torch.Tensor):
|
||||
batch_size, height, width, _ = red.shape
|
||||
result = torch.zeros_like(red)
|
||||
|
||||
images = [red, green, blue]
|
||||
for b in range(batch_size):
|
||||
img_r = (red[b] * 255).to(torch.uint8).numpy()
|
||||
img_g = (green[b] * 255).to(torch.uint8).numpy()
|
||||
img_b = (blue[b] * 255).to(torch.uint8).numpy()
|
||||
pil_image_r = Image.fromarray(img_r, mode='RGB').convert("L")
|
||||
pil_image_g = Image.fromarray(img_g, mode='RGB').convert("L")
|
||||
pil_image_b = Image.fromarray(img_b, mode='RGB').convert("L")
|
||||
|
||||
output_image = Image.merge("RGB", (pil_image_r, pil_image_g, pil_image_b))
|
||||
|
||||
output_array = torch.tensor(np.array(output_image.convert("RGB"))).float() / 255
|
||||
result[b] = output_array
|
||||
for i in range(3):
|
||||
result[b, :, :, i] = images[i][b, :, :, 0]
|
||||
|
||||
return (result,)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user