support preview latent

This commit is contained in:
ltdrdata 2023-05-18 23:49:09 +09:00
parent 62a371e12b
commit 3564ee85a6
2 changed files with 96 additions and 4 deletions

View File

@ -12,6 +12,9 @@ from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import numpy as np import numpy as np
import safetensors.torch import safetensors.torch
from io import BytesIO
import piexif
import zipfile
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
@ -290,11 +293,79 @@ class SaveLatent:
return {} return {}
class SavePreviewLatent(SaveLatent):
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT", ),
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"}), },
"optional": {"image": ("IMAGE", ), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_preview_latent"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
@staticmethod
def save_to_file(tensor_bytes, prompt, extra_pnginfo, image, image_path):
compressed_data = BytesIO()
with zipfile.ZipFile(compressed_data, mode='w') as archive:
archive.writestr("latent", tensor_bytes)
image = image.copy()
exif_data = {"Exif": {piexif.ExifIFD.UserComment: compressed_data.getvalue()}}
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
exif_bytes = piexif.dump(exif_data)
image.save(image_path, format='png', exif=exif_bytes, pnginfo=metadata)
@staticmethod
def load_preview(image):
if image is None:
comfy_path = os.path.dirname(__file__)
image_path = os.path.join(comfy_path, "logo.png")
return Image.open(image_path)
else:
i = 255. * image[0].cpu().numpy()
image = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
return image
def save_preview_latent(self, samples, filename_prefix="ComfyUI", image=None, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
# load preview
preview = SavePreviewLatent.load_preview(image)
# support save metadata for latent sharing
file = f"{filename}_{counter:05}_.latent.png"
file = os.path.join(full_output_folder, file)
output = {"latent_tensor": samples["samples"]}
tensor_bytes = safetensors.torch.save(output)
SavePreviewLatent.save_to_file(tensor_bytes, prompt, extra_pnginfo, preview, file)
return {}
class LoadLatent: class LoadLatent:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
def check_file_extension(x):
return x.endswith(".latent") or x.endswith(".latent.png")
input_dir = folder_paths.get_input_directory() input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")] files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and check_file_extension(f)]
return {"required": {"latent": [sorted(files), ]}, } return {"required": {"latent": [sorted(files), ]}, }
CATEGORY = "_for_testing" CATEGORY = "_for_testing"
@ -302,10 +373,29 @@ class LoadLatent:
RETURN_TYPES = ("LATENT", ) RETURN_TYPES = ("LATENT", )
FUNCTION = "load" FUNCTION = "load"
@staticmethod
def load_preview_latent(image_path):
image = Image.open(image_path)
exif_data = piexif.load(image.info["exif"])
if piexif.ExifIFD.UserComment in exif_data["Exif"]:
compressed_data = exif_data["Exif"][piexif.ExifIFD.UserComment]
compressed_data_io = BytesIO(compressed_data)
with zipfile.ZipFile(compressed_data_io, mode='r') as archive:
tensor_bytes = archive.read("latent")
tensor = safetensors.torch.load(tensor_bytes)
return {"samples": tensor['latent_tensor']}
return None
def load(self, latent): def load(self, latent):
latent_path = folder_paths.get_annotated_filepath(latent) latent_path = folder_paths.get_annotated_filepath(latent)
latent = safetensors.torch.load_file(latent_path, device="cpu")
samples = {"samples": latent["latent_tensor"]} if latent.endswith(".latent"):
latent = safetensors.torch.load_file(latent_path, device="cpu")
samples = {"samples": latent["latent_tensor"]}
else:
samples = LoadLatent.load_preview_latent(latent_path)
return (samples, ) return (samples, )
@classmethod @classmethod
@ -1282,7 +1372,8 @@ NODE_CLASS_MAPPINGS = {
"DiffusersLoader": DiffusersLoader, "DiffusersLoader": DiffusersLoader,
"LoadLatent": LoadLatent, "LoadLatent": LoadLatent,
"SaveLatent": SaveLatent "SaveLatent": SaveLatent,
"SavePreviewLatent": SavePreviewLatent
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {

View File

@ -9,3 +9,4 @@ pytorch_lightning
aiohttp aiohttp
accelerate accelerate
pyyaml pyyaml
piexif