Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-07-18 16:31:21 -07:00
commit cc99d89ac6
34 changed files with 488 additions and 38 deletions

23
.github/workflows/pylint.yml vendored Normal file
View File

@ -0,0 +1,23 @@
name: Python Linting
on: [push, pull_request]
jobs:
pylint:
name: Run Pylint
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.x
- name: Install Pylint
run: pip install pylint
- name: Run Pylint
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py")

View File

@ -23,3 +23,7 @@ jobs:
npm run test:generate
npm test -- --verbose
working-directory: ./tests-ui
- name: Run Unit Tests
run: |
pip install -r tests-unit/requirements.txt
python -m pytest tests-unit

3
.gitignore vendored
View File

@ -175,4 +175,5 @@ cython_debug/
/tests-ui/data/object_info.json
/user/
*.log
*.log
web_custom_versions/

View File

@ -0,0 +1,191 @@
from __future__ import annotations
import argparse
import logging
import os
import re
import tempfile
import zipfile
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict
import requests
from typing_extensions import NotRequired
from comfy.cli_args import DEFAULT_VERSION_STRING
from comfy.cmd.folder_paths import add_model_folder_path
from comfy.component_model.files import get_package_as_path
REQUEST_TIMEOUT = 10 # seconds
class Asset(TypedDict):
url: str
class Release(TypedDict):
id: int
tag_name: str
name: str
prerelease: bool
created_at: str
published_at: str
body: str
assets: NotRequired[list[Asset]]
@dataclass
class FrontEndProvider:
owner: str
repo: str
@property
def folder_name(self) -> str:
return f"{self.owner}_{self.repo}"
@property
def release_url(self) -> str:
return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
@cached_property
def all_releases(self) -> list[Release]:
releases = []
api_url = self.release_url
while api_url:
response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
response.raise_for_status() # Raises an HTTPError if the response was an error
releases.extend(response.json())
# GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
if "next" in response.links:
api_url = response.links["next"]["url"]
else:
api_url = None
return releases
@cached_property
def latest_release(self) -> Release:
latest_release_url = f"{self.release_url}/latest"
response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
response.raise_for_status() # Raises an HTTPError if the response was an error
return response.json()
def get_release(self, version: str) -> Release:
if version == "latest":
return self.latest_release
else:
for release in self.all_releases:
if release["tag_name"] in [version, f"v{version}"]:
return release
raise ValueError(f"Version {version} not found in releases")
def download_release_asset_zip(release: Release, destination_path: str) -> None:
"""Download dist.zip from github release."""
asset_url = None
for asset in release.get("assets", []):
if asset["name"] == "dist.zip":
asset_url = asset["url"]
break
if not asset_url:
raise ValueError("dist.zip not found in the release assets")
# Use a temporary file to download the zip content
with tempfile.TemporaryFile() as tmp_file:
headers = {"Accept": "application/octet-stream"}
response = requests.get(
asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
)
response.raise_for_status() # Ensure we got a successful response
# Write the content to the temporary file
tmp_file.write(response.content)
# Go back to the beginning of the temporary file
tmp_file.seek(0)
# Extract the zip file content to the destination path
with zipfile.ZipFile(tmp_file, "r") as zip_ref:
zip_ref.extractall(destination_path)
class FrontendManager:
DEFAULT_FRONTEND_PATH = get_package_as_path('comfy', 'web/')
CUSTOM_FRONTENDS_ROOT = add_model_folder_path("web_custom_versions", extensions=set())
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
Args:
value (str): The version string to parse.
Returns:
tuple[str, str]: A tuple containing provider name and version.
Raises:
argparse.ArgumentTypeError: If the version string is invalid.
"""
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
match_result = re.match(VERSION_PATTERN, value)
if match_result is None:
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod
def init_frontend_unsafe(cls, version_string: str) -> str:
"""
Initializes the frontend for the specified version.
Args:
version_string (str): The version string.
Returns:
str: The path to the initialized frontend.
Raises:
Exception: If there is an error during the initialization process.
main error source might be request timeout or invalid URL.
"""
if version_string == DEFAULT_VERSION_STRING:
return cls.DEFAULT_FRONTEND_PATH
repo_owner, repo_name, version = cls.parse_version_string(version_string)
provider = FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)
semantic_version = release["tag_name"].lstrip("v")
web_root = str(
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
)
if not os.path.exists(web_root):
os.makedirs(web_root, exist_ok=True)
logging.info(
"Downloading frontend(%s) version(%s) to (%s)",
provider.folder_name,
semantic_version,
web_root,
)
logging.debug(release)
download_release_asset_zip(release, destination_path=web_root)
return web_root
@classmethod
def init_frontend(cls, version_string: str) -> str:
"""
Initializes the frontend with the specified version string.
Args:
version_string (str): The version string to initialize the frontend with.
Returns:
str: The path of the initialized frontend.
"""
try:
return cls.init_frontend_unsafe(version_string)
except Exception as e:
logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.")
return cls.DEFAULT_FRONTEND_PATH

View File

@ -363,7 +363,7 @@ class ControlNet(nn.Module):
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
if idx < len(control_type):
feat_seq += self.task_embedding[control_type[idx]]
feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
inputs.append(feat_seq.unsqueeze(1))
condition_list.append(controlnet_cond)

View File

@ -15,6 +15,9 @@ from . import options
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, ConfigChangeHandler, EnumAction, \
EnhancedConfigArgParser
# todo: move this
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
def _create_parser() -> EnhancedConfigArgParser:
parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json'],
@ -108,6 +111,7 @@ def _create_parser() -> EnhancedConfigArgParser:
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true",
help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true",
@ -160,13 +164,43 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
parser.add_argument("--force-hf-local-dir-mode", action="store_true", help="Download repos from huggingface.co to the models/huggingface directory with the \"local_dir\" argument instead of models/huggingface_cache with the \"cache_dir\" argument, recreating the traditional file structure.")
parser.add_argument(
"--front-end-version",
type=str,
default=DEFAULT_VERSION_STRING,
help="""
Specifies the version of the frontend to be used. This command needs internet connectivity to query and
download available frontend implementations from GitHub releases.
The version string should be in the format of:
[repoOwner]/[repoName]@[version]
where version is one of: "latest" or a valid version number (e.g. "1.0.0")
""",
)
def is_valid_directory(path: Optional[str]) -> Optional[str]:
"""Validate if the given path is a directory."""
if path is None:
return None
if not os.path.isdir(path):
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
return path
parser.add_argument(
"--front-end-root",
type=is_valid_directory,
default=None,
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)
# now give plugins a chance to add configuration
for entry_point in entry_points().select(group='comfyui.custom_config'):
try:
plugin_callable: ConfigurationExtender | ModuleType = entry_point.load()
if isinstance(plugin_callable, ModuleType):
# todo: find the configuration extender in the module
plugin_callable = ...
raise ValueError("unexpected or unsupported plugin configuration type")
else:
parser_result = plugin_callable(parser)
if parser_result is not None:

View File

@ -1,6 +1,5 @@
from .component_model import files
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
import os
import torch
import json
import logging
@ -43,6 +42,7 @@ class ClipVisionModel():
else:
raise ValueError(f"json_config had invalid value={json_config}")
self.image_size = config.get("image_size", 224)
self.load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
self.dtype = model_management.text_encoder_dtype(self.load_device)
@ -58,7 +58,7 @@ class ClipVisionModel():
def encode_image(self, image):
model_management.load_model_gpu(self.patcher)
pixel_values = clip_preprocess(image.to(self.load_device)).float()
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
outputs = Output()
@ -101,7 +101,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = files.get_path_as_dict(None, "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json")
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
json_config = files.get_path_as_dict(None, "clip_vision_config_vitl_336.json")
else:
json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json")
else:
return None

View File

@ -0,0 +1,18 @@
{
"attention_dropout": 0.0,
"dropout": 0.0,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"image_size": 336,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"layer_norm_eps": 1e-5,
"model_type": "clip_vision_model",
"num_attention_heads": 16,
"num_channels": 3,
"num_hidden_layers": 24,
"patch_size": 14,
"projection_dim": 768,
"torch_dtype": "float32"
}

View File

@ -19,7 +19,7 @@ from .. import interruption
from .. import model_management
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
ValidationErrorDict, NodeErrorsDictValue
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
from ..execution_context import new_execution_context, ExecutionContext
from ..nodes.package import import_all_nodes_in_workspace
@ -318,6 +318,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
to_delete = True
elif unique_id not in old_prompt:
to_delete = True
elif class_type != old_prompt[unique_id]['class_type']:
to_delete = True
elif inputs == old_prompt[unique_id]['inputs']:
for x in inputs:
input_data = inputs[x]
@ -731,13 +733,18 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
span.set_status(Status(StatusCode.ERROR))
if res.error is not None and len(res.error) > 0:
span.set_attributes({
f"error.{k}": v for k, v in res.error.items()
f"error.{k}": v for k, v in res.error.items() if isinstance(v, (bool, str, bytes, int, float, list[str], list[int], list[float]))
})
if "extra_info" in res.error and isinstance(res.error["extra_info"], dict):
extra_info: ValidationErrorExtraInfoDict = res.error["extra_info"]
span.set_attributes({
f"error.extra_info.{k}": v for k, v in extra_info.items() if isinstance(v, (str, list[str]))
})
if len(res.node_errors) > 0:
for node_id, node_error in res.node_errors.items():
for node_error_field, node_error_value in node_error.items():
if isinstance(node_error_value, (str, bool, int, float)):
span.set_attribute("node_errors.{node_id}.{node_error_field}", node_error_value)
span.set_attribute(f"node_errors.{node_id}.{node_error_field}", node_error_value)
return res

View File

@ -7,9 +7,9 @@ import logging
import mimetypes
import os
import struct
import sys
import traceback
import uuid
import hashlib
from asyncio import Future, AbstractEventLoop
from enum import Enum
from io import BytesIO
@ -19,7 +19,6 @@ from urllib.parse import quote, urlencode
import aiofiles
import aiohttp
import sys
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from aiohttp import web
@ -30,6 +29,7 @@ from .latent_preview_image_encoding import encode_preview_image
from .. import interruption
from .. import model_management
from .. import utils
from ..app.frontend_management import FrontendManager
from ..app.user_manager import UserManager
from ..cli_args import args
from ..client.client_types import FileOutput
@ -115,10 +115,11 @@ class PromptServer(ExecutorToClientProgress):
handler_args={'max_field_size': 16380},
middlewares=middlewares)
self.sockets = dict()
web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../web")
if not os.path.exists(web_root_path):
web_root_path = get_package_as_path('comfy', 'web/')
self.web_root = web_root_path
self.web_root = (
FrontendManager.init_frontend(args.front_end_version)
if args.front_end_root is None
else args.front_end_root
)
routes = web.RouteTableDef()
self.routes: web.RouteTableDef = routes
self.last_node_id = None
@ -191,10 +192,12 @@ class PromptServer(ExecutorToClientProgress):
return type_dir, dir_type
def compare_image_hash(filepath, image):
hasher = node_helpers.hasher()
# function to compare hashes of two images to see if it already exists, fix to #3465
if os.path.exists(filepath):
a = hashlib.sha256()
b = hashlib.sha256()
a = hasher()
b = hasher()
with open(filepath, "rb") as f:
a.update(f.read())
b.update(image.file.read())
@ -233,7 +236,7 @@ class PromptServer(ExecutorToClientProgress):
else:
i = 1
while os.path.exists(filepath):
if compare_image_hash(filepath, image): #compare hash to prevent saving of duplicates with same name, fix for #3465
if compare_image_hash(filepath, image): # compare hash to prevent saving of duplicates with same name, fix for #3465
image_is_duplicate = True
break
filename = f"{split[0]} ({i}){split[1]}"
@ -719,6 +722,7 @@ class PromptServer(ExecutorToClientProgress):
@external_address.setter
def external_address(self, value):
self._external_address = value
def add_routes(self):
self.user_manager.add_routes(self.routes)

View File

@ -45,6 +45,7 @@ class ControlBase:
self.timestep_range = None
self.compression_ratio = 8
self.upscale_algorithm = 'nearest-exact'
self.extra_args = {}
if device is None:
device = model_management.get_torch_device()
@ -90,6 +91,7 @@ class ControlBase:
c.compression_ratio = self.compression_ratio
c.upscale_algorithm = self.upscale_algorithm
c.latent_format = self.latent_format
c.extra_args = self.extra_args.copy()
c.vae = self.vae
def inference_memory_requirements(self, dtype):
@ -135,6 +137,10 @@ class ControlBase:
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
return out
def set_extra_arg(self, argument, value=None):
self.extra_args[argument] = value
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
super().__init__(device)
@ -190,7 +196,7 @@ class ControlNet(ControlBase):
timestep = self.model_sampling_current.timestep(t)
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
return self.control_merge(control, control_prev, output_dtype)
def copy(self):

View File

@ -61,8 +61,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
beta_schedule = sampling_settings.get("beta_schedule", "linear")
linear_start = sampling_settings.get("linear_start", 0.00085)
linear_end = sampling_settings.get("linear_end", 0.012)
timesteps = sampling_settings.get("timesteps", 1000)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
self.sigma_data = 1.0
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,

View File

@ -1,3 +1,7 @@
import hashlib
from comfy.cli_args import args
from PIL import ImageFile, UnidentifiedImageError
def conditioning_set_values(conditioning, values={}):
@ -22,3 +26,12 @@ def pillow(fn, arg):
if prev_value is not None:
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
return x
def hasher():
hashfuncs = {
"md5": hashlib.md5,
"sha1": hashlib.sha1,
"sha256": hashlib.sha256,
"sha512": hashlib.sha512
}
return hashfuncs[args.default_hashing_function]

View File

@ -745,7 +745,7 @@ class ControlNetApply:
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_controlnet"
CATEGORY = "conditioning"
CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, conditioning, control_net, image, strength):
if strength == 0:
@ -780,7 +780,7 @@ class ControlNetApplyAdvanced:
RETURN_NAMES = ("positive", "negative")
FUNCTION = "apply_controlnet"
CATEGORY = "conditioning"
CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
if strength == 0:

View File

@ -17,10 +17,10 @@ from . import model_detection
from . import model_management
from . import model_patcher
from . import model_sampling
from . import sa_t5
from .text_encoders import sa_t5
from . import sd1_clip
from . import sd2_clip
from . import sd3_clip
from .text_encoders import sd3_clip
from . import sdxl_clip
from . import utils
from .ldm.audio.autoencoder import AudioOobleckVAE

View File

@ -5,8 +5,8 @@ from . import utils
from . import sd1_clip
from . import sd2_clip
from . import sdxl_clip
from . import sd3_clip
from . import sa_t5
from .text_encoders import sd3_clip
from .text_encoders import sa_t5
from .text_encoders import aura_t5
from . import supported_models_base

View File

@ -2,7 +2,7 @@ from importlib import resources
from comfy import sd1_clip
from .llama_tokenizer import LLAMATokenizer
from .. import t5
from ..text_encoders import t5
from ..component_model.files import get_path_as_dict

View File

@ -1,14 +1,13 @@
from transformers import T5TokenizerFast
import comfy.t5
import comfy.text_encoders.t5
from comfy import sd1_clip
from comfy.component_model import files
class T5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
class T5BaseTokenizer(sd1_clip.SDTokenizer):

View File

@ -1,11 +1,10 @@
import logging
import os
import torch
from transformers import T5TokenizerFast
import comfy.model_management
import comfy.t5
import comfy.text_encoders.t5
from comfy import sd1_clip
from comfy import sdxl_clip
from comfy.component_model import files
@ -13,13 +12,13 @@ from comfy.component_model import files
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5)
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json", package="comfy.text_encoders")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None):
tokenizer_path = files.get_package_as_path("comfy.t5_tokenizer")
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)

View File

@ -17,7 +17,6 @@ function getResourceURL(subfolder, filename, type = "input") {
"filename=" + encodeURIComponent(filename),
"type=" + type,
"subfolder=" + subfolder,
app.getPreviewFormatParam().substring(1),
app.getRandParam().substring(1)
].join("&")

View File

@ -101,7 +101,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
CATEGORY = "_for_testing/sd3"
CATEGORY = "conditioning/controlnet"
NODE_CLASS_MAPPINGS = {
"TripleCLIPLoader": TripleCLIPLoader,

View File

@ -0,0 +1,37 @@
UNION_CONTROLNET_TYPES = {"auto": -1,
"openpose": 0,
"depth": 1,
"hed/pidi/scribble/ted": 2,
"canny/lineart/anime_lineart/mlsd": 3,
"normal": 4,
"segment": 5,
"tile": 6,
"repaint": 7,
}
class SetUnionControlNetType:
@classmethod
def INPUT_TYPES(s):
return {"required": {"control_net": ("CONTROL_NET", ),
"type": (list(UNION_CONTROLNET_TYPES.keys()),)
}}
CATEGORY = "conditioning/controlnet"
RETURN_TYPES = ("CONTROL_NET",)
FUNCTION = "set_controlnet_type"
def set_controlnet_type(self, control_net, type):
control_net = control_net.copy()
type_number = UNION_CONTROLNET_TYPES[type]
if type_number >= 0:
control_net.set_extra_arg("control_type", [type_number])
else:
control_net.set_extra_arg("control_type", [])
return (control_net,)
NODE_CLASS_MAPPINGS = {
"SetUnionControlNetType": SetUnionControlNetType,
}

View File

@ -1,6 +1,8 @@
[pytest]
markers =
inference: mark as inference test (deselect with '-m "not inference"')
testpaths = tests
testpaths =
tests
tests-unit
addopts = -s
asyncio_mode = auto

8
tests-unit/README.md Normal file
View File

@ -0,0 +1,8 @@
# Pytest Unit Tests
## Install test dependencies
`pip install -r tests-units/requirements.txt`
## Run tests
`pytest tests-units/`

View File

View File

@ -0,0 +1,100 @@
import argparse
import pytest
from requests.exceptions import HTTPError
from app.frontend_management import (
FrontendManager,
FrontEndProvider,
Release,
)
from comfy.cli_args import DEFAULT_VERSION_STRING
@pytest.fixture
def mock_releases():
return [
Release(
id=1,
tag_name="1.0.0",
name="Release 1.0.0",
prerelease=False,
created_at="2022-01-01T00:00:00Z",
published_at="2022-01-01T00:00:00Z",
body="Release notes for 1.0.0",
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
),
Release(
id=2,
tag_name="2.0.0",
name="Release 2.0.0",
prerelease=False,
created_at="2022-02-01T00:00:00Z",
published_at="2022-02-01T00:00:00Z",
body="Release notes for 2.0.0",
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
),
]
@pytest.fixture
def mock_provider(mock_releases):
provider = FrontEndProvider(
owner="test-owner",
repo="test-repo",
)
provider.all_releases = mock_releases
provider.latest_release = mock_releases[1]
FrontendManager.PROVIDERS = [provider]
return provider
def test_get_release(mock_provider, mock_releases):
version = "1.0.0"
release = mock_provider.get_release(version)
assert release == mock_releases[0]
def test_get_release_latest(mock_provider, mock_releases):
version = "latest"
release = mock_provider.get_release(version)
assert release == mock_releases[1]
def test_get_release_invalid_version(mock_provider):
version = "invalid"
with pytest.raises(ValueError):
mock_provider.get_release(version)
def test_init_frontend_default():
version_string = DEFAULT_VERSION_STRING
frontend_path = FrontendManager.init_frontend(version_string)
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
def test_init_frontend_invalid_version():
version_string = "test-owner/test-repo@1.100.99"
with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string)
def test_init_frontend_invalid_provider():
version_string = "invalid/invalid@latest"
with pytest.raises(HTTPError):
FrontendManager.init_frontend_unsafe(version_string)
def test_parse_version_string():
version_string = "owner/repo@1.0.0"
repo_owner, repo_name, version = FrontendManager.parse_version_string(
version_string
)
assert repo_owner == "owner"
assert repo_name == "repo"
assert version == "1.0.0"
def test_parse_version_string_invalid():
version_string = "invalid"
with pytest.raises(argparse.ArgumentTypeError):
FrontendManager.parse_version_string(version_string)

View File

@ -0,0 +1 @@
pytest>=7.8.0