PNG cICP chunk support

This commit is contained in:
catboxanon 2025-01-30 15:44:41 -05:00
parent ef85058e97
commit a4aba18d29
2 changed files with 39 additions and 8 deletions

View File

@ -246,7 +246,7 @@ class CLIP:
return self.patcher.get_key_patches() return self.patcher.get_key_patches()
class VAE: class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None): def __init__(self, sd=None, device=None, config=None, dtype=None, meta=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) sd = diffusers_convert.convert_vae_state_dict(sd)
@ -416,6 +416,15 @@ class VAE:
self.first_stage_model.to(self.vae_dtype) self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device() self.output_device = model_management.intermediate_device()
self.png_chunks = {}
if meta is not None:
meta_color_space = meta.get("modelspec.color_space")
if str(meta_color_space).lower().startswith("cicp:"):
cicp_chunk = meta_color_space.split("cicp:")[-1].split(",")
cicp_chunk = bytes([1 if b.lower() == 'true' else 0 if b.lower() == 'false' else int(b) for b in cicp_chunk])
self.png_chunks[b"cICP"] = cicp_chunk
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))

View File

@ -11,7 +11,7 @@ import time
import random import random
import logging import logging
from PIL import Image, ImageOps, ImageSequence from PIL import Image, ImageOps, ImageSequence, PngImagePlugin
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
@ -283,10 +283,12 @@ class VAEDecode:
CATEGORY = "latent" CATEGORY = "latent"
DESCRIPTION = "Decodes latent images back into pixel space images." DESCRIPTION = "Decodes latent images back into pixel space images."
def decode(self, vae, samples): def decode(self, vae: comfy.sd.VAE, samples):
images = vae.decode(samples["samples"]) images = vae.decode(samples["samples"])
if len(images.shape) == 5: #Combine batches if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
if vae.png_chunks is not None:
images.png_chunks = vae.png_chunks
return (images, ) return (images, )
class VAEDecodeTiled: class VAEDecodeTiled:
@ -769,7 +771,8 @@ class VAELoader:
else: else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name) vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path) sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd) meta = json.loads(comfy.utils.safetensors_header(vae_path, max_size=1024*1024) or "{}").get("__metadata__")
vae = comfy.sd.VAE(sd=sd, meta=meta)
return (vae,) return (vae,)
class ControlNetLoader: class ControlNetLoader:
@ -1576,6 +1579,7 @@ class SaveImage:
self.type = "output" self.type = "output"
self.prefix_append = "" self.prefix_append = ""
self.compress_level = 4 self.compress_level = 4
self.extra_chunks = [b"cICP"]
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -1597,6 +1601,13 @@ class SaveImage:
CATEGORY = "image" CATEGORY = "image"
DESCRIPTION = "Saves the input images to your ComfyUI output directory." DESCRIPTION = "Saves the input images to your ComfyUI output directory."
def putchunk_patched(self, fp, cid, *data):
for chunk in self.extra_chunks:
if cid == chunk.lower():
cid = chunk
break
return PngImagePlugin.putchunk(fp, cid, *data)
def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): def save_images(self, images, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
filename_prefix += self.prefix_append filename_prefix += self.prefix_append
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0]) full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir, images[0].shape[1], images[0].shape[0])
@ -1604,18 +1615,28 @@ class SaveImage:
for (batch_number, image) in enumerate(images): for (batch_number, image) in enumerate(images):
i = 255. * image.cpu().numpy() i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
metadata = None metadata = PngInfo()
if not args.disable_metadata: if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None: if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt)) metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None: if extra_pnginfo is not None:
for x in extra_pnginfo: for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x])) metadata.add_text(x, json.dumps(extra_pnginfo[x]))
if hasattr(images, "png_chunks"):
for name, data in images.png_chunks.items():
if name in self.extra_chunks:
metadata.add(name.lower(), data)
else:
metadata.add(name, data)
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png" file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)
#TODO: revert to using img.save once Pillow supports cICP chunk
img.encoderinfo = {"pnginfo": metadata, "compress_level": self.compress_level}
with open(os.path.join(full_output_folder, file), 'wb') as fp:
PngImagePlugin._save(img, fp, None, chunk=self.putchunk_patched)
results.append({ results.append({
"filename": file, "filename": file,
"subfolder": subfolder, "subfolder": subfolder,
@ -1627,6 +1648,7 @@ class SaveImage:
class PreviewImage(SaveImage): class PreviewImage(SaveImage):
def __init__(self): def __init__(self):
super().__init__()
self.output_dir = folder_paths.get_temp_directory() self.output_dir = folder_paths.get_temp_directory()
self.type = "temp" self.type = "temp"
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))