From 3564ee85a60fb556062acb63d06a0efe1505bca2 Mon Sep 17 00:00:00 2001 From: ltdrdata Date: Thu, 18 May 2023 23:49:09 +0900 Subject: [PATCH] support preview latent --- nodes.py | 99 ++++++++++++++++++++++++++++++++++++++++++++++-- requirements.txt | 1 + 2 files changed, 96 insertions(+), 4 deletions(-) diff --git a/nodes.py b/nodes.py index 3c61cd2ec..14501e74c 100644 --- a/nodes.py +++ b/nodes.py @@ -12,6 +12,9 @@ from PIL import Image from PIL.PngImagePlugin import PngInfo import numpy as np import safetensors.torch +from io import BytesIO +import piexif +import zipfile sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -290,11 +293,79 @@ class SaveLatent: return {} +class SavePreviewLatent(SaveLatent): + def __init__(self): + self.output_dir = folder_paths.get_output_directory() + + @classmethod + def INPUT_TYPES(s): + return {"required": {"samples": ("LATENT", ), + "filename_prefix": ("STRING", {"default": "latents/ComfyUI"}), }, + "optional": {"image": ("IMAGE", ), }, + "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, + } + RETURN_TYPES = () + FUNCTION = "save_preview_latent" + + OUTPUT_NODE = True + + CATEGORY = "_for_testing" + + @staticmethod + def save_to_file(tensor_bytes, prompt, extra_pnginfo, image, image_path): + compressed_data = BytesIO() + with zipfile.ZipFile(compressed_data, mode='w') as archive: + archive.writestr("latent", tensor_bytes) + image = image.copy() + exif_data = {"Exif": {piexif.ExifIFD.UserComment: compressed_data.getvalue()}} + + metadata = PngInfo() + if prompt is not None: + metadata.add_text("prompt", json.dumps(prompt)) + if extra_pnginfo is not None: + for x in extra_pnginfo: + metadata.add_text(x, json.dumps(extra_pnginfo[x])) + + exif_bytes = piexif.dump(exif_data) + image.save(image_path, format='png', exif=exif_bytes, pnginfo=metadata) + + @staticmethod + def load_preview(image): + if image is None: + comfy_path = os.path.dirname(__file__) + image_path = os.path.join(comfy_path, "logo.png") + return Image.open(image_path) + else: + i = 255. * image[0].cpu().numpy() + image = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8)) + return image + + def save_preview_latent(self, samples, filename_prefix="ComfyUI", image=None, prompt=None, extra_pnginfo=None): + full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) + + # load preview + preview = SavePreviewLatent.load_preview(image) + + # support save metadata for latent sharing + file = f"{filename}_{counter:05}_.latent.png" + file = os.path.join(full_output_folder, file) + + output = {"latent_tensor": samples["samples"]} + + tensor_bytes = safetensors.torch.save(output) + SavePreviewLatent.save_to_file(tensor_bytes, prompt, extra_pnginfo, preview, file) + + return {} + + class LoadLatent: @classmethod def INPUT_TYPES(s): + def check_file_extension(x): + return x.endswith(".latent") or x.endswith(".latent.png") + input_dir = folder_paths.get_input_directory() - files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and check_file_extension(f)] return {"required": {"latent": [sorted(files), ]}, } CATEGORY = "_for_testing" @@ -302,10 +373,29 @@ class LoadLatent: RETURN_TYPES = ("LATENT", ) FUNCTION = "load" + @staticmethod + def load_preview_latent(image_path): + image = Image.open(image_path) + exif_data = piexif.load(image.info["exif"]) + + if piexif.ExifIFD.UserComment in exif_data["Exif"]: + compressed_data = exif_data["Exif"][piexif.ExifIFD.UserComment] + compressed_data_io = BytesIO(compressed_data) + with zipfile.ZipFile(compressed_data_io, mode='r') as archive: + tensor_bytes = archive.read("latent") + tensor = safetensors.torch.load(tensor_bytes) + return {"samples": tensor['latent_tensor']} + return None + def load(self, latent): latent_path = folder_paths.get_annotated_filepath(latent) - latent = safetensors.torch.load_file(latent_path, device="cpu") - samples = {"samples": latent["latent_tensor"]} + + if latent.endswith(".latent"): + latent = safetensors.torch.load_file(latent_path, device="cpu") + samples = {"samples": latent["latent_tensor"]} + else: + samples = LoadLatent.load_preview_latent(latent_path) + return (samples, ) @classmethod @@ -1282,7 +1372,8 @@ NODE_CLASS_MAPPINGS = { "DiffusersLoader": DiffusersLoader, "LoadLatent": LoadLatent, - "SaveLatent": SaveLatent + "SaveLatent": SaveLatent, + "SavePreviewLatent": SavePreviewLatent } NODE_DISPLAY_NAME_MAPPINGS = { diff --git a/requirements.txt b/requirements.txt index 0527b31df..ccd8863cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ pytorch_lightning aiohttp accelerate pyyaml +piexif \ No newline at end of file