Enable previews by default and over distributed channels

This commit is contained in:
doctorpangloss 2024-04-09 13:15:05 -07:00
parent 37cca051b6
commit e49c662c7f
9 changed files with 84 additions and 35 deletions

View File

@ -111,7 +111,7 @@ def create_parser() -> argparse.ArgumentParser:
parser.add_argument("--disable-ipex-optimize", action="store_true", parser.add_argument("--disable-ipex-optimize", action="store_true",
help="Disables ipex.optimize when loading models with Intel GPUs.") help="Disables ipex.optimize when loading models with Intel GPUs.")
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto,
help="Default preview method for sampler nodes.", action=EnumAction) help="Default preview method for sampler nodes.", action=EnumAction)
attn_group = parser.add_mutually_exclusive_group() attn_group = parser.add_mutually_exclusive_group()

View File

@ -1,6 +1,6 @@
# Define a class for your command-line arguments # Define a class for your command-line arguments
import enum import enum
from typing import Optional, List, Callable from typing import Optional, List, Callable, Literal
import configargparse as argparse import configargparse as argparse
ConfigurationExtender = Callable[[argparse.ArgParser], Optional[argparse.ArgParser]] ConfigurationExtender = Callable[[argparse.ArgParser], Optional[argparse.ArgParser]]
@ -50,7 +50,7 @@ class Configuration(dict):
fp32_text_enc (bool): Use FP32 precision for the text encoder. fp32_text_enc (bool): Use FP32 precision for the text encoder.
directml (Optional[int]): Use DirectML. -1 for auto-selection. directml (Optional[int]): Use DirectML. -1 for auto-selection.
disable_ipex_optimize (bool): Disable IPEX optimization for Intel GPUs. disable_ipex_optimize (bool): Disable IPEX optimization for Intel GPUs.
preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "none". preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "auto".
use_split_cross_attention (bool): Use split cross-attention optimization. use_split_cross_attention (bool): Use split cross-attention optimization.
use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization. use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization.
use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function. use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function.
@ -116,7 +116,7 @@ class Configuration(dict):
self.fp32_text_enc: bool = False self.fp32_text_enc: bool = False
self.directml: Optional[int] = None self.directml: Optional[int] = None
self.disable_ipex_optimize: bool = False self.disable_ipex_optimize: bool = False
self.preview_method: str = "none" self.preview_method: LatentPreviewMethod = LatentPreviewMethod.Auto
self.use_split_cross_attention: bool = False self.use_split_cross_attention: bool = False
self.use_quad_cross_attention: bool = False self.use_quad_cross_attention: bool = False
self.use_pytorch_cross_attention: bool = False self.use_pytorch_cross_attention: bool = False

View File

@ -1,8 +1,11 @@
from __future__ import annotations
import torch import torch
from PIL import Image from PIL import Image
import numpy as np import numpy as np
from ..cli_args import args from ..cli_args import args
from ..cli_args_types import LatentPreviewMethod from ..cli_args_types import LatentPreviewMethod
from ..model_downloader import get_or_download, KNOWN_APPROX_VAES
from ..taesd.taesd import TAESD from ..taesd.taesd import TAESD
from ..cmd import folder_paths from ..cmd import folder_paths
from .. import utils from .. import utils
@ -59,7 +62,7 @@ def get_previewer(device, latent_format):
if fn.startswith(latent_format.taesd_decoder_name)), if fn.startswith(latent_format.taesd_decoder_name)),
"" ""
) )
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path) taesd_decoder_path = get_or_download("vae_approx", taesd_decoder_path, KNOWN_APPROX_VAES)
if method == LatentPreviewMethod.Auto: if method == LatentPreviewMethod.Auto:
method = LatentPreviewMethod.Latent2RGB method = LatentPreviewMethod.Latent2RGB

View File

@ -0,0 +1,29 @@
from __future__ import annotations
import struct
from io import BytesIO
from typing import Literal
import PIL.Image
from PIL import Image, ImageOps
def encode_preview_image(image: PIL.Image.Image, image_type: Literal["JPEG", "PNG"], max_size: int):
if max_size is not None:
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.Resampling.LANCZOS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1
if image_type == "JPEG":
type_num = 1
elif image_type == "PNG":
type_num = 2
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
preview_bytes = bytesIO.getvalue()
return preview_bytes

View File

@ -11,21 +11,24 @@ import sys
import traceback import traceback
import uuid import uuid
from asyncio import Future, AbstractEventLoop from asyncio import Future, AbstractEventLoop
from enum import Enum
from io import BytesIO from io import BytesIO
from typing import List, Optional, Dict from typing import List, Optional, Dict
from urllib.parse import quote, urlencode from urllib.parse import quote, urlencode
from posixpath import join as urljoin from posixpath import join as urljoin
from can_ada import URL, parse as urlparse from can_ada import URL, parse as urlparse
import aiofiles import aiofiles
import aiohttp import aiohttp
from PIL import Image, ImageOps from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from aiohttp import web from aiohttp import web
from pkg_resources import resource_filename from pkg_resources import resource_filename
from typing_extensions import NamedTuple from typing_extensions import NamedTuple
import comfy.interruption import comfy.interruption
from .latent_preview_image_encoding import encode_preview_image
from .. import model_management from .. import model_management
from .. import utils from .. import utils
from ..app.user_manager import UserManager from ..app.user_manager import UserManager
@ -713,8 +716,15 @@ class PromptServer(ExecutorToClientProgress):
else: else:
await self.send_json(event, data, sid) await self.send_json(event, data, sid)
def encode_bytes(self, event, data): def encode_bytes(self, event: int | Enum | str, data):
if not isinstance(event, int): # todo: investigate what is propagating these spurious, string-repr'd previews
if event == repr(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE):
event = BinaryEventTypes.UNENCODED_PREVIEW_IMAGE.value
elif event == repr(BinaryEventTypes.PREVIEW_IMAGE):
event = BinaryEventTypes.PREVIEW_IMAGE.value
elif isinstance(event, Enum):
event: int = event.value
elif not isinstance(event, int):
raise RuntimeError(f"Binary event types must be integers, got {event}") raise RuntimeError(f"Binary event types must be integers, got {event}")
packed = struct.pack(">I", event) packed = struct.pack(">I", event)
@ -726,24 +736,7 @@ class PromptServer(ExecutorToClientProgress):
image_type = image_data[0] image_type = image_data[0]
image = image_data[1] image = image_data[1]
max_size = image_data[2] max_size = image_data[2]
if max_size is not None: preview_bytes = encode_preview_image(image, image_type, max_size)
if hasattr(Image, 'Resampling'):
resampling = Image.Resampling.BILINEAR
else:
resampling = Image.Resampling.LANCZOS
image = ImageOps.contain(image, (max_size, max_size), resampling)
type_num = 1
if image_type == "JPEG":
type_num = 1
elif image_type == "PNG":
type_num = 2
bytesIO = BytesIO()
header = struct.pack(">I", type_num)
bytesIO.write(header)
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
preview_bytes = bytesIO.getvalue()
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid) await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
async def send_bytes(self, event, data, sid=None): async def send_bytes(self, event, data, sid=None):

View File

@ -1,7 +1,8 @@
from __future__ import annotations # for Python 3.7-3.9 from __future__ import annotations # for Python 3.7-3.9
import PIL.Image
from typing_extensions import NotRequired, TypedDict from typing_extensions import NotRequired, TypedDict
from typing import Optional, Literal, Protocol, TypeAlias, Union from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple
from .queue_types import BinaryEventTypes from .queue_types import BinaryEventTypes
@ -34,11 +35,17 @@ class ProgressMessage(TypedDict):
sid: NotRequired[str] sid: NotRequired[str]
class UnencodedPreviewImageMessage(NamedTuple):
format: Literal["JPEG", "PNG"]
pil_image: PIL.Image.Image
max_size: int = 512
ExecutedMessage: TypeAlias = ExecutingMessage ExecutedMessage: TypeAlias = ExecutingMessage
SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None] SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None]
SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, bytes, bytearray, None] SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
class ExecutorToClientProgress(Protocol): class ExecutorToClientProgress(Protocol):

View File

@ -1,14 +1,16 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import base64
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from enum import Enum
from functools import partial from functools import partial
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from aio_pika.patterns import RPC from aio_pika.patterns import RPC
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \
UnencodedPreviewImageMessage
from ..component_model.queue_types import BinaryEventTypes from ..component_model.queue_types import BinaryEventTypes
from ..utils import hijack_progress from ..utils import hijack_progress
@ -17,6 +19,8 @@ async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[
caller_server: Optional[ExecutorToClientProgress] = None) -> None: caller_server: Optional[ExecutorToClientProgress] = None) -> None:
assert caller_server is not None assert caller_server is not None
assert user_id is not None assert user_id is not None
if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE or isinstance(data, str):
data: bytes = base64.b64decode(data)
caller_server.send_sync(event, data, sid=user_id) caller_server.send_sync(event, data, sid=user_id)
@ -39,11 +43,19 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None: async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical # for now, do not send binary data this way, since it cannot be json serialized / it's impractical
if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
return from ..cmd.latent_preview_image_encoding import encode_preview_image
# encode preview image
event = BinaryEventTypes.PREVIEW_IMAGE.value
data: UnencodedPreviewImageMessage
format, pil_image, max_size = data
data: bytes = encode_preview_image(pil_image, format, max_size)
if isinstance(data, bytes) or isinstance(data, bytearray): if isinstance(data, bytes) or isinstance(data, bytearray):
return if isinstance(event, Enum):
event: int = event.value
data: str = base64.b64encode(data).decode()
if user_id is None: if user_id is None:
# todo: user_id should never be none here # todo: user_id should never be none here

View File

@ -15,13 +15,13 @@ from aio_pika.patterns import JsonRPC
from .distributed_progress import ProgressHandlers from .distributed_progress import ProgressHandlers
from .distributed_types import RpcRequest, RpcReply from .distributed_types import RpcRequest, RpcReply
from .history import History
from .server_stub import ServerStub from .server_stub import ServerStub
from ..auth.permissions import jwt_decode from ..auth.permissions import jwt_decode
from ..cmd.server import PromptServer
from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
from .history import History
from ..cmd.server import PromptServer
class DistributedPromptQueue(AbstractPromptQueue): class DistributedPromptQueue(AbstractPromptQueue):

View File

@ -281,6 +281,11 @@ KNOWN_DIFF_CONTROLNETS = [
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_seg_fp16.safetensors"), HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_seg_fp16.safetensors"),
] ]
KNOWN_APPROX_VAES = [
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors")
]
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]: def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
if args.disable_known_models: if args.disable_known_models: