mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-25 13:50:15 +08:00
Enable previews by default and over distributed channels
This commit is contained in:
parent
37cca051b6
commit
e49c662c7f
@ -111,7 +111,7 @@ def create_parser() -> argparse.ArgumentParser:
|
|||||||
parser.add_argument("--disable-ipex-optimize", action="store_true",
|
parser.add_argument("--disable-ipex-optimize", action="store_true",
|
||||||
help="Disables ipex.optimize when loading models with Intel GPUs.")
|
help="Disables ipex.optimize when loading models with Intel GPUs.")
|
||||||
|
|
||||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews,
|
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto,
|
||||||
help="Default preview method for sampler nodes.", action=EnumAction)
|
help="Default preview method for sampler nodes.", action=EnumAction)
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
# Define a class for your command-line arguments
|
# Define a class for your command-line arguments
|
||||||
import enum
|
import enum
|
||||||
from typing import Optional, List, Callable
|
from typing import Optional, List, Callable, Literal
|
||||||
import configargparse as argparse
|
import configargparse as argparse
|
||||||
|
|
||||||
ConfigurationExtender = Callable[[argparse.ArgParser], Optional[argparse.ArgParser]]
|
ConfigurationExtender = Callable[[argparse.ArgParser], Optional[argparse.ArgParser]]
|
||||||
@ -50,7 +50,7 @@ class Configuration(dict):
|
|||||||
fp32_text_enc (bool): Use FP32 precision for the text encoder.
|
fp32_text_enc (bool): Use FP32 precision for the text encoder.
|
||||||
directml (Optional[int]): Use DirectML. -1 for auto-selection.
|
directml (Optional[int]): Use DirectML. -1 for auto-selection.
|
||||||
disable_ipex_optimize (bool): Disable IPEX optimization for Intel GPUs.
|
disable_ipex_optimize (bool): Disable IPEX optimization for Intel GPUs.
|
||||||
preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "none".
|
preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "auto".
|
||||||
use_split_cross_attention (bool): Use split cross-attention optimization.
|
use_split_cross_attention (bool): Use split cross-attention optimization.
|
||||||
use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization.
|
use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization.
|
||||||
use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function.
|
use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function.
|
||||||
@ -116,7 +116,7 @@ class Configuration(dict):
|
|||||||
self.fp32_text_enc: bool = False
|
self.fp32_text_enc: bool = False
|
||||||
self.directml: Optional[int] = None
|
self.directml: Optional[int] = None
|
||||||
self.disable_ipex_optimize: bool = False
|
self.disable_ipex_optimize: bool = False
|
||||||
self.preview_method: str = "none"
|
self.preview_method: LatentPreviewMethod = LatentPreviewMethod.Auto
|
||||||
self.use_split_cross_attention: bool = False
|
self.use_split_cross_attention: bool = False
|
||||||
self.use_quad_cross_attention: bool = False
|
self.use_quad_cross_attention: bool = False
|
||||||
self.use_pytorch_cross_attention: bool = False
|
self.use_pytorch_cross_attention: bool = False
|
||||||
|
|||||||
@ -1,8 +1,11 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
from ..cli_args_types import LatentPreviewMethod
|
from ..cli_args_types import LatentPreviewMethod
|
||||||
|
from ..model_downloader import get_or_download, KNOWN_APPROX_VAES
|
||||||
from ..taesd.taesd import TAESD
|
from ..taesd.taesd import TAESD
|
||||||
from ..cmd import folder_paths
|
from ..cmd import folder_paths
|
||||||
from .. import utils
|
from .. import utils
|
||||||
@ -59,7 +62,7 @@ def get_previewer(device, latent_format):
|
|||||||
if fn.startswith(latent_format.taesd_decoder_name)),
|
if fn.startswith(latent_format.taesd_decoder_name)),
|
||||||
""
|
""
|
||||||
)
|
)
|
||||||
taesd_decoder_path = folder_paths.get_full_path("vae_approx", taesd_decoder_path)
|
taesd_decoder_path = get_or_download("vae_approx", taesd_decoder_path, KNOWN_APPROX_VAES)
|
||||||
|
|
||||||
if method == LatentPreviewMethod.Auto:
|
if method == LatentPreviewMethod.Auto:
|
||||||
method = LatentPreviewMethod.Latent2RGB
|
method = LatentPreviewMethod.Latent2RGB
|
||||||
|
|||||||
29
comfy/cmd/latent_preview_image_encoding.py
Normal file
29
comfy/cmd/latent_preview_image_encoding.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import struct
|
||||||
|
from io import BytesIO
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
|
||||||
|
|
||||||
|
def encode_preview_image(image: PIL.Image.Image, image_type: Literal["JPEG", "PNG"], max_size: int):
|
||||||
|
if max_size is not None:
|
||||||
|
if hasattr(Image, 'Resampling'):
|
||||||
|
resampling = Image.Resampling.BILINEAR
|
||||||
|
else:
|
||||||
|
resampling = Image.Resampling.LANCZOS
|
||||||
|
|
||||||
|
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
||||||
|
type_num = 1
|
||||||
|
if image_type == "JPEG":
|
||||||
|
type_num = 1
|
||||||
|
elif image_type == "PNG":
|
||||||
|
type_num = 2
|
||||||
|
bytesIO = BytesIO()
|
||||||
|
header = struct.pack(">I", type_num)
|
||||||
|
bytesIO.write(header)
|
||||||
|
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
||||||
|
preview_bytes = bytesIO.getvalue()
|
||||||
|
return preview_bytes
|
||||||
@ -11,21 +11,24 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
import uuid
|
import uuid
|
||||||
from asyncio import Future, AbstractEventLoop
|
from asyncio import Future, AbstractEventLoop
|
||||||
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import List, Optional, Dict
|
from typing import List, Optional, Dict
|
||||||
from urllib.parse import quote, urlencode
|
from urllib.parse import quote, urlencode
|
||||||
from posixpath import join as urljoin
|
from posixpath import join as urljoin
|
||||||
|
|
||||||
from can_ada import URL, parse as urlparse
|
from can_ada import URL, parse as urlparse
|
||||||
|
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from PIL import Image, ImageOps
|
from PIL import Image
|
||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from pkg_resources import resource_filename
|
from pkg_resources import resource_filename
|
||||||
from typing_extensions import NamedTuple
|
from typing_extensions import NamedTuple
|
||||||
|
|
||||||
import comfy.interruption
|
import comfy.interruption
|
||||||
|
from .latent_preview_image_encoding import encode_preview_image
|
||||||
from .. import model_management
|
from .. import model_management
|
||||||
from .. import utils
|
from .. import utils
|
||||||
from ..app.user_manager import UserManager
|
from ..app.user_manager import UserManager
|
||||||
@ -713,8 +716,15 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
else:
|
else:
|
||||||
await self.send_json(event, data, sid)
|
await self.send_json(event, data, sid)
|
||||||
|
|
||||||
def encode_bytes(self, event, data):
|
def encode_bytes(self, event: int | Enum | str, data):
|
||||||
if not isinstance(event, int):
|
# todo: investigate what is propagating these spurious, string-repr'd previews
|
||||||
|
if event == repr(BinaryEventTypes.UNENCODED_PREVIEW_IMAGE):
|
||||||
|
event = BinaryEventTypes.UNENCODED_PREVIEW_IMAGE.value
|
||||||
|
elif event == repr(BinaryEventTypes.PREVIEW_IMAGE):
|
||||||
|
event = BinaryEventTypes.PREVIEW_IMAGE.value
|
||||||
|
elif isinstance(event, Enum):
|
||||||
|
event: int = event.value
|
||||||
|
elif not isinstance(event, int):
|
||||||
raise RuntimeError(f"Binary event types must be integers, got {event}")
|
raise RuntimeError(f"Binary event types must be integers, got {event}")
|
||||||
|
|
||||||
packed = struct.pack(">I", event)
|
packed = struct.pack(">I", event)
|
||||||
@ -726,24 +736,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
image_type = image_data[0]
|
image_type = image_data[0]
|
||||||
image = image_data[1]
|
image = image_data[1]
|
||||||
max_size = image_data[2]
|
max_size = image_data[2]
|
||||||
if max_size is not None:
|
preview_bytes = encode_preview_image(image, image_type, max_size)
|
||||||
if hasattr(Image, 'Resampling'):
|
|
||||||
resampling = Image.Resampling.BILINEAR
|
|
||||||
else:
|
|
||||||
resampling = Image.Resampling.LANCZOS
|
|
||||||
|
|
||||||
image = ImageOps.contain(image, (max_size, max_size), resampling)
|
|
||||||
type_num = 1
|
|
||||||
if image_type == "JPEG":
|
|
||||||
type_num = 1
|
|
||||||
elif image_type == "PNG":
|
|
||||||
type_num = 2
|
|
||||||
|
|
||||||
bytesIO = BytesIO()
|
|
||||||
header = struct.pack(">I", type_num)
|
|
||||||
bytesIO.write(header)
|
|
||||||
image.save(bytesIO, format=image_type, quality=95, compress_level=1)
|
|
||||||
preview_bytes = bytesIO.getvalue()
|
|
||||||
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
await self.send_bytes(BinaryEventTypes.PREVIEW_IMAGE, preview_bytes, sid=sid)
|
||||||
|
|
||||||
async def send_bytes(self, event, data, sid=None):
|
async def send_bytes(self, event, data, sid=None):
|
||||||
|
|||||||
@ -1,7 +1,8 @@
|
|||||||
from __future__ import annotations # for Python 3.7-3.9
|
from __future__ import annotations # for Python 3.7-3.9
|
||||||
|
|
||||||
|
import PIL.Image
|
||||||
from typing_extensions import NotRequired, TypedDict
|
from typing_extensions import NotRequired, TypedDict
|
||||||
from typing import Optional, Literal, Protocol, TypeAlias, Union
|
from typing import Optional, Literal, Protocol, TypeAlias, Union, NamedTuple
|
||||||
|
|
||||||
from .queue_types import BinaryEventTypes
|
from .queue_types import BinaryEventTypes
|
||||||
|
|
||||||
@ -34,11 +35,17 @@ class ProgressMessage(TypedDict):
|
|||||||
sid: NotRequired[str]
|
sid: NotRequired[str]
|
||||||
|
|
||||||
|
|
||||||
|
class UnencodedPreviewImageMessage(NamedTuple):
|
||||||
|
format: Literal["JPEG", "PNG"]
|
||||||
|
pil_image: PIL.Image.Image
|
||||||
|
max_size: int = 512
|
||||||
|
|
||||||
|
|
||||||
ExecutedMessage: TypeAlias = ExecutingMessage
|
ExecutedMessage: TypeAlias = ExecutingMessage
|
||||||
|
|
||||||
SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None]
|
SendSyncEvent: TypeAlias = Union[Literal["status", "executing", "progress", "executed"], BinaryEventTypes, None]
|
||||||
|
|
||||||
SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, bytes, bytearray, None]
|
SendSyncData: TypeAlias = Union[StatusMessage, ExecutingMessage, ProgressMessage, UnencodedPreviewImageMessage, bytes, bytearray, str, None]
|
||||||
|
|
||||||
|
|
||||||
class ExecutorToClientProgress(Protocol):
|
class ExecutorToClientProgress(Protocol):
|
||||||
|
|||||||
@ -1,14 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
from asyncio import AbstractEventLoop
|
from asyncio import AbstractEventLoop
|
||||||
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from typing import Optional, Dict, Any
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
from aio_pika.patterns import RPC
|
from aio_pika.patterns import RPC
|
||||||
|
|
||||||
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress
|
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \
|
||||||
|
UnencodedPreviewImageMessage
|
||||||
from ..component_model.queue_types import BinaryEventTypes
|
from ..component_model.queue_types import BinaryEventTypes
|
||||||
from ..utils import hijack_progress
|
from ..utils import hijack_progress
|
||||||
|
|
||||||
@ -17,6 +19,8 @@ async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[
|
|||||||
caller_server: Optional[ExecutorToClientProgress] = None) -> None:
|
caller_server: Optional[ExecutorToClientProgress] = None) -> None:
|
||||||
assert caller_server is not None
|
assert caller_server is not None
|
||||||
assert user_id is not None
|
assert user_id is not None
|
||||||
|
if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE or isinstance(data, str):
|
||||||
|
data: bytes = base64.b64decode(data)
|
||||||
caller_server.send_sync(event, data, sid=user_id)
|
caller_server.send_sync(event, data, sid=user_id)
|
||||||
|
|
||||||
|
|
||||||
@ -39,11 +43,19 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
|
|||||||
|
|
||||||
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
|
||||||
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
|
# for now, do not send binary data this way, since it cannot be json serialized / it's impractical
|
||||||
if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE:
|
||||||
return
|
from ..cmd.latent_preview_image_encoding import encode_preview_image
|
||||||
|
|
||||||
|
# encode preview image
|
||||||
|
event = BinaryEventTypes.PREVIEW_IMAGE.value
|
||||||
|
data: UnencodedPreviewImageMessage
|
||||||
|
format, pil_image, max_size = data
|
||||||
|
data: bytes = encode_preview_image(pil_image, format, max_size)
|
||||||
|
|
||||||
if isinstance(data, bytes) or isinstance(data, bytearray):
|
if isinstance(data, bytes) or isinstance(data, bytearray):
|
||||||
return
|
if isinstance(event, Enum):
|
||||||
|
event: int = event.value
|
||||||
|
data: str = base64.b64encode(data).decode()
|
||||||
|
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
# todo: user_id should never be none here
|
# todo: user_id should never be none here
|
||||||
|
|||||||
@ -15,13 +15,13 @@ from aio_pika.patterns import JsonRPC
|
|||||||
|
|
||||||
from .distributed_progress import ProgressHandlers
|
from .distributed_progress import ProgressHandlers
|
||||||
from .distributed_types import RpcRequest, RpcReply
|
from .distributed_types import RpcRequest, RpcReply
|
||||||
|
from .history import History
|
||||||
from .server_stub import ServerStub
|
from .server_stub import ServerStub
|
||||||
from ..auth.permissions import jwt_decode
|
from ..auth.permissions import jwt_decode
|
||||||
|
from ..cmd.server import PromptServer
|
||||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
|
from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData
|
||||||
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
|
from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation
|
||||||
from .history import History
|
|
||||||
from ..cmd.server import PromptServer
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedPromptQueue(AbstractPromptQueue):
|
class DistributedPromptQueue(AbstractPromptQueue):
|
||||||
|
|||||||
@ -281,6 +281,11 @@ KNOWN_DIFF_CONTROLNETS = [
|
|||||||
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_seg_fp16.safetensors"),
|
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_seg_fp16.safetensors"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
KNOWN_APPROX_VAES = [
|
||||||
|
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
|
||||||
|
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
|
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
|
||||||
if args.disable_known_models:
|
if args.disable_known_models:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user