mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
cc99d89ac6
23
.github/workflows/pylint.yml
vendored
Normal file
23
.github/workflows/pylint.yml
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
name: Python Linting
|
||||
|
||||
on: [push, pull_request]
|
||||
|
||||
jobs:
|
||||
pylint:
|
||||
name: Run Pylint
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: 3.x
|
||||
|
||||
- name: Install Pylint
|
||||
run: pip install pylint
|
||||
|
||||
- name: Run Pylint
|
||||
run: pylint --rcfile=.pylintrc $(find . -type f -name "*.py")
|
||||
4
.github/workflows/test-ui.yaml
vendored
4
.github/workflows/test-ui.yaml
vendored
@ -23,3 +23,7 @@ jobs:
|
||||
npm run test:generate
|
||||
npm test -- --verbose
|
||||
working-directory: ./tests-ui
|
||||
- name: Run Unit Tests
|
||||
run: |
|
||||
pip install -r tests-unit/requirements.txt
|
||||
python -m pytest tests-unit
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@ -175,4 +175,5 @@ cython_debug/
|
||||
|
||||
/tests-ui/data/object_info.json
|
||||
/user/
|
||||
*.log
|
||||
*.log
|
||||
web_custom_versions/
|
||||
191
comfy/app/frontend_management.py
Normal file
191
comfy/app/frontend_management.py
Normal file
@ -0,0 +1,191 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import zipfile
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict
|
||||
|
||||
import requests
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
from comfy.cmd.folder_paths import add_model_folder_path
|
||||
from comfy.component_model.files import get_package_as_path
|
||||
|
||||
REQUEST_TIMEOUT = 10 # seconds
|
||||
|
||||
|
||||
class Asset(TypedDict):
|
||||
url: str
|
||||
|
||||
|
||||
class Release(TypedDict):
|
||||
id: int
|
||||
tag_name: str
|
||||
name: str
|
||||
prerelease: bool
|
||||
created_at: str
|
||||
published_at: str
|
||||
body: str
|
||||
assets: NotRequired[list[Asset]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrontEndProvider:
|
||||
owner: str
|
||||
repo: str
|
||||
|
||||
@property
|
||||
def folder_name(self) -> str:
|
||||
return f"{self.owner}_{self.repo}"
|
||||
|
||||
@property
|
||||
def release_url(self) -> str:
|
||||
return f"https://api.github.com/repos/{self.owner}/{self.repo}/releases"
|
||||
|
||||
@cached_property
|
||||
def all_releases(self) -> list[Release]:
|
||||
releases = []
|
||||
api_url = self.release_url
|
||||
while api_url:
|
||||
response = requests.get(api_url, timeout=REQUEST_TIMEOUT)
|
||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||
releases.extend(response.json())
|
||||
# GitHub uses the Link header to provide pagination links. Check if it exists and update api_url accordingly.
|
||||
if "next" in response.links:
|
||||
api_url = response.links["next"]["url"]
|
||||
else:
|
||||
api_url = None
|
||||
return releases
|
||||
|
||||
@cached_property
|
||||
def latest_release(self) -> Release:
|
||||
latest_release_url = f"{self.release_url}/latest"
|
||||
response = requests.get(latest_release_url, timeout=REQUEST_TIMEOUT)
|
||||
response.raise_for_status() # Raises an HTTPError if the response was an error
|
||||
return response.json()
|
||||
|
||||
def get_release(self, version: str) -> Release:
|
||||
if version == "latest":
|
||||
return self.latest_release
|
||||
else:
|
||||
for release in self.all_releases:
|
||||
if release["tag_name"] in [version, f"v{version}"]:
|
||||
return release
|
||||
raise ValueError(f"Version {version} not found in releases")
|
||||
|
||||
|
||||
def download_release_asset_zip(release: Release, destination_path: str) -> None:
|
||||
"""Download dist.zip from github release."""
|
||||
asset_url = None
|
||||
for asset in release.get("assets", []):
|
||||
if asset["name"] == "dist.zip":
|
||||
asset_url = asset["url"]
|
||||
break
|
||||
|
||||
if not asset_url:
|
||||
raise ValueError("dist.zip not found in the release assets")
|
||||
|
||||
# Use a temporary file to download the zip content
|
||||
with tempfile.TemporaryFile() as tmp_file:
|
||||
headers = {"Accept": "application/octet-stream"}
|
||||
response = requests.get(
|
||||
asset_url, headers=headers, allow_redirects=True, timeout=REQUEST_TIMEOUT
|
||||
)
|
||||
response.raise_for_status() # Ensure we got a successful response
|
||||
|
||||
# Write the content to the temporary file
|
||||
tmp_file.write(response.content)
|
||||
|
||||
# Go back to the beginning of the temporary file
|
||||
tmp_file.seek(0)
|
||||
|
||||
# Extract the zip file content to the destination path
|
||||
with zipfile.ZipFile(tmp_file, "r") as zip_ref:
|
||||
zip_ref.extractall(destination_path)
|
||||
|
||||
|
||||
class FrontendManager:
|
||||
DEFAULT_FRONTEND_PATH = get_package_as_path('comfy', 'web/')
|
||||
CUSTOM_FRONTENDS_ROOT = add_model_folder_path("web_custom_versions", extensions=set())
|
||||
|
||||
@classmethod
|
||||
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
|
||||
"""
|
||||
Args:
|
||||
value (str): The version string to parse.
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: A tuple containing provider name and version.
|
||||
|
||||
Raises:
|
||||
argparse.ArgumentTypeError: If the version string is invalid.
|
||||
"""
|
||||
VERSION_PATTERN = r"^([a-zA-Z0-9][a-zA-Z0-9-]{0,38})/([a-zA-Z0-9_.-]+)@(v?\d+\.\d+\.\d+|latest)$"
|
||||
match_result = re.match(VERSION_PATTERN, value)
|
||||
if match_result is None:
|
||||
raise argparse.ArgumentTypeError(f"Invalid version string: {value}")
|
||||
|
||||
return match_result.group(1), match_result.group(2), match_result.group(3)
|
||||
|
||||
@classmethod
|
||||
def init_frontend_unsafe(cls, version_string: str) -> str:
|
||||
"""
|
||||
Initializes the frontend for the specified version.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string.
|
||||
|
||||
Returns:
|
||||
str: The path to the initialized frontend.
|
||||
|
||||
Raises:
|
||||
Exception: If there is an error during the initialization process.
|
||||
main error source might be request timeout or invalid URL.
|
||||
"""
|
||||
if version_string == DEFAULT_VERSION_STRING:
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
|
||||
repo_owner, repo_name, version = cls.parse_version_string(version_string)
|
||||
provider = FrontEndProvider(repo_owner, repo_name)
|
||||
release = provider.get_release(version)
|
||||
|
||||
semantic_version = release["tag_name"].lstrip("v")
|
||||
web_root = str(
|
||||
Path(cls.CUSTOM_FRONTENDS_ROOT) / provider.folder_name / semantic_version
|
||||
)
|
||||
if not os.path.exists(web_root):
|
||||
os.makedirs(web_root, exist_ok=True)
|
||||
logging.info(
|
||||
"Downloading frontend(%s) version(%s) to (%s)",
|
||||
provider.folder_name,
|
||||
semantic_version,
|
||||
web_root,
|
||||
)
|
||||
logging.debug(release)
|
||||
download_release_asset_zip(release, destination_path=web_root)
|
||||
return web_root
|
||||
|
||||
@classmethod
|
||||
def init_frontend(cls, version_string: str) -> str:
|
||||
"""
|
||||
Initializes the frontend with the specified version string.
|
||||
|
||||
Args:
|
||||
version_string (str): The version string to initialize the frontend with.
|
||||
|
||||
Returns:
|
||||
str: The path of the initialized frontend.
|
||||
"""
|
||||
try:
|
||||
return cls.init_frontend_unsafe(version_string)
|
||||
except Exception as e:
|
||||
logging.error("Failed to initialize frontend: %s", e)
|
||||
logging.info("Falling back to the default frontend.")
|
||||
return cls.DEFAULT_FRONTEND_PATH
|
||||
@ -363,7 +363,7 @@ class ControlNet(nn.Module):
|
||||
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
|
||||
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
|
||||
if idx < len(control_type):
|
||||
feat_seq += self.task_embedding[control_type[idx]]
|
||||
feat_seq += self.task_embedding[control_type[idx]].to(dtype=feat_seq.dtype, device=feat_seq.device)
|
||||
|
||||
inputs.append(feat_seq.unsqueeze(1))
|
||||
condition_list.append(controlnet_cond)
|
||||
|
||||
@ -15,6 +15,9 @@ from . import options
|
||||
from .cli_args_types import LatentPreviewMethod, Configuration, ConfigurationExtender, ConfigChangeHandler, EnumAction, \
|
||||
EnhancedConfigArgParser
|
||||
|
||||
# todo: move this
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
|
||||
def _create_parser() -> EnhancedConfigArgParser:
|
||||
parser = EnhancedConfigArgParser(default_config_files=['config.yaml', 'config.json'],
|
||||
@ -108,6 +111,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
parser.add_argument("--disable-smart-memory", action="store_true",
|
||||
help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
parser.add_argument("--deterministic", action="store_true",
|
||||
@ -160,13 +164,43 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||
parser.add_argument("--force-hf-local-dir-mode", action="store_true", help="Download repos from huggingface.co to the models/huggingface directory with the \"local_dir\" argument instead of models/huggingface_cache with the \"cache_dir\" argument, recreating the traditional file structure.")
|
||||
|
||||
parser.add_argument(
|
||||
"--front-end-version",
|
||||
type=str,
|
||||
default=DEFAULT_VERSION_STRING,
|
||||
help="""
|
||||
Specifies the version of the frontend to be used. This command needs internet connectivity to query and
|
||||
download available frontend implementations from GitHub releases.
|
||||
|
||||
The version string should be in the format of:
|
||||
[repoOwner]/[repoName]@[version]
|
||||
where version is one of: "latest" or a valid version number (e.g. "1.0.0")
|
||||
""",
|
||||
)
|
||||
|
||||
def is_valid_directory(path: Optional[str]) -> Optional[str]:
|
||||
"""Validate if the given path is a directory."""
|
||||
if path is None:
|
||||
return None
|
||||
|
||||
if not os.path.isdir(path):
|
||||
raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
|
||||
return path
|
||||
|
||||
parser.add_argument(
|
||||
"--front-end-root",
|
||||
type=is_valid_directory,
|
||||
default=None,
|
||||
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
|
||||
)
|
||||
|
||||
# now give plugins a chance to add configuration
|
||||
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
||||
try:
|
||||
plugin_callable: ConfigurationExtender | ModuleType = entry_point.load()
|
||||
if isinstance(plugin_callable, ModuleType):
|
||||
# todo: find the configuration extender in the module
|
||||
plugin_callable = ...
|
||||
raise ValueError("unexpected or unsupported plugin configuration type")
|
||||
else:
|
||||
parser_result = plugin_callable(parser)
|
||||
if parser_result is not None:
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from .component_model import files
|
||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import logging
|
||||
@ -43,6 +42,7 @@ class ClipVisionModel():
|
||||
else:
|
||||
raise ValueError(f"json_config had invalid value={json_config}")
|
||||
|
||||
self.image_size = config.get("image_size", 224)
|
||||
self.load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
self.dtype = model_management.text_encoder_dtype(self.load_device)
|
||||
@ -58,7 +58,7 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device)).float()
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output=-2)
|
||||
|
||||
outputs = Output()
|
||||
@ -101,7 +101,10 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
|
||||
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
|
||||
json_config = files.get_path_as_dict(None, "clip_vision_config_h.json")
|
||||
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
|
||||
json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json")
|
||||
if sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
|
||||
json_config = files.get_path_as_dict(None, "clip_vision_config_vitl_336.json")
|
||||
else:
|
||||
json_config = files.get_path_as_dict(None, "clip_vision_config_vitl.json")
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
18
comfy/clip_vision_config_vitl_336.json
Normal file
18
comfy/clip_vision_config_vitl_336.json
Normal file
@ -0,0 +1,18 @@
|
||||
{
|
||||
"attention_dropout": 0.0,
|
||||
"dropout": 0.0,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 1024,
|
||||
"image_size": 336,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 4096,
|
||||
"layer_norm_eps": 1e-5,
|
||||
"model_type": "clip_vision_model",
|
||||
"num_attention_heads": 16,
|
||||
"num_channels": 3,
|
||||
"num_hidden_layers": 24,
|
||||
"patch_size": 14,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float32"
|
||||
}
|
||||
@ -19,7 +19,7 @@ from .. import interruption
|
||||
from .. import model_management
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress, ValidationTuple, ValidateInputsTuple, \
|
||||
ValidationErrorDict, NodeErrorsDictValue
|
||||
ValidationErrorDict, NodeErrorsDictValue, ValidationErrorExtraInfoDict
|
||||
from ..component_model.queue_types import QueueTuple, HistoryEntry, QueueItem, MAXIMUM_HISTORY_SIZE, ExecutionStatus
|
||||
from ..execution_context import new_execution_context, ExecutionContext
|
||||
from ..nodes.package import import_all_nodes_in_workspace
|
||||
@ -318,6 +318,8 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item
|
||||
to_delete = True
|
||||
elif unique_id not in old_prompt:
|
||||
to_delete = True
|
||||
elif class_type != old_prompt[unique_id]['class_type']:
|
||||
to_delete = True
|
||||
elif inputs == old_prompt[unique_id]['inputs']:
|
||||
for x in inputs:
|
||||
input_data = inputs[x]
|
||||
@ -731,13 +733,18 @@ def validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple:
|
||||
span.set_status(Status(StatusCode.ERROR))
|
||||
if res.error is not None and len(res.error) > 0:
|
||||
span.set_attributes({
|
||||
f"error.{k}": v for k, v in res.error.items()
|
||||
f"error.{k}": v for k, v in res.error.items() if isinstance(v, (bool, str, bytes, int, float, list[str], list[int], list[float]))
|
||||
})
|
||||
if "extra_info" in res.error and isinstance(res.error["extra_info"], dict):
|
||||
extra_info: ValidationErrorExtraInfoDict = res.error["extra_info"]
|
||||
span.set_attributes({
|
||||
f"error.extra_info.{k}": v for k, v in extra_info.items() if isinstance(v, (str, list[str]))
|
||||
})
|
||||
if len(res.node_errors) > 0:
|
||||
for node_id, node_error in res.node_errors.items():
|
||||
for node_error_field, node_error_value in node_error.items():
|
||||
if isinstance(node_error_value, (str, bool, int, float)):
|
||||
span.set_attribute("node_errors.{node_id}.{node_error_field}", node_error_value)
|
||||
span.set_attribute(f"node_errors.{node_id}.{node_error_field}", node_error_value)
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@ -7,9 +7,9 @@ import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import struct
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
import hashlib
|
||||
from asyncio import Future, AbstractEventLoop
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
@ -19,7 +19,6 @@ from urllib.parse import quote, urlencode
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import sys
|
||||
from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from aiohttp import web
|
||||
@ -30,6 +29,7 @@ from .latent_preview_image_encoding import encode_preview_image
|
||||
from .. import interruption
|
||||
from .. import model_management
|
||||
from .. import utils
|
||||
from ..app.frontend_management import FrontendManager
|
||||
from ..app.user_manager import UserManager
|
||||
from ..cli_args import args
|
||||
from ..client.client_types import FileOutput
|
||||
@ -115,10 +115,11 @@ class PromptServer(ExecutorToClientProgress):
|
||||
handler_args={'max_field_size': 16380},
|
||||
middlewares=middlewares)
|
||||
self.sockets = dict()
|
||||
web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../web")
|
||||
if not os.path.exists(web_root_path):
|
||||
web_root_path = get_package_as_path('comfy', 'web/')
|
||||
self.web_root = web_root_path
|
||||
self.web_root = (
|
||||
FrontendManager.init_frontend(args.front_end_version)
|
||||
if args.front_end_root is None
|
||||
else args.front_end_root
|
||||
)
|
||||
routes = web.RouteTableDef()
|
||||
self.routes: web.RouteTableDef = routes
|
||||
self.last_node_id = None
|
||||
@ -191,10 +192,12 @@ class PromptServer(ExecutorToClientProgress):
|
||||
return type_dir, dir_type
|
||||
|
||||
def compare_image_hash(filepath, image):
|
||||
hasher = node_helpers.hasher()
|
||||
|
||||
# function to compare hashes of two images to see if it already exists, fix to #3465
|
||||
if os.path.exists(filepath):
|
||||
a = hashlib.sha256()
|
||||
b = hashlib.sha256()
|
||||
a = hasher()
|
||||
b = hasher()
|
||||
with open(filepath, "rb") as f:
|
||||
a.update(f.read())
|
||||
b.update(image.file.read())
|
||||
@ -233,7 +236,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
else:
|
||||
i = 1
|
||||
while os.path.exists(filepath):
|
||||
if compare_image_hash(filepath, image): #compare hash to prevent saving of duplicates with same name, fix for #3465
|
||||
if compare_image_hash(filepath, image): # compare hash to prevent saving of duplicates with same name, fix for #3465
|
||||
image_is_duplicate = True
|
||||
break
|
||||
filename = f"{split[0]} ({i}){split[1]}"
|
||||
@ -719,6 +722,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
@external_address.setter
|
||||
def external_address(self, value):
|
||||
self._external_address = value
|
||||
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
|
||||
|
||||
@ -45,6 +45,7 @@ class ControlBase:
|
||||
self.timestep_range = None
|
||||
self.compression_ratio = 8
|
||||
self.upscale_algorithm = 'nearest-exact'
|
||||
self.extra_args = {}
|
||||
|
||||
if device is None:
|
||||
device = model_management.get_torch_device()
|
||||
@ -90,6 +91,7 @@ class ControlBase:
|
||||
c.compression_ratio = self.compression_ratio
|
||||
c.upscale_algorithm = self.upscale_algorithm
|
||||
c.latent_format = self.latent_format
|
||||
c.extra_args = self.extra_args.copy()
|
||||
c.vae = self.vae
|
||||
|
||||
def inference_memory_requirements(self, dtype):
|
||||
@ -135,6 +137,10 @@ class ControlBase:
|
||||
o[i] = prev_val + o[i] #TODO: change back to inplace add if shared tensors stop being an issue
|
||||
return out
|
||||
|
||||
def set_extra_arg(self, argument, value=None):
|
||||
self.extra_args[argument] = value
|
||||
|
||||
|
||||
class ControlNet(ControlBase):
|
||||
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None):
|
||||
super().__init__(device)
|
||||
@ -190,7 +196,7 @@ class ControlNet(ControlBase):
|
||||
timestep = self.model_sampling_current.timestep(t)
|
||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y, **self.extra_args)
|
||||
return self.control_merge(control, control_prev, output_dtype)
|
||||
|
||||
def copy(self):
|
||||
|
||||
@ -61,8 +61,9 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
||||
beta_schedule = sampling_settings.get("beta_schedule", "linear")
|
||||
linear_start = sampling_settings.get("linear_start", 0.00085)
|
||||
linear_end = sampling_settings.get("linear_end", 0.012)
|
||||
timesteps = sampling_settings.get("timesteps", 1000)
|
||||
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
||||
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
||||
self.sigma_data = 1.0
|
||||
|
||||
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
||||
|
||||
@ -1,3 +1,7 @@
|
||||
import hashlib
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
from PIL import ImageFile, UnidentifiedImageError
|
||||
|
||||
def conditioning_set_values(conditioning, values={}):
|
||||
@ -22,3 +26,12 @@ def pillow(fn, arg):
|
||||
if prev_value is not None:
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = prev_value
|
||||
return x
|
||||
|
||||
def hasher():
|
||||
hashfuncs = {
|
||||
"md5": hashlib.md5,
|
||||
"sha1": hashlib.sha1,
|
||||
"sha256": hashlib.sha256,
|
||||
"sha512": hashlib.sha512
|
||||
}
|
||||
return hashfuncs[args.default_hashing_function]
|
||||
|
||||
@ -745,7 +745,7 @@ class ControlNetApply:
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "apply_controlnet"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_controlnet(self, conditioning, control_net, image, strength):
|
||||
if strength == 0:
|
||||
@ -780,7 +780,7 @@ class ControlNetApplyAdvanced:
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
FUNCTION = "apply_controlnet"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
|
||||
if strength == 0:
|
||||
|
||||
@ -17,10 +17,10 @@ from . import model_detection
|
||||
from . import model_management
|
||||
from . import model_patcher
|
||||
from . import model_sampling
|
||||
from . import sa_t5
|
||||
from .text_encoders import sa_t5
|
||||
from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from . import sd3_clip
|
||||
from .text_encoders import sd3_clip
|
||||
from . import sdxl_clip
|
||||
from . import utils
|
||||
from .ldm.audio.autoencoder import AudioOobleckVAE
|
||||
|
||||
@ -5,8 +5,8 @@ from . import utils
|
||||
from . import sd1_clip
|
||||
from . import sd2_clip
|
||||
from . import sdxl_clip
|
||||
from . import sd3_clip
|
||||
from . import sa_t5
|
||||
from .text_encoders import sd3_clip
|
||||
from .text_encoders import sa_t5
|
||||
from .text_encoders import aura_t5
|
||||
|
||||
from . import supported_models_base
|
||||
|
||||
@ -2,7 +2,7 @@ from importlib import resources
|
||||
|
||||
from comfy import sd1_clip
|
||||
from .llama_tokenizer import LLAMATokenizer
|
||||
from .. import t5
|
||||
from ..text_encoders import t5
|
||||
from ..component_model.files import get_path_as_dict
|
||||
|
||||
|
||||
|
||||
@ -1,14 +1,13 @@
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
import comfy.t5
|
||||
import comfy.text_encoders.t5
|
||||
from comfy import sd1_clip
|
||||
from comfy.component_model import files
|
||||
|
||||
|
||||
class T5BaseModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_base.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
|
||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||
@ -1,11 +1,10 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.t5
|
||||
import comfy.text_encoders.t5
|
||||
from comfy import sd1_clip
|
||||
from comfy import sdxl_clip
|
||||
from comfy.component_model import files
|
||||
@ -13,13 +12,13 @@ from comfy.component_model import files
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.t5.T5)
|
||||
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json", package="comfy.text_encoders")
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5)
|
||||
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None):
|
||||
tokenizer_path = files.get_package_as_path("comfy.t5_tokenizer")
|
||||
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||
|
||||
|
||||
@ -17,7 +17,6 @@ function getResourceURL(subfolder, filename, type = "input") {
|
||||
"filename=" + encodeURIComponent(filename),
|
||||
"type=" + type,
|
||||
"subfolder=" + subfolder,
|
||||
app.getPreviewFormatParam().substring(1),
|
||||
app.getRandParam().substring(1)
|
||||
].join("&")
|
||||
|
||||
|
||||
@ -101,7 +101,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
|
||||
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
|
||||
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
|
||||
}}
|
||||
CATEGORY = "_for_testing/sd3"
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"TripleCLIPLoader": TripleCLIPLoader,
|
||||
|
||||
37
comfy_extras/nodes_controlnet.py
Normal file
37
comfy_extras/nodes_controlnet.py
Normal file
@ -0,0 +1,37 @@
|
||||
|
||||
UNION_CONTROLNET_TYPES = {"auto": -1,
|
||||
"openpose": 0,
|
||||
"depth": 1,
|
||||
"hed/pidi/scribble/ted": 2,
|
||||
"canny/lineart/anime_lineart/mlsd": 3,
|
||||
"normal": 4,
|
||||
"segment": 5,
|
||||
"tile": 6,
|
||||
"repaint": 7,
|
||||
}
|
||||
|
||||
class SetUnionControlNetType:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"control_net": ("CONTROL_NET", ),
|
||||
"type": (list(UNION_CONTROLNET_TYPES.keys()),)
|
||||
}}
|
||||
|
||||
CATEGORY = "conditioning/controlnet"
|
||||
RETURN_TYPES = ("CONTROL_NET",)
|
||||
|
||||
FUNCTION = "set_controlnet_type"
|
||||
|
||||
def set_controlnet_type(self, control_net, type):
|
||||
control_net = control_net.copy()
|
||||
type_number = UNION_CONTROLNET_TYPES[type]
|
||||
if type_number >= 0:
|
||||
control_net.set_extra_arg("control_type", [type_number])
|
||||
else:
|
||||
control_net.set_extra_arg("control_type", [])
|
||||
|
||||
return (control_net,)
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"SetUnionControlNetType": SetUnionControlNetType,
|
||||
}
|
||||
@ -1,6 +1,8 @@
|
||||
[pytest]
|
||||
markers =
|
||||
inference: mark as inference test (deselect with '-m "not inference"')
|
||||
testpaths = tests
|
||||
testpaths =
|
||||
tests
|
||||
tests-unit
|
||||
addopts = -s
|
||||
asyncio_mode = auto
|
||||
8
tests-unit/README.md
Normal file
8
tests-unit/README.md
Normal file
@ -0,0 +1,8 @@
|
||||
# Pytest Unit Tests
|
||||
|
||||
## Install test dependencies
|
||||
|
||||
`pip install -r tests-units/requirements.txt`
|
||||
|
||||
## Run tests
|
||||
`pytest tests-units/`
|
||||
0
tests-unit/app_test/__init__.py
Normal file
0
tests-unit/app_test/__init__.py
Normal file
100
tests-unit/app_test/frontend_manager_test.py
Normal file
100
tests-unit/app_test/frontend_manager_test.py
Normal file
@ -0,0 +1,100 @@
|
||||
import argparse
|
||||
import pytest
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from app.frontend_management import (
|
||||
FrontendManager,
|
||||
FrontEndProvider,
|
||||
Release,
|
||||
)
|
||||
from comfy.cli_args import DEFAULT_VERSION_STRING
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_releases():
|
||||
return [
|
||||
Release(
|
||||
id=1,
|
||||
tag_name="1.0.0",
|
||||
name="Release 1.0.0",
|
||||
prerelease=False,
|
||||
created_at="2022-01-01T00:00:00Z",
|
||||
published_at="2022-01-01T00:00:00Z",
|
||||
body="Release notes for 1.0.0",
|
||||
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
|
||||
),
|
||||
Release(
|
||||
id=2,
|
||||
tag_name="2.0.0",
|
||||
name="Release 2.0.0",
|
||||
prerelease=False,
|
||||
created_at="2022-02-01T00:00:00Z",
|
||||
published_at="2022-02-01T00:00:00Z",
|
||||
body="Release notes for 2.0.0",
|
||||
assets=[{"name": "dist.zip", "url": "https://example.com/dist.zip"}],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider(mock_releases):
|
||||
provider = FrontEndProvider(
|
||||
owner="test-owner",
|
||||
repo="test-repo",
|
||||
)
|
||||
provider.all_releases = mock_releases
|
||||
provider.latest_release = mock_releases[1]
|
||||
FrontendManager.PROVIDERS = [provider]
|
||||
return provider
|
||||
|
||||
|
||||
def test_get_release(mock_provider, mock_releases):
|
||||
version = "1.0.0"
|
||||
release = mock_provider.get_release(version)
|
||||
assert release == mock_releases[0]
|
||||
|
||||
|
||||
def test_get_release_latest(mock_provider, mock_releases):
|
||||
version = "latest"
|
||||
release = mock_provider.get_release(version)
|
||||
assert release == mock_releases[1]
|
||||
|
||||
|
||||
def test_get_release_invalid_version(mock_provider):
|
||||
version = "invalid"
|
||||
with pytest.raises(ValueError):
|
||||
mock_provider.get_release(version)
|
||||
|
||||
|
||||
def test_init_frontend_default():
|
||||
version_string = DEFAULT_VERSION_STRING
|
||||
frontend_path = FrontendManager.init_frontend(version_string)
|
||||
assert frontend_path == FrontendManager.DEFAULT_FRONTEND_PATH
|
||||
|
||||
|
||||
def test_init_frontend_invalid_version():
|
||||
version_string = "test-owner/test-repo@1.100.99"
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
|
||||
def test_init_frontend_invalid_provider():
|
||||
version_string = "invalid/invalid@latest"
|
||||
with pytest.raises(HTTPError):
|
||||
FrontendManager.init_frontend_unsafe(version_string)
|
||||
|
||||
|
||||
def test_parse_version_string():
|
||||
version_string = "owner/repo@1.0.0"
|
||||
repo_owner, repo_name, version = FrontendManager.parse_version_string(
|
||||
version_string
|
||||
)
|
||||
assert repo_owner == "owner"
|
||||
assert repo_name == "repo"
|
||||
assert version == "1.0.0"
|
||||
|
||||
|
||||
def test_parse_version_string_invalid():
|
||||
version_string = "invalid"
|
||||
with pytest.raises(argparse.ArgumentTypeError):
|
||||
FrontendManager.parse_version_string(version_string)
|
||||
1
tests-unit/requirements.txt
Normal file
1
tests-unit/requirements.txt
Normal file
@ -0,0 +1 @@
|
||||
pytest>=7.8.0
|
||||
Loading…
Reference in New Issue
Block a user