support preview latent

This commit is contained in:
ltdrdata 2023-05-18 23:49:09 +09:00
parent 62a371e12b
commit 3564ee85a6
2 changed files with 96 additions and 4 deletions

View File

@ -12,6 +12,9 @@ from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import safetensors.torch
from io import BytesIO
import piexif
import zipfile
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
@ -290,11 +293,79 @@ class SaveLatent:
return {}
class SavePreviewLatent(SaveLatent):
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": {"samples": ("LATENT", ),
"filename_prefix": ("STRING", {"default": "latents/ComfyUI"}), },
"optional": {"image": ("IMAGE", ), },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
FUNCTION = "save_preview_latent"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
@staticmethod
def save_to_file(tensor_bytes, prompt, extra_pnginfo, image, image_path):
compressed_data = BytesIO()
with zipfile.ZipFile(compressed_data, mode='w') as archive:
archive.writestr("latent", tensor_bytes)
image = image.copy()
exif_data = {"Exif": {piexif.ExifIFD.UserComment: compressed_data.getvalue()}}
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))
exif_bytes = piexif.dump(exif_data)
image.save(image_path, format='png', exif=exif_bytes, pnginfo=metadata)
@staticmethod
def load_preview(image):
if image is None:
comfy_path = os.path.dirname(__file__)
image_path = os.path.join(comfy_path, "logo.png")
return Image.open(image_path)
else:
i = 255. * image[0].cpu().numpy()
image = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
return image
def save_preview_latent(self, samples, filename_prefix="ComfyUI", image=None, prompt=None, extra_pnginfo=None):
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
# load preview
preview = SavePreviewLatent.load_preview(image)
# support save metadata for latent sharing
file = f"{filename}_{counter:05}_.latent.png"
file = os.path.join(full_output_folder, file)
output = {"latent_tensor": samples["samples"]}
tensor_bytes = safetensors.torch.save(output)
SavePreviewLatent.save_to_file(tensor_bytes, prompt, extra_pnginfo, preview, file)
return {}
class LoadLatent:
@classmethod
def INPUT_TYPES(s):
def check_file_extension(x):
return x.endswith(".latent") or x.endswith(".latent.png")
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(".latent")]
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and check_file_extension(f)]
return {"required": {"latent": [sorted(files), ]}, }
CATEGORY = "_for_testing"
@ -302,10 +373,29 @@ class LoadLatent:
RETURN_TYPES = ("LATENT", )
FUNCTION = "load"
@staticmethod
def load_preview_latent(image_path):
image = Image.open(image_path)
exif_data = piexif.load(image.info["exif"])
if piexif.ExifIFD.UserComment in exif_data["Exif"]:
compressed_data = exif_data["Exif"][piexif.ExifIFD.UserComment]
compressed_data_io = BytesIO(compressed_data)
with zipfile.ZipFile(compressed_data_io, mode='r') as archive:
tensor_bytes = archive.read("latent")
tensor = safetensors.torch.load(tensor_bytes)
return {"samples": tensor['latent_tensor']}
return None
def load(self, latent):
latent_path = folder_paths.get_annotated_filepath(latent)
latent = safetensors.torch.load_file(latent_path, device="cpu")
samples = {"samples": latent["latent_tensor"]}
if latent.endswith(".latent"):
latent = safetensors.torch.load_file(latent_path, device="cpu")
samples = {"samples": latent["latent_tensor"]}
else:
samples = LoadLatent.load_preview_latent(latent_path)
return (samples, )
@classmethod
@ -1282,7 +1372,8 @@ NODE_CLASS_MAPPINGS = {
"DiffusersLoader": DiffusersLoader,
"LoadLatent": LoadLatent,
"SaveLatent": SaveLatent
"SaveLatent": SaveLatent,
"SavePreviewLatent": SavePreviewLatent
}
NODE_DISPLAY_NAME_MAPPINGS = {

View File

@ -9,3 +9,4 @@ pytorch_lightning
aiohttp
accelerate
pyyaml
piexif