Merge WIP

This commit is contained in:
doctorpangloss 2024-08-25 18:52:29 -07:00
commit 5155a3e248
64 changed files with 33923 additions and 22018 deletions

View File

@ -0,0 +1,2 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
pause

View File

@ -1,5 +1,8 @@
blank_issues_enabled: true
contact_links:
- name: ComfyUI Frontend Issues
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
- name: ComfyUI Matrix Space
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).

View File

@ -598,6 +598,7 @@ The default installation includes a fast latent preview method that's low-resolu
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
| Ctrl + Enter | Queue up current graph for generation |
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
| Ctrl + Alt + Enter | Cancel current generation |
| Ctrl + Z/Ctrl + Y | Undo/Redo |
| Ctrl + S | Save workflow |
| Ctrl + O | Load workflow |
@ -620,6 +621,8 @@ The default installation includes a fast latent preview method that's low-resolu
| H | Toggle visibility of history |
| R | Refresh graph |
| Double-Click LMB | Open node quick search palette |
| Shift + Drag | Move multiple wires at once |
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
Ctrl can also be replaced with Cmd instead for macOS users
@ -1013,6 +1016,47 @@ To run:
docker run -it -v ./output:/workspace/output -v ./models:/workspace/models --gpus=all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --rm hiddenswitch/comfyui
```
## Frontend Development
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
### Reporting Issues and Requesting Features
For any bugs, issues, or feature requests related to the frontend, please use the [ComfyUI Frontend repository](https://github.com/Comfy-Org/ComfyUI_frontend). This will help us manage and address frontend-specific concerns more efficiently.
### Using the Latest Frontend
The new frontend is now the default for ComfyUI. However, please note:
1. The frontend in the main ComfyUI repository is updated weekly.
2. Daily releases are available in the separate frontend repository.
To use the most up-to-date frontend version:
1. For the latest daily release, launch ComfyUI with this command line argument:
```
--front-end-version Comfy-Org/ComfyUI_frontend@latest
```
2. For a specific version, replace `latest` with the desired version number:
```
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
```
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
### Accessing the Legacy Frontend
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
```
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
```
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
## Community
[Chat on Matrix: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org), an alternative to Discord.

0
api_server/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,3 @@
# ComfyUI Internal Routes
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.

View File

View File

@ -0,0 +1,40 @@
from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory
from api_server.services.file_service import FileService
class InternalRoutes:
'''
The top level web router for internal routes: /internal/*
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
Check README.md for more information.
'''
def __init__(self):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
self.file_service = FileService({
"models": models_dir,
"user": user_directory,
"output": output_directory
})
def setup_routes(self):
@self.routes.get('/files')
async def list_files(request):
directory_key = request.query.get('directory', '')
try:
file_list = self.file_service.list_files(directory_key)
return web.json_response({"files": file_list})
except ValueError as e:
return web.json_response({"error": str(e)}, status=400)
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
def get_app(self):
if self._app is None:
self._app = web.Application()
self.setup_routes()
self._app.add_routes(self.routes)
return self._app

View File

View File

@ -0,0 +1,13 @@
from typing import Dict, List, Optional
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
class FileService:
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
self.allowed_directories: Dict[str, str] = allowed_directories
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
def list_files(self, directory_key: str) -> List[FileSystemItem]:
if directory_key not in self.allowed_directories:
raise ValueError("Invalid directory key")
directory_path: str = self.allowed_directories[directory_key]
return self.file_system_ops.walk_directory(directory_path)

View File

@ -0,0 +1,42 @@
import os
from typing import List, Union, TypedDict, Literal
from typing_extensions import TypeGuard
class FileInfo(TypedDict):
name: str
path: str
type: Literal["file"]
size: int
class DirectoryInfo(TypedDict):
name: str
path: str
type: Literal["directory"]
FileSystemItem = Union[FileInfo, DirectoryInfo]
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
return item["type"] == "file"
class FileSystemOperations:
@staticmethod
def walk_directory(directory: str) -> List[FileSystemItem]:
file_list: List[FileSystemItem] = []
for root, dirs, files in os.walk(directory):
for name in files:
file_path = os.path.join(root, name)
relative_path = os.path.relpath(file_path, directory)
file_list.append({
"name": name,
"path": relative_path,
"type": "file",
"size": os.path.getsize(file_path)
})
for name in dirs:
dir_path = os.path.join(root, name)
relative_path = os.path.relpath(dir_path, directory)
file_list.append({
"name": name,
"path": relative_path,
"type": "directory"
})
return file_list

View File

@ -59,6 +59,8 @@ class CacheKeySetID(CacheKeySet):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = (node_id, node["class_type"])
self.subcache_keys[node_id] = (node_id, node["class_type"])
@ -78,6 +80,8 @@ class CacheKeySetInputSignature(CacheKeySet):
for node_id in node_ids:
if node_id in self.keys:
continue
if not self.dynprompt.has_node(node_id):
continue
node = self.dynprompt.get_node(node_id)
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
self.subcache_keys[node_id] = (node_id, node["class_type"])
@ -91,6 +95,9 @@ class CacheKeySetInputSignature(CacheKeySet):
return to_hashable(signature)
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
if not dynprompt.has_node(node_id):
# This node doesn't exist -- we can't cache it.
return [float("NaN")]
node = dynprompt.get_node(node_id)
class_type = node["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
@ -116,6 +123,8 @@ class CacheKeySetInputSignature(CacheKeySet):
return ancestors, order_mapping
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
if not dynprompt.has_node(node_id):
return
inputs = dynprompt.get_node(node_id)["inputs"]
input_keys = sorted(inputs.keys())
for key in input_keys:

View File

@ -113,11 +113,13 @@ def _create_parser() -> EnhancedConfigArgParser:
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true",
help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true",
help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI. Raises an error if nodes cannot be imported,")

View File

@ -81,6 +81,8 @@ class Configuration(dict):
lowvram (bool): Reduce UNet's VRAM usage.
novram (bool): Minimize VRAM usage.
cpu (bool): Use CPU for processing.
fast (bool): Enable some untested and potentially quality deteriorating optimizations
reserve_vram (Optional[float]): Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS
disable_smart_memory (bool): Disable smart memory management.
deterministic (bool): Use deterministic algorithms where possible.
dont_print_server (bool): Suppress server output.
@ -157,6 +159,8 @@ class Configuration(dict):
self.lowvram: bool = False
self.novram: bool = False
self.cpu: bool = False
self.fast: bool = False
self.reserve_vram: Optional[float] = None
self.disable_smart_memory: bool = False
self.deterministic: bool = False
self.dont_print_server: bool = False

View File

@ -89,10 +89,11 @@ class CLIPTextModel_(torch.nn.Module):
intermediate_size = config_dict["intermediate_size"]
intermediate_activation = config_dict["hidden_act"]
vocab_size = config_dict["vocab_size"]
num_positions = config_dict["max_position_embeddings"]
self.eos_token_id = config_dict["eos_token_id"]
super().__init__()
self.embeddings = CLIPEmbeddings(embed_dim, vocab_size=vocab_size, dtype=dtype, device=device, operations=operations)
self.embeddings = CLIPEmbeddings(embed_dim, vocab_size=vocab_size, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
@ -124,7 +125,6 @@ class CLIPTextModel(torch.nn.Module):
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
embed_dim = config_dict["hidden_size"]
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
self.text_projection.weight.copy_(torch.eye(embed_dim))
self.dtype = dtype
def get_input_embeddings(self):

View File

@ -61,7 +61,8 @@ class IsChangedCache:
self.is_changed[node_id] = node["is_changed"]
return self.is_changed[node_id]
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
try:
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
@ -567,6 +568,7 @@ class PromptExecutor:
break
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
self.success = result != ExecutionResult.FAILURE
if result == ExecutionResult.FAILURE:
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
break

View File

@ -99,7 +99,7 @@ folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(mo
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions))
folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions))
folder_names_and_paths["unet"] = FolderPathsTuple("unet", [os.path.join(models_dir, "unet")], set(supported_pt_extensions))
folder_names_and_paths["unet"] = folder_names_and_paths["diffusion_models"] = FolderPathsTuple("diffusion_models", [os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], set(supported_pt_extensions))
folder_names_and_paths["clip_vision"] = FolderPathsTuple("clip_vision", [os.path.join(models_dir, "clip_vision")], set(supported_pt_extensions))
folder_names_and_paths["style_models"] = FolderPathsTuple("style_models", [os.path.join(models_dir, "style_models")], set(supported_pt_extensions))
folder_names_and_paths["embeddings"] = FolderPathsTuple("embeddings", [os.path.join(models_dir, "embeddings")], set(supported_pt_extensions))

View File

@ -188,6 +188,7 @@ async def main():
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
if args.input_directory:
input_dir = os.path.abspath(args.input_directory)

View File

@ -25,6 +25,7 @@ from aiohttp import web
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
from typing_extensions import NamedTuple
from api_server.routes.internal.internal_routes import InternalRoutes
from ..model_filemanager import download_model, DownloadModelStatus
from .latent_preview_image_encoding import encode_preview_image
from .. import interruption
@ -97,6 +98,7 @@ class PromptServer(ExecutorToClientProgress):
self.address: str = "0.0.0.0"
self.user_manager = UserManager()
self.internal_routes = InternalRoutes()
# todo: this is probably read by custom nodes elsewhere
self.supports: List[str] = ["custom_nodes_from_web"]
self.prompt_queue: AbstractPromptQueue | AsyncAbstractPromptQueue | None = None
@ -170,6 +172,14 @@ class PromptServer(ExecutorToClientProgress):
embeddings = folder_paths.get_filename_list("embeddings")
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
@routes.get("/models/{folder}")
async def get_models(request):
folder = request.match_info.get("folder", None)
if not folder in folder_paths.folder_names_and_paths:
return web.Response(status=404)
files = folder_paths.get_filename_list(folder)
return web.json_response(files)
@routes.get("/extensions")
async def get_extensions(request):
files = glob.glob(os.path.join(glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
@ -461,6 +471,11 @@ class PromptServer(ExecutorToClientProgress):
if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
if getattr(obj_class, "DEPRECATED", False):
info['deprecated'] = True
if getattr(obj_class, "EXPERIMENTAL", False):
info['experimental'] = True
return info
@routes.get("/object_info")
@ -764,6 +779,7 @@ class PromptServer(ExecutorToClientProgress):
def add_routes(self):
self.user_manager.add_routes(self.routes)
self.app.add_subapp('/internal', self.internal_routes.get_app())
# Prefix every route with /api for easier matching for delegation.
# This is very useful for frontend dev server, which need to forward

View File

@ -391,7 +391,8 @@ def controlnet_config(sd):
else:
operations = ops.disable_weight_init
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
offload_device = model_management.unet_offload_device()
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
def controlnet_load_state_dict(control_model, sd):
missing, unexpected = control_model.load_state_dict(sd, strict=False)
@ -405,12 +406,12 @@ def controlnet_load_state_dict(control_model, sd):
def load_controlnet_mmdit(sd):
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
num_blocks = model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
for k in sd:
new_sd[k] = sd[k]
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = latent_formats.SD3()
@ -420,9 +421,9 @@ def load_controlnet_mmdit(sd):
def load_controlnet_hunyuandit(controlnet_data):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
control_model = hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
control_model = hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
control_model = controlnet_load_state_dict(control_model, controlnet_data)
latent_format = latent_formats.SDXL()
@ -431,8 +432,8 @@ def load_controlnet_hunyuandit(controlnet_data):
return control
def load_controlnet_flux_xlabs(sd):
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
control_model = flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
control_model = flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
@ -535,6 +536,7 @@ def load_controlnet(ckpt_path, model=None):
if manual_cast_dtype is not None:
controlnet_config["operations"] = ops.manual_cast
controlnet_config["dtype"] = unet_dtype
controlnet_config["device"] = model_management.unet_offload_device()
controlnet_config.pop("out_channels")
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
control_model = cldm.ControlNet(**controlnet_config)

59
comfy/float.py Normal file
View File

@ -0,0 +1,59 @@
import torch
#Not 100% sure about this
def manual_stochastic_round_to_float8(x, dtype):
if dtype == torch.float8_e4m3fn:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
elif dtype == torch.float8_e5m2:
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
else:
raise ValueError("Unsupported dtype")
sign = torch.sign(x)
abs_x = x.abs()
# Combine exponent calculation and clamping
exponent = torch.clamp(
torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS,
0, 2**EXPONENT_BITS - 1
)
# Combine mantissa calculation and rounding
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
# zero_mask = (abs_x == 0)
# subnormal_mask = (exponent == 0) & (abs_x != 0)
normal_mask = ~(exponent == 0)
mantissa_scaled = torch.where(
normal_mask,
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
)
mantissa_floor = mantissa_scaled.floor()
mantissa = torch.where(
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
(mantissa_floor + 1) / (2**MANTISSA_BITS),
mantissa_floor / (2**MANTISSA_BITS)
)
result = torch.where(
normal_mask,
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
)
result = torch.where(abs_x == 0, 0, result)
return result.to(dtype=dtype)
def stochastic_rounding(value, dtype):
if dtype == torch.float32:
return value.to(dtype=torch.float32)
if dtype == torch.float16:
return value.to(dtype=torch.float16)
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
return manual_stochastic_round_to_float8(value, dtype)
return value.to(dtype=dtype)

View File

@ -172,19 +172,41 @@ class ExecutionList(TopologicalSort):
"current_inputs": []
}
return None, error_details, ex
next_node = available[0]
self.staged_node_id = self.ux_friendly_pick_node(available)
return self.staged_node_id, None, None
def ux_friendly_pick_node(self, node_list):
# If an output node is available, do that first.
# Technically this has no effect on the overall length of execution, but it feels better as a user
# for a PreviewImage to display a result as soon as it can
# Some other heuristics could probably be used here to improve the UX further.
for node_id in available:
def is_output(node_id):
class_type = self.dynprompt.get_node(node_id)["class_type"]
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
next_node = node_id
break
self.staged_node_id = next_node
return self.staged_node_id, None, None
return True
return False
for node_id in node_list:
if is_output(node_id):
return node_id
#This should handle the VAEDecode -> preview case
for node_id in node_list:
for blocked_node_id in self.blocking[node_id]:
if is_output(blocked_node_id):
return node_id
#This should handle the VAELoader -> VAEDecode -> preview case
for node_id in node_list:
for blocked_node_id in self.blocking[node_id]:
for blocked_node_id1 in self.blocking[blocked_node_id]:
if is_output(blocked_node_id1):
return node_id
#TODO: this function should be improved
return node_list[0]
def unstage_node_execution(self):
assert self.staged_node_id is not None

View File

@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
from . import utils
from . import deis
from .. import model_patcher
from .. import model_sampling
def append_zero(x):
return torch.cat([x, x.new_zeros([1])])
@ -509,6 +510,9 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac
@torch.no_grad()
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if isinstance(model.inner_model.inner_model.model_sampling, model_sampling.CONST):
return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
@ -541,6 +545,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
return x
@torch.no_grad()
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
# logged_x = x.unsqueeze(0)
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args)
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
sigma_down = sigmas[i+1] * downstep_ratio
alpha_ip1 = 1 - sigmas[i+1]
alpha_down = 1 - sigma_down
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
if callback is not None:
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
if sigmas[i + 1] == 0:
# Euler method
d = to_d(x, sigmas[i], denoised)
dt = sigma_down - sigmas[i]
x = x + d * dt
else:
# DPM-Solver++(2S)
if sigmas[i] == 1.0:
sigma_s = 0.9999
else:
t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
r = 1 / 2
h = t_down - t_i
s = t_i + r * h
sigma_s = sigma_fn(s)
# sigma_s = sigmas[i+1]
sigma_s_i_ratio = sigma_s / sigmas[i]
u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
D_i = model(u, sigma_s * s_in, **extra_args)
sigma_down_i_ratio = sigma_down / sigmas[i]
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
# Noise addition
if sigmas[i + 1] > 0 and eta > 0:
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
return x
@torch.no_grad()
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
"""DPM-Solver++ (stochastic)."""

View File

@ -178,7 +178,7 @@ class DoubleStreamBlock(nn.Module):
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
if txt.dtype == torch.float16:
txt = txt.clip(-65504, 65504)
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
return img, txt
@ -233,7 +233,7 @@ class SingleStreamBlock(nn.Module):
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
if x.dtype == torch.float16:
x = x.clip(-65504, 65504)
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x

View File

@ -17,8 +17,10 @@
"""
import logging
import torch
from . import utils
from . import model_base
from . import model_management
LORA_CLIP_MAP = {
"mlp.fc1": "mlp_fc1",
@ -316,7 +318,196 @@ def model_lora_keys_unet(model, key_map={}):
for k in diffusers_keys:
if k.endswith(".weight"):
to = diffusers_keys[k]
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
key_map[key_lora] = to
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
return key_map
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
dora_scale = model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
for p in patches:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
patch_type = v[0]
v = v[1]
if patch_type == "diff":
w1 = v[0]
if strength != 0.0:
if w1.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += function(strength * model_management.cast_to_device(w1, weight.device, weight.dtype))
elif patch_type == "lora": #lora/locon
mat1 = model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
mat2 = model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
#locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
else:
w1 = model_management.cast_to_device(w1, weight.device, intermediate_dtype)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
model_management.cast_to_device(t2, weight.device, intermediate_dtype),
model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
else:
w2 = model_management.cast_to_device(w2, weight.device, intermediate_dtype)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: #cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
model_management.cast_to_device(t1, weight.device, intermediate_dtype),
model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
model_management.cast_to_device(t2, weight.device, intermediate_dtype),
model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
else:
m1 = torch.mm(model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5]
a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
return weight

View File

@ -103,10 +103,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
if self.manual_cast_dtype is not None:
operations = ops.manual_cast
else:
operations = ops.disable_weight_init
operations = ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
else:
operations = model_config.custom_operations
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)

View File

@ -471,9 +471,15 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
'use_temporal_attention': False, 'use_temporal_resblock': False}
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
for unet_config in supported_models:
matches = True

View File

@ -424,7 +424,7 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors"),
HuggingFile("black-forest-labs/FLUX.1-schnell", "flux1-schnell.safetensors"),
HuggingFile("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors"),
], folder_name="unet")
], folder_name="diffusion_models")
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([
# todo: is this correct?

View File

@ -60,9 +60,14 @@ cpu_state = CPUState.GPU
total_vram = 0
lowvram_available = True
xpu_available = False
try:
torch_version = torch.version.__version__
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
except:
pass
lowvram_available = True
if args.deterministic:
logging.info("Using deterministic algorithms for pytorch")
torch.use_deterministic_algorithms(True, warn_only=True)
@ -83,10 +88,10 @@ if args.directml is not None:
try:
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
if torch.xpu.is_available():
xpu_available = True
_ = torch.xpu.device_count()
xpu_available = torch.xpu.is_available()
except:
pass
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
try:
if torch.backends.mps.is_available():
@ -224,7 +229,6 @@ VAE_DTYPES = [torch.float32]
try:
if is_nvidia() or is_amd():
torch_version = torch.version.__version__
if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
@ -356,17 +360,15 @@ class LoadedModel:
self.model_use_more_vram(use_more_vram)
else:
try:
if lowvram_model_memory > 0 and load_weights:
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
else:
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
except Exception as e:
self.model.unpatch_model(self.model.offload_device)
self.model_unload()
raise e
if is_intel_xpu() and not args.disable_ipex_optimize:
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
with torch.no_grad():
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
self.weights_loaded = True
return self.real_model
@ -421,6 +423,19 @@ def minimum_inference_memory():
return (1024 * 1024 * 1024) * 1.2
EXTRA_RESERVED_VRAM = 200 * 1024 * 1024
if any(platform.win32_ver()):
EXTRA_RESERVED_VRAM = 500 * 1024 * 1024 # Windows is higher because of the shared vram issue
if args.reserve_vram is not None:
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
def extra_reserved_memory():
return EXTRA_RESERVED_VRAM
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
with model_management_lock:
return _unload_model_clones(model, unload_weights_only, force_unload)
@ -519,11 +534,11 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
global vram_state
inference_memory = minimum_inference_memory()
extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
if minimum_memory_required is None:
minimum_memory_required = extra_mem
else:
minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
models = set(models)
@ -636,7 +651,9 @@ def cleanup_models(keep_clone_weights_loaded=False):
with model_management_lock:
to_delete = []
for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2:
# TODO: very fragile function needs improvement
num_refs = sys.getrefcount(current_loaded_models[i].model)
if num_refs <= 2:
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
# TODO: find a less fragile way to do this.
@ -749,6 +766,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=(torch.flo
if bf16_supported and weight_dtype == torch.bfloat16:
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
for dt in supported_dtypes:
if dt == torch.float16 and fp16_supported:
return torch.float16
@ -984,7 +1002,8 @@ def pytorch_attention_flash_attention():
def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
try:
if platform.mac_ver()[0] in ['14.5']: # black image bug on OSX Sonoma 14.5
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
upcast = True
except:
pass
@ -1084,7 +1103,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_amd():
return True
try:
props = torch.cuda.get_device_properties("cuda")
props = torch.cuda.get_device_properties(device)
if props.major >= 8:
return True
@ -1094,16 +1113,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
logging.warning("Torch was not compiled with cuda support")
return False
fp16_works = False
# FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
# when the model doesn't actually fit on the card
# TODO: actually test if GP106 and others have the same type of behavior
# FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
for x in nvidia_10_series:
if x in props.name.lower():
fp16_works = True
return True
if fp16_works or manual_cast:
if manual_cast:
free_model_memory = maximum_vram_for_weights(device)
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
@ -1168,6 +1184,17 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def supports_fp8_compute(device=None):
props = torch.cuda.get_device_properties(device)
if props.major >= 9:
return True
if props.major < 8:
return False
if props.minor < 9:
return False
return True
def soft_empty_cache(force=False):
with model_management_lock:
global cpu_state

View File

@ -60,12 +60,6 @@ class ModelManageable(Protocol):
return self.model
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
"""
Loads the model to the device
:param device_to: the device to move the model weights to
:param patch_weights: True if the patch's weights should also be moved
:return:
"""
...
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:

View File

@ -25,34 +25,14 @@ from typing import Optional
import torch
import torch.nn
from . import model_management
from . import model_management, lora
from . import utils
from .float import stochastic_rounding
from .model_base import BaseModel
from .model_management_types import ModelManageable, MemoryMeasurements
from .types import UnetWrapperFunction
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
dora_scale = model_management.cast_to_device(dora_scale, weight.device, torch.float32)
lora_diff *= alpha
weight_calc = weight + lora_diff.type(weight.dtype)
weight_norm = (
weight_calc.transpose(0, 1)
.reshape(weight_calc.shape[1], -1)
.norm(dim=1, keepdim=True)
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
.transpose(0, 1)
)
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
if strength != 1.0:
weight_calc -= weight
weight += strength * (weight_calc)
else:
weight[:] = weight_calc
return weight
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
to = model_options["transformer_options"].copy()
@ -98,12 +78,12 @@ def wipe_lowvram_weight(m):
class LowVramPatch:
def __init__(self, key, model_patcher):
def __init__(self, key, patches):
self.key = key
self.model_patcher = model_patcher
self.patches = patches
def __call__(self, weight):
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
return lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
class ModelPatcher(ModelManageable):
@ -332,38 +312,18 @@ class ModelPatcher(ModelManageable):
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
out_weight = lora.calculate_weight(self.patches[key], temp_weight, key)
out_weight = stochastic_rounding(out_weight, weight.dtype)
if inplace_update:
utils.copy_to_param(self.model, key, out_weight)
else:
utils.set_attr_param(self.model, key, out_weight)
def patch_model(self, device_to=None, patch_weights=True):
for k in self.object_patches:
old = utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
if patch_weights:
model_sd = self.model_state_dict()
for key in self.patches:
if key not in model_sd:
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
continue
self.patch_weight_to_device(key, device_to)
if device_to is not None:
self.model.to(device_to)
self.model_device = device_to
self._memory_measurements.model_loaded_weight_memory = self.model_size()
return self.model
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
mem_counter = 0
patch_counter = 0
lowvram_counter = 0
load_completely = []
for n, m in self.model.named_modules():
lowvram_weight = False
@ -372,7 +332,7 @@ class ModelPatcher(ModelManageable):
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
lowvram_counter += 1
if m.comfy_cast_weights:
if hasattr(m, "prev_comfy_cast_weights"): # Already lowvramed
continue
weight_key = "{}.weight".format(n)
@ -383,13 +343,13 @@ class ModelPatcher(ModelManageable):
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
m.weight_function = LowVramPatch(weight_key, self)
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
m.bias_function = LowVramPatch(bias_key, self)
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
@ -400,206 +360,58 @@ class ModelPatcher(ModelManageable):
wipe_lowvram_weight(m)
if hasattr(m, "weight"):
mem_counter += model_management.module_size(m)
param = list(m.parameters())
if len(param) > 0:
weight = param[0]
if weight.device == device_to:
continue
mem_used = model_management.module_size(m)
mem_counter += mem_used
load_completely.append((mem_used, n, m))
weight_to = None
if full_load: # TODO
weight_to = device_to
self.patch_weight_to_device(weight_key, device_to=weight_to) # TODO: speed this up without OOM
self.patch_weight_to_device(bias_key, device_to=weight_to)
m.to(device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
load_completely.sort(reverse=True)
for x in load_completely:
n = x[1]
m = x[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if hasattr(m, "comfy_patched_weights"):
if m.comfy_patched_weights == True:
continue
self.patch_weight_to_device(weight_key, device_to=device_to)
self.patch_weight_to_device(bias_key, device_to=device_to)
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
for x in load_completely:
x[2].to(device_to)
if lowvram_counter > 0:
logging.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
self._memory_measurements.model_lowvram = True
else:
logging.debug("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
self._memory_measurements.model_lowvram = False
if full_load:
self.model.to(device_to)
mem_counter = self.model_size()
self._memory_measurements.lowvram_patch_counter += patch_counter
self._memory_measurements.model_loaded_weight_memory = mem_counter
self.model_device = device_to
self._memory_measurements.model_loaded_weight_memory = mem_counter
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
self.patch_model(device_to, patch_weights=False)
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
for k in self.object_patches:
old = utils.set_attr(self.model, k, self.object_patches[k])
if k not in self.object_patches_backup:
self.object_patches_backup[k] = old
if lowvram_model_memory == 0:
full_load = True
else:
full_load = False
if load_weights:
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
return self.model
def calculate_weight(self, patches, weight, key):
for p in patches:
strength = p[0]
v = p[1]
strength_model = p[2]
offset = p[3]
function = p[4]
if function is None:
function = lambda a: a
old_weight = None
if offset is not None:
old_weight = weight
weight = weight.narrow(offset[0], offset[1], offset[2])
if strength_model != 1.0:
weight *= strength_model
if isinstance(v, list):
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
patch_type = "diff"
if len(v) == 2:
patch_type = v[0]
v = v[1]
elif len(v) != 1:
logging.warning("patch {} not recognized: {}".format(key, v))
continue
if patch_type == "diff":
w1 = v[0]
if strength != 0.0:
if w1.shape != weight.shape:
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
else:
weight += function(strength * model_management.cast_to_device(w1, weight.device, weight.dtype))
elif patch_type == "lora": # lora/locon
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
dora_scale = v[4]
if v[2] is not None:
alpha = v[2] / mat2.shape[0]
else:
alpha = 1.0
if v[3] is not None:
# locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = model_management.cast_to_device(v[3], weight.device, torch.float32)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
try:
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "lokr":
w1 = v[0]
w2 = v[1]
w1_a = v[3]
w1_b = v[4]
w2_a = v[5]
w2_b = v[6]
t2 = v[7]
dora_scale = v[8]
dim = None
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(model_management.cast_to_device(w1_a, weight.device, torch.float32),
model_management.cast_to_device(w1_b, weight.device, torch.float32))
else:
w1 = model_management.cast_to_device(w1, weight.device, torch.float32)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(model_management.cast_to_device(w2_a, weight.device, torch.float32),
model_management.cast_to_device(w2_b, weight.device, torch.float32))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
model_management.cast_to_device(t2, weight.device, torch.float32),
model_management.cast_to_device(w2_b, weight.device, torch.float32),
model_management.cast_to_device(w2_a, weight.device, torch.float32))
else:
w2 = model_management.cast_to_device(w2, weight.device, torch.float32)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
if v[2] is not None and dim is not None:
alpha = v[2] / dim
else:
alpha = 1.0
try:
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "loha":
w1a = v[0]
w1b = v[1]
if v[2] is not None:
alpha = v[2] / w1b.shape[0]
else:
alpha = 1.0
w2a = v[3]
w2b = v[4]
dora_scale = v[7]
if v[5] is not None: # cp decomposition
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
model_management.cast_to_device(t1, weight.device, torch.float32),
model_management.cast_to_device(w1b, weight.device, torch.float32),
model_management.cast_to_device(w1a, weight.device, torch.float32))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
model_management.cast_to_device(t2, weight.device, torch.float32),
model_management.cast_to_device(w2b, weight.device, torch.float32),
model_management.cast_to_device(w2a, weight.device, torch.float32))
else:
m1 = torch.mm(model_management.cast_to_device(w1a, weight.device, torch.float32),
model_management.cast_to_device(w1b, weight.device, torch.float32))
m2 = torch.mm(model_management.cast_to_device(w2a, weight.device, torch.float32),
model_management.cast_to_device(w2b, weight.device, torch.float32))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
elif patch_type == "glora":
if v[4] is not None:
alpha = v[4] / v[0].shape[0]
else:
alpha = 1.0
dora_scale = v[5]
a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
if old_weight is not None:
weight = old_weight
return weight
def unpatch_model(self, device_to=None, unpatch_weights=True):
if unpatch_weights:
if self._memory_measurements.model_lowvram:
@ -625,6 +437,10 @@ class ModelPatcher(ModelManageable):
self.model_device = device_to
self._memory_measurements.model_loaded_weight_memory = 0
for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"):
del m.comfy_patched_weights
keys = list(self.object_patches_backup.keys())
for k in keys:
utils.set_attr(self.model, k, self.object_patches_backup[k])
@ -634,39 +450,47 @@ class ModelPatcher(ModelManageable):
def partially_unload(self, device_to, memory_to_free=0):
memory_freed = 0
patch_counter = 0
unload_list = []
for n, m in list(self.model.named_modules())[::-1]:
if memory_to_free < memory_freed:
break
for n, m in self.model.named_modules():
shift_lowvram = False
if hasattr(m, "comfy_cast_weights"):
module_mem = model_management.module_size(m)
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
unload_list.append((module_mem, n, m))
if m.weight is not None and m.weight.device != device_to:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
if bk is not None:
if bk.inplace_update:
utils.copy_to_param(self.model, key, bk.weight)
else:
utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
unload_list.sort()
for unload in unload_list:
if memory_to_free < memory_freed:
break
module_mem = unload[0]
n = unload[1]
m = unload[2]
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self)
patch_counter += 1
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
for key in [weight_key, bias_key]:
bk = self.backup.get(key, None)
if bk is not None:
if bk.inplace_update:
utils.copy_to_param(self.model, key, bk.weight)
else:
utils.set_attr_param(self.model, key, bk.weight)
self.backup.pop(key)
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
memory_freed += module_mem
logging.debug("freed {}".format(n))
m.to(device_to)
if weight_key in self.patches:
m.weight_function = LowVramPatch(weight_key, self.patches)
patch_counter += 1
if bias_key in self.patches:
m.bias_function = LowVramPatch(bias_key, self.patches)
patch_counter += 1
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
logging.debug("freed {}".format(n))
self._memory_measurements.model_lowvram = True
self._memory_measurements.lowvram_patch_counter += patch_counter
@ -675,14 +499,14 @@ class ModelPatcher(ModelManageable):
def partially_load(self, device_to, extra_memory=0):
self.unpatch_model(unpatch_weights=False)
self.patch_model(patch_weights=False)
self.patch_model(load_weights=False)
full_load = False
if not self._memory_measurements.model_lowvram:
return 0
if self._memory_measurements.model_loaded_weight_memory + extra_memory > self.model_size():
full_load = True
current_used = self._memory_measurements.model_loaded_weight_memory
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
return self._memory_measurements.model_loaded_weight_memory - current_used
def current_loaded_device(self):
@ -697,3 +521,7 @@ class ModelPatcher(ModelManageable):
return f"<ModelPatcher for {self.ckpt_name} ({self.model.__class__.__name__})>"
else:
return f"<ModelPatcher for {self.model.__class__.__name__}>"
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
return lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)

View File

@ -853,7 +853,7 @@ class ControlNetApplyAdvanced:
class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (get_filename_list_with_downloadable("unet", KNOWN_UNET_MODELS),),
return {"required": { "unet_name": (get_filename_list_with_downloadable("diffusion_models", KNOWN_UNET_MODELS),),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
}}
RETURN_TYPES = ("MODEL",)
@ -868,7 +868,7 @@ class UNETLoader:
elif weight_dtype == "fp8_e5m2":
model_options["dtype"] = torch.float8_e5m2
unet_path = get_or_download("unet", unet_name, KNOWN_UNET_MODELS)
unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS)
model = sd.load_diffusion_model(unet_path, model_options=model_options)
return (model,)

View File

@ -19,31 +19,44 @@
import torch
from . import model_management
from .cli_args import args
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
return weight.to(dtype=dtype, copy=copy)
r = torch.empty_like(weight, dtype=dtype, device=device)
r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to(weight, dtype=None, device=None, non_blocking=False):
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
def cast_to_input(weight, input, non_blocking=False, copy=True):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
def cast_to_input(weight, input, non_blocking=False):
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
def cast_bias_weight(s, input=None, dtype=None, device=None):
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if input is not None:
if dtype is None:
dtype = input.dtype
if bias_dtype is None:
bias_dtype = dtype
if device is None:
device = input.device
bias = None
non_blocking = model_management.device_should_use_non_blocking(device)
non_blocking = model_management.device_supports_non_blocking(device)
if s.bias is not None:
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
if s.bias_function is not None:
has_function = s.bias_function is not None
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
bias = s.bias_function(bias)
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
if s.weight_function is not None:
has_function = s.weight_function is not None
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
if has_function:
weight = s.weight_function(weight)
return weight, bias
@ -242,3 +255,59 @@ class manual_cast(disable_weight_init):
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True
def fp8_linear(self, input):
dtype = self.weight.dtype
if dtype not in [torch.float8_e4m3fn]:
return None
if len(input.shape) == 3:
inn = input.reshape(-1, input.shape[2]).to(dtype)
non_blocking = model_management.device_supports_non_blocking(input.device)
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
w = w.t()
scale_weight = self.scale_weight
scale_input = self.scale_input
if scale_weight is None:
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
if scale_input is None:
scale_input = scale_weight
if scale_input is None:
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
if isinstance(o, tuple):
o = o[0]
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
return None
class fp8_ops(manual_cast):
class Linear(manual_cast.Linear):
def reset_parameters(self):
self.scale_weight = None
self.scale_input = None
return None
def forward_comfy_cast_weights(self, input):
out = fp8_linear(self, input)
if out is not None:
return out
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)
def pick_operations(weight_dtype, compute_dtype, load_device=None):
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
if args.fast:
if model_management.supports_fp8_compute(load_device):
return fp8_ops
return manual_cast

View File

@ -33,6 +33,7 @@ from .text_encoders import hydit
from .text_encoders import sa_t5
from .text_encoders import sd2_clip
from .text_encoders import sd3_clip
from .text_encoders import long_clipl
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
@ -66,7 +67,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
class CLIP:
def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None = None, parameters=0):
def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None = None, parameters=0, model_options={}):
if tokenizer_data is None:
tokenizer_data = dict()
if no_init:
@ -77,12 +78,17 @@ class CLIP:
load_device = model_management.text_encoder_device()
offload_device = model_management.text_encoder_offload_device()
dtype = model_management.text_encoder_dtype(load_device)
dtype = model_options.get("dtype", None)
if dtype is None:
dtype = model_management.text_encoder_dtype(load_device)
params['dtype'] = dtype
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
if "textmodel_json_config" not in params and textmodel_json_config is not None:
params['textmodel_json_config'] = textmodel_json_config
params['model_options'] = model_options
self.cond_stage_model = clip(**(params))
for dt in self.cond_stage_model.dtypes:
@ -416,10 +422,18 @@ class CLIPTarget:
tokenizer: Optional[Any] = None
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, textmodel_json_config: str | dict | None = None):
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, textmodel_json_config: str | dict | None = None, model_options=None):
if model_options is None:
model_options = dict()
clip_data = []
for p in ckpt_paths:
clip_data.append(utils.load_torch_file(p, safe_load=True))
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, textmodel_json_config=textmodel_json_config)
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, textmodel_json_config=None):
clip_data = state_dicts
class EmptyClass:
pass
for i in range(len(clip_data)):
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
@ -454,8 +468,13 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
clip_target.clip = sa_t5.SAT5Model
clip_target.tokenizer = sa_t5.SAT5Tokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
if w is not None and w.shape[0] == 248:
clip_target.clip = long_clipl.LongClipModel
clip_target.tokenizer = long_clipl.LongClipTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
elif len(clip_data) == 2:
if clip_type == CLIPType.SD3:
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
@ -483,7 +502,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
for c in clip_data:
parameters += utils.calculate_parameters(c)
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters)
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters, model_options=model_options)
for c in clip_data:
m, u = clip.load_sd(c)
if len(m) > 0:
@ -529,16 +548,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
sd = utils.load_torch_file(ckpt_path)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, ckpt_path=ckpt_path)
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
if out is None:
raise RuntimeError("Could not detect model type of: {}".format(ckpt_path))
return out
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, ckpt_path: str | None = None):
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, ckpt_path=""):
clip = None
clipvision = None
vae = None
@ -589,7 +606,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
clip_sd = model_config.process_clip_state_dict(sd)
if len(clip_sd) > 0:
parameters = utils.calculate_parameters(clip_sd)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
m, u = clip.load_sd(clip_sd, full_model=True)
if len(m) > 0:
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
@ -697,10 +714,13 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
if clip is not None:
load_models.append(clip.load_model())
clip_sd = clip.get_sd()
vae_sd = None
if vae is not None:
vae_sd = vae.get_sd()
model_management.load_models_gpu(load_models, force_patch_weights=True)
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
for k in extra_keys:
sd[k] = extra_keys[k]

View File

@ -94,7 +94,6 @@ class ClipTokenWeightEncoder:
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = [
"last",
"pooled",
@ -104,7 +103,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
super().__init__()
if special_tokens is None:
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
@ -112,7 +111,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json", package=__package__)
self.operations = ops.manual_cast
operations = model_options.get("custom_operations", None)
if operations is None:
operations = ops.manual_cast
self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
self.num_layers = self.transformer.num_layers
@ -680,9 +683,12 @@ class SD1Tokenizer:
def state_dict(self):
return {}
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, textmodel_json_config=None):
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config)
class SD1ClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, name=None, **kwargs):
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs):
super().__init__()
if name is not None:
@ -692,7 +698,7 @@ class SD1ClipModel(torch.nn.Module):
self.clip_name = clip_name
self.clip = "clip_{}".format(self.clip_name)
setattr(self, self.clip, clip_model(device=device, dtype=dtype, textmodel_json_config=textmodel_json_config, **kwargs))
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config, **kwargs))
self.dtypes = set()
if dtype is not None:

View File

@ -7,14 +7,14 @@ from .component_model.files import get_path_as_dict
class SDXLClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None, model_options={}):
if layer == "penultimate":
layer = "hidden"
layer_idx = -2
textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
def load_sd(self, sd):
return super().load_sd(sd)
@ -50,11 +50,11 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
self.clip_g = SDXLClipG(device=device, dtype=dtype)
self.dtypes = set([dtype])
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = {dtype}
def set_clip_options(self, options):
self.clip_l.set_clip_options(options)
@ -79,8 +79,8 @@ class SDXLClipModel(torch.nn.Module):
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, textmodel_json_config=textmodel_json_config)
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options, textmodel_json_config=textmodel_json_config)
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
@ -94,15 +94,15 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None, model_options={}):
textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
def load_sd(self, sd):
return super().load_sd(sd)
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, textmodel_json_config=textmodel_json_config)
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, model_options={}):
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, textmodel_json_config=textmodel_json_config, model_options=model_options)

View File

@ -654,6 +654,7 @@ class Flux(supported_models_base.BASE):
def clip_target(self, state_dict={}):
pref = self.text_encoder_key_prefix[0]
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
dtype_t5 = None
if t5_key in state_dict:
dtype_t5 = state_dict[t5_key].dtype
else:

View File

@ -6,7 +6,9 @@ from ..text_encoders import t5
from ..component_model.files import get_path_as_dict
class PT5XlModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = dict()
textmodel_json_config = get_path_as_dict(textmodel_json_config, "t5_pile_config_xl.json", package=__package__)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=t5.T5, enable_attention_masks=True, zero_out_masked=True)
@ -25,5 +27,5 @@ class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
class AuraT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)

View File

@ -9,7 +9,9 @@ from ..component_model import files
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = dict()
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json", package=__package__)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5)
@ -46,11 +48,11 @@ class FluxTokenizer:
class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_t5 = model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5])
def set_clip_options(self, options):
@ -78,7 +80,6 @@ class FluxClipModel(torch.nn.Module):
def flux_clip(dtype_t5=None):
class FluxClipModel_(FluxClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return FluxClipModel_

View File

@ -11,7 +11,9 @@ from ..component_model.files import get_path_as_dict, get_package_as_path
class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = dict()
textmodel_json_config = get_path_as_dict(textmodel_json_config, "hydit_clip.json", package=__package__)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
@ -23,7 +25,9 @@ class HyditBertTokenizer(sd1_clip.SDTokenizer):
class MT5XLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = dict()
textmodel_json_config = get_path_as_dict(textmodel_json_config, "mt5_config_xl.json", package=__package__)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, return_attention_masks=True)
@ -66,10 +70,12 @@ class HyditTokenizer:
class HyditModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None):
def __init__(self, device="cpu", dtype=None, model_options=None):
super().__init__()
self.hydit_clip = HyditBertModel(dtype=dtype)
self.mt5xl = MT5XLModel(dtype=dtype)
if model_options is None:
model_options = dict()
self.hydit_clip = HyditBertModel(dtype=dtype, model_options=model_options)
self.mt5xl = MT5XLModel(dtype=dtype, model_options=model_options)
self.dtypes = set()
if dtype is not None:

View File

@ -0,0 +1,25 @@
{
"_name_or_path": "openai/clip-vit-large-patch14",
"architectures": [
"CLIPTextModel"
],
"attention_dropout": 0.0,
"bos_token_id": 0,
"dropout": 0.0,
"eos_token_id": 49407,
"hidden_act": "quick_gelu",
"hidden_size": 768,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-05,
"max_position_embeddings": 248,
"model_type": "clip_text_model",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 1,
"projection_dim": 768,
"torch_dtype": "float32",
"transformers_version": "4.24.0",
"vocab_size": 49408
}

View File

@ -0,0 +1,19 @@
from comfy import sd1_clip
import os
class LongClipTokenizer_(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
class LongClipModel_(sd1_clip.SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
class LongClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)

View File

@ -6,9 +6,11 @@ from ..component_model import files
class T5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = dict()
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_base.json", package=__package__)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, zero_out_masked=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, model_options=model_options, enable_attention_masks=True, zero_out_masked=True)
class T5BaseTokenizer(sd1_clip.SDTokenizer):
@ -25,5 +27,5 @@ class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
class SAT5Model(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, **kwargs):
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs)
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, model_options=model_options, name="t5base", clip_model=T5BaseModel, **kwargs)

View File

@ -4,13 +4,13 @@ from ..component_model.files import get_path_as_dict
class SD2ClipHModel(sd1_clip.SDClipModel):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None):
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None, model_options={}):
if layer == "penultimate":
layer = "hidden"
layer_idx = -2
textmodel_json_config = get_path_as_dict(textmodel_json_config, "sd2_clip_config.json", package=__package__)
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options)
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
@ -26,5 +26,5 @@ class SD2Tokenizer(sd1_clip.SD1Tokenizer):
class SD2ClipModel(sd1_clip.SD1ClipModel):
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, textmodel_json_config=textmodel_json_config, **kwargs)
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, model_options={}, **kwargs):
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, model_options=model_options, textmodel_json_config=textmodel_json_config, **kwargs)

View File

@ -11,9 +11,9 @@ from ..component_model import files
class T5XXLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None, model_options={}):
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json", package=__package__)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, model_options=model_options)
class T5XXLTokenizer(sd1_clip.SDTokenizer):
@ -21,7 +21,7 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
if tokenizer_data is None:
tokenizer_data = dict()
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
class SD3Tokenizer:
@ -50,24 +50,24 @@ class SD3Tokenizer:
class SD3ClipModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__()
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None
if clip_g:
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_g = None
if t5:
dtype_t5 = model_management.pick_weight_dtype(dtype_t5, dtype, device)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes.add(dtype_t5)
else:
self.t5xxl = None
@ -145,7 +145,6 @@ class SD3ClipModel(torch.nn.Module):
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
class SD3ClipModel_(SD3ClipModel):
def __init__(self, device="cpu", dtype=None):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
return SD3ClipModel_

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

1
comfy/web/assets/index-CaD4RONs.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,8 +1,8 @@
@font-face {
font-family: 'primeicons';
font-display: block;
src: url('/assets/primeicons-DMOk5skT.eot');
src: url('/assets/primeicons-DMOk5skT.eot?#iefix') format('embedded-opentype'), url('/assets/primeicons-C6QP2o4f.woff2') format('woff2'), url('/assets/primeicons-WjwUDZjB.woff') format('woff'), url('/assets/primeicons-MpK4pl85.ttf') format('truetype'), url('/assets/primeicons-Dr5RGzOO.svg?#primeicons') format('svg');
src: url('./primeicons-DMOk5skT.eot');
src: url('./primeicons-DMOk5skT.eot?#iefix') format('embedded-opentype'), url('./primeicons-C6QP2o4f.woff2') format('woff2'), url('./primeicons-WjwUDZjB.woff') format('woff'), url('./primeicons-MpK4pl85.ttf') format('truetype'), url('./primeicons-Dr5RGzOO.svg?#primeicons') format('svg');
font-weight: normal;
font-style: normal;
}
@ -1330,6 +1330,184 @@
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-7a0b94a3]:hover {
border-right: 4px solid var(--p-button-text-primary-color);
}
:root {
--red-600: #dc3545;
}
.comfy-missing-nodes[data-v-286402f2] {
font-family: monospace;
color: var(--red-600);
padding: 1.5rem;
background-color: var(--surface-ground);
border-radius: var(--border-radius);
box-shadow: var(--card-shadow);
}
.warning-title[data-v-286402f2] {
margin-top: 0;
margin-bottom: 1rem;
}
.warning-description[data-v-286402f2] {
margin-bottom: 1rem;
}
.missing-nodes-list[data-v-286402f2] {
max-height: 300px;
overflow-y: auto;
}
.missing-nodes-list.maximized[data-v-286402f2] {
max-height: unset;
}
.missing-node-item[data-v-286402f2] {
display: flex;
align-items: center;
padding: 0.5rem;
}
.node-type[data-v-286402f2] {
font-weight: 600;
color: var(--text-color);
}
.node-hint[data-v-286402f2] {
margin-left: 0.5rem;
font-style: italic;
color: var(--text-color-secondary);
}
[data-v-286402f2] .p-button {
margin-left: auto;
}
.added-nodes-warning[data-v-286402f2] {
margin-top: 1rem;
font-style: italic;
}
.input-slider[data-v-fbaf7a8c] {
display: flex;
align-items: center;
gap: 1rem;
}
.slider-part[data-v-fbaf7a8c] {
flex-grow: 1;
}
.input-part[data-v-fbaf7a8c] {
width: 5rem !important;
}
.info-chip[data-v-6361f2fb] {
background: transparent;
}
.setting-item[data-v-6361f2fb] {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 1rem;
}
.setting-label[data-v-6361f2fb] {
display: flex;
align-items: center;
flex: 1;
}
.setting-input[data-v-6361f2fb] {
flex: 1;
display: flex;
justify-content: flex-end;
margin-left: 1rem;
}
/* Ensure PrimeVue components take full width of their container */
.setting-input[data-v-6361f2fb] .p-inputtext,
.setting-input[data-v-6361f2fb] .input-slider,
.setting-input[data-v-6361f2fb] .p-select,
.setting-input[data-v-6361f2fb] .p-togglebutton {
width: 100%;
max-width: 200px;
}
.setting-input[data-v-6361f2fb] .p-inputtext {
max-width: unset;
}
/* Special case for ToggleSwitch to align it to the right */
.setting-input[data-v-6361f2fb] .p-toggleswitch {
margin-left: auto;
}
.search-box-input[data-v-8160f15b] {
width: 100%;
}
.no-results-placeholder[data-v-5a7d148a] {
display: flex;
justify-content: center;
align-items: center;
height: 100%;
padding: 2rem;
}
.no-results-placeholder[data-v-5a7d148a] .p-card {
background-color: var(--surface-ground);
text-align: center;
}
.no-results-placeholder h3[data-v-5a7d148a] {
color: var(--text-color);
margin-bottom: 0.5rem;
}
.no-results-placeholder p[data-v-5a7d148a] {
color: var(--text-color-secondary);
margin-bottom: 1rem;
}
/* Remove after we have tailwind setup */
.border-none {
border: none !important;
}
.settings-tab-panels {
padding-top: 0px !important;
}
.settings-container[data-v-29723d1f] {
display: flex;
height: 70vh;
width: 60vw;
max-width: 1000px;
overflow: hidden;
/* Prevents container from scrolling */
}
.settings-sidebar[data-v-29723d1f] {
width: 250px;
flex-shrink: 0;
/* Prevents sidebar from shrinking */
overflow-y: auto;
padding: 10px;
}
.settings-search-box[data-v-29723d1f] {
width: 100%;
margin-bottom: 10px;
}
.settings-content[data-v-29723d1f] {
flex-grow: 1;
overflow-y: auto;
/* Allows vertical scrolling */
}
/* Ensure the Listbox takes full width of the sidebar */
.settings-sidebar[data-v-29723d1f] .p-listbox {
width: 100%;
}
/* Optional: Style scrollbars for webkit browsers */
.settings-sidebar[data-v-29723d1f]::-webkit-scrollbar,
.settings-content[data-v-29723d1f]::-webkit-scrollbar {
width: 1px;
}
.settings-sidebar[data-v-29723d1f]::-webkit-scrollbar-thumb,
.settings-content[data-v-29723d1f]::-webkit-scrollbar-thumb {
background-color: transparent;
}
.pi-cog[data-v-969a1066] {
font-size: 1.25rem;
margin-right: 0.5rem;
}
.version-tag[data-v-969a1066] {
margin-left: 0.5rem;
}
.lds-ring {
display: inline-block;
position: relative;
@ -2881,6 +3059,7 @@ body {
#graph-canvas {
width: 100%;
height: 100%;
touch-action: none;
}
.comfyui-body-right {
@ -3482,184 +3661,6 @@ audio.comfy-audio.empty-audio-widget {
max-width: 25vw;
}
:root {
--red-600: #dc3545;
}
.comfy-missing-nodes[data-v-286402f2] {
font-family: monospace;
color: var(--red-600);
padding: 1.5rem;
background-color: var(--surface-ground);
border-radius: var(--border-radius);
box-shadow: var(--card-shadow);
}
.warning-title[data-v-286402f2] {
margin-top: 0;
margin-bottom: 1rem;
}
.warning-description[data-v-286402f2] {
margin-bottom: 1rem;
}
.missing-nodes-list[data-v-286402f2] {
max-height: 300px;
overflow-y: auto;
}
.missing-nodes-list.maximized[data-v-286402f2] {
max-height: unset;
}
.missing-node-item[data-v-286402f2] {
display: flex;
align-items: center;
padding: 0.5rem;
}
.node-type[data-v-286402f2] {
font-weight: 600;
color: var(--text-color);
}
.node-hint[data-v-286402f2] {
margin-left: 0.5rem;
font-style: italic;
color: var(--text-color-secondary);
}
[data-v-286402f2] .p-button {
margin-left: auto;
}
.added-nodes-warning[data-v-286402f2] {
margin-top: 1rem;
font-style: italic;
}
.input-slider[data-v-fbaf7a8c] {
display: flex;
align-items: center;
gap: 1rem;
}
.slider-part[data-v-fbaf7a8c] {
flex-grow: 1;
}
.input-part[data-v-fbaf7a8c] {
width: 5rem !important;
}
.info-chip[data-v-4feeb3d2] {
background: transparent;
}
.setting-item[data-v-4feeb3d2] {
display: flex;
justify-content: space-between;
align-items: center;
margin-bottom: 1rem;
}
.setting-label[data-v-4feeb3d2] {
display: flex;
align-items: center;
flex: 1;
}
.setting-input[data-v-4feeb3d2] {
flex: 1;
display: flex;
justify-content: flex-end;
margin-left: 1rem;
}
/* Ensure PrimeVue components take full width of their container */
.setting-input[data-v-4feeb3d2] .p-inputtext,
.setting-input[data-v-4feeb3d2] .input-slider,
.setting-input[data-v-4feeb3d2] .p-select,
.setting-input[data-v-4feeb3d2] .p-togglebutton {
width: 100%;
max-width: 200px;
}
.setting-input[data-v-4feeb3d2] .p-inputtext {
max-width: unset;
}
/* Special case for ToggleSwitch to align it to the right */
.setting-input[data-v-4feeb3d2] .p-toggleswitch {
margin-left: auto;
}
.search-box-input[data-v-3bbe5335] {
width: 100%;
}
.no-results-placeholder[data-v-5a7d148a] {
display: flex;
justify-content: center;
align-items: center;
height: 100%;
padding: 2rem;
}
.no-results-placeholder[data-v-5a7d148a] .p-card {
background-color: var(--surface-ground);
text-align: center;
}
.no-results-placeholder h3[data-v-5a7d148a] {
color: var(--text-color);
margin-bottom: 0.5rem;
}
.no-results-placeholder p[data-v-5a7d148a] {
color: var(--text-color-secondary);
margin-bottom: 1rem;
}
/* Remove after we have tailwind setup */
.border-none {
border: none !important;
}
.settings-tab-panels {
padding-top: 0px !important;
}
.settings-container[data-v-833dbfbb] {
display: flex;
height: 70vh;
width: 60vw;
max-width: 1000px;
overflow: hidden;
/* Prevents container from scrolling */
}
.settings-sidebar[data-v-833dbfbb] {
width: 250px;
flex-shrink: 0;
/* Prevents sidebar from shrinking */
overflow-y: auto;
padding: 10px;
}
.settings-search-box[data-v-833dbfbb] {
width: 100%;
margin-bottom: 10px;
}
.settings-content[data-v-833dbfbb] {
flex-grow: 1;
overflow-y: auto;
/* Allows vertical scrolling */
}
/* Ensure the Listbox takes full width of the sidebar */
.settings-sidebar[data-v-833dbfbb] .p-listbox {
width: 100%;
}
/* Optional: Style scrollbars for webkit browsers */
.settings-sidebar[data-v-833dbfbb]::-webkit-scrollbar,
.settings-content[data-v-833dbfbb]::-webkit-scrollbar {
width: 1px;
}
.settings-sidebar[data-v-833dbfbb]::-webkit-scrollbar-thumb,
.settings-content[data-v-833dbfbb]::-webkit-scrollbar-thumb {
background-color: transparent;
}
.pi-cog[data-v-969a1066] {
font-size: 1.25rem;
margin-right: 0.5rem;
}
.version-tag[data-v-969a1066] {
margin-left: 0.5rem;
}
:root {
--sidebar-width: 64px;
--sidebar-icon-size: 1.5rem;
@ -3869,77 +3870,81 @@ audio.comfy-audio.empty-audio-widget {
color: var(--error-text);
}
.comfy-vue-node-search-container[data-v-b8a4ffdc] {
.comfy-vue-node-search-container[data-v-ba2c5897] {
display: flex;
width: 100%;
min-width: 24rem;
align-items: center;
justify-content: center;
}
.comfy-vue-node-search-container[data-v-b8a4ffdc] * {
.comfy-vue-node-search-container[data-v-ba2c5897] * {
pointer-events: auto;
}
.comfy-vue-node-preview-container[data-v-b8a4ffdc] {
.comfy-vue-node-preview-container[data-v-ba2c5897] {
position: absolute;
left: -350px;
top: 50px;
}
.comfy-vue-node-search-box[data-v-b8a4ffdc] {
.comfy-vue-node-search-box[data-v-ba2c5897] {
z-index: 10;
flex-grow: 1;
}
.option-container[data-v-b8a4ffdc] {
.option-container[data-v-ba2c5897] {
display: flex;
width: 100%;
cursor: pointer;
flex-direction: column;
align-items: center;
justify-content: space-between;
overflow: hidden;
padding-left: 1rem;
padding-right: 1rem;
padding-top: 0.5rem;
padding-bottom: 0.5rem;
padding-left: 0.5rem;
padding-right: 0.5rem;
padding-top: 0px;
padding-bottom: 0px;
}
.option-display-name[data-v-b8a4ffdc] {
.option-display-name[data-v-ba2c5897] {
display: flex;
flex-direction: column;
font-weight: 600;
}
.option-category[data-v-b8a4ffdc] {
.option-category[data-v-ba2c5897] {
overflow: hidden;
text-overflow: ellipsis;
font-size: 0.875rem;
line-height: 1.25rem;
font-weight: 300;
--tw-text-opacity: 1;
color: rgb(156 163 175 / var(--tw-text-opacity));
/* Keeps the text on a single line by default */
white-space: nowrap;
}
.i-badge[data-v-b8a4ffdc] {
.i-badge[data-v-ba2c5897] {
--tw-bg-opacity: 1;
background-color: rgb(34 197 94 / var(--tw-bg-opacity));
--tw-text-opacity: 1;
color: rgb(255 255 255 / var(--tw-text-opacity));
}
.o-badge[data-v-b8a4ffdc] {
.o-badge[data-v-ba2c5897] {
--tw-bg-opacity: 1;
background-color: rgb(239 68 68 / var(--tw-bg-opacity));
--tw-text-opacity: 1;
color: rgb(255 255 255 / var(--tw-text-opacity));
}
.c-badge[data-v-b8a4ffdc] {
.c-badge[data-v-ba2c5897] {
--tw-bg-opacity: 1;
background-color: rgb(59 130 246 / var(--tw-bg-opacity));
--tw-text-opacity: 1;
color: rgb(255 255 255 / var(--tw-text-opacity));
}
.s-badge[data-v-b8a4ffdc] {
.s-badge[data-v-ba2c5897] {
--tw-bg-opacity: 1;
background-color: rgb(234 179 8 / var(--tw-bg-opacity));
}
[data-v-b8a4ffdc] .highlight {
[data-v-ba2c5897] .highlight {
background-color: var(--p-primary-color);
color: var(--p-primary-contrast-color);
font-weight: bold;
border-radius: 0.25rem;
padding: 0.125rem 0.25rem;
padding: 0rem 0.125rem;
margin: -0.125rem 0.125rem;
}
@ -3971,46 +3976,57 @@ audio.comfy-audio.empty-audio-widget {
z-index: 99999;
}
.result-container[data-v-0fac61d9] {
.broken-image[data-v-1a883642] {
display: none;
}
.broken-image-placeholder[data-v-1a883642] {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
width: 100%;
height: 100%;
margin: 2rem;
}
.broken-image-placeholder i[data-v-1a883642] {
font-size: 3rem;
margin-bottom: 0.5rem;
}
.result-container[data-v-7ceacc88] {
width: 100%;
height: 100%;
aspect-ratio: 1 / 1;
overflow: hidden;
position: relative;
display: flex;
justify-content: center;
align-items: center;
}
[data-v-0fac61d9] img {
position: absolute;
top: 50%;
left: 50%;
transform: translate(-50%, -50%);
[data-v-7ceacc88] .task-output-image {
width: 100%;
height: 100%;
-o-object-fit: cover;
object-fit: cover;
-o-object-position: center;
object-position: center;
}
.p-image-preview[data-v-0fac61d9] {
position: static;
display: contents;
}
[data-v-0fac61d9] .image-preview-mask {
.image-preview-mask[data-v-7ceacc88] {
position: absolute;
left: 50%;
top: 50%;
transform: translate(-50%, -50%);
width: auto;
height: auto;
display: flex;
align-items: center;
justify-content: center;
opacity: 0;
padding: 10px;
cursor: pointer;
background: rgba(0, 0, 0, 0.5);
color: var(--p-image-preview-mask-color);
transition:
opacity var(--p-image-transition-duration),
background var(--p-image-transition-duration);
border-radius: 50%;
transition: opacity 0.3s ease;
}
.result-container:hover .image-preview-mask[data-v-7ceacc88] {
opacity: 1;
}
.task-result-preview[data-v-6cf8179c] {
.task-result-preview[data-v-7c099cb7] {
aspect-ratio: 1 / 1;
overflow: hidden;
display: flex;
@ -4019,18 +4035,18 @@ audio.comfy-audio.empty-audio-widget {
width: 100%;
height: 100%;
}
.task-result-preview i[data-v-6cf8179c],
.task-result-preview span[data-v-6cf8179c] {
.task-result-preview i[data-v-7c099cb7],
.task-result-preview span[data-v-7c099cb7] {
font-size: 2rem;
}
.task-item[data-v-6cf8179c] {
.task-item[data-v-7c099cb7] {
display: flex;
flex-direction: column;
border-radius: 4px;
overflow: hidden;
position: relative;
}
.task-item-details[data-v-6cf8179c] {
.task-item-details[data-v-7c099cb7] {
position: absolute;
bottom: 0;
padding: 0.6rem;
@ -4041,12 +4057,23 @@ audio.comfy-audio.empty-audio-widget {
/* In dark mode, transparent background color for tags is not ideal for tags that
are floating on top of images. */
.tag-wrapper[data-v-6cf8179c] {
.tag-wrapper[data-v-7c099cb7] {
background-color: var(--p-primary-contrast-color);
border-radius: 6px;
display: inline-flex;
}
/* PrimeVue's galleria teleports the fullscreen gallery out of subtree so we
cannot use scoped style here. */
img.galleria-image {
max-width: 100vw;
max-height: 100vh;
-o-object-fit: contain;
object-fit: contain;
/* Set z-index so the close button doesn't get hidden behind the image when image is large */
z-index: -1;
}
.comfy-vue-side-bar-container[data-v-bde767d2] {
display: flex;
flex-direction: column;
@ -4078,14 +4105,33 @@ are floating on top of images. */
background-color: transparent;
}
.queue-grid[data-v-7f831ee9] {
.scroll-container[data-v-bd027c46] {
height: 100%;
overflow-y: auto;
}
.queue-grid[data-v-bd027c46] {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
padding: 0.5rem;
gap: 0.5rem;
}
.spinner[data-v-40c18658] {
.node-lib-tree-node-label {
display: flex;
align-items: center;
margin-left: var(--p-tree-node-gap);
}
[data-v-11e12183] .node-lib-search-box {
margin-left: 1rem;
margin-right: 1rem;
margin-top: 1rem;
}
[data-v-11e12183] .comfy-vue-side-bar-body {
background: var(--p-tree-background);
}
.spinner[data-v-afce9bd6] {
position: absolute;
inset: 0px;
display: flex;

View File

@ -48,7 +48,7 @@ var __async = (__this, __arguments, generator) => {
});
};
var _PrimitiveNode_instances, onFirstConnection_fn, createWidget_fn, mergeWidgetConfig_fn, isValidConnection_fn, removeWidgets_fn, _convertedToProcess;
import { C as ComfyDialog, $ as $el, a as ComfyApp, b as app, L as LGraphCanvas, c as LiteGraph, d as applyTextReplacements, e as ComfyWidgets, f as addValueControlWidgets, D as DraggableList, g as api, h as LGraphGroup, i as LGraphNode } from "./index-D8Zp4vRl.js";
import { C as ComfyDialog, $ as $el, a as ComfyApp, b as app, L as LGraphCanvas, c as LiteGraph, d as applyTextReplacements, e as ComfyWidgets, f as addValueControlWidgets, D as DraggableList, g as api, u as useToastStore, h as LGraphGroup, i as LGraphNode } from "./index-CaD4RONs.js";
const _ClipspaceDialog = class _ClipspaceDialog extends ComfyDialog {
static registerButton(name, contextPredicate, callback) {
const item = $el("button", {
@ -891,6 +891,7 @@ app.registerExtension({
});
app.ui.settings.addSetting({
id: id$4,
category: ["Comfy", "ColorPalette"],
name: "Color Palette",
type: /* @__PURE__ */ __name((name, setter, value) => {
const options = [
@ -923,12 +924,6 @@ app.registerExtension({
options
);
return $el("tr", [
$el("td", [
$el("label", {
for: id$4.replaceAll(".", "-"),
textContent: "Color palette"
})
]),
$el("td", [
els.select,
$el(
@ -1221,7 +1216,6 @@ app.registerExtension({
const floatWeight = parseFloat(weight);
if (isNaN(floatWeight)) return weight;
const newWeight = floatWeight + delta;
if (newWeight < 0) return "0";
return String(Number(newWeight.toFixed(10)));
}
__name(incrementWeight, "incrementWeight");
@ -1302,7 +1296,7 @@ app.registerExtension({
selectedText = addWeightToParentheses(selectedText);
const weightDelta = event.key === "ArrowUp" ? delta : -delta;
const updatedText = selectedText.replace(
/\((.*):(\d+(?:\.\d+)?)\)/,
/\((.*):([+-]?\d+(?:\.\d+)?)\)/,
(match, text, weight) => {
weight = incrementWeight(weight, weightDelta);
if (weight == 1) {
@ -1620,9 +1614,9 @@ function getWidgetConfig(slot) {
}
__name(getWidgetConfig, "getWidgetConfig");
function getConfig(widgetName) {
var _a, _b, _c, _d;
var _a, _b, _c, _d, _e;
const { nodeData } = this.constructor;
return (_d = (_a = nodeData == null ? void 0 : nodeData.input) == null ? void 0 : _a.required[widgetName]) != null ? _d : (_c = (_b = nodeData == null ? void 0 : nodeData.input) == null ? void 0 : _b.optional) == null ? void 0 : _c[widgetName];
return (_e = (_b = (_a = nodeData == null ? void 0 : nodeData.input) == null ? void 0 : _a.required) == null ? void 0 : _b[widgetName]) != null ? _e : (_d = (_c = nodeData == null ? void 0 : nodeData.input) == null ? void 0 : _c.optional) == null ? void 0 : _d[widgetName];
}
__name(getConfig, "getConfig");
function isConvertibleWidget(widget, config) {
@ -1854,8 +1848,7 @@ app.registerExtension({
init() {
useConversionSubmenusSetting = app.ui.settings.addSetting({
id: "Comfy.NodeInputConversionSubmenus",
name: "Node widget/input conversion sub-menus",
tooltip: "In the node context menu, place the entries that convert between input/widget in sub-menus.",
name: "In the node context menu, place the entries that convert between input/widget in sub-menus.",
type: "boolean",
defaultValue: true
});
@ -3957,7 +3950,8 @@ app.registerExtension({
}, "replace");
app.ui.settings.addSetting({
id: id$2,
name: "Invert Menu Scrolling",
category: ["Comfy", "Graph", "InvertMenuScrolling"],
name: "Invert Context Menu Scrolling",
type: "boolean",
defaultValue: false,
onChange(value) {
@ -3974,58 +3968,66 @@ app.registerExtension({
name: "Comfy.Keybinds",
init() {
const keybindListener = /* @__PURE__ */ __name(function(event) {
const modifierPressed = event.ctrlKey || event.metaKey;
if (modifierPressed && event.key === "Enter") {
if (event.altKey) {
api.interrupt();
return __async(this, null, function* () {
const modifierPressed = event.ctrlKey || event.metaKey;
if (modifierPressed && event.key === "Enter") {
if (event.altKey) {
yield api.interrupt();
useToastStore().add({
severity: "info",
summary: "Interrupted",
detail: "Execution has been interrupted",
life: 1e3
});
return;
}
app.queuePrompt(event.shiftKey ? -1 : 0).then();
return;
}
app.queuePrompt(event.shiftKey ? -1 : 0).then();
return;
}
const target = event.composedPath()[0];
if (["INPUT", "TEXTAREA"].includes(target.tagName)) {
return;
}
const modifierKeyIdMap = {
s: "#comfy-save-button",
o: "#comfy-file-input",
Backspace: "#comfy-clear-button",
d: "#comfy-load-default-button"
};
const modifierKeybindId = modifierKeyIdMap[event.key];
if (modifierPressed && modifierKeybindId) {
event.preventDefault();
const elem = document.querySelector(modifierKeybindId);
elem.click();
return;
}
if (event.ctrlKey || event.altKey || event.metaKey) {
return;
}
if (event.key === "Escape") {
const modals = document.querySelectorAll(".comfy-modal");
const modal = Array.from(modals).find(
(modal2) => window.getComputedStyle(modal2).getPropertyValue("display") !== "none"
);
if (modal) {
modal.style.display = "none";
const target = event.composedPath()[0];
if (["INPUT", "TEXTAREA"].includes(target.tagName)) {
return;
}
;
[...document.querySelectorAll("dialog")].forEach((d) => {
d.close();
});
}
const keyIdMap = {
q: ".queue-tab-button.side-bar-button",
h: ".queue-tab-button.side-bar-button",
r: "#comfy-refresh-button"
};
const buttonId = keyIdMap[event.key];
if (buttonId) {
const button = document.querySelector(buttonId);
button.click();
}
const modifierKeyIdMap = {
s: "#comfy-save-button",
o: "#comfy-file-input",
Backspace: "#comfy-clear-button",
d: "#comfy-load-default-button"
};
const modifierKeybindId = modifierKeyIdMap[event.key];
if (modifierPressed && modifierKeybindId) {
event.preventDefault();
const elem = document.querySelector(modifierKeybindId);
elem.click();
return;
}
if (event.ctrlKey || event.altKey || event.metaKey) {
return;
}
if (event.key === "Escape") {
const modals = document.querySelectorAll(".comfy-modal");
const modal = Array.from(modals).find(
(modal2) => window.getComputedStyle(modal2).getPropertyValue("display") !== "none"
);
if (modal) {
modal.style.display = "none";
}
;
[...document.querySelectorAll("dialog")].forEach((d) => {
d.close();
});
}
const keyIdMap = {
q: ".queue-tab-button.side-bar-button",
h: ".queue-tab-button.side-bar-button",
r: "#comfy-refresh-button"
};
const buttonId = keyIdMap[event.key];
if (buttonId) {
const button = document.querySelector(buttonId);
button.click();
}
});
}, "keybindListener");
window.addEventListener("keydown", keybindListener, true);
}
@ -4037,6 +4039,7 @@ const ext = {
return __async(this, null, function* () {
app2.ui.settings.addSetting({
id: id$1,
category: ["Comfy", "Graph", "LinkRenderMode"],
name: "Link Render Mode",
defaultValue: 2,
type: "combo",
@ -5684,7 +5687,9 @@ app.registerExtension({
LiteGraph.middle_click_slot_add_default_node = true;
this.suggestionsNumber = app.ui.settings.addSetting({
id: "Comfy.NodeSuggestions.number",
category: ["Comfy", "Node Search Box", "NodeSuggestions"],
name: "Number of nodes suggestions",
tooltip: "Only for litegraph searchbox/context menu",
type: "slider",
attrs: {
min: 1,
@ -5766,7 +5771,8 @@ app.registerExtension({
init() {
app.ui.settings.addSetting({
id: "Comfy.SnapToGrid.GridSize",
name: "Grid Size",
category: ["Comfy", "Graph", "GridSize"],
name: "Snap to gird size",
type: "slider",
attrs: {
min: 1,
@ -6176,4 +6182,4 @@ app.registerExtension({
};
}
});
//# sourceMappingURL=index--0nRVkuV.js.map
//# sourceMappingURL=index-DkvOTKox.js.map

1
comfy/web/assets/index-DkvOTKox.js.map generated vendored Normal file

File diff suppressed because one or more lines are too long

View File

@ -20,7 +20,7 @@ var __async = (__this, __arguments, generator) => {
step((generator = generator.apply(__this, __arguments)).next());
});
};
import { j as createSpinner, g as api, $ as $el } from "./index-D8Zp4vRl.js";
import { j as createSpinner, g as api, $ as $el } from "./index-CaD4RONs.js";
const _UserSelectionScreen = class _UserSelectionScreen {
show(users, user) {
return __async(this, null, function* () {
@ -139,4 +139,4 @@ window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
export {
UserSelectionScreen
};
//# sourceMappingURL=userSelection-CH4RQEqW.js.map
//# sourceMappingURL=userSelection-GRU1gtOt.js.map

File diff suppressed because one or more lines are too long

View File

@ -14,8 +14,8 @@
</style> -->
<link rel="stylesheet" type="text/css" href="user.css" />
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
<script type="module" crossorigin src="/assets/index-D8Zp4vRl.js"></script>
<link rel="stylesheet" crossorigin href="/assets/index-BHzRuMlR.css">
<script type="module" crossorigin src="./assets/index-CaD4RONs.js"></script>
<link rel="stylesheet" crossorigin href="./assets/index-DAK31IJJ.css">
</head>
<body class="litegraph">
<div id="vue-app"></div>

View File

@ -332,6 +332,25 @@ class VAESave:
utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
return {}
class ModelSave:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
RETURN_TYPES = ()
FUNCTION = "save"
OUTPUT_NODE = True
CATEGORY = "advanced/model_merging"
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
return {}
NODE_CLASS_MAPPINGS = {
"ModelMergeSimple": ModelMergeSimple,
"ModelMergeBlocks": ModelMergeBlocks,
@ -343,4 +362,9 @@ NODE_CLASS_MAPPINGS = {
"CLIPMergeAdd": CLIPAdd,
"CLIPSave": CLIPSave,
"VAESave": VAESave,
"ModelSave": ModelSave,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CheckpointSave": "Save Checkpoint",
}

View File

@ -4,14 +4,14 @@ class Example:
Class methods
-------------
INPUT_TYPES (dict):
INPUT_TYPES (dict):
Tell the main program input parameters of nodes.
IS_CHANGED:
optional method to control when the node is re executed.
Attributes
----------
RETURN_TYPES (`tuple`):
RETURN_TYPES (`tuple`):
The type of each element in the output tuple.
RETURN_NAMES (`tuple`):
Optional: The name of each output in the output tuple.
@ -23,13 +23,19 @@ class Example:
Assumed to be False if not present.
CATEGORY (`str`):
The category the node should appear in the UI.
DEPRECATED (`bool`):
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
functional in existing workflows that use them.
EXPERIMENTAL (`bool`):
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
significant changes or removal in future versions. Use with caution in production workflows.
execute(s) -> tuple || None:
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
"""
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
"""

View File

@ -0,0 +1,115 @@
import pytest
from aiohttp import web
from unittest.mock import MagicMock, patch
from api_server.routes.internal.internal_routes import InternalRoutes
from api_server.services.file_service import FileService
from folder_paths import models_dir, user_directory, output_directory
@pytest.fixture
def internal_routes():
return InternalRoutes()
@pytest.fixture
def aiohttp_client_factory(aiohttp_client, internal_routes):
async def _get_client():
app = internal_routes.get_app()
return await aiohttp_client(app)
return _get_client
@pytest.mark.asyncio
async def test_list_files_valid_directory(aiohttp_client_factory, internal_routes):
mock_file_list = [
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
{"name": "dir1", "path": "dir1", "type": "directory"}
]
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
client = await aiohttp_client_factory()
resp = await client.get('/files?directory=models')
assert resp.status == 200
data = await resp.json()
assert 'files' in data
assert len(data['files']) == 2
assert data['files'] == mock_file_list
# Check other valid directories
resp = await client.get('/files?directory=user')
assert resp.status == 200
resp = await client.get('/files?directory=output')
assert resp.status == 200
@pytest.mark.asyncio
async def test_list_files_invalid_directory(aiohttp_client_factory, internal_routes):
internal_routes.file_service.list_files = MagicMock(side_effect=ValueError("Invalid directory key"))
client = await aiohttp_client_factory()
resp = await client.get('/files?directory=invalid')
assert resp.status == 400
data = await resp.json()
assert 'error' in data
assert data['error'] == "Invalid directory key"
@pytest.mark.asyncio
async def test_list_files_exception(aiohttp_client_factory, internal_routes):
internal_routes.file_service.list_files = MagicMock(side_effect=Exception("Unexpected error"))
client = await aiohttp_client_factory()
resp = await client.get('/files?directory=models')
assert resp.status == 500
data = await resp.json()
assert 'error' in data
assert data['error'] == "Unexpected error"
@pytest.mark.asyncio
async def test_list_files_no_directory_param(aiohttp_client_factory, internal_routes):
mock_file_list = []
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
client = await aiohttp_client_factory()
resp = await client.get('/files')
assert resp.status == 200
data = await resp.json()
assert 'files' in data
assert len(data['files']) == 0
def test_setup_routes(internal_routes):
internal_routes.setup_routes()
routes = internal_routes.routes
assert any(route.method == 'GET' and str(route.path) == '/files' for route in routes)
def test_get_app(internal_routes):
app = internal_routes.get_app()
assert isinstance(app, web.Application)
assert internal_routes._app is not None
def test_get_app_reuse(internal_routes):
app1 = internal_routes.get_app()
app2 = internal_routes.get_app()
assert app1 is app2
@pytest.mark.asyncio
async def test_routes_added_to_app(aiohttp_client_factory, internal_routes):
client = await aiohttp_client_factory()
try:
resp = await client.get('/files')
print(f"Response received: status {resp.status}")
except Exception as e:
print(f"Exception occurred during GET request: {e}")
raise
assert resp.status != 404, "Route /files does not exist"
@pytest.mark.asyncio
async def test_file_service_initialization():
with patch('api_server.routes.internal.internal_routes.FileService') as MockFileService:
# Create a mock instance
mock_file_service_instance = MagicMock(spec=FileService)
MockFileService.return_value = mock_file_service_instance
internal_routes = InternalRoutes()
# Check if FileService was initialized with the correct parameters
MockFileService.assert_called_once_with({
"models": models_dir,
"user": user_directory,
"output": output_directory
})
# Verify that the file_service attribute of InternalRoutes is set
assert internal_routes.file_service == mock_file_service_instance

View File

@ -0,0 +1,54 @@
import pytest
from unittest.mock import MagicMock
from api_server.services.file_service import FileService
@pytest.fixture
def mock_file_system_ops():
return MagicMock()
@pytest.fixture
def file_service(mock_file_system_ops):
allowed_directories = {
"models": "/path/to/models",
"user": "/path/to/user",
"output": "/path/to/output"
}
return FileService(allowed_directories, file_system_ops=mock_file_system_ops)
def test_list_files_valid_directory(file_service, mock_file_system_ops):
mock_file_system_ops.walk_directory.return_value = [
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
{"name": "dir1", "path": "dir1", "type": "directory"}
]
result = file_service.list_files("models")
assert len(result) == 2
assert result[0]["name"] == "file1.txt"
assert result[1]["name"] == "dir1"
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
def test_list_files_invalid_directory(file_service):
# Does not support walking directories outside of the allowed directories
with pytest.raises(ValueError, match="Invalid directory key"):
file_service.list_files("invalid_key")
def test_list_files_empty_directory(file_service, mock_file_system_ops):
mock_file_system_ops.walk_directory.return_value = []
result = file_service.list_files("models")
assert len(result) == 0
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
@pytest.mark.parametrize("directory_key", ["models", "user", "output"])
def test_list_files_all_allowed_directories(file_service, mock_file_system_ops, directory_key):
mock_file_system_ops.walk_directory.return_value = [
{"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100}
]
result = file_service.list_files(directory_key)
assert len(result) == 1
assert result[0]["name"] == f"file_{directory_key}.txt"
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")

View File

@ -0,0 +1,42 @@
import pytest
from typing import List
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem, is_file_info
@pytest.fixture
def temp_directory(tmp_path):
# Create a temporary directory structure
dir1 = tmp_path / "dir1"
dir2 = tmp_path / "dir2"
dir1.mkdir()
dir2.mkdir()
(dir1 / "file1.txt").write_text("content1")
(dir2 / "file2.txt").write_text("content2")
(tmp_path / "file3.txt").write_text("content3")
return tmp_path
def test_walk_directory(temp_directory):
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
assert len(result) == 5 # 2 directories and 3 files
files = [item for item in result if item['type'] == 'file']
dirs = [item for item in result if item['type'] == 'directory']
assert len(files) == 3
assert len(dirs) == 2
file_names = {file['name'] for file in files}
assert file_names == {'file1.txt', 'file2.txt', 'file3.txt'}
dir_names = {dir['name'] for dir in dirs}
assert dir_names == {'dir1', 'dir2'}
def test_walk_directory_empty(tmp_path):
result = FileSystemOperations.walk_directory(str(tmp_path))
assert len(result) == 0
def test_walk_directory_file_size(temp_directory):
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
files = [item for item in result if is_file_info(item)]
for file in files:
assert file['size'] > 0 # Assuming all files have some content

View File

@ -361,19 +361,6 @@ class TestExecution:
for i in range(3):
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
async def test_output_reuse(self, client: Client, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
output1 = g.node("PreviewImage", images=input1.out(0))
output2 = g.node("PreviewImage", images=input1.out(0))
result = await client.run(g)
images1 = result.get_images(output1)
images2 = result.get_images(output2)
assert len(images1) == 1, "Should have 1 image"
assert len(images2) == 1, "Should have 1 image"
async def test_mixed_lazy_results(self, client: Client, builder: GraphBuilder):
g = builder
val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0)
@ -390,3 +377,55 @@ class TestExecution:
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be 0.0"
assert numpy.array(images[1]).min() == 127 and numpy.array(images[1]).max() == 127, "Second image should be 0.5"
assert numpy.array(images[2]).min() == 255 and numpy.array(images[2]).max() == 255, "Third image should be 1.0"
async def test_missing_node_error(self, client: Client, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
input3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
mix2 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input3.out(0), mask=mask.out(0))
# We have multiple outputs. The first is invalid, but the second is valid
g.node("SaveImage", images=mix1.out(0))
g.node("SaveImage", images=mix2.out(0))
g.remove_node("removeme")
await client.run(g)
# Add back in the missing node to make sure the error doesn't break the server
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
await client.run(g)
async def test_output_reuse(self, client: Client, builder: GraphBuilder):
g = builder
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
output1 = g.node("SaveImage", images=input1.out(0))
output2 = g.node("SaveImage", images=input1.out(0))
result = await client.run(g)
images1 = result.get_images(output1)
images2 = result.get_images(output2)
assert len(images1) == 1, "Should have 1 image"
assert len(images2) == 1, "Should have 1 image"
# This tests that only constant outputs are used in the call to `IS_CHANGED`
async def test_is_changed_with_outputs(self, client: Client, builder: GraphBuilder):
g = builder
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
output = g.node("PreviewImage", images=test_node.out(0))
result = await client.run(g)
images = result.get_images(output)
assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
result = await client.run(g)
images = result.get_images(output)
assert len(images) == 1, "Should have 1 image"
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
assert not result.did_run(test_node), "The execution should have been cached"

View File

@ -95,6 +95,31 @@ class TestCustomIsChanged:
else:
return False
class TestIsChangedWithConstants:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_is_changed"
CATEGORY = "Testing/Nodes"
def custom_is_changed(self, image, value):
return (image * value,)
@classmethod
def IS_CHANGED(cls, image, value):
if image is None:
return value
else:
return image.mean().item() * value
class TestCustomValidation1:
@classmethod
def INPUT_TYPES(cls):
@ -287,6 +312,7 @@ TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage,
"TestCustomIsChanged": TestCustomIsChanged,
"TestIsChangedWithConstants": TestIsChangedWithConstants,
"TestCustomValidation1": TestCustomValidation1,
"TestCustomValidation2": TestCustomValidation2,
"TestCustomValidation3": TestCustomValidation3,
@ -299,6 +325,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestLazyMixImages": "Lazy Mix Images",
"TestVariadicAverage": "Variadic Average",
"TestCustomIsChanged": "Custom IsChanged",
"TestIsChangedWithConstants": "IsChanged With Constants",
"TestCustomValidation1": "Custom Validation 1",
"TestCustomValidation2": "Custom Validation 2",
"TestCustomValidation3": "Custom Validation 3",

View File

@ -28,6 +28,28 @@ class StubImage:
elif content == "NOISE":
return (torch.rand(batch_size, height, width, 3),)
class StubConstantImage:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "stub_constant_image"
CATEGORY = "Testing/Stub Nodes"
def stub_constant_image(self, value, height, width, batch_size):
return (torch.ones(batch_size, height, width, 3) * value,)
class StubMask:
def __init__(self):
pass
@ -93,12 +115,14 @@ class StubFloat:
TEST_STUB_NODE_CLASS_MAPPINGS = {
"StubImage": StubImage,
"StubConstantImage": StubConstantImage,
"StubMask": StubMask,
"StubInt": StubInt,
"StubFloat": StubFloat,
}
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
"StubImage": "Stub Image",
"StubConstantImage": "Stub Constant Image",
"StubMask": "Stub Mask",
"StubInt": "Stub Int",
"StubFloat": "Stub Float",