Enable previews by default and over distributed channels

This commit is contained in:
doctorpangloss 2024-04-09 13:15:05 -07:00
parent 37cca051b6
commit e49c662c7f
9 changed files with 84 additions and 35 deletions

View File

@ -111,7 +111,7 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument("--disable-ipex-optimize", action="store_true",
help="Disables ipex.optimize when loading models with Intel GPUs.")
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews,
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto,
help="Default preview method for sampler nodes.", action=EnumAction)
attn_group = parser.add_mutually_exclusive_group()

View File

@ -1,6 +1,6 @@
# Define a class for your command-line arguments
import enum
from typing import Optional, List, Callable
from typing import Optional, List, Callable, Literal
import configargparse as argparse
ConfigurationExtender = Callable[[argparse.ArgParser], Optional[argparse.ArgParser]]
@ -50,7 +50,7 @@ class Configuration(dict):
fp32_text_enc (bool): Use FP32 precision for the text encoder.
directml (Optional[int]): Use DirectML. -1 for auto-selection.
disable_ipex_optimize (bool): Disable IPEX optimization for Intel GPUs.
preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "none".
preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "auto".
use_split_cross_attention (bool): Use split cross-attention optimization.
use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization.
use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function.
@ -116,7 +116,7 @@ class Configuration(dict):
self.fp32_text_enc: bool = False
self.directml: Optional[int] = None
self.disable_ipex_optimize: bool = False
self.preview_method: str = "none"
self.preview_method: LatentPreviewMethod = LatentPreviewMethod.Auto
self.use_split_cross_attention: bool = False
self.use_quad_cross_attention: bool = False
self.use_pytorch_cross_attention: bool = False

View File

@ -1,8 +1,11 @@
from __future__ import annotations
import torch
from PIL import Image
import numpy as np
from ..cli_args import args
from ..cli_args_types import LatentPreviewMethod
from ..model_downloader import get_or_download, KNOWN_APPROX_VAES
from ..taesd.taesd import TAESD
from ..cmd import folder_paths
from .. import utils
@ -59,7 +62,7 @@ def get_previewer(device, latent_format):
if fn.startswith(latent_format.taesd_decoder_name)),
""
)
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
taesd_decoder_path = get_or_download("vae_approx", taesd_decoder_path, KNOWN_APPROX_VAES)
if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB

View File

@ -0,0 +1,29 @@
from __future__ import annotations
import struct
from io import BytesIO
from typing import Literal
import PIL.Image
from PIL import Image, ImageOps
def encode_preview_image(image: PIL.Image.Image, image_type: Literal["JPEG", "PNG"], max_size: int):
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.Resampling.LANCZOS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1
if image_type == "JPEG":
type_num = 1
elif image_type == "PNG":
type_num = 2
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
preview_bytes = bytesIO.getvalue()
return preview_bytes

View File

@ -11,21 +11,24 @@ import sys
import traceback
import uuid
from asyncio import Future, AbstractEventLoop
from enum import Enum
from io import BytesIO
from typing import List, Optional, Dict
from urllib.parse import quote, urlencode
from posixpath import join as urljoin
from can_ada import URL, parse as urlparse
import aiofiles
import aiohttp
from PIL import Image, ImageOps
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from aiohttp import web
from pkg_resources import resource_filename
from typing_extensions import NamedTuple
import comfy.interruption
from .latent_preview_image_encoding import encode_preview_image
from .. import model_management
from .. import utils
from ..app.user_manager import UserManager
@ -713,8 +716,15 @@ class PromptServer(ExecutorToClientProgress):
else:
await self.send_json(event, data, sid)
def encode_bytes(self, event, data):
if not isinstance(event, int):
def encode_bytes(self, event: int | Enum | str, data):
# todo: investigate what is propagating these spurious, string-repr'd previews
if event == repr(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE):
event = BinaryEventTypes.UNENCODED_PREVIEW_IMAGE.value
elif event == repr(BinaryEventTypes.PREVIEW_IMAGE):
event = BinaryEventTypes.PREVIEW_IMAGE.value
elif isinstance(event, Enum):
event: int = event.value
elif not isinstance(event, int):
raise RuntimeError(f"Binary event types must be integers, got {event}")
packed = struct.pack(">I", event)
@ -726,24 +736,7 @@ class PromptServer(ExecutorToClientProgress):
image_type = image_data[0]
image = image_data[1]
max_size = image_data[2]
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.Resampling.LANCZOS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1
if image_type == "JPEG":
type_num = 1
elif image_type == "PNG":
type_num = 2
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
preview_bytes = bytesIO.getvalue()
preview_bytes = encode_preview_image(image, image_type, max_size)
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_bytes(self, event, data, sid=None):

View File

@ -1,7 +1,8 @@
from __future__ import annotations # for Python 3.7-3.9
import PIL.Image
from typing_extensions import NotRequired, TypedDict
from typing import Optional, Literal, Protocol, TypeAlias, Union
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple
from .queue_types import BinaryEventTypes
@ -34,11 +35,17 @@ class ProgressMessage(TypedDict):
sid: NotRequired[str]
class UnencodedPreviewImageMessage(NamedTuple):
format: Literal["JPEG", "PNG"]
pil_image: PIL.Image.Image
max_size: int = 512
ExecutedMessage: TypeAlias = ExecutingMessage
SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None]
SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, bytes, bytearray, None]
SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
class ExecutorToClientProgress(Protocol):

View File

@ -1,14 +1,16 @@
from __future__ import annotations
import asyncio
import base64
from asyncio import AbstractEventLoop
from enum import Enum
from functools import partial
from typing import Optional, Dict, Any
from aio_pika.patterns import RPC
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \
UnencodedPreviewImageMessage
from ..component_model.queue_types import BinaryEventTypes
from ..utils import hijack_progress
@ -17,6 +19,8 @@ async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[
caller_server: Optional[ExecutorToClientProgress] = None) -> None:
assert caller_server is not None
assert user_id is not None
if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE or isinstance(data, str):
data: bytes = base64.b64decode(data)
caller_server.send_sync(event, data, sid=user_id)
@ -39,11 +43,19 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
return
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
from ..cmd.latent_preview_image_encoding import encode_preview_image
# encode preview image
event = BinaryEventTypes.PREVIEW_IMAGE.value
data: UnencodedPreviewImageMessage
format, pil_image, max_size = data
data: bytes = encode_preview_image(pil_image, format, max_size)
if isinstance(data, bytes) or isinstance(data, bytearray):
return
if isinstance(event, Enum):
event: int = event.value
data: str = base64.b64encode(data).decode()
if user_id is None:
# todo: user_id should never be none here

View File

@ -15,13 +15,13 @@ from aio_pika.patterns import JsonRPC
from .distributed_progress import ProgressHandlers
from .distributed_types import RpcRequest, RpcReply
from .history import History
from .server_stub import ServerStub
from ..auth.permissions import jwt_decode
from ..cmd.server import PromptServer
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
from .history import History
from ..cmd.server import PromptServer
class DistributedPromptQueue(AbstractPromptQueue):

View File

@ -281,6 +281,11 @@ KNOWN_DIFF_CONTROLNETS = [
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_seg_fp16.safetensors"),
]
KNOWN_APPROX_VAES = [
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors")
]
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
if args.disable_known_models: