From 6a2dccf81b351a1eb6c17bc560b3d373267a45b0 Mon Sep 17 00:00:00 2001 From: missionfloyd Date: Fri, 7 Apr 2023 14:37:49 -0600 Subject: [PATCH] Fix updating colors --- comfy_extras/nodes_post_processing.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index f77bf7151..8d51cab21 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -134,17 +134,8 @@ class Quantize: FUNCTION = "quantize" CATEGORY = "image/postprocessing" - - def flatten_list(self, list_of_lists, flat_list=[]): - for item in list_of_lists: - if type(item) == tuple: - self.flatten_list(item, flat_list) - else: - flat_list.append(item) - return flat_list - - def quantize(self, palette, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"): + def quantize(self, palette: str, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"): batch_size, height, width, _ = image.shape result = torch.zeros_like(image) @@ -156,10 +147,18 @@ class Quantize: pil_image = Image.fromarray(img, mode='RGB') if palette: + def flatten_list(list_of_lists, flat_list=[]): + for item in list_of_lists: + if type(item) == tuple: + flatten_list(item, flat_list) + else: + flat_list.append(item) + return flat_list + pal_img = Image.new('P', (1, 1)) pal_colors = palette.replace(" ", "").split(",") pal_colors = map(lambda i: ImageColor.getrgb(i) if re.search("#[a-fA-F0-9]{6}", i) else int(i), pal_colors) - pal_img.putpalette(self.flatten_list(pal_colors)) + pal_img.putpalette(flatten_list(pal_colors)) else: pal_img = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 quantized_image = pil_image.quantize(colors=colors, palette=pal_img, dither=dither_option)