Add custom palette option to quantize

This commit is contained in:
missionfloyd 2023-04-07 00:37:11 -06:00 committed by GitHub
parent 44fea05064
commit f73c67cb7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,8 @@
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from PIL import Image from PIL import Image, ImageColor
import re
import comfy.utils import comfy.utils
@ -124,6 +125,7 @@ class Quantize:
"max": 256, "max": 256,
"step": 1 "step": 1
}), }),
"palette": ("STRING", {"default": ""}),
"dither": (["none", "floyd-steinberg"],), "dither": (["none", "floyd-steinberg"],),
}, },
} }
@ -133,7 +135,16 @@ class Quantize:
CATEGORY = "image/postprocessing" CATEGORY = "image/postprocessing"
def quantize(self, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"): def flatten_list(self, list_of_lists, flat_list=[]):
for item in list_of_lists:
if type(item) == tuple:
self.flatten_list(item, flat_list)
else:
flat_list.append(item)
return flat_list
def quantize(self, palette, image: torch.Tensor, colors: int = 256, dither: str = "FLOYDSTEINBERG"):
batch_size, height, width, _ = image.shape batch_size, height, width, _ = image.shape
result = torch.zeros_like(image) result = torch.zeros_like(image)
@ -144,8 +155,14 @@ class Quantize:
img = (tensor_image * 255).to(torch.uint8).numpy() img = (tensor_image * 255).to(torch.uint8).numpy()
pil_image = Image.fromarray(img, mode='RGB') pil_image = Image.fromarray(img, mode='RGB')
palette = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836 if palette:
quantized_image = pil_image.quantize(colors=colors, palette=palette, dither=dither_option) pal_img = Image.new('P', (1, 1))
pal_colors = palette.replace(" ", "").split(",")
pal_colors = map(lambda i: ImageColor.getrgb(i) if re.search("#[a-fA-F0-9]{6}", i) else int(i), pal_colors)
pal_img.putpalette(self.flatten_list(pal_colors))
else:
pal_img = pil_image.quantize(colors=colors) # Required as described in https://github.com/python-pillow/Pillow/issues/5836
quantized_image = pil_image.quantize(colors=colors, palette=pal_img, dither=dither_option)
quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255 quantized_array = torch.tensor(np.array(quantized_image.convert("RGB"))).float() / 255
result[b] = quantized_array result[b] = quantized_array