mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Enable previews by default and over distributed channels
This commit is contained in:
parent
37cca051b6
commit
e49c662c7f
@ -111,7 +111,7 @@ def create_parser() -> argparse.ArgumentParser:
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true",
|
||||
help="Disables ipex.optimize when loading models with Intel GPUs.")
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews,
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto,
|
||||
help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
# Define a class for your command-line arguments
|
||||
import enum
|
||||
from typing import Optional, List, Callable
|
||||
from typing import Optional, List, Callable, Literal
|
||||
import configargparse as argparse
|
||||
|
||||
ConfigurationExtender = Callable[[argparse.ArgParser], Optional[argparse.ArgParser]]
|
||||
@ -50,7 +50,7 @@ class Configuration(dict):
|
||||
fp32_text_enc (bool): Use FP32 precision for the text encoder.
|
||||
directml (Optional[int]): Use DirectML. -1 for auto-selection.
|
||||
disable_ipex_optimize (bool): Disable IPEX optimization for Intel GPUs.
|
||||
preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "none".
|
||||
preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "auto".
|
||||
use_split_cross_attention (bool): Use split cross-attention optimization.
|
||||
use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization.
|
||||
use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function.
|
||||
@ -116,7 +116,7 @@ class Configuration(dict):
|
||||
self.fp32_text_enc: bool = False
|
||||
self.directml: Optional[int] = None
|
||||
self.disable_ipex_optimize: bool = False
|
||||
self.preview_method: str = "none"
|
||||
self.preview_method: LatentPreviewMethod = LatentPreviewMethod.Auto
|
||||
self.use_split_cross_attention: bool = False
|
||||
self.use_quad_cross_attention: bool = False
|
||||
self.use_pytorch_cross_attention: bool = False
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
from ..cli_args import args
|
||||
from ..cli_args_types import LatentPreviewMethod
|
||||
from ..model_downloader import get_or_download, KNOWN_APPROX_VAES
|
||||
from ..taesd.taesd import TAESD
|
||||
from ..cmd import folder_paths
|
||||
from .. import utils
|
||||
@ -59,7 +62,7 @@ def get_previewer(device, latent_format):
|
||||
if fn.startswith(latent_format.taesd_decoder_name)),
|
||||
""
|
||||
)
|
||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
|
||||
taesd_decoder_path = get_or_download("vae_approx", taesd_decoder_path, KNOWN_APPROX_VAES)
|
||||
|
||||
if method == LatentPreviewMethod.Auto:
|
||||
method = LatentPreviewMethod.Latent2RGB
|
||||
|
||||
29
comfy/cmd/latent_preview_image_encoding.py
Normal file
29
comfy/cmd/latent_preview_image_encoding.py
Normal file
@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import struct
|
||||
from io import BytesIO
|
||||
from typing import Literal
|
||||
|
||||
import PIL.Image
|
||||
from PIL import Image, ImageOps
|
||||
|
||||
|
||||
def encode_preview_image(image: PIL.Image.Image, image_type: Literal["JPEG", "PNG"], max_size: int):
|
||||
if max_size is not None:
|
||||
if hasattr(Image, 'Resampling'):
|
||||
resampling = Image.Resampling.BILINEAR
|
||||
else:
|
||||
resampling = Image.Resampling.LANCZOS
|
||||
|
||||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||
type_num = 1
|
||||
if image_type == "JPEG":
|
||||
type_num = 1
|
||||
elif image_type == "PNG":
|
||||
type_num = 2
|
||||
bytesIO = BytesIO()
|
||||
header = struct.pack(">I", type_num)
|
||||
bytesIO.write(header)
|
||||
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
||||
preview_bytes = bytesIO.getvalue()
|
||||
return preview_bytes
|
||||
@ -11,21 +11,24 @@ import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from asyncio import Future, AbstractEventLoop
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import List, Optional, Dict
|
||||
from urllib.parse import quote, urlencode
|
||||
from posixpath import join as urljoin
|
||||
|
||||
from can_ada import URL, parse as urlparse
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
from PIL import Image, ImageOps
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from aiohttp import web
|
||||
from pkg_resources import resource_filename
|
||||
from typing_extensions import NamedTuple
|
||||
|
||||
import comfy.interruption
|
||||
from .latent_preview_image_encoding import encode_preview_image
|
||||
from .. import model_management
|
||||
from .. import utils
|
||||
from ..app.user_manager import UserManager
|
||||
@ -713,8 +716,15 @@ class PromptServer(ExecutorToClientProgress):
|
||||
else:
|
||||
await self.send_json(event, data, sid)
|
||||
|
||||
def encode_bytes(self, event, data):
|
||||
if not isinstance(event, int):
|
||||
def encode_bytes(self, event: int | Enum | str, data):
|
||||
# todo: investigate what is propagating these spurious, string-repr'd previews
|
||||
if event == repr(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE):
|
||||
event = BinaryEventTypes.UNENCODED_PREVIEW_IMAGE.value
|
||||
elif event == repr(BinaryEventTypes.PREVIEW_IMAGE):
|
||||
event = BinaryEventTypes.PREVIEW_IMAGE.value
|
||||
elif isinstance(event, Enum):
|
||||
event: int = event.value
|
||||
elif not isinstance(event, int):
|
||||
raise RuntimeError(f"Binary event types must be integers, got {event}")
|
||||
|
||||
packed = struct.pack(">I", event)
|
||||
@ -726,24 +736,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
image_type = image_data[0]
|
||||
image = image_data[1]
|
||||
max_size = image_data[2]
|
||||
if max_size is not None:
|
||||
if hasattr(Image, 'Resampling'):
|
||||
resampling = Image.Resampling.BILINEAR
|
||||
else:
|
||||
resampling = Image.Resampling.LANCZOS
|
||||
|
||||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||
type_num = 1
|
||||
if image_type == "JPEG":
|
||||
type_num = 1
|
||||
elif image_type == "PNG":
|
||||
type_num = 2
|
||||
|
||||
bytesIO = BytesIO()
|
||||
header = struct.pack(">I", type_num)
|
||||
bytesIO.write(header)
|
||||
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
||||
preview_bytes = bytesIO.getvalue()
|
||||
preview_bytes = encode_preview_image(image, image_type, max_size)
|
||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||
|
||||
async def send_bytes(self, event, data, sid=None):
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
from __future__ import annotations # for Python 3.7-3.9
|
||||
|
||||
import PIL.Image
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
from typing import Optional, Literal, Protocol, TypeAlias, Union
|
||||
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple
|
||||
|
||||
from .queue_types import BinaryEventTypes
|
||||
|
||||
@ -34,11 +35,17 @@ class ProgressMessage(TypedDict):
|
||||
sid: NotRequired[str]
|
||||
|
||||
|
||||
class UnencodedPreviewImageMessage(NamedTuple):
|
||||
format: Literal["JPEG", "PNG"]
|
||||
pil_image: PIL.Image.Image
|
||||
max_size: int = 512
|
||||
|
||||
|
||||
ExecutedMessage: TypeAlias = ExecutingMessage
|
||||
|
||||
SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None]
|
||||
|
||||
SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, bytes, bytearray, None]
|
||||
SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
|
||||
|
||||
|
||||
class ExecutorToClientProgress(Protocol):
|
||||
|
||||
@ -1,14 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from asyncio import AbstractEventLoop
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from aio_pika.patterns import RPC
|
||||
|
||||
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress
|
||||
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \
|
||||
UnencodedPreviewImageMessage
|
||||
from ..component_model.queue_types import BinaryEventTypes
|
||||
from ..utils import hijack_progress
|
||||
|
||||
@ -17,6 +19,8 @@ async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[
|
||||
caller_server: Optional[ExecutorToClientProgress] = None) -> None:
|
||||
assert caller_server is not None
|
||||
assert user_id is not None
|
||||
if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE or isinstance(data, str):
|
||||
data: bytes = base64.b64decode(data)
|
||||
caller_server.send_sync(event, data, sid=user_id)
|
||||
|
||||
|
||||
@ -39,11 +43,19 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
||||
|
||||
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
||||
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
|
||||
if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||
return
|
||||
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||
from ..cmd.latent_preview_image_encoding import encode_preview_image
|
||||
|
||||
# encode preview image
|
||||
event = BinaryEventTypes.PREVIEW_IMAGE.value
|
||||
data: UnencodedPreviewImageMessage
|
||||
format, pil_image, max_size = data
|
||||
data: bytes = encode_preview_image(pil_image, format, max_size)
|
||||
|
||||
if isinstance(data, bytes) or isinstance(data, bytearray):
|
||||
return
|
||||
if isinstance(event, Enum):
|
||||
event: int = event.value
|
||||
data: str = base64.b64encode(data).decode()
|
||||
|
||||
if user_id is None:
|
||||
# todo: user_id should never be none here
|
||||
|
||||
@ -15,13 +15,13 @@ from aio_pika.patterns import JsonRPC
|
||||
|
||||
from .distributed_progress import ProgressHandlers
|
||||
from .distributed_types import RpcRequest, RpcReply
|
||||
from .history import History
|
||||
from .server_stub import ServerStub
|
||||
from ..auth.permissions import jwt_decode
|
||||
from ..cmd.server import PromptServer
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
|
||||
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
|
||||
from .history import History
|
||||
from ..cmd.server import PromptServer
|
||||
|
||||
|
||||
class DistributedPromptQueue(AbstractPromptQueue):
|
||||
|
||||
@ -281,6 +281,11 @@ KNOWN_DIFF_CONTROLNETS = [
|
||||
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_seg_fp16.safetensors"),
|
||||
]
|
||||
|
||||
KNOWN_APPROX_VAES = [
|
||||
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
|
||||
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors")
|
||||
]
|
||||
|
||||
|
||||
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
|
||||
if args.disable_known_models:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user