diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 5fedd901a..6704b5131 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -111,7 +111,7 @@ def create_parser() -> argparse.ArgumentParser: parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.") - parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, + parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto, help="Default preview method for sampler nodes.", action=EnumAction) attn_group = parser.add_mutually_exclusive_group() diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index 5d1c8f8cf..4275f0b7a 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -1,6 +1,6 @@ # Define a class for your command-line arguments import enum -from typing import Optional, List, Callable +from typing import Optional, List, Callable, Literal import configargparse as argparse ConfigurationExtender = Callable[[argparse.ArgParser], Optional[argparse.ArgParser]] @@ -50,7 +50,7 @@ class Configuration(dict): fp32_text_enc (bool): Use FP32 precision for the text encoder. directml (Optional[int]): Use DirectML. -1 for auto-selection. disable_ipex_optimize (bool): Disable IPEX optimization for Intel GPUs. - preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "none". + preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "auto". use_split_cross_attention (bool): Use split cross-attention optimization. use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization. use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function. @@ -116,7 +116,7 @@ class Configuration(dict): self.fp32_text_enc: bool = False self.directml: Optional[int] = None self.disable_ipex_optimize: bool = False - self.preview_method: str = "none" + self.preview_method: LatentPreviewMethod = LatentPreviewMethod.Auto self.use_split_cross_attention: bool = False self.use_quad_cross_attention: bool = False self.use_pytorch_cross_attention: bool = False diff --git a/comfy/cmd/latent_preview.py b/comfy/cmd/latent_preview.py index 50affe00a..814ca2864 100644 --- a/comfy/cmd/latent_preview.py +++ b/comfy/cmd/latent_preview.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import torch from PIL import Image import numpy as np from ..cli_args import args from ..cli_args_types import LatentPreviewMethod +from ..model_downloader import get_or_download, KNOWN_APPROX_VAES from ..taesd.taesd import TAESD from ..cmd import folder_paths from .. import utils @@ -59,7 +62,7 @@ def get_previewer(device, latent_format): if fn.startswith(latent_format.taesd_decoder_name)), "" ) - taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path) + taesd_decoder_path = get_or_download("vae_approx", taesd_decoder_path, KNOWN_APPROX_VAES) if method == LatentPreviewMethod.Auto: method = LatentPreviewMethod.Latent2RGB diff --git a/comfy/cmd/latent_preview_image_encoding.py b/comfy/cmd/latent_preview_image_encoding.py new file mode 100644 index 000000000..041f17d8b --- /dev/null +++ b/comfy/cmd/latent_preview_image_encoding.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import struct +from io import BytesIO +from typing import Literal + +import PIL.Image +from PIL import Image, ImageOps + + +def encode_preview_image(image: PIL.Image.Image, image_type: Literal["JPEG", "PNG"], max_size: int): + if max_size is not None: + if hasattr(Image, 'Resampling'): + resampling = Image.Resampling.BILINEAR + else: + resampling = Image.Resampling.LANCZOS + + image = ImageOps.contain(image, (max_size, max_size), resampling) + type_num = 1 + if image_type == "JPEG": + type_num = 1 + elif image_type == "PNG": + type_num = 2 + bytesIO = BytesIO() + header = struct.pack(">I", type_num) + bytesIO.write(header) + image.save(bytesIO, format=image_type, quality=95, compress_level=1) + preview_bytes = bytesIO.getvalue() + return preview_bytes diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 6beddd324..79b89bd4a 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -11,21 +11,24 @@ import sys import traceback import uuid from asyncio import Future, AbstractEventLoop +from enum import Enum from io import BytesIO from typing import List, Optional, Dict from urllib.parse import quote, urlencode from posixpath import join as urljoin + from can_ada import URL, parse as urlparse import aiofiles import aiohttp -from PIL import Image, ImageOps +from PIL import Image from PIL.PngImagePlugin import PngInfo from aiohttp import web from pkg_resources import resource_filename from typing_extensions import NamedTuple import comfy.interruption +from .latent_preview_image_encoding import encode_preview_image from .. import model_management from .. import utils from ..app.user_manager import UserManager @@ -713,8 +716,15 @@ class PromptServer(ExecutorToClientProgress): else: await self.send_json(event, data, sid) - def encode_bytes(self, event, data): - if not isinstance(event, int): + def encode_bytes(self, event: int | Enum | str, data): + # todo: investigate what is propagating these spurious, string-repr'd previews + if event == repr(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE): + event = BinaryEventTypes.UNENCODED_PREVIEW_IMAGE.value + elif event == repr(BinaryEventTypes.PREVIEW_IMAGE): + event = BinaryEventTypes.PREVIEW_IMAGE.value + elif isinstance(event, Enum): + event: int = event.value + elif not isinstance(event, int): raise RuntimeError(f"Binary event types must be integers, got {event}") packed = struct.pack(">I", event) @@ -726,24 +736,7 @@ class PromptServer(ExecutorToClientProgress): image_type = image_data[0] image = image_data[1] max_size = image_data[2] - if max_size is not None: - if hasattr(Image, 'Resampling'): - resampling = Image.Resampling.BILINEAR - else: - resampling = Image.Resampling.LANCZOS - - image = ImageOps.contain(image, (max_size, max_size), resampling) - type_num = 1 - if image_type == "JPEG": - type_num = 1 - elif image_type == "PNG": - type_num = 2 - - bytesIO = BytesIO() - header = struct.pack(">I", type_num) - bytesIO.write(header) - image.save(bytesIO, format=image_type, quality=95, compress_level=1) - preview_bytes = bytesIO.getvalue() + preview_bytes = encode_preview_image(image, image_type, max_size) await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) async def send_bytes(self, event, data, sid=None): diff --git a/comfy/component_model/executor_types.py b/comfy/component_model/executor_types.py index 38523aac1..6dfc1e4be 100644 --- a/comfy/component_model/executor_types.py +++ b/comfy/component_model/executor_types.py @@ -1,7 +1,8 @@ from __future__ import annotations # for Python 3.7-3.9 +import PIL.Image from typing_extensions import NotRequired, TypedDict -from typing import Optional, Literal, Protocol, TypeAlias, Union +from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple from .queue_types import BinaryEventTypes @@ -34,11 +35,17 @@ class ProgressMessage(TypedDict): sid: NotRequired[str] +class UnencodedPreviewImageMessage(NamedTuple): + format: Literal["JPEG", "PNG"] + pil_image: PIL.Image.Image + max_size: int = 512 + + ExecutedMessage: TypeAlias = ExecutingMessage SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None] -SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, bytes, bytearray, None] +SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None] class ExecutorToClientProgress(Protocol): diff --git a/comfy/distributed/distributed_progress.py b/comfy/distributed/distributed_progress.py index 6e971994d..264f385cd 100644 --- a/comfy/distributed/distributed_progress.py +++ b/comfy/distributed/distributed_progress.py @@ -1,14 +1,16 @@ from __future__ import annotations import asyncio +import base64 from asyncio import AbstractEventLoop +from enum import Enum from functools import partial - from typing import Optional, Dict, Any from aio_pika.patterns import RPC -from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress +from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \ + UnencodedPreviewImageMessage from ..component_model.queue_types import BinaryEventTypes from ..utils import hijack_progress @@ -17,6 +19,8 @@ async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[ caller_server: Optional[ExecutorToClientProgress] = None) -> None: assert caller_server is not None assert user_id is not None + if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE or isinstance(data, str): + data: bytes = base64.b64decode(data) caller_server.send_sync(event, data, sid=user_id) @@ -39,11 +43,19 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress): async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None: # for now, do not send binary data this way, since it cannot be json serialized / it's impractical - if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: - return + if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: + from ..cmd.latent_preview_image_encoding import encode_preview_image + + # encode preview image + event = BinaryEventTypes.PREVIEW_IMAGE.value + data: UnencodedPreviewImageMessage + format, pil_image, max_size = data + data: bytes = encode_preview_image(pil_image, format, max_size) if isinstance(data, bytes) or isinstance(data, bytearray): - return + if isinstance(event, Enum): + event: int = event.value + data: str = base64.b64encode(data).decode() if user_id is None: # todo: user_id should never be none here diff --git a/comfy/distributed/distributed_prompt_queue.py b/comfy/distributed/distributed_prompt_queue.py index 82b5504d1..613c8d867 100644 --- a/comfy/distributed/distributed_prompt_queue.py +++ b/comfy/distributed/distributed_prompt_queue.py @@ -15,13 +15,13 @@ from aio_pika.patterns import JsonRPC from .distributed_progress import ProgressHandlers from .distributed_types import RpcRequest, RpcReply +from .history import History from .server_stub import ServerStub from ..auth.permissions import jwt_decode +from ..cmd.server import PromptServer from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation -from .history import History -from ..cmd.server import PromptServer class DistributedPromptQueue(AbstractPromptQueue): diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 23e2d29a7..8b76ab920 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -281,6 +281,11 @@ KNOWN_DIFF_CONTROLNETS = [ HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_seg_fp16.safetensors"), ] +KNOWN_APPROX_VAES = [ + HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"), + HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors") +] + def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]: if args.disable_known_models: