Fix AuraFlow

This commit is contained in:
doctorpangloss 2024-07-15 15:29:49 -07:00
parent 3d1d833e6f
commit a20bf8134d
11 changed files with 222 additions and 60 deletions

View File

@ -64,7 +64,7 @@ ignore-patterns=^\.#
# manipulated during runtime and thus existing member attributes cannot be
# deduced by static analysis). It supports qualified module names, as well as
# Unix pattern matching.
ignored-modules=
ignored-modules=sentencepiece.*
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().

View File

@ -85,9 +85,10 @@ def create_cors_middleware(allowed_origin: str):
class PromptServer(ExecutorToClientProgress):
instance: 'PromptServer'
instance: Optional['PromptServer'] = None
def __init__(self, loop):
# todo: this really needs to be set up differently, because sometimes the prompt server will not be initialized
PromptServer.instance = self
mimetypes.init()

View File

@ -0,0 +1,43 @@
class _RoutesWrapper:
def __init__(self):
self.routes = []
def _decorator_factory(self, method):
def decorator(path):
def wrapper(func):
from comfy.cmd.server import PromptServer
if PromptServer.instance is not None:
getattr(PromptServer.instance.routes, method)(path)(func)
self.routes.append((method, path, func))
return func
return wrapper
return decorator
def get(self, path):
return self._decorator_factory('get')(path)
def post(self, path):
return self._decorator_factory('post')(path)
def put(self, path):
return self._decorator_factory('put')(path)
def delete(self, path):
return self._decorator_factory('delete')(path)
def patch(self, path):
return self._decorator_factory('patch')(path)
def head(self, path):
return self._decorator_factory('head')(path)
def options(self, path):
return self._decorator_factory('options')(path)
def route(self, method, path):
return self._decorator_factory(method.lower())(path)
prompt_server_instance_routes = _RoutesWrapper()

View File

@ -35,7 +35,7 @@ class DistributedPromptWorker:
self._loop = loop or asyncio.get_event_loop()
self._embedded_comfy_client = embedded_comfy_client
self._health_check_port = health_check_port
self._health_check_site = None
self._health_check_site: Optional[web.TCPSite] = None
async def _health_check(self, request):
return web.Response(text="OK", content_type="text/plain")

View File

