Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2024-09-23 12:50:31 -07:00
commit fa3176f96f
51 changed files with 51259 additions and 44715 deletions

View File

@ -488,6 +488,7 @@ disable=raw-checker-failed,
too-many-return-statements,
too-many-branches,
too-many-arguments,
too-many-positional-arguments,
too-many-locals,
too-many-statements,
too-many-boolean-expressions,

View File

@ -1,10 +1,9 @@
import logging
from typing import Optional
from aiohttp import web
from ...services.file_service import FileService
from ....cmd.folder_paths import models_dir, user_directory, output_directory
from ....cmd.folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
class InternalRoutes:
@ -41,6 +40,13 @@ class InternalRoutes:
# todo: applications really shouldn't serve logs like this
return web.json_response({})
@self.routes.get('/folder_paths')
async def get_folder_paths(request):
response = {}
for key in folder_names_and_paths:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
def get_app(self):
if self._app is None:
self._app = web.Application()

View File

@ -119,6 +119,29 @@ class UserManager():
@routes.get("/userdata")
async def listuserdata(request):
"""
List user data files in a specified directory.
This endpoint allows listing files in a user's data directory, with options for recursion,
full file information, and path splitting.
Query Parameters:
- dir (required): The directory to list files from.
- recurse (optional): If "true", recursively list files in subdirectories.
- full_info (optional): If "true", return detailed file information (path, size, modified time).
- split (optional): If "true", split file paths into components (only applies when full_info is false).
Returns:
- 400: If 'dir' parameter is missing.
- 403: If the requested path is not allowed.
- 404: If the requested directory does not exist.
- 200: JSON response with the list of files or file information.
The response format depends on the query parameters:
- Default: List of relative file paths.
- full_info=true: List of dictionaries with file details.
- split=true (and full_info=false): List of lists, each containing path components.
"""
directory = request.rel_url.query.get('dir', '')
if not directory:
return web.Response(status=400, text="Directory not provided")

View File

@ -1,10 +1,20 @@
import itertools
from typing import Sequence, Mapping
from typing import Sequence, Mapping, Dict
from .cmd.execution import nodes
from .graph import DynamicPrompt
from .graph_utils import is_link
NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
def include_unique_id_in_input(class_type: str) -> bool:
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
NODE_CLASS_CONTAINS_UNIQUE_ID[class_type] = "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values()
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
class CacheKeySet:
def __init__(self, dynprompt, node_ids, is_changed_cache):
@ -102,7 +112,7 @@ class CacheKeySetInputSignature(CacheKeySet):
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
signature = [class_type, self.is_changed_cache.get(node_id)]
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or "UNIQUE_ID" in class_def.INPUT_TYPES().get("hidden", {}).values():
if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT) or include_unique_id_in_input(class_type):
signature.append(node_id)
inputs = node["inputs"]
for key in sorted(inputs.keys()):

View File

@ -16,6 +16,7 @@ class ControlNet(MMDiT):
def __init__(
self,
num_blocks = None,
control_latent_channels = None,
dtype = None,
device = None,
operations = None,
@ -27,10 +28,13 @@ class ControlNet(MMDiT):
for _ in range(len(self.joint_blocks)):
self.controlnet_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, device=device, dtype=dtype))
if control_latent_channels is None:
control_latent_channels = self.in_channels
self.pos_embed_input = PatchEmbed(
None,
self.patch_size,
self.in_channels,
control_latent_channels,
self.hidden_size,
bias=True,
strict_img_size=False,

View File

@ -28,8 +28,8 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument('-w', "--cwd", type=str, default=None,
help="Specify the working directory. If not set, this is the current working directory. models/, input/, output/ and other directories will be located here by default.")
parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0",
help="Specify the IP address to listen on (default: 127.0.0.1). If --listen is provided without an argument, it defaults to 0.0.0.0. (listens on all)")
parser.add_argument('-H', "--listen", type=str, default="127.0.0.1", metavar="IP", nargs="?", const="0.0.0.0,::",
help="Specify the IP address to listen on (default: 127.0.0.1). You can give a list of ip addresses by separating them with a comma like: 127.2.2.2,127.3.3.3 If --listen is provided without an argument, it defaults to 0.0.0.0,:: (listens on all ipv4 and ipv6)")
parser.add_argument("--port", type=int, default=8188, help="Set the listen port.")
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*",
help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")

View File

@ -51,7 +51,46 @@ temp_directory = os.path.join(get_base_path(), "temp")
input_directory = os.path.join(get_base_path(), "input")
user_directory = os.path.join(get_base_path(), "user")
_filename_list_cache = {}
filename_list_cache = {}
class CacheHelper:
"""
Helper class for managing file list cache data.
"""
def __init__(self):
self.cache: dict[str, tuple[list[str], dict[str, float], float]] = {}
self.active = False
def get(self, key: str, default=None) -> tuple[list[str], dict[str, float], float]:
if not self.active:
return default
return self.cache.get(key, default)
def set(self, key: str, value: tuple[list[str], dict[str, float], float]) -> None:
if self.active:
self.cache[key] = value
def clear(self):
self.cache.clear()
def __enter__(self):
self.active = True
return self
def __exit__(self, exc_type, exc_value, traceback):
self.active = False
self.clear()
cache_helper = CacheHelper()
def map_legacy(folder_name: str) -> str:
legacy = {"unet": "diffusion_models"}
return legacy.get(folder_name, folder_name)
if not os.path.exists(input_directory):
try:
@ -150,7 +189,7 @@ def exists_annotated_filepath(name):
return os.path.exists(filepath)
def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, extensions: Optional[set[str]] = None) -> str:
def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, extensions: Optional[set[str]] = None, is_default: bool = False) -> str:
"""
Registers a model path for the given canonical name.
:param folder_name: the folder name
@ -165,7 +204,10 @@ def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, e
folder_path = folder_names_and_paths[folder_name]
if full_folder_path not in folder_path.paths:
folder_path.paths.append(full_folder_path)
if is_default:
folder_path.paths.insert(0, full_folder_path)
else:
folder_path.paths.append(full_folder_path)
if extensions is not None:
folder_path.supported_extensions |= extensions
@ -244,7 +286,15 @@ def get_full_path(folder_name, filename) -> Optional[str | bytes | os.PathLike]:
return None
def get_filename_list_(folder_name):
def get_full_path_or_raise(folder_name: str, filename: str) -> str:
full_path = get_full_path(folder_name, filename)
if full_path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found.")
return full_path
def get_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float]:
folder_name = map_legacy(folder_name)
global folder_names_and_paths
output_list = set()
folders = folder_names_and_paths[folder_name]
@ -257,12 +307,17 @@ def get_filename_list_(folder_name):
return sorted(list(output_list)), output_folders, time.perf_counter()
def cached_filename_list_(folder_name):
global _filename_list_cache
def cached_filename_list_(folder_name: str) -> tuple[list[str], dict[str, float], float] | None:
strong_cache = cache_helper.get(folder_name)
if strong_cache is not None:
return strong_cache
global filename_list_cache
global folder_names_and_paths
if folder_name not in _filename_list_cache:
folder_name = map_legacy(folder_name)
if folder_name not in filename_list_cache:
return None
out = _filename_list_cache[folder_name]
out = filename_list_cache[folder_name]
for x in out[1]:
time_modified = out[1][x]
@ -279,12 +334,14 @@ def cached_filename_list_(folder_name):
return out
def get_filename_list(folder_name):
def get_filename_list(folder_name: str) -> list[str]:
folder_name = map_legacy(folder_name)
out = cached_filename_list_(folder_name)
if out is None:
out = get_filename_list_(folder_name)
global _filename_list_cache
_filename_list_cache[folder_name] = out
global filename_list_cache
filename_list_cache[folder_name] = out
cache_helper.set(folder_name, out)
return list(out[0])
@ -345,8 +402,8 @@ def create_directories():
def invalidate_cache(folder_name):
global _filename_list_cache
_filename_list_cache.pop(folder_name, None)
global filename_list_cache
filename_list_cache.pop(folder_name, None)
def filter_files_content_types(files: list[str], content_types: list[Literal["image", "video", "audio"]]) -> list[str]:

View File

@ -82,7 +82,10 @@ def prompt_worker(q: AbstractPromptQueue, _server: server_module.PromptServer):
async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
addresses = []
for addr in address.split(","):
addresses.append((addr, port))
await asyncio.gather(server.start_multi_address(addresses, call_on_start), server.publish_loop())
def cleanup_temp():
@ -223,12 +226,15 @@ async def main(from_script_dir: Optional[Path] = None):
import webbrowser
if os.name == 'nt' and address == '0.0.0.0' or address == '':
address = '127.0.0.1'
if ':' in address:
address = "[{}]".format(address)
webbrowser.open(f"http://{address}:{port}")
call_on_start = startup_server
server.address = args.listen
server.port = args.port
try:
await server.setup()
await run(server, address=args.listen, port=args.port, verbose=not args.dont_print_server,

View File

@ -29,7 +29,7 @@ from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
from typing_extensions import NamedTuple
from .latent_preview_image_encoding import encode_preview_image
from .. import interruption
from .. import interruption, model_management
from .. import node_helpers
from .. import utils
from ..api_server.routes.internal.internal_routes import InternalRoutes
@ -256,6 +256,12 @@ class PromptServer(ExecutorToClientProgress):
embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@routes.get("/models")
def list_model_types(request):
model_types = list(folder_paths.folder_names_and_paths.keys())
return web.json_response(model_types)
@routes.get("/models/{folder}")
async def get_models(request):
folder = request.match_info.get("folder", None)
@ -505,12 +511,17 @@ class PromptServer(ExecutorToClientProgress):
async def system_stats(request):
device = get_torch_device()
device_name = get_torch_device_name(device)
cpu_device = model_management.torch.device("cpu")
ram_total = model_management.get_total_memory(cpu_device)
ram_free = model_management.get_free_memory(cpu_device)
vram_total, torch_vram_total = get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = get_free_memory(device, torch_free_too=True)
system_stats = {
"system": {
"os": os.name,
"ram_total": ram_total,
"ram_free": ram_free,
"comfyui_version": get_comfyui_version(),
"python_version": sys.version,
"pytorch_version": torch_version,
@ -568,14 +579,15 @@ class PromptServer(ExecutorToClientProgress):
@routes.get("/object_info")
async def get_object_info(request):
out = {}
for x in self.nodes.NODE_CLASS_MAPPINGS:
try:
out[x] = node_info(x)
except Exception as e:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
return web.json_response(out)
with folder_paths.cache_helper:
out = {}
for x in self.nodes.NODE_CLASS_MAPPINGS:
try:
out[x] = node_info(x)
except Exception as e:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc())
return web.json_response(out)
@routes.get("/object_info/{node_class}")
async def get_object_info_node(request):
@ -969,17 +981,29 @@ class PromptServer(ExecutorToClientProgress):
await self.send(*msg)
async def start(self, address: str | None, port: int | None, verbose=True, call_on_start=None):
await self.start_multi_address([(address, port)], call_on_start=call_on_start, verbose=verbose)
async def start_multi_address(self, addresses, call_on_start=None, verbose=True):
runner = web.AppRunner(self.app, access_log=None)
await runner.setup()
site = web.TCPSite(runner, host=address, port=port)
await site.start()
for addr in addresses:
address = addr[0]
port = addr[1]
site = web.TCPSite(runner, address, port)
await site.start()
self.address = address
self.port = port
if not hasattr(self, 'address'):
self.address = address #TODO: remove this
self.port = port
if ':' in address:
address_print = "[{}]".format(address)
else:
address_print = address
if verbose:
logging.info("Starting server")
logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address == "0.0.0.0" else address, port))
logging.info("To see the GUI go to: http://{}:{}".format("localhost" if address_print == "0.0.0.0" else address, port))
if call_on_start is not None:
call_on_start(address, port)

View File

@ -30,10 +30,9 @@ from . import model_patcher
from . import ops
from . import utils
from .cldm import cldm, mmdit
from .ldm import hydit
from .ldm.cascade import controlnet as cascade_controlnet
from .ldm.flux import controlnet as controlnet_flux
from .ldm.flux.weight_dtypes import FLUX_WEIGHT_DTYPES
from .ldm.hydit.controlnet import HunYuanControlNet
from .t2i_adapter import adapter
@ -81,13 +80,21 @@ class ControlBase:
self.previous_controlnet = None
self.extra_conds = []
self.strength_type = StrengthType.CONSTANT
self.concat_mask = False
self.extra_concat_orig = []
self.extra_concat = None
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None):
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
self.cond_hint_original = cond_hint
self.strength = strength
self.timestep_percent_range = timestep_percent_range
if self.latent_format is not None:
if vae is None:
logging.warning("WARNING: no VAE provided to the controlnet apply node when this controlnet requires one.")
self.vae = vae
self.extra_concat_orig = extra_concat.copy()
if self.concat_mask and len(self.extra_concat_orig) == 0:
self.extra_concat_orig.append(torch.tensor([[[[1.0]]]]))
return self
def pre_run(self, model, percent_to_timestep_function):
@ -102,9 +109,9 @@ class ControlBase:
def cleanup(self):
if self.previous_controlnet is not None:
self.previous_controlnet.cleanup()
if self.cond_hint is not None:
del self.cond_hint
self.cond_hint = None
self.cond_hint = None
self.extra_concat = None
self.timestep_range = None
def get_models(self):
@ -125,6 +132,8 @@ class ControlBase:
c.vae = self.vae
c.extra_conds = self.extra_conds.copy()
c.strength_type = self.strength_type
c.concat_mask = self.concat_mask
c.extra_concat_orig = self.extra_concat_orig.copy()
def inference_memory_requirements(self, dtype):
if self.previous_controlnet is not None:
@ -177,7 +186,7 @@ class ControlBase:
class ControlNet(ControlBase):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, ckpt_name: str = None):
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, device=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, ckpt_name: str = None):
super().__init__(device)
self.control_model = control_model
self.load_device = load_device
@ -192,6 +201,7 @@ class ControlNet(ControlBase):
self.latent_format = latent_format
self.extra_conds += extra_conds
self.strength_type = strength_type
self.concat_mask = concat_mask
def get_control(self, x_noisy, t, cond, batched_number):
control_prev = None
@ -216,6 +226,9 @@ class ControlNet(ControlBase):
compression_ratio = self.compression_ratio
if self.vae is not None:
compression_ratio *= self.vae.downscale_ratio
else:
if self.latent_format is not None:
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
if self.vae is not None:
loaded_models = model_management.loaded_models(only_currently_used=True)
@ -223,6 +236,13 @@ class ControlNet(ControlBase):
model_management.load_models_gpu(loaded_models)
if self.latent_format is not None:
self.cond_hint = self.latent_format.process_in(self.cond_hint)
if len(self.extra_concat_orig) > 0:
to_concat = []
for c in self.extra_concat_orig:
c = utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
to_concat.append(utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
self.cond_hint = self.cond_hint.to(device=self.device, dtype=dtype)
if x_noisy.shape[0] != self.cond_hint.shape[0]:
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
@ -322,7 +342,7 @@ class ControlLoraOps:
class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, device=None):
def __init__(self, control_weights, global_average_pooling=False, device=None, model_options={}): # TODO? model_options
ControlBase.__init__(self, device)
self.control_weights = control_weights
self.global_average_pooling = global_average_pooling
@ -381,19 +401,27 @@ class ControlLora(ControlNet):
return utils.calculate_parameters(self.control_weights) * model_management.dtype_size(dtype) + ControlBase.inference_memory_requirements(self, dtype)
def controlnet_config(sd):
def controlnet_config(sd, model_options=None):
if model_options is None:
model_options = {}
model_config = model_detection.model_config_from_unet(sd, "", True)
supported_inference_dtypes = model_config.supported_inference_dtypes
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
controlnet_config = model_config.unet_config
unet_dtype = model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
load_device = model_management.get_torch_device()
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
operations = ops.manual_cast
else:
operations = ops.disable_weight_init
operations = model_options.get("custom_operations", None)
if operations is None:
operations = ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
offload_device = model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
@ -410,26 +438,35 @@ def controlnet_load_state_dict(control_model, sd):
return control_model
def load_controlnet_mmdit(sd):
def load_controlnet_mmdit(sd, model_options=None):
if model_options is None:
model_options = {}
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
num_blocks = model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
concat_mask = False
control_latent_channels = new_sd.get("pos_embed_input.proj.weight").shape[1]
if control_latent_channels == 17: # inpaint controlnet
concat_mask = True
control_model = mmdit.ControlNet(num_blocks=num_blocks, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = latent_formats.SD3()
latent_format.shift_factor = 0 # SD3 controlnet weirdness
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype)
return control
def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
def load_controlnet_hunyuandit(controlnet_data, model_options=None):
if model_options is None:
model_options = {}
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
control_model = hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)
latent_format = latent_formats.SDXL()
@ -473,10 +510,10 @@ def load_controlnet_flux_instantx_union(sd, controlnet_class, weight_dtype, full
device = model_management.get_torch_device()
if weight_dtype == "fp8_e4m3fn":
if weight_dtype == torch.float8_e4m3fn or weight_dtype == "fp8_e4m3fn":
dtype = torch.float8_e4m3fn
operations = ops.manual_cast
elif weight_dtype == "fp8_e5m2":
elif weight_dtype == torch.float8_e5m2 or weight_dtype == "float8_e5m2":
dtype = torch.float8_e5m2
operations = ops.manual_cast
else:
@ -493,8 +530,10 @@ def load_controlnet_flux_instantx_union(sd, controlnet_class, weight_dtype, full
return control
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False, model_options=None):
if model_options is None:
model_options = {}
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
control_model = controlnet_flux.ControlNetFlux(mistoline=mistoline, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
@ -502,9 +541,11 @@ def load_controlnet_flux_xlabs_mistoline(sd, mistoline=False):
return control
def load_controlnet_flux_instantx(sd):
def load_controlnet_flux_instantx(sd, model_options=None):
if model_options is None:
model_options = {}
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
for k in sd:
new_sd[k] = sd[k]
@ -514,13 +555,16 @@ def load_controlnet_flux_instantx(sd):
num_union_modes = new_sd[union_cnet].shape[0]
control_latent_channels = new_sd.get("pos_embed_input.weight").shape[1] // 4
concat_mask = False
if control_latent_channels == 17:
concat_mask = True
control_model = controlnet_flux.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, control_latent_channels=control_latent_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = latent_formats.Flux()
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
@ -528,12 +572,15 @@ def convert_mistoline(sd):
return utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
controlnet_data = utils.load_torch_file(ckpt_path, safe_load=True)
def load_controlnet_state_dict(state_dict, model=None, model_options=None, ckpt_name: str = None):
if model_options is None:
model_options = {}
controlnet_data = state_dict
if 'after_proj_list.18.bias' in controlnet_data.keys(): # Hunyuan DiT
return load_controlnet_hunyuandit(controlnet_data)
return load_controlnet_hunyuandit(controlnet_data, model_options=model_options)
if "lora_controlnet" in controlnet_data:
return ControlLora(controlnet_data)
return ControlLora(controlnet_data, model_options=model_options)
controlnet_config = None
supported_inference_dtypes = None
@ -590,13 +637,13 @@ def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
controlnet_data = new_sd
elif "controlnet_blocks.0.weight" in controlnet_data:
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs_mistoline(controlnet_data)
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data) # SD3 diffusers controlnet
return load_controlnet_mmdit(controlnet_data, model_options=model_options) # SD3 diffusers controlnet
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data)
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
elif "controlnet_blocks.0.linear.weight" in controlnet_data: # mistoline flux
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True)
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False
@ -608,25 +655,36 @@ def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
elif key in controlnet_data:
prefix = ""
else:
net = load_t2i_adapter(controlnet_data)
net = load_t2i_adapter(controlnet_data, model_options=model_options)
if net is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
logging.error("error could not detect control model type.")
return net
if controlnet_config is None:
model_config = model_detection.model_config_from_unet(controlnet_data, prefix, True)
supported_inference_dtypes = model_config.supported_inference_dtypes
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
controlnet_config = model_config.unet_config
unet_dtype = model_options.get("dtype", None)
if unet_dtype is None:
weight_dtype = utils.weight_dtype(controlnet_data)
if supported_inference_dtypes is None:
supported_inference_dtypes = [model_management.unet_dtype()]
if weight_dtype is not None:
supported_inference_dtypes.append(weight_dtype)
unet_dtype = model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
load_device = model_management.get_torch_device()
if supported_inference_dtypes is None:
unet_dtype = model_management.unet_dtype()
else:
unet_dtype = model_management.unet_dtype(supported_dtypes=supported_inference_dtypes)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device)
if manual_cast_dtype is not None:
controlnet_config["operations"] = ops.manual_cast
operations = model_options.get("custom_operations", None)
if operations is None:
operations = ops.pick_operations(unet_dtype, manual_cast_dtype)
controlnet_config["operations"] = operations
controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = model_management.unet_offload_device()
controlnet_config.pop("out_channels")
@ -663,15 +721,26 @@ def load_controlnet(ckpt_path, model=None, weight_dtype=FLUX_WEIGHT_DTYPES[0]):
if len(unexpected) > 0:
logging.debug("unexpected controlnet keys: {}".format(unexpected))
global_average_pooling = False
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): # TODO: smarter way of enabling global_average_pooling
global_average_pooling = True
filename = os.path.splitext(ckpt_name)[0]
global_average_pooling = model_options.get("global_average_pooling", False)
control = ControlNet(control_model, global_average_pooling=global_average_pooling, load_device=load_device, manual_cast_dtype=manual_cast_dtype, ckpt_name=filename)
return control
def load_controlnet(ckpt_path, model=None, model_options=None):
if model_options is None:
model_options = {}
if "global_average_pooling" not in model_options:
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): # TODO: smarter way of enabling global_average_pooling
model_options["global_average_pooling"] = True
cnet = load_controlnet_state_dict(utils.load_torch_file(ckpt_path, safe_load=True), model=model, model_options=model_options, ckpt_name=ckpt_path)
if cnet is None:
logging.error("error checkpoint does not contain controlnet or t2i adapter data {}".format(ckpt_path))
return cnet
class T2IAdapter(ControlBase):
def __init__(self, t2i_model, channels_in, compression_ratio, upscale_algorithm, device=None):
super().__init__(device)
@ -728,7 +797,7 @@ class T2IAdapter(ControlBase):
return c
def load_t2i_adapter(t2i_data):
def load_t2i_adapter(t2i_data, model_options={}): # TODO: model_options
compression_ratio = 8
upscale_algorithm = 'nearest-exact'

View File

@ -17,6 +17,9 @@ def load_extra_path_config(yaml_path):
if "base_path" in conf:
base_path = conf.pop("base_path")
base_path = os.path.expandvars(os.path.expanduser(base_path))
is_default = False
if "is_default" in conf:
is_default = conf.pop("is_default")
for x in conf:
for y in conf[x].split("\n"):
if len(y) == 0:
@ -25,4 +28,4 @@ def load_extra_path_config(yaml_path):
if base_path is not None:
full_path = os.path.join(base_path, full_path)
logging.info("Adding extra search path {} {}".format(x, full_path))
folder_paths.add_model_folder_path(x, full_path)
folder_paths.add_model_folder_path(x, full_path, is_default)

View File

@ -45,6 +45,17 @@ def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
return append_zero(sigmas)
def get_sigmas_laplace(n, sigma_min, sigma_max, mu=0., beta=0.5, device='cpu'):
"""Constructs the noise schedule proposed by Tiankai et al. (2024). """
epsilon = 1e-5 # avoid log(0)
x = torch.linspace(0, 1, n, device=device)
clamp = lambda x: torch.clamp(x, min=sigma_min, max=sigma_max)
lmb = mu - beta * torch.sign(0.5-x) * torch.log(1 - 2 * torch.abs(0.5-x) + epsilon)
sigmas = clamp(torch.exp(lmb))
return sigmas
def to_d(x, sigma, denoised):
"""Converts a denoiser output to a Karras ODE derivative."""
return (x - denoised) / utils.append_dims(sigma, x.ndim)

View File

@ -414,7 +414,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
weight *= strength_model
if isinstance(v, list):
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype),)
v = (calculate_weight(v[1:], model_management.cast_to_device(v[0], weight.device, intermediate_dtype, copy=True), key, intermediate_dtype=intermediate_dtype),)
patch_type = ""
if len(v) == 1:

View File

@ -199,6 +199,8 @@ Visit the repository, accept the terms, and then do one of the following:
# a path was found for any reason, so we should invalidate the cache
if path is not None:
folder_paths.invalidate_cache(folder_name)
if path is None:
raise FileNotFoundError(f"Model in folder '{folder_name}' with filename '{filename}' not found, and no download candidates matched for the filename.")
return path

View File

@ -367,7 +367,7 @@ class LoadedModel:
self.model_unload()
raise e
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and self.real_model is not None:
with torch.no_grad():
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
@ -716,6 +716,8 @@ def maximum_vram_for_weights(device=None) -> int:
def unet_dtype(device=None, model_params=0, supported_dtypes=(torch.float16, torch.bfloat16, torch.float32)):
if model_params < 0:
model_params = 1000000000000000000000
if args.bf16_unet:
return torch.bfloat16
if args.fp16_unet:

View File

@ -287,17 +287,21 @@ class ModelPatcher(ModelManageable):
return list(p)
def get_key_patches(self, filter_prefix=None):
model_management.unload_model_clones(self)
model_sd = self.model_state_dict()
p = {}
for k in model_sd:
if filter_prefix is not None:
if not k.startswith(filter_prefix):
continue
if k in self.patches:
p[k] = [model_sd[k]] + self.patches[k]
bk = self.backup.get(k, None)
if bk is not None:
weight = bk.weight
else:
p[k] = (model_sd[k],)
weight = model_sd[k]
if k in self.patches:
p[k] = [weight] + self.patches[k]
else:
p[k] = (weight,)
return p
def model_state_dict(self, filter_prefix=None):

View File

@ -696,11 +696,11 @@ class VAELoader:
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
decoder = next(filter(lambda a: a.startswith("{}_decoder.".format(name)), approx_vaes))
enc = utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
enc = utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", encoder))
for k in enc:
sd_["taesd_encoder.{}".format(k)] = enc[k]
dec = utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
dec = utils.load_torch_file(folder_paths.get_full_path_or_raise("vae_approx", decoder))
for k in dec:
sd_["taesd_decoder.{}".format(k)] = dec[k]
@ -769,7 +769,13 @@ class ControlNetLoaderWeights:
def load_controlnet(self, control_net_name, weight_dtype):
controlnet_path = get_or_download("controlnet", control_net_name, KNOWN_CONTROLNETS)
controlnet_ = controlnet.load_controlnet(controlnet_path, weight_dtype=weight_dtype)
model_options = {}
if weight_dtype == "float8_e5m2":
model_options["dtype"] = torch.float8_e5m2
elif weight_dtype == "float8_e4m3fn":
model_options["dtype"] = torch.float8_e4m3fn
controlnet_ = controlnet.load_controlnet(controlnet_path, model_options=model_options)
return (controlnet_,)
class DiffControlNetLoader:
@ -800,6 +806,7 @@ class ControlNetApply:
RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "apply_controlnet"
DEPRECATED = True
CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, conditioning, control_net, image: RGBImageBatch, strength):
@ -829,7 +836,10 @@ class ControlNetApplyAdvanced:
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
},
"optional": {"vae": ("VAE", ),
}
}
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
RETURN_NAMES = ("positive", "negative")
@ -837,7 +847,7 @@ class ControlNetApplyAdvanced:
CATEGORY = "conditioning/controlnet"
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None):
def apply_controlnet(self, positive, negative, control_net, image, strength, start_percent, end_percent, vae=None, extra_concat=[]):
if strength == 0:
return (positive, negative)
@ -854,7 +864,7 @@ class ControlNetApplyAdvanced:
if prev_cnet in cnets:
c_net = cnets[prev_cnet]
else:
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae)
c_net = control_net.copy().set_cond_hint(control_hint, strength, (start_percent, end_percent), vae=vae, extra_concat=extra_concat)
c_net.set_previous_controlnet(prev_cnet)
cnets[prev_cnet] = c_net
@ -1932,8 +1942,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"ConditioningSetArea": "Conditioning (Set Area)",
"ConditioningSetAreaPercentage": "Conditioning (Set Area with Percentage)",
"ConditioningSetMask": "Conditioning (Set Mask)",
"ControlNetApply": "Apply ControlNet",
"ControlNetApplyAdvanced": "Apply ControlNet (Advanced)",
"ControlNetApply": "Apply ControlNet (OLD)",
"ControlNetApplyAdvanced": "Apply ControlNet",
# Latent
"VAEEncodeForInpaint": "VAE Encode (for Inpainting)",
"SetLatentNoiseMask": "Set Latent Noise Mask",

View File

@ -264,7 +264,6 @@ def fp8_linear(self, input):
if len(input.shape) == 3:
inn = input.reshape(-1, input.shape[2]).to(dtype)
non_blocking = model_management.device_supports_non_blocking(input.device)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t()
@ -304,10 +303,10 @@ class fp8_ops(manual_cast):
return torch.nn.functional.linear(input, weight, bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None):
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False):
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast:
if args.fast and not disable_fast_fp8:
if model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast

View File

@ -76,14 +76,14 @@ class CLIP:
clip = target.clip
tokenizer = target.tokenizer
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
load_device = model_options.get("load_device", model_management.text_encoder_device())
offload_device = model_options.get("offload_device", model_management.text_encoder_offload_device())
dtype = model_options.get("dtype", None)
if dtype is None:
dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
params['device'] = model_options.get("initial_device", model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype)))
if "textmodel_json_config" not in params and textmodel_json_config is not None:
params['textmodel_json_config'] = textmodel_json_config
@ -469,12 +469,8 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = sa_t5.SAT5Tokenizer
else:
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
clip_target.clip = long_clipl.LongClipModel
clip_target.tokenizer = long_clipl.LongClipTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2:
if clip_type == CLIPType.SD3:
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
@ -499,10 +495,12 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.tokenizer = sd3_clip.SD3Tokenizer
parameters = 0
tokenizer_data = {}
for c in clip_data:
parameters += utils.calculate_parameters(c)
tokenizer_data, model_options = long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters, model_options=model_options)
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
@ -548,14 +546,22 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None):
if te_model_options is None:
te_model_options = {}
if model_options is None:
model_options = {}
sd = utils.load_torch_file(ckpt_path)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, ckpt_path=ckpt_path)
if out is None:
raise RuntimeError("Could not detect model type of: {}".format(ckpt_path))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, ckpt_path=""):
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options=None, te_model_options=None, ckpt_path=""):
if te_model_options is None:
te_model_options = {}
if model_options is None:
model_options = {}
clip = None
clipvision = None
vae = None
@ -632,7 +638,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
return (_model_patcher, clip, vae, clipvision)
def load_diffusion_model_state_dict(sd, model_options: dict = None): # load unet in diffusers or regular format
def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: Optional[str]=""): # load unet in diffusers or regular format
if model_options is None:
model_options = {}
dtype = model_options.get("dtype", None)
@ -677,21 +683,21 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None): # load une
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
model_config.custom_operations = model_options.get("custom_operations", None)
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
model = model_config.get_model(new_sd, "")
model = model.to(offload_device)
model.load_model_weights(new_sd, "")
left_over = sd.keys()
if len(left_over) > 0:
logging.info("left over keys in unet: {}".format(left_over))
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device)
return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device, ckpt_name=os.path.basename(ckpt_path))
def load_diffusion_model(unet_path, model_options: dict = None):
if model_options is None:
model_options = {}
sd = utils.load_torch_file(unet_path)
model = load_diffusion_model_state_dict(sd, model_options=model_options)
model = load_diffusion_model_state_dict(sd, model_options=model_options, ckpt_path=unet_path)
if model is None:
logging.error("ERROR UNSUPPORTED UNET {}".format(unet_path))
raise RuntimeError("ERROR: Could not detect model type of: {}".format(unet_path))

View File

@ -657,6 +657,7 @@ class SD1Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
self.sd_tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text: str, return_word_ids=False):
@ -698,6 +699,7 @@ class SD1ClipModel(torch.nn.Module):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config, **kwargs))
self.dtypes = set()

View File

@ -26,8 +26,11 @@ class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
class SDXLTokenizer:
def __init__(self, embedding_directory=None, **kwargs):
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
def __init__(self, embedding_directory=None, tokenizer_data=None, **kwargs):
if tokenizer_data is None:
tokenizer_data = {}
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False):
@ -50,9 +53,12 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, device="cpu", dtype=None, model_options=None):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
if model_options is None:
model_options = {}
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = {dtype}
@ -69,7 +75,8 @@ class SDXLClipModel(torch.nn.Module):
token_weight_pairs_l = token_weight_pairs["l"]
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
return torch.cat([l_out, g_out], dim=-1), g_pooled
cut_to = min(l_out.shape[1], g_out.shape[1])
return torch.cat([l_out[:,:cut_to], g_out[:,:cut_to]], dim=-1), g_pooled
def load_sd(self, sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@ -79,20 +86,29 @@ class SDXLClipModel(torch.nn.Module):
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, model_options={}):
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, model_options=None):
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options, textmodel_json_config=textmodel_json_config)
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="g", tokenizer=StableCascadeClipGTokenizer)
class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None, model_options={}):
textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")

View File

@ -28,13 +28,15 @@ class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = dict()
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
def tokenize_with_weights(self, text: str, return_word_ids=False):
out = {}
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
out = {
"l": self.clip_l.tokenize_with_weights(text, return_word_ids),
"t5xxl": self.t5xxl.tokenize_with_weights(text, return_word_ids)
}
return out
def untokenize(self, token_weight_pair):
@ -48,12 +50,15 @@ class FluxTokenizer:
class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options=None):
super().__init__()
if model_options is None:
model_options = {}
dtype_t5 = model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5])
self.dtypes = {dtype, dtype_t5}
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
@ -80,6 +85,9 @@ class FluxClipModel(torch.nn.Module):
def flux_clip(dtype_t5=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_

View File

@ -4,18 +4,41 @@ from ..component_model.files import get_path_as_dict
class LongClipTokenizer_(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, textmodel_json_config=None):
def __init__(self, *args, **kwargs):
kwargs = kwargs or {}
textmodel_json_config = kwargs.get("textmodel_json_config", None)
textmodel_json_config = get_path_as_dict(textmodel_json_config, "long_clipl.json", package=__package__)
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
if model_options is None:
model_options = {}
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
def model_options_long_clip(sd, tokenizer_data, model_options):
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
if w is None:
w = sd.get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
tokenizer_data = tokenizer_data.copy()
model_options = model_options.copy()
tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
model_options["clip_l_class"] = LongClipModel_
return tokenizer_data, model_options

View File

@ -28,7 +28,8 @@ class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = dict()
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory)
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
@ -54,7 +55,8 @@ class SD3ClipModel(torch.nn.Module):
super().__init__()
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None
@ -107,7 +109,8 @@ class SD3ClipModel(torch.nn.Module):
if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
if lg_out is not None:
lg_out = torch.cat([lg_out, g_out], dim=-1)
cut_to = min(lg_out.shape[1], g_out.shape[1])
lg_out = torch.cat([lg_out[:, :cut_to], g_out[:, :cut_to]], dim=-1)
else:
lg_out = torch.nn.functional.pad(g_out, (768, 0))
else:
@ -145,6 +148,9 @@ class SD3ClipModel(torch.nn.Module):
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_

View File

@ -95,6 +95,7 @@ def calculate_parameters(sd, prefix=""):
params += w.nelement()
return params
def weight_dtype(sd, prefix=""):
dtypes = {}
for k in sd.keys():
@ -107,6 +108,7 @@ def weight_dtype(sd, prefix=""):
return max(dtypes, key=dtypes.get)
def state_dict_key_replace(state_dict, keys_to_replace):
for x in keys_to_replace:
if x in state_dict:
@ -472,6 +474,7 @@ def auraflow_to_diffusers(mmdit_config, output_prefix=""):
return key_map
def flux_to_diffusers(mmdit_config, output_prefix=""):
n_double_layers = mmdit_config.get("depth", 0)
n_single_layers = mmdit_config.get("depth_single_blocks", 0)
@ -496,27 +499,27 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}add_v_proj.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
block_map = {
"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
"norm1.linear.weight": "img_mod.lin.weight",
"norm1.linear.bias": "img_mod.lin.bias",
"norm1_context.linear.weight": "txt_mod.lin.weight",
"norm1_context.linear.bias": "txt_mod.lin.bias",
"attn.to_add_out.weight": "txt_attn.proj.weight",
"attn.to_add_out.bias": "txt_attn.proj.bias",
"ff.net.0.proj.weight": "img_mlp.0.weight",
"ff.net.0.proj.bias": "img_mlp.0.bias",
"ff.net.2.weight": "img_mlp.2.weight",
"ff.net.2.bias": "img_mlp.2.bias",
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
}
"attn.to_out.0.weight": "img_attn.proj.weight",
"attn.to_out.0.bias": "img_attn.proj.bias",
"norm1.linear.weight": "img_mod.lin.weight",
"norm1.linear.bias": "img_mod.lin.bias",
"norm1_context.linear.weight": "txt_mod.lin.weight",
"norm1_context.linear.bias": "txt_mod.lin.bias",
"attn.to_add_out.weight": "txt_attn.proj.weight",
"attn.to_add_out.bias": "txt_attn.proj.bias",
"ff.net.0.proj.weight": "img_mlp.0.weight",
"ff.net.0.proj.bias": "img_mlp.0.bias",
"ff.net.2.weight": "img_mlp.2.weight",
"ff.net.2.bias": "img_mlp.2.bias",
"ff_context.net.0.proj.weight": "txt_mlp.0.weight",
"ff_context.net.0.proj.bias": "txt_mlp.0.bias",
"ff_context.net.2.weight": "txt_mlp.2.weight",
"ff_context.net.2.bias": "txt_mlp.2.bias",
"attn.norm_q.weight": "img_attn.norm.query_norm.scale",
"attn.norm_k.weight": "img_attn.norm.key_norm.scale",
"attn.norm_added_q.weight": "txt_attn.norm.query_norm.scale",
"attn.norm_added_k.weight": "txt_attn.norm.key_norm.scale",
}
for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
@ -534,13 +537,13 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
key_map["{}.proj_mlp.{}".format(prefix_from, end)] = (qkv, (0, hidden_size * 3, hidden_size * 4))
block_map = {
"norm.linear.weight": "modulation.lin.weight",
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
}
"norm.linear.weight": "modulation.lin.weight",
"norm.linear.bias": "modulation.lin.bias",
"proj_out.weight": "linear2.weight",
"proj_out.bias": "linear2.bias",
"attn.norm_q.weight": "norm.query_norm.scale",
"attn.norm_k.weight": "norm.key_norm.scale",
}
for k in block_map:
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, block_map[k])
@ -763,7 +766,9 @@ def common_upscale(samples, width, height, upscale_method, crop):
def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
return math.ceil((height / (tile_y - overlap))) * math.ceil((width / (tile_x - overlap)))
rows = 1 if height <= tile_y else math.ceil((height - overlap) / (tile_y - overlap))
cols = 1 if width <= tile_x else math.ceil((width - overlap) / (tile_x - overlap))
return rows * cols
@torch.inference_mode()
@ -773,10 +778,19 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
for b in range(samples.shape[0]):
s = samples[b:b + 1]
# handle entire input fitting in a single tile
if all(s.shape[d + 2] <= tile[d] for d in range(dims)):
output[b:b + 1] = function(s).to(output_device)
if pbar is not None:
pbar.update(1)
continue
out = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
out_div = torch.zeros([s.shape[0], out_channels] + list(map(lambda a: round(a * upscale_amount), s.shape[2:])), device=output_device)
for it in itertools.product(*map(lambda a: range(0, a[0], a[1] - overlap), zip(s.shape[2:], tile))):
positions = [range(0, s.shape[d + 2], tile[d] - overlap) if s.shape[d + 2] > tile[d] else [0] for d in range(dims)]
for it in itertools.product(*positions):
s_in = s
upscaled = []
@ -785,15 +799,16 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
l = min(tile[d], s.shape[d + 2] - pos)
s_in = s_in.narrow(d + 2, pos, l)
upscaled.append(round(pos * upscale_amount))
ps = function(s_in).to(output_device)
mask = torch.ones_like(ps)
feather = round(overlap * upscale_amount)
for t in range(feather):
for d in range(2, dims + 2):
m = mask.narrow(d, t, 1)
m *= ((1.0 / feather) * (t + 1))
m = mask.narrow(d, mask.shape[d] - 1 - t, 1)
m *= ((1.0 / feather) * (t + 1))
a = (t + 1) / feather
mask.narrow(d, t, 1).mul_(a)
mask.narrow(d, mask.shape[d] - 1 - t, 1).mul_(a)
o = out
o_d = out_div
@ -801,8 +816,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
o = o.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o_d = o_d.narrow(d + 2, upscaled[d], mask.shape[d + 2])
o += ps * mask
o_d += mask
o.add_(ps * mask)
o_d.add_(mask)
if pbar is not None:
pbar.update(1)

1
comfy/web/assets/CREDIT.txt generated vendored Normal file
View File

@ -0,0 +1 @@
Thanks to OpenArt (https://openart.ai) for providing the sorted-custom-node-map data, captured in September 2024.

3142
comfy/web/assets/GraphView-DN9xGvF3.js generated vendored Normal file

File diff suppressed because one or more lines are too long

1
comfy/web/assets/GraphView-DN9xGvF3.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

158
comfy/web/assets/GraphView-DXU9yRen.css generated vendored Normal file
View File

@ -0,0 +1,158 @@
.group-title-editor.node-title-editor[data-v-fc3f26e3] {
z-index: 9999;
padding: 0.25rem;
}
[data-v-fc3f26e3] .editable-text {
width: 100%;
height: 100%;
}
[data-v-fc3f26e3] .editable-text input {
width: 100%;
height: 100%;
/* Override the default font size */
font-size: inherit;
}
.side-bar-button-icon {
font-size: var(--sidebar-icon-size) !important;
}
.side-bar-button-selected .side-bar-button-icon {
font-size: var(--sidebar-icon-size) !important;
font-weight: bold;
}
.side-bar-button[data-v-caa3ee9c] {
width: var(--sidebar-width);
height: var(--sidebar-width);
border-radius: 0;
}
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
.comfyui-body-left .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
border-left: 4px solid var(--p-button-text-primary-color);
}
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c],
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-caa3ee9c]:hover {
border-right: 4px solid var(--p-button-text-primary-color);
}
:root {
--sidebar-width: 64px;
--sidebar-icon-size: 1.5rem;
}
:root .small-sidebar {
--sidebar-width: 40px;
--sidebar-icon-size: 1rem;
}
.side-tool-bar-container[data-v-ed7a1148] {
display: flex;
flex-direction: column;
align-items: center;
pointer-events: auto;
width: var(--sidebar-width);
height: 100%;
background-color: var(--comfy-menu-bg);
color: var(--fg-color);
}
.side-tool-bar-end[data-v-ed7a1148] {
align-self: flex-end;
margin-top: auto;
}
.sidebar-content-container[data-v-ed7a1148] {
height: 100%;
overflow-y: auto;
}
.p-splitter-gutter {
pointer-events: auto;
}
.gutter-hidden {
display: none !important;
}
.side-bar-panel[data-v-edca8328] {
background-color: var(--bg-color);
pointer-events: auto;
}
.splitter-overlay[data-v-edca8328] {
width: 100%;
height: 100%;
position: absolute;
top: 0;
left: 0;
background-color: transparent;
pointer-events: none;
/* Set it the same as the ComfyUI menu */
/* Note: Lite-graph DOM widgets have the same z-index as the node id, so
999 should be sufficient to make sure splitter overlays on node's DOM
widgets */
z-index: 999;
border: none;
}
[data-v-37f672ab] .highlight {
background-color: var(--p-primary-color);
color: var(--p-primary-contrast-color);
font-weight: bold;
border-radius: 0.25rem;
padding: 0rem 0.125rem;
margin: -0.125rem 0.125rem;
}
.comfy-vue-node-search-container[data-v-2d409367] {
display: flex;
width: 100%;
min-width: 26rem;
align-items: center;
justify-content: center;
}
.comfy-vue-node-search-container[data-v-2d409367] * {
pointer-events: auto;
}
.comfy-vue-node-preview-container[data-v-2d409367] {
position: absolute;
left: -350px;
top: 50px;
}
.comfy-vue-node-search-box[data-v-2d409367] {
z-index: 10;
flex-grow: 1;
}
._filter-button[data-v-2d409367] {
z-index: 10;
}
._dialog[data-v-2d409367] {
min-width: 26rem;
}
.invisible-dialog-root {
width: 30%;
min-width: 24rem;
max-width: 48rem;
border: 0 !important;
background-color: transparent !important;
margin-top: 25vh;
}
.node-search-box-dialog-mask {
align-items: flex-start !important;
}
.node-tooltip[data-v-e0597bf9] {
background: var(--comfy-input-bg);
border-radius: 5px;
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
color: var(--input-text);
font-family: sans-serif;
left: 0;
max-width: 30vw;
padding: 4px 8px;
position: absolute;
top: 0;
transform: translate(5px, calc(-100% - 5px));
white-space: pre-wrap;
z-index: 99999;
}

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { C as ComfyDialog, $ as $el, a as ComfyApp, b as app, L as LGraphCanvas, c as LiteGraph, d as LGraphNode, e as applyTextReplacements, f as ComfyWidgets, g as addValueControlWidgets, D as DraggableList, h as api, i as LGraphGroup, u as useToastStore } from "./index-Dfv2aLsq.js";
import { aM as ComfyDialog, aN as $el, aO as ComfyApp, c as app, aH as LGraphCanvas, k as LiteGraph, e as LGraphNode, aP as applyTextReplacements, aQ as ComfyWidgets, aR as addValueControlWidgets, aS as DraggableList, av as useNodeDefStore, aT as api, L as LGraphGroup, aU as useToastStore, at as NodeSourceType, aV as NodeBadgeMode, u as useSettingStore, q as computed, w as watch, aW as BadgePosition, aJ as LGraphBadge$1, aX as _ } from "./index-Drc_oD2f.js";
class ClipspaceDialog extends ComfyDialog {
static {
__name(this, "ClipspaceDialog");
@ -213,7 +213,9 @@ const colorPalettes = {
WIDGET_SECONDARY_TEXT_COLOR: "#999",
LINK_COLOR: "#9A9",
EVENT_LINK_COLOR: "#A86",
CONNECTING_LINK_COLOR: "#AFA"
CONNECTING_LINK_COLOR: "#AFA",
BADGE_FG_COLOR: "#FFF",
BADGE_BG_COLOR: "#0F1F0F"
},
comfy_base: {
"fg-color": "#fff",
@ -283,7 +285,9 @@ const colorPalettes = {
WIDGET_SECONDARY_TEXT_COLOR: "#555",
LINK_COLOR: "#4CAF50",
EVENT_LINK_COLOR: "#FF9800",
CONNECTING_LINK_COLOR: "#2196F3"
CONNECTING_LINK_COLOR: "#2196F3",
BADGE_FG_COLOR: "#000",
BADGE_BG_COLOR: "#FFF"
},
comfy_base: {
"fg-color": "#222",
@ -621,6 +625,32 @@ const defaultColorPaletteId = "dark";
const els = {
select: null
};
const getCustomColorPalettes = /* @__PURE__ */ __name(() => {
return app.ui.settings.getSettingValue(idCustomColorPalettes, {});
}, "getCustomColorPalettes");
const setCustomColorPalettes = /* @__PURE__ */ __name((customColorPalettes) => {
return app.ui.settings.setSettingValue(
idCustomColorPalettes,
customColorPalettes
);
}, "setCustomColorPalettes");
const defaultColorPalette = colorPalettes[defaultColorPaletteId];
const getColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
if (!colorPaletteId) {
colorPaletteId = app.ui.settings.getSettingValue(id$4, defaultColorPaletteId);
}
if (colorPaletteId.startsWith("custom_")) {
colorPaletteId = colorPaletteId.substr(7);
let customColorPalettes = getCustomColorPalettes();
if (customColorPalettes[colorPaletteId]) {
return customColorPalettes[colorPaletteId];
}
}
return colorPalettes[colorPaletteId];
}, "getColorPalette");
const setColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
app.ui.settings.setSettingValue(id$4, colorPaletteId);
}, "setColorPalette");
app.registerExtension({
name: id$4,
init() {
@ -695,28 +725,19 @@ app.registerExtension({
comfy_base: {}
}
};
const defaultColorPalette = colorPalettes[defaultColorPaletteId];
for (const key in defaultColorPalette.colors.litegraph_base) {
const defaultColorPalette2 = colorPalettes[defaultColorPaletteId];
for (const key in defaultColorPalette2.colors.litegraph_base) {
if (!colorPalette.colors.litegraph_base[key]) {
colorPalette.colors.litegraph_base[key] = "";
}
}
for (const key in defaultColorPalette.colors.comfy_base) {
for (const key in defaultColorPalette2.colors.comfy_base) {
if (!colorPalette.colors.comfy_base[key]) {
colorPalette.colors.comfy_base[key] = "";
}
}
return completeColorPalette(colorPalette);
}, "getColorPaletteTemplate");
const getCustomColorPalettes = /* @__PURE__ */ __name(() => {
return app.ui.settings.getSettingValue(idCustomColorPalettes, {});
}, "getCustomColorPalettes");
const setCustomColorPalettes = /* @__PURE__ */ __name((customColorPalettes) => {
return app.ui.settings.setSettingValue(
idCustomColorPalettes,
customColorPalettes
);
}, "setCustomColorPalettes");
const addCustomColorPalette = /* @__PURE__ */ __name(async (colorPalette) => {
if (typeof colorPalette !== "object") {
alert("Invalid color palette.");
@ -807,25 +828,6 @@ app.registerExtension({
app.canvas.draw(true, true);
}
}, "loadColorPalette");
const getColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
if (!colorPaletteId) {
colorPaletteId = app.ui.settings.getSettingValue(
id$4,
defaultColorPaletteId
);
}
if (colorPaletteId.startsWith("custom_")) {
colorPaletteId = colorPaletteId.substr(7);
let customColorPalettes = getCustomColorPalettes();
if (customColorPalettes[colorPaletteId]) {
return customColorPalettes[colorPaletteId];
}
}
return colorPalettes[colorPaletteId];
}, "getColorPalette");
const setColorPalette = /* @__PURE__ */ __name((colorPaletteId) => {
app.ui.settings.setSettingValue(id$4, colorPaletteId);
}, "setColorPalette");
const fileInput = $el("input", {
type: "file",
accept: ".json",
@ -994,6 +996,10 @@ app.registerExtension({
});
}
});
window.comfyAPI = window.comfyAPI || {};
window.comfyAPI.colorPalette = window.comfyAPI.colorPalette || {};
window.comfyAPI.colorPalette.defaultColorPalette = defaultColorPalette;
window.comfyAPI.colorPalette.getColorPalette = getColorPalette;
const ext$2 = {
name: "Comfy.ContextMenuFilter",
init() {
@ -1360,7 +1366,7 @@ class PrimitiveNode extends LGraphNode {
this.#mergeWidgetConfig();
}
}
onConnectionsChange(_, index, connected) {
onConnectionsChange(_2, index, connected) {
if (app.configuringGraph) {
return;
}
@ -1806,7 +1812,7 @@ app.registerExtension({
convertToInput(this, widget, config);
return true;
};
nodeType.prototype.getExtraMenuOptions = function(_, options) {
nodeType.prototype.getExtraMenuOptions = function(_2, options) {
const r = origGetExtraMenuOptions ? origGetExtraMenuOptions.apply(this, arguments) : void 0;
if (this.widgets) {
let toInput = [];
@ -1862,6 +1868,7 @@ app.registerExtension({
};
nodeType.prototype.onGraphConfigured = function() {
if (!this.inputs) return;
this.widgets ??= [];
for (const input of this.inputs) {
if (input.widget) {
if (!input.widget[GET_CONFIG]) {
@ -1919,7 +1926,7 @@ app.registerExtension({
return r;
};
function isNodeAtPos(pos) {
for (const n of app2.graph._nodes) {
for (const n of app2.graph.nodes) {
if (n.pos[0] === pos[0] && n.pos[1] === pos[1]) {
return true;
}
@ -2308,7 +2315,7 @@ class ManageGroupDialog extends ComfyDialog {
"button.comfy-btn",
{
onclick: /* @__PURE__ */ __name((e) => {
const node = app.graph._nodes.find(
const node = app.graph.nodes.find(
(n) => n.type === "workflow/" + this.selectedGroup
);
if (node) {
@ -2374,7 +2381,7 @@ class ManageGroupDialog extends ComfyDialog {
}
types[g] = type2;
if (!nodesByType) {
nodesByType = app.graph._nodes.reduce((p, n) => {
nodesByType = app.graph.nodes.reduce((p, n) => {
p[n.type] ??= [];
p[n.type].push(n);
return p;
@ -2424,7 +2431,7 @@ const Workflow = {
isInUseGroupNode(name) {
const id2 = `workflow/${name}`;
if (app.graph.extra?.groupNodes?.[name]) {
if (app.graph._nodes.find((n) => n.type === id2)) {
if (app.graph.nodes.find((n) => n.type === id2)) {
return Workflow.InUse.InWorkflow;
} else {
return Workflow.InUse.Registered;
@ -2576,6 +2583,8 @@ class GroupNodeConfig {
display_name: this.name,
category: "group nodes" + ("/" + source),
input: { required: {} },
description: `Group node combining ${this.nodeData.nodes.map((n) => n.type).join(", ")}`,
python_module: "custom_nodes." + this.name,
[GROUP]: this
};
this.inputs = [];
@ -2591,6 +2600,7 @@ class GroupNodeConfig {
}
this.#convertedToProcess = null;
await app.registerNodeDef("workflow/" + this.name, this.nodeDef);
useNodeDefStore().addNodeDef(this.nodeDef);
}
getLinks() {
this.linksFrom = {};
@ -2775,7 +2785,7 @@ class GroupNodeConfig {
checkPrimitiveConnection(link, inputName, inputs) {
const sourceNode = this.nodeData.nodes[link[0]];
if (sourceNode.type === "PrimitiveNode") {
const [sourceNodeId, _, targetNodeId, __] = link;
const [sourceNodeId, _2, targetNodeId, __] = link;
const primitiveDef = this.primitiveDefs[sourceNodeId];
const targetWidget = inputs[inputName];
const primitiveConfig = primitiveDef.input.required.value;
@ -3177,7 +3187,7 @@ class GroupNodeHandler {
return newNodes;
};
const getExtraMenuOptions = this.node.getExtraMenuOptions;
this.node.getExtraMenuOptions = function(_, options) {
this.node.getExtraMenuOptions = function(_2, options) {
getExtraMenuOptions?.apply(this, arguments);
let optionIndex = options.findIndex((o) => o.content === "Outputs");
if (optionIndex === -1) optionIndex = options.length;
@ -3353,7 +3363,7 @@ class GroupNodeHandler {
} else if (innerNode.type === "Reroute") {
const rerouteLinks = this.groupData.linksFrom[old.node.index];
if (rerouteLinks) {
for (const [_, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
for (const [_2, , targetNodeId, targetSlot] of rerouteLinks["0"]) {
const node = this.innerNodes[targetNodeId];
const input = node.inputs[targetSlot];
if (input.widget) {
@ -3599,7 +3609,7 @@ function addNodesToGroup(group, nodes = []) {
var node;
x1 = y1 = x2 = y2 = -1;
nx1 = ny1 = nx2 = ny2 = -1;
for (var n of [group._nodes, nodes]) {
for (var n of [group.nodes, nodes]) {
for (var i in n) {
node = n[i];
nx1 = node.pos[0];
@ -3659,7 +3669,7 @@ app.registerExtension({
return options;
}
group.recomputeInsideNodes();
const nodesInGroup = group._nodes;
const nodesInGroup = group.nodes;
options.push({
content: "Add Selected Nodes To Group",
disabled: !Object.keys(app.canvas.selected_nodes || {}).length,
@ -4002,6 +4012,16 @@ function prepare_mask(image, maskCanvas, maskCtx, maskColor) {
maskCtx.putImageData(maskData, 0, 0);
}
__name(prepare_mask, "prepare_mask");
var PointerType = /* @__PURE__ */ ((PointerType2) => {
PointerType2["Arc"] = "arc";
PointerType2["Rect"] = "rect";
return PointerType2;
})(PointerType || {});
var CompositionOperation = /* @__PURE__ */ ((CompositionOperation2) => {
CompositionOperation2["SourceOver"] = "source-over";
CompositionOperation2["DestinationOut"] = "destination-out";
return CompositionOperation2;
})(CompositionOperation || {});
class MaskEditorDialog extends ComfyDialog {
static {
__name(this, "MaskEditorDialog");
@ -4030,6 +4050,8 @@ class MaskEditorDialog extends ComfyDialog {
mousedown_pan_x;
mousedown_pan_y;
last_pressure;
pointer_type;
brush_pointer_type_select;
static getInstance() {
if (!MaskEditorDialog.instance) {
MaskEditorDialog.instance = new MaskEditorDialog();
@ -4077,7 +4099,7 @@ class MaskEditorDialog extends ComfyDialog {
divElement.style.borderColor = "var(--border-color)";
divElement.style.borderStyle = "solid";
divElement.style.fontSize = "15px";
divElement.style.height = "21px";
divElement.style.height = "25px";
divElement.style.padding = "1px 6px";
divElement.style.display = "flex";
divElement.style.position = "relative";
@ -4107,7 +4129,7 @@ class MaskEditorDialog extends ComfyDialog {
divElement.style.borderColor = "var(--border-color)";
divElement.style.borderStyle = "solid";
divElement.style.fontSize = "15px";
divElement.style.height = "21px";
divElement.style.height = "25px";
divElement.style.padding = "1px 6px";
divElement.style.display = "flex";
divElement.style.position = "relative";
@ -4126,8 +4148,63 @@ class MaskEditorDialog extends ComfyDialog {
self.opacity_slider_input.addEventListener("input", callback);
return divElement;
}
createPointerTypeSelect(self) {
const divElement = document.createElement("div");
divElement.id = "maskeditor-pointer-type";
divElement.style.cssFloat = "left";
divElement.style.fontFamily = "sans-serif";
divElement.style.marginRight = "4px";
divElement.style.color = "var(--input-text)";
divElement.style.backgroundColor = "var(--comfy-input-bg)";
divElement.style.borderRadius = "8px";
divElement.style.borderColor = "var(--border-color)";
divElement.style.borderStyle = "solid";
divElement.style.fontSize = "15px";
divElement.style.height = "25px";
divElement.style.padding = "1px 6px";
divElement.style.display = "flex";
divElement.style.position = "relative";
divElement.style.top = "2px";
divElement.style.pointerEvents = "auto";
const labelElement = document.createElement("label");
labelElement.textContent = "Pointer Type:";
const selectElement = document.createElement("select");
selectElement.style.borderRadius = "0";
selectElement.style.borderColor = "transparent";
selectElement.style.borderStyle = "unset";
selectElement.style.fontSize = "0.9em";
const optionArc = document.createElement("option");
optionArc.value = "arc";
optionArc.text = "Circle";
optionArc.selected = true;
const optionRect = document.createElement("option");
optionRect.value = "rect";
optionRect.text = "Square";
selectElement.appendChild(optionArc);
selectElement.appendChild(optionRect);
selectElement.addEventListener("change", (event) => {
const target = event.target;
self.pointer_type = target.value;
this.setBrushBorderRadius(self);
});
divElement.appendChild(labelElement);
divElement.appendChild(selectElement);
return divElement;
}
setBrushBorderRadius(self) {
if (self.pointer_type === "rect") {
this.brush.style.borderRadius = "0%";
this.brush.style.MozBorderRadius = "0%";
this.brush.style.WebkitBorderRadius = "0%";
} else {
this.brush.style.borderRadius = "50%";
this.brush.style.MozBorderRadius = "50%";
this.brush.style.WebkitBorderRadius = "50%";
}
}
setlayout(imgCanvas, maskCanvas) {
const self = this;
self.pointer_type = "arc";
var bottom_panel = document.createElement("div");
bottom_panel.style.position = "absolute";
bottom_panel.style.bottom = "0px";
@ -4140,13 +4217,11 @@ class MaskEditorDialog extends ComfyDialog {
brush.style.backgroundColor = "transparent";
brush.style.outline = "1px dashed black";
brush.style.boxShadow = "0 0 0 1px white";
brush.style.borderRadius = "50%";
brush.style.MozBorderRadius = "50%";
brush.style.WebkitBorderRadius = "50%";
brush.style.position = "absolute";
brush.style.zIndex = "8889";
brush.style.pointerEvents = "none";
this.brush = brush;
this.setBrushBorderRadius(self);
this.element.appendChild(imgCanvas);
this.element.appendChild(maskCanvas);
this.element.appendChild(bottom_panel);
@ -4177,6 +4252,7 @@ class MaskEditorDialog extends ComfyDialog {
}
}
);
this.brush_pointer_type_select = this.createPointerTypeSelect(self);
this.colorButton = this.createLeftButton(this.getColorButtonText(), () => {
if (self.brush_color_mode === "black") {
self.brush_color_mode = "white";
@ -4203,6 +4279,7 @@ class MaskEditorDialog extends ComfyDialog {
bottom_panel.appendChild(cancelButton);
bottom_panel.appendChild(this.brush_size_slider);
bottom_panel.appendChild(this.brush_opacity_slider);
bottom_panel.appendChild(this.brush_pointer_type_select);
bottom_panel.appendChild(this.colorButton);
imgCanvas.style.position = "absolute";
maskCanvas.style.position = "absolute";
@ -4568,19 +4645,22 @@ class MaskEditorDialog extends ComfyDialog {
}
if (diff > 20 && !this.drawing_mode)
requestAnimationFrame(() => {
self.maskCtx.beginPath();
self.maskCtx.fillStyle = this.getMaskFillStyle();
self.maskCtx.globalCompositeOperation = "source-over";
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
self.init_shape(
self,
"source-over"
/* SourceOver */
);
self.draw_shape(self, x, y, brush_size);
self.lastx = x;
self.lasty = y;
});
else
requestAnimationFrame(() => {
self.maskCtx.beginPath();
self.maskCtx.fillStyle = this.getMaskFillStyle();
self.maskCtx.globalCompositeOperation = "source-over";
self.init_shape(
self,
"source-over"
/* SourceOver */
);
var dx = x - self.lastx;
var dy = y - self.lasty;
var distance = Math.sqrt(dx * dx + dy * dy);
@ -4589,8 +4669,7 @@ class MaskEditorDialog extends ComfyDialog {
for (var i = 0; i < distance; i += 5) {
var px = self.lastx + directionX * i;
var py = self.lasty + directionY * i;
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
self.draw_shape(self, px, py, brush_size);
}
self.lastx = x;
self.lasty = y;
@ -4611,17 +4690,22 @@ class MaskEditorDialog extends ComfyDialog {
}
if (diff > 20 && !this.drawing_mode)
requestAnimationFrame(() => {
self.maskCtx.beginPath();
self.maskCtx.globalCompositeOperation = "destination-out";
self.maskCtx.arc(x2, y2, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
self.init_shape(
self,
"destination-out"
/* DestinationOut */
);
self.draw_shape(self, x2, y2, brush_size);
self.lastx = x2;
self.lasty = y2;
});
else
requestAnimationFrame(() => {
self.maskCtx.beginPath();
self.maskCtx.globalCompositeOperation = "destination-out";
self.init_shape(
self,
"destination-out"
/* DestinationOut */
);
var dx = x2 - self.lastx;
var dy = y2 - self.lasty;
var distance = Math.sqrt(dx * dx + dy * dy);
@ -4630,8 +4714,7 @@ class MaskEditorDialog extends ComfyDialog {
for (var i = 0; i < distance; i += 5) {
var px = self.lastx + directionX * i;
var py = self.lasty + directionY * i;
self.maskCtx.arc(px, py, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
self.draw_shape(self, px, py, brush_size);
}
self.lastx = x2;
self.lasty = y2;
@ -4665,20 +4748,47 @@ class MaskEditorDialog extends ComfyDialog {
const maskRect = self.maskCanvas.getBoundingClientRect();
const x = (event.offsetX || event.targetTouches[0].clientX - maskRect.left) / self.zoom_ratio;
const y = (event.offsetY || event.targetTouches[0].clientY - maskRect.top) / self.zoom_ratio;
self.maskCtx.beginPath();
if (!event.altKey && event.button == 0) {
self.maskCtx.fillStyle = this.getMaskFillStyle();
self.maskCtx.globalCompositeOperation = "source-over";
self.init_shape(
self,
"source-over"
/* SourceOver */
);
} else {
self.maskCtx.globalCompositeOperation = "destination-out";
self.init_shape(
self,
"destination-out"
/* DestinationOut */
);
}
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
self.maskCtx.fill();
self.draw_shape(self, x, y, brush_size);
self.lastx = x;
self.lasty = y;
self.lasttime = performance.now();
}
}
init_shape(self, compositionOperation) {
self.maskCtx.beginPath();
if (compositionOperation == "source-over") {
self.maskCtx.fillStyle = this.getMaskFillStyle();
self.maskCtx.globalCompositeOperation = "source-over";
} else if (compositionOperation == "destination-out") {
self.maskCtx.globalCompositeOperation = "destination-out";
}
}
draw_shape(self, x, y, brush_size) {
if (self.pointer_type === "rect") {
self.maskCtx.rect(
x - brush_size,
y - brush_size,
brush_size * 2,
brush_size * 2
);
} else {
self.maskCtx.arc(x, y, brush_size, 0, Math.PI * 2, false);
}
self.maskCtx.fill();
}
async save() {
const backupCanvas = document.createElement("canvas");
const backupCtx = backupCanvas.getContext("2d", {
@ -5264,7 +5374,7 @@ app.registerExtension({
updateNodes.push(node);
} else {
const nodeOutType = node.inputs && node.inputs[link?.target_slot] && node.inputs[link.target_slot].type ? node.inputs[link.target_slot].type : null;
if (inputType && inputType !== "*" && nodeOutType !== inputType) {
if (inputType && !LiteGraph.isValidConnection(inputType, nodeOutType)) {
node.disconnectInput(link.target_slot);
} else {
outputType = nodeOutType;
@ -5300,6 +5410,7 @@ app.registerExtension({
}
if (!targetWidget) {
targetWidget = targetNode.widgets?.find(
// @ts-expect-error fix widget types
(w) => w.name === targetInput.widget.name
);
}
@ -5342,7 +5453,7 @@ app.registerExtension({
};
this.isVirtualNode = true;
}
getExtraMenuOptions(_, options) {
getExtraMenuOptions(_2, options) {
options.unshift(
{
content: (this.properties.showOutputText ? "Hide" : "Show") + " Type",
@ -5564,8 +5675,7 @@ app.registerExtension({
slot_types_default_in: {},
async beforeRegisterNodeDef(nodeType, nodeData, app2) {
var nodeId = nodeData.name;
var inputs = [];
inputs = nodeData["input"]["required"];
const inputs = nodeData["input"]["required"];
for (const inputKey in inputs) {
var input = inputs[inputKey];
if (typeof input[0] !== "string") continue;
@ -5637,7 +5747,7 @@ app.registerExtension({
tooltip: "When dragging and resizing nodes while holding shift they will be aligned to the grid, this controls the size of that grid.",
defaultValue: LiteGraph.CANVAS_GRID_SIZE,
onChange(value) {
LiteGraph.CANVAS_GRID_SIZE = +value;
LiteGraph.CANVAS_GRID_SIZE = +value || 10;
}
});
const onNodeMoved = app.canvas.onNodeMoved;
@ -5697,7 +5807,7 @@ app.registerExtension({
}
if (app.canvas.last_mouse_dragging === false && app.shiftDown) {
this.recomputeInsideNodes();
for (const node of this._nodes) {
for (const node of this.nodes) {
node.alignToGrid();
}
LGraphNode.prototype.alignToGrid.apply(this);
@ -5730,7 +5840,7 @@ app.registerExtension({
LGraphCanvas.onGroupAdd = function() {
const v = onGroupAdd.apply(app.canvas, arguments);
if (app.shiftDown) {
const lastGroup = app.graph._groups[app.graph._groups.length - 1];
const lastGroup = app.graph.groups[app.graph.groups.length - 1];
if (lastGroup) {
roundVectorToGrid(lastGroup.pos);
roundVectorToGrid(lastGroup.size);
@ -6026,4 +6136,108 @@ app.registerExtension({
};
}
});
//# sourceMappingURL=index-CrROdkG4.js.map
function getNodeSource(node) {
const nodeDef = node.constructor.nodeData;
if (!nodeDef) {
return null;
}
const nodeDefStore = useNodeDefStore();
return nodeDefStore.nodeDefsByName[nodeDef.name]?.nodeSource ?? null;
}
__name(getNodeSource, "getNodeSource");
function isCoreNode(node) {
return getNodeSource(node)?.type === NodeSourceType.Core;
}
__name(isCoreNode, "isCoreNode");
function badgeTextVisible(node, badgeMode) {
return badgeMode === NodeBadgeMode.None || isCoreNode(node) && badgeMode === NodeBadgeMode.HideBuiltIn;
}
__name(badgeTextVisible, "badgeTextVisible");
function getNodeIdBadgeText(node, nodeIdBadgeMode) {
return badgeTextVisible(node, nodeIdBadgeMode) ? "" : `#${node.id}`;
}
__name(getNodeIdBadgeText, "getNodeIdBadgeText");
function getNodeSourceBadgeText(node, nodeSourceBadgeMode) {
const nodeSource = getNodeSource(node);
return badgeTextVisible(node, nodeSourceBadgeMode) ? "" : nodeSource?.badgeText ?? "";
}
__name(getNodeSourceBadgeText, "getNodeSourceBadgeText");
function getNodeLifeCycleBadgeText(node, nodeLifeCycleBadgeMode) {
let text = "";
const nodeDef = node.constructor.nodeData;
if (!nodeDef) {
return "";
}
if (nodeDef.deprecated) {
text = "[DEPR]";
}
if (nodeDef.experimental) {
text = "[BETA]";
}
return badgeTextVisible(node, nodeLifeCycleBadgeMode) ? "" : text;
}
__name(getNodeLifeCycleBadgeText, "getNodeLifeCycleBadgeText");
class NodeBadgeExtension {
static {
__name(this, "NodeBadgeExtension");
}
constructor(nodeIdBadgeMode = null, nodeSourceBadgeMode = null, nodeLifeCycleBadgeMode = null, colorPalette = null) {
this.nodeIdBadgeMode = nodeIdBadgeMode;
this.nodeSourceBadgeMode = nodeSourceBadgeMode;
this.nodeLifeCycleBadgeMode = nodeLifeCycleBadgeMode;
this.colorPalette = colorPalette;
}
name = "Comfy.NodeBadge";
init(app2) {
const settingStore = useSettingStore();
this.nodeSourceBadgeMode = computed(
() => settingStore.get("Comfy.NodeBadge.NodeSourceBadgeMode")
);
this.nodeIdBadgeMode = computed(
() => settingStore.get("Comfy.NodeBadge.NodeIdBadgeMode")
);
this.nodeLifeCycleBadgeMode = computed(
() => settingStore.get(
"Comfy.NodeBadge.NodeLifeCycleBadgeMode"
)
);
this.colorPalette = computed(
() => getColorPalette(settingStore.get("Comfy.ColorPalette"))
);
watch(this.nodeSourceBadgeMode, () => {
app2.graph.setDirtyCanvas(true, true);
});
watch(this.nodeIdBadgeMode, () => {
app2.graph.setDirtyCanvas(true, true);
});
watch(this.nodeLifeCycleBadgeMode, () => {
app2.graph.setDirtyCanvas(true, true);
});
}
nodeCreated(node, app2) {
node.badgePosition = BadgePosition.TopRight;
node.badge_enabled = true;
const badge = computed(
() => new LGraphBadge$1({
text: _.truncate(
[
getNodeIdBadgeText(node, this.nodeIdBadgeMode.value),
getNodeLifeCycleBadgeText(
node,
this.nodeLifeCycleBadgeMode.value
),
getNodeSourceBadgeText(node, this.nodeSourceBadgeMode.value)
].filter((s) => s.length > 0).join(" "),
{
length: 31
}
),
fgColor: this.colorPalette.value.colors.litegraph_base?.BADGE_FG_COLOR || defaultColorPalette.colors.litegraph_base.BADGE_FG_COLOR,
bgColor: this.colorPalette.value.colors.litegraph_base?.BADGE_BG_COLOR || defaultColorPalette.colors.litegraph_base.BADGE_BG_COLOR
})
);
node.badges.push(() => badge.value);
}
}
app.registerExtension(new NodeBadgeExtension());
//# sourceMappingURL=index-BDBCRrlL.js.map

1
comfy/web/assets/index-BDBCRrlL.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1
comfy/web/assets/index-Drc_oD2f.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

2602
comfy/web/assets/sorted-custom-node-map.json generated vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
var __defProp = Object.defineProperty;
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
import { j as createSpinner, h as api, $ as $el } from "./index-Dfv2aLsq.js";
import { aY as createSpinner, aT as api, aN as $el } from "./index-Drc_oD2f.js";
class UserSelectionScreen {
static {
__name(this, "UserSelectionScreen");
@ -117,4 +117,4 @@ window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
export {
UserSelectionScreen
};
//# sourceMappingURL=userSelection-DSpF-zVD.js.map
//# sourceMappingURL=userSelection-BM5u5JIA.js.map

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,3 @@
// Shim for extensions/core/colorPalette.ts
export const defaultColorPalette = window.comfyAPI.colorPalette.defaultColorPalette;
export const getColorPalette = window.comfyAPI.colorPalette.getColorPalette;

94
comfy/web/index.html vendored
View File

@ -1,50 +1,50 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>ComfyUI</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
<!-- Browser Test Fonts -->
<!-- <link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
<link href="https://fonts.googleapis.com/css2?family=Noto+Color+Emoji&family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
<style>
* {
font-family: 'Roboto Mono', 'Noto Color Emoji';
}
</style> -->
<link rel="stylesheet" type="text/css" href="user.css" />
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
<script type="module" crossorigin src="./assets/index-Dfv2aLsq.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-W4jP-SrU.css">
</head>
<body class="litegraph">
<div id="vue-app"></div>
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
<main class="comfy-user-selection-inner">
<h1>ComfyUI</h1>
<form>
<section>
<label>New user:
<input placeholder="Enter a username" />
</label>
</section>
<div class="comfy-user-existing">
<span class="or-separator">OR</span>
<section>
<label>
Existing user:
<select>
<option hidden disabled selected value> Select a user </option>
</select>
</label>
</section>
</div>
<footer>
<span class="comfy-user-error">&nbsp;</span>
<button class="comfy-btn comfy-user-button-next">Next</button>
</footer>
</form>
</main>
</div>
</body>
<head>
<meta charset="UTF-8">
<title>ComfyUI</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
<!-- Browser Test Fonts -->
<!-- <link href="https://fonts.googleapis.com/css2?family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
<link href="https://fonts.googleapis.com/css2?family=Noto+Color+Emoji&family=Roboto+Mono:ital,wght@0,100..700;1,100..700&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap" rel="stylesheet">
<style>
* {
font-family: 'Roboto Mono', 'Noto Color Emoji';
}
</style> -->
<link rel="stylesheet" type="text/css" href="user.css" />
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
<script type="module" crossorigin src="./assets/index-Drc_oD2f.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-8NH3XvqK.css">
</head>
<body class="litegraph">
<div id="vue-app"></div>
<div id="comfy-user-selection" class="comfy-user-selection" style="display: none;">
<main class="comfy-user-selection-inner">
<h1>ComfyUI</h1>
<form>
<section>
<label>New user:
<input placeholder="Enter a username" />
</label>
</section>
<div class="comfy-user-existing">
<span class="or-separator">OR</span>
<section>
<label>
Existing user:
<select>
<option hidden disabled selected value> Select a user </option>
</select>
</label>
</section>
</div>
<footer>
<span class="comfy-user-error">&nbsp;</span>
<button class="comfy-btn comfy-user-button-next">Next</button>
</footer>
</form>
</main>
</div>
</body>
</html>

View File

@ -1,4 +1,3 @@
// Shim for scripts/workflows.ts
export const trimJsonExt = window.comfyAPI.workflows.trimJsonExt;
export const ComfyWorkflowManager = window.comfyAPI.workflows.ComfyWorkflowManager;
export const ComfyWorkflow = window.comfyAPI.workflows.ComfyWorkflow;

View File

@ -73,6 +73,9 @@ class VAEDecodeAudio:
def decode(self, vae, samples):
audio = vae.decode(samples["samples"]).movedim(-1, 1)
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
std[std < 1.0] = 1.0
audio /= std
return ({"waveform": audio, "sample_rate": 44100},)

View File

@ -1,9 +1,12 @@
import comfy.utils
from comfy.cldm.control_types import UNION_CONTROLNET_TYPES
from comfy.nodes.base_nodes import ControlNetApplyAdvanced
class SetUnionControlNetType:
@classmethod
def INPUT_TYPES(s):
return {"required": {"control_net": ("CONTROL_NET", ),
return {"required": {"control_net": ("CONTROL_NET",),
"type": (["auto"] + list(UNION_CONTROLNET_TYPES.keys()),)
}}
@ -22,6 +25,37 @@ class SetUnionControlNetType:
return (control_net,)
class ControlNetInpaintingAliMamaApply(ControlNetApplyAdvanced):
@classmethod
def INPUT_TYPES(s):
return {"required": {"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"control_net": ("CONTROL_NET",),
"vae": ("VAE",),
"image": ("IMAGE",),
"mask": ("MASK",),
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
FUNCTION = "apply_inpaint_controlnet"
CATEGORY = "conditioning/controlnet"
def apply_inpaint_controlnet(self, positive, negative, control_net, vae, image, mask, strength, start_percent, end_percent):
extra_concat = []
if control_net.concat_mask:
mask = 1.0 - mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))
mask_apply = comfy.utils.common_upscale(mask, image.shape[2], image.shape[1], "bilinear", "center").round()
image = image * mask_apply.movedim(1, -1).repeat(1, 1, 1, image.shape[3])
extra_concat = [mask]
return self.apply_controlnet(positive, negative, control_net, image, strength, start_percent, end_percent, vae=vae, extra_concat=extra_concat)
NODE_CLASS_MAPPINGS = {
"SetUnionControlNetType": SetUnionControlNetType,
"ControlNetInpaintingAliMamaApply": ControlNetInpaintingAliMamaApply,
}

View File

@ -94,6 +94,27 @@ class PolyexponentialScheduler:
sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho)
return (sigmas, )
class LaplaceScheduler:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
"sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}),
"mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}),
"beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}),
}
}
RETURN_TYPES = ("SIGMAS",)
CATEGORY = "sampling/custom_sampling/schedulers"
FUNCTION = "get_sigmas"
def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta):
sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta)
return (sigmas, )
class SDTurboScheduler:
@classmethod
def INPUT_TYPES(s):
@ -677,6 +698,7 @@ NODE_CLASS_MAPPINGS = {
"KarrasScheduler": KarrasScheduler,
"ExponentialScheduler": ExponentialScheduler,
"PolyexponentialScheduler": PolyexponentialScheduler,
"LaplaceScheduler": LaplaceScheduler,
"VPScheduler": VPScheduler,
"BetaSamplingScheduler": BetaSamplingScheduler,
"SDTurboScheduler": SDTurboScheduler,

View File

@ -107,7 +107,7 @@ class HypernetworkLoader:
CATEGORY = "loaders"
def load_hypernetwork(self, model, hypernetwork_name, strength):
hypernetwork_path = folder_paths.get_full_path("hypernetworks", hypernetwork_name)
hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name)
model_hypernetwork = model.clone()
patch = load_hypernetwork_patch(hypernetwork_path, strength)
if patch is not None:

View File

@ -1,24 +1,26 @@
from comfy.nodes.common import MAX_RESOLUTION
from comfy.cmd import folder_paths
from comfy.cli_args import args
import json
import os
import numpy as np
from PIL import Image
from PIL.PngImagePlugin import PngInfo
import numpy as np
import json
import os
from comfy.cli_args import args
from comfy.cmd import folder_paths
from comfy.nodes.common import MAX_RESOLUTION
from comfy.utils import tensor2pil
class ImageCrop:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}}
return {"required": {"image": ("IMAGE",),
"width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}),
"x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
"y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "crop"
@ -29,31 +31,35 @@ class ImageCrop:
y = min(y, image.shape[1] - 1)
to_x = width + x
to_y = height + y
img = image[:,y:to_y, x:to_x, :]
img = image[:, y:to_y, x:to_x, :]
return (img,)
class RepeatImageBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
return {"required": {"image": ("IMAGE",),
"amount": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "repeat"
CATEGORY = "image/batch"
def repeat(self, image, amount):
s = image.repeat((amount, 1,1,1))
s = image.repeat((amount, 1, 1, 1))
return (s,)
class ImageFromBatch:
@classmethod
def INPUT_TYPES(s):
return {"required": { "image": ("IMAGE",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
"length": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
return {"required": {"image": ("IMAGE",),
"batch_index": ("INT", {"default": 0, "min": 0, "max": 4095}),
"length": ("INT", {"default": 1, "min": 1, "max": 4096}),
}}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "frombatch"
@ -66,6 +72,7 @@ class ImageFromBatch:
s = s_in[batch_index:batch_index + length].clone()
return (s,)
class SaveAnimatedWEBP:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@ -73,10 +80,11 @@ class SaveAnimatedWEBP:
self.prefix_append = ""
methods = {"default": 4, "fastest": 0, "slowest": 6}
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
{"images": ("IMAGE",),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"lossless": ("BOOLEAN", {"default": True}),
@ -121,7 +129,7 @@ class SaveAnimatedWEBP:
c = len(pil_images)
for i in range(0, c, num_frames):
file = f"{filename}_{counter:05}_.webp"
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0/fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
pil_images[i].save(os.path.join(full_output_folder, file), save_all=True, duration=int(1000.0 / fps), append_images=pil_images[i + 1:i + num_frames], exif=metadata, lossless=lossless, quality=quality, method=method)
results.append({
"filename": file,
"subfolder": subfolder,
@ -130,7 +138,8 @@ class SaveAnimatedWEBP:
counter += 1
animated = num_frames != 1
return { "ui": { "images": results, "animated": (animated,) } }
return {"ui": {"images": results, "animated": (animated,)}}
class SaveAnimatedPNG:
def __init__(self):
@ -141,7 +150,7 @@ class SaveAnimatedPNG:
@classmethod
def INPUT_TYPES(s):
return {"required":
{"images": ("IMAGE", ),
{"images": ("IMAGE",),
"filename_prefix": ("STRING", {"default": "ComfyUI"}),
"fps": ("FLOAT", {"default": 6.0, "min": 0.01, "max": 1000.0, "step": 0.01}),
"compress_level": ("INT", {"default": 4, "min": 0, "max": 9})
@ -176,16 +185,63 @@ class SaveAnimatedPNG:
metadata.add(b"comf", x.encode("latin-1", "strict") + b"\0" + json.dumps(extra_pnginfo[x]).encode("latin-1", "strict"), after_idat=True)
file = f"{filename}_{counter:05}_.png"
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0/fps), append_images=pil_images[1:])
pil_images[0].save(os.path.join(full_output_folder, file), pnginfo=metadata, compress_level=compress_level, save_all=True, duration=int(1000.0 / fps), append_images=pil_images[1:])
results.append({
"filename": file,
"subfolder": subfolder,
"type": self.type
})
return { "ui": { "images": results, "animated": (True,)} }
return {"ui": {"images": results, "animated": (True,)}}
class ImageSizeToNumber:
"""
By WASasquatch (Discord: WAS#0263)
Copyright 2023 Jordan Thompson (WASasquatch)
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the Software), to
deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense,
and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED AS IS, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
}
}
RETURN_TYPES = ("*", "*", "FLOAT", "FLOAT", "INT", "INT")
RETURN_NAMES = ("width_num", "height_num", "width_float", "height_float", "width_int", "height_int")
FUNCTION = "image_width_height"
CATEGORY = "image/operations"
def image_width_height(self, image):
image = tensor2pil(image)
if image.size:
return (
image.size[0], image.size[1], float(image.size[0]), float(image.size[1]), image.size[0], image.size[1])
return 0, 0, 0, 0, 0, 0
NODE_CLASS_MAPPINGS = {
# From WAS Node Suite
# Class mapping is kept for compatibility
"Image Size to Number": ImageSizeToNumber,
"ImageCrop": ImageCrop,
"RepeatImageBatch": RepeatImageBatch,
"ImageFromBatch": ImageFromBatch,

View File

@ -126,7 +126,7 @@ class PhotoMakerLoader:
CATEGORY = "_for_testing/photomaker"
def load_photomaker_model(self, photomaker_model_name):
photomaker_model_path = folder_paths.get_full_path("photomaker", photomaker_model_name)
photomaker_model_path = folder_paths.get_full_path_or_raise("photomaker", photomaker_model_name)
photomaker_model = PhotoMakerIDEncoder()
data = utils.load_torch_file(photomaker_model_path, safe_load=True)
if "id_encoder" in data:

View File

@ -42,7 +42,7 @@ class EmptySD3LatentImage:
CATEGORY = "latent/sd3"
def generate(self, width, height, batch_size=1):
latent = torch.ones([batch_size, 16, height // 8, width // 8], device=self.device) * 0.0609
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples": latent},)
@ -101,6 +101,7 @@ class ControlNetApplySD3(nodes.ControlNetApplyAdvanced):
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.001})
}}
CATEGORY = "conditioning/controlnet"
DEPRECATED = True
NODE_CLASS_MAPPINGS = {
"TripleCLIPLoader": TripleCLIPLoader,
@ -111,5 +112,5 @@ NODE_CLASS_MAPPINGS = {
NODE_DISPLAY_NAME_MAPPINGS = {
# Sampling
"ControlNetApplySD3": "ControlNetApply SD3 and HunyuanDiT",
"ControlNetApplySD3": "Apply Controlnet with VAE",
}

View File

@ -9,9 +9,9 @@ from comfy.cmd.folder_paths import filter_files_content_types
@pytest.fixture(scope="module")
def file_extensions():
return {
'image': ['bmp', 'cdr', 'gif', 'heif', 'ico', 'jpeg', 'jpg', 'pcx', 'png', 'pnm', 'ppm', 'psd', 'sgi', 'svg', 'tiff', 'webp', 'xbm', 'xcf', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'awb', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'sd2', 'smp', 'snd', 'wav'],
'video': ['avi', 'flv', 'm2v', 'm4v', 'mj2', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
'image': ['gif', 'heif', 'ico', 'jpeg', 'jpg', 'png', 'pnm', 'ppm', 'svg', 'tiff', 'webp', 'xbm', 'xpm'],
'audio': ['aif', 'aifc', 'aiff', 'au', 'flac', 'm4a', 'mp2', 'mp3', 'ogg', 'snd', 'wav'],
'video': ['avi', 'm2v', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ogv', 'qt', 'webm', 'wmv']
}

View File

@ -82,14 +82,16 @@ def test_load_extra_model_paths_expands_userpath(
load_extra_path_config(dummy_yaml_file_name)
expected_calls = [
('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1')),
('checkpoints', os.path.join(mock_expanded_home, 'App', 'subfolder1'), False),
]
assert mock_add_model_folder_path.call_count == len(expected_calls)
# Check if add_model_folder_path was called with the correct arguments
for actual_call, expected_call in zip(mock_add_model_folder_path.call_args_list, expected_calls):
assert actual_call.args == expected_call
assert actual_call.args[0] == expected_call[0]
assert os.path.normpath(actual_call.args[1]) == os.path.normpath(expected_call[1]) # Normalize and check the path to check on multiple OS.
assert actual_call.args[2] == expected_call[2]
# Check if yaml.safe_load was called
mock_yaml_safe_load.assert_called_once()
@ -123,7 +125,7 @@ def test_load_extra_model_paths_expands_appdata(
expected_base_path = 'C:/Users/TestUser/AppData/Roaming/ComfyUI'
expected_calls = [
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints')),
('checkpoints', os.path.join(expected_base_path, 'models/checkpoints'), False),
]
assert mock_add_model_folder_path.call_count == len(expected_calls)