Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-09-23 12:50:31 -07:00
commit fa3176f96f
51 changed files with 51259 additions and 44715 deletions

View File

@ -488,6 +488,7 @@ disable=raw-checker-failed,
too-many-return-statements, too-many-return-statements,
too-many-branches, too-many-branches,
too-many-arguments, too-many-arguments,
too-many-positional-arguments,
too-many-locals, too-many-locals,
too-many-statements, too-many-statements,
too-many-boolean-expressions, too-many-boolean-expressions,

View File

@ -1,10 +1,9 @@
import logging
from typing import Optional from typing import Optional
from aiohttp import web from aiohttp import web
from ...services.file_service import FileService from ...services.file_service import FileService
from ....cmd.folder_paths import models_dir, user_directory, output_directory from ....cmd.folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
class InternalRoutes: class InternalRoutes:
@ -41,6 +40,13 @@ class InternalRoutes:
# todo: applications really shouldn't serve logs like this # todo: applications really shouldn't serve logs like this
return web.json_response({}) return web.json_response({})
@self.routes.get('/folder_paths')
async def get_folder_paths(request):
response = {}
for key in folder_names_and_paths:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
def get_app(self): def get_app(self):
if self._app is None: if self._app is None:
self._app = web.Application() self._app = web.Application()

View File

@ -119,6 +119,29 @@ class UserManager():
@routes.get("/userdata") @routes.get("/userdata")
async def listuserdata(request): async def listuserdata(request):
"""
List user data files in a specified directory.
This endpoint allows listing files in a user's data directory, with options for recursion,
full file information, and path splitting.
Query Parameters:
- dir (required): The directory to list files from.
- recurse (optional): If "true", recursively list files in subdirectories.
- full_info (optional): If "true", return detailed file information (path, size, modified time).
- split (optional): If "true", split file paths into components (only applies when full_info is false).
Returns:
- 400: If 'dir' parameter is missing.
- 403: If the requested path is not allowed.
- 404: If the requested directory does not exist.
- 200: JSON response with the list of files or file information.
The response format depends on the query parameters:
- Default: List of relative file paths.
- full_info=true: List of dictionaries with file details.
- split=true (and full_info=false): List of lists, each containing path components.
"""
directory = request.rel_url.query.get('dir', '') directory = request.rel_url.query.get('dir', '')
if not directory: if not directory:
return web.Response(status=400, text="Directory not provided") return web.Response(status=400, text="Directory not provided")

View File

@ -1,10 +1,20 @@
import itertools import itertools
from typing import Sequence, Mapping from typing import Sequence, Mapping, Dict
from .cmd.execution import nodes from .cmd.execution import nodes
from .graph import DynamicPrompt from .graph import DynamicPrompt
from .graph_utils import is_link from .graph_utils import is_link
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool:
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet: class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache): def __init__(self, dynprompt, node_ids, is_changed_cache):
@ -102,7 +112,7 @@ class CacheKeySetInputSignature(CacheKeySet):
class_type = node["class_type"] class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)] signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values(): if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id) signature.append(node_id)
inputs = node["inputs"] inputs = node["inputs"]
for key in sorted(inputs.keys()): for key in sorted(inputs.keys()):

View File

@ -16,6 +16,7 @@ class ControlNet(MMDiT):
def __init__( def __init__(
self, self,
num_blocks = None, num_blocks = None,
control_latent_channels = None,
dtype = None, dtype = None,
device = None, device = None,
operations = None, operations = None,
@ -27,10 +28,13 @@ class ControlNet(MMDiT):
for _ in range(len(self.joint_blocks)): for _ in range(len(self.joint_blocks)):
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype)) self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
if control_latent_channels is None:
control_latent_channels = self.in_channels
self.pos_embed_input = PatchEmbed( self.pos_embed_input = PatchEmbed(
None, None,
self.patch_size, self.patch_size,
self.in_channels, control_latent_channels,
self.hidden_size, self.hidden_size,
bias=True, bias=True,
strict_img_size=False, strict_img_size=False,

View File

@ -28,8 +28,8 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument('-w', "--cwd", type=str, default=None, parser.add_argument('-w', "--cwd", type=str, default=None,
help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.") help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.")
parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0", parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::",
help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)") help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.") parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*",
help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.") help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")

View File

@ -51,7 +51,46 @@ temp_directory = os.path.join(get_base_path(), "temp")
input_directory = os.path.join(get_base_path(), "input") input_directory = os.path.join(get_base_path(), "input")
user_directory = os.path.join(get_base_path(), "user") user_directory = os.path.join(get_base_path(), "user")
_filename_list_cache = {} filename_list_cache = {}
class CacheHelper:
"""
Helper class for managing file list cache data.
"""
def __init__(self):
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
self.active = False
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
if not self.active:
return default
return self.cache.get(key, default)
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
if self.active:
self.cache[key] = value
def clear(self):
self.cache.clear()
def __enter__(self):
self.active = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.active = False
self.clear()
cache_helper = CacheHelper()
def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name)
if not os.path.exists(input_directory): if not os.path.exists(input_directory):
try: try:
@ -150,7 +189,7 @@ def exists_annotated_filepath(name):
return os.path.exists(filepath) return os.path.exists(filepath)
def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, extensions: Optional[set[str]] = None) -> str: def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, extensions: Optional[set[str]] = None, is_default: bool = False) -> str:
""" """
Registers a model path for the given canonical name. Registers a model path for the given canonical name.
:param folder_name: the folder name :param folder_name: the folder name
@ -165,7 +204,10 @@ def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, e
folder_path = folder_names_and_paths[folder_name] folder_path = folder_names_and_paths[folder_name]
if full_folder_path not in folder_path.paths: if full_folder_path not in folder_path.paths:
folder_path.paths.append(full_folder_path) if is_default:
folder_path.paths.insert(0, full_folder_path)
else:
folder_path.paths.append(full_folder_path)
if extensions is not None: if extensions is not None:
folder_path.supported_extensions |= extensions folder_path.supported_extensions |= extensions
@ -244,7 +286,15 @@ def get_full_path(folder_name, filename) -> Optional[str | bytes | os.PathLike]:
return None return None
def get_filename_list_(folder_name): def get_full_path_or_raise(folder_name: str, filename: str) -> str:
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
return full_path
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths global folder_names_and_paths
output_list = set() output_list = set()
folders = folder_names_and_paths[folder_name] folders = folder_names_and_paths[folder_name]
@ -257,12 +307,17 @@ def get_filename_list_(folder_name):
return sorted(list(output_list)), output_folders, time.perf_counter() return sorted(list(output_list)), output_folders, time.perf_counter()
def cached_filename_list_(folder_name): def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
global _filename_list_cache strong_cache = cache_helper.get(folder_name)
if strong_cache is not None:
return strong_cache
global filename_list_cache
global folder_names_and_paths global folder_names_and_paths
if folder_name not in _filename_list_cache: folder_name = map_legacy(folder_name)
if folder_name not in filename_list_cache:
return None return None
out = _filename_list_cache[folder_name] out = filename_list_cache[folder_name]
for x in out[1]: for x in out[1]:
time_modified = out[1][x] time_modified = out[1][x]
@ -279,12 +334,14 @@ def cached_filename_list_(folder_name):
return out return out
def get_filename_list(folder_name): def get_filename_list(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
out = cached_filename_list_(folder_name) out = cached_filename_list_(folder_name)
if out is None: if out is None:
out = get_filename_list_(folder_name) out = get_filename_list_(folder_name)
global _filename_list_cache global filename_list_cache
_filename_list_cache[folder_name] = out filename_list_cache[folder_name] = out
cache_helper.set(folder_name, out)
return list(out[0]) return list(out[0])
@ -345,8 +402,8 @@ def create_directories():
def invalidate_cache(folder_name): def invalidate_cache(folder_name):
global _filename_list_cache global filename_list_cache
_filename_list_cache.pop(folder_name, None) filename_list_cache.pop(folder_name, None)
def filter_files_content_types(files: list[str], content_types: list[Literal["image", "video", "audio"]]) -> list[str]: def filter_files_content_types(files: list[str], content_types: list[Literal["image", "video", "audio"]]) -> list[str]:

View File

@ -82,7 +82,10 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
async def run(server, address='', port=8188, verbose=True, call_on_start=None): async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop()) addresses = []
for addr in address.split(","):
addresses.append((addr, port))
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
def cleanup_temp(): def cleanup_temp():
@ -223,12 +226,15 @@ async def main(from_script_dir: Optional[Path] = None):
import webbrowser import webbrowser
if os.name == 'nt' and address == '0.0.0.0' or address == '': if os.name == 'nt' and address == '0.0.0.0' or address == '':
address = '127.0.0.1' address = '127.0.0.1'
if ':' in address:
address = "[{}]".format(address)
webbrowser.open(f"http://{address}:{port}") webbrowser.open(f"http://{address}:{port}")
call_on_start = startup_server call_on_start = startup_server
server.address = args.listen server.address = args.listen
server.port = args.port server.port = args.port
try: try:
await server.setup() await server.setup()
await run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server, await run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server,

View File

@ -29,7 +29,7 @@ from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
from typing_extensions import NamedTuple from typing_extensions import NamedTuple
from .latent_preview_image_encoding import encode_preview_image from .latent_preview_image_encoding import encode_preview_image
from .. import interruption from .. import interruption, model_management
from .. import node_helpers from .. import node_helpers
from .. import utils from .. import utils
from ..api_server.routes.internal.internal_routes import InternalRoutes from ..api_server.routes.internal.internal_routes import InternalRoutes
@ -256,6 +256,12 @@ class PromptServer(ExecutorToClientProgress):
embeddings = folder_paths.get_filename_list("embeddings") embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings))) return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@routes.get("/models")
def list_model_types(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
return web.json_response(model_types)
@routes.get("/models/{folder}") @routes.get("/models/{folder}")
async def get_models(request): async def get_models(request):
folder = request.match_info.get("folder", None) folder = request.match_info.get("folder", None)
@ -505,12 +511,17 @@ class PromptServer(ExecutorToClientProgress):
async def system_stats(request): async def system_stats(request):
device = get_torch_device() device = get_torch_device()
device_name = get_torch_device_name(device) device_name = get_torch_device_name(device)
cpu_device = model_management.torch.device("cpu")
ram_total = model_management.get_total_memory(cpu_device)
ram_free = model_management.get_free_memory(cpu_device)
vram_total, torch_vram_total = get_total_memory(device, torch_total_too=True) vram_total, torch_vram_total = get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = get_free_memory(device, torch_free_too=True) vram_free, torch_vram_free = get_free_memory(device, torch_free_too=True)
system_stats = { system_stats = {
"system": { "system": {
"os": os.name, "os": os.name,
"ram_total": ram_total,
"ram_free": ram_free,
"comfyui_version": get_comfyui_version(), "comfyui_version": get_comfyui_version(),
"python_version": sys.version, "python_version": sys.version,
"pytorch_version": torch_version, "pytorch_version": torch_version,
@ -568,14 +579,15 @@ class PromptServer(ExecutorToClientProgress):
@routes.get("/object_info") @routes.get("/object_info")
async def get_object_info(request): async def get_object_info(request):
out = {} with folder_paths.cache_helper:
for x in self.nodes.NODE_CLASS_MAPPINGS: out = {}
try: for x in self.nodes.NODE_CLASS_MAPPINGS:
out[x] = node_info(x) try:
except Exception as e: out[x] = node_info(x)
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") except Exception as e:
logging.error(traceback.format_exc()) logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
return web.json_response(out) logging.error(traceback.format_exc())
return web.json_response(out)
@routes.get("/object_info/{node_class}") @routes.get("/object_info/{node_class}")
async def get_object_info_node(request): async def get_object_info_node(request):
@ -969,17 +981,29 @@ class PromptServer(ExecutorToClientProgress):
await self.send(*msg) await self.send(*msg)
async def start(self, address: str | None, port: int | None, verbose=True, call_on_start=None): async def start(self, address: str | None, port: int | None, verbose=True, call_on_start=None):
await self.start_multi_address([(address, port)], call_on_start=call_on_start, verbose=verbose)
async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
runner = web.AppRunner(self.app, access_log=None) runner = web.AppRunner(self.app, access_log=None)
await runner.setup() await runner.setup()
site = web.TCPSite(runner, host=address, port=port) for addr in addresses:
await site.start() address = addr[0]
port = addr[1]
site = web.TCPSite(runner, address, port)
await site.start()
self.address = address if not hasattr(self, 'address'):
self.port = port self.address = address #TODO: remove this
self.port = port
if ':' in address:
address_print = "[{}]".format(address)
else:
address_print = address
if verbose: if verbose:
logging.info("Starting server") logging.info("Starting server")
logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port)) logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address_print == "0.0.0.0" else address, port))
if call_on_start is not None: if call_on_start is not None:
call_on_start(address, port) call_on_start(address, port)

View File

@ -30,10 +30,9 @@ from . import model_patcher
from . import ops from . import ops
from . import utils from . import utils
from .cldm import cldm, mmdit from .cldm import cldm, mmdit
from .ldm import hydit
from .ldm.cascade import controlnet as cascade_controlnet from .ldm.cascade import controlnet as cascade_controlnet
from .ldm.flux import controlnet as controlnet_flux from .ldm.flux import controlnet as controlnet_flux
from .ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES from .ldm.hydit.controlnet import HunYuanControlNet
from .t2i_adapter import adapter from .t2i_adapter import adapter
@ -81,13 +80,21 @@ class ControlBase:
self.previous_controlnet = None self.previous_controlnet = None
self.extra_conds = [] self.extra_conds = []
self.strength_type = StrengthType.CONSTANT self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None): def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint self.cond_hint_original = cond_hint
self.strength = strength self.strength = strength
self.timestep_percent_range = timestep_percent_range self.timestep_percent_range = timestep_percent_range
if self.latent_format is not None: if self.latent_format is not None:
if vae is None:
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
self.vae = vae self.vae = vae
self.extra_concat_orig = extra_concat.copy()
if self.concat_mask and len(self.extra_concat_orig) == 0:
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
return self return self
def pre_run(self, model, percent_to_timestep_function): def pre_run(self, model, percent_to_timestep_function):
@ -102,9 +109,9 @@ class ControlBase:
def cleanup(self): def cleanup(self):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
self.previous_controlnet.cleanup() self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint self.cond_hint = None
self.cond_hint = None self.extra_concat = None
self.timestep_range = None self.timestep_range = None
def get_models(self): def get_models(self):
@ -125,6 +132,8 @@ class ControlBase:
c.vae = self.vae c.vae = self.vae
c.extra_conds = self.extra_conds.copy() c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
def inference_memory_requirements(self, dtype): def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None: if self.previous_controlnet is not None:
@ -177,7 +186,7 @@ class ControlBase:
class ControlNet(ControlBase): class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, ckpt_name: str = None): def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, ckpt_name: str = None):
super().__init__(device) super().__init__(device)
self.control_model = control_model self.control_model = control_model
self.load_device = load_device self.load_device = load_device
@ -192,6 +201,7 @@ class ControlNet(ControlBase):
self.latent_format = latent_format self.latent_format = latent_format
self.extra_conds += extra_conds self.extra_conds += extra_conds
self.strength_type = strength_type self.strength_type = strength_type
self.concat_mask = concat_mask
def get_control(self, x_noisy, t, cond, batched_number): def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None control_prev = None
@ -216,6 +226,9 @@ class ControlNet(ControlBase):
compression_ratio = self.compression_ratio compression_ratio = self.compression_ratio
if self.vae is not None: if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio compression_ratio *= self.vae.downscale_ratio
else:
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center") self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None: if self.vae is not None:
loaded_models = model_management.loaded_models(only_currently_used=True) loaded_models = model_management.loaded_models(only_currently_used=True)
@ -223,6 +236,13 @@ class ControlNet(ControlBase):
model_management.load_models_gpu(loaded_models) model_management.load_models_gpu(loaded_models)
if self.latent_format is not None: if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint) self.cond_hint = self.latent_format.process_in(self.cond_hint)
if len(self.extra_concat_orig) > 0:
to_concat = []
for c in self.extra_concat_orig:
c = utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
to_concat.append(utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype) self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]: if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number) self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
@ -322,7 +342,7 @@ class ControlLoraOps:
class ControlLora(ControlNet): class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None): def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): # TODO? model_options
ControlBase.__init__(self, device) ControlBase.__init__(self, device)
self.control_weights = control_weights self.control_weights = control_weights
self.global_average_pooling = global_average_pooling self.global_average_pooling = global_average_pooling
@ -381,19 +401,27 @@ class ControlLora(ControlNet):
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype) return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def controlnet_config(sd): def controlnet_config(sd, model_options=None):
if model_options is None:
model_options = {}
model_config = model_detection.model_config_from_unet(sd, "", True) model_config = model_detection.model_config_from_unet(sd, "", True)
supported_inference_dtypes = model_config.supported_inference_dtypes unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
controlnet_config = model_config.unet_config
unet_dtype = model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = model_management.get_torch_device() load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = ops.manual_cast operations = model_options.get("custom_operations", None)
else: if operations is None:
operations = ops.disable_weight_init operations = ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
offload_device = model_management.unet_offload_device() offload_device = model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
@ -410,26 +438,35 @@ def controlnet_load_state_dict(control_model, sd):
return control_model return control_model
def load_controlnet_mmdit(sd): def load_controlnet_mmdit(sd, model_options=None):
if model_options is None:
model_options = {}
new_sd = model_detection.convert_diffusers_mmdit(sd, "") new_sd = model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
num_blocks = model_detection.count_blocks(new_sd, 'joint_blocks.{}.') num_blocks = model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd: for k in sd:
new_sd[k] = sd[k] new_sd[k] = sd[k]
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) concat_mask = False
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
if control_latent_channels == 17: # inpaint controlnet
concat_mask = True
control_model = mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd) control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = latent_formats.SD3() latent_format = latent_formats.SD3()
latent_format.shift_factor = 0 # SD3 controlnet weirdness latent_format.shift_factor = 0 # SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype) control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control return control
def load_controlnet_hunyuandit(controlnet_data): def load_controlnet_hunyuandit(controlnet_data, model_options=None):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data) if model_options is None:
model_options = {}
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
control_model = hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype) control_model = HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data) control_model = controlnet_load_state_dict(control_model, controlnet_data)
latent_format = latent_formats.SDXL() latent_format = latent_formats.SDXL()
@ -473,10 +510,10 @@ def load_controlnet_flux_instantx_union(sd, controlnet_class, weight_dtype, full
device = model_management.get_torch_device() device = model_management.get_torch_device()
if weight_dtype == "fp8_e4m3fn": if weight_dtype == torch.float8_e4m3fn or weight_dtype == "fp8_e4m3fn":
dtype = torch.float8_e4m3fn dtype = torch.float8_e4m3fn
operations = ops.manual_cast operations = ops.manual_cast
elif weight_dtype == "fp8_e5m2": elif weight_dtype == torch.float8_e5m2 or weight_dtype == "float8_e5m2":
dtype = torch.float8_e5m2 dtype = torch.float8_e5m2
operations = ops.manual_cast operations = ops.manual_cast
else: else:
@ -493,8 +530,10 @@ def load_controlnet_flux_instantx_union(sd, controlnet_class, weight_dtype, full
return control return control
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False): def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options=None):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd) if model_options is None:
model_options = {}
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = controlnet_flux.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_flux.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd) control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance'] extra_conds = ['y', 'guidance']
@ -502,9 +541,11 @@ def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False):
return control return control
def load_controlnet_flux_instantx(sd): def load_controlnet_flux_instantx(sd, model_options=None):
if model_options is None:
model_options = {}
new_sd = model_detection.convert_diffusers_mmdit(sd, "") new_sd = model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd) model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
for k in sd: for k in sd:
new_sd[k] = sd[k] new_sd[k] = sd[k]
@ -514,13 +555,16 @@ def load_controlnet_flux_instantx(sd):
num_union_modes = new_sd[union_cnet].shape[0] num_union_modes = new_sd[union_cnet].shape[0]
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4 control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
concat_mask = False
if control_latent_channels == 17:
concat_mask = True
control_model = controlnet_flux.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config) control_model = controlnet_flux.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd) control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = latent_formats.Flux() latent_format = latent_formats.Flux()
extra_conds = ['y', 'guidance'] extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds) control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control return control
@ -528,12 +572,15 @@ def convert_mistoline(sd):
return utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."}) return utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]): def load_controlnet_state_dict(state_dict, model=None, model_options=None, ckpt_name: str = None):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True) if model_options is None:
model_options = {}
controlnet_data = state_dict
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data) return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
if "lora_controlnet" in controlnet_data: if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data) return ControlLora(controlnet_data, model_options=model_options)
controlnet_config = None controlnet_config = None
supported_inference_dtypes = None supported_inference_dtypes = None
@ -590,13 +637,13 @@ def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
controlnet_data = new_sd controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data: elif "controlnet_blocks.0.weight" in controlnet_data:
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data: if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs_mistoline(controlnet_data) return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data: elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data) # SD3 diffusers controlnet return load_controlnet_mmdit(controlnet_data, model_options=model_options) # SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data: elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data) return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: # mistoline flux elif "controlnet_blocks.0.linear.weight" in controlnet_data: # mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True) return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
pth_key = 'control_model.zero_convs.0.0.weight' pth_key = 'control_model.zero_convs.0.0.weight'
pth = False pth = False
@ -608,25 +655,36 @@ def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
elif key in controlnet_data: elif key in controlnet_data:
prefix = "" prefix = ""
else: else:
net = load_t2i_adapter(controlnet_data) net = load_t2i_adapter(controlnet_data, model_options=model_options)
if net is None: if net is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path)) logging.error("error could not detect control model type.")
return net return net
if controlnet_config is None: if controlnet_config is None:
model_config = model_detection.model_config_from_unet(controlnet_data, prefix, True) model_config = model_detection.model_config_from_unet(controlnet_data, prefix, True)
supported_inference_dtypes = model_config.supported_inference_dtypes supported_inference_dtypes = list(model_config.supported_inference_dtypes)
controlnet_config = model_config.unet_config controlnet_config = model_config.unet_config
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = utils.weight_dtype(controlnet_data)
if supported_inference_dtypes is None:
supported_inference_dtypes = [model_management.unet_dtype()]
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
load_device = model_management.get_torch_device() load_device = model_management.get_torch_device()
if supported_inference_dtypes is None:
unet_dtype = model_management.unet_dtype()
else:
unet_dtype = model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None: operations = model_options.get("custom_operations", None)
controlnet_config["operations"] = ops.manual_cast if operations is None:
operations = ops.pick_operations(unet_dtype, manual_cast_dtype)
controlnet_config["operations"] = operations
controlnet_config["dtype"] = unet_dtype controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = model_management.unet_offload_device() controlnet_config["device"] = model_management.unet_offload_device()
controlnet_config.pop("out_channels") controlnet_config.pop("out_channels")
@ -663,15 +721,26 @@ def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
if len(unexpected) > 0: if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected)) logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = False filename = os.path.splitext(ckpt_name)[0]
filename = os.path.splitext(ckpt_path)[0] global_average_pooling = model_options.get("global_average_pooling", False)
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): # TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype, ckpt_name=filename) control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype, ckpt_name=filename)
return control return control
def load_controlnet(ckpt_path, model=None, model_options=None):
if model_options is None:
model_options = {}
if "global_average_pooling" not in model_options:
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): # TODO: smarter way of enabling global_average_pooling
model_options["global_average_pooling"] = True
cnet = load_controlnet_state_dict(utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options, ckpt_name=ckpt_path)
if cnet is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return cnet
class T2IAdapter(ControlBase): class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None): def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device) super().__init__(device)
@ -728,7 +797,7 @@ class T2IAdapter(ControlBase):
return c return c
def load_t2i_adapter(t2i_data): def load_t2i_adapter(t2i_data, model_options={}): # TODO: model_options
compression_ratio = 8 compression_ratio = 8
upscale_algorithm = 'nearest-exact' upscale_algorithm = 'nearest-exact'

View File

@ -17,6 +17,9 @@ def load_extra_path_config(yaml_path):
if "base_path" in conf: if "base_path" in conf:
base_path = conf.pop("base_path") base_path = conf.pop("base_path")
base_path = os.path.expandvars(os.path.expanduser(base_path)) base_path = os.path.expandvars(os.path.expanduser(base_path))
is_default = False
if "is_default" in conf:
is_default = conf.pop("is_default")
for x in conf: for x in conf:
for y in conf[x].split("\n"): for y in conf[x].split("\n"):
if len(y) == 0: if len(y) == 0:
@ -25,4 +28,4 @@ def load_extra_path_config(yaml_path):
if base_path is not None: if base_path is not None:
full_path = os.path.join(base_path, full_path) full_path = os.path.join(base_path, full_path)
logging.info("Adding extra search path {} {}".format(x, full_path)) logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path) folder_paths.add_model_folder_path(x, full_path, is_default)

View File

@ -45,6 +45,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
return append_zero(sigmas) return append_zero(sigmas)
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
epsilon = 1e-5 # avoid log(0)
x = torch.linspace(0, 1, n, device=device)
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
sigmas = clamp(torch.exp(lmb))
return sigmas
def to_d(x, sigma, denoised): def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative.""" """Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim) return (x - denoised) / utils.append_dims(sigma, x.ndim)

View File

@ -414,7 +414,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
weight *= strength_model weight *= strength_model
if isinstance(v, list): if isinstance(v, list):
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype),) v = (calculate_weight(v[1:], model_management.cast_to_device(v[0], weight.device, intermediate_dtype, copy=True), key, intermediate_dtype=intermediate_dtype),)
patch_type = "" patch_type = ""
if len(v) == 1: if len(v) == 1:

View File

@ -199,6 +199,8 @@ Visit the repository, accept the terms, and then do one of the following:
# a path was found for any reason, so we should invalidate the cache # a path was found for any reason, so we should invalidate the cache
if path is not None: if path is not None:
folder_paths.invalidate_cache(folder_name) folder_paths.invalidate_cache(folder_name)
if path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found, and no download candidates matched for the filename.")
return path return path

View File

@ -367,7 +367,7 @@ class LoadedModel:
self.model_unload() self.model_unload()
raise e raise e
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None: if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
with torch.no_grad(): with torch.no_grad():
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True) self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
@ -716,6 +716,8 @@ def maximum_vram_for_weights(device=None) -> int:
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)): def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
if model_params < 0:
model_params = 1000000000000000000000
if args.bf16_unet: if args.bf16_unet:
return torch.bfloat16 return torch.bfloat16
if args.fp16_unet: if args.fp16_unet:

View File

@ -287,17 +287,21 @@ class ModelPatcher(ModelManageable):
return list(p) return list(p)
def get_key_patches(self, filter_prefix=None): def get_key_patches(self, filter_prefix=None):
model_management.unload_model_clones(self)
model_sd = self.model_state_dict() model_sd = self.model_state_dict()
p = {} p = {}
for k in model_sd: for k in model_sd:
if filter_prefix is not None: if filter_prefix is not None:
if not k.startswith(filter_prefix): if not k.startswith(filter_prefix):
continue continue
if k in self.patches: bk = self.backup.get(k, None)
p[k] = [model_sd[k]] + self.patches[k] if bk is not None:
weight = bk.weight
else: else:
p[k] = (model_sd[k],) weight = model_sd[k]
if k in self.patches:
p[k] = [weight] + self.patches[k]
else:
p[k] = (weight,)
return p return p
def model_state_dict(self, filter_prefix=None): def model_state_dict(self, filter_prefix=None):

View File

@ -696,11 +696,11 @@ class VAELoader:
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes)) encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes)) decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder)) enc = utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
for k in enc: for k in enc:
sd_["taesd_encoder.{}".format(k)] = enc[k] sd_["taesd_encoder.{}".format(k)] = enc[k]
dec = utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder)) dec = utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
for k in dec: for k in dec:
sd_["taesd_decoder.{}".format(k)] = dec[k] sd_["taesd_decoder.{}".format(k)] = dec[k]
@ -769,7 +769,13 @@ class ControlNetLoaderWeights:
def load_controlnet(self, control_net_name, weight_dtype): def load_controlnet(self, control_net_name, weight_dtype):
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS) controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS)
controlnet_ = controlnet.load_controlnet(controlnet_path, weight_dtype=weight_dtype) model_options = {}
if weight_dtype == "float8_e5m2":
model_options["dtype"] = torch.float8_e5m2
elif weight_dtype == "float8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
controlnet_ = controlnet.load_controlnet(controlnet_path, model_options=model_options)
return (controlnet_,) return (controlnet_,)
class DiffControlNetLoader: class DiffControlNetLoader:
@ -800,6 +806,7 @@ class ControlNetApply:
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_controlnet" FUNCTION = "apply_controlnet"
DEPRECATED = True
CATEGORY = "conditioning/controlnet" CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, conditioning, control_net, image: RGBImageBatch, strength): def apply_controlnet(self, conditioning, control_net, image: RGBImageBatch, strength):
@ -829,7 +836,10 @@ class ControlNetApplyAdvanced:
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), "strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}), "start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}} },
"optional": {"vae": ("VAE", ),
}
}
RETURN_TYPES = ("CONDITIONING","CONDITIONING") RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative") RETURN_NAMES = ("positive", "negative")
@ -837,7 +847,7 @@ class ControlNetApplyAdvanced:
CATEGORY = "conditioning/controlnet" CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None): def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
if strength == 0: if strength == 0:
return (positive, negative) return (positive, negative)
@ -854,7 +864,7 @@ class ControlNetApplyAdvanced:
if prev_cnet in cnets: if prev_cnet in cnets:
c_net = cnets[prev_cnet] c_net = cnets[prev_cnet]
else: else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae) c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
c_net.set_previous_controlnet(prev_cnet) c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net cnets[prev_cnet] = c_net
@ -1932,8 +1942,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ConditioningSetArea": "Conditioning (Set Area)", "ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)", "ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
"ConditioningSetMask": "Conditioning (Set Mask)", "ConditioningSetMask": "Conditioning (Set Mask)",
"ControlNetApply": "Apply ControlNet", "ControlNetApply": "Apply ControlNet (OLD)",
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)", "ControlNetApplyAdvanced": "Apply ControlNet",
# Latent # Latent
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)", "VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
"SetLatentNoiseMask": "Set Latent Noise Mask", "SetLatentNoiseMask": "Set Latent Noise Mask",

View File

@ -264,7 +264,6 @@ def fp8_linear(self, input):
if len(input.shape) == 3: if len(input.shape) == 3:
inn = input.reshape(-1, input.shape[2]).to(dtype) inn = input.reshape(-1, input.shape[2]).to(dtype)
non_blocking = model_management.device_supports_non_blocking(input.device)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t() w = w.t()
@ -304,10 +303,10 @@ class fp8_ops(manual_cast):
return torch.nn.functional.linear(input, weight, bias) return torch.nn.functional.linear(input, weight, bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False):
if compute_dtype is None or weight_dtype == compute_dtype: if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init return disable_weight_init
if args.fast: if args.fast and not disable_fast_fp8:
if model_management.supports_fp8_compute(load_device): if model_management.supports_fp8_compute(load_device):
return fp8_ops return fp8_ops
return manual_cast return manual_cast

View File

@ -76,14 +76,14 @@ class CLIP:
clip = target.clip clip = target.clip
tokenizer = target.tokenizer tokenizer = target.tokenizer
load_device = model_management.text_encoder_device() load_device = model_options.get("load_device", model_management.text_encoder_device())
offload_device = model_management.text_encoder_offload_device() offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
dtype = model_options.get("dtype", None) dtype = model_options.get("dtype", None)
if dtype is None: if dtype is None:
dtype = model_management.text_encoder_dtype(load_device) dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype params['dtype'] = dtype
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)) params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
if "textmodel_json_config" not in params and textmodel_json_config is not None: if "textmodel_json_config" not in params and textmodel_json_config is not None:
params['textmodel_json_config'] = textmodel_json_config params['textmodel_json_config'] = textmodel_json_config
@ -469,12 +469,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = sa_t5.SAT5Tokenizer clip_target.tokenizer = sa_t5.SAT5Tokenizer
else: else:
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None) w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248: clip_target.clip = sd1_clip.SD1ClipModel
clip_target.clip = long_clipl.LongClipModel clip_target.tokenizer = sd1_clip.SD1Tokenizer
clip_target.tokenizer = long_clipl.LongClipTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2: elif len(clip_data) == 2:
if clip_type == CLIPType.SD3: if clip_type == CLIPType.SD3:
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False) clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
@ -499,10 +495,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = sd3_clip.SD3Tokenizer clip_target.tokenizer = sd3_clip.SD3Tokenizer
parameters = 0 parameters = 0
tokenizer_data = {}
for c in clip_data: for c in clip_data:
parameters += utils.calculate_parameters(c) parameters += utils.calculate_parameters(c)
tokenizer_data, model_options = long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters, model_options=model_options) clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
for c in clip_data: for c in clip_data:
m, u = clip.load_sd(c) m, u = clip.load_sd(c)
if len(m) > 0: if len(m) > 0:
@ -548,14 +546,22 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae) return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}): def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None):
if te_model_options is None:
te_model_options = {}
if model_options is None:
model_options = {}
sd = utils.load_torch_file(ckpt_path) sd = utils.load_torch_file(ckpt_path)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options) out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, ckpt_path=ckpt_path)
if out is None: if out is None:
raise RuntimeError("Could not detect model type of: {}".format(ckpt_path)) raise RuntimeError("Could not detect model type of: {}".format(ckpt_path))
return out return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, ckpt_path=""): def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, ckpt_path=""):
if te_model_options is None:
te_model_options = {}
if model_options is None:
model_options = {}
clip = None clip = None
clipvision = None clipvision = None
vae = None vae = None
@ -632,7 +638,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (_model_patcher, clip, vae, clipvision) return (_model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options: dict = None): # load unet in diffusers or regular format def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: Optional[str]=""): # load unet in diffusers or regular format
if model_options is None: if model_options is None:
model_options = {} model_options = {}
dtype = model_options.get("dtype", None) dtype = model_options.get("dtype", None)
@ -677,21 +683,21 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None): # load une
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", None) model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
model = model_config.get_model(new_sd, "") model = model_config.get_model(new_sd, "")
model = model.to(offload_device) model = model.to(offload_device)
model.load_model_weights(new_sd, "") model.load_model_weights(new_sd, "")
left_over = sd.keys() left_over = sd.keys()
if len(left_over) > 0: if len(left_over) > 0:
logging.info("left over keys in unet: {}".format(left_over)) logging.info("left over keys in unet: {}".format(left_over))
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device) return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device, ckpt_name=os.path.basename(ckpt_path))
def load_diffusion_model(unet_path, model_options: dict = None): def load_diffusion_model(unet_path, model_options: dict = None):
if model_options is None: if model_options is None:
model_options = {} model_options = {}
sd = utils.load_torch_file(unet_path) sd = utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options) model = load_diffusion_model_state_dict(sd, model_options=model_options, ckpt_path=unet_path)
if model is None: if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path)) logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path)) raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))

View File

@ -657,6 +657,7 @@ class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
self.sd_tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) self.sd_tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text: str, return_word_ids=False): def tokenize_with_weights(self, text: str, return_word_ids=False):
@ -698,6 +699,7 @@ class SD1ClipModel(torch.nn.Module):
self.clip_name = clip_name self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name) self.clip = "clip_{}".format(self.clip_name)
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config, **kwargs)) setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config, **kwargs))
self.dtypes = set() self.dtypes = set()

View File

@ -26,8 +26,11 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
class SDXLTokenizer: class SDXLTokenizer:
def __init__(self, embedding_directory=None, **kwargs): def __init__(self, embedding_directory=None, tokenizer_data=None, **kwargs):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) if tokenizer_data is None:
tokenizer_data = {}
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory) self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False): def tokenize_with_weights(self, text: str, return_word_ids=False):
@ -50,9 +53,12 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module): class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options=None):
super().__init__() super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options) if model_options is None:
model_options = {}
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options) self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = {dtype} self.dtypes = {dtype}
@ -69,7 +75,8 @@ class SDXLClipModel(torch.nn.Module):
token_weight_pairs_l = token_weight_pairs["l"] token_weight_pairs_l = token_weight_pairs["l"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l) l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled cut_to = min(l_out.shape[1], g_out.shape[1])
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
def load_sd(self, sd): def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd: if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -79,20 +86,29 @@ class SDXLClipModel(torch.nn.Module):
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel): class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, model_options={}): def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, model_options=None):
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options, textmodel_json_config=textmodel_json_config) super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options, textmodel_json_config=textmodel_json_config)
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer): class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}): def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g') super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer): class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
class StableCascadeClipG(sd1_clip.SDClipModel): class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None, model_options={}): def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None, model_options={}):
textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json") textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")

View File

@ -28,13 +28,15 @@ class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data=None): def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None: if tokenizer_data is None:
tokenizer_data = dict() tokenizer_data = dict()
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False): def tokenize_with_weights(self, text: str, return_word_ids=False):
out = {} out = {
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids) "l": self.clip_l.tokenize_with_weights(text, return_word_ids),
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids) "t5xxl": self.t5xxl.tokenize_with_weights(text, return_word_ids)
}
return out return out
def untokenize(self, token_weight_pair): def untokenize(self, token_weight_pair):
@ -48,12 +50,15 @@ class FluxTokenizer:
class FluxClipModel(torch.nn.Module): class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}): def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options=None):
super().__init__() super().__init__()
if model_options is None:
model_options = {}
dtype_t5 = model_management.pick_weight_dtype(dtype_t5, dtype, device) dtype_t5 = model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options) self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5]) self.dtypes = {dtype, dtype_t5}
def set_clip_options(self, options): def set_clip_options(self, options):
self.clip_l.set_clip_options(options) self.clip_l.set_clip_options(options)
@ -80,6 +85,9 @@ class FluxClipModel(torch.nn.Module):
def flux_clip(dtype_t5=None): def flux_clip(dtype_t5=None):
class FluxClipModel_(FluxClipModel): class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_ return FluxClipModel_

View File

@ -4,18 +4,41 @@ from ..component_model.files import get_path_as_dict
class LongClipTokenizer_(sd1_clip.SDTokenizer): class LongClipTokenizer_(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
class LongClipModel_(sd1_clip.SDClipModel): class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, textmodel_json_config=None): def __init__(self, *args, **kwargs):
kwargs = kwargs or {}
textmodel_json_config = kwargs.get("textmodel_json_config", None)
textmodel_json_config = get_path_as_dict(textmodel_json_config, "long_clipl.json", package=__package__) textmodel_json_config = get_path_as_dict(textmodel_json_config, "long_clipl.json", package=__package__)
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options) super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
class LongClipTokenizer(sd1_clip.SD1Tokenizer): class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_) super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
class LongClipModel(sd1_clip.SD1ClipModel): class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs) super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
def model_options_long_clip(sd, tokenizer_data, model_options):
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
if w is None:
w = sd.get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
tokenizer_data = tokenizer_data.copy()
model_options = model_options.copy()
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
model_options["clip_l_class"] = LongClipModel_
return tokenizer_data, model_options

View File

@ -28,7 +28,8 @@ class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data=None): def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None: if tokenizer_data is None:
tokenizer_data = dict() tokenizer_data = dict()
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory) clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory) self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory) self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
@ -54,7 +55,8 @@ class SD3ClipModel(torch.nn.Module):
super().__init__() super().__init__()
self.dtypes = set() self.dtypes = set()
if clip_l: if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options) clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype) self.dtypes.add(dtype)
else: else:
self.clip_l = None self.clip_l = None
@ -107,7 +109,8 @@ class SD3ClipModel(torch.nn.Module):
if self.clip_g is not None: if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g) g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
if lg_out is not None: if lg_out is not None:
lg_out = torch.cat([lg_out, g_out], dim=-1) cut_to = min(lg_out.shape[1], g_out.shape[1])
lg_out = torch.cat([lg_out[:, :cut_to], g_out[:, :cut_to]], dim=-1)
else: else:
lg_out = torch.nn.functional.pad(g_out, (768, 0)) lg_out = torch.nn.functional.pad(g_out, (768, 0))
else: else:
@ -145,6 +148,9 @@ class SD3ClipModel(torch.nn.Module):
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None): def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
class SD3ClipModel_(SD3ClipModel): class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}): def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_ return SD3ClipModel_

View File

@ -95,6 +95,7 @@ def calculate_parameters(sd, prefix=""):
params += w.nelement() params += w.nelement()
return params return params
def weight_dtype(sd, prefix=""): def weight_dtype(sd, prefix=""):
dtypes = {} dtypes = {}
for k in sd.keys(): for k in sd.keys():
@ -107,6 +108,7 @@ def weight_dtype(sd, prefix=""):
return max(dtypes, key=dtypes.get) return max(dtypes, key=dtypes.get)
def state_dict_key_replace(state_dict, keys_to_replace): def state_dict_key_replace(state_dict, keys_to_replace):
for x in keys_to_replace: for x in keys_to_replace:
if x in state_dict: if x in state_dict:
@ -472,6 +474,7 @@ def auraflow_to_diffusers(mmdit_config, output_prefix=""):
return key_map return key_map
def flux_to_diffusers(mmdit_config, output_prefix=""): def flux_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("depth", 0) n_double_layers = mmdit_config.get("depth", 0)
n_single_layers = mmdit_config.get("depth_single_blocks", 0) n_single_layers = mmdit_config.get("depth_single_blocks", 0)
@ -496,27 +499,27 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = { block_map = {
"attn.to_out.0.weight": "img_attn.proj.weight", "attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias", "attn.to_out.0.bias": "img_attn.proj.bias",
"norm1.linear.weight": "img_mod.lin.weight", "norm1.linear.weight": "img_mod.lin.weight",
"norm1.linear.bias": "img_mod.lin.bias", "norm1.linear.bias": "img_mod.lin.bias",
"norm1_context.linear.weight": "txt_mod.lin.weight", "norm1_context.linear.weight": "txt_mod.lin.weight",
"norm1_context.linear.bias": "txt_mod.lin.bias", "norm1_context.linear.bias": "txt_mod.lin.bias",
"attn.to_add_out.weight": "txt_attn.proj.weight", "attn.to_add_out.weight": "txt_attn.proj.weight",
"attn.to_add_out.bias": "txt_attn.proj.bias", "attn.to_add_out.bias": "txt_attn.proj.bias",
"ff.net.0.proj.weight": "img_mlp.0.weight", "ff.net.0.proj.weight": "img_mlp.0.weight",
"ff.net.0.proj.bias": "img_mlp.0.bias", "ff.net.0.proj.bias": "img_mlp.0.bias",
"ff.net.2.weight": "img_mlp.2.weight", "ff.net.2.weight": "img_mlp.2.weight",
"ff.net.2.bias": "img_mlp.2.bias", "ff.net.2.bias": "img_mlp.2.bias",
"ff_context.net.0.proj.weight": "txt_mlp.0.weight", "ff_context.net.0.proj.weight": "txt_mlp.0.weight",
"ff_context.net.0.proj.bias": "txt_mlp.0.bias", "ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight", "ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias", "ff_context.net.2.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale", "attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale", "attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale", "attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale", "attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
} }
for k in block_map: for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
@ -534,13 +537,13 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4)) key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
block_map = { block_map = {
"norm.linear.weight": "modulation.lin.weight", "norm.linear.weight": "modulation.lin.weight",
"norm.linear.bias": "modulation.lin.bias", "norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight", "proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias", "proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale", "attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale", "attn.norm_k.weight": "norm.key_norm.scale",
} }
for k in block_map: for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k]) key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
@ -763,7 +766,9 @@ def common_upscale(samples, width, height, upscale_method, crop):
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap))) rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
return rows * cols
@torch.inference_mode() @torch.inference_mode()
@ -773,10 +778,19 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
for b in range(samples.shape[0]): for b in range(samples.shape[0]):
s = samples[b:b + 1] s = samples[b:b + 1]
# handle entire input fitting in a single tile
if all(s.shape[d + 2] <= tile[d] for d in range(dims)):
output[b:b + 1] = function(s).to(output_device)
if pbar is not None:
pbar.update(1)
continue
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device) out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))): positions = [range(0, s.shape[d + 2], tile[d] - overlap) if s.shape[d + 2] > tile[d] else [0] for d in range(dims)]
for it in itertools.product(*positions):
s_in = s s_in = s
upscaled = [] upscaled = []
@ -785,15 +799,16 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
l = min(tile[d], s.shape[d + 2] - pos) l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l) s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(pos * upscale_amount)) upscaled.append(round(pos * upscale_amount))
ps = function(s_in).to(output_device) ps = function(s_in).to(output_device)
mask = torch.ones_like(ps) mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount) feather = round(overlap * upscale_amount)
for t in range(feather): for t in range(feather):
for d in range(2, dims + 2): for d in range(2, dims + 2):
m = mask.narrow(d, t, 1) a = (t + 1) / feather
m *= ((1.0 / feather) * (t + 1)) mask.narrow(d, t, 1).mul_(a)
m = mask.narrow(d, mask.shape[d] - 1 - t, 1) mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
m *= ((1.0 / feather) * (t + 1))
o = out o = out
o_d = out_div o_d = out_div
@ -801,8 +816,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2]) o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o += ps * mask o.add_(ps * mask)
o_d += mask o_d.add_(mask)
if pbar is not None: if pbar is not None:
pbar.update(1) pbar.update(1)

1
comfy/web/assets/CREDIT.txt generated vendored Normal file
View File

