This commit is contained in:
catboxanon 2025-12-30 18:00:42 +08:00 committed by GitHub
commit 60745581ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 3 deletions

View File

@ -679,6 +679,15 @@ class VAE:
self.first_stage_model.to(self.vae_dtype)
self.output_device = model_management.intermediate_device()
self.png_chunks = {}
if metadata is not None:
meta_color_space = metadata.get("modelspec.color_space")
if str(meta_color_space).lower().startswith("cicp:"):
cicp_chunk = meta_color_space.split("cicp:")[-1].split(",")
cicp_chunk = bytes([1 if b.lower() == 'true' else 0 if b.lower() == 'false' else int(b) for b in cicp_chunk])
self.png_chunks[b"cICP"] = cicp_chunk
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
self.model_size()

View File

@ -294,10 +294,12 @@ class VAEDecode:
CATEGORY = "latent"
DESCRIPTION = "Decodes latent images back into pixel space images."
def decode(self, vae, samples):
def decode(self, vae: comfy.sd.VAE, samples):
images = vae.decode(samples["samples"])
if len(images.shape) == 5: #Combine batches
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
if vae.png_chunks is not None:
images.png_chunks = vae.png_chunks
return (images, )
class VAEDecodeTiled:
@ -794,7 +796,8 @@ class VAELoader:
else:
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
sd = comfy.utils.load_torch_file(vae_path)
vae = comfy.sd.VAE(sd=sd)
metadata = json.loads(comfy.utils.safetensors_header(vae_path, max_size=1024*1024) or "{}").get("__metadata__")
vae = comfy.sd.VAE(sd=sd, metadata=metadata)
vae.throw_exception_if_invalid()
return (vae,)
@ -1613,7 +1616,9 @@ class SaveImage:
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
if hasattr(images, "png_chunks"):
for name, data in images.png_chunks.items():
metadata.add(name, data)
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
file = f"{filename_with_batch_num}_{counter:05}_.png"
img.save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=self.compress_level)