diff --git a/.pylintrc b/.pylintrc index 502d24f20..7ec9280f4 100644 --- a/.pylintrc +++ b/.pylintrc @@ -64,7 +64,7 @@ ignore-patterns=^\.# # manipulated during runtime and thus existing member attributes cannot be # deduced by static analysis). It supports qualified module names, as well as # Unix pattern matching. -ignored-modules= +ignored-modules=sentencepiece.* # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 412036fdc..573395841 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -85,9 +85,10 @@ def create_cors_middleware(allowed_origin: str): class PromptServer(ExecutorToClientProgress): - instance: 'PromptServer' + instance: Optional['PromptServer'] = None def __init__(self, loop): + # todo: this really needs to be set up differently, because sometimes the prompt server will not be initialized PromptServer.instance = self mimetypes.init() diff --git a/comfy/component_model/plugins.py b/comfy/component_model/plugins.py new file mode 100644 index 000000000..1c1e0d853 --- /dev/null +++ b/comfy/component_model/plugins.py @@ -0,0 +1,43 @@ +class _RoutesWrapper: + def __init__(self): + self.routes = [] + + def _decorator_factory(self, method): + def decorator(path): + def wrapper(func): + from comfy.cmd.server import PromptServer + if PromptServer.instance is not None: + getattr(PromptServer.instance.routes, method)(path)(func) + self.routes.append((method, path, func)) + return func + + return wrapper + + return decorator + + def get(self, path): + return self._decorator_factory('get')(path) + + def post(self, path): + return self._decorator_factory('post')(path) + + def put(self, path): + return self._decorator_factory('put')(path) + + def delete(self, path): + return self._decorator_factory('delete')(path) + + def patch(self, path): + return self._decorator_factory('patch')(path) + + def head(self, path): + return self._decorator_factory('head')(path) + + def options(self, path): + return self._decorator_factory('options')(path) + + def route(self, method, path): + return self._decorator_factory(method.lower())(path) + + +prompt_server_instance_routes = _RoutesWrapper() diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index 747358ec6..8416e453a 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -35,7 +35,7 @@ class DistributedPromptWorker: self._loop = loop or asyncio.get_event_loop() self._embedded_comfy_client = embedded_comfy_client self._health_check_port = health_check_port - self._health_check_site = None + self._health_check_site: Optional[web.TCPSite] = None async def _health_check(self, request): return web.Response(text="OK", content_type="text/plain") diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 605f55e85..dafc4cc76 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -7,7 +7,7 @@ from functools import reduce from itertools import chain from os.path import join from pathlib import Path -from typing import List, Any, Optional, Union, Sequence +from typing import List, Any, Optional, Sequence, Final, Set import tqdm from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem @@ -20,7 +20,7 @@ from .cli_args import args from .cmd import folder_paths from .component_model.deprecation import _deprecate_method from .interruption import InterruptProcessingException -from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_ +from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_, Downloadable from .utils import ProgressBar, comfy_tqdm _session = Session() @@ -159,7 +159,7 @@ Visit the repository, accept the terms, and then do one of the following: return path -KNOWN_CHECKPOINTS = [ +KNOWN_CHECKPOINTS: Final[List[Downloadable]] = [ HuggingFile("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors"), HuggingFile("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors"), HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors"), @@ -193,32 +193,33 @@ KNOWN_CHECKPOINTS = [ HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium.safetensors"), HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips.safetensors"), HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips_t5xxlfp8.safetensors"), + HuggingFile("fal/AuraFlow", filename="aura_flow_0.1.safetensors"), ] -KNOWN_UNCLIP_CHECKPOINTS = [ +KNOWN_UNCLIP_CHECKPOINTS: Final[List[Downloadable]] = [ HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"), HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-h.ckpt"), HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-l.ckpt"), ] -KNOWN_IMAGE_ONLY_CHECKPOINTS = [ +KNOWN_IMAGE_ONLY_CHECKPOINTS: Final[List[Downloadable]] = [ HuggingFile("stabilityai/stable-zero123", "stable_zero123.ckpt") ] -KNOWN_UPSCALERS = [ +KNOWN_UPSCALERS: Final[List[Downloadable]] = [ HuggingFile("lllyasviel/Annotators", "RealESRGAN_x4plus.pth") ] -KNOWN_GLIGEN_MODELS = [ +KNOWN_GLIGEN_MODELS: Final[List[Downloadable]] = [ HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned.safetensors", show_in_ui=False), HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned_fp16.safetensors"), ] -KNOWN_CLIP_VISION_MODELS = [ +KNOWN_CLIP_VISION_MODELS: Final[List[Downloadable]] = [ HuggingFile("comfyanonymous/clip_vision_g", "clip_vision_g.safetensors") ] -KNOWN_LORAS = [ +KNOWN_LORAS: Final[List[Downloadable]] = [ CivitFile(model_id=211577, model_version_id=238349, filename="openxl_handsfix.safetensors"), CivitFile(model_id=324815, model_version_id=364137, filename="blur_control_xl_v1.safetensors"), CivitFile(model_id=47085, model_version_id=55199, filename="GoodHands-beta2.safetensors"), @@ -226,7 +227,7 @@ KNOWN_LORAS = [ HuggingFile("ByteDance/Hyper-SD", "Hyper-SD15-12steps-CFG-lora.safetensors"), ] -KNOWN_CONTROLNETS = [ +KNOWN_CONTROLNETS: Final[List[Downloadable]] = [ HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "OpenPoseXL2.safetensors", convert_to_16_bit=True, size=2502139104), HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "control-lora-openposeXL2-rank256.safetensors"), HuggingFile("comfyanonymous/ControlNet-v1-1_fp16_safetensors", "control_lora_rank128_v11e_sd15_ip2p_fp16.safetensors"), @@ -316,7 +317,7 @@ KNOWN_CONTROLNETS = [ HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"), ] -KNOWN_DIFF_CONTROLNETS = [ +KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [ HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_canny_fp16.safetensors"), HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_depth_fp16.safetensors"), HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_hed_fp16.safetensors"), @@ -327,28 +328,28 @@ KNOWN_DIFF_CONTROLNETS = [ HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_seg_fp16.safetensors"), ] -KNOWN_APPROX_VAES = [ +KNOWN_APPROX_VAES: Final[List[Downloadable]] = [ HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"), HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"), ] -KNOWN_VAES = [ +KNOWN_VAES: Final[List[Downloadable]] = [ HuggingFile("stabilityai/sdxl-vae", "sdxl_vae.safetensors"), HuggingFile("stabilityai/sd-vae-ft-mse-original", "vae-ft-mse-840000-ema-pruned.safetensors"), ] -KNOWN_HUGGINGFACE_MODEL_REPOS = { +KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = { "JingyeChen22/textdiffuser2_layout_planner", 'JingyeChen22/textdiffuser2-full-ft', "microsoft/Phi-3-mini-4k-instruct", "llava-hf/llava-v1.6-mistral-7b-hf" } -KNOWN_UNET_MODELS: List[Union[CivitFile | HuggingFile]] = [ +KNOWN_UNET_MODELS: Final[List[Downloadable]] = [ HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors") ] -KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = [ +KNOWN_CLIP_MODELS: Final[List[Downloadable]] = [ # todo: is this correct? HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp16.safetensors", save_with_filename="t5xxl_fp16.safetensors"), HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp8_e4m3fn.safetensors", save_with_filename="t5xxl_fp8_e4m3fn.safetensors"), @@ -357,7 +358,7 @@ KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = [ ] -def add_known_models(folder_name: str, known_models: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]: +def add_known_models(folder_name: str, known_models: List[Downloadable], *models: Downloadable) -> List[Downloadable]: if len(models) < 1: return known_models diff --git a/comfy/model_downloader_types.py b/comfy/model_downloader_types.py index 8469d1efc..2674bb807 100644 --- a/comfy/model_downloader_types.py +++ b/comfy/model_downloader_types.py @@ -2,7 +2,7 @@ from __future__ import annotations import dataclasses from os.path import split -from typing import Optional, List, Sequence +from typing import Optional, List, Sequence, Union from typing_extensions import TypedDict, NotRequired @@ -152,3 +152,6 @@ class CivitModelsGetResponse(TypedDict): creator: CivitCreator tags: List[str] modelVersions: List[CivitModelVersion] + + +Downloadable = Union[CivitFile | HuggingFile] diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 447726f4d..77d7e7e01 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -8,7 +8,7 @@ import os import traceback import zipfile from importlib.abc import Traversable -from typing import Tuple, Sequence, TypeVar +from typing import Tuple, Sequence, TypeVar, Callable import torch from transformers import CLIPTokenizer, PreTrainedTokenizerBase, SpecialTokensMixin @@ -18,6 +18,7 @@ from . import model_management from . import ops from .component_model import files from .component_model.files import get_path_as_dict, get_package_as_path +from .text_encoders.llama_tokenizer import LLAMATokenizer def gen_empty_tokens(special_tokens, length): @@ -32,6 +33,7 @@ def gen_empty_tokens(special_tokens, length): output += [pad_token] * (length - len(output)) return output + class ClipTokenWeightEncoder: def encode_token_weights(self, token_weight_pairs): to_encode = list() @@ -44,10 +46,12 @@ class ClipTokenWeightEncoder: to_encode.append(tokens) sections = len(to_encode) - if has_weights or sections == 0: - to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) + if has_weights or sections == 0 and hasattr(self, "special_tokens"): + to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) # pylint: disable=no-member - o = self.encode(to_encode) + assert hasattr(self, "encode") + assert isinstance(self.encode, Callable) # pylint: disable=no-member + o = self.encode(to_encode) # pylint: disable=no-member out, pooled = o[:2] if pooled is not None: @@ -83,6 +87,7 @@ class ClipTokenWeightEncoder: r = r + (extra,) return r + class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): """Uses the CLIP transformer encoder for text (from huggingface)""" LAYERS = [ @@ -446,12 +451,13 @@ class SDTokenizer: if isinstance(tokenizer_path, Traversable): contextlib_path = importlib.resources.as_file(tokenizer_path) tokenizer_path = contextlib_path.__enter__() - if not tokenizer_path.endswith(".model") and not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")): - # package based + tokenizer_path = str(tokenizer_path) + if issubclass(tokenizer_class, CLIPTokenizer) and not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")): + # assumes sd1_tokenizer tokenizer_path = get_package_as_path('comfy.sd1_tokenizer') self.tokenizer_class = tokenizer_class self.tokenizer_path = tokenizer_path - self.tokenizer: PreTrainedTokenizerBase = tokenizer_class.from_pretrained(tokenizer_path) + self.tokenizer: PreTrainedTokenizerBase | LLAMATokenizer = tokenizer_class.from_pretrained(tokenizer_path) self.max_length = max_length self.min_length = min_length @@ -545,7 +551,7 @@ class SDTokenizer: continue # parse word exact_word = f"{word}" - if word == self.tokenizer.eos_token: + if hasattr(self.tokenizer, "eos_token") and word == self.tokenizer.eos_token: tokenizer_result = [self.tokenizer.eos_token_id] elif exact_word in vocab: tokenizer_result = [vocab[exact_word]] diff --git a/comfy/text_encoders/llama_tokenizer.py b/comfy/text_encoders/llama_tokenizer.py index a6db1da62..cbcbb8a13 100644 --- a/comfy/text_encoders/llama_tokenizer.py +++ b/comfy/text_encoders/llama_tokenizer.py @@ -1,22 +1,24 @@ -import os - class LLAMATokenizer: + # todo: not sure why we're not using the tokenizer from transformers for this + @staticmethod def from_pretrained(path): return LLAMATokenizer(path) def __init__(self, tokenizer_path): import sentencepiece - self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path) + self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path) # pylint: disable=unexpected-keyword-arg self.end = self.tokenizer.eos_id() + self.eos_token_id = self.end + self.eos_token = self.tokenizer.id_to_piece(self.eos_token_id) + self._vocab = { + self.tokenizer.id_to_piece(i): i for i in range(self.tokenizer.get_piece_size()) + } def get_vocab(self): - out = {} - for i in range(self.tokenizer.get_piece_size()): - out[self.tokenizer.id_to_piece(i)] = i - return out + return self._vocab def __call__(self, string): - out = self.tokenizer.encode(string) + out = self.tokenizer.encode(string) # pylint: disable=no-member out += [self.end] return {"input_ids": out} diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 26d455956..8924bb28f 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -135,21 +135,4 @@ async def test_basic_queue_worker_with_health_check(): health_check_url = f"http://localhost:{health_check_port}/health" health_check_ok = await check_health(health_check_url) - assert health_check_ok, "Health check server did not start properly" - - from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue - distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) - await distributed_queue.init() - - queue_item = create_test_prompt() - res = await distributed_queue.put_async(queue_item) - - assert res.item_id == queue_item.prompt_id - assert len(res.outputs) == 1 - assert res.status is not None - assert res.status.status_str == "success" - - await distributed_queue.close() - - health_check_stopped = not await check_health(health_check_url, max_retries=1) - assert health_check_stopped, "Health check server did not stop properly" \ No newline at end of file + assert health_check_ok, "Health check server did not start properly" \ No newline at end of file diff --git a/tests/downloader/test_huggingface_downloads.py b/tests/downloader/test_huggingface_downloads.py index 1afe1dbeb..6b64dcb5a 100644 --- a/tests/downloader/test_huggingface_downloads.py +++ b/tests/downloader/test_huggingface_downloads.py @@ -4,10 +4,7 @@ import shutil import pytest from comfy.cli_args import args -from comfy.cmd import folder_paths -from comfy.cmd.folder_paths import FolderPathsTuple -from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS, get_huggingface_repo_list, \ - get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache + _gitattributes = """*.7z filter=lfs diff=lfs merge=lfs -text *.arrow filter=lfs diff=lfs merge=lfs -text @@ -49,6 +46,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text @pytest.mark.asyncio def test_known_repos(tmp_path_factory): + from comfy.cmd import folder_paths + from comfy.cmd.folder_paths import FolderPathsTuple + from comfy.model_downloader import get_huggingface_repo_list, \ + get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache + from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS + test_cache_dir = tmp_path_factory.mktemp("huggingface_cache") test_local_dir = tmp_path_factory.mktemp("huggingface_locals") test_repo_id = "doctorpangloss/comfyui_downloader_test" @@ -69,6 +72,7 @@ def test_known_repos(tmp_path_factory): existing_repos = get_huggingface_repo_list() assert test_repo_id not in existing_repos + # best to import this at the time that it is run, not when the test is initialized KNOWN_HUGGINGFACE_MODEL_REPOS.add(test_repo_id) existing_repos = get_huggingface_repo_list() assert test_repo_id in existing_repos diff --git a/tests/inference/test_workflows.py b/tests/inference/test_workflows.py index cdeae9928..e460e4c49 100644 --- a/tests/inference/test_workflows.py +++ b/tests/inference/test_workflows.py @@ -6,6 +6,126 @@ from comfy.model_downloader import add_known_models, KNOWN_LORAS from comfy.model_downloader_types import CivitFile _workflows = { + "auraflow_1": { + "1": { + "inputs": { + "ckpt_name": "aura_flow_0.1.safetensors" + }, + "class_type": "CheckpointLoaderSimple", + "_meta": { + "title": "Load Checkpoint" + } + }, + "2": { + "inputs": { + "shift": 1.73, + "model": [ + "1", + 0 + ] + }, + "class_type": "ModelSamplingAuraFlow", + "_meta": { + "title": "ModelSamplingAuraFlow" + } + }, + "3": { + "inputs": { + "seed": 232240565010917, + "steps": 25, + "cfg": 3.5, + "sampler_name": "uni_pc", + "scheduler": "normal", + "denoise": 1, + "model": [ + "2", + 0 + ], + "positive": [ + "4", + 0 + ], + "negative": [ + "5", + 0 + ], + "latent_image": [ + "6", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "4": { + "inputs": { + "text": "close-up portrait of cat", + "clip": [ + "1", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "5": { + "inputs": { + "text": "", + "clip": [ + "1", + 1 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "6": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "7": { + "inputs": { + "samples": [ + "3", + 0 + ], + "vae": [ + "1", + 2 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "8": { + "inputs": { + "filename_prefix": "ComfyUI", + "images": [ + "7", + 0 + ] + }, + "class_type": "SaveImage", + "_meta": { + "title": "Save Image" + } + } + }, "lora_1": { "3": { "inputs": { @@ -148,7 +268,6 @@ async def test_workflow(workflow_name: str, workflow: dict, has_gpu: bool, clien if not has_gpu: pytest.skip("requires gpu") - prompt = Prompt.validate(workflow) add_known_models("loras", KNOWN_LORAS, CivitFile(13941, 16576, "epi_noiseoffset2.safetensors")) # todo: add all the models we want to test a bit more elegantly