@ -0,0 +1 @@
Thanks to OpenArt (https://openart.ai) for providing the sorted-custom-node-map data, captured in September 2024.

3142
comfy/web/assets/GraphView-DN9xGvF3.js generated vendored Normal file

File diff suppressed because one or more lines are too long

1
comfy/web/assets/GraphView-DN9xGvF3.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

158
comfy/web/assets/GraphView-DXU9yRen.css generated vendored Normal file
View File

@ -0,0 +1,158 @@
.group-title-editor.node-title-editor[data-v-fc3f26e3] {
z-index: 9999;
padding: 0.25rem;
}
[data-v-fc3f26e3] .editable-text {
width: 100%;
height: 100%;
}
[data-v-fc3f26e3] .editable-text input {
width: 100%;
height: 100%;
/* Override the default font size */
font-size: inherit;
}
.side-bar-button-icon {
font-size: var(--sidebar-icon-size) !important;
}
.side-bar-button-selected .side-bar-button-icon {
font-size: var(--sidebar-icon-size) !important;
font-weight: bold;
}
.side-bar-button[data-v-caa3ee9c] {
width: var(--sidebar-width);
height: var(--sidebar-width);
border-radius: 0;
}
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
border-left: 4px solid var(--p-button-text-primary-color);
}
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
border-right: 4px solid var(--p-button-text-primary-color);
}
:root {
--sidebar-width: 64px;
--sidebar-icon-size: 1.5rem;
}
:root .small-sidebar {
--sidebar-width: 40px;
--sidebar-icon-size: 1rem;
}
.side-tool-bar-container[data-v-ed7a1148] {
display: flex;
flex-direction: column;
align-items: center;
pointer-events: auto;
width: var(--sidebar-width);
height: 100%;
background-color: var(--comfy-menu-bg);
color: var(--fg-color);
}
.side-tool-bar-end[data-v-ed7a1148] {
align-self: flex-end;
margin-top: auto;
}
.sidebar-content-container[data-v-ed7a1148] {
height: 100%;
overflow-y: auto;
}
.p-splitter-gutter {
pointer-events: auto;
}
.gutter-hidden {
display: none !important;
}
.side-bar-panel[data-v-edca8328] {
background-color: var(--bg-color);
pointer-events: auto;
}
.splitter-overlay[data-v-edca8328] {
width: 100%;
height: 100%;
position: absolute;
top: 0;
left: 0;
background-color: transparent;
pointer-events: none;
/* Set it the same as the ComfyUI menu */
/* Note: Lite-graph DOM widgets have the same z-index as the node id, so
999 should be sufficient to make sure splitter overlays on node's DOM
widgets */
z-index: 999;
border: none;
}
[data-v-37f672ab] .highlight {
background-color: var(--p-primary-color);
color: var(--p-primary-contrast-color);
font-weight: bold;
border-radius: 0.25rem;
padding: 0rem 0.125rem;
margin: -0.125rem 0.125rem;
}
.comfy-vue-node-search-container[data-v-2d409367] {
display: flex;
width: 100%;
min-width: 26rem;
align-items: center;
justify-content: center;
}
.comfy-vue-node-search-container[data-v-2d409367] * {
pointer-events: auto;
}
.comfy-vue-node-preview-container[data-v-2d409367] {
position: absolute;
left: -350px;
top: 50px;
}
.comfy-vue-node-search-box[data-v-2d409367] {
z-index: 10;
flex-grow: 1;
}
._filter-button[data-v-2d409367] {
z-index: 10;
}
._dialog[data-v-2d409367] {
min-width: 26rem;
}
.invisible-dialog-root {
width: 30%;
min-width: 24rem;
max-width: 48rem;
border: 0 !important;
background-color: transparent !important;
margin-top: 25vh;
}
.node-search-box-dialog-mask {
align-items: flex-start !important;
}
.node-tooltip[data-v-e0597bf9] {
background: var(--comfy-input-bg);
border-radius: 5px;
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
color: var(--input-text);
font-family: sans-serif;
left: 0;
max-width: 30vw;
padding: 4px 8px;
position: absolute;
top: 0;
transform: translate(5px, calc(-100% - 5px));
white-space: pre-wrap;
z-index: 99999;
}

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
var __defProp = Object.defineProperty; var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true }); var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { C as ComfyDialog, $ as $el, a as ComfyApp, b as app, L as LGraphCanvas, c as LiteGraph, d as LGraphNode, e as applyTextReplacements, f as ComfyWidgets, g as addValueControlWidgets, D as DraggableList, h as api, i as LGraphGroup, u as useToastStore } from "./index-Dfv2aLsq.js"; import { aM as ComfyDialog, aN as $el, aO as ComfyApp, c as app, aH as LGraphCanvas, k as LiteGraph, e as LGraphNode, aP as applyTextReplacements, aQ as ComfyWidgets, aR as addValueControlWidgets, aS as DraggableList, av as useNodeDefStore, aT as api, L as LGraphGroup, aU as useToastStore, at as NodeSourceType, aV as NodeBadgeMode, u as useSettingStore, q as computed, w as watch, aW as BadgePosition, aJ as LGraphBadge$1, aX as _ } from "./index-Drc_oD2f.js";
class ClipspaceDialog extends ComfyDialog { class ClipspaceDialog extends ComfyDialog {
static { static {
__name(this, "ClipspaceDialog"); __name(this, "ClipspaceDialog");
@ -213,7 +213,9 @@ const colorPalettes = {
WIDGET_SECONDARY_TEXT_COLOR: "#999", WIDGET_SECONDARY_TEXT_COLOR: "#999",
LINK_COLOR: "#9A9", LINK_COLOR: "#9A9",
EVENT_LINK_COLOR: "#A86", EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA" CONNECTING_LINK_COLOR: "#AFA",
BADGE_FG_COLOR: "#FFF",
BADGE_BG_COLOR: "#0F1F0F"
}, },
comfy_base: { comfy_base: {
"fg-color": "#fff", "fg-color": "#fff",
@ -283,7 +285,9 @@ const colorPalettes = {
WIDGET_SECONDARY_TEXT_COLOR: "#555", WIDGET_SECONDARY_TEXT_COLOR: "#555",
LINK_COLOR: "#4CAF50", LINK_COLOR: "#4CAF50",
EVENT_LINK_COLOR: "#FF9800", EVENT_LINK_COLOR: "#FF9800",
CONNECTING_LINK_COLOR: "#2196F3" CONNECTING_LINK_COLOR: "#2196F3",
BADGE_FG_COLOR: "#000",
BADGE_BG_COLOR: "#FFF"
}, },
comfy_base: { comfy_base: {
"fg-color": "#222", "fg-color": "#222",
@ -621,6 +625,32 @@ const defaultColorPaletteId = "dark";
const els = { const els = {
select: null select: null
}; };
const getCustomColorPalettes = /* @__PURE__ */ __name(() => {
return app.ui.settings.getSettingValue(idCustomColorPalettes, {});
}, "getCustomColorPalettes");
const setCustomColorPalettes = /* @__PURE__ */ __name((customColorPalettes) => {
return app.ui.settings.setSettingValue(
idCustomColorPalettes,
customColorPalettes
);
}, "setCustomColorPalettes");
const defaultColorPalette = colorPalettes[defaultColorPaletteId];
const getColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
if (!colorPaletteId) {
colorPaletteId = app.ui.settings.getSettingValue(id$4, defaultColorPaletteId);
}
if (colorPaletteId.startsWith("custom_")) {
colorPaletteId = colorPaletteId.substr(7);
let customColorPalettes = getCustomColorPalettes();
if (customColorPalettes[colorPaletteId]) {
return customColorPalettes[colorPaletteId];
}
}
return colorPalettes[colorPaletteId];
}, "getColorPalette");
const setColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
app.ui.settings.setSettingValue(id$4, colorPaletteId);
}, "setColorPalette");
app.registerExtension({ app.registerExtension({
name: id$4, name: id$4,
init() { init() {
@ -695,28 +725,19 @@ app.registerExtension({
comfy_base: {} comfy_base: {}
} }
}; };
const defaultColorPalette = colorPalettes[defaultColorPaletteId]; const defaultColorPalette2 = colorPalettes[defaultColorPaletteId];
for (const key in defaultColorPalette.colors.litegraph_base) { for (const key in defaultColorPalette2.colors.litegraph_base) {
if (!colorPalette.colors.litegraph_base[key]) { if (!colorPalette.colors.litegraph_base[key]) {
colorPalette.colors.litegraph_base[key] = ""; colorPalette.colors.litegraph_base[key] = "";
} }
} }
for (const key in defaultColorPalette.colors.comfy_base) { for (const key in defaultColorPalette2.colors.comfy_base) {
if (!colorPalette.colors.comfy_base[key]) { if (!colorPalette.colors.comfy_base[key]) {
colorPalette.colors.comfy_base[key] = ""; colorPalette.colors.comfy_base[key] = "";
} }
} }
return completeColorPalette(colorPalette); return completeColorPalette(colorPalette);
}, "getColorPaletteTemplate"); }, "getColorPaletteTemplate");
const getCustomColorPalettes = /* @__PURE__ */ __name(() => {
return app.ui.settings.getSettingValue(idCustomColorPalettes, {});
}, "getCustomColorPalettes");
const setCustomColorPalettes = /* @__PURE__ */ __name((customColorPalettes) => {
return app.ui.settings.setSettingValue(
idCustomColorPalettes,
customColorPalettes
);
}, "setCustomColorPalettes");
const addCustomColorPalette = /* @__PURE__ */ __name(async (colorPalette) => { const addCustomColorPalette = /* @__PURE__ */ __name(async (colorPalette) => {
if (typeof colorPalette !== "object") { if (typeof colorPalette !== "object") {
alert("Invalid color palette."); alert("Invalid color palette.");
@ -807,25 +828,6 @@ app.registerExtension({
app.canvas.draw(true, true); app.canvas.draw(true, true);
} }
}, "loadColorPalette"); }, "loadColorPalette");
const getColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
if (!colorPaletteId) {
colorPaletteId = app.ui.settings.getSettingValue(
id$4,
defaultColorPaletteId
);
}
if (colorPaletteId.startsWith("custom_")) {
colorPaletteId = colorPaletteId.substr(7);
let customColorPalettes = getCustomColorPalettes();
if (customColorPalettes[colorPaletteId]) {
return customColorPalettes[colorPaletteId];
}
}
return colorPalettes[colorPaletteId];
}, "getColorPalette");
const setColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
app.ui.settings.setSettingValue(id$4, colorPaletteId);
}, "setColorPalette");
const fileInput = $el("input", { const fileInput = $el("input", {
type: "file", type: "file",
accept: ".json", accept: ".json",
@ -994,6 +996,10 @@ app.registerExtension({
}); });
} }
}); });
window.comfyAPI = window.comfyAPI || {};
window.comfyAPI.colorPalette = window.comfyAPI.colorPalette || {};
window.comfyAPI.colorPalette.defaultColorPalette = defaultColorPalette;
window.comfyAPI.colorPalette.getColorPalette = getColorPalette;
const ext$2 = { const ext$2 = {
name: "Comfy.ContextMenuFilter", name: "Comfy.ContextMenuFilter",
init() { init() {
@ -1360,7 +1366,7 @@ class PrimitiveNode extends LGraphNode {
this.#mergeWidgetConfig(); this.#mergeWidgetConfig();
} }
} }
onConnectionsChange(_, index, connected) { onConnectionsChange(_2, index, connected) {
if (app.configuringGraph) { if (app.configuringGraph) {
return; return;
} }
@ -1806,7 +1812,7 @@ app.registerExtension({
convertToInput(this, widget, config); convertToInput(this, widget, config);
return true; return true;
}; };
nodeType.prototype.getExtraMenuOptions = function(_, options) { nodeType.prototype.getExtraMenuOptions = function(_2, options) {
const r = origGetExtraMenuOptions ? origGetExtraMenuOptions.apply(this, arguments) : void 0; const r = origGetExtraMenuOptions ? origGetExtraMenuOptions.apply(this, arguments) : void 0;
if (this.widgets) { if (this.widgets) {
let toInput = []; let toInput = [];
@ -1862,6 +1868,7 @@ app.registerExtension({
}; };
nodeType.prototype.onGraphConfigured = function() { nodeType.prototype.onGraphConfigured = function() {
if (!this.inputs) return; if (!this.inputs) return;
this.widgets ??= [];
for (const input of this.inputs) { for (const input of this.inputs) {
if (input.widget) { if (input.widget) {
if (!input.widget[GET_CONFIG]) { if (!input.widget[GET_CONFIG]) {
@ -1919,7 +1926,7 @@ app.registerExtension({
return r; return r;
}; };
function isNodeAtPos(pos) { function isNodeAtPos(pos) {
for (const n of app2.graph._nodes) { for (const n of app2.graph.nodes) {
if (n.pos[0] === pos[0] && n.pos[1] === pos[1]) { if (n.pos[0] === pos[0] && n.pos[1] === pos[1]) {
return true; return true;
} }
@ -2308,7 +2315,7 @@ class ManageGroupDialog extends ComfyDialog {
"button.comfy-btn", "button.comfy-btn",
{ {
onclick: /* @__PURE__ */ __name((e) => { onclick: /* @__PURE__ */ __name((e) => {
const node = app.graph._nodes.find( const node = app.graph.nodes.find(
(n) => n.type === "workflow/" + this.selectedGroup (n) => n.type === "workflow/" + this.selectedGroup
); );
if (node) { if (node) {
@ -2374,7 +2381,7 @@ class ManageGroupDialog extends ComfyDialog {
} }
types[g] = type2; types[g] = type2;
if (!nodesByType) { if (!nodesByType) {
nodesByType = app.graph._nodes.reduce((p, n) => { nodesByType = app.graph.nodes.reduce((p, n) => {
p[n.type] ??= []; p[n.type] ??= [];
p[n.type].push(n); p[n.type].push(n);
return p; return p;
@ -2424,7 +2431,7 @@ const Workflow = {
isInUseGroupNode(name) { isInUseGroupNode(name) {
const id2 = `workflow/${name}`; const id2 = `workflow/${name}`;
if (app.graph.extra?.groupNodes?.[name]) { if (app.graph.extra?.groupNodes?.[name]) {
if (app.graph._nodes.find((n) => n.type === id2)) { if (app.graph.nodes.find((n) => n.type === id2)) {
return Workflow.InUse.InWorkflow; return Workflow.InUse.InWorkflow;
} else { } else {
return Workflow.InUse.Registered; return Workflow.InUse.Registered;
@ -2576,6 +2583,8 @@ class GroupNodeConfig {
display_name: this.name, display_name: this.name,
category: "group nodes" + ("/" + source), category: "group nodes" + ("/" + source),
input: { required: {} }, input: { required: {} },
description: `Group node combining ${this.nodeData.nodes.map((n) => n.type).join(", ")}`,
python_module: "custom_nodes." + this.name,
[GROUP]: this [GROUP]: this
}; };
this.inputs = []; this.inputs = [];
@ -2591,6 +2600,7 @@ class GroupNodeConfig {
} }
this.#convertedToProcess = null; this.#convertedToProcess = null;
await app.registerNodeDef("workflow/" + this.name, this.nodeDef); await app.registerNodeDef("workflow/" + this.name, this.nodeDef);
useNodeDefStore().addNodeDef(this.nodeDef);
} }
getLinks() { getLinks() {
this.linksFrom = {}; this.linksFrom = {};
@ -2775,7 +2785,7 @@ class GroupNodeConfig {
checkPrimitiveConnection(link, inputName, inputs) { checkPrimitiveConnection(link, inputName, inputs) {
const sourceNode = this.nodeData.nodes[link[0]]; const sourceNode = this.nodeData.nodes[link[0]];
if (sourceNode.type === "PrimitiveNode") { if (sourceNode.type === "PrimitiveNode") {
const [sourceNodeId, _, targetNodeId, __] = link; const [sourceNodeId, _2, targetNodeId, __] = link;
const primitiveDef = this.primitiveDefs[sourceNodeId]; const primitiveDef = this.primitiveDefs[sourceNodeId];
const targetWidget = inputs[inputName]; const targetWidget = inputs[inputName];
const primitiveConfig = primitiveDef.input.required.value; const primitiveConfig = primitiveDef.input.required.value;
@ -3177,7 +3187,7 @@ class GroupNodeHandler {
return newNodes; return newNodes;
}; };
const getExtraMenuOptions = this.node.getExtraMenuOptions; const getExtraMenuOptions = this.node.getExtraMenuOptions;
this.node.getExtraMenuOptions = function(_, options) { this.node.getExtraMenuOptions = function(_2, options) {
getExtraMenuOptions?.apply(this, arguments); getExtraMenuOptions?.apply(this, arguments);
let optionIndex = options.findIndex((o) => o.content === "Outputs"); let optionIndex = options.findIndex((o) => o.content === "Outputs");
if (optionIndex === -1) optionIndex = options.length; if (optionIndex === -1) optionIndex = options.length;
@ -3353,7 +3363,7 @@ class GroupNodeHandler {
} else if (innerNode.type === "Reroute") { } else if (innerNode.type === "Reroute") {
const rerouteLinks = this.groupData.linksFrom[old.node.index]; const rerouteLinks = this.groupData.linksFrom[old.node.index];
if (rerouteLinks) { if (rerouteLinks) {
for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) { for (const [_2, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
const node = this.innerNodes[targetNodeId]; const node = this.innerNodes[targetNodeId];
const input = node.inputs[targetSlot]; const input = node.inputs[targetSlot];
if (input.widget) { if (input.widget) {
@ -3599,7 +3609,7 @@ function addNodesToGroup(group, nodes = []) {
var node; var node;
x1 = y1 = x2 = y2 = -1; x1 = y1 = x2 = y2 = -1;
nx1 = ny1 = nx2 = ny2 = -1; nx1 = ny1 = nx2 = ny2 = -1;
for (var n of [group._nodes, nodes]) { for (var n of [group.nodes, nodes]) {
for (var i in n) { for (var i in n) {
node = n[i]; node = n[i];
nx1 = node.pos[0]; nx1 = node.pos[0];
@ -3659,7 +3669,7 @@ app.registerExtension({
return options; return options;
} }
group.recomputeInsideNodes(); group.recomputeInsideNodes();
const nodesInGroup = group._nodes; const nodesInGroup = group.nodes;
options.push({ options.push({
content: "Add Selected Nodes To Group", content: "Add Selected Nodes To Group",
disabled: !Object.keys(app.canvas.selected_nodes || {}).length, disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
@ -4002,6 +4012,16 @@ function prepare_mask(image, maskCanvas, maskCtx, maskColor) {
maskCtx.putImageData(maskData, 0, 0); maskCtx.putImageData(maskData, 0, 0);
} }
__name(prepare_mask, "prepare_mask"); __name(prepare_mask, "prepare_mask");
var PointerType = /* @__PURE__ */ ((PointerType2) => {
PointerType2["Arc"] = "arc";
PointerType2["Rect"] = "rect";
return PointerType2;
})(PointerType || {});
var CompositionOperation = /* @__PURE__ */ ((CompositionOperation2) => {
CompositionOperation2["SourceOver"] = "source-over";
CompositionOperation2["DestinationOut"] = "destination-out";
return CompositionOperation2;
})(CompositionOperation || {});
class MaskEditorDialog extends ComfyDialog { class MaskEditorDialog extends ComfyDialog {
static { static {
__name(this, "MaskEditorDialog"); __name(this, "MaskEditorDialog");
@ -4030,6 +4050,8 @@ class MaskEditorDialog extends ComfyDialog {
mousedown_pan_x; mousedown_pan_x;
mousedown_pan_y; mousedown_pan_y;
last_pressure; last_pressure;
pointer_type;
brush_pointer_type_select;
static getInstance() { static getInstance() {
if (!MaskEditorDialog.instance) { if (!MaskEditorDialog.instance) {
MaskEditorDialog.instance = new MaskEditorDialog(); MaskEditorDialog.instance = new MaskEditorDialog();
@ -4077,7 +4099,7 @@ class MaskEditorDialog extends ComfyDialog {
divElement.style.borderColor = "var(--border-color)"; divElement.style.borderColor = "var(--border-color)";
divElement.style.borderStyle = "solid"; divElement.style.borderStyle = "solid";
divElement.style.fontSize = "15px"; divElement.style.fontSize = "15px";
divElement.style.height = "21px"; divElement.style.height = "25px";
divElement.style.padding = "1px 6px"; divElement.style.padding = "1px 6px";
divElement.style.display = "flex"; divElement.style.display = "flex";
divElement.style.position = "relative"; divElement.style.position = "relative";
@ -4107,7 +4129,7 @@ class MaskEditorDialog extends ComfyDialog {
divElement.style.borderColor = "var(--border-color)"; divElement.style.borderColor = "var(--border-color)";
divElement.style.borderStyle = "solid"; divElement.style.borderStyle = "solid";
divElement.style.fontSize = "15px"; divElement.style.fontSize = "15px";
divElement.style.height = "21px"; divElement.style.height = "25px";
divElement.style.padding = "1px 6px"; divElement.style.padding = "1px 6px";
divElement.style.display = "flex"; divElement.style.display = "flex";
divElement.style.position = "relative"; divElement.style.position = "relative";
@ -4126,8 +4148,63 @@ class MaskEditorDialog extends ComfyDialog {
self.opacity_slider_input.addEventListener("input", callback); self.opacity_slider_input.addEventListener("input", callback);
return divElement; return divElement;
} }
createPointerTypeSelect(self) {
const divElement = document.createElement("div");
divElement.id = "maskeditor-pointer-type";
divElement.style.cssFloat = "left";
divElement.style.fontFamily = "sans-serif";
divElement.style.marginRight = "4px";
divElement.style.color = "var(--input-text)";
divElement.style.backgroundColor = "var(--comfy-input-bg)";
divElement.style.borderRadius = "8px";
divElement.style.borderColor = "var(--border-color)";
divElement.style.borderStyle = "solid";
divElement.style.fontSize = "15px";
divElement.style.height = "25px";
divElement.style.padding = "1px 6px";
divElement.style.display = "flex";
divElement.style.position = "relative";
divElement.style.top = "2px";
divElement.style.pointerEvents = "auto";
const labelElement = document.createElement("label");
labelElement.textContent = "Pointer Type:";
const selectElement = document.createElement("select");
selectElement.style.borderRadius = "0";
selectElement.style.borderColor = "transparent";
selectElement.style.borderStyle = "unset";
selectElement.style.fontSize = "0.9em";
const optionArc = document.createElement("option");
optionArc.value = "arc";
optionArc.text = "Circle";
optionArc.selected = true;
const optionRect = document.createElement("option");
optionRect.value = "rect";
optionRect.text = "Square";
selectElement.appendChild(optionArc);
selectElement.appendChild(optionRect);
selectElement.addEventListener("change", (event) => {
const target = event.target;
self.pointer_type = target.value;
this.setBrushBorderRadius(self);
});
divElement.appendChild(labelElement);
divElement.appendChild(selectElement);
return divElement;
}
setBrushBorderRadius(self) {
if (self.pointer_type === "rect") {
this.brush.style.borderRadius = "0%";
this.brush.style.MozBorderRadius = "0%";
this.brush.style.WebkitBorderRadius = "0%";
} else {
this.brush.style.borderRadius = "50%";
this.brush.style.MozBorderRadius = "50%";
this.brush.style.WebkitBorderRadius = "50%";
}
}
setlayout(imgCanvas, maskCanvas) { setlayout(imgCanvas, maskCanvas) {
const self = this; const self = this;
self.pointer_type = "arc";
var bottom_panel = document.createElement("div"); var bottom_panel = document.createElement("div");
bottom_panel.style.position = "absolute"; bottom_panel.style.position = "absolute";
bottom_panel.style.bottom = "0px"; bottom_panel.style.bottom = "0px";
@ -4140,13 +4217,11 @@ class MaskEditorDialog extends ComfyDialog {
brush.style.backgroundColor = "transparent"; brush.style.backgroundColor = "transparent";
brush.style.outline = "1px dashed black"; brush.style.outline = "1px dashed black";
brush.style.boxShadow = "0 0 0 1px white"; brush.style.boxShadow = "0 0 0 1px white";
brush.style.borderRadius = "50%";
brush.style.MozBorderRadius = "50%";
brush.style.WebkitBorderRadius = "50%";
brush.style.position = "absolute"; brush.style.position = "absolute";
brush.style.zIndex = "8889"; brush.style.zIndex = "8889";
brush.style.pointerEvents = "none"; brush.style.pointerEvents = "none";
this.brush = brush; this.brush = brush;
this.setBrushBorderRadius(self);
this.element.appendChild(imgCanvas); this.element.appendChild(imgCanvas);
this.element.appendChild(maskCanvas); this.element.appendChild(maskCanvas);
this.element.appendChild(bottom_panel); this.element.appendChild(bottom_panel);
@ -4177,6 +4252,7 @@ class MaskEditorDialog extends ComfyDialog {
} }
} }
); );
this.brush_pointer_type_select = this.createPointerTypeSelect(self);
this.colorButton = this.createLeftButton(this.getColorButtonText(), () => { this.colorButton = this.createLeftButton(this.getColorButtonText(), () => {
if (self.brush_color_mode === "black") { if (self.brush_color_mode === "black") {
self.brush_color_mode = "white"; self.brush_color_mode = "white";
@ -4203,6 +4279,7 @@ class MaskEditorDialog extends ComfyDialog {
bottom_panel.appendChild(cancelButton); bottom_panel.appendChild(cancelButton);
bottom_panel.appendChild(this.brush_size_slider); bottom_panel.appendChild(this.brush_size_slider);
bottom_panel.appendChild(this.brush_opacity_slider); bottom_panel.appendChild(this.brush_opacity_slider);
bottom_panel.appendChild(this.brush_pointer_type_select);
bottom_panel.appendChild(this.colorButton); bottom_panel.appendChild(this.colorButton);
imgCanvas.style.position = "absolute"; imgCanvas.style.position = "absolute";
maskCanvas.style.position = "absolute"; maskCanvas.style.position = "absolute";
@ -4568,19 +4645,22 @@ class MaskEditorDialog extends ComfyDialog {
} }
if (diff > 20 && !this.drawing_mode) if (diff > 20 && !this.drawing_mode)
requestAnimationFrame(() => { requestAnimationFrame(() => {
self.maskCtx.beginPath(); self.init_shape(
self.maskCtx.fillStyle = this.getMaskFillStyle(); self,
self.maskCtx.globalCompositeOperation = "source-over"; "source-over"
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); /* SourceOver */
self.maskCtx.fill(); );
self.draw_shape(self, x, y, brush_size);
self.lastx = x; self.lastx = x;
self.lasty = y; self.lasty = y;
}); });
else else
requestAnimationFrame(() => { requestAnimationFrame(() => {
self.maskCtx.beginPath(); self.init_shape(
self.maskCtx.fillStyle = this.getMaskFillStyle(); self,
self.maskCtx.globalCompositeOperation = "source-over"; "source-over"
/* SourceOver */
);
var dx = x - self.lastx; var dx = x - self.lastx;
var dy = y - self.lasty; var dy = y - self.lasty;
var distance = Math.sqrt(dx * dx + dy * dy); var distance = Math.sqrt(dx * dx + dy * dy);
@ -4589,8 +4669,7 @@ class MaskEditorDialog extends ComfyDialog {
for (var i = 0; i < distance; i += 5) { for (var i = 0; i < distance; i += 5) {
var px = self.lastx + directionX * i; var px = self.lastx + directionX * i;
var py = self.lasty + directionY * i; var py = self.lasty + directionY * i;
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false); self.draw_shape(self, px, py, brush_size);
self.maskCtx.fill();
} }
self.lastx = x; self.lastx = x;
self.lasty = y; self.lasty = y;
@ -4611,17 +4690,22 @@ class MaskEditorDialog extends ComfyDialog {
} }
if (diff > 20 && !this.drawing_mode) if (diff > 20 && !this.drawing_mode)
requestAnimationFrame(() => { requestAnimationFrame(() => {
self.maskCtx.beginPath(); self.init_shape(
self.maskCtx.globalCompositeOperation = "destination-out"; self,
self.maskCtx.arc(x2, y2, brush_size, 0, Math.PI * 2, false); "destination-out"
self.maskCtx.fill(); /* DestinationOut */
);
self.draw_shape(self, x2, y2, brush_size);
self.lastx = x2; self.lastx = x2;
self.lasty = y2; self.lasty = y2;
}); });
else else
requestAnimationFrame(() => { requestAnimationFrame(() => {
self.maskCtx.beginPath(); self.init_shape(
self.maskCtx.globalCompositeOperation = "destination-out"; self,
"destination-out"
/* DestinationOut */
);
var dx = x2 - self.lastx; var dx = x2 - self.lastx;
var dy = y2 - self.lasty; var dy = y2 - self.lasty;
var distance = Math.sqrt(dx * dx + dy * dy); var distance = Math.sqrt(dx * dx + dy * dy);
@ -4630,8 +4714,7 @@ class MaskEditorDialog extends ComfyDialog {
for (var i = 0; i < distance; i += 5) { for (var i = 0; i < distance; i += 5) {
var px = self.lastx + directionX * i; var px = self.lastx + directionX * i;
var py = self.lasty + directionY * i; var py = self.lasty + directionY * i;
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false); self.draw_shape(self, px, py, brush_size);
self.maskCtx.fill();
} }
self.lastx = x2; self.lastx = x2;
self.lasty = y2; self.lasty = y2;
@ -4665,20 +4748,47 @@ class MaskEditorDialog extends ComfyDialog {
const maskRect = self.maskCanvas.getBoundingClientRect(); const maskRect = self.maskCanvas.getBoundingClientRect();
const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio; const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio; const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
self.maskCtx.beginPath();
if (!event.altKey && event.button == 0) { if (!event.altKey && event.button == 0) {
self.maskCtx.fillStyle = this.getMaskFillStyle(); self.init_shape(
self.maskCtx.globalCompositeOperation = "source-over"; self,
"source-over"
/* SourceOver */
);
} else { } else {
self.maskCtx.globalCompositeOperation = "destination-out"; self.init_shape(
self,
"destination-out"
/* DestinationOut */
);
} }
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false); self.draw_shape(self, x, y, brush_size);
self.maskCtx.fill();
self.lastx = x; self.lastx = x;
self.lasty = y; self.lasty = y;
self.lasttime = performance.now(); self.lasttime = performance.now();
} }
} }
init_shape(self, compositionOperation) {
self.maskCtx.beginPath();
if (compositionOperation == "source-over") {
self.maskCtx.fillStyle = this.getMaskFillStyle();
self.maskCtx.globalCompositeOperation = "source-over";
} else if (compositionOperation == "destination-out") {
self.maskCtx.globalCompositeOperation = "destination-out";
}
}
draw_shape(self, x, y, brush_size) {
if (self.pointer_type === "rect") {
self.maskCtx.rect(
x - brush_size,
y - brush_size,
brush_size * 2,
brush_size * 2
);
} else {
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
}
self.maskCtx.fill();
}
async save() { async save() {
const backupCanvas = document.createElement("canvas"); const backupCanvas = document.createElement("canvas");
const backupCtx = backupCanvas.getContext("2d", { const backupCtx = backupCanvas.getContext("2d", {
@ -5264,7 +5374,7 @@ app.registerExtension({
updateNodes.push(node); updateNodes.push(node);
} else { } else {
const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null; const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null;
if (inputType && inputType !== "*" && nodeOutType !== inputType) { if (inputType && !LiteGraph.isValidConnection(inputType, nodeOutType)) {
node.disconnectInput(link.target_slot); node.disconnectInput(link.target_slot);
} else { } else {
outputType = nodeOutType; outputType = nodeOutType;
@ -5300,6 +5410,7 @@ app.registerExtension({
} }
if (!targetWidget) { if (!targetWidget) {
targetWidget = targetNode.widgets?.find( targetWidget = targetNode.widgets?.find(
// @ts-expect-error fix widget types
(w) => w.name === targetInput.widget.name (w) => w.name === targetInput.widget.name
); );
} }
@ -5342,7 +5453,7 @@ app.registerExtension({
}; };
this.isVirtualNode = true; this.isVirtualNode = true;
} }
getExtraMenuOptions(_, options) { getExtraMenuOptions(_2, options) {
options.unshift( options.unshift(
{ {
content: (this.properties.showOutputText ? "Hide" : "Show") + " Type", content: (this.properties.showOutputText ? "Hide" : "Show") + " Type",
@ -5564,8 +5675,7 @@ app.registerExtension({
slot_types_default_in: {}, slot_types_default_in: {},
async beforeRegisterNodeDef(nodeType, nodeData, app2) { async beforeRegisterNodeDef(nodeType, nodeData, app2) {
var nodeId = nodeData.name; var nodeId = nodeData.name;
var inputs = []; const inputs = nodeData["input"]["required"];
inputs = nodeData["input"]["required"];
for (const inputKey in inputs) { for (const inputKey in inputs) {
var input = inputs[inputKey]; var input = inputs[inputKey];
if (typeof input[0] !== "string") continue; if (typeof input[0] !== "string") continue;
@ -5637,7 +5747,7 @@ app.registerExtension({
tooltip: "When dragging and resizing nodes while holding shift they will be aligned to the grid, this controls the size of that grid.", tooltip: "When dragging and resizing nodes while holding shift they will be aligned to the grid, this controls the size of that grid.",
defaultValue: LiteGraph.CANVAS_GRID_SIZE, defaultValue: LiteGraph.CANVAS_GRID_SIZE,
onChange(value) { onChange(value) {
LiteGraph.CANVAS_GRID_SIZE = +value; LiteGraph.CANVAS_GRID_SIZE = +value || 10;
} }
}); });
const onNodeMoved = app.canvas.onNodeMoved; const onNodeMoved = app.canvas.onNodeMoved;
@ -5697,7 +5807,7 @@ app.registerExtension({
} }
if (app.canvas.last_mouse_dragging === false && app.shiftDown) { if (app.canvas.last_mouse_dragging === false && app.shiftDown) {
this.recomputeInsideNodes(); this.recomputeInsideNodes();
for (const node of this._nodes) { for (const node of this.nodes) {
node.alignToGrid(); node.alignToGrid();
} }
LGraphNode.prototype.alignToGrid.apply(this); LGraphNode.prototype.alignToGrid.apply(this);
@ -5730,7 +5840,7 @@ app.registerExtension({
LGraphCanvas.onGroupAdd = function() { LGraphCanvas.onGroupAdd = function() {
const v = onGroupAdd.apply(app.canvas, arguments); const v = onGroupAdd.apply(app.canvas, arguments);
if (app.shiftDown) { if (app.shiftDown) {
const lastGroup = app.graph._groups[app.graph._groups.length - 1]; const lastGroup = app.graph.groups[app.graph.groups.length - 1];
if (lastGroup) { if (lastGroup) {
roundVectorToGrid(lastGroup.pos); roundVectorToGrid(lastGroup.pos);
roundVectorToGrid(lastGroup.size); roundVectorToGrid(lastGroup.size);
@ -6026,4 +6136,108 @@ app.registerExtension({
}; };
} }
}); });
//# sourceMappingURL=index-CrROdkG4.js.map function getNodeSource(node) {
const nodeDef = node.constructor.nodeData;
if (!nodeDef) {
return null;
}
const nodeDefStore = useNodeDefStore();
return nodeDefStore.nodeDefsByName[nodeDef.name]?.nodeSource ?? null;
}
__name(getNodeSource, "getNodeSource");
function isCoreNode(node) {
return getNodeSource(node)?.type === NodeSourceType.Core;
}
__name(isCoreNode, "isCoreNode");
function badgeTextVisible(node, badgeMode) {
return badgeMode === NodeBadgeMode.None || isCoreNode(node) && badgeMode === NodeBadgeMode.HideBuiltIn;
}
__name(badgeTextVisible, "badgeTextVisible");
function getNodeIdBadgeText(node, nodeIdBadgeMode) {
return badgeTextVisible(node, nodeIdBadgeMode) ? "" : `#${node.id}`;
}
__name(getNodeIdBadgeText, "getNodeIdBadgeText");
function getNodeSourceBadgeText(node, nodeSourceBadgeMode) {
const nodeSource = getNodeSource(node);
return badgeTextVisible(node, nodeSourceBadgeMode) ? "" : nodeSource?.badgeText ?? "";
}
__name(getNodeSourceBadgeText, "getNodeSourceBadgeText");
function getNodeLifeCycleBadgeText(node, nodeLifeCycleBadgeMode) {
let text = "";
const nodeDef = node.constructor.nodeData;
if (!nodeDef) {
return "";
}
if (nodeDef.deprecated) {
text = "[DEPR]";
}
if (nodeDef.experimental) {
text = "[BETA]";
}
return badgeTextVisible(node, nodeLifeCycleBadgeMode) ? "" : text;
}
__name(getNodeLifeCycleBadgeText, "getNodeLifeCycleBadgeText");
class NodeBadgeExtension {
static {
__name(this, "NodeBadgeExtension");
}
constructor(nodeIdBadgeMode = null, nodeSourceBadgeMode = null, nodeLifeCycleBadgeMode = null, colorPalette = null) {
this.nodeIdBadgeMode = nodeIdBadgeMode;
this.nodeSourceBadgeMode = nodeSourceBadgeMode;
this.nodeLifeCycleBadgeMode = nodeLifeCycleBadgeMode;
this.colorPalette = colorPalette;
}
name = "Comfy.NodeBadge";
init(app2) {
const settingStore = useSettingStore();
this.nodeSourceBadgeMode = computed(
() => settingStore.get("Comfy.NodeBadge.NodeSourceBadgeMode")
);
this.nodeIdBadgeMode = computed(
() => settingStore.get("Comfy.NodeBadge.NodeIdBadgeMode")
);
this.nodeLifeCycleBadgeMode = computed(
() => settingStore.get(
"Comfy.NodeBadge.NodeLifeCycleBadgeMode"
)
);
this.colorPalette = computed(
() => getColorPalette(settingStore.get("Comfy.ColorPalette"))
);
watch(this.nodeSourceBadgeMode, () => {
app2.graph.setDirtyCanvas(true, true);
});
watch(this.nodeIdBadgeMode, () => {
app2.graph.setDirtyCanvas(true, true);
});
watch(this.nodeLifeCycleBadgeMode, () => {
app2.graph.setDirtyCanvas(true, true);
});
}
nodeCreated(node, app2) {
node.badgePosition = BadgePosition.TopRight;
node.badge_enabled = true;
const badge = computed(
() => new LGraphBadge$1({
text: _.truncate(
[
getNodeIdBadgeText(node, this.nodeIdBadgeMode.value),
getNodeLifeCycleBadgeText(
node,
this.nodeLifeCycleBadgeMode.value
),
getNodeSourceBadgeText(node, this.nodeSourceBadgeMode.value)
].filter((s) => s.length > 0).join(" "),
{
length: 31
}
),
fgColor: this.colorPalette.value.colors.litegraph_base?.BADGE_FG_COLOR || defaultColorPalette.colors.litegraph_base.BADGE_FG_COLOR,
bgColor: this.colorPalette.value.colors.litegraph_base?.BADGE_BG_COLOR || defaultColorPalette.colors.litegraph_base.BADGE_BG_COLOR
})
);
node.badges.push(() => badge.value);
}
}
app.registerExtension(new NodeBadgeExtension());
//# sourceMappingURL=index-BDBCRrlL.js.map

1
comfy/web/assets/index-BDBCRrlL.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1
comfy/web/assets/index-Drc_oD2f.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

2602
comfy/web/assets/sorted-custom-node-map.json generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
var __defProp = Object.defineProperty; var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true }); var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { j as createSpinner, h as api, $ as $el } from "./index-Dfv2aLsq.js"; import { aY as createSpinner, aT as api, aN as $el } from "./index-Drc_oD2f.js";
class UserSelectionScreen { class UserSelectionScreen {
static { static {
__name(this, "UserSelectionScreen"); __name(this, "UserSelectionScreen");
@ -117,4 +117,4 @@ window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
export { export {
UserSelectionScreen UserSelectionScreen
}; };
//# sourceMappingURL=userSelection-DSpF-zVD.js.map //# sourceMappingURL=userSelection-BM5u5JIA.js.map

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,3 @@
// Shim for extensions/core/colorPalette.ts
export const defaultColorPalette = window.comfyAPI.colorPalette.defaultColorPalette;
export const getColorPalette = window.comfyAPI.colorPalette.getColorPalette;

94
comfy/web/index.html vendored
View File

@ -1,50 +1,50 @@
<!DOCTYPE html> <!DOCTYPE html>
<html lang="en"> <html lang="en">
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<title>ComfyUI</title> <title>ComfyUI</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no"> <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
<!-- Browser Test Fonts --> <!-- Browser Test Fonts -->
<!-- <link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet"> <!-- <link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
<link href="https://fonts.googleapis.com/css2?family=Noto+Color+Emoji&family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet"> <link href="https://fonts.googleapis.com/css2?family=Noto+Color+Emoji&family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
<style> <style>
* { * {
font-family: 'Roboto Mono', 'Noto Color Emoji'; font-family: 'Roboto Mono', 'Noto Color Emoji';
} }
</style> --> </style> -->
<link rel="stylesheet" type="text/css" href="user.css" /> <link rel="stylesheet" type="text/css" href="user.css" />
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" /> <link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
<script type="module" crossorigin src="./assets/index-Dfv2aLsq.js"></script> <script type="module" crossorigin src="./assets/index-Drc_oD2f.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-W4jP-SrU.css"> <link rel="stylesheet" crossorigin href="./assets/index-8NH3XvqK.css">
</head> </head>
<body class="litegraph"> <body class="litegraph">
<div id="vue-app"></div> <div id="vue-app"></div>
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;"> <div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
<main class="comfy-user-selection-inner"> <main class="comfy-user-selection-inner">
<h1>ComfyUI</h1> <h1>ComfyUI</h1>
<form> <form>
<section> <section>
<label>New user: <label>New user:
<input placeholder="Enter a username" /> <input placeholder="Enter a username" />
</label> </label>
</section> </section>
<div class="comfy-user-existing"> <div class="comfy-user-existing">
<span class="or-separator">OR</span> <span class="or-separator">OR</span>
<section> <section>
<label> <label>
Existing user: Existing user:
<select> <select>
<option hidden disabled selected value> Select a user </option> <option hidden disabled selected value> Select a user </option>
</select> </select>
</label> </label>
</section> </section>
</div> </div>
<footer> <footer>
<span class="comfy-user-error">&nbsp;</span> <span class="comfy-user-error">&nbsp;</span>
<button class="comfy-btn comfy-user-button-next">Next</button> <button class="comfy-btn comfy-user-button-next">Next</button>
</footer> </footer>
</form> </form>
</main> </main>
</div> </div>
</body> </body>
</html> </html>

View File

@ -1,4 +1,3 @@
// Shim for scripts/workflows.ts // Shim for scripts/workflows.ts
export const trimJsonExt = window.comfyAPI.workflows.trimJsonExt;
export const ComfyWorkflowManager = window.comfyAPI.workflows.ComfyWorkflowManager; export const ComfyWorkflowManager = window.comfyAPI.workflows.ComfyWorkflowManager;
export const ComfyWorkflow = window.comfyAPI.workflows.ComfyWorkflow; export const ComfyWorkflow = window.comfyAPI.workflows.ComfyWorkflow;

View File

@ -73,6 +73,9 @@ class VAEDecodeAudio:
def decode(self, vae, samples): def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1) audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return ({"waveform": audio, "sample_rate": 44100},) return ({"waveform": audio, "sample_rate": 44100},)

View File

@ -1,9 +1,12 @@
import comfy.utils
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
from comfy.nodes.base_nodes import ControlNetApplyAdvanced
class SetUnionControlNetType: class SetUnionControlNetType:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"control_net": ("CONTROL_NET", ), return {"required": {"control_net": ("CONTROL_NET",),
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),) "type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
}} }}
@ -22,6 +25,37 @@ class SetUnionControlNetType:
return (control_net,) return (control_net,)
class ControlNetInpaintingAliMamaApply(ControlNetApplyAdvanced):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"control_net": ("CONTROL_NET",),
"vae": ("VAE",),
"image": ("IMAGE",),
"mask": ("MASK",),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
FUNCTION = "apply_inpaint_controlnet"
CATEGORY = "conditioning/controlnet"
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
extra_concat = []
if control_net.concat_mask:
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
extra_concat = [mask]
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"SetUnionControlNetType": SetUnionControlNetType, "SetUnionControlNetType": SetUnionControlNetType,
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
} }

View File

@ -94,6 +94,27 @@ class PolyexponentialScheduler:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, ) return (sigmas, )
class LaplaceScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
return (sigmas, )
class SDTurboScheduler: class SDTurboScheduler:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -677,6 +698,7 @@ NODE_CLASS_MAPPINGS = {
"KarrasScheduler": KarrasScheduler, "KarrasScheduler": KarrasScheduler,
"ExponentialScheduler": ExponentialScheduler, "ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler,
"LaplaceScheduler": LaplaceScheduler,
"VPScheduler": VPScheduler, "VPScheduler": VPScheduler,
"BetaSamplingScheduler": BetaSamplingScheduler, "BetaSamplingScheduler": BetaSamplingScheduler,
"SDTurboScheduler": SDTurboScheduler, "SDTurboScheduler": SDTurboScheduler,

View File

@ -107,7 +107,7 @@ class HypernetworkLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength): def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name) hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone() model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength) patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None: if patch is not None:

View File

@ -1,24 +1,26 @@
from comfy.nodes.common import MAX_RESOLUTION import json
from comfy.cmd import folder_paths import os
from comfy.cli_args import args
import numpy as np
from PIL import Image from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
import numpy as np from comfy.cli_args import args
import json from comfy.cmd import folder_paths
import os from comfy.nodes.common import MAX_RESOLUTION
from comfy.utils import tensor2pil
class ImageCrop: class ImageCrop:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), return {"required": {"image": ("IMAGE",),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}} }}
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "crop" FUNCTION = "crop"
@ -29,31 +31,35 @@ class ImageCrop:
y = min(y, image.shape[1] - 1) y = min(y, image.shape[1] - 1)
to_x = width + x to_x = width + x
to_y = height + y to_y = height + y
img = image[:,y:to_y, x:to_x, :] img = image[:, y:to_y, x:to_x, :]
return (img,) return (img,)
class RepeatImageBatch: class RepeatImageBatch:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), return {"required": {"image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}), "amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
}} }}
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat" FUNCTION = "repeat"
CATEGORY = "image/batch" CATEGORY = "image/batch"
def repeat(self, image, amount): def repeat(self, image, amount):
s = image.repeat((amount, 1,1,1)) s = image.repeat((amount, 1, 1, 1))
return (s,) return (s,)
class ImageFromBatch: class ImageFromBatch:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",), return {"required": {"image": ("IMAGE",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}), "batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
"length": ("INT", {"default": 1, "min": 1, "max": 4096}), "length": ("INT", {"default": 1, "min": 1, "max": 4096}),
}} }}
RETURN_TYPES = ("IMAGE",) RETURN_TYPES = ("IMAGE",)
FUNCTION = "frombatch" FUNCTION = "frombatch"
@ -66,6 +72,7 @@ class ImageFromBatch:
s = s_in[batch_index:batch_index + length].clone() s = s_in[batch_index:batch_index + length].clone()
return (s,) return (s,)
class SaveAnimatedWEBP: class SaveAnimatedWEBP:
def __init__(self): def __init__(self):
self.output_dir = folder_paths.get_output_directory() self.output_dir = folder_paths.get_output_directory()
@ -73,10 +80,11 @@ class SaveAnimatedWEBP:
self.prefix_append = "" self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6} methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": return {"required":
{"images": ("IMAGE", ), {"images": ("IMAGE",),
"filename_prefix": ("STRING", {"default": "ComfyUI"}), "filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": True}), "lossless": ("BOOLEAN", {"default": True}),
@ -121,7 +129,7 @@ class SaveAnimatedWEBP:
c = len(pil_images) c = len(pil_images)
for i in range(0, c, num_frames): for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp" file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method) pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0 / fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({ results.append({
"filename": file, "filename": file,
"subfolder": subfolder, "subfolder": subfolder,
@ -130,7 +138,8 @@ class SaveAnimatedWEBP:
counter += 1 counter += 1
animated = num_frames != 1 animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,) } } return {"ui": {"images": results, "animated": (animated,)}}
class SaveAnimatedPNG: class SaveAnimatedPNG:
def __init__(self): def __init__(self):
@ -141,7 +150,7 @@ class SaveAnimatedPNG:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": return {"required":
{"images": ("IMAGE", ), {"images": ("IMAGE",),
"filename_prefix": ("STRING", {"default": "ComfyUI"}), "filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}), "fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9}) "compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
@ -176,16 +185,63 @@ class SaveAnimatedPNG:
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True) metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
file = f"{filename}_{counter:05}_.png" file = f"{filename}_{counter:05}_.png"
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:]) pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0 / fps), append_images=pil_images[1:])
results.append({ results.append({
"filename": file, "filename": file,
"subfolder": subfolder, "subfolder": subfolder,
"type": self.type "type": self.type
}) })
return { "ui": { "images": results, "animated": (True,)} } return {"ui": {"images": results, "animated": (True,)}}
class ImageSizeToNumber:
"""
By WASasquatch (Discord: WAS#0263)
Copyright 2023 Jordan Thompson (WASasquatch)
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the Software), to
deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED AS IS, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
}
}
RETURN_TYPES = ("*", "*", "FLOAT", "FLOAT", "INT", "INT")
RETURN_NAMES = ("width_num", "height_num", "width_float", "height_float", "width_int", "height_int")
FUNCTION = "image_width_height"
CATEGORY = "image/operations"
def image_width_height(self, image):
image = tensor2pil(image)
if image.size:
return (
image.size[0], image.size[1], float(image.size[0]), float(image.size[1]), image.size[0], image.size[1])
return 0, 0, 0, 0, 0, 0
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
# From WAS Node Suite
# Class mapping is kept for compatibility
"Image Size to Number": ImageSizeToNumber,
"ImageCrop": ImageCrop, "ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch, "RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch, "ImageFromBatch": ImageFromBatch,

View File

@ -126,7 +126,7 @@ class PhotoMakerLoader:
CATEGORY = "_for_testing/photomaker" CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name): def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name) photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder() photomaker_model = PhotoMakerIDEncoder()
data = utils.load_torch_file(photomaker_model_path, safe_load=True) data = utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data: if "id_encoder" in data:

View File

@ -42,7 +42,7 @@ class EmptySD3LatentImage:
CATEGORY = "latent/sd3" CATEGORY = "latent/sd3"
def generate(self, width, height, batch_size=1): def generate(self, width, height, batch_size=1):
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609 latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples": latent},) return ({"samples": latent},)
@ -101,6 +101,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001}) "end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}} }}
CATEGORY = "conditioning/controlnet" CATEGORY = "conditioning/controlnet"
DEPRECATED = True
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
"TripleCLIPLoader": TripleCLIPLoader, "TripleCLIPLoader": TripleCLIPLoader,
@ -111,5 +112,5 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling # Sampling
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT", "ControlNetApplySD3": "Apply Controlnet with VAE",
} }

View File

@ -9,9 +9,9 @@ from comfy.cmd.folder_paths import filter_files_content_types
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def file_extensions(): def file_extensions():
return { return {
'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'], 'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'], 'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
'video': ['avi', 'flv', 'm2v', 'm4v', 'mj2', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv'] 'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
} }

View File

@ -82,14 +82,16 @@ def test_load_extra_model_paths_expands_userpath(
load_extra_path_config(dummy_yaml_file_name) load_extra_path_config(dummy_yaml_file_name)
expected_calls = [ expected_calls = [
('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1')), ('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1'), False),
] ]
assert mock_add_model_folder_path.call_count == len(expected_calls) assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check if add_model_folder_path was called with the correct arguments # Check if add_model_folder_path was called with the correct arguments
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls): for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call assert actual_call.args[0] == expected_call[0]
assert os.path.normpath(actual_call.args[1]) == os.path.normpath(expected_call[1]) # Normalize and check the path to check on multiple OS.
assert actual_call.args[2] == expected_call[2]
# Check if yaml.safe_load was called # Check if yaml.safe_load was called
mock_yaml_safe_load.assert_called_once() mock_yaml_safe_load.assert_called_once()
@ -123,7 +125,7 @@ def test_load_extra_model_paths_expands_appdata(
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI' expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
expected_calls = [ expected_calls = [
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints')), ('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
] ]
assert mock_add_model_folder_path.call_count == len(expected_calls) assert mock_add_model_folder_path.call_count == len(expected_calls)