mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-13 07:40:50 +08:00
Merge WIP
This commit is contained in:
commit
5155a3e248
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
2
.ci/windows_nightly_base_files/run_nvidia_gpu_fast.bat
Normal file
@ -0,0 +1,2 @@
|
||||
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast
|
||||
pause
|
||||
3
.github/ISSUE_TEMPLATE/config.yml
vendored
3
.github/ISSUE_TEMPLATE/config.yml
vendored
@ -1,5 +1,8 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: ComfyUI Frontend Issues
|
||||
url: https://github.com/Comfy-Org/ComfyUI_frontend/issues
|
||||
about: Issues related to the ComfyUI frontend (display issues, user interaction bugs), please go to the frontend repo to file the issue
|
||||
- name: ComfyUI Matrix Space
|
||||
url: https://app.element.io/#/room/%23comfyui_space%3Amatrix.org
|
||||
about: The ComfyUI Matrix Space is available for support and general discussion related to ComfyUI (Matrix is like Discord but open source).
|
||||
|
||||
44
README.md
44
README.md
@ -598,6 +598,7 @@ The default installation includes a fast latent preview method that's low-resolu
|
||||
|------------------------------------|--------------------------------------------------------------------------------------------------------------------|
|
||||
| Ctrl + Enter | Queue up current graph for generation |
|
||||
| Ctrl + Shift + Enter | Queue up current graph as first for generation |
|
||||
| Ctrl + Alt + Enter | Cancel current generation |
|
||||
| Ctrl + Z/Ctrl + Y | Undo/Redo |
|
||||
| Ctrl + S | Save workflow |
|
||||
| Ctrl + O | Load workflow |
|
||||
@ -620,6 +621,8 @@ The default installation includes a fast latent preview method that's low-resolu
|
||||
| H | Toggle visibility of history |
|
||||
| R | Refresh graph |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
| Shift + Drag | Move multiple wires at once |
|
||||
| Ctrl + Alt + LMB | Disconnect all wires from clicked slot |
|
||||
|
||||
Ctrl can also be replaced with Cmd instead for macOS users
|
||||
|
||||
@ -1013,6 +1016,47 @@ To run:
|
||||
docker run -it -v ./output:/workspace/output -v ./models:/workspace/models --gpus=all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --rm hiddenswitch/comfyui
|
||||
```
|
||||
|
||||
## Frontend Development
|
||||
|
||||
As of August 15, 2024, we have transitioned to a new frontend, which is now hosted in a separate repository: [ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend). This repository now hosts the compiled JS (from TS/Vue) under the `web/` directory.
|
||||
|
||||
### Reporting Issues and Requesting Features
|
||||
|
||||
For any bugs, issues, or feature requests related to the frontend, please use the [ComfyUI Frontend repository](https://github.com/Comfy-Org/ComfyUI_frontend). This will help us manage and address frontend-specific concerns more efficiently.
|
||||
|
||||
### Using the Latest Frontend
|
||||
|
||||
The new frontend is now the default for ComfyUI. However, please note:
|
||||
|
||||
1. The frontend in the main ComfyUI repository is updated weekly.
|
||||
2. Daily releases are available in the separate frontend repository.
|
||||
|
||||
To use the most up-to-date frontend version:
|
||||
|
||||
1. For the latest daily release, launch ComfyUI with this command line argument:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_frontend@latest
|
||||
```
|
||||
|
||||
2. For a specific version, replace `latest` with the desired version number:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
|
||||
```
|
||||
|
||||
This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||
|
||||
### Accessing the Legacy Frontend
|
||||
|
||||
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
|
||||
```
|
||||
|
||||
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
|
||||
|
||||
## Community
|
||||
|
||||
[Chat on Matrix: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org), an alternative to Discord.
|
||||
|
||||
0
api_server/__init__.py
Normal file
0
api_server/__init__.py
Normal file
0
api_server/routes/__init__.py
Normal file
0
api_server/routes/__init__.py
Normal file
3
api_server/routes/internal/README.md
Normal file
3
api_server/routes/internal/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
# ComfyUI Internal Routes
|
||||
|
||||
All routes under the `/internal` path are designated for **internal use by ComfyUI only**. These routes are not intended for use by external applications may change at any time without notice.
|
||||
0
api_server/routes/internal/__init__.py
Normal file
0
api_server/routes/internal/__init__.py
Normal file
40
api_server/routes/internal/internal_routes.py
Normal file
40
api_server/routes/internal/internal_routes.py
Normal file
@ -0,0 +1,40 @@
|
||||
from aiohttp import web
|
||||
from typing import Optional
|
||||
from folder_paths import models_dir, user_directory, output_directory
|
||||
from api_server.services.file_service import FileService
|
||||
|
||||
class InternalRoutes:
|
||||
'''
|
||||
The top level web router for internal routes: /internal/*
|
||||
The endpoints here should NOT be depended upon. It is for ComfyUI frontend use only.
|
||||
Check README.md for more information.
|
||||
|
||||
'''
|
||||
def __init__(self):
|
||||
self.routes: web.RouteTableDef = web.RouteTableDef()
|
||||
self._app: Optional[web.Application] = None
|
||||
self.file_service = FileService({
|
||||
"models": models_dir,
|
||||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
|
||||
def setup_routes(self):
|
||||
@self.routes.get('/files')
|
||||
async def list_files(request):
|
||||
directory_key = request.query.get('directory', '')
|
||||
try:
|
||||
file_list = self.file_service.list_files(directory_key)
|
||||
return web.json_response({"files": file_list})
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
except Exception as e:
|
||||
return web.json_response({"error": str(e)}, status=500)
|
||||
|
||||
|
||||
def get_app(self):
|
||||
if self._app is None:
|
||||
self._app = web.Application()
|
||||
self.setup_routes()
|
||||
self._app.add_routes(self.routes)
|
||||
return self._app
|
||||
0
api_server/services/__init__.py
Normal file
0
api_server/services/__init__.py
Normal file
13
api_server/services/file_service.py
Normal file
13
api_server/services/file_service.py
Normal file
@ -0,0 +1,13 @@
|
||||
from typing import Dict, List, Optional
|
||||
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
|
||||
|
||||
class FileService:
|
||||
def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
|
||||
self.allowed_directories: Dict[str, str] = allowed_directories
|
||||
self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
|
||||
|
||||
def list_files(self, directory_key: str) -> List[FileSystemItem]:
|
||||
if directory_key not in self.allowed_directories:
|
||||
raise ValueError("Invalid directory key")
|
||||
directory_path: str = self.allowed_directories[directory_key]
|
||||
return self.file_system_ops.walk_directory(directory_path)
|
||||
42
api_server/utils/file_operations.py
Normal file
42
api_server/utils/file_operations.py
Normal file
@ -0,0 +1,42 @@
|
||||
import os
|
||||
from typing import List, Union, TypedDict, Literal
|
||||
from typing_extensions import TypeGuard
|
||||
class FileInfo(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
type: Literal["file"]
|
||||
size: int
|
||||
|
||||
class DirectoryInfo(TypedDict):
|
||||
name: str
|
||||
path: str
|
||||
type: Literal["directory"]
|
||||
|
||||
FileSystemItem = Union[FileInfo, DirectoryInfo]
|
||||
|
||||
def is_file_info(item: FileSystemItem) -> TypeGuard[FileInfo]:
|
||||
return item["type"] == "file"
|
||||
|
||||
class FileSystemOperations:
|
||||
@staticmethod
|
||||
def walk_directory(directory: str) -> List[FileSystemItem]:
|
||||
file_list: List[FileSystemItem] = []
|
||||
for root, dirs, files in os.walk(directory):
|
||||
for name in files:
|
||||
file_path = os.path.join(root, name)
|
||||
relative_path = os.path.relpath(file_path, directory)
|
||||
file_list.append({
|
||||
"name": name,
|
||||
"path": relative_path,
|
||||
"type": "file",
|
||||
"size": os.path.getsize(file_path)
|
||||
})
|
||||
for name in dirs:
|
||||
dir_path = os.path.join(root, name)
|
||||
relative_path = os.path.relpath(dir_path, directory)
|
||||
file_list.append({
|
||||
"name": name,
|
||||
"path": relative_path,
|
||||
"type": "directory"
|
||||
})
|
||||
return file_list
|
||||
@ -59,6 +59,8 @@ class CacheKeySetID(CacheKeySet):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = (node_id, node["class_type"])
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
@ -78,6 +80,8 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
for node_id in node_ids:
|
||||
if node_id in self.keys:
|
||||
continue
|
||||
if not self.dynprompt.has_node(node_id):
|
||||
continue
|
||||
node = self.dynprompt.get_node(node_id)
|
||||
self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id)
|
||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||
@ -91,6 +95,9 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
return to_hashable(signature)
|
||||
|
||||
def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||
if not dynprompt.has_node(node_id):
|
||||
# This node doesn't exist -- we can't cache it.
|
||||
return [float("NaN")]
|
||||
node = dynprompt.get_node(node_id)
|
||||
class_type = node["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
@ -116,6 +123,8 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
return ancestors, order_mapping
|
||||
|
||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||
if not dynprompt.has_node(node_id):
|
||||
return
|
||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||
input_keys = sorted(inputs.keys())
|
||||
for key in input_keys:
|
||||
|
||||
@ -113,11 +113,13 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
vram_group.add_argument("--novram", action="store_true", help="When lowvram isn't enough.")
|
||||
vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).")
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
|
||||
parser.add_argument("--disable-smart-memory", action="store_true",
|
||||
help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
|
||||
parser.add_argument("--deterministic", action="store_true",
|
||||
help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
|
||||
parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
|
||||
|
||||
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
|
||||
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI. Raises an error if nodes cannot be imported,")
|
||||
|
||||
@ -81,6 +81,8 @@ class Configuration(dict):
|
||||
lowvram (bool): Reduce UNet's VRAM usage.
|
||||
novram (bool): Minimize VRAM usage.
|
||||
cpu (bool): Use CPU for processing.
|
||||
fast (bool): Enable some untested and potentially quality deteriorating optimizations
|
||||
reserve_vram (Optional[float]): Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS
|
||||
disable_smart_memory (bool): Disable smart memory management.
|
||||
deterministic (bool): Use deterministic algorithms where possible.
|
||||
dont_print_server (bool): Suppress server output.
|
||||
@ -157,6 +159,8 @@ class Configuration(dict):
|
||||
self.lowvram: bool = False
|
||||
self.novram: bool = False
|
||||
self.cpu: bool = False
|
||||
self.fast: bool = False
|
||||
self.reserve_vram: Optional[float] = None
|
||||
self.disable_smart_memory: bool = False
|
||||
self.deterministic: bool = False
|
||||
self.dont_print_server: bool = False
|
||||
|
||||
@ -89,10 +89,11 @@ class CLIPTextModel_(torch.nn.Module):
|
||||
intermediate_size = config_dict["intermediate_size"]
|
||||
intermediate_activation = config_dict["hidden_act"]
|
||||
vocab_size = config_dict["vocab_size"]
|
||||
num_positions = config_dict["max_position_embeddings"]
|
||||
self.eos_token_id = config_dict["eos_token_id"]
|
||||
|
||||
super().__init__()
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, vocab_size=vocab_size, dtype=dtype, device=device, operations=operations)
|
||||
self.embeddings = CLIPEmbeddings(embed_dim, vocab_size=vocab_size, num_positions=num_positions, dtype=dtype, device=device, operations=operations)
|
||||
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
|
||||
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
|
||||
|
||||
@ -124,7 +125,6 @@ class CLIPTextModel(torch.nn.Module):
|
||||
self.text_model = CLIPTextModel_(config_dict, dtype, device, operations)
|
||||
embed_dim = config_dict["hidden_size"]
|
||||
self.text_projection = operations.Linear(embed_dim, embed_dim, bias=False, dtype=dtype, device=device)
|
||||
self.text_projection.weight.copy_(torch.eye(embed_dim))
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
|
||||
@ -61,7 +61,8 @@ class IsChangedCache:
|
||||
self.is_changed[node_id] = node["is_changed"]
|
||||
return self.is_changed[node_id]
|
||||
|
||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, self.outputs_cache)
|
||||
# Intentionally do not use cached outputs here. We only want constants in IS_CHANGED
|
||||
input_data_all, _ = get_input_data(node["inputs"], class_def, node_id, None)
|
||||
try:
|
||||
is_changed = map_node_over_list(class_def, input_data_all, "IS_CHANGED")
|
||||
node["is_changed"] = [None if isinstance(x, ExecutionBlocker) else x for x in is_changed]
|
||||
@ -567,6 +568,7 @@ class PromptExecutor:
|
||||
break
|
||||
|
||||
result, error, ex = execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
@ -99,7 +99,7 @@ folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(mo
|
||||
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["unet"] = FolderPathsTuple("unet", [os.path.join(models_dir, "unet")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["unet"] = folder_names_and_paths["diffusion_models"] = FolderPathsTuple("diffusion_models", [os.path.join(models_dir, "unet"), os.path.join(models_dir, "diffusion_models")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["clip_vision"] = FolderPathsTuple("clip_vision", [os.path.join(models_dir, "clip_vision")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["style_models"] = FolderPathsTuple("style_models", [os.path.join(models_dir, "style_models")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["embeddings"] = FolderPathsTuple("embeddings", [os.path.join(models_dir, "embeddings")], set(supported_pt_extensions))
|
||||
|
||||
@ -188,6 +188,7 @@ async def main():
|
||||
folder_paths.add_model_folder_path("checkpoints", os.path.join(folder_paths.get_output_directory(), "checkpoints"))
|
||||
folder_paths.add_model_folder_path("clip", os.path.join(folder_paths.get_output_directory(), "clip"))
|
||||
folder_paths.add_model_folder_path("vae", os.path.join(folder_paths.get_output_directory(), "vae"))
|
||||
folder_paths.add_model_folder_path("diffusion_models", os.path.join(folder_paths.get_output_directory(), "diffusion_models"))
|
||||
|
||||
if args.input_directory:
|
||||
input_dir = os.path.abspath(args.input_directory)
|
||||
|
||||
@ -25,6 +25,7 @@ from aiohttp import web
|
||||
from can_ada import URL, parse as urlparse # pylint: disable=no-name-in-module
|
||||
from typing_extensions import NamedTuple
|
||||
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from ..model_filemanager import download_model, DownloadModelStatus
|
||||
from .latent_preview_image_encoding import encode_preview_image
|
||||
from .. import interruption
|
||||
@ -97,6 +98,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
self.address: str = "0.0.0.0"
|
||||
self.user_manager = UserManager()
|
||||
self.internal_routes = InternalRoutes()
|
||||
# todo: this is probably read by custom nodes elsewhere
|
||||
self.supports: List[str] = ["custom_nodes_from_web"]
|
||||
self.prompt_queue: AbstractPromptQueue | AsyncAbstractPromptQueue | None = None
|
||||
@ -170,6 +172,14 @@ class PromptServer(ExecutorToClientProgress):
|
||||
embeddings = folder_paths.get_filename_list("embeddings")
|
||||
return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings)))
|
||||
|
||||
@routes.get("/models/{folder}")
|
||||
async def get_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = folder_paths.get_filename_list(folder)
|
||||
return web.json_response(files)
|
||||
|
||||
@routes.get("/extensions")
|
||||
async def get_extensions(request):
|
||||
files = glob.glob(os.path.join(glob.escape(self.web_root), 'extensions/**/*.js'), recursive=True)
|
||||
@ -461,6 +471,11 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
if hasattr(obj_class, 'OUTPUT_TOOLTIPS'):
|
||||
info['output_tooltips'] = obj_class.OUTPUT_TOOLTIPS
|
||||
|
||||
if getattr(obj_class, "DEPRECATED", False):
|
||||
info['deprecated'] = True
|
||||
if getattr(obj_class, "EXPERIMENTAL", False):
|
||||
info['experimental'] = True
|
||||
return info
|
||||
|
||||
@routes.get("/object_info")
|
||||
@ -764,6 +779,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
# This is very useful for frontend dev server, which need to forward
|
||||
|
||||
@ -391,7 +391,8 @@ def controlnet_config(sd):
|
||||
else:
|
||||
operations = ops.disable_weight_init
|
||||
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype
|
||||
offload_device = model_management.unet_offload_device()
|
||||
return model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device
|
||||
|
||||
def controlnet_load_state_dict(control_model, sd):
|
||||
missing, unexpected = control_model.load_state_dict(sd, strict=False)
|
||||
@ -405,12 +406,12 @@ def controlnet_load_state_dict(control_model, sd):
|
||||
|
||||
def load_controlnet_mmdit(sd):
|
||||
new_sd = model_detection.convert_diffusers_mmdit(sd, "")
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(new_sd)
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
|
||||
num_blocks = model_detection.count_blocks(new_sd, 'joint_blocks.{}.')
|
||||
for k in sd:
|
||||
new_sd[k] = sd[k]
|
||||
|
||||
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = mmdit.ControlNet(num_blocks=num_blocks, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, new_sd)
|
||||
|
||||
latent_format = latent_formats.SD3()
|
||||
@ -420,9 +421,9 @@ def load_controlnet_mmdit(sd):
|
||||
|
||||
|
||||
def load_controlnet_hunyuandit(controlnet_data):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(controlnet_data)
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data)
|
||||
|
||||
control_model = hydit.controlnet.HunYuanControlNet(operations=operations, device=load_device, dtype=unet_dtype)
|
||||
control_model = hydit.controlnet.HunYuanControlNet(operations=operations, device=offload_device, dtype=unet_dtype)
|
||||
control_model = controlnet_load_state_dict(control_model, controlnet_data)
|
||||
|
||||
latent_format = latent_formats.SDXL()
|
||||
@ -431,8 +432,8 @@ def load_controlnet_hunyuandit(controlnet_data):
|
||||
return control
|
||||
|
||||
def load_controlnet_flux_xlabs(sd):
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype = controlnet_config(sd)
|
||||
control_model = flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=load_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
|
||||
control_model = flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||
control_model = controlnet_load_state_dict(control_model, sd)
|
||||
extra_conds = ['y', 'guidance']
|
||||
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||
@ -535,6 +536,7 @@ def load_controlnet(ckpt_path, model=None):
|
||||
if manual_cast_dtype is not None:
|
||||
controlnet_config["operations"] = ops.manual_cast
|
||||
controlnet_config["dtype"] = unet_dtype
|
||||
controlnet_config["device"] = model_management.unet_offload_device()
|
||||
controlnet_config.pop("out_channels")
|
||||
controlnet_config["hint_channels"] = controlnet_data["{}input_hint_block.0.weight".format(prefix)].shape[1]
|
||||
control_model = cldm.ControlNet(**controlnet_config)
|
||||
|
||||
59
comfy/float.py
Normal file
59
comfy/float.py
Normal file
@ -0,0 +1,59 @@
|
||||
import torch
|
||||
|
||||
#Not 100% sure about this
|
||||
def manual_stochastic_round_to_float8(x, dtype):
|
||||
if dtype == torch.float8_e4m3fn:
|
||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 4, 3, 7
|
||||
elif dtype == torch.float8_e5m2:
|
||||
EXPONENT_BITS, MANTISSA_BITS, EXPONENT_BIAS = 5, 2, 15
|
||||
else:
|
||||
raise ValueError("Unsupported dtype")
|
||||
|
||||
sign = torch.sign(x)
|
||||
abs_x = x.abs()
|
||||
|
||||
# Combine exponent calculation and clamping
|
||||
exponent = torch.clamp(
|
||||
torch.floor(torch.log2(abs_x)).to(torch.int32) + EXPONENT_BIAS,
|
||||
0, 2**EXPONENT_BITS - 1
|
||||
)
|
||||
|
||||
# Combine mantissa calculation and rounding
|
||||
# min_normal = 2.0 ** (-EXPONENT_BIAS + 1)
|
||||
# zero_mask = (abs_x == 0)
|
||||
# subnormal_mask = (exponent == 0) & (abs_x != 0)
|
||||
normal_mask = ~(exponent == 0)
|
||||
|
||||
mantissa_scaled = torch.where(
|
||||
normal_mask,
|
||||
(abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS),
|
||||
(abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)))
|
||||
)
|
||||
mantissa_floor = mantissa_scaled.floor()
|
||||
mantissa = torch.where(
|
||||
torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor),
|
||||
(mantissa_floor + 1) / (2**MANTISSA_BITS),
|
||||
mantissa_floor / (2**MANTISSA_BITS)
|
||||
)
|
||||
result = torch.where(
|
||||
normal_mask,
|
||||
sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa),
|
||||
sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa
|
||||
)
|
||||
|
||||
result = torch.where(abs_x == 0, 0, result)
|
||||
return result.to(dtype=dtype)
|
||||
|
||||
|
||||
|
||||
def stochastic_rounding(value, dtype):
|
||||
if dtype == torch.float32:
|
||||
return value.to(dtype=torch.float32)
|
||||
if dtype == torch.float16:
|
||||
return value.to(dtype=torch.float16)
|
||||
if dtype == torch.bfloat16:
|
||||
return value.to(dtype=torch.bfloat16)
|
||||
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
|
||||
return manual_stochastic_round_to_float8(value, dtype)
|
||||
|
||||
return value.to(dtype=dtype)
|
||||
@ -172,19 +172,41 @@ class ExecutionList(TopologicalSort):
|
||||
"current_inputs": []
|
||||
}
|
||||
return None, error_details, ex
|
||||
next_node = available[0]
|
||||
|
||||
self.staged_node_id = self.ux_friendly_pick_node(available)
|
||||
return self.staged_node_id, None, None
|
||||
|
||||
def ux_friendly_pick_node(self, node_list):
|
||||
# If an output node is available, do that first.
|
||||
# Technically this has no effect on the overall length of execution, but it feels better as a user
|
||||
# for a PreviewImage to display a result as soon as it can
|
||||
# Some other heuristics could probably be used here to improve the UX further.
|
||||
for node_id in available:
|
||||
def is_output(node_id):
|
||||
class_type = self.dynprompt.get_node(node_id)["class_type"]
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
if hasattr(class_def, 'OUTPUT_NODE') and class_def.OUTPUT_NODE == True:
|
||||
next_node = node_id
|
||||
break
|
||||
self.staged_node_id = next_node
|
||||
return self.staged_node_id, None, None
|
||||
return True
|
||||
return False
|
||||
|
||||
for node_id in node_list:
|
||||
if is_output(node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAEDecode -> preview case
|
||||
for node_id in node_list:
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
if is_output(blocked_node_id):
|
||||
return node_id
|
||||
|
||||
#This should handle the VAELoader -> VAEDecode -> preview case
|
||||
for node_id in node_list:
|
||||
for blocked_node_id in self.blocking[node_id]:
|
||||
for blocked_node_id1 in self.blocking[blocked_node_id]:
|
||||
if is_output(blocked_node_id1):
|
||||
return node_id
|
||||
|
||||
#TODO: this function should be improved
|
||||
return node_list[0]
|
||||
|
||||
def unstage_node_execution(self):
|
||||
assert self.staged_node_id is not None
|
||||
|
||||
@ -9,6 +9,7 @@ from tqdm.auto import trange, tqdm
|
||||
from . import utils
|
||||
from . import deis
|
||||
from .. import model_patcher
|
||||
from .. import model_sampling
|
||||
|
||||
def append_zero(x):
|
||||
return torch.cat([x, x.new_zeros([1])])
|
||||
@ -509,6 +510,9 @@ def sample_dpm_adaptive(model, x, sigma_min, sigma_max, extra_args=None, callbac
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
if isinstance(model.inner_model.inner_model.model_sampling, model_sampling.CONST):
|
||||
return sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args, callback, disable, eta, s_noise, noise_sampler)
|
||||
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
@ -541,6 +545,55 @@ def sample_dpmpp_2s_ancestral(model, x, sigmas, extra_args=None, callback=None,
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_2s_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
|
||||
"""Ancestral sampling with DPM-Solver++(2S) second-order steps."""
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
sigma_fn = lambda lbda: (lbda.exp() + 1) ** -1
|
||||
lambda_fn = lambda sigma: ((1-sigma)/sigma).log()
|
||||
|
||||
# logged_x = x.unsqueeze(0)
|
||||
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
if sigmas[i + 1] == 0:
|
||||
# Euler method
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
else:
|
||||
# DPM-Solver++(2S)
|
||||
if sigmas[i] == 1.0:
|
||||
sigma_s = 0.9999
|
||||
else:
|
||||
t_i, t_down = lambda_fn(sigmas[i]), lambda_fn(sigma_down)
|
||||
r = 1 / 2
|
||||
h = t_down - t_i
|
||||
s = t_i + r * h
|
||||
sigma_s = sigma_fn(s)
|
||||
# sigma_s = sigmas[i+1]
|
||||
sigma_s_i_ratio = sigma_s / sigmas[i]
|
||||
u = sigma_s_i_ratio * x + (1 - sigma_s_i_ratio) * denoised
|
||||
D_i = model(u, sigma_s * s_in, **extra_args)
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * D_i
|
||||
# print("sigma_i", sigmas[i], "sigma_ip1", sigmas[i+1],"sigma_down", sigma_down, "sigma_down_i_ratio", sigma_down_i_ratio, "sigma_s_i_ratio", sigma_s_i_ratio, "renoise_coeff", renoise_coeff)
|
||||
# Noise addition
|
||||
if sigmas[i + 1] > 0 and eta > 0:
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
# logged_x = torch.cat((logged_x, x.unsqueeze(0)), dim=0)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
|
||||
"""DPM-Solver++ (stochastic)."""
|
||||
|
||||
@ -178,7 +178,7 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
||||
|
||||
if txt.dtype == torch.float16:
|
||||
txt = txt.clip(-65504, 65504)
|
||||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
|
||||
|
||||
return img, txt
|
||||
|
||||
@ -233,7 +233,7 @@ class SingleStreamBlock(nn.Module):
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x += mod.gate * output
|
||||
if x.dtype == torch.float16:
|
||||
x = x.clip(-65504, 65504)
|
||||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
195
comfy/lora.py
195
comfy/lora.py
@ -17,8 +17,10 @@
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
from . import utils
|
||||
from . import model_base
|
||||
from . import model_management
|
||||
|
||||
LORA_CLIP_MAP = {
|
||||
"mlp.fc1": "mlp_fc1",
|
||||
@ -316,7 +318,196 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map[key_lora] = to
|
||||
key_map["transformer.{}".format(k[:-len(".weight")])] = to #simpletrainer and probably regular diffusers flux lora format
|
||||
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype):
|
||||
dora_scale = model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||
if strength != 1.0:
|
||||
weight_calc -= weight
|
||||
weight += strength * (weight_calc)
|
||||
else:
|
||||
weight[:] = weight_calc
|
||||
return weight
|
||||
|
||||
def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
offset = p[3]
|
||||
function = p[4]
|
||||
if function is None:
|
||||
function = lambda a: a
|
||||
|
||||
old_weight = None
|
||||
if offset is not None:
|
||||
old_weight = weight
|
||||
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||
|
||||
if strength_model != 1.0:
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (calculate_weight(v[1:], v[0].clone(), key, intermediate_dtype=intermediate_dtype), )
|
||||
|
||||
if len(v) == 1:
|
||||
patch_type = "diff"
|
||||
elif len(v) == 2:
|
||||
patch_type = v[0]
|
||||
v = v[1]
|
||||
|
||||
if patch_type == "diff":
|
||||
w1 = v[0]
|
||||
if strength != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += function(strength * model_management.cast_to_device(w1, weight.device, weight.dtype))
|
||||
elif patch_type == "lora": #lora/locon
|
||||
mat1 = model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
|
||||
mat2 = model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
|
||||
dora_scale = v[4]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
if v[3] is not None:
|
||||
#locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
mat3 = model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "lokr":
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dora_scale = v[8]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||
model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w1 = model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||
model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||
model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||
else:
|
||||
w2 = model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha = v[2] / dim
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "loha":
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / w1b.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
dora_scale = v[7]
|
||||
if v[5] is not None: #cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
||||
model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||
model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
||||
model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||
model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||
else:
|
||||
m1 = torch.mm(model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||
model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||
m2 = torch.mm(model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||
model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "glora":
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / v[0].shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
dora_scale = v[5]
|
||||
|
||||
a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||
|
||||
try:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
else:
|
||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||
|
||||
if old_weight is not None:
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
||||
@ -103,10 +103,7 @@ class BaseModel(torch.nn.Module):
|
||||
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
if self.manual_cast_dtype is not None:
|
||||
operations = ops.manual_cast
|
||||
else:
|
||||
operations = ops.disable_weight_init
|
||||
operations = ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
|
||||
@ -471,9 +471,15 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
||||
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
||||
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
||||
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p]
|
||||
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
|
||||
|
||||
for unet_config in supported_models:
|
||||
matches = True
|
||||
|
||||
@ -424,7 +424,7 @@ KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors"),
|
||||
HuggingFile("black-forest-labs/FLUX.1-schnell", "flux1-schnell.safetensors"),
|
||||
HuggingFile("black-forest-labs/FLUX.1-dev", "flux1-dev.safetensors"),
|
||||
], folder_name="unet")
|
||||
], folder_name="diffusion_models")
|
||||
|
||||
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||
# todo: is this correct?
|
||||
|
||||
@ -60,9 +60,14 @@ cpu_state = CPUState.GPU
|
||||
|
||||
total_vram = 0
|
||||
|
||||
lowvram_available = True
|
||||
xpu_available = False
|
||||
try:
|
||||
torch_version = torch.version.__version__
|
||||
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
|
||||
except:
|
||||
pass
|
||||
|
||||
lowvram_available = True
|
||||
if args.deterministic:
|
||||
logging.info("Using deterministic algorithms for pytorch")
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
@ -83,10 +88,10 @@ if args.directml is not None:
|
||||
try:
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
|
||||
|
||||
if torch.xpu.is_available():
|
||||
xpu_available = True
|
||||
_ = torch.xpu.device_count()
|
||||
xpu_available = torch.xpu.is_available()
|
||||
except:
|
||||
pass
|
||||
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||
|
||||
try:
|
||||
if torch.backends.mps.is_available():
|
||||
@ -224,7 +229,6 @@ VAE_DTYPES = [torch.float32]
|
||||
|
||||
try:
|
||||
if is_nvidia() or is_amd():
|
||||
torch_version = torch.version.__version__
|
||||
if int(torch_version[0]) >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
@ -356,17 +360,15 @@ class LoadedModel:
|
||||
self.model_use_more_vram(use_more_vram)
|
||||
else:
|
||||
try:
|
||||
if lowvram_model_memory > 0 and load_weights:
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
else:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, load_weights=load_weights, force_patch_weights=force_patch_weights)
|
||||
except Exception as e:
|
||||
self.model.unpatch_model(self.model.offload_device)
|
||||
self.model_unload()
|
||||
raise e
|
||||
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize:
|
||||
self.real_model = ipex.optimize(self.real_model.eval(), graph_mode=True, concat_linear=True)
|
||||
if is_intel_xpu() and not args.disable_ipex_optimize and self.real_model is not None:
|
||||
with torch.no_grad():
|
||||
self.real_model = ipex.optimize(self.real_model.eval(), inplace=True, graph_mode=True, concat_linear=True)
|
||||
|
||||
self.weights_loaded = True
|
||||
return self.real_model
|
||||
@ -421,6 +423,19 @@ def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 1.2
|
||||
|
||||
|
||||
EXTRA_RESERVED_VRAM = 200 * 1024 * 1024
|
||||
if any(platform.win32_ver()):
|
||||
EXTRA_RESERVED_VRAM = 500 * 1024 * 1024 # Windows is higher because of the shared vram issue
|
||||
|
||||
if args.reserve_vram is not None:
|
||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||
logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
|
||||
|
||||
|
||||
def extra_reserved_memory():
|
||||
return EXTRA_RESERVED_VRAM
|
||||
|
||||
|
||||
def unload_model_clones(model, unload_weights_only=True, force_unload=True) -> bool | Literal[None]:
|
||||
with model_management_lock:
|
||||
return _unload_model_clones(model, unload_weights_only, force_unload)
|
||||
@ -519,11 +534,11 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
|
||||
global vram_state
|
||||
|
||||
inference_memory = minimum_inference_memory()
|
||||
extra_mem = max(inference_memory, memory_required + 300 * 1024 * 1024)
|
||||
extra_mem = max(inference_memory, memory_required + extra_reserved_memory())
|
||||
if minimum_memory_required is None:
|
||||
minimum_memory_required = extra_mem
|
||||
else:
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required + 300 * 1024 * 1024)
|
||||
minimum_memory_required = max(inference_memory, minimum_memory_required + extra_reserved_memory())
|
||||
|
||||
models = set(models)
|
||||
|
||||
@ -636,7 +651,9 @@ def cleanup_models(keep_clone_weights_loaded=False):
|
||||
with model_management_lock:
|
||||
to_delete = []
|
||||
for i in range(len(current_loaded_models)):
|
||||
if sys.getrefcount(current_loaded_models[i].model) <= 2:
|
||||
# TODO: very fragile function needs improvement
|
||||
num_refs = sys.getrefcount(current_loaded_models[i].model)
|
||||
if num_refs <= 2:
|
||||
if not keep_clone_weights_loaded:
|
||||
to_delete = [i] + to_delete
|
||||
# TODO: find a less fragile way to do this.
|
||||
@ -749,6 +766,7 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=(torch.flo
|
||||
if bf16_supported and weight_dtype == torch.bfloat16:
|
||||
return None
|
||||
|
||||
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
|
||||
for dt in supported_dtypes:
|
||||
if dt == torch.float16 and fp16_supported:
|
||||
return torch.float16
|
||||
@ -984,7 +1002,8 @@ def pytorch_attention_flash_attention():
|
||||
def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
try:
|
||||
if platform.mac_ver()[0] in ['14.5']: # black image bug on OSX Sonoma 14.5
|
||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS
|
||||
upcast = True
|
||||
except:
|
||||
pass
|
||||
@ -1084,7 +1103,7 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_amd():
|
||||
return True
|
||||
try:
|
||||
props = torch.cuda.get_device_properties("cuda")
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 8:
|
||||
return True
|
||||
|
||||
@ -1094,16 +1113,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
logging.warning("Torch was not compiled with cuda support")
|
||||
return False
|
||||
|
||||
fp16_works = False
|
||||
# FP16 is confirmed working on a 1080 (GP104) but it's a bit slower than FP32 so it should only be enabled
|
||||
# when the model doesn't actually fit on the card
|
||||
# TODO: actually test if GP106 and others have the same type of behavior
|
||||
# FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
|
||||
nvidia_10_series = ["1080", "1070", "titan x", "p3000", "p3200", "p4000", "p4200", "p5000", "p5200", "p6000", "1060", "1050", "p40", "p100", "p6", "p4"]
|
||||
for x in nvidia_10_series:
|
||||
if x in props.name.lower():
|
||||
fp16_works = True
|
||||
return True
|
||||
|
||||
if fp16_works or manual_cast:
|
||||
if manual_cast:
|
||||
free_model_memory = maximum_vram_for_weights(device)
|
||||
if (not prioritize_performance) or model_params * 4 > free_model_memory:
|
||||
return True
|
||||
@ -1168,6 +1184,17 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
return False
|
||||
|
||||
|
||||
def supports_fp8_compute(device=None):
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major >= 9:
|
||||
return True
|
||||
if props.major < 8:
|
||||
return False
|
||||
if props.minor < 9:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
with model_management_lock:
|
||||
global cpu_state
|
||||
|
||||
@ -60,12 +60,6 @@ class ModelManageable(Protocol):
|
||||
return self.model
|
||||
|
||||
def patch_model(self, device_to: torch.device | None = None, patch_weights: bool = True) -> torch.nn.Module:
|
||||
"""
|
||||
Loads the model to the device
|
||||
:param device_to: the device to move the model weights to
|
||||
:param patch_weights: True if the patch's weights should also be moved
|
||||
:return:
|
||||
"""
|
||||
...
|
||||
|
||||
def unpatch_model(self, offload_device: torch.device | None = None, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
|
||||
@ -25,34 +25,14 @@ from typing import Optional
|
||||
import torch
|
||||
import torch.nn
|
||||
|
||||
from . import model_management
|
||||
from . import model_management, lora
|
||||
from . import utils
|
||||
from .float import stochastic_rounding
|
||||
from .model_base import BaseModel
|
||||
from .model_management_types import ModelManageable, MemoryMeasurements
|
||||
from .types import UnetWrapperFunction
|
||||
|
||||
|
||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength):
|
||||
dora_scale = model_management.cast_to_device(dora_scale, weight.device, torch.float32)
|
||||
lora_diff *= alpha
|
||||
weight_calc = weight + lora_diff.type(weight.dtype)
|
||||
weight_norm = (
|
||||
weight_calc.transpose(0, 1)
|
||||
.reshape(weight_calc.shape[1], -1)
|
||||
.norm(dim=1, keepdim=True)
|
||||
.reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
|
||||
if strength != 1.0:
|
||||
weight_calc -= weight
|
||||
weight += strength * (weight_calc)
|
||||
else:
|
||||
weight[:] = weight_calc
|
||||
return weight
|
||||
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
to = model_options["transformer_options"].copy()
|
||||
|
||||
@ -98,12 +78,12 @@ def wipe_lowvram_weight(m):
|
||||
|
||||
|
||||
class LowVramPatch:
|
||||
def __init__(self, key, model_patcher):
|
||||
def __init__(self, key, patches):
|
||||
self.key = key
|
||||
self.model_patcher = model_patcher
|
||||
self.patches = patches
|
||||
|
||||
def __call__(self, weight):
|
||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||
return lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||
|
||||
|
||||
class ModelPatcher(ModelManageable):
|
||||
@ -332,38 +312,18 @@ class ModelPatcher(ModelManageable):
|
||||
temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
out_weight = self.calculate_weight(self.patches[key], temp_weight, key).to(weight.dtype)
|
||||
out_weight = lora.calculate_weight(self.patches[key], temp_weight, key)
|
||||
out_weight = stochastic_rounding(out_weight, weight.dtype)
|
||||
if inplace_update:
|
||||
utils.copy_to_param(self.model, key, out_weight)
|
||||
else:
|
||||
utils.set_attr_param(self.model, key, out_weight)
|
||||
|
||||
def patch_model(self, device_to=None, patch_weights=True):
|
||||
for k in self.object_patches:
|
||||
old = utils.set_attr(self.model, k, self.object_patches[k])
|
||||
if k not in self.object_patches_backup:
|
||||
self.object_patches_backup[k] = old
|
||||
|
||||
if patch_weights:
|
||||
model_sd = self.model_state_dict()
|
||||
for key in self.patches:
|
||||
if key not in model_sd:
|
||||
logging.warning("could not patch. key doesn't exist in model: {}".format(key))
|
||||
continue
|
||||
|
||||
self.patch_weight_to_device(key, device_to)
|
||||
|
||||
if device_to is not None:
|
||||
self.model.to(device_to)
|
||||
self.model_device = device_to
|
||||
self._memory_measurements.model_loaded_weight_memory = self.model_size()
|
||||
|
||||
return self.model
|
||||
|
||||
def lowvram_load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
lowvram_counter = 0
|
||||
load_completely = []
|
||||
for n, m in self.model.named_modules():
|
||||
lowvram_weight = False
|
||||
|
||||
@ -372,7 +332,7 @@ class ModelPatcher(ModelManageable):
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
if m.comfy_cast_weights:
|
||||
if hasattr(m, "prev_comfy_cast_weights"): # Already lowvramed
|
||||
continue
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
@ -383,13 +343,13 @@ class ModelPatcher(ModelManageable):
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(weight_key)
|
||||
else:
|
||||
m.weight_function = LowVramPatch(weight_key, self)
|
||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(bias_key)
|
||||
else:
|
||||
m.bias_function = LowVramPatch(bias_key, self)
|
||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||
patch_counter += 1
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
@ -400,206 +360,58 @@ class ModelPatcher(ModelManageable):
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if hasattr(m, "weight"):
|
||||
mem_counter += model_management.module_size(m)
|
||||
param = list(m.parameters())
|
||||
if len(param) > 0:
|
||||
weight = param[0]
|
||||
if weight.device == device_to:
|
||||
continue
|
||||
mem_used = model_management.module_size(m)
|
||||
mem_counter += mem_used
|
||||
load_completely.append((mem_used, n, m))
|
||||
|
||||
weight_to = None
|
||||
if full_load: # TODO
|
||||
weight_to = device_to
|
||||
self.patch_weight_to_device(weight_key, device_to=weight_to) # TODO: speed this up without OOM
|
||||
self.patch_weight_to_device(bias_key, device_to=weight_to)
|
||||
m.to(device_to)
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
load_completely.sort(reverse=True)
|
||||
for x in load_completely:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
if hasattr(m, "comfy_patched_weights"):
|
||||
if m.comfy_patched_weights == True:
|
||||
continue
|
||||
|
||||
self.patch_weight_to_device(weight_key, device_to=device_to)
|
||||
self.patch_weight_to_device(bias_key, device_to=device_to)
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
m.comfy_patched_weights = True
|
||||
|
||||
for x in load_completely:
|
||||
x[2].to(device_to)
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.debug("loaded partially {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), patch_counter))
|
||||
self._memory_measurements.model_lowvram = True
|
||||
else:
|
||||
logging.debug("loaded completely {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024)))
|
||||
logging.info("loaded completely {} {} {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
self._memory_measurements.model_lowvram = False
|
||||
if full_load:
|
||||
self.model.to(device_to)
|
||||
mem_counter = self.model_size()
|
||||
|
||||
self._memory_measurements.lowvram_patch_counter += patch_counter
|
||||
self._memory_measurements.model_loaded_weight_memory = mem_counter
|
||||
|
||||
self.model_device = device_to
|
||||
self._memory_measurements.model_loaded_weight_memory = mem_counter
|
||||
|
||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||
self.patch_model(device_to, patch_weights=False)
|
||||
self.lowvram_load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
def patch_model(self, device_to=None, lowvram_model_memory=0, load_weights=True, force_patch_weights=False):
|
||||
for k in self.object_patches:
|
||||
old = utils.set_attr(self.model, k, self.object_patches[k])
|
||||
if k not in self.object_patches_backup:
|
||||
self.object_patches_backup[k] = old
|
||||
|
||||
if lowvram_model_memory == 0:
|
||||
full_load = True
|
||||
else:
|
||||
full_load = False
|
||||
|
||||
if load_weights:
|
||||
self.load(device_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights, full_load=full_load)
|
||||
return self.model
|
||||
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
for p in patches:
|
||||
strength = p[0]
|
||||
v = p[1]
|
||||
strength_model = p[2]
|
||||
offset = p[3]
|
||||
function = p[4]
|
||||
if function is None:
|
||||
function = lambda a: a
|
||||
|
||||
old_weight = None
|
||||
if offset is not None:
|
||||
old_weight = weight
|
||||
weight = weight.narrow(offset[0], offset[1], offset[2])
|
||||
|
||||
if strength_model != 1.0:
|
||||
weight *= strength_model
|
||||
|
||||
if isinstance(v, list):
|
||||
v = (self.calculate_weight(v[1:], v[0].clone(), key),)
|
||||
|
||||
patch_type = "diff"
|
||||
if len(v) == 2:
|
||||
patch_type = v[0]
|
||||
v = v[1]
|
||||
elif len(v) != 1:
|
||||
logging.warning("patch {} not recognized: {}".format(key, v))
|
||||
continue
|
||||
|
||||
if patch_type == "diff":
|
||||
w1 = v[0]
|
||||
if strength != 0.0:
|
||||
if w1.shape != weight.shape:
|
||||
logging.warning("WARNING SHAPE MISMATCH {} WEIGHT NOT MERGED {} != {}".format(key, w1.shape, weight.shape))
|
||||
else:
|
||||
weight += function(strength * model_management.cast_to_device(w1, weight.device, weight.dtype))
|
||||
elif patch_type == "lora": # lora/locon
|
||||
mat1 = model_management.cast_to_device(v[0], weight.device, torch.float32)
|
||||
mat2 = model_management.cast_to_device(v[1], weight.device, torch.float32)
|
||||
dora_scale = v[4]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / mat2.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
if v[3] is not None:
|
||||
# locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||
mat3 = model_management.cast_to_device(v[3], weight.device, torch.float32)
|
||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||
mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
|
||||
try:
|
||||
lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "lokr":
|
||||
w1 = v[0]
|
||||
w2 = v[1]
|
||||
w1_a = v[3]
|
||||
w1_b = v[4]
|
||||
w2_a = v[5]
|
||||
w2_b = v[6]
|
||||
t2 = v[7]
|
||||
dora_scale = v[8]
|
||||
dim = None
|
||||
|
||||
if w1 is None:
|
||||
dim = w1_b.shape[0]
|
||||
w1 = torch.mm(model_management.cast_to_device(w1_a, weight.device, torch.float32),
|
||||
model_management.cast_to_device(w1_b, weight.device, torch.float32))
|
||||
else:
|
||||
w1 = model_management.cast_to_device(w1, weight.device, torch.float32)
|
||||
|
||||
if w2 is None:
|
||||
dim = w2_b.shape[0]
|
||||
if t2 is None:
|
||||
w2 = torch.mm(model_management.cast_to_device(w2_a, weight.device, torch.float32),
|
||||
model_management.cast_to_device(w2_b, weight.device, torch.float32))
|
||||
else:
|
||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
model_management.cast_to_device(t2, weight.device, torch.float32),
|
||||
model_management.cast_to_device(w2_b, weight.device, torch.float32),
|
||||
model_management.cast_to_device(w2_a, weight.device, torch.float32))
|
||||
else:
|
||||
w2 = model_management.cast_to_device(w2, weight.device, torch.float32)
|
||||
|
||||
if len(w2.shape) == 4:
|
||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||
if v[2] is not None and dim is not None:
|
||||
alpha = v[2] / dim
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
try:
|
||||
lora_diff = torch.kron(w1, w2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "loha":
|
||||
w1a = v[0]
|
||||
w1b = v[1]
|
||||
if v[2] is not None:
|
||||
alpha = v[2] / w1b.shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
w2a = v[3]
|
||||
w2b = v[4]
|
||||
dora_scale = v[7]
|
||||
if v[5] is not None: # cp decomposition
|
||||
t1 = v[5]
|
||||
t2 = v[6]
|
||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
model_management.cast_to_device(t1, weight.device, torch.float32),
|
||||
model_management.cast_to_device(w1b, weight.device, torch.float32),
|
||||
model_management.cast_to_device(w1a, weight.device, torch.float32))
|
||||
|
||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||
model_management.cast_to_device(t2, weight.device, torch.float32),
|
||||
model_management.cast_to_device(w2b, weight.device, torch.float32),
|
||||
model_management.cast_to_device(w2a, weight.device, torch.float32))
|
||||
else:
|
||||
m1 = torch.mm(model_management.cast_to_device(w1a, weight.device, torch.float32),
|
||||
model_management.cast_to_device(w1b, weight.device, torch.float32))
|
||||
m2 = torch.mm(model_management.cast_to_device(w2a, weight.device, torch.float32),
|
||||
model_management.cast_to_device(w2b, weight.device, torch.float32))
|
||||
|
||||
try:
|
||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
elif patch_type == "glora":
|
||||
if v[4] is not None:
|
||||
alpha = v[4] / v[0].shape[0]
|
||||
else:
|
||||
alpha = 1.0
|
||||
|
||||
dora_scale = v[5]
|
||||
|
||||
a1 = model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, torch.float32)
|
||||
a2 = model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b1 = model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, torch.float32)
|
||||
b2 = model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, torch.float32)
|
||||
|
||||
try:
|
||||
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
|
||||
if dora_scale is not None:
|
||||
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength))
|
||||
else:
|
||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||
except Exception as e:
|
||||
logging.error("ERROR {} {} {}".format(patch_type, key, e))
|
||||
else:
|
||||
logging.warning("patch type not recognized {} {}".format(patch_type, key))
|
||||
|
||||
if old_weight is not None:
|
||||
weight = old_weight
|
||||
|
||||
return weight
|
||||
|
||||
def unpatch_model(self, device_to=None, unpatch_weights=True):
|
||||
if unpatch_weights:
|
||||
if self._memory_measurements.model_lowvram:
|
||||
@ -625,6 +437,10 @@ class ModelPatcher(ModelManageable):
|
||||
self.model_device = device_to
|
||||
self._memory_measurements.model_loaded_weight_memory = 0
|
||||
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, "comfy_patched_weights"):
|
||||
del m.comfy_patched_weights
|
||||
|
||||
keys = list(self.object_patches_backup.keys())
|
||||
for k in keys:
|
||||
utils.set_attr(self.model, k, self.object_patches_backup[k])
|
||||
@ -634,39 +450,47 @@ class ModelPatcher(ModelManageable):
|
||||
def partially_unload(self, device_to, memory_to_free=0):
|
||||
memory_freed = 0
|
||||
patch_counter = 0
|
||||
unload_list = []
|
||||
|
||||
for n, m in list(self.model.named_modules())[::-1]:
|
||||
if memory_to_free < memory_freed:
|
||||
break
|
||||
|
||||
for n, m in self.model.named_modules():
|
||||
shift_lowvram = False
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
module_mem = model_management.module_size(m)
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
unload_list.append((module_mem, n, m))
|
||||
|
||||
if m.weight is not None and m.weight.device != device_to:
|
||||
for key in [weight_key, bias_key]:
|
||||
bk = self.backup.get(key, None)
|
||||
if bk is not None:
|
||||
if bk.inplace_update:
|
||||
utils.copy_to_param(self.model, key, bk.weight)
|
||||
else:
|
||||
utils.set_attr_param(self.model, key, bk.weight)
|
||||
self.backup.pop(key)
|
||||
unload_list.sort()
|
||||
for unload in unload_list:
|
||||
if memory_to_free < memory_freed:
|
||||
break
|
||||
module_mem = unload[0]
|
||||
n = unload[1]
|
||||
m = unload[2]
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
m.to(device_to)
|
||||
if weight_key in self.patches:
|
||||
m.weight_function = LowVramPatch(weight_key, self)
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
m.bias_function = LowVramPatch(bias_key, self)
|
||||
patch_counter += 1
|
||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||
for key in [weight_key, bias_key]:
|
||||
bk = self.backup.get(key, None)
|
||||
if bk is not None:
|
||||
if bk.inplace_update:
|
||||
utils.copy_to_param(self.model, key, bk.weight)
|
||||
else:
|
||||
utils.set_attr_param(self.model, key, bk.weight)
|
||||
self.backup.pop(key)
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
memory_freed += module_mem
|
||||
logging.debug("freed {}".format(n))
|
||||
m.to(device_to)
|
||||
if weight_key in self.patches:
|
||||
m.weight_function = LowVramPatch(weight_key, self.patches)
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
m.bias_function = LowVramPatch(bias_key, self.patches)
|
||||
patch_counter += 1
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
memory_freed += module_mem
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
self._memory_measurements.model_lowvram = True
|
||||
self._memory_measurements.lowvram_patch_counter += patch_counter
|
||||
@ -675,14 +499,14 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0):
|
||||
self.unpatch_model(unpatch_weights=False)
|
||||
self.patch_model(patch_weights=False)
|
||||
self.patch_model(load_weights=False)
|
||||
full_load = False
|
||||
if not self._memory_measurements.model_lowvram:
|
||||
return 0
|
||||
if self._memory_measurements.model_loaded_weight_memory + extra_memory > self.model_size():
|
||||
full_load = True
|
||||
current_used = self._memory_measurements.model_loaded_weight_memory
|
||||
self.lowvram_load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
||||
self.load(device_to, lowvram_model_memory=current_used + extra_memory, full_load=full_load)
|
||||
return self._memory_measurements.model_loaded_weight_memory - current_used
|
||||
|
||||
def current_loaded_device(self):
|
||||
@ -697,3 +521,7 @@ class ModelPatcher(ModelManageable):
|
||||
return f"<ModelPatcher for {self.ckpt_name} ({self.model.__class__.__name__})>"
|
||||
else:
|
||||
return f"<ModelPatcher for {self.model.__class__.__name__}>"
|
||||
|
||||
def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32):
|
||||
print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead")
|
||||
return lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||
|
||||
@ -853,7 +853,7 @@ class ControlNetApplyAdvanced:
|
||||
class UNETLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (get_filename_list_with_downloadable("unet", KNOWN_UNET_MODELS),),
|
||||
return {"required": { "unet_name": (get_filename_list_with_downloadable("diffusion_models", KNOWN_UNET_MODELS),),
|
||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e5m2"],)
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
@ -868,7 +868,7 @@ class UNETLoader:
|
||||
elif weight_dtype == "fp8_e5m2":
|
||||
model_options["dtype"] = torch.float8_e5m2
|
||||
|
||||
unet_path = get_or_download("unet", unet_name, KNOWN_UNET_MODELS)
|
||||
unet_path = get_or_download("diffusion_models", unet_name, KNOWN_UNET_MODELS)
|
||||
model = sd.load_diffusion_model(unet_path, model_options=model_options)
|
||||
return (model,)
|
||||
|
||||
|
||||
93
comfy/ops.py
93
comfy/ops.py
@ -19,31 +19,44 @@
|
||||
import torch
|
||||
|
||||
from . import model_management
|
||||
from .cli_args import args
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
|
||||
if device is None or weight.device == device:
|
||||
if not copy:
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
return r
|
||||
|
||||
|
||||
def cast_to(weight, dtype=None, device=None, non_blocking=False):
|
||||
return weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
def cast_to_input(weight, input, non_blocking=False, copy=True):
|
||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def cast_to_input(weight, input, non_blocking=False):
|
||||
return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking)
|
||||
|
||||
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None):
|
||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
|
||||
if input is not None:
|
||||
if dtype is None:
|
||||
dtype = input.dtype
|
||||
if bias_dtype is None:
|
||||
bias_dtype = dtype
|
||||
if device is None:
|
||||
device = input.device
|
||||
|
||||
bias = None
|
||||
non_blocking = model_management.device_should_use_non_blocking(device)
|
||||
non_blocking = model_management.device_supports_non_blocking(device)
|
||||
if s.bias is not None:
|
||||
bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking)
|
||||
if s.bias_function is not None:
|
||||
has_function = s.bias_function is not None
|
||||
bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
bias = s.bias_function(bias)
|
||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking)
|
||||
if s.weight_function is not None:
|
||||
|
||||
has_function = s.weight_function is not None
|
||||
weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
|
||||
if has_function:
|
||||
weight = s.weight_function(weight)
|
||||
return weight, bias
|
||||
|
||||
@ -242,3 +255,59 @@ class manual_cast(disable_weight_init):
|
||||
|
||||
class Embedding(disable_weight_init.Embedding):
|
||||
comfy_cast_weights = True
|
||||
|
||||
|
||||
def fp8_linear(self, input):
|
||||
dtype = self.weight.dtype
|
||||
if dtype not in [torch.float8_e4m3fn]:
|
||||
return None
|
||||
|
||||
if len(input.shape) == 3:
|
||||
inn = input.reshape(-1, input.shape[2]).to(dtype)
|
||||
non_blocking = model_management.device_supports_non_blocking(input.device)
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||
w = w.t()
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
if scale_weight is None:
|
||||
scale_weight = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
if scale_input is None:
|
||||
scale_input = scale_weight
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((1), device=input.device, dtype=torch.float32)
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
else:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||
return None
|
||||
|
||||
class fp8_ops(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def reset_parameters(self):
|
||||
self.scale_weight = None
|
||||
self.scale_input = None
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias = cast_bias_weight(self, input)
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None):
|
||||
if compute_dtype is None or weight_dtype == compute_dtype:
|
||||
return disable_weight_init
|
||||
if args.fast:
|
||||
if model_management.supports_fp8_compute(load_device):
|
||||
return fp8_ops
|
||||
return manual_cast
|
||||
|
||||
46
comfy/sd.py
46
comfy/sd.py
@ -33,6 +33,7 @@ from .text_encoders import hydit
|
||||
from .text_encoders import sa_t5
|
||||
from .text_encoders import sd2_clip
|
||||
from .text_encoders import sd3_clip
|
||||
from .text_encoders import long_clipl
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
||||
@ -66,7 +67,7 @@ def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None = None, parameters=0):
|
||||
def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None = None, parameters=0, model_options={}):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = dict()
|
||||
if no_init:
|
||||
@ -77,12 +78,17 @@ class CLIP:
|
||||
|
||||
load_device = model_management.text_encoder_device()
|
||||
offload_device = model_management.text_encoder_offload_device()
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
dtype = model_options.get("dtype", None)
|
||||
if dtype is None:
|
||||
dtype = model_management.text_encoder_dtype(load_device)
|
||||
|
||||
params['dtype'] = dtype
|
||||
params['device'] = model_management.text_encoder_initial_device(load_device, offload_device, parameters * model_management.dtype_size(dtype))
|
||||
if "textmodel_json_config" not in params and textmodel_json_config is not None:
|
||||
params['textmodel_json_config'] = textmodel_json_config
|
||||
|
||||
params['model_options'] = model_options
|
||||
|
||||
self.cond_stage_model = clip(**(params))
|
||||
|
||||
for dt in self.cond_stage_model.dtypes:
|
||||
@ -416,10 +422,18 @@ class CLIPTarget:
|
||||
tokenizer: Optional[Any] = None
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, textmodel_json_config: str | dict | None = None):
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, textmodel_json_config: str | dict | None = None, model_options=None):
|
||||
if model_options is None:
|
||||
model_options = dict()
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
clip_data.append(utils.load_torch_file(p, safe_load=True))
|
||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, textmodel_json_config=textmodel_json_config)
|
||||
|
||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, textmodel_json_config=None):
|
||||
clip_data = state_dicts
|
||||
class EmptyClass:
|
||||
pass
|
||||
|
||||
for i in range(len(clip_data)):
|
||||
if "transformer.resblocks.0.ln_1.weight" in clip_data[i]:
|
||||
@ -454,8 +468,13 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
clip_target.clip = sa_t5.SAT5Model
|
||||
clip_target.tokenizer = sa_t5.SAT5Tokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
w = clip_data[0].get("text_model.embeddings.position_embedding.weight", None)
|
||||
if w is not None and w.shape[0] == 248:
|
||||
clip_target.clip = long_clipl.LongClipModel
|
||||
clip_target.tokenizer = long_clipl.LongClipTokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
clip_target.tokenizer = sd1_clip.SD1Tokenizer
|
||||
elif len(clip_data) == 2:
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=True, t5=False)
|
||||
@ -483,7 +502,7 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI
|
||||
for c in clip_data:
|
||||
parameters += utils.calculate_parameters(c)
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters, model_options=model_options)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
@ -529,16 +548,14 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
|
||||
return (model, clip, vae)
|
||||
|
||||
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}):
|
||||
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
|
||||
sd = utils.load_torch_file(ckpt_path)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, ckpt_path=ckpt_path)
|
||||
out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
|
||||
if out is None:
|
||||
raise RuntimeError("Could not detect model type of: {}".format(ckpt_path))
|
||||
return out
|
||||
|
||||
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, ckpt_path: str | None = None):
|
||||
def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, ckpt_path=""):
|
||||
clip = None
|
||||
clipvision = None
|
||||
vae = None
|
||||
@ -589,7 +606,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
parameters = utils.calculate_parameters(clip_sd)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
@ -697,10 +714,13 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
if clip is not None:
|
||||
load_models.append(clip.load_model())
|
||||
clip_sd = clip.get_sd()
|
||||
vae_sd = None
|
||||
if vae is not None:
|
||||
vae_sd = vae.get_sd()
|
||||
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
for k in extra_keys:
|
||||
sd[k] = extra_keys[k]
|
||||
|
||||
|
||||
@ -94,7 +94,6 @@ class ClipTokenWeightEncoder:
|
||||
|
||||
|
||||
class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = [
|
||||
"last",
|
||||
"pooled",
|
||||
@ -104,7 +103,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_length=77,
|
||||
freeze=True, layer="last", layer_idx=None, textmodel_json_config: str | dict | None = None, dtype=None, model_class=clip_model.CLIPTextModel,
|
||||
special_tokens=None, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||
return_projected_pooled=True, return_attention_masks=False): # clip-vit-base-patch32
|
||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
if special_tokens is None:
|
||||
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
||||
@ -112,7 +111,11 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
|
||||
config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json", package=__package__)
|
||||
|
||||
self.operations = ops.manual_cast
|
||||
operations = model_options.get("custom_operations", None)
|
||||
if operations is None:
|
||||
operations = ops.manual_cast
|
||||
|
||||
self.operations = operations
|
||||
self.transformer = model_class(config, dtype, device, self.operations)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
@ -680,9 +683,12 @@ class SD1Tokenizer:
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
class SD1CheckpointClipModel(SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, textmodel_json_config=None):
|
||||
super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config)
|
||||
|
||||
class SD1ClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None, clip_name="l", clip_model=SDClipModel, textmodel_json_config=None, name=None, **kwargs):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, clip_name="l", clip_model=SD1CheckpointClipModel, textmodel_json_config=None, name=None, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
if name is not None:
|
||||
@ -692,7 +698,7 @@ class SD1ClipModel(torch.nn.Module):
|
||||
self.clip_name = clip_name
|
||||
self.clip = "clip_{}".format(self.clip_name)
|
||||
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, textmodel_json_config=textmodel_json_config, **kwargs))
|
||||
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config, **kwargs))
|
||||
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
|
||||
@ -7,14 +7,14 @@ from .component_model.files import get_path_as_dict
|
||||
|
||||
|
||||
class SDXLClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None, model_options={}):
|
||||
if layer == "penultimate":
|
||||
layer = "hidden"
|
||||
layer_idx = -2
|
||||
|
||||
textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
@ -50,11 +50,11 @@ class SDXLTokenizer:
|
||||
|
||||
|
||||
class SDXLClipModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype)
|
||||
self.dtypes = set([dtype])
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
|
||||
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes = {dtype}
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.clip_l.set_clip_options(options)
|
||||
@ -79,8 +79,8 @@ class SDXLClipModel(torch.nn.Module):
|
||||
|
||||
|
||||
class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, textmodel_json_config=textmodel_json_config)
|
||||
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=SDXLClipG, model_options=model_options, textmodel_json_config=textmodel_json_config)
|
||||
|
||||
|
||||
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
|
||||
@ -94,15 +94,15 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
|
||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None, model_options={}):
|
||||
textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return super().load_sd(sd)
|
||||
|
||||
|
||||
class StableCascadeClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, textmodel_json_config=textmodel_json_config)
|
||||
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="g", clip_model=StableCascadeClipG, textmodel_json_config=textmodel_json_config, model_options=model_options)
|
||||
|
||||
@ -654,6 +654,7 @@ class Flux(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref)
|
||||
dtype_t5 = None
|
||||
if t5_key in state_dict:
|
||||
dtype_t5 = state_dict[t5_key].dtype
|
||||
else:
|
||||
|
||||
@ -6,7 +6,9 @@ from ..text_encoders import t5
|
||||
from ..component_model.files import get_path_as_dict
|
||||
|
||||
class PT5XlModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
|
||||
if model_options is None:
|
||||
model_options = dict()
|
||||
textmodel_json_config = get_path_as_dict(textmodel_json_config, "t5_pile_config_xl.json", package=__package__)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 2, "pad": 1}, model_class=t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
@ -25,5 +27,5 @@ class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
|
||||
class AuraT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="pile_t5xl", clip_model=PT5XlModel, **kwargs)
|
||||
|
||||
@ -9,7 +9,9 @@ from ..component_model import files
|
||||
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
|
||||
if model_options is None:
|
||||
model_options = dict()
|
||||
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json", package=__package__)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5)
|
||||
|
||||
@ -46,11 +48,11 @@ class FluxTokenizer:
|
||||
|
||||
|
||||
class FluxClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None):
|
||||
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_t5 = model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||
self.dtypes = set([dtype, dtype_t5])
|
||||
|
||||
def set_clip_options(self, options):
|
||||
@ -78,7 +80,6 @@ class FluxClipModel(torch.nn.Module):
|
||||
|
||||
def flux_clip(dtype_t5=None):
|
||||
class FluxClipModel_(FluxClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype)
|
||||
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||
return FluxClipModel_
|
||||
|
||||
@ -11,7 +11,9 @@ from ..component_model.files import get_path_as_dict, get_package_as_path
|
||||
|
||||
|
||||
class HyditBertModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
|
||||
if model_options is None:
|
||||
model_options = dict()
|
||||
textmodel_json_config = get_path_as_dict(textmodel_json_config, "hydit_clip.json", package=__package__)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
|
||||
|
||||
@ -23,7 +25,9 @@ class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
|
||||
class MT5XLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
|
||||
if model_options is None:
|
||||
model_options = dict()
|
||||
textmodel_json_config = get_path_as_dict(textmodel_json_config, "mt5_config_xl.json", package=__package__)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, return_attention_masks=True)
|
||||
|
||||
@ -66,10 +70,12 @@ class HyditTokenizer:
|
||||
|
||||
|
||||
class HyditModel(torch.nn.Module):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||
super().__init__()
|
||||
self.hydit_clip = HyditBertModel(dtype=dtype)
|
||||
self.mt5xl = MT5XLModel(dtype=dtype)
|
||||
if model_options is None:
|
||||
model_options = dict()
|
||||
self.hydit_clip = HyditBertModel(dtype=dtype, model_options=model_options)
|
||||
self.mt5xl = MT5XLModel(dtype=dtype, model_options=model_options)
|
||||
|
||||
self.dtypes = set()
|
||||
if dtype is not None:
|
||||
|
||||
25
comfy/text_encoders/long_clipl.json
Normal file
25
comfy/text_encoders/long_clipl.json
Normal file
@ -0,0 +1,25 @@
|
||||
{
|
||||
"_name_or_path": "openai/clip-vit-large-patch14",
|
||||
"architectures": [
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"attention_dropout": 0.0,
|
||||
"bos_token_id": 0,
|
||||
"dropout": 0.0,
|
||||
"eos_token_id": 49407,
|
||||
"hidden_act": "quick_gelu",
|
||||
"hidden_size": 768,
|
||||
"initializer_factor": 1.0,
|
||||
"initializer_range": 0.02,
|
||||
"intermediate_size": 3072,
|
||||
"layer_norm_eps": 1e-05,
|
||||
"max_position_embeddings": 248,
|
||||
"model_type": "clip_text_model",
|
||||
"num_attention_heads": 12,
|
||||
"num_hidden_layers": 12,
|
||||
"pad_token_id": 1,
|
||||
"projection_dim": 768,
|
||||
"torch_dtype": "float32",
|
||||
"transformers_version": "4.24.0",
|
||||
"vocab_size": 49408
|
||||
}
|
||||
19
comfy/text_encoders/long_clipl.py
Normal file
19
comfy/text_encoders/long_clipl.py
Normal file
@ -0,0 +1,19 @@
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
|
||||
class LongClipTokenizer_(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
class LongClipModel_(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
|
||||
super().__init__(device=device, textmodel_json_config=textmodel_json_config, return_projected_pooled=False, dtype=dtype, model_options=model_options)
|
||||
|
||||
class LongClipTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
|
||||
|
||||
class LongClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
|
||||
@ -6,9 +6,11 @@ from ..component_model import files
|
||||
|
||||
|
||||
class T5BaseModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
|
||||
if model_options is None:
|
||||
model_options = dict()
|
||||
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_base.json", package=__package__)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, zero_out_masked=True)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, model_options=model_options, enable_attention_masks=True, zero_out_masked=True)
|
||||
|
||||
|
||||
class T5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||
@ -25,5 +27,5 @@ class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
|
||||
class SAT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, name="t5base", clip_model=T5BaseModel, **kwargs)
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="t5base", clip_model=T5BaseModel, **kwargs)
|
||||
|
||||
@ -4,13 +4,13 @@ from ..component_model.files import get_path_as_dict
|
||||
|
||||
|
||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
def __init__(self, arch="ViT-H-14", device="cpu", max_length=77, freeze=True, layer="penultimate", layer_idx=None, dtype=None, textmodel_json_config=None, model_options={}):
|
||||
if layer == "penultimate":
|
||||
layer = "hidden"
|
||||
layer_idx = -2
|
||||
|
||||
textmodel_json_config = get_path_as_dict(textmodel_json_config, "sd2_clip_config.json", package=__package__)
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=True, model_options=model_options)
|
||||
|
||||
|
||||
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
|
||||
@ -26,5 +26,5 @@ class SD2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
|
||||
class SD2ClipModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, textmodel_json_config=textmodel_json_config, **kwargs)
|
||||
def __init__(self, device="cpu", dtype=None, textmodel_json_config=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, clip_name="h", clip_model=SD2ClipHModel, model_options=model_options, textmodel_json_config=textmodel_json_config, **kwargs)
|
||||
|
||||
@ -11,9 +11,9 @@ from ..component_model import files
|
||||
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, textmodel_json_config=None, model_options={}):
|
||||
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json", package=__package__)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5)
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, model_options=model_options)
|
||||
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
@ -21,7 +21,7 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = dict()
|
||||
tokenizer_path = files.get_package_as_path("comfy.text_encoders.t5_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
|
||||
|
||||
|
||||
class SD3Tokenizer:
|
||||
@ -50,24 +50,24 @@ class SD3Tokenizer:
|
||||
|
||||
|
||||
class SD3ClipModel(torch.nn.Module):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None):
|
||||
def __init__(self, clip_l=True, clip_g=True, t5=True, dtype_t5=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
self.dtypes = set()
|
||||
if clip_l:
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False)
|
||||
self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_l = None
|
||||
|
||||
if clip_g:
|
||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype)
|
||||
self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
|
||||
self.dtypes.add(dtype)
|
||||
else:
|
||||
self.clip_g = None
|
||||
|
||||
if t5:
|
||||
dtype_t5 = model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5)
|
||||
self.t5xxl = T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
|
||||
self.dtypes.add(dtype_t5)
|
||||
else:
|
||||
self.t5xxl = None
|
||||
@ -145,7 +145,6 @@ class SD3ClipModel(torch.nn.Module):
|
||||
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None):
|
||||
class SD3ClipModel_(SD3ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None):
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype)
|
||||
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||
return SD3ClipModel_
|
||||
|
||||
1
comfy/web/assets/index--0nRVkuV.js.map
generated
vendored
1
comfy/web/assets/index--0nRVkuV.js.map
generated
vendored
File diff suppressed because one or more lines are too long
53618
comfy/web/assets/index-D8Zp4vRl.js → comfy/web/assets/index-CaD4RONs.js
generated
vendored
53618
comfy/web/assets/index-D8Zp4vRl.js → comfy/web/assets/index-CaD4RONs.js
generated
vendored
File diff suppressed because one or more lines are too long
1
comfy/web/assets/index-CaD4RONs.js.map
generated
vendored
Normal file
1
comfy/web/assets/index-CaD4RONs.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
comfy/web/assets/index-D8Zp4vRl.js.map
generated
vendored
1
comfy/web/assets/index-D8Zp4vRl.js.map
generated
vendored
File diff suppressed because one or more lines are too long
500
comfy/web/assets/index-BHzRuMlR.css → comfy/web/assets/index-DAK31IJJ.css
generated
vendored
500
comfy/web/assets/index-BHzRuMlR.css → comfy/web/assets/index-DAK31IJJ.css
generated
vendored
@ -1,8 +1,8 @@
|
||||
@font-face {
|
||||
font-family: 'primeicons';
|
||||
font-display: block;
|
||||
src: url('/assets/primeicons-DMOk5skT.eot');
|
||||
src: url('/assets/primeicons-DMOk5skT.eot?#iefix') format('embedded-opentype'), url('/assets/primeicons-C6QP2o4f.woff2') format('woff2'), url('/assets/primeicons-WjwUDZjB.woff') format('woff'), url('/assets/primeicons-MpK4pl85.ttf') format('truetype'), url('/assets/primeicons-Dr5RGzOO.svg?#primeicons') format('svg');
|
||||
src: url('./primeicons-DMOk5skT.eot');
|
||||
src: url('./primeicons-DMOk5skT.eot?#iefix') format('embedded-opentype'), url('./primeicons-C6QP2o4f.woff2') format('woff2'), url('./primeicons-WjwUDZjB.woff') format('woff'), url('./primeicons-MpK4pl85.ttf') format('truetype'), url('./primeicons-Dr5RGzOO.svg?#primeicons') format('svg');
|
||||
font-weight: normal;
|
||||
font-style: normal;
|
||||
}
|
||||
@ -1330,6 +1330,184 @@
|
||||
.comfyui-body-right .side-bar-button.side-bar-button-selected[data-v-7a0b94a3]:hover {
|
||||
border-right: 4px solid var(--p-button-text-primary-color);
|
||||
}
|
||||
|
||||
:root {
|
||||
--red-600: #dc3545;
|
||||
}
|
||||
|
||||
.comfy-missing-nodes[data-v-286402f2] {
|
||||
font-family: monospace;
|
||||
color: var(--red-600);
|
||||
padding: 1.5rem;
|
||||
background-color: var(--surface-ground);
|
||||
border-radius: var(--border-radius);
|
||||
box-shadow: var(--card-shadow);
|
||||
}
|
||||
.warning-title[data-v-286402f2] {
|
||||
margin-top: 0;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.warning-description[data-v-286402f2] {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.missing-nodes-list[data-v-286402f2] {
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.missing-nodes-list.maximized[data-v-286402f2] {
|
||||
max-height: unset;
|
||||
}
|
||||
.missing-node-item[data-v-286402f2] {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 0.5rem;
|
||||
}
|
||||
.node-type[data-v-286402f2] {
|
||||
font-weight: 600;
|
||||
color: var(--text-color);
|
||||
}
|
||||
.node-hint[data-v-286402f2] {
|
||||
margin-left: 0.5rem;
|
||||
font-style: italic;
|
||||
color: var(--text-color-secondary);
|
||||
}
|
||||
[data-v-286402f2] .p-button {
|
||||
margin-left: auto;
|
||||
}
|
||||
.added-nodes-warning[data-v-286402f2] {
|
||||
margin-top: 1rem;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.input-slider[data-v-fbaf7a8c] {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 1rem;
|
||||
}
|
||||
.slider-part[data-v-fbaf7a8c] {
|
||||
flex-grow: 1;
|
||||
}
|
||||
.input-part[data-v-fbaf7a8c] {
|
||||
width: 5rem !important;
|
||||
}
|
||||
|
||||
.info-chip[data-v-6361f2fb] {
|
||||
background: transparent;
|
||||
}
|
||||
.setting-item[data-v-6361f2fb] {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.setting-label[data-v-6361f2fb] {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
flex: 1;
|
||||
}
|
||||
.setting-input[data-v-6361f2fb] {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
margin-left: 1rem;
|
||||
}
|
||||
|
||||
/* Ensure PrimeVue components take full width of their container */
|
||||
.setting-input[data-v-6361f2fb] .p-inputtext,
|
||||
.setting-input[data-v-6361f2fb] .input-slider,
|
||||
.setting-input[data-v-6361f2fb] .p-select,
|
||||
.setting-input[data-v-6361f2fb] .p-togglebutton {
|
||||
width: 100%;
|
||||
max-width: 200px;
|
||||
}
|
||||
.setting-input[data-v-6361f2fb] .p-inputtext {
|
||||
max-width: unset;
|
||||
}
|
||||
|
||||
/* Special case for ToggleSwitch to align it to the right */
|
||||
.setting-input[data-v-6361f2fb] .p-toggleswitch {
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.search-box-input[data-v-8160f15b] {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.no-results-placeholder[data-v-5a7d148a] {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100%;
|
||||
padding: 2rem;
|
||||
}
|
||||
.no-results-placeholder[data-v-5a7d148a] .p-card {
|
||||
background-color: var(--surface-ground);
|
||||
text-align: center;
|
||||
}
|
||||
.no-results-placeholder h3[data-v-5a7d148a] {
|
||||
color: var(--text-color);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
.no-results-placeholder p[data-v-5a7d148a] {
|
||||
color: var(--text-color-secondary);
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
/* Remove after we have tailwind setup */
|
||||
.border-none {
|
||||
border: none !important;
|
||||
}
|
||||
.settings-tab-panels {
|
||||
padding-top: 0px !important;
|
||||
}
|
||||
|
||||
.settings-container[data-v-29723d1f] {
|
||||
display: flex;
|
||||
height: 70vh;
|
||||
width: 60vw;
|
||||
max-width: 1000px;
|
||||
overflow: hidden;
|
||||
/* Prevents container from scrolling */
|
||||
}
|
||||
.settings-sidebar[data-v-29723d1f] {
|
||||
width: 250px;
|
||||
flex-shrink: 0;
|
||||
/* Prevents sidebar from shrinking */
|
||||
overflow-y: auto;
|
||||
padding: 10px;
|
||||
}
|
||||
.settings-search-box[data-v-29723d1f] {
|
||||
width: 100%;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.settings-content[data-v-29723d1f] {
|
||||
flex-grow: 1;
|
||||
overflow-y: auto;
|
||||
/* Allows vertical scrolling */
|
||||
}
|
||||
|
||||
/* Ensure the Listbox takes full width of the sidebar */
|
||||
.settings-sidebar[data-v-29723d1f] .p-listbox {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
/* Optional: Style scrollbars for webkit browsers */
|
||||
.settings-sidebar[data-v-29723d1f]::-webkit-scrollbar,
|
||||
.settings-content[data-v-29723d1f]::-webkit-scrollbar {
|
||||
width: 1px;
|
||||
}
|
||||
.settings-sidebar[data-v-29723d1f]::-webkit-scrollbar-thumb,
|
||||
.settings-content[data-v-29723d1f]::-webkit-scrollbar-thumb {
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
.pi-cog[data-v-969a1066] {
|
||||
font-size: 1.25rem;
|
||||
margin-right: 0.5rem;
|
||||
}
|
||||
.version-tag[data-v-969a1066] {
|
||||
margin-left: 0.5rem;
|
||||
}
|
||||
.lds-ring {
|
||||
display: inline-block;
|
||||
position: relative;
|
||||
@ -2881,6 +3059,7 @@ body {
|
||||
#graph-canvas {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
touch-action: none;
|
||||
}
|
||||
|
||||
.comfyui-body-right {
|
||||
@ -3482,184 +3661,6 @@ audio.comfy-audio.empty-audio-widget {
|
||||
max-width: 25vw;
|
||||
}
|
||||
|
||||
:root {
|
||||
--red-600: #dc3545;
|
||||
}
|
||||
|
||||
.comfy-missing-nodes[data-v-286402f2] {
|
||||
font-family: monospace;
|
||||
color: var(--red-600);
|
||||
padding: 1.5rem;
|
||||
background-color: var(--surface-ground);
|
||||
border-radius: var(--border-radius);
|
||||
box-shadow: var(--card-shadow);
|
||||
}
|
||||
.warning-title[data-v-286402f2] {
|
||||
margin-top: 0;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.warning-description[data-v-286402f2] {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.missing-nodes-list[data-v-286402f2] {
|
||||
max-height: 300px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.missing-nodes-list.maximized[data-v-286402f2] {
|
||||
max-height: unset;
|
||||
}
|
||||
.missing-node-item[data-v-286402f2] {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
padding: 0.5rem;
|
||||
}
|
||||
.node-type[data-v-286402f2] {
|
||||
font-weight: 600;
|
||||
color: var(--text-color);
|
||||
}
|
||||
.node-hint[data-v-286402f2] {
|
||||
margin-left: 0.5rem;
|
||||
font-style: italic;
|
||||
color: var(--text-color-secondary);
|
||||
}
|
||||
[data-v-286402f2] .p-button {
|
||||
margin-left: auto;
|
||||
}
|
||||
.added-nodes-warning[data-v-286402f2] {
|
||||
margin-top: 1rem;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.input-slider[data-v-fbaf7a8c] {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 1rem;
|
||||
}
|
||||
.slider-part[data-v-fbaf7a8c] {
|
||||
flex-grow: 1;
|
||||
}
|
||||
.input-part[data-v-fbaf7a8c] {
|
||||
width: 5rem !important;
|
||||
}
|
||||
|
||||
.info-chip[data-v-4feeb3d2] {
|
||||
background: transparent;
|
||||
}
|
||||
.setting-item[data-v-4feeb3d2] {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.setting-label[data-v-4feeb3d2] {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
flex: 1;
|
||||
}
|
||||
.setting-input[data-v-4feeb3d2] {
|
||||
flex: 1;
|
||||
display: flex;
|
||||
justify-content: flex-end;
|
||||
margin-left: 1rem;
|
||||
}
|
||||
|
||||
/* Ensure PrimeVue components take full width of their container */
|
||||
.setting-input[data-v-4feeb3d2] .p-inputtext,
|
||||
.setting-input[data-v-4feeb3d2] .input-slider,
|
||||
.setting-input[data-v-4feeb3d2] .p-select,
|
||||
.setting-input[data-v-4feeb3d2] .p-togglebutton {
|
||||
width: 100%;
|
||||
max-width: 200px;
|
||||
}
|
||||
.setting-input[data-v-4feeb3d2] .p-inputtext {
|
||||
max-width: unset;
|
||||
}
|
||||
|
||||
/* Special case for ToggleSwitch to align it to the right */
|
||||
.setting-input[data-v-4feeb3d2] .p-toggleswitch {
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.search-box-input[data-v-3bbe5335] {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.no-results-placeholder[data-v-5a7d148a] {
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
height: 100%;
|
||||
padding: 2rem;
|
||||
}
|
||||
.no-results-placeholder[data-v-5a7d148a] .p-card {
|
||||
background-color: var(--surface-ground);
|
||||
text-align: center;
|
||||
}
|
||||
.no-results-placeholder h3[data-v-5a7d148a] {
|
||||
color: var(--text-color);
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
.no-results-placeholder p[data-v-5a7d148a] {
|
||||
color: var(--text-color-secondary);
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
/* Remove after we have tailwind setup */
|
||||
.border-none {
|
||||
border: none !important;
|
||||
}
|
||||
.settings-tab-panels {
|
||||
padding-top: 0px !important;
|
||||
}
|
||||
|
||||
.settings-container[data-v-833dbfbb] {
|
||||
display: flex;
|
||||
height: 70vh;
|
||||
width: 60vw;
|
||||
max-width: 1000px;
|
||||
overflow: hidden;
|
||||
/* Prevents container from scrolling */
|
||||
}
|
||||
.settings-sidebar[data-v-833dbfbb] {
|
||||
width: 250px;
|
||||
flex-shrink: 0;
|
||||
/* Prevents sidebar from shrinking */
|
||||
overflow-y: auto;
|
||||
padding: 10px;
|
||||
}
|
||||
.settings-search-box[data-v-833dbfbb] {
|
||||
width: 100%;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
.settings-content[data-v-833dbfbb] {
|
||||
flex-grow: 1;
|
||||
overflow-y: auto;
|
||||
/* Allows vertical scrolling */
|
||||
}
|
||||
|
||||
/* Ensure the Listbox takes full width of the sidebar */
|
||||
.settings-sidebar[data-v-833dbfbb] .p-listbox {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
/* Optional: Style scrollbars for webkit browsers */
|
||||
.settings-sidebar[data-v-833dbfbb]::-webkit-scrollbar,
|
||||
.settings-content[data-v-833dbfbb]::-webkit-scrollbar {
|
||||
width: 1px;
|
||||
}
|
||||
.settings-sidebar[data-v-833dbfbb]::-webkit-scrollbar-thumb,
|
||||
.settings-content[data-v-833dbfbb]::-webkit-scrollbar-thumb {
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
.pi-cog[data-v-969a1066] {
|
||||
font-size: 1.25rem;
|
||||
margin-right: 0.5rem;
|
||||
}
|
||||
.version-tag[data-v-969a1066] {
|
||||
margin-left: 0.5rem;
|
||||
}
|
||||
|
||||
:root {
|
||||
--sidebar-width: 64px;
|
||||
--sidebar-icon-size: 1.5rem;
|
||||
@ -3869,77 +3870,81 @@ audio.comfy-audio.empty-audio-widget {
|
||||
color: var(--error-text);
|
||||
}
|
||||
|
||||
.comfy-vue-node-search-container[data-v-b8a4ffdc] {
|
||||
.comfy-vue-node-search-container[data-v-ba2c5897] {
|
||||
display: flex;
|
||||
width: 100%;
|
||||
min-width: 24rem;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
}
|
||||
.comfy-vue-node-search-container[data-v-b8a4ffdc] * {
|
||||
.comfy-vue-node-search-container[data-v-ba2c5897] * {
|
||||
pointer-events: auto;
|
||||
}
|
||||
.comfy-vue-node-preview-container[data-v-b8a4ffdc] {
|
||||
.comfy-vue-node-preview-container[data-v-ba2c5897] {
|
||||
position: absolute;
|
||||
left: -350px;
|
||||
top: 50px;
|
||||
}
|
||||
.comfy-vue-node-search-box[data-v-b8a4ffdc] {
|
||||
.comfy-vue-node-search-box[data-v-ba2c5897] {
|
||||
z-index: 10;
|
||||
flex-grow: 1;
|
||||
}
|
||||
.option-container[data-v-b8a4ffdc] {
|
||||
.option-container[data-v-ba2c5897] {
|
||||
display: flex;
|
||||
width: 100%;
|
||||
cursor: pointer;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
overflow: hidden;
|
||||
padding-left: 1rem;
|
||||
padding-right: 1rem;
|
||||
padding-top: 0.5rem;
|
||||
padding-bottom: 0.5rem;
|
||||
padding-left: 0.5rem;
|
||||
padding-right: 0.5rem;
|
||||
padding-top: 0px;
|
||||
padding-bottom: 0px;
|
||||
}
|
||||
.option-display-name[data-v-b8a4ffdc] {
|
||||
.option-display-name[data-v-ba2c5897] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
font-weight: 600;
|
||||
}
|
||||
.option-category[data-v-b8a4ffdc] {
|
||||
.option-category[data-v-ba2c5897] {
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
font-size: 0.875rem;
|
||||
line-height: 1.25rem;
|
||||
font-weight: 300;
|
||||
--tw-text-opacity: 1;
|
||||
color: rgb(156 163 175 / var(--tw-text-opacity));
|
||||
/* Keeps the text on a single line by default */
|
||||
white-space: nowrap;
|
||||
}
|
||||
.i-badge[data-v-b8a4ffdc] {
|
||||
.i-badge[data-v-ba2c5897] {
|
||||
--tw-bg-opacity: 1;
|
||||
background-color: rgb(34 197 94 / var(--tw-bg-opacity));
|
||||
--tw-text-opacity: 1;
|
||||
color: rgb(255 255 255 / var(--tw-text-opacity));
|
||||
}
|
||||
.o-badge[data-v-b8a4ffdc] {
|
||||
.o-badge[data-v-ba2c5897] {
|
||||
--tw-bg-opacity: 1;
|
||||
background-color: rgb(239 68 68 / var(--tw-bg-opacity));
|
||||
--tw-text-opacity: 1;
|
||||
color: rgb(255 255 255 / var(--tw-text-opacity));
|
||||
}
|
||||
.c-badge[data-v-b8a4ffdc] {
|
||||
.c-badge[data-v-ba2c5897] {
|
||||
--tw-bg-opacity: 1;
|
||||
background-color: rgb(59 130 246 / var(--tw-bg-opacity));
|
||||
--tw-text-opacity: 1;
|
||||
color: rgb(255 255 255 / var(--tw-text-opacity));
|
||||
}
|
||||
.s-badge[data-v-b8a4ffdc] {
|
||||
.s-badge[data-v-ba2c5897] {
|
||||
--tw-bg-opacity: 1;
|
||||
background-color: rgb(234 179 8 / var(--tw-bg-opacity));
|
||||
}
|
||||
[data-v-b8a4ffdc] .highlight {
|
||||
[data-v-ba2c5897] .highlight {
|
||||
background-color: var(--p-primary-color);
|
||||
color: var(--p-primary-contrast-color);
|
||||
font-weight: bold;
|
||||
border-radius: 0.25rem;
|
||||
padding: 0.125rem 0.25rem;
|
||||
padding: 0rem 0.125rem;
|
||||
margin: -0.125rem 0.125rem;
|
||||
}
|
||||
|
||||
@ -3971,46 +3976,57 @@ audio.comfy-audio.empty-audio-widget {
|
||||
z-index: 99999;
|
||||
}
|
||||
|
||||
.result-container[data-v-0fac61d9] {
|
||||
.broken-image[data-v-1a883642] {
|
||||
display: none;
|
||||
}
|
||||
.broken-image-placeholder[data-v-1a883642] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
margin: 2rem;
|
||||
}
|
||||
.broken-image-placeholder i[data-v-1a883642] {
|
||||
font-size: 3rem;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
|
||||
.result-container[data-v-7ceacc88] {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
aspect-ratio: 1 / 1;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
}
|
||||
[data-v-0fac61d9] img {
|
||||
position: absolute;
|
||||
top: 50%;
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
[data-v-7ceacc88] .task-output-image {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
-o-object-fit: cover;
|
||||
object-fit: cover;
|
||||
-o-object-position: center;
|
||||
object-position: center;
|
||||
}
|
||||
.p-image-preview[data-v-0fac61d9] {
|
||||
position: static;
|
||||
display: contents;
|
||||
}
|
||||
[data-v-0fac61d9] .image-preview-mask {
|
||||
.image-preview-mask[data-v-7ceacc88] {
|
||||
position: absolute;
|
||||
left: 50%;
|
||||
top: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
width: auto;
|
||||
height: auto;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
opacity: 0;
|
||||
padding: 10px;
|
||||
cursor: pointer;
|
||||
background: rgba(0, 0, 0, 0.5);
|
||||
color: var(--p-image-preview-mask-color);
|
||||
transition:
|
||||
opacity var(--p-image-transition-duration),
|
||||
background var(--p-image-transition-duration);
|
||||
border-radius: 50%;
|
||||
transition: opacity 0.3s ease;
|
||||
}
|
||||
.result-container:hover .image-preview-mask[data-v-7ceacc88] {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
.task-result-preview[data-v-6cf8179c] {
|
||||
.task-result-preview[data-v-7c099cb7] {
|
||||
aspect-ratio: 1 / 1;
|
||||
overflow: hidden;
|
||||
display: flex;
|
||||
@ -4019,18 +4035,18 @@ audio.comfy-audio.empty-audio-widget {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
.task-result-preview i[data-v-6cf8179c],
|
||||
.task-result-preview span[data-v-6cf8179c] {
|
||||
.task-result-preview i[data-v-7c099cb7],
|
||||
.task-result-preview span[data-v-7c099cb7] {
|
||||
font-size: 2rem;
|
||||
}
|
||||
.task-item[data-v-6cf8179c] {
|
||||
.task-item[data-v-7c099cb7] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
border-radius: 4px;
|
||||
overflow: hidden;
|
||||
position: relative;
|
||||
}
|
||||
.task-item-details[data-v-6cf8179c] {
|
||||
.task-item-details[data-v-7c099cb7] {
|
||||
position: absolute;
|
||||
bottom: 0;
|
||||
padding: 0.6rem;
|
||||
@ -4041,12 +4057,23 @@ audio.comfy-audio.empty-audio-widget {
|
||||
|
||||
/* In dark mode, transparent background color for tags is not ideal for tags that
|
||||
are floating on top of images. */
|
||||
.tag-wrapper[data-v-6cf8179c] {
|
||||
.tag-wrapper[data-v-7c099cb7] {
|
||||
background-color: var(--p-primary-contrast-color);
|
||||
border-radius: 6px;
|
||||
display: inline-flex;
|
||||
}
|
||||
|
||||
/* PrimeVue's galleria teleports the fullscreen gallery out of subtree so we
|
||||
cannot use scoped style here. */
|
||||
img.galleria-image {
|
||||
max-width: 100vw;
|
||||
max-height: 100vh;
|
||||
-o-object-fit: contain;
|
||||
object-fit: contain;
|
||||
/* Set z-index so the close button doesn't get hidden behind the image when image is large */
|
||||
z-index: -1;
|
||||
}
|
||||
|
||||
.comfy-vue-side-bar-container[data-v-bde767d2] {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
@ -4078,14 +4105,33 @@ are floating on top of images. */
|
||||
background-color: transparent;
|
||||
}
|
||||
|
||||
.queue-grid[data-v-7f831ee9] {
|
||||
.scroll-container[data-v-bd027c46] {
|
||||
height: 100%;
|
||||
overflow-y: auto;
|
||||
}
|
||||
.queue-grid[data-v-bd027c46] {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
|
||||
padding: 0.5rem;
|
||||
gap: 0.5rem;
|
||||
}
|
||||
|
||||
.spinner[data-v-40c18658] {
|
||||
.node-lib-tree-node-label {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
margin-left: var(--p-tree-node-gap);
|
||||
}
|
||||
|
||||
[data-v-11e12183] .node-lib-search-box {
|
||||
margin-left: 1rem;
|
||||
margin-right: 1rem;
|
||||
margin-top: 1rem;
|
||||
}
|
||||
[data-v-11e12183] .comfy-vue-side-bar-body {
|
||||
background: var(--p-tree-background);
|
||||
}
|
||||
|
||||
.spinner[data-v-afce9bd6] {
|
||||
position: absolute;
|
||||
inset: 0px;
|
||||
display: flex;
|
||||
136
comfy/web/assets/index--0nRVkuV.js → comfy/web/assets/index-DkvOTKox.js
generated
vendored
136
comfy/web/assets/index--0nRVkuV.js → comfy/web/assets/index-DkvOTKox.js
generated
vendored
@ -48,7 +48,7 @@ var __async = (__this, __arguments, generator) => {
|
||||
});
|
||||
};
|
||||
var _PrimitiveNode_instances, onFirstConnection_fn, createWidget_fn, mergeWidgetConfig_fn, isValidConnection_fn, removeWidgets_fn, _convertedToProcess;
|
||||
import { C as ComfyDialog, $ as $el, a as ComfyApp, b as app, L as LGraphCanvas, c as LiteGraph, d as applyTextReplacements, e as ComfyWidgets, f as addValueControlWidgets, D as DraggableList, g as api, h as LGraphGroup, i as LGraphNode } from "./index-D8Zp4vRl.js";
|
||||
import { C as ComfyDialog, $ as $el, a as ComfyApp, b as app, L as LGraphCanvas, c as LiteGraph, d as applyTextReplacements, e as ComfyWidgets, f as addValueControlWidgets, D as DraggableList, g as api, u as useToastStore, h as LGraphGroup, i as LGraphNode } from "./index-CaD4RONs.js";
|
||||
const _ClipspaceDialog = class _ClipspaceDialog extends ComfyDialog {
|
||||
static registerButton(name, contextPredicate, callback) {
|
||||
const item = $el("button", {
|
||||
@ -891,6 +891,7 @@ app.registerExtension({
|
||||
});
|
||||
app.ui.settings.addSetting({
|
||||
id: id$4,
|
||||
category: ["Comfy", "ColorPalette"],
|
||||
name: "Color Palette",
|
||||
type: /* @__PURE__ */ __name((name, setter, value) => {
|
||||
const options = [
|
||||
@ -923,12 +924,6 @@ app.registerExtension({
|
||||
options
|
||||
);
|
||||
return $el("tr", [
|
||||
$el("td", [
|
||||
$el("label", {
|
||||
for: id$4.replaceAll(".", "-"),
|
||||
textContent: "Color palette"
|
||||
})
|
||||
]),
|
||||
$el("td", [
|
||||
els.select,
|
||||
$el(
|
||||
@ -1221,7 +1216,6 @@ app.registerExtension({
|
||||
const floatWeight = parseFloat(weight);
|
||||
if (isNaN(floatWeight)) return weight;
|
||||
const newWeight = floatWeight + delta;
|
||||
if (newWeight < 0) return "0";
|
||||
return String(Number(newWeight.toFixed(10)));
|
||||
}
|
||||
__name(incrementWeight, "incrementWeight");
|
||||
@ -1302,7 +1296,7 @@ app.registerExtension({
|
||||
selectedText = addWeightToParentheses(selectedText);
|
||||
const weightDelta = event.key === "ArrowUp" ? delta : -delta;
|
||||
const updatedText = selectedText.replace(
|
||||
/\((.*):(\d+(?:\.\d+)?)\)/,
|
||||
/\((.*):([+-]?\d+(?:\.\d+)?)\)/,
|
||||
(match, text, weight) => {
|
||||
weight = incrementWeight(weight, weightDelta);
|
||||
if (weight == 1) {
|
||||
@ -1620,9 +1614,9 @@ function getWidgetConfig(slot) {
|
||||
}
|
||||
__name(getWidgetConfig, "getWidgetConfig");
|
||||
function getConfig(widgetName) {
|
||||
var _a, _b, _c, _d;
|
||||
var _a, _b, _c, _d, _e;
|
||||
const { nodeData } = this.constructor;
|
||||
return (_d = (_a = nodeData == null ? void 0 : nodeData.input) == null ? void 0 : _a.required[widgetName]) != null ? _d : (_c = (_b = nodeData == null ? void 0 : nodeData.input) == null ? void 0 : _b.optional) == null ? void 0 : _c[widgetName];
|
||||
return (_e = (_b = (_a = nodeData == null ? void 0 : nodeData.input) == null ? void 0 : _a.required) == null ? void 0 : _b[widgetName]) != null ? _e : (_d = (_c = nodeData == null ? void 0 : nodeData.input) == null ? void 0 : _c.optional) == null ? void 0 : _d[widgetName];
|
||||
}
|
||||
__name(getConfig, "getConfig");
|
||||
function isConvertibleWidget(widget, config) {
|
||||
@ -1854,8 +1848,7 @@ app.registerExtension({
|
||||
init() {
|
||||
useConversionSubmenusSetting = app.ui.settings.addSetting({
|
||||
id: "Comfy.NodeInputConversionSubmenus",
|
||||
name: "Node widget/input conversion sub-menus",
|
||||
tooltip: "In the node context menu, place the entries that convert between input/widget in sub-menus.",
|
||||
name: "In the node context menu, place the entries that convert between input/widget in sub-menus.",
|
||||
type: "boolean",
|
||||
defaultValue: true
|
||||
});
|
||||
@ -3957,7 +3950,8 @@ app.registerExtension({
|
||||
}, "replace");
|
||||
app.ui.settings.addSetting({
|
||||
id: id$2,
|
||||
name: "Invert Menu Scrolling",
|
||||
category: ["Comfy", "Graph", "InvertMenuScrolling"],
|
||||
name: "Invert Context Menu Scrolling",
|
||||
type: "boolean",
|
||||
defaultValue: false,
|
||||
onChange(value) {
|
||||
@ -3974,58 +3968,66 @@ app.registerExtension({
|
||||
name: "Comfy.Keybinds",
|
||||
init() {
|
||||
const keybindListener = /* @__PURE__ */ __name(function(event) {
|
||||
const modifierPressed = event.ctrlKey || event.metaKey;
|
||||
if (modifierPressed && event.key === "Enter") {
|
||||
if (event.altKey) {
|
||||
api.interrupt();
|
||||
return __async(this, null, function* () {
|
||||
const modifierPressed = event.ctrlKey || event.metaKey;
|
||||
if (modifierPressed && event.key === "Enter") {
|
||||
if (event.altKey) {
|
||||
yield api.interrupt();
|
||||
useToastStore().add({
|
||||
severity: "info",
|
||||
summary: "Interrupted",
|
||||
detail: "Execution has been interrupted",
|
||||
life: 1e3
|
||||
});
|
||||
return;
|
||||
}
|
||||
app.queuePrompt(event.shiftKey ? -1 : 0).then();
|
||||
return;
|
||||
}
|
||||
app.queuePrompt(event.shiftKey ? -1 : 0).then();
|
||||
return;
|
||||
}
|
||||
const target = event.composedPath()[0];
|
||||
if (["INPUT", "TEXTAREA"].includes(target.tagName)) {
|
||||
return;
|
||||
}
|
||||
const modifierKeyIdMap = {
|
||||
s: "#comfy-save-button",
|
||||
o: "#comfy-file-input",
|
||||
Backspace: "#comfy-clear-button",
|
||||
d: "#comfy-load-default-button"
|
||||
};
|
||||
const modifierKeybindId = modifierKeyIdMap[event.key];
|
||||
if (modifierPressed && modifierKeybindId) {
|
||||
event.preventDefault();
|
||||
const elem = document.querySelector(modifierKeybindId);
|
||||
elem.click();
|
||||
return;
|
||||
}
|
||||
if (event.ctrlKey || event.altKey || event.metaKey) {
|
||||
return;
|
||||
}
|
||||
if (event.key === "Escape") {
|
||||
const modals = document.querySelectorAll(".comfy-modal");
|
||||
const modal = Array.from(modals).find(
|
||||
(modal2) => window.getComputedStyle(modal2).getPropertyValue("display") !== "none"
|
||||
);
|
||||
if (modal) {
|
||||
modal.style.display = "none";
|
||||
const target = event.composedPath()[0];
|
||||
if (["INPUT", "TEXTAREA"].includes(target.tagName)) {
|
||||
return;
|
||||
}
|
||||
;
|
||||
[...document.querySelectorAll("dialog")].forEach((d) => {
|
||||
d.close();
|
||||
});
|
||||
}
|
||||
const keyIdMap = {
|
||||
q: ".queue-tab-button.side-bar-button",
|
||||
h: ".queue-tab-button.side-bar-button",
|
||||
r: "#comfy-refresh-button"
|
||||
};
|
||||
const buttonId = keyIdMap[event.key];
|
||||
if (buttonId) {
|
||||
const button = document.querySelector(buttonId);
|
||||
button.click();
|
||||
}
|
||||
const modifierKeyIdMap = {
|
||||
s: "#comfy-save-button",
|
||||
o: "#comfy-file-input",
|
||||
Backspace: "#comfy-clear-button",
|
||||
d: "#comfy-load-default-button"
|
||||
};
|
||||
const modifierKeybindId = modifierKeyIdMap[event.key];
|
||||
if (modifierPressed && modifierKeybindId) {
|
||||
event.preventDefault();
|
||||
const elem = document.querySelector(modifierKeybindId);
|
||||
elem.click();
|
||||
return;
|
||||
}
|
||||
if (event.ctrlKey || event.altKey || event.metaKey) {
|
||||
return;
|
||||
}
|
||||
if (event.key === "Escape") {
|
||||
const modals = document.querySelectorAll(".comfy-modal");
|
||||
const modal = Array.from(modals).find(
|
||||
(modal2) => window.getComputedStyle(modal2).getPropertyValue("display") !== "none"
|
||||
);
|
||||
if (modal) {
|
||||
modal.style.display = "none";
|
||||
}
|
||||
;
|
||||
[...document.querySelectorAll("dialog")].forEach((d) => {
|
||||
d.close();
|
||||
});
|
||||
}
|
||||
const keyIdMap = {
|
||||
q: ".queue-tab-button.side-bar-button",
|
||||
h: ".queue-tab-button.side-bar-button",
|
||||
r: "#comfy-refresh-button"
|
||||
};
|
||||
const buttonId = keyIdMap[event.key];
|
||||
if (buttonId) {
|
||||
const button = document.querySelector(buttonId);
|
||||
button.click();
|
||||
}
|
||||
});
|
||||
}, "keybindListener");
|
||||
window.addEventListener("keydown", keybindListener, true);
|
||||
}
|
||||
@ -4037,6 +4039,7 @@ const ext = {
|
||||
return __async(this, null, function* () {
|
||||
app2.ui.settings.addSetting({
|
||||
id: id$1,
|
||||
category: ["Comfy", "Graph", "LinkRenderMode"],
|
||||
name: "Link Render Mode",
|
||||
defaultValue: 2,
|
||||
type: "combo",
|
||||
@ -5684,7 +5687,9 @@ app.registerExtension({
|
||||
LiteGraph.middle_click_slot_add_default_node = true;
|
||||
this.suggestionsNumber = app.ui.settings.addSetting({
|
||||
id: "Comfy.NodeSuggestions.number",
|
||||
category: ["Comfy", "Node Search Box", "NodeSuggestions"],
|
||||
name: "Number of nodes suggestions",
|
||||
tooltip: "Only for litegraph searchbox/context menu",
|
||||
type: "slider",
|
||||
attrs: {
|
||||
min: 1,
|
||||
@ -5766,7 +5771,8 @@ app.registerExtension({
|
||||
init() {
|
||||
app.ui.settings.addSetting({
|
||||
id: "Comfy.SnapToGrid.GridSize",
|
||||
name: "Grid Size",
|
||||
category: ["Comfy", "Graph", "GridSize"],
|
||||
name: "Snap to gird size",
|
||||
type: "slider",
|
||||
attrs: {
|
||||
min: 1,
|
||||
@ -6176,4 +6182,4 @@ app.registerExtension({
|
||||
};
|
||||
}
|
||||
});
|
||||
//# sourceMappingURL=index--0nRVkuV.js.map
|
||||
//# sourceMappingURL=index-DkvOTKox.js.map
|
||||
1
comfy/web/assets/index-DkvOTKox.js.map
generated
vendored
Normal file
1
comfy/web/assets/index-DkvOTKox.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -20,7 +20,7 @@ var __async = (__this, __arguments, generator) => {
|
||||
step((generator = generator.apply(__this, __arguments)).next());
|
||||
});
|
||||
};
|
||||
import { j as createSpinner, g as api, $ as $el } from "./index-D8Zp4vRl.js";
|
||||
import { j as createSpinner, g as api, $ as $el } from "./index-CaD4RONs.js";
|
||||
const _UserSelectionScreen = class _UserSelectionScreen {
|
||||
show(users, user) {
|
||||
return __async(this, null, function* () {
|
||||
@ -139,4 +139,4 @@ window.comfyAPI.userSelection.UserSelectionScreen = UserSelectionScreen;
|
||||
export {
|
||||
UserSelectionScreen
|
||||
};
|
||||
//# sourceMappingURL=userSelection-CH4RQEqW.js.map
|
||||
//# sourceMappingURL=userSelection-GRU1gtOt.js.map
|
||||
File diff suppressed because one or more lines are too long
4
comfy/web/index.html
vendored
4
comfy/web/index.html
vendored
@ -14,8 +14,8 @@
|
||||
</style> -->
|
||||
<link rel="stylesheet" type="text/css" href="user.css" />
|
||||
<link rel="stylesheet" type="text/css" href="materialdesignicons.min.css" />
|
||||
<script type="module" crossorigin src="/assets/index-D8Zp4vRl.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="/assets/index-BHzRuMlR.css">
|
||||
<script type="module" crossorigin src="./assets/index-CaD4RONs.js"></script>
|
||||
<link rel="stylesheet" crossorigin href="./assets/index-DAK31IJJ.css">
|
||||
</head>
|
||||
<body class="litegraph">
|
||||
<div id="vue-app"></div>
|
||||
|
||||
@ -332,6 +332,25 @@ class VAESave:
|
||||
utils.save_torch_file(vae.get_sd(), output_checkpoint, metadata=metadata)
|
||||
return {}
|
||||
|
||||
class ModelSave:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"filename_prefix": ("STRING", {"default": "diffusion_models/ComfyUI"}),},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},}
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save"
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "advanced/model_merging"
|
||||
|
||||
def save(self, model, filename_prefix, prompt=None, extra_pnginfo=None):
|
||||
save_checkpoint(model, filename_prefix=filename_prefix, output_dir=self.output_dir, prompt=prompt, extra_pnginfo=extra_pnginfo)
|
||||
return {}
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"ModelMergeSimple": ModelMergeSimple,
|
||||
"ModelMergeBlocks": ModelMergeBlocks,
|
||||
@ -343,4 +362,9 @@ NODE_CLASS_MAPPINGS = {
|
||||
"CLIPMergeAdd": CLIPAdd,
|
||||
"CLIPSave": CLIPSave,
|
||||
"VAESave": VAESave,
|
||||
"ModelSave": ModelSave,
|
||||
}
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"CheckpointSave": "Save Checkpoint",
|
||||
}
|
||||
|
||||
@ -4,14 +4,14 @@ class Example:
|
||||
|
||||
Class methods
|
||||
-------------
|
||||
INPUT_TYPES (dict):
|
||||
INPUT_TYPES (dict):
|
||||
Tell the main program input parameters of nodes.
|
||||
IS_CHANGED:
|
||||
optional method to control when the node is re executed.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
RETURN_TYPES (`tuple`):
|
||||
RETURN_TYPES (`tuple`):
|
||||
The type of each element in the output tuple.
|
||||
RETURN_NAMES (`tuple`):
|
||||
Optional: The name of each output in the output tuple.
|
||||
@ -23,13 +23,19 @@ class Example:
|
||||
Assumed to be False if not present.
|
||||
CATEGORY (`str`):
|
||||
The category the node should appear in the UI.
|
||||
DEPRECATED (`bool`):
|
||||
Indicates whether the node is deprecated. Deprecated nodes are hidden by default in the UI, but remain
|
||||
functional in existing workflows that use them.
|
||||
EXPERIMENTAL (`bool`):
|
||||
Indicates whether the node is experimental. Experimental nodes are marked as such in the UI and may be subject to
|
||||
significant changes or removal in future versions. Use with caution in production workflows.
|
||||
execute(s) -> tuple || None:
|
||||
The entry point method. The name of this method must be the same as the value of property `FUNCTION`.
|
||||
For example, if `FUNCTION = "execute"` then this method's name must be `execute`, if `FUNCTION = "foo"` then it must be `foo`.
|
||||
"""
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
"""
|
||||
|
||||
115
tests-unit/server/routes/internal_routes_test.py
Normal file
115
tests-unit/server/routes/internal_routes_test.py
Normal file
@ -0,0 +1,115 @@
|
||||
import pytest
|
||||
from aiohttp import web
|
||||
from unittest.mock import MagicMock, patch
|
||||
from api_server.routes.internal.internal_routes import InternalRoutes
|
||||
from api_server.services.file_service import FileService
|
||||
from folder_paths import models_dir, user_directory, output_directory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def internal_routes():
|
||||
return InternalRoutes()
|
||||
|
||||
@pytest.fixture
|
||||
def aiohttp_client_factory(aiohttp_client, internal_routes):
|
||||
async def _get_client():
|
||||
app = internal_routes.get_app()
|
||||
return await aiohttp_client(app)
|
||||
return _get_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_valid_directory(aiohttp_client_factory, internal_routes):
|
||||
mock_file_list = [
|
||||
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
|
||||
{"name": "dir1", "path": "dir1", "type": "directory"}
|
||||
]
|
||||
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
|
||||
client = await aiohttp_client_factory()
|
||||
resp = await client.get('/files?directory=models')
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert 'files' in data
|
||||
assert len(data['files']) == 2
|
||||
assert data['files'] == mock_file_list
|
||||
|
||||
# Check other valid directories
|
||||
resp = await client.get('/files?directory=user')
|
||||
assert resp.status == 200
|
||||
resp = await client.get('/files?directory=output')
|
||||
assert resp.status == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_invalid_directory(aiohttp_client_factory, internal_routes):
|
||||
internal_routes.file_service.list_files = MagicMock(side_effect=ValueError("Invalid directory key"))
|
||||
client = await aiohttp_client_factory()
|
||||
resp = await client.get('/files?directory=invalid')
|
||||
assert resp.status == 400
|
||||
data = await resp.json()
|
||||
assert 'error' in data
|
||||
assert data['error'] == "Invalid directory key"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_exception(aiohttp_client_factory, internal_routes):
|
||||
internal_routes.file_service.list_files = MagicMock(side_effect=Exception("Unexpected error"))
|
||||
client = await aiohttp_client_factory()
|
||||
resp = await client.get('/files?directory=models')
|
||||
assert resp.status == 500
|
||||
data = await resp.json()
|
||||
assert 'error' in data
|
||||
assert data['error'] == "Unexpected error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_files_no_directory_param(aiohttp_client_factory, internal_routes):
|
||||
mock_file_list = []
|
||||
internal_routes.file_service.list_files = MagicMock(return_value=mock_file_list)
|
||||
client = await aiohttp_client_factory()
|
||||
resp = await client.get('/files')
|
||||
assert resp.status == 200
|
||||
data = await resp.json()
|
||||
assert 'files' in data
|
||||
assert len(data['files']) == 0
|
||||
|
||||
def test_setup_routes(internal_routes):
|
||||
internal_routes.setup_routes()
|
||||
routes = internal_routes.routes
|
||||
assert any(route.method == 'GET' and str(route.path) == '/files' for route in routes)
|
||||
|
||||
def test_get_app(internal_routes):
|
||||
app = internal_routes.get_app()
|
||||
assert isinstance(app, web.Application)
|
||||
assert internal_routes._app is not None
|
||||
|
||||
def test_get_app_reuse(internal_routes):
|
||||
app1 = internal_routes.get_app()
|
||||
app2 = internal_routes.get_app()
|
||||
assert app1 is app2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_added_to_app(aiohttp_client_factory, internal_routes):
|
||||
client = await aiohttp_client_factory()
|
||||
try:
|
||||
resp = await client.get('/files')
|
||||
print(f"Response received: status {resp.status}")
|
||||
except Exception as e:
|
||||
print(f"Exception occurred during GET request: {e}")
|
||||
raise
|
||||
|
||||
assert resp.status != 404, "Route /files does not exist"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_file_service_initialization():
|
||||
with patch('api_server.routes.internal.internal_routes.FileService') as MockFileService:
|
||||
# Create a mock instance
|
||||
mock_file_service_instance = MagicMock(spec=FileService)
|
||||
MockFileService.return_value = mock_file_service_instance
|
||||
internal_routes = InternalRoutes()
|
||||
|
||||
# Check if FileService was initialized with the correct parameters
|
||||
MockFileService.assert_called_once_with({
|
||||
"models": models_dir,
|
||||
"user": user_directory,
|
||||
"output": output_directory
|
||||
})
|
||||
|
||||
# Verify that the file_service attribute of InternalRoutes is set
|
||||
assert internal_routes.file_service == mock_file_service_instance
|
||||
54
tests-unit/server/services/file_service_test.py
Normal file
54
tests-unit/server/services/file_service_test.py
Normal file
@ -0,0 +1,54 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
from api_server.services.file_service import FileService
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file_system_ops():
|
||||
return MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def file_service(mock_file_system_ops):
|
||||
allowed_directories = {
|
||||
"models": "/path/to/models",
|
||||
"user": "/path/to/user",
|
||||
"output": "/path/to/output"
|
||||
}
|
||||
return FileService(allowed_directories, file_system_ops=mock_file_system_ops)
|
||||
|
||||
def test_list_files_valid_directory(file_service, mock_file_system_ops):
|
||||
mock_file_system_ops.walk_directory.return_value = [
|
||||
{"name": "file1.txt", "path": "file1.txt", "type": "file", "size": 100},
|
||||
{"name": "dir1", "path": "dir1", "type": "directory"}
|
||||
]
|
||||
|
||||
result = file_service.list_files("models")
|
||||
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "file1.txt"
|
||||
assert result[1]["name"] == "dir1"
|
||||
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
|
||||
|
||||
def test_list_files_invalid_directory(file_service):
|
||||
# Does not support walking directories outside of the allowed directories
|
||||
with pytest.raises(ValueError, match="Invalid directory key"):
|
||||
file_service.list_files("invalid_key")
|
||||
|
||||
def test_list_files_empty_directory(file_service, mock_file_system_ops):
|
||||
mock_file_system_ops.walk_directory.return_value = []
|
||||
|
||||
result = file_service.list_files("models")
|
||||
|
||||
assert len(result) == 0
|
||||
mock_file_system_ops.walk_directory.assert_called_once_with("/path/to/models")
|
||||
|
||||
@pytest.mark.parametrize("directory_key", ["models", "user", "output"])
|
||||
def test_list_files_all_allowed_directories(file_service, mock_file_system_ops, directory_key):
|
||||
mock_file_system_ops.walk_directory.return_value = [
|
||||
{"name": f"file_{directory_key}.txt", "path": f"file_{directory_key}.txt", "type": "file", "size": 100}
|
||||
]
|
||||
|
||||
result = file_service.list_files(directory_key)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == f"file_{directory_key}.txt"
|
||||
mock_file_system_ops.walk_directory.assert_called_once_with(f"/path/to/{directory_key}")
|
||||
42
tests-unit/server/utils/file_operations_test.py
Normal file
42
tests-unit/server/utils/file_operations_test.py
Normal file
@ -0,0 +1,42 @@
|
||||
import pytest
|
||||
from typing import List
|
||||
from api_server.utils.file_operations import FileSystemOperations, FileSystemItem, is_file_info
|
||||
|
||||
@pytest.fixture
|
||||
def temp_directory(tmp_path):
|
||||
# Create a temporary directory structure
|
||||
dir1 = tmp_path / "dir1"
|
||||
dir2 = tmp_path / "dir2"
|
||||
dir1.mkdir()
|
||||
dir2.mkdir()
|
||||
(dir1 / "file1.txt").write_text("content1")
|
||||
(dir2 / "file2.txt").write_text("content2")
|
||||
(tmp_path / "file3.txt").write_text("content3")
|
||||
return tmp_path
|
||||
|
||||
def test_walk_directory(temp_directory):
|
||||
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
|
||||
|
||||
assert len(result) == 5 # 2 directories and 3 files
|
||||
|
||||
files = [item for item in result if item['type'] == 'file']
|
||||
dirs = [item for item in result if item['type'] == 'directory']
|
||||
|
||||
assert len(files) == 3
|
||||
assert len(dirs) == 2
|
||||
|
||||
file_names = {file['name'] for file in files}
|
||||
assert file_names == {'file1.txt', 'file2.txt', 'file3.txt'}
|
||||
|
||||
dir_names = {dir['name'] for dir in dirs}
|
||||
assert dir_names == {'dir1', 'dir2'}
|
||||
|
||||
def test_walk_directory_empty(tmp_path):
|
||||
result = FileSystemOperations.walk_directory(str(tmp_path))
|
||||
assert len(result) == 0
|
||||
|
||||
def test_walk_directory_file_size(temp_directory):
|
||||
result: List[FileSystemItem] = FileSystemOperations.walk_directory(str(temp_directory))
|
||||
files = [item for item in result if is_file_info(item)]
|
||||
for file in files:
|
||||
assert file['size'] > 0 # Assuming all files have some content
|
||||
@ -361,19 +361,6 @@ class TestExecution:
|
||||
for i in range(3):
|
||||
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
|
||||
|
||||
async def test_output_reuse(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
output1 = g.node("PreviewImage", images=input1.out(0))
|
||||
output2 = g.node("PreviewImage", images=input1.out(0))
|
||||
|
||||
result = await client.run(g)
|
||||
images1 = result.get_images(output1)
|
||||
images2 = result.get_images(output2)
|
||||
assert len(images1) == 1, "Should have 1 image"
|
||||
assert len(images2) == 1, "Should have 1 image"
|
||||
|
||||
async def test_mixed_lazy_results(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
val_list = g.node("TestMakeListNode", value1=0.0, value2=0.5, value3=1.0)
|
||||
@ -390,3 +377,55 @@ class TestExecution:
|
||||
assert numpy.array(images[0]).min() == 0 and numpy.array(images[0]).max() == 0, "First image should be 0.0"
|
||||
assert numpy.array(images[1]).min() == 127 and numpy.array(images[1]).max() == 127, "Second image should be 0.5"
|
||||
assert numpy.array(images[2]).min() == 255 and numpy.array(images[2]).max() == 255, "Third image should be 1.0"
|
||||
|
||||
async def test_missing_node_error(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||
input3 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
||||
mix1 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
mix2 = g.node("TestLazyMixImages", image1=input1.out(0), image2=input3.out(0), mask=mask.out(0))
|
||||
# We have multiple outputs. The first is invalid, but the second is valid
|
||||
g.node("SaveImage", images=mix1.out(0))
|
||||
g.node("SaveImage", images=mix2.out(0))
|
||||
g.remove_node("removeme")
|
||||
|
||||
await client.run(g)
|
||||
|
||||
# Add back in the missing node to make sure the error doesn't break the server
|
||||
input2 = g.node("StubImage", id="removeme", content="WHITE", height=512, width=512, batch_size=1)
|
||||
await client.run(g)
|
||||
|
||||
async def test_output_reuse(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
output1 = g.node("SaveImage", images=input1.out(0))
|
||||
output2 = g.node("SaveImage", images=input1.out(0))
|
||||
|
||||
result = await client.run(g)
|
||||
images1 = result.get_images(output1)
|
||||
images2 = result.get_images(output2)
|
||||
assert len(images1) == 1, "Should have 1 image"
|
||||
assert len(images2) == 1, "Should have 1 image"
|
||||
|
||||
|
||||
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
||||
async def test_is_changed_with_outputs(self, client: Client, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
||||
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
||||
|
||||
output = g.node("PreviewImage", images=test_node.out(0))
|
||||
|
||||
result = await client.run(g)
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have 1 image"
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
|
||||
result = await client.run(g)
|
||||
images = result.get_images(output)
|
||||
assert len(images) == 1, "Should have 1 image"
|
||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||
assert not result.did_run(test_node), "The execution should have been cached"
|
||||
|
||||
@ -95,6 +95,31 @@ class TestCustomIsChanged:
|
||||
else:
|
||||
return False
|
||||
|
||||
class TestIsChangedWithConstants:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"image": ("IMAGE",),
|
||||
"value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "custom_is_changed"
|
||||
|
||||
CATEGORY = "Testing/Nodes"
|
||||
|
||||
def custom_is_changed(self, image, value):
|
||||
return (image * value,)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(cls, image, value):
|
||||
if image is None:
|
||||
return value
|
||||
else:
|
||||
return image.mean().item() * value
|
||||
|
||||
class TestCustomValidation1:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
@ -287,6 +312,7 @@ TEST_NODE_CLASS_MAPPINGS = {
|
||||
"TestLazyMixImages": TestLazyMixImages,
|
||||
"TestVariadicAverage": TestVariadicAverage,
|
||||
"TestCustomIsChanged": TestCustomIsChanged,
|
||||
"TestIsChangedWithConstants": TestIsChangedWithConstants,
|
||||
"TestCustomValidation1": TestCustomValidation1,
|
||||
"TestCustomValidation2": TestCustomValidation2,
|
||||
"TestCustomValidation3": TestCustomValidation3,
|
||||
@ -299,6 +325,7 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"TestLazyMixImages": "Lazy Mix Images",
|
||||
"TestVariadicAverage": "Variadic Average",
|
||||
"TestCustomIsChanged": "Custom IsChanged",
|
||||
"TestIsChangedWithConstants": "IsChanged With Constants",
|
||||
"TestCustomValidation1": "Custom Validation 1",
|
||||
"TestCustomValidation2": "Custom Validation 2",
|
||||
"TestCustomValidation3": "Custom Validation 3",
|
||||
|
||||
@ -28,6 +28,28 @@ class StubImage:
|
||||
elif content == "NOISE":
|
||||
return (torch.rand(batch_size, height, width, 3),)
|
||||
|
||||
class StubConstantImage:
|
||||
def __init__(self):
|
||||
pass
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
"height": ("INT", {"default": 512, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
"width": ("INT", {"default": 512, "min": 1, "max": 4096 ** 3, "step": 1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 1024 ** 3, "step": 1}),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "stub_constant_image"
|
||||
|
||||
CATEGORY = "Testing/Stub Nodes"
|
||||
|
||||
def stub_constant_image(self, value, height, width, batch_size):
|
||||
return (torch.ones(batch_size, height, width, 3) * value,)
|
||||
|
||||
class StubMask:
|
||||
def __init__(self):
|
||||
pass
|
||||
@ -93,12 +115,14 @@ class StubFloat:
|
||||
|
||||
TEST_STUB_NODE_CLASS_MAPPINGS = {
|
||||
"StubImage": StubImage,
|
||||
"StubConstantImage": StubConstantImage,
|
||||
"StubMask": StubMask,
|
||||
"StubInt": StubInt,
|
||||
"StubFloat": StubFloat,
|
||||
}
|
||||
TEST_STUB_NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"StubImage": "Stub Image",
|
||||
"StubConstantImage": "Stub Constant Image",
|
||||
"StubMask": "Stub Mask",
|
||||
"StubInt": "Stub Int",
|
||||
"StubFloat": "Stub Float",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user