diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 469a7be55..89f6ae33c 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -1240,6 +1240,54 @@ class SaveImageAdvanced(IO.ComfyNode): return IO.NodeOutput(ui={"images": results}) +class ImageGridSlice(IO.ComfyNode): + MAX_ROWS = 16 + MAX_COLUMNS = 16 + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="ImageGridSlice", + display_name="Image Grid Slice", + category="image/transform", + search_aliases=["grid", "slice image", "split image", "crop grid"], + description="Slices an image into a grid of rows x columns tiles, returned as a single batch (row-major).", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("rows", default=2, min=1, max=cls.MAX_ROWS), + IO.Int.Input("columns", default=2, min=1, max=cls.MAX_COLUMNS), + ], + outputs=[ + IO.Image.Output(display_name="images"), + ], + ) + + @classmethod + def execute(cls, image, rows, columns) -> IO.NodeOutput: + rows = max(1, min(rows, cls.MAX_ROWS)) + columns = max(1, min(columns, cls.MAX_COLUMNS)) + + _, height, width, _ = image.shape + tile_height = height // rows + tile_width = width // columns + + tiles = [] + for row in range(rows): + y_start = row * tile_height + for col in range(columns): + x_start = col * tile_width + tiles.append( + image[ + :, + y_start:y_start + tile_height, + x_start:x_start + tile_width, + :, + ] + ) + + return IO.NodeOutput(torch.cat(tiles, dim=0)) + + class ImagesExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: @@ -1262,6 +1310,7 @@ class ImagesExtension(ComfyExtension): ImageScaleToMaxDimension, SplitImageToTileList, ImageMergeTileList, + ImageGridSlice, ]