Fix distributed previews, add sageattention to the docker container

This commit is contained in:
doctorpangloss 2025-09-12 11:48:25 -07:00
parent 421c9b88ae
commit 7cd6383110
6 changed files with 86 additions and 46 deletions

View File

@ -2,3 +2,4 @@
!comfy* !comfy*
!pyproject.toml !pyproject.toml
!README.md !README.md
!pkg/*

View File

@ -10,6 +10,7 @@ ENV UV_BREAK_SYSTEM_PACKAGES=1
ENV PIP_DISABLE_PIP_VERSION_CHECK=1 ENV PIP_DISABLE_PIP_VERSION_CHECK=1
ENV PIP_NO_CACHE_DIR=1 ENV PIP_NO_CACHE_DIR=1
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV UV_OVERRIDE=/workspace/overrides.txt
ENV LANG=C.UTF-8 ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8 ENV LC_ALL=C.UTF-8
@ -17,20 +18,22 @@ ENV LC_ALL=C.UTF-8
# mitigates # mitigates
# RuntimeError: Failed to import transformers.generation.utils because of the following error (look up to see its traceback): # RuntimeError: Failed to import transformers.generation.utils because of the following error (look up to see its traceback):
# numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject # numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject
RUN pip freeze | grep numpy > numpy-override.txt RUN echo "onnxruntime-gpu==1.22.0" >> /workspace/overrides.txt; pip freeze | grep nvidia >> /workspace/overrides.txt; echo "torch==2.7.0a0+7c8ec84dab.nv25.3" >> /workspace/overrides.txt; pip freeze | grep numpy >> /workspace/overrides.txt; echo "opencv-python; python_version < '0'" >> /workspace/overrides.txt; echo "opencv-contrib-python; python_version < '0'" >> /workspace/overrides.txt; echo "opencv-python-headless; python_version < '0'" >> /workspace/overrides.txt; echo "opencv-contrib-python-headless!=4.11.0.86" >> /workspace/overrides.txt; echo "sentry-sdk; python_version < '0'" >> /workspace/overrides.txt
# mitigates https://stackoverflow.com/questions/55313610/importerror-libgl-so-1-cannot-open-shared-object-file-no-such-file-or-directo
# mitigates AttributeError: module 'cv2.dnn' has no attribute 'DictValue' \ # mitigates AttributeError: module 'cv2.dnn' has no attribute 'DictValue' \
# see https://github.com/facebookresearch/nougat/issues/40 # see https://github.com/facebookresearch/nougat/issues/40
RUN apt-get update && \ RUN pip install uv && uv --version && \
apt-get install --no-install-recommends -y ffmpeg libsm6 libxext6 && \ apt-get update && apt-get install --no-install-recommends ffmpeg libsm6 libxext6 -y && \
pip install uv && uv --version && \ uv pip uninstall --system $(pip list --format=freeze | grep opencv) && \
apt-get purge -y && \ rm -rf /usr/local/lib/python3.12/dist-packages/cv2/ && \
uv pip install wheel && \
uv pip install --no-build-isolation "opencv-contrib-python-headless!=4.11.0.86" && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
RUN uv pip uninstall --system $(pip list --format=freeze | grep opencv) && \ # install sageattention
rm -rf /usr/local/lib/python3.12/dist-packages/cv2/ && \ ADD pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl /workspace/pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl
uv pip install --no-build-isolation opencv-python-headless RUN uv pip install --no-deps --no-build-isolation spandrel>=0.3.4 timm>=1.0.19 tensorboard>=2.17.0 poetry flash-attn xformers==0.0.31.post1 file:./pkg/sageattention-2.2.0-cp312-cp312-linux_x86_64.whl
# this exotic command will determine the correct torchaudio to install for the image # this exotic command will determine the correct torchaudio to install for the image
RUN <<-EOF RUN <<-EOF
python -c 'import torch, re, subprocess python -c 'import torch, re, subprocess
@ -41,7 +44,7 @@ if not torch_ver_match:
torch_ver = torch_ver_match.group(1) torch_ver = torch_ver_match.group(1)
cuda_ver_tag = f"cu{torch.version.cuda.replace(".", "")}" cuda_ver_tag = f"cu{torch.version.cuda.replace(".", "")}"
command = [ command = [
"uv", "pip", "install", "--no-deps", "--overrides=numpy-override.txt", "uv", "pip", "install", "--no-deps",
f"torchaudio=={torch_ver}+{cuda_ver_tag}", f"torchaudio=={torch_ver}+{cuda_ver_tag}",
"--extra-index-url", f"https://download.pytorch.org/whl/{cuda_ver_tag}", "--extra-index-url", f"https://download.pytorch.org/whl/{cuda_ver_tag}",
] ]
@ -50,17 +53,18 @@ EOF
# sources for building this dockerfile # sources for building this dockerfile
# use these lines to build from the local fs # use these lines to build from the local fs
# ADD . /src ADD . /workspace/src
# ARG SOURCES=/src ARG SOURCES="comfyui[attention,comfyui_manager]@./src"
# this builds from github # this builds from github
ARG SOURCES="comfyui[attention,comfyui_manager]@git+https://github.com/hiddenswitch/ComfyUI.git" #ARG SOURCES="comfyui[attention,comfyui_manager]@git+https://github.com/hiddenswitch/ComfyUI.git"
ENV SOURCES=$SOURCES ENV SOURCES=$SOURCES
RUN uv pip install --overrides=numpy-override.txt $SOURCES RUN uv pip install $SOURCES
WORKDIR /workspace WORKDIR /workspace
# addresses https://github.com/pytorch/pytorch/issues/104801 # addresses https://github.com/pytorch/pytorch/issues/104801
# and issues reported by importing nodes_canny # and issues reported by importing nodes_canny
RUN python -c "import torch; import xformers; import sageattention; import cv2"
RUN comfyui --quick-test-for-ci --cpu --cwd /workspace RUN comfyui --quick-test-for-ci --cpu --cwd /workspace
EXPOSE 8188 EXPOSE 8188

View File

@ -27,7 +27,7 @@ class LogInterceptor(io.TextIOWrapper):
# Simple handling for cr to overwrite the last output if it isnt a full line # Simple handling for cr to overwrite the last output if it isnt a full line
# else logs just get full of progress messages # else logs just get full of progress messages
if isinstance(data, str) and data.startswith("\r") and not logs[-1]["m"].endswith("\n"): if isinstance(data, str) and data.startswith("\r") and len(logs) > 0 and not logs[-1]["m"].endswith("\n"):
logs.pop() logs.pop()
logs.append(entry) logs.append(entry)
super().write(data) super().write(data)

View File

@ -2,25 +2,35 @@ from __future__ import annotations
import asyncio import asyncio
import base64 import base64
import pickle
from asyncio import AbstractEventLoop from asyncio import AbstractEventLoop
from enum import Enum
from functools import partial from functools import partial
from typing import Optional, Dict, Any, Union from typing import Optional, Dict, Any, TypeVar, NewType
from aio_pika import DeliveryMode
from aio_pika.patterns import RPC from aio_pika.patterns import RPC
from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \ from ..component_model.executor_types import SendSyncEvent, SendSyncData, ExecutorToClientProgress, \
UnencodedPreviewImageMessage, StatusMessage, QueueInfo, ExecInfo StatusMessage, QueueInfo, ExecInfo
from ..component_model.queue_types import BinaryEventTypes
T = TypeVar('T')
Base64Pickled = NewType('Base64Pickled', str)
async def _progress(event: SendSyncEvent, data: SendSyncData, user_id: Optional[str] = None, def obj2base64(obj: T) -> Base64Pickled:
return Base64Pickled(base64.b64encode(pickle.dumps(obj)).decode())
def base642obj(data: Base64Pickled) -> T:
return pickle.loads(base64.b64decode(data))
async def _progress(event: Base64Pickled, data: Base64Pickled, user_id: Optional[str] = None,
caller_server: Optional[ExecutorToClientProgress] = None) -> None: caller_server: Optional[ExecutorToClientProgress] = None) -> None:
assert caller_server is not None assert caller_server is not None
assert user_id is not None assert user_id is not None
if event == BinaryEventTypes.PREVIEW_IMAGE or event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE or isinstance(data, str):
data: bytes = base64.b64decode(data) caller_server.send_sync(base642obj(event), base642obj(data), sid=user_id)
caller_server.send_sync(event, data, sid=user_id)
def _get_name(queue_name: str, user_id: str) -> str: def _get_name(queue_name: str, user_id: str) -> str:
@ -43,27 +53,10 @@ class DistributedExecutorToClientProgress(ExecutorToClientProgress):
return True return True
async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None: async def send(self, event: SendSyncEvent, data: SendSyncData, user_id: Optional[str]) -> None:
if event == BinaryEventTypes.UNENCODED_PREVIEW_IMAGE: assert user_id is not None, f"event={event} data={data}"
from ..cmd.latent_preview_image_encoding import encode_preview_image
# encode preview image
event = BinaryEventTypes.PREVIEW_IMAGE.value
data: UnencodedPreviewImageMessage
format, pil_image, max_size, node_id, task_id = data
data: bytes = encode_preview_image(pil_image, format, max_size, node_id, task_id)
if isinstance(data, bytes) or isinstance(data, bytearray):
if isinstance(event, Enum):
event: int = event.value
data: str = base64.b64encode(data).decode()
if user_id is None:
# todo: user_id should never be none here
return
try: try:
# we don't need to await this coroutine # we don't need to await this coroutine
_ = asyncio.create_task(self._rpc.call(_get_name(self._queue_name, user_id), {"event": event, "data": data}, expiration=1000)) _ = asyncio.create_task(self._rpc.call(_get_name(self._queue_name, user_id), {"event": obj2base64(event), "data": obj2base64(data)}, expiration=1000, delivery_mode=DeliveryMode.NOT_PERSISTENT))
except asyncio.TimeoutError: except asyncio.TimeoutError:
# these can gracefully expire # these can gracefully expire
pass pass

View File

@ -5,7 +5,14 @@ services:
context: . context: .
dockerfile: Dockerfile dockerfile: Dockerfile
volumes: volumes:
# USING DOCKER MANAGED VOLUMES
- workspace_data:/workspace - workspace_data:/workspace
# OR: USE LOCAL DIRECTORIES
# Comment out the `- workspace_data...` line, then uncomment:
# - models:/workspace/models
# - custom_nodes:/workspace/custom_nodes
# - output:/workspace/output
# - input:/workspace/input
deploy: deploy:
replicas: 1 replicas: 1
resources: resources:
@ -16,7 +23,6 @@ services:
capabilities: [ gpu ] capabilities: [ gpu ]
environment: environment:
- COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI=amqp://guest:guest@rabbitmq:5672 - COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI=amqp://guest:guest@rabbitmq:5672
- COMFYUI_EXECUTOR_FACTORY=ProcessPoolExecutor
- COMFYUI_PANIC_WHEN=torch.cuda.OutOfMemoryError - COMFYUI_PANIC_WHEN=torch.cuda.OutOfMemoryError
- COMFYUI_LOGGING_LEVEL=ERROR - COMFYUI_LOGGING_LEVEL=ERROR
command: command:
@ -25,7 +31,7 @@ services:
test: curl -f http://localhost:9090/health test: curl -f http://localhost:9090/health
interval: 10s interval: 10s
timeout: 5s timeout: 5s
retries: 3 retries: 1
start_period: 10s start_period: 10s
restart: unless-stopped restart: unless-stopped
frontend: frontend:
@ -35,7 +41,14 @@ services:
deploy: deploy:
replicas: 1 replicas: 1
volumes: volumes:
# USING DOCKER MANAGED VOLUMES
- workspace_data:/workspace - workspace_data:/workspace
# OR: USE LOCAL DIRECTORIES
# Comment out the `- workspace_data...` line, then uncomment:
# - models:/workspace/models
# - custom_nodes:/workspace/custom_nodes
# - output:/workspace/output
# - input:/workspace/input
environment: environment:
- COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI=amqp://guest:guest@rabbitmq:5672 - COMFYUI_DISTRIBUTED_QUEUE_CONNECTION_URI=amqp://guest:guest@rabbitmq:5672
- COMFYUI_DISTRIBUTED_QUEUE_FRONTEND=1 - COMFYUI_DISTRIBUTED_QUEUE_FRONTEND=1
@ -54,5 +67,34 @@ services:
restart: unless-stopped restart: unless-stopped
rabbitmq: rabbitmq:
image: rabbitmq:3 image: rabbitmq:3
command: >
sh -c "echo 'log.default.level = error' > /etc/rabbitmq/rabbitmq.conf &&
docker-entrypoint.sh rabbitmq-server"
volumes: volumes:
# USING DOCKER MANAGED VOLUMES
workspace_data: {} workspace_data: {}
# OR: USE LOCAL DIRECTORIES
# models:
# driver: local
# driver_opts:
# type: 'none'
# o: 'bind'
# device: './models'
# custom_nodes:
# driver: local
# driver_opts:
# type: 'none'
# o: 'bind'
# device: './custom_nodes'
# output:
# driver: local
# driver_opts:
# type: 'none'
# o: 'bind'
# device: './output'
# input:
# driver: local
# driver_opts:
# type: 'none'
# o: 'bind'
# device: './input'

Binary file not shown.