@ -7,7 +7,7 @@ from functools import reduce
from itertools import chain
from os.path import join
from pathlib import Path
from typing import List, Any, Optional, Union, Sequence
from typing import List, Any, Optional, Sequence, Final, Set
import tqdm
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
@ -20,7 +20,7 @@ from .cli_args import args
from .cmd import folder_paths
from .component_model.deprecation import _deprecate_method
from .interruption import InterruptProcessingException
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_, Downloadable
from .utils import ProgressBar, comfy_tqdm
_session = Session()
@ -159,7 +159,7 @@ Visit the repository, accept the terms, and then do one of the following:
return path
KNOWN_CHECKPOINTS = [
KNOWN_CHECKPOINTS: Final[List[Downloadable]] = [
HuggingFile("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors"),
HuggingFile("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors"),
HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors"),
@ -193,32 +193,33 @@ KNOWN_CHECKPOINTS = [
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium.safetensors"),
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips.safetensors"),
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips_t5xxlfp8.safetensors"),
HuggingFile("fal/AuraFlow", filename="aura_flow_0.1.safetensors"),
]
KNOWN_UNCLIP_CHECKPOINTS = [
KNOWN_UNCLIP_CHECKPOINTS: Final[List[Downloadable]] = [
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"),
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-h.ckpt"),
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-l.ckpt"),
]
KNOWN_IMAGE_ONLY_CHECKPOINTS = [
KNOWN_IMAGE_ONLY_CHECKPOINTS: Final[List[Downloadable]] = [
HuggingFile("stabilityai/stable-zero123", "stable_zero123.ckpt")
]
KNOWN_UPSCALERS = [
KNOWN_UPSCALERS: Final[List[Downloadable]] = [
HuggingFile("lllyasviel/Annotators", "RealESRGAN_x4plus.pth")
]
KNOWN_GLIGEN_MODELS = [
KNOWN_GLIGEN_MODELS: Final[List[Downloadable]] = [
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned.safetensors", show_in_ui=False),
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned_fp16.safetensors"),
]
KNOWN_CLIP_VISION_MODELS = [
KNOWN_CLIP_VISION_MODELS: Final[List[Downloadable]] = [
HuggingFile("comfyanonymous/clip_vision_g", "clip_vision_g.safetensors")
]
KNOWN_LORAS = [
KNOWN_LORAS: Final[List[Downloadable]] = [
CivitFile(model_id=211577, model_version_id=238349, filename="openxl_handsfix.safetensors"),
CivitFile(model_id=324815, model_version_id=364137, filename="blur_control_xl_v1.safetensors"),
CivitFile(model_id=47085, model_version_id=55199, filename="GoodHands-beta2.safetensors"),
@ -226,7 +227,7 @@ KNOWN_LORAS = [
HuggingFile("ByteDance/Hyper-SD", "Hyper-SD15-12steps-CFG-lora.safetensors"),
]
KNOWN_CONTROLNETS = [
KNOWN_CONTROLNETS: Final[List[Downloadable]] = [
HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "OpenPoseXL2.safetensors", convert_to_16_bit=True, size=2502139104),
HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "control-lora-openposeXL2-rank256.safetensors"),
HuggingFile("comfyanonymous/ControlNet-v1-1_fp16_safetensors", "control_lora_rank128_v11e_sd15_ip2p_fp16.safetensors"),
@ -316,7 +317,7 @@ KNOWN_CONTROLNETS = [
HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"),
]
KNOWN_DIFF_CONTROLNETS = [
KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_canny_fp16.safetensors"),
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_depth_fp16.safetensors"),
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_hed_fp16.safetensors"),
@ -327,28 +328,28 @@ KNOWN_DIFF_CONTROLNETS = [
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_seg_fp16.safetensors"),
]
KNOWN_APPROX_VAES = [
KNOWN_APPROX_VAES: Final[List[Downloadable]] = [
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"),
]
KNOWN_VAES = [
KNOWN_VAES: Final[List[Downloadable]] = [
HuggingFile("stabilityai/sdxl-vae", "sdxl_vae.safetensors"),
HuggingFile("stabilityai/sd-vae-ft-mse-original", "vae-ft-mse-840000-ema-pruned.safetensors"),
]
KNOWN_HUGGINGFACE_MODEL_REPOS = {
KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
"JingyeChen22/textdiffuser2_layout_planner",
'JingyeChen22/textdiffuser2-full-ft',
"microsoft/Phi-3-mini-4k-instruct",
"llava-hf/llava-v1.6-mistral-7b-hf"
}
KNOWN_UNET_MODELS: List[Union[CivitFile | HuggingFile]] = [
KNOWN_UNET_MODELS: Final[List[Downloadable]] = [
HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors")
]
KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = [
KNOWN_CLIP_MODELS: Final[List[Downloadable]] = [
# todo: is this correct?
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp16.safetensors", save_with_filename="t5xxl_fp16.safetensors"),
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp8_e4m3fn.safetensors", save_with_filename="t5xxl_fp8_e4m3fn.safetensors"),
@ -357,7 +358,7 @@ KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = [
]
def add_known_models(folder_name: str, known_models: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
def add_known_models(folder_name: str, known_models: List[Downloadable], *models: Downloadable) -> List[Downloadable]:
if len(models) < 1:
return known_models

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import dataclasses
from os.path import split
from typing import Optional, List, Sequence
from typing import Optional, List, Sequence, Union
from typing_extensions import TypedDict, NotRequired
@ -152,3 +152,6 @@ class CivitModelsGetResponse(TypedDict):
creator: CivitCreator
tags: List[str]
modelVersions: List[CivitModelVersion]
Downloadable = Union[CivitFile | HuggingFile]

View File

@ -8,7 +8,7 @@ import os
import traceback
import zipfile
from importlib.abc import Traversable
from typing import Tuple, Sequence, TypeVar
from typing import Tuple, Sequence, TypeVar, Callable
import torch
from transformers import CLIPTokenizer, PreTrainedTokenizerBase, SpecialTokensMixin
@ -18,6 +18,7 @@ from . import model_management
from . import ops
from .component_model import files
from .component_model.files import get_path_as_dict, get_package_as_path
from .text_encoders.llama_tokenizer import LLAMATokenizer
def gen_empty_tokens(special_tokens, length):
@ -32,6 +33,7 @@ def gen_empty_tokens(special_tokens, length):
output += [pad_token] * (length - len(output))
return output
class ClipTokenWeightEncoder:
def encode_token_weights(self, token_weight_pairs):
to_encode = list()
@ -44,10 +46,12 @@ class ClipTokenWeightEncoder:
to_encode.append(tokens)
sections = len(to_encode)
if has_weights or sections == 0:
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))
if has_weights or sections == 0 and hasattr(self, "special_tokens"):
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) # pylint: disable=no-member
o = self.encode(to_encode)
assert hasattr(self, "encode")
assert isinstance(self.encode, Callable) # pylint: disable=no-member
o = self.encode(to_encode) # pylint: disable=no-member
out, pooled = o[:2]
if pooled is not None:
@ -83,6 +87,7 @@ class ClipTokenWeightEncoder:
r = r + (extra,)
return r
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
@ -446,12 +451,13 @@ class SDTokenizer:
if isinstance(tokenizer_path, Traversable):
contextlib_path = importlib.resources.as_file(tokenizer_path)
tokenizer_path = contextlib_path.__enter__()
if not tokenizer_path.endswith(".model") and not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
# package based
tokenizer_path = str(tokenizer_path)
if issubclass(tokenizer_class, CLIPTokenizer) and not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
# assumes sd1_tokenizer
tokenizer_path = get_package_as_path('comfy.sd1_tokenizer')
self.tokenizer_class = tokenizer_class
self.tokenizer_path = tokenizer_path
self.tokenizer: PreTrainedTokenizerBase = tokenizer_class.from_pretrained(tokenizer_path)
self.tokenizer: PreTrainedTokenizerBase | LLAMATokenizer = tokenizer_class.from_pretrained(tokenizer_path)
self.max_length = max_length
self.min_length = min_length
@ -545,7 +551,7 @@ class SDTokenizer:
continue
# parse word
exact_word = f"{word}</w>"
if word == self.tokenizer.eos_token:
if hasattr(self.tokenizer, "eos_token") and word == self.tokenizer.eos_token:
tokenizer_result = [self.tokenizer.eos_token_id]
elif exact_word in vocab:
tokenizer_result = [vocab[exact_word]]

View File

@ -1,22 +1,24 @@
import os
class LLAMATokenizer:
# todo: not sure why we're not using the tokenizer from transformers for this
@staticmethod
def from_pretrained(path):
return LLAMATokenizer(path)
def __init__(self, tokenizer_path):
import sentencepiece
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path)
self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path) # pylint: disable=unexpected-keyword-arg
self.end = self.tokenizer.eos_id()
self.eos_token_id = self.end
self.eos_token = self.tokenizer.id_to_piece(self.eos_token_id)
self._vocab = {
self.tokenizer.id_to_piece(i): i for i in range(self.tokenizer.get_piece_size())
}
def get_vocab(self):
out = {}
for i in range(self.tokenizer.get_piece_size()):
out[self.tokenizer.id_to_piece(i)] = i
return out
return self._vocab
def __call__(self, string):
out = self.tokenizer.encode(string)
out = self.tokenizer.encode(string) # pylint: disable=no-member
out += [self.end]
return {"input_ids": out}

View File

@ -135,21 +135,4 @@ async def test_basic_queue_worker_with_health_check():
health_check_url = f"http://localhost:{health_check_port}/health"
health_check_ok = await check_health(health_check_url)
assert health_check_ok, "Health check server did not start properly"
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
await distributed_queue.init()
queue_item = create_test_prompt()
res = await distributed_queue.put_async(queue_item)
assert res.item_id == queue_item.prompt_id
assert len(res.outputs) == 1
assert res.status is not None
assert res.status.status_str == "success"
await distributed_queue.close()
health_check_stopped = not await check_health(health_check_url, max_retries=1)
assert health_check_stopped, "Health check server did not stop properly"
assert health_check_ok, "Health check server did not start properly"

View File

@ -4,10 +4,7 @@ import shutil
import pytest
from comfy.cli_args import args
from comfy.cmd import folder_paths
from comfy.cmd.folder_paths import FolderPathsTuple
from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS, get_huggingface_repo_list, \
get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache
_gitattributes = """*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
@ -49,6 +46,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
@pytest.mark.asyncio
def test_known_repos(tmp_path_factory):
from comfy.cmd import folder_paths
from comfy.cmd.folder_paths import FolderPathsTuple
from comfy.model_downloader import get_huggingface_repo_list, \
get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache
from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS
test_cache_dir = tmp_path_factory.mktemp("huggingface_cache")
test_local_dir = tmp_path_factory.mktemp("huggingface_locals")
test_repo_id = "doctorpangloss/comfyui_downloader_test"
@ -69,6 +72,7 @@ def test_known_repos(tmp_path_factory):
existing_repos = get_huggingface_repo_list()
assert test_repo_id not in existing_repos
# best to import this at the time that it is run, not when the test is initialized
KNOWN_HUGGINGFACE_MODEL_REPOS.add(test_repo_id)
existing_repos = get_huggingface_repo_list()
assert test_repo_id in existing_repos

View File

@ -6,6 +6,126 @@ from comfy.model_downloader import add_known_models, KNOWN_LORAS
from comfy.model_downloader_types import CivitFile
_workflows = {
"auraflow_1": {
"1": {
"inputs": {
"ckpt_name": "aura_flow_0.1.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"2": {
"inputs": {
"shift": 1.73,
"model": [
"1",
0
]
},
"class_type": "ModelSamplingAuraFlow",
"_meta": {
"title": "ModelSamplingAuraFlow"
}
},
"3": {
"inputs": {
"seed": 232240565010917,
"steps": 25,
"cfg": 3.5,
"sampler_name": "uni_pc",
"scheduler": "normal",
"denoise": 1,
"model": [
"2",
0
],
"positive": [
"4",
0
],
"negative": [
"5",
0
],
"latent_image": [
"6",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"text": "close-up portrait of cat",
"clip": [
"1",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"5": {
"inputs": {
"text": "",
"clip": [
"1",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"6": {
"inputs": {
"width": 1024,
"height": 1024,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"7": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"1",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"8": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"7",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
},
"lora_1": {
"3": {
"inputs": {
@ -148,7 +268,6 @@ async def test_workflow(workflow_name: str, workflow: dict, has_gpu: bool, clien
if not has_gpu:
pytest.skip("requires gpu")
prompt = Prompt.validate(workflow)
add_known_models("loras", KNOWN_LORAS, CivitFile(13941, 16576, "epi_noiseoffset2.safetensors"))
# todo: add all the models we want to test a bit more elegantly