mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-30 00:00:26 +08:00
Add expand option to rotate
This commit is contained in:
parent
bc54b69c59
commit
9b40cd3f89
@ -270,6 +270,7 @@ class Rotate:
|
||||
"Bilinear",
|
||||
"Bicubic",
|
||||
],),
|
||||
"expand": (["disabled", "enabled"],),
|
||||
"fill_color": ("STRING", {"default": "#000000"}),
|
||||
},
|
||||
}
|
||||
@ -279,9 +280,8 @@ class Rotate:
|
||||
|
||||
CATEGORY = "image/postprocessing"
|
||||
|
||||
def rotate(self, image: torch.Tensor, angle: int, resample: str, fill_color: str):
|
||||
batch_size, height, width, _ = image.shape
|
||||
result = torch.zeros_like(image)
|
||||
def rotate(self, image: torch.Tensor, angle: int, resample: str, expand: str, fill_color: str):
|
||||
batch_size, _, _, _ = image.shape
|
||||
|
||||
resamplers = {
|
||||
"Nearest Neighbor": Image.Resampling.NEAREST,
|
||||
@ -289,29 +289,30 @@ class Rotate:
|
||||
"Bicubic": Image.Resampling.BICUBIC,
|
||||
}
|
||||
|
||||
for b in range(batch_size):
|
||||
tensor_image = image[b]
|
||||
img = (tensor_image * 255).to(torch.uint8).numpy()
|
||||
pil_image = Image.fromarray(img, mode='RGB')
|
||||
|
||||
fill_color = fill_color or "#000000"
|
||||
tensor_image = image[0]
|
||||
img = (tensor_image * 255).to(torch.uint8).numpy()
|
||||
pil_image = Image.fromarray(img, mode='RGB')
|
||||
|
||||
expand = True if expand == "enabled" else False
|
||||
fill_color = fill_color or "#000000"
|
||||
|
||||
def parse_palette(color_str):
|
||||
if re.match(r'^#[a-fA-F0-9]{6}$', color_str) or color_str.lower() in ImageColor.colormap:
|
||||
return ImageColor.getrgb(color_str)
|
||||
def parse_palette(color_str):
|
||||
if re.match(r'^#[a-fA-F0-9]{6}$', color_str) or color_str.lower() in ImageColor.colormap:
|
||||
return ImageColor.getrgb(color_str)
|
||||
|
||||
color_rgb = re.match(r'^\(?(\d{1,3}),(\d{1,3}),(\d{1,3})\)?$', color_str)
|
||||
if color_rgb and int(color_rgb.group(1)) <= 255 and int(color_rgb.group(2)) <= 255 and int(color_rgb.group(3)) <= 255:
|
||||
return tuple(map(int, re.findall(r'\d{1,3}', color_str)))
|
||||
else:
|
||||
raise ValueError(f"Invalid color format: {color_str}")
|
||||
color_rgb = re.match(r'^\(?(\d{1,3}),(\d{1,3}),(\d{1,3})\)?$', color_str)
|
||||
if color_rgb and int(color_rgb.group(1)) <= 255 and int(color_rgb.group(2)) <= 255 and int(color_rgb.group(3)) <= 255:
|
||||
return tuple(map(int, re.findall(r'\d{1,3}', color_str)))
|
||||
else:
|
||||
raise ValueError(f"Invalid color format: {color_str}")
|
||||
|
||||
color = fill_color.replace(" ", "")
|
||||
color = parse_palette(color)
|
||||
rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=False, fillcolor=color)
|
||||
|
||||
rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255
|
||||
result[b] = rotated_array
|
||||
color = fill_color.replace(" ", "")
|
||||
color = parse_palette(color)
|
||||
rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=expand, fillcolor=color)
|
||||
height, width = rotated_image.size
|
||||
result = torch.zeros(batch_size, height, width, 3)
|
||||
rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255
|
||||
result[0] = rotated_array
|
||||
|
||||
return (result,)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user