diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 90090b104..10fbd77a6 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -290,13 +290,6 @@ class Rotate: "bilinear": Image.Resampling.BILINEAR, "bicubic": Image.Resampling.BICUBIC, } - - tensor_image = image[0] - img = (tensor_image * 255).to(torch.uint8).numpy() - pil_image = Image.fromarray(img, mode='RGB') - - expand = True if expand == "enabled" else False - fill_color = fill_color or "#000000" def parse_palette(color_str): if re.match(r'^#[a-fA-F0-9]{6}$', color_str) or color_str.lower() in ImageColor.colormap: @@ -308,16 +301,24 @@ class Rotate: else: raise ValueError(f"Invalid color format: {color_str}") + expand = True if expand == "enabled" else False + fill_color = fill_color or "#000000" center = (width / 2, height / 2) if center_of_image == "enabled" else (center_x, center_y) translate = (translate_x, translate_y) color = fill_color.replace(" ", "") color = parse_palette(color) - rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=expand, center=center, translate=translate, fillcolor=color) - result_width, result_height = rotated_image.size + + result_width, result_height = Image.new("RGB", (width, height)).rotate(angle=angle, expand=expand, center=center, translate=translate).size result = torch.zeros(batch_size, result_height, result_width, 3) - rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255 - result[0] = rotated_array + + for b in range(batch_size): + tensor_image = image[b] + img = (tensor_image * 255).to(torch.uint8).numpy() + pil_image = Image.fromarray(img, mode='RGB') + rotated_image = pil_image.rotate(angle=angle, resample=resamplers[resample], expand=expand, center=center, translate=translate, fillcolor=color) + rotated_array = torch.tensor(np.array(rotated_image.convert("RGB"))).float() / 255 + result[b] = rotated_array return (result,)