diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 000000000..5effbea35 --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,23 @@ +name: Python Linting + +on: [push, pull_request] + +jobs: + pylint: + name: Run Pylint + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.x + + - name: Install Pylint + run: pip install pylint + + - name: Run Pylint + run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py") diff --git a/.github/workflows/test-ui.yaml b/.github/workflows/test-ui.yaml index 95de59dfb..b56c85730 100644 --- a/.github/workflows/test-ui.yaml +++ b/.github/workflows/test-ui.yaml @@ -23,3 +23,7 @@ jobs: npm run test:generate npm test -- --verbose working-directory: ./tests-ui + - name: Run Unit Tests + run: | + pip install -r tests-unit/requirements.txt + python -m pytest tests-unit \ No newline at end of file diff --git a/.gitignore b/.gitignore index 2a246b3f4..c734b0d48 100644 --- a/.gitignore +++ b/.gitignore @@ -175,4 +175,5 @@ cython_debug/ /tests-ui/data/object_info.json /user/ -*.log \ No newline at end of file +*.log +web_custom_versions/ \ No newline at end of file diff --git a/comfy/app/frontend_management.py b/comfy/app/frontend_management.py new file mode 100644 index 000000000..5c1f649fa --- /dev/null +++ b/comfy/app/frontend_management.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import argparse +import logging +import os +import re +import tempfile +import zipfile +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path +from typing import TypedDict + +import requests +from typing_extensions import NotRequired + +from comfy.cli_args import DEFAULT_VERSION_STRING +from comfy.cmd.folder_paths import add_model_folder_path +from comfy.component_model.files import get_package_as_path + +REQUEST_TIMEOUT = 10 # seconds + + +class Asset(TypedDict): + url: str + + +class Release(TypedDict): + id: int + tag_name: str + name: str + prerelease: bool + created_at: str + published_at: str + body: str + assets: NotRequired[list[Asset]] + + +@dataclass +class FrontEndProvider: + owner: str + repo: str + + @property + def folder_name(self) -> str: + return f"{self.owner}_{self.repo}" + + @property + def release_url(self) -> str: + return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases" + + @cached_property + def all_releases(self) -> list[Release]: + releases = [] + api_url = self.release_url + while api_url: + response = requests.get(api_url, timeout=REQUEST_TIMEOUT) + response.raise_for_status() # Raises an HTTPError if the response was an error + releases.extend(response.json()) + # GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly. + if "next" in response.links: + api_url = response.links["next"]["url"] + else: + api_url = None + return releases + + @cached_property + def latest_release(self) -> Release: + latest_release_url = f"{self.release_url}/latest" + response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT) + response.raise_for_status() # Raises an HTTPError if the response was an error + return response.json() + + def get_release(self, version: str) -> Release: + if version == "latest": + return self.latest_release + else: + for release in self.all_releases: + if release["tag_name"] in [version, f"v{version}"]: + return release + raise ValueError(f"Version {version} not found in releases") + + +def download_release_asset_zip(release: Release, destination_path: str) -> None: + """Download dist.zip from github release.""" + asset_url = None + for asset in release.get("assets", []): + if asset["name"] == "dist.zip": + asset_url = asset["url"] + break + + if not asset_url: + raise ValueError("dist.zip not found in the release assets") + + # Use a temporary file to download the zip content + with tempfile.TemporaryFile() as tmp_file: + headers = {"Accept": "application/octet-stream"} + response = requests.get( + asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT + ) + response.raise_for_status() # Ensure we got a successful response + + # Write the content to the temporary file + tmp_file.write(response.content) + + # Go back to the beginning of the temporary file + tmp_file.seek(0) + + # Extract the zip file content to the destination path + with zipfile.ZipFile(tmp_file, "r") as zip_ref: + zip_ref.extractall(destination_path) + + +class FrontendManager: + DEFAULT_FRONTEND_PATH = get_package_as_path('comfy', 'web/') + CUSTOM_FRONTENDS_ROOT = add_model_folder_path("web_custom_versions", extensions=set()) + + @classmethod + def parse_version_string(cls, value: str) -> tuple[str, str, str]: + """ + Args: + value (str): The version string to parse. + + Returns: + tuple[str, str]: A tuple containing provider name and version. + + Raises: + argparse.ArgumentTypeError: If the version string is invalid. + """ + VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$" + match_result = re.match(VERSION_PATTERN, value) + if match_result is None: + raise argparse.ArgumentTypeError(f"Invalid version string: {value}") + + return match_result.group(1), match_result.group(2), match_result.group(3) + + @classmethod + def init_frontend_unsafe(cls, version_string: str) -> str: + """ + Initializes the frontend for the specified version. + + Args: + version_string (str): The version string. + + Returns: + str: The path to the initialized frontend. + + Raises: + Exception: If there is an error during the initialization process. + main error source might be request timeout or invalid URL. + """ + if version_string == DEFAULT_VERSION_STRING: + return cls.DEFAULT_FRONTEND_PATH + + repo_owner, repo_name, version = cls.parse_version_string(version_string) + provider = FrontEndProvider(repo_owner, repo_name) + release = provider.get_release(version) + + semantic_version = release["tag_name"].lstrip("v") + web_root = str( + Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version + ) + if not os.path.exists(web_root): + os.makedirs(web_root, exist_ok=True) + logging.info( + "Downloading frontend(%s) version(%s) to (%s)", + provider.folder_name, + semantic_version, + web_root, + ) + logging.debug(release) + download_release_asset_zip(release, destination_path=web_root) + return web_root + + @classmethod + def init_frontend(cls, version_string: str) -> str: + """ + Initializes the frontend with the specified version string. + + Args: + version_string (str): The version string to initialize the frontend with. + + Returns: + str: The path of the initialized frontend. + """ + try: + return cls.init_frontend_unsafe(version_string) + except Exception as e: + logging.error("Failed to initialize frontend: %s", e) + logging.info("Falling back to the default frontend.") + return cls.DEFAULT_FRONTEND_PATH diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 064dfce06..78490c548 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -363,7 +363,7 @@ class ControlNet(nn.Module): controlnet_cond = self.input_hint_block(hint[idx], emb, context) feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) if idx < len(control_type): - feat_seq += self.task_embedding[control_type[idx]] + feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device) inputs.append(feat_seq.unsqueeze(1)) condition_list.append(controlnet_cond) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 903ed9408..70506f24d 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -15,6 +15,9 @@ from . import options from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, ConfigChangeHandler, EnumAction, \ EnhancedConfigArgParser +# todo: move this +DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest" + def _create_parser() -> EnhancedConfigArgParser: parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json'], @@ -108,6 +111,7 @@ def _create_parser() -> EnhancedConfigArgParser: vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.") vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") + parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.") parser.add_argument("--deterministic", action="store_true", @@ -160,13 +164,43 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.") parser.add_argument("--force-hf-local-dir-mode", action="store_true", help="Download repos from huggingface.co to the models/huggingface directory with the \"local_dir\" argument instead of models/huggingface_cache with the \"cache_dir\" argument, recreating the traditional file structure.") + parser.add_argument( + "--front-end-version", + type=str, + default=DEFAULT_VERSION_STRING, + help=""" + Specifies the version of the frontend to be used. This command needs internet connectivity to query and + download available frontend implementations from GitHub releases. + + The version string should be in the format of: + [repoOwner]/[repoName]@[version] + where version is one of: "latest" or a valid version number (e.g. "1.0.0") + """, + ) + + def is_valid_directory(path: Optional[str]) -> Optional[str]: + """Validate if the given path is a directory.""" + if path is None: + return None + + if not os.path.isdir(path): + raise argparse.ArgumentTypeError(f"{path} is not a valid directory.") + return path + + parser.add_argument( + "--front-end-root", + type=is_valid_directory, + default=None, + help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.", + ) + # now give plugins a chance to add configuration for entry_point in entry_points().select(group='comfyui.custom_config'): try: plugin_callable: ConfigurationExtender | ModuleType = entry_point.load() if isinstance(plugin_callable, ModuleType): # todo: find the configuration extender in the module - plugin_callable = ... + raise ValueError("unexpected or unsupported plugin configuration type") else: parser_result = plugin_callable(parser) if parser_result is not None: diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py index 69e3cbb82..216a864fe 100644 --- a/comfy/clip_vision.py +++ b/comfy/clip_vision.py @@ -1,6 +1,5 @@ from .component_model import files from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace -import os import torch import json import logging @@ -43,6 +42,7 @@ class ClipVisionModel(): else: raise ValueError(f"json_config had invalid value={json_config}") + self.image_size = config.get("image_size", 224) self.load_device = model_management.text_encoder_device() offload_device = model_management.text_encoder_offload_device() self.dtype = model_management.text_encoder_dtype(self.load_device) @@ -58,7 +58,7 @@ class ClipVisionModel(): def encode_image(self, image): model_management.load_model_gpu(self.patcher) - pixel_values = clip_preprocess(image.to(self.load_device)).float() + pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float() out = self.model(pixel_values=pixel_values, intermediate_output=-2) outputs = Output() @@ -101,7 +101,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False): elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd: json_config = files.get_path_as_dict(None, "clip_vision_config_h.json") elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd: - json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json") + if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577: + json_config = files.get_path_as_dict(None, "clip_vision_config_vitl_336.json") + else: + json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json") else: return None diff --git a/comfy/clip_vision_config_vitl_336.json b/comfy/clip_vision_config_vitl_336.json new file mode 100644 index 000000000..f26945273 --- /dev/null +++ b/comfy/clip_vision_config_vitl_336.json @@ -0,0 +1,18 @@ +{ + "attention_dropout": 0.0, + "dropout": 0.0, + "hidden_act": "quick_gelu", + "hidden_size": 1024, + "image_size": 336, + "initializer_factor": 1.0, + "initializer_range": 0.02, + "intermediate_size": 4096, + "layer_norm_eps": 1e-5, + "model_type": "clip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 24, + "patch_size": 14, + "projection_dim": 768, + "torch_dtype": "float32" +} diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index c7744ba25..7e55bdf3c 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -19,7 +19,7 @@ from .. import interruption from .. import model_management from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \ - ValidationErrorDict, NodeErrorsDictValue + ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus from ..execution_context import new_execution_context, ExecutionContext from ..nodes.package import import_all_nodes_in_workspace @@ -318,6 +318,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item to_delete = True elif unique_id not in old_prompt: to_delete = True + elif class_type != old_prompt[unique_id]['class_type']: + to_delete = True elif inputs == old_prompt[unique_id]['inputs']: for x in inputs: input_data = inputs[x] @@ -731,13 +733,18 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple: span.set_status(Status(StatusCode.ERROR)) if res.error is not None and len(res.error) > 0: span.set_attributes({ - f"error.{k}": v for k, v in res.error.items() + f"error.{k}": v for k, v in res.error.items() if isinstance(v, (bool, str, bytes, int, float, list[str], list[int], list[float])) }) + if "extra_info" in res.error and isinstance(res.error["extra_info"], dict): + extra_info: ValidationErrorExtraInfoDict = res.error["extra_info"] + span.set_attributes({ + f"error.extra_info.{k}": v for k, v in extra_info.items() if isinstance(v, (str, list[str])) + }) if len(res.node_errors) > 0: for node_id, node_error in res.node_errors.items(): for node_error_field, node_error_value in node_error.items(): if isinstance(node_error_value, (str, bool, int, float)): - span.set_attribute("node_errors.{node_id}.{node_error_field}", node_error_value) + span.set_attribute(f"node_errors.{node_id}.{node_error_field}", node_error_value) return res diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 3287eb55a..40905cfb2 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -7,9 +7,9 @@ import logging import mimetypes import os import struct +import sys import traceback import uuid -import hashlib from asyncio import Future, AbstractEventLoop from enum import Enum from io import BytesIO @@ -19,7 +19,6 @@ from urllib.parse import quote, urlencode import aiofiles import aiohttp -import sys from PIL import Image from PIL.PngImagePlugin import PngInfo from aiohttp import web @@ -30,6 +29,7 @@ from .latent_preview_image_encoding import encode_preview_image from .. import interruption from .. import model_management from .. import utils +from ..app.frontend_management import FrontendManager from ..app.user_manager import UserManager from ..cli_args import args from ..client.client_types import FileOutput @@ -115,10 +115,11 @@ class PromptServer(ExecutorToClientProgress): handler_args={'max_field_size': 16380}, middlewares=middlewares) self.sockets = dict() - web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../web") - if not os.path.exists(web_root_path): - web_root_path = get_package_as_path('comfy', 'web/') - self.web_root = web_root_path + self.web_root = ( + FrontendManager.init_frontend(args.front_end_version) + if args.front_end_root is None + else args.front_end_root + ) routes = web.RouteTableDef() self.routes: web.RouteTableDef = routes self.last_node_id = None @@ -191,10 +192,12 @@ class PromptServer(ExecutorToClientProgress): return type_dir, dir_type def compare_image_hash(filepath, image): + hasher = node_helpers.hasher() + # function to compare hashes of two images to see if it already exists, fix to #3465 if os.path.exists(filepath): - a = hashlib.sha256() - b = hashlib.sha256() + a = hasher() + b = hasher() with open(filepath, "rb") as f: a.update(f.read()) b.update(image.file.read()) @@ -233,7 +236,7 @@ class PromptServer(ExecutorToClientProgress): else: i = 1 while os.path.exists(filepath): - if compare_image_hash(filepath, image): #compare hash to prevent saving of duplicates with same name, fix for #3465 + if compare_image_hash(filepath, image): # compare hash to prevent saving of duplicates with same name, fix for #3465 image_is_duplicate = True break filename = f"{split[0]} ({i}){split[1]}" @@ -719,6 +722,7 @@ class PromptServer(ExecutorToClientProgress): @external_address.setter def external_address(self, value): self._external_address = value + def add_routes(self): self.user_manager.add_routes(self.routes) diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 8ae92e108..252308b24 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -45,6 +45,7 @@ class ControlBase: self.timestep_range = None self.compression_ratio = 8 self.upscale_algorithm = 'nearest-exact' + self.extra_args = {} if device is None: device = model_management.get_torch_device() @@ -90,6 +91,7 @@ class ControlBase: c.compression_ratio = self.compression_ratio c.upscale_algorithm = self.upscale_algorithm c.latent_format = self.latent_format + c.extra_args = self.extra_args.copy() c.vae = self.vae def inference_memory_requirements(self, dtype): @@ -135,6 +137,10 @@ class ControlBase: o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue return out + def set_extra_arg(self, argument, value=None): + self.extra_args[argument] = value + + class ControlNet(ControlBase): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None): super().__init__(device) @@ -190,7 +196,7 @@ class ControlNet(ControlBase): timestep = self.model_sampling_current.timestep(t) x_noisy = self.model_sampling_current.calculate_input(t, x_noisy) - control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y) + control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args) return self.control_merge(control, control_prev, output_dtype) def copy(self): diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py index 019b32e12..4b4c090d8 100644 --- a/comfy/model_sampling.py +++ b/comfy/model_sampling.py @@ -61,8 +61,9 @@ class ModelSamplingDiscrete(torch.nn.Module): beta_schedule = sampling_settings.get("beta_schedule", "linear") linear_start = sampling_settings.get("linear_start", 0.00085) linear_end = sampling_settings.get("linear_end", 0.012) + timesteps = sampling_settings.get("timesteps", 1000) - self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3) + self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3) self.sigma_data = 1.0 def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000, diff --git a/comfy/node_helpers.py b/comfy/node_helpers.py index 43b9e829f..fee628790 100644 --- a/comfy/node_helpers.py +++ b/comfy/node_helpers.py @@ -1,3 +1,7 @@ +import hashlib + +from comfy.cli_args import args + from PIL import ImageFile, UnidentifiedImageError def conditioning_set_values(conditioning, values={}): @@ -22,3 +26,12 @@ def pillow(fn, arg): if prev_value is not None: ImageFile.LOAD_TRUNCATED_IMAGES = prev_value return x + +def hasher(): + hashfuncs = { + "md5": hashlib.md5, + "sha1": hashlib.sha1, + "sha256": hashlib.sha256, + "sha512": hashlib.sha512 + } + return hashfuncs[args.default_hashing_function] diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index ab0e8ca53..3dbed2909 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -745,7 +745,7 @@ class ControlNetApply: RETURN_TYPES = ("CONDITIONING",) FUNCTION = "apply_controlnet" - CATEGORY = "conditioning" + CATEGORY = "conditioning/controlnet" def apply_controlnet(self, conditioning, control_net, image, strength): if strength == 0: @@ -780,7 +780,7 @@ class ControlNetApplyAdvanced: RETURN_NAMES = ("positive", "negative") FUNCTION = "apply_controlnet" - CATEGORY = "conditioning" + CATEGORY = "conditioning/controlnet" def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None): if strength == 0: diff --git a/comfy/sd.py b/comfy/sd.py index f9f0288a3..618d4b76c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -17,10 +17,10 @@ from . import model_detection from . import model_management from . import model_patcher from . import model_sampling -from . import sa_t5 +from .text_encoders import sa_t5 from . import sd1_clip from . import sd2_clip -from . import sd3_clip +from .text_encoders import sd3_clip from . import sdxl_clip from . import utils from .ldm.audio.autoencoder import AudioOobleckVAE diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 0c2994026..3ba0b72f7 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -5,8 +5,8 @@ from . import utils from . import sd1_clip from . import sd2_clip from . import sdxl_clip -from . import sd3_clip -from . import sa_t5 +from .text_encoders import sd3_clip +from .text_encoders import sa_t5 from .text_encoders import aura_t5 from . import supported_models_base diff --git a/comfy/text_encoders/aura_t5.py b/comfy/text_encoders/aura_t5.py index 94ebd868b..4684fe282 100644 --- a/comfy/text_encoders/aura_t5.py +++ b/comfy/text_encoders/aura_t5.py @@ -2,7 +2,7 @@ from importlib import resources from comfy import sd1_clip from .llama_tokenizer import LLAMATokenizer -from .. import t5 +from ..text_encoders import t5 from ..component_model.files import get_path_as_dict diff --git a/comfy/sa_t5.py b/comfy/text_encoders/sa_t5.py similarity index 88% rename from comfy/sa_t5.py rename to comfy/text_encoders/sa_t5.py index 4521c364e..abe5d373f 100644 --- a/comfy/sa_t5.py +++ b/comfy/text_encoders/sa_t5.py @@ -1,14 +1,13 @@ from transformers import T5TokenizerFast -import comfy.t5 +import comfy.text_encoders.t5 from comfy import sd1_clip from comfy.component_model import files - class T5BaseModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None): textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_base.json") - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True) + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True) class T5BaseTokenizer(sd1_clip.SDTokenizer): diff --git a/comfy/sd3_clip.py b/comfy/text_encoders/sd3_clip.py similarity index 95% rename from comfy/sd3_clip.py rename to comfy/text_encoders/sd3_clip.py index 5990ec3b9..10cdfb4a2 100644 --- a/comfy/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -1,11 +1,10 @@ import logging -import os import torch from transformers import T5TokenizerFast import comfy.model_management -import comfy.t5 +import comfy.text_encoders.t5 from comfy import sd1_clip from comfy import sdxl_clip from comfy.component_model import files @@ -13,13 +12,13 @@ from comfy.component_model import files class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None): - textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json") - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5) + textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json", package="comfy.text_encoders") + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5) class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None): - tokenizer_path = files.get_package_as_path("comfy.t5_tokenizer") + tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer") super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77) diff --git a/comfy/t5.py b/comfy/text_encoders/t5.py similarity index 100% rename from comfy/t5.py rename to comfy/text_encoders/t5.py diff --git a/comfy/t5_config_base.json b/comfy/text_encoders/t5_config_base.json similarity index 100% rename from comfy/t5_config_base.json rename to comfy/text_encoders/t5_config_base.json diff --git a/comfy/t5_config_xxl.json b/comfy/text_encoders/t5_config_xxl.json similarity index 100% rename from comfy/t5_config_xxl.json rename to comfy/text_encoders/t5_config_xxl.json diff --git a/comfy/t5_tokenizer/__init__.py b/comfy/text_encoders/t5_tokenizer/__init__.py similarity index 100% rename from comfy/t5_tokenizer/__init__.py rename to comfy/text_encoders/t5_tokenizer/__init__.py diff --git a/comfy/t5_tokenizer/special_tokens_map.json b/comfy/text_encoders/t5_tokenizer/special_tokens_map.json similarity index 100% rename from comfy/t5_tokenizer/special_tokens_map.json rename to comfy/text_encoders/t5_tokenizer/special_tokens_map.json diff --git a/comfy/t5_tokenizer/tokenizer.json b/comfy/text_encoders/t5_tokenizer/tokenizer.json similarity index 100% rename from comfy/t5_tokenizer/tokenizer.json rename to comfy/text_encoders/t5_tokenizer/tokenizer.json diff --git a/comfy/t5_tokenizer/tokenizer_config.json b/comfy/text_encoders/t5_tokenizer/tokenizer_config.json similarity index 100% rename from comfy/t5_tokenizer/tokenizer_config.json rename to comfy/text_encoders/t5_tokenizer/tokenizer_config.json diff --git a/comfy/web/extensions/core/uploadAudio.js b/comfy/web/extensions/core/uploadAudio.js index 0ac9cb807..6cc3863a1 100644 --- a/comfy/web/extensions/core/uploadAudio.js +++ b/comfy/web/extensions/core/uploadAudio.js @@ -17,7 +17,6 @@ function getResourceURL(subfolder, filename, type = "input") { "filename=" + encodeURIComponent(filename), "type=" + type, "subfolder=" + subfolder, - app.getPreviewFormatParam().substring(1), app.getRandParam().substring(1) ].join("&") diff --git a/comfy_extras/nodes/nodes_sd3.py b/comfy_extras/nodes/nodes_sd3.py index beeae9885..4d944a959 100644 --- a/comfy_extras/nodes/nodes_sd3.py +++ b/comfy_extras/nodes/nodes_sd3.py @@ -101,7 +101,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced): "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) }} - CATEGORY = "_for_testing/sd3" + CATEGORY = "conditioning/controlnet" NODE_CLASS_MAPPINGS = { "TripleCLIPLoader": TripleCLIPLoader, diff --git a/comfy_extras/nodes_controlnet.py b/comfy_extras/nodes_controlnet.py new file mode 100644 index 000000000..ef7cfc6ab --- /dev/null +++ b/comfy_extras/nodes_controlnet.py @@ -0,0 +1,37 @@ + +UNION_CONTROLNET_TYPES = {"auto": -1, + "openpose": 0, + "depth": 1, + "hed/pidi/scribble/ted": 2, + "canny/lineart/anime_lineart/mlsd": 3, + "normal": 4, + "segment": 5, + "tile": 6, + "repaint": 7, + } + +class SetUnionControlNetType: + @classmethod + def INPUT_TYPES(s): + return {"required": {"control_net": ("CONTROL_NET", ), + "type": (list(UNION_CONTROLNET_TYPES.keys()),) + }} + + CATEGORY = "conditioning/controlnet" + RETURN_TYPES = ("CONTROL_NET",) + + FUNCTION = "set_controlnet_type" + + def set_controlnet_type(self, control_net, type): + control_net = control_net.copy() + type_number = UNION_CONTROLNET_TYPES[type] + if type_number >= 0: + control_net.set_extra_arg("control_type", [type_number]) + else: + control_net.set_extra_arg("control_type", []) + + return (control_net,) + +NODE_CLASS_MAPPINGS = { + "SetUnionControlNetType": SetUnionControlNetType, +} diff --git a/pytest.ini b/pytest.ini index 208837e45..891354251 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,6 +1,8 @@ [pytest] markers = inference: mark as inference test (deselect with '-m "not inference"') -testpaths = tests +testpaths = + tests + tests-unit addopts = -s asyncio_mode = auto \ No newline at end of file diff --git a/tests-unit/README.md b/tests-unit/README.md new file mode 100644 index 000000000..94abd9853 --- /dev/null +++ b/tests-unit/README.md @@ -0,0 +1,8 @@ +# Pytest Unit Tests + +## Install test dependencies + +`pip install -r tests-units/requirements.txt` + +## Run tests +`pytest tests-units/` diff --git a/tests-unit/app_test/__init__.py b/tests-unit/app_test/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests-unit/app_test/frontend_manager_test.py b/tests-unit/app_test/frontend_manager_test.py new file mode 100644 index 000000000..637869cfb --- /dev/null +++ b/tests-unit/app_test/frontend_manager_test.py @@ -0,0 +1,100 @@ +import argparse +import pytest +from requests.exceptions import HTTPError + +from app.frontend_management import ( + FrontendManager, + FrontEndProvider, + Release, +) +from comfy.cli_args import DEFAULT_VERSION_STRING + + +@pytest.fixture +def mock_releases(): + return [ + Release( + id=1, + tag_name="1.0.0", + name="Release 1.0.0", + prerelease=False, + created_at="2022-01-01T00:00:00Z", + published_at="2022-01-01T00:00:00Z", + body="Release notes for 1.0.0", + assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}], + ), + Release( + id=2, + tag_name="2.0.0", + name="Release 2.0.0", + prerelease=False, + created_at="2022-02-01T00:00:00Z", + published_at="2022-02-01T00:00:00Z", + body="Release notes for 2.0.0", + assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}], + ), + ] + + +@pytest.fixture +def mock_provider(mock_releases): + provider = FrontEndProvider( + owner="test-owner", + repo="test-repo", + ) + provider.all_releases = mock_releases + provider.latest_release = mock_releases[1] + FrontendManager.PROVIDERS = [provider] + return provider + + +def test_get_release(mock_provider, mock_releases): + version = "1.0.0" + release = mock_provider.get_release(version) + assert release == mock_releases[0] + + +def test_get_release_latest(mock_provider, mock_releases): + version = "latest" + release = mock_provider.get_release(version) + assert release == mock_releases[1] + + +def test_get_release_invalid_version(mock_provider): + version = "invalid" + with pytest.raises(ValueError): + mock_provider.get_release(version) + + +def test_init_frontend_default(): + version_string = DEFAULT_VERSION_STRING + frontend_path = FrontendManager.init_frontend(version_string) + assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH + + +def test_init_frontend_invalid_version(): + version_string = "test-owner/test-repo@1.100.99" + with pytest.raises(HTTPError): + FrontendManager.init_frontend_unsafe(version_string) + + +def test_init_frontend_invalid_provider(): + version_string = "invalid/invalid@latest" + with pytest.raises(HTTPError): + FrontendManager.init_frontend_unsafe(version_string) + + +def test_parse_version_string(): + version_string = "owner/repo@1.0.0" + repo_owner, repo_name, version = FrontendManager.parse_version_string( + version_string + ) + assert repo_owner == "owner" + assert repo_name == "repo" + assert version == "1.0.0" + + +def test_parse_version_string_invalid(): + version_string = "invalid" + with pytest.raises(argparse.ArgumentTypeError): + FrontendManager.parse_version_string(version_string) diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt new file mode 100644 index 000000000..0587502f8 --- /dev/null +++ b/tests-unit/requirements.txt @@ -0,0 +1 @@ +pytest>=7.8.0