mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 15:20:25 +08:00
Add transpose and rotate nodes
This commit is contained in:
parent
90581684b4
commit
d36ad5d958
@ -1,7 +1,8 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from PIL import Image
|
from PIL import Image, ImageColor
|
||||||
|
import re
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
|
|
||||||
@ -202,9 +203,128 @@ class Sharpen:
|
|||||||
|
|
||||||
return (result,)
|
return (result,)
|
||||||
|
|
||||||
|
class Transpose:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"method": ([
|
||||||
|
"Flip horizontal",
|
||||||
|
"Flip vertical",
|
||||||
|
"Rotate 90°",
|
||||||
|
"Rotate 180°",
|
||||||
|
"Rotate 270°",
|
||||||
|
"Transpose",
|
||||||
|
"Transverse",
|
||||||
|
],),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "transpose"
|
||||||
|
|
||||||
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
|
def transpose(self, image: torch.Tensor, method: str):
|
||||||
|
batch_size, height, width, _ = image.shape
|
||||||
|
result = torch.zeros_like(image)
|
||||||
|
|
||||||
|
methods = {
|
||||||
|
"Flip horizontal": Image.Transpose.FLIP_LEFT_RIGHT,
|
||||||
|
"Flip vertical": Image.Transpose.FLIP_TOP_BOTTOM,
|
||||||
|
"Rotate 90°": Image.Transpose.ROTATE_90,
|
||||||
|
"Rotate 180°": Image.Transpose.ROTATE_180,
|
||||||
|
"Rotate 270°": Image.Transpose.ROTATE_270,
|
||||||
|
"Transpose": Image.Transpose.TRANSPOSE,
|
||||||
|
"Transverse": Image.Transpose.TRANSVERSE,
|
||||||
|
}
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
tensor_image = image[b]
|
||||||
|
img = (tensor_image * 255).to(torch.uint8).numpy()
|
||||||
|
pil_image = Image.fromarray(img, mode='RGB')
|
||||||
|
|
||||||
|
transposed_image = pil_image.transpose(methods[method])
|
||||||
|
|
||||||
|
transposed_array = torch.tensor(np.array(transposed_image.convert("RGB"))).float() / 255
|
||||||
|
result[b] = transposed_array
|
||||||
|
|
||||||
|
return (result,)
|
||||||
|
|
||||||
|
class Rotate:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(s):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"image": ("IMAGE",),
|
||||||
|
"angle": ("FLOAT", {
|
||||||
|
"default": 0,
|
||||||
|
"min": 0,
|
||||||
|
"max": 360,
|
||||||
|
"step": 0.1
|
||||||
|
}),
|
||||||
|
"resample": ([
|
||||||
|
"Nearest Neighbor",
|
||||||
|
"Bilinear",
|
||||||
|
"Bicubic",
|
||||||
|
],),
|
||||||
|
"fill_color": ("STRING", {"default": "#000000"}),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("IMAGE",)
|
||||||
|
FUNCTION = "rotate"
|
||||||
|
|
||||||
|
CATEGORY = "image/postprocessing"
|
||||||
|
|
||||||
|
def rotate(self, image: torch.Tensor, angle: int, resample: str, fill_color: str):
|
||||||
|
batch_size, height, width, _ = image.shape
|
||||||
|
result = torch.zeros_like(image)
|
||||||
|
|
||||||
|
resamplers = {
|
||||||
|
"Nearest Neighbor": Image.Resampling.NEAREST,
|
||||||
|
"Bilinear": Image.Resampling.BILINEAR,
|
||||||
|
"Bicubic": Image.Resampling.BICUBIC,
|
||||||
|
}
|
||||||
|
|
||||||
|
for b in range(batch_size):
|
||||||
|
tensor_image = image[b]
|
||||||
|
img = (tensor_image * 255).to(torch.uint8).numpy()
|
||||||
|
pil_image = Image.fromarray(img, mode='RGB')
|
||||||
|
|
||||||
|
fill_color = fill_color or "#000000"
|
||||||
|
|
||||||
|
def parse_palette(color_str):
|
||||||
|
if re.match(r'^#[a-fA-F0-9]{6}$', color_str) or color_str.lower() in ImageColor.colormap:
|
||||||
|
return ImageColor.getrgb(color_str)
|
||||||
|
|
||||||
|
color_rgb = re.match(r'^\(?(\d{1,3}),(\d{1,3}),(\d{1,3})\)?$', color_str)
|
||||||
|
if color_rgb and int(color_rgb.group(1)) <= 255 and int(color_rgb.group(2)) <= 255 and int(color_rgb.group(3)) <= 255:
|
||||||
|
return tuple(map(int, re.findall(r'\d{1,3}', color_str)))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid color format: {color_str}")
|
||||||
|
|
||||||
|
color = fill_color.replace(" ", "")
|
||||||
|
color = parse_palette(color)
|
||||||
|
rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=False, fillcolor=color)
|
||||||
|
|
||||||
|
rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255
|
||||||
|
result[b] = rotated_array
|
||||||
|
|
||||||
|
return (result,)
|
||||||
|
|
||||||
NODE_CLASS_MAPPINGS = {
|
NODE_CLASS_MAPPINGS = {
|
||||||
"ImageBlend": Blend,
|
"ImageBlend": Blend,
|
||||||
"ImageBlur": Blur,
|
"ImageBlur": Blur,
|
||||||
"ImageQuantize": Quantize,
|
"ImageQuantize": Quantize,
|
||||||
"ImageSharpen": Sharpen,
|
"ImageSharpen": Sharpen,
|
||||||
|
"ImageTranspose": Transpose,
|
||||||
|
"ImageRotate": Rotate,
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user