mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-23 04:17:34 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
0fd407ae87
@ -10,6 +10,7 @@
|
||||
/tests-unit/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/notebooks/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/script_examples/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
/.github/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata @Kosinkadink
|
||||
|
||||
# Python web server
|
||||
/api_server/ @yoland68 @robinjhuang @huchenlei @webfiltered @pythongosssss @ltdrdata
|
||||
|
||||
24
README.md
24
README.md
@ -20,10 +20,21 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu
|
||||
### Upstream Features
|
||||
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
- Fully supports SD1.x, SD2.x, [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/), [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/), [SD3](https://comfyanonymous.github.io/ComfyUI_examples/sd3/) and [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- Image Models
|
||||
- SD1.x, SD2.x,
|
||||
- [SDXL](https://comfyanonymous.github.io/ComfyUI_examples/sdxl/), [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- [Stable Cascade](https://comfyanonymous.github.io/ComfyUI_examples/stable_cascade/)
|
||||
- [SD3 and SD3.5](https://comfyanonymous.github.io/ComfyUI_examples/sd3/)
|
||||
- Pixart Alpha and Sigma
|
||||
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
|
||||
- Video Models
|
||||
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
|
||||
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
|
||||
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- Asynchronous Queue system
|
||||
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
|
||||
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
|
||||
@ -43,9 +54,6 @@ A vanilla, up-to-date fork of [ComfyUI](https://github.com/comfyanonymous/comfyu
|
||||
- [GLIGEN](https://comfyanonymous.github.io/ComfyUI_examples/gligen/)
|
||||
- [Model Merging](https://comfyanonymous.github.io/ComfyUI_examples/model_merging/)
|
||||
- [LCM models and Loras](https://comfyanonymous.github.io/ComfyUI_examples/lcm/)
|
||||
- [SDXL Turbo](https://comfyanonymous.github.io/ComfyUI_examples/sdturbo/)
|
||||
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
|
||||
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
|
||||
- Latent previews with [TAESD](#how-to-show-high-quality-previews)
|
||||
- Starts up very fast.
|
||||
- Works fully offline: will never download anything.
|
||||
@ -812,6 +820,8 @@ The default installation includes a fast latent preview method that's low-resolu
|
||||
| `Q` | Toggle visibility of the queue |
|
||||
| `H` | Toggle visibility of history |
|
||||
| `R` | Refresh graph |
|
||||
| `F` | Show/Hide menu |
|
||||
| `.` | Fit view to selection (Whole graph when nothing is selected) |
|
||||
| Double-Click LMB | Open node quick search palette |
|
||||
| `Shift` + Drag | Move multiple wires at once |
|
||||
| `Ctrl` + `Alt` + LMB | Disconnect all wires from clicked slot |
|
||||
|
||||
184
app/model_manager.py
Normal file
184
app/model_manager.py
Normal file
@ -0,0 +1,184 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
import logging
|
||||
import folder_paths
|
||||
import glob
|
||||
import comfy.utils
|
||||
from aiohttp import web
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from folder_paths import map_legacy, filter_files_extensions, filter_files_content_types
|
||||
|
||||
|
||||
class ModelFileManager:
|
||||
def __init__(self) -> None:
|
||||
self.cache: dict[str, tuple[list[dict], dict[str, float], float]] = {}
|
||||
|
||||
def get_cache(self, key: str, default=None) -> tuple[list[dict], dict[str, float], float] | None:
|
||||
return self.cache.get(key, default)
|
||||
|
||||
def set_cache(self, key: str, value: tuple[list[dict], dict[str, float], float]):
|
||||
self.cache[key] = value
|
||||
|
||||
def clear_cache(self):
|
||||
self.cache.clear()
|
||||
|
||||
def add_routes(self, routes):
|
||||
# NOTE: This is an experiment to replace `/models`
|
||||
@routes.get("/experiment/models")
|
||||
async def get_model_folders(request):
|
||||
model_types = list(folder_paths.folder_names_and_paths.keys())
|
||||
folder_black_list = ["configs", "custom_nodes"]
|
||||
output_folders: list[dict] = []
|
||||
for folder in model_types:
|
||||
if folder in folder_black_list:
|
||||
continue
|
||||
output_folders.append({"name": folder, "folders": folder_paths.get_folder_paths(folder)})
|
||||
return web.json_response(output_folders)
|
||||
|
||||
# NOTE: This is an experiment to replace `/models/{folder}`
|
||||
@routes.get("/experiment/models/{folder}")
|
||||
async def get_all_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = self.get_model_file_list(folder)
|
||||
return web.json_response(files)
|
||||
|
||||
@routes.get("/experiment/models/preview/{folder}/{path_index}/{filename:.*}")
|
||||
async def get_model_preview(request):
|
||||
folder_name = request.match_info.get("folder", None)
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if not folder_name in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
folder = folders[0][path_index]
|
||||
full_filename = os.path.join(folder, filename)
|
||||
|
||||
previews = self.get_model_previews(full_filename)
|
||||
default_preview = previews[0] if len(previews) > 0 else None
|
||||
if default_preview is None or (isinstance(default_preview, str) and not os.path.isfile(default_preview)):
|
||||
return web.Response(status=404)
|
||||
|
||||
try:
|
||||
with Image.open(default_preview) as img:
|
||||
img_bytes = BytesIO()
|
||||
img.save(img_bytes, format="WEBP")
|
||||
img_bytes.seek(0)
|
||||
return web.Response(body=img_bytes.getvalue(), content_type="image/webp")
|
||||
except:
|
||||
return web.Response(status=404)
|
||||
|
||||
def get_model_file_list(self, folder_name: str):
|
||||
folder_name = map_legacy(folder_name)
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
output_list: list[dict] = []
|
||||
|
||||
for index, folder in enumerate(folders[0]):
|
||||
if not os.path.isdir(folder):
|
||||
continue
|
||||
out = self.cache_model_file_list_(folder)
|
||||
if out is None:
|
||||
out = self.recursive_search_models_(folder, index)
|
||||
self.set_cache(folder, out)
|
||||
output_list.extend(out[0])
|
||||
|
||||
return output_list
|
||||
|
||||
def cache_model_file_list_(self, folder: str):
|
||||
model_file_list_cache = self.get_cache(folder)
|
||||
|
||||
if model_file_list_cache is None:
|
||||
return None
|
||||
if not os.path.isdir(folder):
|
||||
return None
|
||||
if os.path.getmtime(folder) != model_file_list_cache[1]:
|
||||
return None
|
||||
for x in model_file_list_cache[1]:
|
||||
time_modified = model_file_list_cache[1][x]
|
||||
folder = x
|
||||
if os.path.getmtime(folder) != time_modified:
|
||||
return None
|
||||
|
||||
return model_file_list_cache
|
||||
|
||||
def recursive_search_models_(self, directory: str, pathIndex: int) -> tuple[list[str], dict[str, float], float]:
|
||||
if not os.path.isdir(directory):
|
||||
return [], {}, time.perf_counter()
|
||||
|
||||
excluded_dir_names = [".git"]
|
||||
# TODO use settings
|
||||
include_hidden_files = False
|
||||
|
||||
result: list[str] = []
|
||||
dirs: dict[str, float] = {}
|
||||
|
||||
for dirpath, subdirs, filenames in os.walk(directory, followlinks=True, topdown=True):
|
||||
subdirs[:] = [d for d in subdirs if d not in excluded_dir_names]
|
||||
if not include_hidden_files:
|
||||
subdirs[:] = [d for d in subdirs if not d.startswith(".")]
|
||||
filenames = [f for f in filenames if not f.startswith(".")]
|
||||
|
||||
filenames = filter_files_extensions(filenames, folder_paths.supported_pt_extensions)
|
||||
|
||||
for file_name in filenames:
|
||||
try:
|
||||
relative_path = os.path.relpath(os.path.join(dirpath, file_name), directory)
|
||||
result.append(relative_path)
|
||||
except:
|
||||
logging.warning(f"Warning: Unable to access {file_name}. Skipping this file.")
|
||||
continue
|
||||
|
||||
for d in subdirs:
|
||||
path: str = os.path.join(dirpath, d)
|
||||
try:
|
||||
dirs[path] = os.path.getmtime(path)
|
||||
except FileNotFoundError:
|
||||
logging.warning(f"Warning: Unable to access {path}. Skipping this path.")
|
||||
continue
|
||||
|
||||
return [{"name": f, "pathIndex": pathIndex} for f in result], dirs, time.perf_counter()
|
||||
|
||||
def get_model_previews(self, filepath: str) -> list[str | BytesIO]:
|
||||
dirname = os.path.dirname(filepath)
|
||||
|
||||
if not os.path.exists(dirname):
|
||||
return []
|
||||
|
||||
basename = os.path.splitext(filepath)[0]
|
||||
match_files = glob.glob(f"{basename}.*", recursive=False)
|
||||
image_files = filter_files_content_types(match_files, "image")
|
||||
safetensors_file = next(filter(lambda x: x.endswith(".safetensors"), match_files), None)
|
||||
safetensors_metadata = {}
|
||||
|
||||
result: list[str | BytesIO] = []
|
||||
|
||||
for filename in image_files:
|
||||
_basename = os.path.splitext(filename)[0]
|
||||
if _basename == basename:
|
||||
result.append(filename)
|
||||
if _basename == f"{basename}.preview":
|
||||
result.append(filename)
|
||||
|
||||
if safetensors_file:
|
||||
safetensors_filepath = os.path.join(dirname, safetensors_file)
|
||||
header = comfy.utils.safetensors_header(safetensors_filepath, max_size=8*1024*1024)
|
||||
if header:
|
||||
safetensors_metadata = json.loads(header)
|
||||
safetensors_images = safetensors_metadata.get("__metadata__", {}).get("ssmd_cover_images", None)
|
||||
if safetensors_images:
|
||||
safetensors_images = json.loads(safetensors_images)
|
||||
for image in safetensors_images:
|
||||
result.append(BytesIO(base64.b64decode(image)))
|
||||
|
||||
return result
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.clear_cache()
|
||||
@ -41,8 +41,8 @@ class UserManager():
|
||||
if not os.path.exists(user_directory):
|
||||
os.makedirs(user_directory, exist_ok=True)
|
||||
if not args.multi_user:
|
||||
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******")
|
||||
logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")
|
||||
|
||||
if args.multi_user:
|
||||
if os.path.isfile(self.get_users_file()):
|
||||
|
||||
@ -162,7 +162,6 @@ class ControlNet(nn.Module):
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
elif self.num_classes == "continuous":
|
||||
print("setting up linear c_adm embedding layer")
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
elif self.num_classes == "sequential":
|
||||
assert adm_in_channels is not None
|
||||
@ -415,7 +414,6 @@ class ControlNet(nn.Module):
|
||||
out_output = []
|
||||
out_middle = []
|
||||
|
||||
hs = []
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0], "There may be a mismatch between the ControlNet and Diffusion models being used"
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
@ -92,6 +92,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
|
||||
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")
|
||||
cache_group = parser.add_mutually_exclusive_group()
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="WARNING: Unused. Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true",
|
||||
@ -100,10 +101,9 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
|
||||
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true",
|
||||
help="Use the new pytorch 2.0 cross attention function.")
|
||||
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||
parser.add_argument("--disable-flash-attn", action="store_true", help="Disable Flash Attention")
|
||||
parser.add_argument("--disable-sage-attention", action="store_true", help="Disable Sage Attention")
|
||||
|
||||
upcast = parser.add_mutually_exclusive_group()
|
||||
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
|
||||
|
||||
@ -79,9 +79,8 @@ class Configuration(dict):
|
||||
use_split_cross_attention (bool): Use split cross-attention optimization.
|
||||
use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization.
|
||||
use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function.
|
||||
use_sage_attention (bool): Use Sage Attention
|
||||
disable_xformers (bool): Disable xformers.
|
||||
disable_flash_attn (bool): Disable flash_attn package attention.
|
||||
disable_sage_attention (bool): Disable sage attention package attention.
|
||||
gpu_only (bool): Run everything on the GPU.
|
||||
highvram (bool): Keep models in GPU memory.
|
||||
normalvram (bool): Default VRAM usage setting.
|
||||
@ -165,9 +164,8 @@ class Configuration(dict):
|
||||
self.use_split_cross_attention: bool = False
|
||||
self.use_quad_cross_attention: bool = False
|
||||
self.use_pytorch_cross_attention: bool = False
|
||||
self.use_sage_attention: bool = False
|
||||
self.disable_xformers: bool = False
|
||||
self.disable_flash_attn: bool = False
|
||||
self.disable_sage_attention: bool = False
|
||||
self.gpu_only: bool = False
|
||||
self.highvram: bool = False
|
||||
self.normalvram: bool = False
|
||||
|
||||
@ -190,11 +190,16 @@ def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func:
|
||||
|
||||
results = []
|
||||
|
||||
def process_inputs(inputs, index=None):
|
||||
def process_inputs(inputs, index=None, input_is_list=False):
|
||||
if allow_interrupt:
|
||||
interruption.throw_exception_if_processing_interrupted()
|
||||
execution_block = None
|
||||
for k, v in inputs.items():
|
||||
if input_is_list:
|
||||
for e in v:
|
||||
if isinstance(e, ExecutionBlocker):
|
||||
v = e
|
||||
break
|
||||
if isinstance(v, ExecutionBlocker):
|
||||
execution_block = execution_block_cb(v) if execution_block_cb else v
|
||||
break
|
||||
@ -206,7 +211,7 @@ def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func:
|
||||
results.append(execution_block)
|
||||
|
||||
if input_is_list:
|
||||
process_inputs(input_data_all, 0)
|
||||
process_inputs(input_data_all, 0, input_is_list=input_is_list)
|
||||
elif max_len_input == 0:
|
||||
process_inputs({})
|
||||
else:
|
||||
@ -904,7 +909,7 @@ def _validate_prompt(prompt: typing.Mapping[str, typing.Any]) -> ValidationTuple
|
||||
if 'class_type' not in prompt[x]:
|
||||
error = {
|
||||
"type": "invalid_prompt",
|
||||
"message": f"Cannot execute because a node is missing the class_type property.",
|
||||
"message": "Cannot execute because a node is missing the class_type property.",
|
||||
"details": f"Node ID '#{x}'",
|
||||
"extra_info": {}
|
||||
}
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
@ -324,7 +325,7 @@ def recursive_search(directory, excluded_dir_names=None) -> tuple[list[str], dic
|
||||
return result, dirs
|
||||
|
||||
|
||||
def filter_files_extensions(files, extensions):
|
||||
def filter_files_extensions(files: collections.abc.Collection[str], extensions: collections.abc.Collection[str]):
|
||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
|
||||
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import collections
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Literal, Tuple, Union, Dict
|
||||
@ -76,7 +77,7 @@ def recursive_search(
|
||||
) -> Tuple[List[str], Dict[str, float]]: ...
|
||||
|
||||
|
||||
def filter_files_extensions(files: List[str], extensions: set[str]) -> List[str]: ...
|
||||
def filter_files_extensions(files: collections.abc.Collection[str], extensions: collections.abc.Collection[str]) -> List[str]: ...
|
||||
|
||||
|
||||
def get_full_path(folder_name: str, filename: str) -> Optional[Union[str, bytes, os.PathLike]]: ...
|
||||
|
||||
@ -32,4 +32,4 @@ def update_windows_updater():
|
||||
except:
|
||||
pass
|
||||
shutil.copy(bat_path, dest_bat_path)
|
||||
print("Updated the windows standalone package updater.")
|
||||
print("Updated the windows standalone package updater.") # noqa: T201
|
||||
|
||||
@ -170,6 +170,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
self.address: str = "0.0.0.0"
|
||||
self.user_manager = UserManager()
|
||||
self.model_file_manager = ModelFileManager()
|
||||
self.internal_routes = InternalRoutes(self)
|
||||
# todo: this is probably read by custom nodes elsewhere
|
||||
self.supports: List[str] = ["custom_nodes_from_web"]
|
||||
@ -471,7 +472,21 @@ class PromptServer(ExecutorToClientProgress):
|
||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
else:
|
||||
return web.FileResponse(file, headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
# Get content type from mimetype, defaulting to 'application/octet-stream'
|
||||
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||||
|
||||
# For security, force certain extensions to download instead of display
|
||||
file_extension = os.path.splitext(filename)[1].lower()
|
||||
if file_extension in {'.html', '.htm', '.js', '.css'}:
|
||||
content_type = 'application/octet-stream' # Forces download
|
||||
|
||||
return web.FileResponse(
|
||||
file,
|
||||
headers={
|
||||
"Content-Disposition": f"filename=\"{filename}\"",
|
||||
"Content-Type": content_type
|
||||
}
|
||||
)
|
||||
return web.Response(status=404)
|
||||
|
||||
@routes.get("/view_metadata/{folder_name}")
|
||||
@ -573,7 +588,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
for x in self.nodes.NODE_CLASS_MAPPINGS:
|
||||
try:
|
||||
out[x] = node_info(x)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
|
||||
logger.error(traceback.format_exc())
|
||||
return web.json_response(out)
|
||||
@ -594,7 +609,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
return web.json_response(self.prompt_queue.get_history(max_items=max_items))
|
||||
|
||||
@routes.get("/history/{prompt_id}")
|
||||
async def get_history_prompt(request):
|
||||
async def get_history_prompt_id(request):
|
||||
prompt_id = request.match_info.get("prompt_id", None)
|
||||
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
|
||||
|
||||
@ -905,6 +920,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
|
||||
def add_routes(self):
|
||||
self.user_manager.add_routes(self.routes)
|
||||
self.model_file_manager.add_routes(self.routes)
|
||||
self.app.add_subapp('/internal', self.internal_routes.get_app())
|
||||
|
||||
# Prefix every route with /api for easier matching for delegation.
|
||||
@ -1049,8 +1065,8 @@ class PromptServer(ExecutorToClientProgress):
|
||||
for handler in self.on_prompt_handlers:
|
||||
try:
|
||||
json_data = handler(json_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"[ERROR] An error occurred during the on_prompt_handler processing")
|
||||
except Exception:
|
||||
logger.warning("[ERROR] An error occurred during the on_prompt_handler processing")
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
return json_data
|
||||
|
||||
@ -301,7 +301,6 @@ class ControlLoraOps:
|
||||
class Linear(torch.nn.Module, ops.CastWeightBiasOp):
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
@ -386,7 +385,6 @@ class ControlLora(ControlNet):
|
||||
self.control_model.to(model_management.get_torch_device())
|
||||
diffusion_model = model.diffusion_model
|
||||
sd = diffusion_model.state_dict()
|
||||
cm = self.control_model.state_dict()
|
||||
|
||||
for k in sd:
|
||||
weight = sd[k]
|
||||
|
||||
@ -157,16 +157,23 @@ vae_conversion_map_attn = [
|
||||
]
|
||||
|
||||
|
||||
def reshape_weight_for_sd(w):
|
||||
def reshape_weight_for_sd(w, conv3d=False):
|
||||
# convert HF linear weights to SD conv2d weights
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
if conv3d:
|
||||
return w.reshape(*w.shape, 1, 1, 1)
|
||||
else:
|
||||
return w.reshape(*w.shape, 1, 1)
|
||||
|
||||
|
||||
def convert_vae_state_dict(vae_state_dict):
|
||||
mapping = {k: k for k in vae_state_dict.keys()}
|
||||
conv3d = False
|
||||
for k, v in mapping.items():
|
||||
for sd_part, hf_part in vae_conversion_map:
|
||||
v = v.replace(hf_part, sd_part)
|
||||
if v.endswith(".conv.weight"):
|
||||
if not conv3d and vae_state_dict[k].ndim == 5:
|
||||
conv3d = True
|
||||
mapping[k] = v
|
||||
for k, v in mapping.items():
|
||||
if "attentions" in k:
|
||||
@ -179,7 +186,7 @@ def convert_vae_state_dict(vae_state_dict):
|
||||
for weight_name in weights_to_convert:
|
||||
if f"mid.attn_1.{weight_name}.weight" in k:
|
||||
logging.debug(f"Reshaping {k} for SD format")
|
||||
new_state_dict[k] = reshape_weight_for_sd(v)
|
||||
new_state_dict[k] = reshape_weight_for_sd(v, conv3d=conv3d)
|
||||
return new_state_dict
|
||||
|
||||
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
import torch
|
||||
import math
|
||||
import logging
|
||||
|
||||
from tqdm.auto import trange
|
||||
|
||||
@ -476,7 +477,7 @@ class UniPC:
|
||||
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
||||
|
||||
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
||||
print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||
logging.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
||||
ns = self.noise_schedule
|
||||
assert order <= len(model_prev_list)
|
||||
|
||||
@ -520,7 +521,6 @@ class UniPC:
|
||||
A_p = C_inv_p
|
||||
|
||||
if use_corrector:
|
||||
print('using corrector')
|
||||
C_inv = torch.linalg.inv(C)
|
||||
A_c = C_inv
|
||||
|
||||
@ -707,7 +707,6 @@ class UniPC:
|
||||
):
|
||||
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
|
||||
# t_T = self.noise_schedule.T if t_start is None else t_start
|
||||
device = x.device
|
||||
steps = len(timesteps) - 1
|
||||
if method == 'multistep':
|
||||
assert steps >= order
|
||||
|
||||
@ -4,7 +4,7 @@ import enum
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
from typing import Callable, Any, TYPE_CHECKING
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -142,7 +142,7 @@ class WeightHook(Hook):
|
||||
weights = self.weights
|
||||
else:
|
||||
weights = self.weights_clip
|
||||
k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||
model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
|
||||
registered.append(self)
|
||||
return True
|
||||
# TODO: add logs about any keys that were not applied
|
||||
@ -301,7 +301,7 @@ class HookGroup:
|
||||
return d
|
||||
|
||||
def get_hooks_for_clip_schedule(self):
|
||||
scheduled_hooks: dict[WeightHook, list[tuple[tuple[float, float], HookKeyframe]]] = {}
|
||||
scheduled_hooks: dict[Hook, list[tuple[tuple[float, float], HookKeyframe]]] = {}
|
||||
for hook in self.hooks:
|
||||
# only care about WeightHooks, for now
|
||||
if hook.hook_type == EnumHookType.Weight:
|
||||
@ -354,7 +354,7 @@ class HookGroup:
|
||||
hook.reset()
|
||||
|
||||
@staticmethod
|
||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup':
|
||||
def combine_all_hooks(hooks_list: list['HookGroup'], require_count=0) -> 'HookGroup' | None:
|
||||
actual: list[HookGroup] = []
|
||||
for group in hooks_list:
|
||||
if group is not None:
|
||||
@ -367,7 +367,7 @@ class HookGroup:
|
||||
# if only 1 hook, just return itself without cloning
|
||||
elif len(actual) == 1:
|
||||
return actual[0]
|
||||
final_hook: HookGroup = None
|
||||
final_hook: HookGroup | None = None
|
||||
for hook in actual:
|
||||
if final_hook is None:
|
||||
final_hook = hook.clone()
|
||||
@ -394,7 +394,7 @@ class HookKeyframe:
|
||||
class HookKeyframeGroup:
|
||||
def __init__(self):
|
||||
self.keyframes: list[HookKeyframe] = []
|
||||
self._current_keyframe: HookKeyframe = None
|
||||
self._current_keyframe: HookKeyframe | None = None
|
||||
self._current_used_steps = 0
|
||||
self._current_index = 0
|
||||
self._current_strength = None
|
||||
@ -626,7 +626,9 @@ def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, H
|
||||
c_dict[hooks_key] = cache[hooks_tuple]
|
||||
|
||||
|
||||
def conditioning_set_values_with_hooks(conditioning, values={}, append_hooks=True):
|
||||
def conditioning_set_values_with_hooks(conditioning, values: dict[str, HookGroup] = None, append_hooks=True):
|
||||
if values is None:
|
||||
values = {}
|
||||
c = []
|
||||
hooks_combine_cache: dict[tuple[HookGroup, HookGroup], HookGroup] = {}
|
||||
for t in conditioning:
|
||||
|
||||
@ -11,7 +11,6 @@ import numpy as np
|
||||
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
|
||||
|
||||
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
|
||||
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
|
||||
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
|
||||
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
|
||||
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
|
||||
|
||||
@ -366,3 +366,27 @@ class LTXV(LatentFormat):
|
||||
]
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512]
|
||||
|
||||
class HunyuanVideo(LatentFormat):
|
||||
latent_channels = 16
|
||||
scale_factor = 0.476986
|
||||
latent_rgb_factors = [
|
||||
[-0.0395, -0.0331, 0.0445],
|
||||
[ 0.0696, 0.0795, 0.0518],
|
||||
[ 0.0135, -0.0945, -0.0282],
|
||||
[ 0.0108, -0.0250, -0.0765],
|
||||
[-0.0209, 0.0032, 0.0224],
|
||||
[-0.0804, -0.0254, -0.0639],
|
||||
[-0.0991, 0.0271, -0.0669],
|
||||
[-0.0646, -0.0422, -0.0400],
|
||||
[-0.0696, -0.0595, -0.0894],
|
||||
[-0.0799, -0.0208, -0.0375],
|
||||
[ 0.1166, 0.1627, 0.0962],
|
||||
[ 0.1165, 0.0432, 0.0407],
|
||||
[-0.2315, -0.1920, -0.1355],
|
||||
[-0.0270, 0.0401, -0.0821],
|
||||
[-0.0616, -0.0997, -0.0727],
|
||||
[ 0.0249, -0.0469, -0.1703]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||
|
||||
@ -98,7 +98,7 @@ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False,
|
||||
|
||||
# todo: ??? Not existing
|
||||
# if antialias:
|
||||
# act = Activation1d(act)
|
||||
# act = Activation1d(act) # noqa: F821 Activation1d is not defined
|
||||
|
||||
return act
|
||||
|
||||
|
||||
@ -1,17 +1,17 @@
|
||||
# code adapted from: https://github.com/Stability-AI/stable-audio-tools
|
||||
from einops.layers.torch import Rearrange
|
||||
|
||||
from ..modules.attention import optimized_attention
|
||||
import math
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
|
||||
from einops import rearrange
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import math
|
||||
|
||||
from ..modules.attention import optimized_attention
|
||||
from ... import ops
|
||||
|
||||
|
||||
class FourierFeatures(nn.Module):
|
||||
def __init__(self, in_features, out_features, std=1., dtype=None, device=None):
|
||||
super().__init__()
|
||||
@ -161,7 +161,6 @@ class RotaryEmbedding(nn.Module):
|
||||
seq_len = 0
|
||||
# device = self.inv_freq.device
|
||||
device = t.device
|
||||
dtype = t.dtype
|
||||
|
||||
# t = t.to(torch.float32)
|
||||
|
||||
@ -173,7 +172,7 @@ class RotaryEmbedding(nn.Module):
|
||||
if self.scale is None:
|
||||
return freqs, 1.
|
||||
|
||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
|
||||
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base # noqa: F821 seq_len is not defined
|
||||
scale = ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
|
||||
scale = torch.cat((scale, scale), dim = -1)
|
||||
|
||||
@ -232,9 +231,9 @@ class FeedForward(nn.Module):
|
||||
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
linear_in = nn.Sequential(
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
activation
|
||||
)
|
||||
|
||||
@ -249,9 +248,9 @@ class FeedForward(nn.Module):
|
||||
|
||||
self.ff = nn.Sequential(
|
||||
linear_in,
|
||||
Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||
rearrange('b d n -> b n d') if use_conv else nn.Identity(),
|
||||
linear_out,
|
||||
Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
rearrange('b n d -> b d n') if use_conv else nn.Identity(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -349,7 +348,6 @@ class Attention(nn.Module):
|
||||
|
||||
# determine masking
|
||||
masks = []
|
||||
# todo: ???
|
||||
|
||||
if input_mask is not None:
|
||||
input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
|
||||
@ -357,7 +355,7 @@ class Attention(nn.Module):
|
||||
|
||||
# Other masks will be added here later
|
||||
|
||||
n, device = q.shape[-2], q.device
|
||||
n = q.shape[-2]
|
||||
|
||||
causal = self.causal if causal is None else causal
|
||||
|
||||
|
||||
@ -147,7 +147,6 @@ class DoubleAttention(nn.Module):
|
||||
|
||||
bsz, seqlen1, _ = c.shape
|
||||
bsz, seqlen2, _ = x.shape
|
||||
seqlen = seqlen1 + seqlen2
|
||||
|
||||
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
|
||||
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)
|
||||
@ -382,7 +381,6 @@ class MMDiT(nn.Module):
|
||||
pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1)
|
||||
self.positional_encoding.data = pe_new.unsqueeze(0).contiguous()
|
||||
self.h_max, self.w_max = target_dim
|
||||
print("PE extended to", target_dim)
|
||||
|
||||
def pe_selection_index_based_on_dim(self, h, w):
|
||||
h_p, w_p = h // self.patch_size, w // self.patch_size
|
||||
|
||||
@ -6,9 +6,12 @@ from comfy import ops
|
||||
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
|
||||
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
||||
padding_mode = "reflect"
|
||||
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
|
||||
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
|
||||
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
|
||||
|
||||
pad = ()
|
||||
for i in range(img.ndim - 2):
|
||||
pad = (0, (patch_size[i] - img.shape[i + 2] % patch_size[i]) % patch_size[i]) + pad
|
||||
|
||||
return torch.nn.functional.pad(img, pad, mode=padding_mode)
|
||||
|
||||
try:
|
||||
rms_norm_torch = torch.nn.functional.rms_norm # pylint: disable=no-member
|
||||
|
||||
@ -113,7 +113,7 @@ class Modulation(nn.Module):
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
@ -140,8 +140,9 @@ class DoubleStreamBlock(nn.Module):
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor):
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
|
||||
img_mod1, img_mod2 = self.img_mod(vec)
|
||||
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
||||
|
||||
@ -161,12 +162,22 @@ class DoubleStreamBlock(nn.Module):
|
||||
txt_q, txt_k, txt_v = torch.unbind(txt_qkv, dim=0)
|
||||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2), pe=pe)
|
||||
if self.flipped_img_txt:
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((img_q, txt_q), dim=2),
|
||||
torch.cat((img_k, txt_k), dim=2),
|
||||
torch.cat((img_v, txt_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
||||
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:]
|
||||
else:
|
||||
# run actual attention
|
||||
attn = attention(torch.cat((txt_q, img_q), dim=2),
|
||||
torch.cat((txt_k, img_k), dim=2),
|
||||
torch.cat((txt_v, img_v), dim=2),
|
||||
pe=pe, mask=attn_mask)
|
||||
|
||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
@ -218,7 +229,7 @@ class SingleStreamBlock(nn.Module):
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
||||
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
|
||||
mod, _ = self.modulation(vec)
|
||||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
||||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
@ -228,7 +239,7 @@ class SingleStreamBlock(nn.Module):
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe)
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask)
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
||||
x = x + mod.gate * output
|
||||
|
||||
@ -1,15 +1,15 @@
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from torch import Tensor
|
||||
|
||||
from ..modules.attention import optimized_attention
|
||||
from ... import model_management
|
||||
|
||||
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
||||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
|
||||
q, k = apply_rope(q, k, pe)
|
||||
|
||||
heads = q.shape[1]
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True)
|
||||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
|
||||
return x
|
||||
|
||||
|
||||
@ -34,3 +34,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
||||
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
#Original code can be found on: https://github.com/black-forest-labs/flux
|
||||
# Original code can be found on: https://github.com/black-forest-labs/flux
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
from einops import rearrange, repeat
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .layers import (
|
||||
@ -13,10 +14,9 @@ from .layers import (
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
)
|
||||
|
||||
from einops import rearrange, repeat
|
||||
from .. import common_dit
|
||||
|
||||
|
||||
@dataclass
|
||||
class FluxParams:
|
||||
in_channels: int
|
||||
@ -91,16 +91,17 @@ class Flux(nn.Module):
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@ -114,7 +115,7 @@ class Flux(nn.Module):
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
@ -125,16 +126,29 @@ class Flux(nn.Module):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap_1(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"])
|
||||
out["img"], out["txt"] = block(img=args["img"],
|
||||
txt=args["txt"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe}, {"original_block": block_wrap_1})
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
"txt": txt,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap_1})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
||||
img, txt = block(img=img,
|
||||
txt=txt,
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
@ -147,22 +161,29 @@ class Flux(nn.Module):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap_2(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"])
|
||||
out["img"] = block(args["img"],
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"))
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe}, {"original_block": block_wrap_2})
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
"vec": vec,
|
||||
"pe": pe,
|
||||
"attn_mask": attn_mask},
|
||||
{"original_block": block_wrap_2})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, txt.shape[1] :, ...] += add
|
||||
img[:, txt.shape[1]:, ...] += add
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
img = img[:, txt.shape[1]:, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
@ -182,5 +203,5 @@ class Flux(nn.Module):
|
||||
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options)
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h, :w]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
# original code from https://github.com/genmoai/models under apache 2.0 license
|
||||
# adapted to ComfyUIfrom typing import Dict, List, Optional, Tuple
|
||||
# adapted to ComfyUI
|
||||
from typing import Tuple, List, Dict, Optional
|
||||
|
||||
import torch
|
||||
@ -459,7 +459,6 @@ class AsymmDiTJoint(nn.Module):
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
|
||||
assert x.ndim == 3
|
||||
B = x.size(0)
|
||||
|
||||
pH, pW = H // self.patch_size, W // self.patch_size
|
||||
N = T * pH * pW
|
||||
|
||||
330
comfy/ldm/hunyuan_video/model.py
Normal file
330
comfy/ldm/hunyuan_video/model.py
Normal file
@ -0,0 +1,330 @@
|
||||
#Based on Flux code because of weird hunyuan video code license.
|
||||
|
||||
import torch
|
||||
import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from einops import repeat
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
from comfy.ldm.flux.layers import (
|
||||
DoubleStreamBlock,
|
||||
EmbedND,
|
||||
LastLayer,
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding
|
||||
)
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
|
||||
@dataclass
|
||||
class HunyuanVideoParams:
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
vec_in_dim: int
|
||||
context_in_dim: int
|
||||
hidden_size: int
|
||||
mlp_ratio: float
|
||||
num_heads: int
|
||||
depth: int
|
||||
depth_single_blocks: int
|
||||
axes_dim: list
|
||||
theta: int
|
||||
patch_size: list
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
def __init__(self, dim: int, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
|
||||
class TokenRefinerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
heads,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
mlp_hidden_dim = hidden_size * 4
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
self.self_attn = SelfAttentionRef(hidden_size, True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
mod1, mod2 = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
|
||||
norm_x = self.norm1(x)
|
||||
qkv = self.self_attn.qkv(norm_x)
|
||||
q, k, v = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.heads, -1).permute(2, 0, 3, 1, 4)
|
||||
attn = optimized_attention(q, k, v, self.heads, mask=mask, skip_reshape=True)
|
||||
|
||||
x = x + self.self_attn.proj(attn) * mod1.unsqueeze(1)
|
||||
x = x + self.mlp(self.norm2(x)) * mod2.unsqueeze(1)
|
||||
return x
|
||||
|
||||
|
||||
class IndividualTokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size,
|
||||
heads,
|
||||
num_blocks,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
TokenRefinerBlock(
|
||||
hidden_size=hidden_size,
|
||||
heads=heads,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x, c, mask):
|
||||
m = None
|
||||
if mask is not None:
|
||||
m = mask.view(mask.shape[0], 1, 1, mask.shape[1]).repeat(1, 1, mask.shape[1], 1)
|
||||
m = m + m.transpose(2, 3)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x, c, m)
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class TokenRefiner(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
text_dim,
|
||||
hidden_size,
|
||||
heads,
|
||||
num_blocks,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.input_embedder = operations.Linear(text_dim, hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.t_embedder = MLPEmbedder(256, hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.c_embedder = MLPEmbedder(text_dim, hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.individual_token_refiner = IndividualTokenRefiner(hidden_size, heads, num_blocks, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
timesteps,
|
||||
mask,
|
||||
):
|
||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||
# m = mask.float().unsqueeze(-1)
|
||||
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
x = self.individual_token_refiner(x, c, mask)
|
||||
return x
|
||||
|
||||
class HunyuanVideo(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
"""
|
||||
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
params = HunyuanVideoParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
)
|
||||
pe_dim = params.hidden_size // params.num_heads
|
||||
if sum(params.axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
|
||||
self.img_in = comfy.ldm.modules.diffusionmodules.mmdit.PatchEmbed(None, self.patch_size, self.in_channels, self.hidden_size, conv3d=True, dtype=dtype, device=device, operations=operations)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
|
||||
self.txt_in = TokenRefiner(params.context_in_dim, self.hidden_size, self.num_heads, 2, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
DoubleStreamBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
flipped_img_txt=True,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
img_ids: Tensor,
|
||||
txt: Tensor,
|
||||
txt_ids: Tensor,
|
||||
txt_mask: Tensor,
|
||||
timesteps: Tensor,
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control=None,
|
||||
transformer_options={},
|
||||
) -> Tensor:
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
|
||||
initial_shape = list(img.shape)
|
||||
# running on sequences img
|
||||
img = self.img_in(img)
|
||||
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
|
||||
if self.params.guidance_embed:
|
||||
if guidance is None:
|
||||
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
if txt_mask is not None and not torch.is_floating_point(txt_mask):
|
||||
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
img_len = img.shape[1]
|
||||
if txt_mask is not None:
|
||||
attn_mask_len = img_len + txt.shape[1]
|
||||
attn_mask = torch.zeros((1, 1, attn_mask_len), dtype=img.dtype, device=img.device)
|
||||
attn_mask[:, 0, img_len:] = txt_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||
txt = out["txt"]
|
||||
img = out["img"]
|
||||
else:
|
||||
img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
if i < len(control_i):
|
||||
add = control_i[i]
|
||||
if add is not None:
|
||||
img += add
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
if i < len(control_o):
|
||||
add = control_o[i]
|
||||
if add is not None:
|
||||
img[:, : img_len] += add
|
||||
|
||||
img = img[:, : img_len]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
|
||||
shape = initial_shape[-3:]
|
||||
for i in range(len(shape)):
|
||||
shape[i] = shape[i] // self.patch_size[i]
|
||||
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
|
||||
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
||||
img = img.reshape(initial_shape)
|
||||
return img
|
||||
|
||||
def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
|
||||
img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
|
||||
return out
|
||||
@ -158,9 +158,6 @@ class HunYuanControlNet(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
# Image embedding
|
||||
num_patches = self.x_embedder.num_patches
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
|
||||
@ -248,9 +248,6 @@ class HunYuanDiT(nn.Module):
|
||||
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
# Image embedding
|
||||
num_patches = self.x_embedder.num_patches
|
||||
|
||||
# HUnYuanDiT Blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
HunYuanDiTBlock(hidden_size=hidden_size,
|
||||
|
||||
@ -53,7 +53,7 @@ class Patchifier(ABC):
|
||||
grid_h = torch.arange(h, dtype=torch.float32, device=device)
|
||||
grid_w = torch.arange(w, dtype=torch.float32, device=device)
|
||||
grid_f = torch.arange(f, dtype=torch.float32, device=device)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w)
|
||||
grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
|
||||
grid = torch.stack(grid, dim=0)
|
||||
grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
|
||||
@ -6,7 +6,9 @@ from einops import rearrange
|
||||
from typing import Optional, Tuple, Union
|
||||
from .conv_nd_factory import make_conv_nd, make_linear_nd
|
||||
from .pixel_norm import PixelNorm
|
||||
|
||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
class Encoder(nn.Module):
|
||||
r"""
|
||||
@ -236,6 +238,7 @@ class Decoder(nn.Module):
|
||||
patch_size: int = 1,
|
||||
norm_layer: str = "group_norm",
|
||||
causal: bool = True,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
@ -250,6 +253,8 @@ class Decoder(nn.Module):
|
||||
block_params = block_params if isinstance(block_params, dict) else {}
|
||||
if block_name == "res_x_y":
|
||||
output_channel = output_channel * block_params.get("multiplier", 2)
|
||||
if block_name == "compress_all":
|
||||
output_channel = output_channel * block_params.get("multiplier", 1)
|
||||
|
||||
self.conv_in = make_conv_nd(
|
||||
dims,
|
||||
@ -276,6 +281,19 @@ class Decoder(nn.Module):
|
||||
resnet_eps=1e-6,
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
elif block_name == "attn_res_x":
|
||||
block = UNetMidBlock3D(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
num_layers=block_params["num_layers"],
|
||||
resnet_groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
attention_head_dim=block_params["attention_head_dim"],
|
||||
)
|
||||
elif block_name == "res_x_y":
|
||||
output_channel = output_channel // block_params.get("multiplier", 2)
|
||||
@ -286,6 +304,8 @@ class Decoder(nn.Module):
|
||||
eps=1e-6,
|
||||
groups=norm_num_groups,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=block_params.get("inject_noise", False),
|
||||
timestep_conditioning=False,
|
||||
)
|
||||
elif block_name == "compress_time":
|
||||
block = DepthToSpaceUpsample(
|
||||
@ -296,11 +316,13 @@ class Decoder(nn.Module):
|
||||
dims=dims, in_channels=input_channel, stride=(1, 2, 2)
|
||||
)
|
||||
elif block_name == "compress_all":
|
||||
output_channel = output_channel // block_params.get("multiplier", 1)
|
||||
block = DepthToSpaceUpsample(
|
||||
dims=dims,
|
||||
in_channels=input_channel,
|
||||
stride=(2, 2, 2),
|
||||
residual=block_params.get("residual", False),
|
||||
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown layer: {block_name}")
|
||||
@ -323,27 +345,75 @@ class Decoder(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
self.timestep_scale_multiplier = nn.Parameter(
|
||||
torch.tensor(1000.0, dtype=torch.float32)
|
||||
)
|
||||
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
output_channel * 2, 0, operations=ops,
|
||||
)
|
||||
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
|
||||
|
||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
||||
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Decoder` class."""
|
||||
# assert target_shape is not None, "target_shape must be provided"
|
||||
batch_size = sample.shape[0]
|
||||
|
||||
sample = self.conv_in(sample, causal=self.causal)
|
||||
|
||||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
||||
|
||||
checkpoint_fn = (
|
||||
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
||||
if self.gradient_checkpointing and self.training
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
sample = sample.to(upscale_dtype)
|
||||
scaled_timestep = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
sample = checkpoint_fn(up_block)(
|
||||
sample, causal=self.causal, timestep=scaled_timestep
|
||||
)
|
||||
else:
|
||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||
|
||||
sample = self.conv_norm_out(sample)
|
||||
|
||||
if self.timestep_conditioning:
|
||||
embedded_timestep = self.last_time_embedder(
|
||||
timestep=scaled_timestep.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=sample.shape[0],
|
||||
hidden_dtype=sample.dtype,
|
||||
)
|
||||
embedded_timestep = embedded_timestep.view(
|
||||
batch_size, embedded_timestep.shape[-1], 1, 1, 1
|
||||
)
|
||||
ada_values = self.last_scale_shift_table[
|
||||
None, ..., None, None, None
|
||||
] + embedded_timestep.reshape(
|
||||
batch_size,
|
||||
2,
|
||||
-1,
|
||||
embedded_timestep.shape[-3],
|
||||
embedded_timestep.shape[-2],
|
||||
embedded_timestep.shape[-1],
|
||||
)
|
||||
shift, scale = ada_values.unbind(dim=1)
|
||||
sample = sample * (1 + scale) + shift
|
||||
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
|
||||
@ -379,12 +449,21 @@ class UNetMidBlock3D(nn.Module):
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_groups: int = 32,
|
||||
norm_layer: str = "group_norm",
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
resnet_groups = (
|
||||
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
)
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
||||
in_channels * 4, 0, operations=ops,
|
||||
)
|
||||
|
||||
self.res_blocks = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock3D(
|
||||
@ -395,25 +474,48 @@ class UNetMidBlock3D(nn.Module):
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
norm_layer=norm_layer,
|
||||
inject_noise=inject_noise,
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, causal: bool = True
|
||||
self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
|
||||
) -> torch.FloatTensor:
|
||||
timestep_embed = None
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
batch_size = hidden_states.shape[0]
|
||||
timestep_embed = self.time_embedder(
|
||||
timestep=timestep.flatten(),
|
||||
resolution=None,
|
||||
aspect_ratio=None,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
timestep_embed = timestep_embed.view(
|
||||
batch_size, timestep_embed.shape[-1], 1, 1, 1
|
||||
)
|
||||
|
||||
for resnet in self.res_blocks:
|
||||
hidden_states = resnet(hidden_states, causal=causal)
|
||||
hidden_states = resnet(hidden_states, causal=causal, timestep=timestep_embed)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DepthToSpaceUpsample(nn.Module):
|
||||
def __init__(self, dims, in_channels, stride, residual=False):
|
||||
def __init__(
|
||||
self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
|
||||
):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.out_channels = math.prod(stride) * in_channels
|
||||
self.out_channels = (
|
||||
math.prod(stride) * in_channels // out_channels_reduction_factor
|
||||
)
|
||||
self.conv = make_conv_nd(
|
||||
dims=dims,
|
||||
in_channels=in_channels,
|
||||
@ -423,8 +525,9 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
causal=True,
|
||||
)
|
||||
self.residual = residual
|
||||
self.out_channels_reduction_factor = out_channels_reduction_factor
|
||||
|
||||
def forward(self, x, causal: bool = True):
|
||||
def forward(self, x, causal: bool = True, timestep: Optional[torch.Tensor] = None):
|
||||
if self.residual:
|
||||
# Reshape and duplicate the input to match the output shape
|
||||
x_in = rearrange(
|
||||
@ -434,7 +537,8 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
p2=self.stride[1],
|
||||
p3=self.stride[2],
|
||||
)
|
||||
x_in = x_in.repeat(1, math.prod(self.stride), 1, 1, 1)
|
||||
num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor
|
||||
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
||||
if self.stride[0] == 2:
|
||||
x_in = x_in[:, :, 1:, :, :]
|
||||
x = self.conv(x, causal=causal)
|
||||
@ -451,7 +555,6 @@ class DepthToSpaceUpsample(nn.Module):
|
||||
x = x + x_in
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
||||
super().__init__()
|
||||
@ -486,11 +589,14 @@ class ResnetBlock3D(nn.Module):
|
||||
groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
norm_layer: str = "group_norm",
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.inject_noise = inject_noise
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm1 = nn.GroupNorm(
|
||||
@ -513,6 +619,9 @@ class ResnetBlock3D(nn.Module):
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if inject_noise:
|
||||
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
||||
|
||||
if norm_layer == "group_norm":
|
||||
self.norm2 = nn.GroupNorm(
|
||||
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
||||
@ -534,6 +643,9 @@ class ResnetBlock3D(nn.Module):
|
||||
causal=True,
|
||||
)
|
||||
|
||||
if inject_noise:
|
||||
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
||||
|
||||
self.conv_shortcut = (
|
||||
make_linear_nd(
|
||||
dims=dims, in_channels=in_channels, out_channels=out_channels
|
||||
@ -548,29 +660,84 @@ class ResnetBlock3D(nn.Module):
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
self.timestep_conditioning = timestep_conditioning
|
||||
|
||||
if timestep_conditioning:
|
||||
self.scale_shift_table = nn.Parameter(
|
||||
torch.randn(4, in_channels) / in_channels**0.5
|
||||
)
|
||||
|
||||
def _feed_spatial_noise(
|
||||
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
||||
) -> torch.FloatTensor:
|
||||
spatial_shape = hidden_states.shape[-2:]
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
# similar to the "explicit noise inputs" method in style-gan
|
||||
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
|
||||
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
|
||||
hidden_states = hidden_states + scaled_noise
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_tensor: torch.FloatTensor,
|
||||
causal: bool = True,
|
||||
timestep: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
hidden_states = input_tensor
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
if self.timestep_conditioning:
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
ada_values = self.scale_shift_table[
|
||||
None, ..., None, None, None
|
||||
] + timestep.reshape(
|
||||
batch_size,
|
||||
4,
|
||||
-1,
|
||||
timestep.shape[-3],
|
||||
timestep.shape[-2],
|
||||
timestep.shape[-1],
|
||||
)
|
||||
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
|
||||
|
||||
hidden_states = hidden_states * (1 + scale1) + shift1
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states, causal=causal)
|
||||
|
||||
if self.inject_noise:
|
||||
hidden_states = self._feed_spatial_noise(
|
||||
hidden_states, self.per_channel_scale1
|
||||
)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
if self.timestep_conditioning:
|
||||
hidden_states = hidden_states * (1 + scale2) + shift2
|
||||
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = self.conv2(hidden_states, causal=causal)
|
||||
|
||||
if self.inject_noise:
|
||||
hidden_states = self._feed_spatial_noise(
|
||||
hidden_states, self.per_channel_scale2
|
||||
)
|
||||
|
||||
input_tensor = self.norm3(input_tensor)
|
||||
|
||||
batch_size = input_tensor.shape[0]
|
||||
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = input_tensor + hidden_states
|
||||
@ -634,33 +801,71 @@ class processor(nn.Module):
|
||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||
|
||||
class VideoVAE(nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, version=0):
|
||||
super().__init__()
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"blocks": [
|
||||
["res_x", 4],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x", 3],
|
||||
["res_x", 4],
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
}
|
||||
|
||||
if version == 0:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"blocks": [
|
||||
["res_x", 4],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x_y", 1],
|
||||
["res_x", 3],
|
||||
["compress_all", 1],
|
||||
["res_x", 3],
|
||||
["res_x", 4],
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
"_class_name": "CausalVideoAutoencoder",
|
||||
"dims": 3,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"decoder_blocks": [
|
||||
["res_x", {"num_layers": 5, "inject_noise": True}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 6, "inject_noise": True}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 7, "inject_noise": True}],
|
||||
["compress_all", {"residual": True, "multiplier": 2}],
|
||||
["res_x", {"num_layers": 8, "inject_noise": False}]
|
||||
],
|
||||
"encoder_blocks": [
|
||||
["res_x", {"num_layers": 4}],
|
||||
["compress_all", {}],
|
||||
["res_x_y", 1],
|
||||
["res_x", {"num_layers": 3}],
|
||||
["compress_all", {}],
|
||||
["res_x_y", 1],
|
||||
["res_x", {"num_layers": 3}],
|
||||
["compress_all", {}],
|
||||
["res_x", {"num_layers": 3}],
|
||||
["res_x", {"num_layers": 4}]
|
||||
],
|
||||
"scaling_factor": 1.0,
|
||||
"norm_layer": "pixel_norm",
|
||||
"patch_size": 4,
|
||||
"latent_log_var": "uniform",
|
||||
"use_quant_conv": False,
|
||||
"causal_decoder": False,
|
||||
"timestep_conditioning": True,
|
||||
}
|
||||
|
||||
double_z = config.get("double_z", True)
|
||||
latent_log_var = config.get(
|
||||
@ -671,7 +876,7 @@ class VideoVAE(nn.Module):
|
||||
dims=config["dims"],
|
||||
in_channels=config.get("in_channels", 3),
|
||||
out_channels=config["latent_channels"],
|
||||
blocks=config.get("encoder_blocks", config.get("blocks")),
|
||||
blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
latent_log_var=latent_log_var,
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
@ -681,18 +886,22 @@ class VideoVAE(nn.Module):
|
||||
dims=config["dims"],
|
||||
in_channels=config["latent_channels"],
|
||||
out_channels=config.get("out_channels", 3),
|
||||
blocks=config.get("decoder_blocks", config.get("blocks")),
|
||||
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
|
||||
patch_size=config.get("patch_size", 1),
|
||||
norm_layer=config.get("norm_layer", "group_norm"),
|
||||
causal=config.get("causal_decoder", False),
|
||||
timestep_conditioning=config.get("timestep_conditioning", False),
|
||||
)
|
||||
|
||||
self.timestep_conditioning = config.get("timestep_conditioning", False)
|
||||
self.per_channel_statistics = processor()
|
||||
|
||||
def encode(self, x):
|
||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
||||
return self.per_channel_statistics.normalize(means)
|
||||
|
||||
def decode(self, x):
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x))
|
||||
def decode(self, x, timestep=0.05, noise_scale=0.025):
|
||||
if self.timestep_conditioning: #TODO: seed
|
||||
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
|
||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
|
||||
|
||||
|
||||
@ -1,15 +1,16 @@
|
||||
import torch
|
||||
import logging
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
import logging as logpy
|
||||
from typing import Any, Dict, Tuple, Union, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from ..modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from ..util import instantiate_from_config, get_obj_from_str
|
||||
from ..modules.ema import LitEma
|
||||
from ..util import instantiate_from_config, get_obj_from_str
|
||||
from ... import ops
|
||||
|
||||
|
||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
def __init__(self, sample: bool = True):
|
||||
super().__init__()
|
||||
@ -39,11 +40,11 @@ class AbstractAutoencoder(torch.nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ema_decay: Union[None, float] = None,
|
||||
monitor: Union[None, str] = None,
|
||||
input_key: str = "jpg",
|
||||
**kwargs,
|
||||
self,
|
||||
ema_decay: Union[None, float] = None,
|
||||
monitor: Union[None, str] = None,
|
||||
input_key: str = "jpg",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -54,7 +55,7 @@ class AbstractAutoencoder(torch.nn.Module):
|
||||
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self, decay=ema_decay)
|
||||
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
def get_input(self, batch) -> Any:
|
||||
raise NotImplementedError()
|
||||
@ -70,14 +71,14 @@ class AbstractAutoencoder(torch.nn.Module):
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
logpy.info(f"{context}: Switched to EMA weights")
|
||||
logging.info(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
logpy.info(f"{context}: Restored training weights")
|
||||
logging.info(f"{context}: Restored training weights")
|
||||
|
||||
def encode(self, *args, **kwargs) -> torch.Tensor:
|
||||
raise NotImplementedError("encode()-method of abstract base class called")
|
||||
@ -86,7 +87,7 @@ class AbstractAutoencoder(torch.nn.Module):
|
||||
raise NotImplementedError("decode()-method of abstract base class called")
|
||||
|
||||
def instantiate_optimizer_from_config(self, params, lr, cfg):
|
||||
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
|
||||
return get_obj_from_str(cfg["target"])(
|
||||
params, lr=lr, **cfg.get("params", dict())
|
||||
)
|
||||
@ -103,18 +104,18 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
encoder_config: Dict,
|
||||
decoder_config: Dict,
|
||||
regularizer_config: Dict,
|
||||
**kwargs,
|
||||
self,
|
||||
*args,
|
||||
encoder_config: Dict,
|
||||
decoder_config: Dict,
|
||||
regularizer_config: Dict,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
|
||||
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
|
||||
self.regularization: DiagonalGaussianRegularizer = instantiate_from_config(
|
||||
self.regularization: Callable[[torch.Tensor], tuple[torch.Tensor, dict]] = instantiate_from_config(
|
||||
regularizer_config
|
||||
)
|
||||
|
||||
@ -122,10 +123,10 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
return self.decoder.get_last_layer()
|
||||
|
||||
def encode(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
return_reg_log: bool = False,
|
||||
unregularized: bool = False,
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
return_reg_log: bool = False,
|
||||
unregularized: bool = False,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
z = self.encoder(x)
|
||||
if unregularized:
|
||||
@ -140,7 +141,7 @@ class AutoencodingEngine(AbstractAutoencoder):
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, **additional_decode_kwargs
|
||||
self, x: torch.Tensor, **additional_decode_kwargs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, dict]:
|
||||
z, reg_log = self.encode(x, return_reg_log=True)
|
||||
dec = self.decode(z, **additional_decode_kwargs)
|
||||
@ -162,16 +163,23 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
self.quant_conv = ops.disable_weight_init.Conv2d(
|
||||
|
||||
if ddconfig.get("conv3d", False):
|
||||
conv_op = ops.disable_weight_init.Conv3d
|
||||
else:
|
||||
conv_op = ops.disable_weight_init.Conv2d
|
||||
|
||||
self.quant_conv = conv_op(
|
||||
(1 + ddconfig["double_z"]) * ddconfig["z_channels"],
|
||||
(1 + ddconfig["double_z"]) * embed_dim,
|
||||
1,
|
||||
)
|
||||
self.post_quant_conv = ops.disable_weight_init.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
|
||||
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_reg_log: bool = False
|
||||
self, x: torch.Tensor, return_reg_log: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]:
|
||||
if self.max_batch_size is None:
|
||||
z = self.encoder(x)
|
||||
@ -182,7 +190,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
n_batches = int(math.ceil(N / bs))
|
||||
z = list()
|
||||
for i_batch in range(n_batches):
|
||||
z_batch = self.encoder(x[i_batch * bs : (i_batch + 1) * bs])
|
||||
z_batch = self.encoder(x[i_batch * bs: (i_batch + 1) * bs])
|
||||
z_batch = self.quant_conv(z_batch)
|
||||
z.append(z_batch)
|
||||
z = torch.cat(z, 0)
|
||||
@ -202,7 +210,7 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
n_batches = int(math.ceil(N / bs))
|
||||
dec = list()
|
||||
for i_batch in range(n_batches):
|
||||
dec_batch = self.post_quant_conv(z[i_batch * bs : (i_batch + 1) * bs])
|
||||
dec_batch = self.post_quant_conv(z[i_batch * bs: (i_batch + 1) * bs])
|
||||
dec_batch = self.decoder(dec_batch, **decoder_kwargs)
|
||||
dec.append(dec_batch)
|
||||
dec = torch.cat(dec, 0)
|
||||
|
||||
@ -19,9 +19,6 @@ if model_management.xformers_enabled():
|
||||
if model_management.sage_attention_enabled():
|
||||
from sageattention import sageattn # pylint: disable=import-error
|
||||
|
||||
if model_management.flash_attn_enabled():
|
||||
from flash_attn import flash_attn_func # pylint: disable=import-error
|
||||
|
||||
from ...cli_args import args
|
||||
from ... import ops
|
||||
|
||||
@ -170,8 +167,6 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
b, _, dim_head = query.shape
|
||||
dim_head //= heads
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
if skip_reshape:
|
||||
query = query.reshape(b * heads, -1, dim_head)
|
||||
value = value.reshape(b * heads, -1, dim_head)
|
||||
@ -189,9 +184,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
|
||||
bytes_per_token = torch.finfo(query.dtype).bits // 8
|
||||
batch_x_heads, q_tokens, _ = query.shape
|
||||
_, _, k_tokens = key.shape
|
||||
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
|
||||
|
||||
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True)
|
||||
mem_free_total, _ = model_management.get_free_memory(query.device, True)
|
||||
|
||||
kv_chunk_size_min = None
|
||||
kv_chunk_size = None
|
||||
@ -243,7 +237,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
|
||||
scale = dim_head ** -0.5
|
||||
|
||||
h = heads
|
||||
if skip_reshape:
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
@ -349,12 +342,9 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
|
||||
b = q.shape[0]
|
||||
dim_head = q.shape[-1]
|
||||
# check to make sure xformers isn't broken
|
||||
disabled_xformers = False
|
||||
|
||||
if not disabled_xformers:
|
||||
@ -365,35 +355,44 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
return attention_pytorch(q, k, v, heads, mask, skip_reshape=skip_reshape)
|
||||
|
||||
if skip_reshape:
|
||||
# b h k d -> b k h d
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b * heads, -1, dim_head),
|
||||
lambda t: t.permute(0, 2, 1, 3),
|
||||
(q, k, v),
|
||||
)
|
||||
# actually do the reshaping
|
||||
else:
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.reshape(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
# add a singleton batch dimension
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a singleton heads dimension
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
# pad to a multiple of 8
|
||||
pad = 8 - mask.shape[-1] % 8
|
||||
mask_out = torch.empty([q.shape[0], q.shape[2], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
# the xformers docs says that it's allowed to have a mask of shape (1, Nq, Nk)
|
||||
# but when using separated heads, the shape has to be (B, H, Nq, Nk)
|
||||
# in flux, this matrix ends up being over 1GB
|
||||
# here, we create a mask with the same batch/head size as the input mask (potentially singleton or full)
|
||||
mask_out = torch.empty([mask.shape[0], mask.shape[1], q.shape[1], mask.shape[-1] + pad], dtype=q.dtype, device=q.device)
|
||||
|
||||
mask_out[..., :mask.shape[-1]] = mask
|
||||
# doesn't this remove the padding again??
|
||||
mask = mask_out[..., :mask.shape[-1]]
|
||||
mask = mask.expand(b, heads, -1, -1)
|
||||
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=mask) # pylint: disable=possibly-used-before-assignment
|
||||
|
||||
if skip_reshape:
|
||||
out = (
|
||||
out.unsqueeze(0)
|
||||
.reshape(b, heads, -1, dim_head)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = (
|
||||
out.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
out = (
|
||||
out.reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
|
||||
return out
|
||||
|
||||
@ -406,6 +405,11 @@ else:
|
||||
|
||||
|
||||
def pytorch_style_decl(func):
|
||||
"""
|
||||
wraps a pytorch style functon for attention into one that can be used with comfyui
|
||||
:param func:
|
||||
:return:
|
||||
"""
|
||||
@wraps(func)
|
||||
def wrapper(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
@ -432,38 +436,96 @@ def pytorch_style_decl(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
@pytorch_style_decl
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
if SDP_BATCH_LIMIT >= b:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = torch.empty((b, q.shape[2], heads * dim_head), dtype=q.dtype, layout=q.layout, device=q.device)
|
||||
for i in range(0, b, SDP_BATCH_LIMIT):
|
||||
m = mask
|
||||
if mask is not None:
|
||||
if mask.shape[0] > 1:
|
||||
m = mask[i : i + SDP_BATCH_LIMIT]
|
||||
|
||||
out[i : i + SDP_BATCH_LIMIT] = torch.nn.functional.scaled_dot_product_attention(
|
||||
q[i : i + SDP_BATCH_LIMIT],
|
||||
k[i : i + SDP_BATCH_LIMIT],
|
||||
v[i : i + SDP_BATCH_LIMIT],
|
||||
attn_mask=m,
|
||||
dropout_p=0.0, is_causal=False
|
||||
).transpose(1, 2).reshape(-1, q.shape[2], heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
@pytorch_style_decl
|
||||
def attention_sageattn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
return sageattn(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) # pylint: disable=possibly-used-before-assignment
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout="HND"
|
||||
else:
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
lambda t: t.view(b, -1, heads, dim_head),
|
||||
(q, k, v),
|
||||
)
|
||||
tensor_layout="NHD"
|
||||
|
||||
if mask is not None:
|
||||
# add a batch dimension if there isn't already one
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
# add a heads dimension if there isn't already one
|
||||
if mask.ndim == 3:
|
||||
mask = mask.unsqueeze(1)
|
||||
|
||||
@pytorch_style_decl
|
||||
def attention_flash_attn(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
|
||||
return flash_attn_func(q, k, v) # pylint: disable=possibly-used-before-assignment
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
if tensor_layout == "HND":
|
||||
out = (
|
||||
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
|
||||
)
|
||||
else:
|
||||
out = out.reshape(b, -1, heads * dim_head)
|
||||
return out
|
||||
|
||||
|
||||
optimized_attention = attention_basic
|
||||
|
||||
if model_management.sage_attention_enabled():
|
||||
logger.info("Using sage attention")
|
||||
optimized_attention = attention_sageattn
|
||||
optimized_attention = attention_sage
|
||||
elif model_management.xformers_enabled():
|
||||
logger.info("Using xformers cross attention")
|
||||
logger.info("Using xformers attention")
|
||||
optimized_attention = attention_xformers
|
||||
elif model_management.pytorch_attention_enabled():
|
||||
logger.info("Using pytorch cross attention")
|
||||
logger.info("Using pytorch attention")
|
||||
optimized_attention = attention_pytorch
|
||||
else:
|
||||
if args.use_split_cross_attention:
|
||||
logger.info("Using split optimization for cross attention")
|
||||
logger.info("Using split optimization for attention")
|
||||
optimized_attention = attention_split
|
||||
else:
|
||||
logger.info("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
logger.info("Using sub quadratic optimization for attention, if you have memory or speed issues try using: --use-split-cross-attention")
|
||||
optimized_attention = attention_sub_quad
|
||||
|
||||
optimized_attention_masked = optimized_attention
|
||||
|
||||
@ -73,45 +73,33 @@ class PatchEmbed(nn.Module):
|
||||
strict_img_size: bool = True,
|
||||
dynamic_img_pad: bool = True,
|
||||
padding_mode='circular',
|
||||
conv3d=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
try:
|
||||
len(patch_size)
|
||||
self.patch_size = patch_size
|
||||
except:
|
||||
if conv3d:
|
||||
self.patch_size = (patch_size, patch_size, patch_size)
|
||||
else:
|
||||
self.patch_size = (patch_size, patch_size)
|
||||
self.padding_mode = padding_mode
|
||||
if img_size is not None:
|
||||
self.img_size = (img_size, img_size)
|
||||
self.grid_size = tuple([s // p for s, p in zip(self.img_size, self.patch_size)])
|
||||
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
||||
else:
|
||||
self.img_size = None
|
||||
self.grid_size = None
|
||||
self.num_patches = None
|
||||
|
||||
# flatten spatial dim and transpose to channels last, kept for bwd compat
|
||||
self.flatten = flatten
|
||||
self.strict_img_size = strict_img_size
|
||||
self.dynamic_img_pad = dynamic_img_pad
|
||||
|
||||
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||
if conv3d:
|
||||
self.proj = operations.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||
else:
|
||||
self.proj = operations.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, dtype=dtype, device=device)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
# B, C, H, W = x.shape
|
||||
# if self.img_size is not None:
|
||||
# if self.strict_img_size:
|
||||
# _assert(H == self.img_size[0], f"Input height ({H}) doesn't match model ({self.img_size[0]}).")
|
||||
# _assert(W == self.img_size[1], f"Input width ({W}) doesn't match model ({self.img_size[1]}).")
|
||||
# elif not self.dynamic_img_pad:
|
||||
# _assert(
|
||||
# H % self.patch_size[0] == 0,
|
||||
# f"Input height ({H}) should be divisible by patch size ({self.patch_size[0]})."
|
||||
# )
|
||||
# _assert(
|
||||
# W % self.patch_size[1] == 0,
|
||||
# f"Input width ({W}) should be divisible by patch size ({self.patch_size[1]})."
|
||||
# )
|
||||
if self.dynamic_img_pad:
|
||||
x = common_dit.pad_to_patch_size(x, self.patch_size, padding_mode=self.padding_mode)
|
||||
x = self.proj(x)
|
||||
|
||||
@ -1,17 +1,20 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import logging
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
from .... import model_management
|
||||
from .... import ops
|
||||
|
||||
ops = ops.disable_weight_init
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
import xformers # pylint: disable=import-error
|
||||
import xformers.ops # pylint: disable=import-error
|
||||
import xformers.ops # pylint: disable=import-error
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
@ -30,64 +33,115 @@ def get_timestep_embedding(timesteps, embedding_dim):
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
class VideoConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
self.padding_mode = padding_mode
|
||||
if padding != 0:
|
||||
padding = (padding, padding, padding, padding, kernel_size - 1, 0)
|
||||
else:
|
||||
kwargs["padding"] = padding
|
||||
|
||||
self.padding = padding
|
||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
try:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
except: #operation not implemented for bf16
|
||||
b, c, h, w = x.shape
|
||||
out = torch.empty((b, c, h*2, w*2), dtype=x.dtype, layout=x.layout, device=x.device)
|
||||
split = 8
|
||||
l = out.shape[1] // split
|
||||
for i in range(0, out.shape[1], l):
|
||||
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
|
||||
del x
|
||||
x = out
|
||||
if self.padding != 0:
|
||||
x = torch.nn.functional.pad(x, self.padding, mode=self.padding_mode)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def interpolate_up(x, scale_factor):
|
||||
try:
|
||||
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
|
||||
except: # operation not implemented for bf16
|
||||
orig_shape = list(x.shape)
|
||||
out_shape = orig_shape[:2]
|
||||
for i in range(len(orig_shape) - 2):
|
||||
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
|
||||
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
|
||||
split = 8
|
||||
l = out.shape[1] // split
|
||||
for i in range(0, out.shape[1], l):
|
||||
out[:, i:i + l] = torch.nn.functional.interpolate(x[:, i:i + l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
|
||||
return out
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
if self.with_conv:
|
||||
self.conv = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
scale_factor = self.scale_factor
|
||||
if isinstance(scale_factor, (int, float)):
|
||||
scale_factor = (scale_factor,) * (x.ndim - 2)
|
||||
|
||||
if x.ndim == 5 and scale_factor[0] > 1.0:
|
||||
t = x.shape[2]
|
||||
if t > 1:
|
||||
a, b = x.split((1, t - 1), dim=2)
|
||||
del x
|
||||
b = interpolate_up(b, scale_factor)
|
||||
else:
|
||||
a = x
|
||||
|
||||
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
|
||||
if t > 1:
|
||||
x = torch.cat((a, b), dim=2)
|
||||
else:
|
||||
x = a
|
||||
else:
|
||||
x = interpolate_up(x, scale_factor)
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
def __init__(self, in_channels, with_conv, stride=2, conv_op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0)
|
||||
self.conv = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0,1,0,1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
if x.ndim == 4:
|
||||
pad = (0, 1, 0, 1)
|
||||
mode = "constant"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
|
||||
elif x.ndim == 5:
|
||||
pad = (1, 1, 1, 1, 2, 0)
|
||||
mode = "replicate"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
@ -96,7 +150,7 @@ class Downsample(nn.Module):
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512):
|
||||
dropout, temb_channels=512, conv_op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
@ -105,34 +159,34 @@ class ResnetBlock(nn.Module):
|
||||
|
||||
self.swish = torch.nn.SiLU(inplace=True)
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = ops.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv1 = conv_op(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = ops.Linear(temb_channels,
|
||||
out_channels)
|
||||
out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout, inplace=True)
|
||||
self.conv2 = ops.Conv2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv2 = conv_op(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = ops.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_shortcut = conv_op(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = ops.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.nin_shortcut = conv_op(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
@ -141,7 +195,7 @@ class ResnetBlock(nn.Module):
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
|
||||
h = h + self.temb_proj(self.swish(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.swish(h)
|
||||
@ -154,22 +208,22 @@ class ResnetBlock(nn.Module):
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
return x + h
|
||||
|
||||
|
||||
def slice_attention(q, k, v):
|
||||
r1 = torch.zeros_like(k, device=q.device)
|
||||
scale = (int(q.shape[-1])**(-0.5))
|
||||
scale = (int(q.shape[-1]) ** (-0.5))
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
|
||||
modifier = 3 if q.element_size() == 2 else 2.5
|
||||
mem_required = tensor_size * modifier
|
||||
steps = 1
|
||||
|
||||
if mem_required > mem_free_total:
|
||||
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2)))
|
||||
|
||||
while True:
|
||||
try:
|
||||
@ -178,7 +232,7 @@ def slice_attention(q, k, v):
|
||||
end = i + slice_size
|
||||
s1 = torch.bmm(q[:, i:end], k) * scale
|
||||
|
||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0,2,1)
|
||||
s2 = torch.nn.functional.softmax(s1, dim=2).permute(0, 2, 1)
|
||||
del s1
|
||||
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
@ -193,23 +247,29 @@ def slice_attention(q, k, v):
|
||||
|
||||
return r1
|
||||
|
||||
|
||||
def normal_attention(q, k, v):
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
orig_shape = q.shape
|
||||
b = orig_shape[0]
|
||||
c = orig_shape[1]
|
||||
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
v = v.reshape(b,c,h*w)
|
||||
q = q.reshape(b, c, -1)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, -1) # b,c,hw
|
||||
v = v.reshape(b, c, -1)
|
||||
|
||||
r1 = slice_attention(q, k, v)
|
||||
h_ = r1.reshape(b,c,h,w)
|
||||
h_ = r1.reshape(orig_shape)
|
||||
del r1
|
||||
return h_
|
||||
|
||||
|
||||
def xformers_attention(q, k, v):
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
orig_shape = q.shape
|
||||
B = orig_shape[0]
|
||||
C = orig_shape[1]
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, C, -1).transpose(1, 2).contiguous(),
|
||||
(q, k, v),
|
||||
@ -217,14 +277,17 @@ def xformers_attention(q, k, v):
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
|
||||
out = out.transpose(1, 2).reshape(B, C, H, W)
|
||||
out = out.transpose(1, 2).reshape(orig_shape)
|
||||
else:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
return out
|
||||
|
||||
|
||||
def pytorch_attention(q, k, v):
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
orig_shape = q.shape
|
||||
B = orig_shape[0]
|
||||
C = orig_shape[1]
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
@ -232,39 +295,39 @@ def pytorch_attention(q, k, v):
|
||||
|
||||
try:
|
||||
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(B, C, H, W)
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
out = out.transpose(2, 3).reshape(orig_shape)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
||||
return out
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
def __init__(self, in_channels, conv_op=ops.Conv2d):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = ops.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.q = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = conv_op(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
if model_management.xformers_enabled_vae():
|
||||
logging.debug("Using xformers attention in VAE")
|
||||
@ -287,21 +350,21 @@ class AttnBlock(nn.Module):
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
return x + h_
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None):
|
||||
return AttnBlock(in_channels)
|
||||
def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None, conv_op=ops.Conv2d):
|
||||
return AttnBlock(in_channels, conv_op=conv_op)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.temb_ch = self.ch * 4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
@ -313,26 +376,26 @@ class Model(nn.Module):
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
ops.Linear(self.ch,
|
||||
self.temb_ch),
|
||||
self.temb_ch),
|
||||
ops.Linear(self.temb_ch,
|
||||
self.temb_ch),
|
||||
self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = ops.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
@ -344,7 +407,7 @@ class Model(nn.Module):
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
@ -366,12 +429,12 @@ class Model(nn.Module):
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
skip_in = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
block_out = ch * ch_mult[i_level]
|
||||
skip_in = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch*in_ch_mult[i_level]
|
||||
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
||||
skip_in = ch * in_ch_mult[i_level]
|
||||
block.append(ResnetBlock(in_channels=block_in + skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
@ -384,18 +447,18 @@ class Model(nn.Module):
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = ops.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x, t=None, context=None):
|
||||
#assert x.shape[2] == x.shape[3] == self.resolution
|
||||
# assert x.shape[2] == x.shape[3] == self.resolution
|
||||
if context is not None:
|
||||
# assume aligned context, cat along channel axis
|
||||
x = torch.cat((x, context), dim=1)
|
||||
@ -417,7 +480,7 @@ class Model(nn.Module):
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
@ -428,7 +491,7 @@ class Model(nn.Module):
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
@ -447,9 +510,10 @@ class Model(nn.Module):
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
||||
conv3d=False, time_compress=None,
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
@ -460,35 +524,47 @@ class Encoder(nn.Module):
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
if conv3d:
|
||||
conv_op = VideoConv3d
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
mid_attn_conv_op = ops.Conv2d
|
||||
|
||||
# downsampling
|
||||
self.conv_in = ops.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_in = conv_op(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
in_ch_mult = (1,) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
dropout=dropout,
|
||||
conv_op=conv_op))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
attn.append(make_attn(block_in, attn_type=attn_type, conv_op=conv_op))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
stride = 2
|
||||
if time_compress is not None:
|
||||
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
|
||||
stride = (1, 2, 2)
|
||||
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
@ -497,20 +573,22 @@ class Encoder(nn.Module):
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
dropout=dropout,
|
||||
conv_op=conv_op)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type, conv_op=mid_attn_conv_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
dropout=dropout,
|
||||
conv_op=conv_op)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = ops.Conv2d(block_in,
|
||||
2*z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_out = conv_op(block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
@ -522,7 +600,7 @@ class Encoder(nn.Module):
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
# middle
|
||||
@ -538,15 +616,16 @@ class Encoder(nn.Module):
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
conv_out_op=ops.Conv2d,
|
||||
resnet_op=ResnetBlock,
|
||||
attn_op=AttnBlock,
|
||||
**ignorekwargs):
|
||||
conv3d=False,
|
||||
time_compress=None,
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
@ -556,65 +635,80 @@ class Decoder(nn.Module):
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
if conv3d:
|
||||
conv_op = VideoConv3d
|
||||
conv_out_op = VideoConv3d
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
mid_attn_conv_op = ops.Conv2d
|
||||
|
||||
# compute block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
logging.debug("Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = ops.Conv2d(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.conv_in = conv_op(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = resnet_op(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = attn_op(block_in)
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
conv_op=conv_op)
|
||||
self.mid.attn_1 = attn_op(block_in, conv_op=mid_attn_conv_op)
|
||||
self.mid.block_2 = resnet_op(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
conv_op=conv_op)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(resnet_op(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
conv_op=conv_op))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(attn_op(block_in))
|
||||
attn.append(attn_op(block_in, conv_op=conv_op))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
scale_factor = 2.0
|
||||
if time_compress is not None:
|
||||
if i_level > math.log2(time_compress):
|
||||
scale_factor = (1.0, 2.0, 2.0)
|
||||
|
||||
up.upsample = Upsample(block_in, resamp_with_conv, conv_op=conv_op, scale_factor=scale_factor)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = conv_out_op(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, z, **kwargs):
|
||||
#assert z.shape[1:] == self.z_shape[1:]
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
@ -630,7 +724,7 @@ class Decoder(nn.Module):
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||
|
||||
@ -9,6 +9,7 @@
|
||||
|
||||
|
||||
import math
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
@ -130,7 +131,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
logging.info(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
@ -142,8 +143,8 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
print(f'For the chosen value of eta, which is {eta}, '
|
||||
logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
logging.info(f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
@ -30,10 +30,10 @@ class DiagonalGaussianDistribution(object):
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape, device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
|
||||
@ -22,7 +22,6 @@ except ImportError:
|
||||
from typing import Optional, NamedTuple, List
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from torch import Tensor
|
||||
from typing import List
|
||||
|
||||
from ... import model_management
|
||||
@ -172,7 +171,7 @@ def _get_attention_scores_no_kv_chunking(
|
||||
del attn_scores
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values
|
||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
|
||||
torch.exp(attn_scores, out=attn_scores)
|
||||
summed = torch.sum(attn_scores, dim=-1, keepdim=True)
|
||||
attn_scores /= summed
|
||||
|
||||
@ -194,6 +194,7 @@ def make_time_attn(
|
||||
attn_kwargs=None,
|
||||
alpha: float = 0,
|
||||
merge_strategy: str = "learned",
|
||||
conv_op=ops.Conv2d,
|
||||
):
|
||||
return partialclass(
|
||||
AttnVideoBlock, in_channels, alpha=alpha, merge_strategy=merge_strategy
|
||||
|
||||
381
comfy/ldm/pixart/blocks.py
Normal file
381
comfy/ldm/pixart/blocks.py
Normal file
@ -0,0 +1,381 @@
|
||||
# Based on:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers.ops
|
||||
if int((xformers.__version__).split(".")[2]) >= 28:
|
||||
block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
|
||||
else:
|
||||
block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
def t2i_modulate(x, shift, scale):
|
||||
return x * (1 + scale) + shift
|
||||
|
||||
class MultiHeadCrossAttention(nn.Module):
|
||||
def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., dtype=None, device=None, operations=None, **kwargs):
|
||||
super(MultiHeadCrossAttention, self).__init__()
|
||||
assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
|
||||
|
||||
self.d_model = d_model
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = d_model // num_heads
|
||||
|
||||
self.q_linear = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.kv_linear = operations.Linear(d_model, d_model*2, dtype=dtype, device=device)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, cond, mask=None):
|
||||
# query/value: img tokens; key: condition; mask: if padding tokens
|
||||
B, N, C = x.shape
|
||||
|
||||
q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
|
||||
kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
|
||||
k, v = kv.unbind(2)
|
||||
|
||||
assert mask is None # TODO?
|
||||
# # TODO: xformers needs separate mask logic here
|
||||
# if model_management.xformers_enabled():
|
||||
# attn_bias = None
|
||||
# if mask is not None:
|
||||
# attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask)
|
||||
# x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias)
|
||||
# else:
|
||||
# q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
||||
# attn_mask = None
|
||||
# mask = torch.ones(())
|
||||
# if mask is not None and len(mask) > 1:
|
||||
# # Create equivalent of xformer diagonal block mask, still only correct for square masks
|
||||
# # But depth doesn't matter as tensors can expand in that dimension
|
||||
# attn_mask_template = torch.ones(
|
||||
# [q.shape[2] // B, mask[0]],
|
||||
# dtype=torch.bool,
|
||||
# device=q.device
|
||||
# )
|
||||
# attn_mask = torch.block_diag(attn_mask_template)
|
||||
#
|
||||
# # create a mask on the diagonal for each mask in the batch
|
||||
# for _ in range(B - 1):
|
||||
# attn_mask = torch.block_diag(attn_mask, attn_mask_template)
|
||||
# x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True)
|
||||
|
||||
x = optimized_attention(q.view(B, -1, C), k.view(B, -1, C), v.view(B, -1, C), self.num_heads, mask=None)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttentionKVCompress(nn.Module):
|
||||
"""Multi-head Attention block with KV token compression and qk norm."""
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=True, sampling='conv', sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool: If True, add a learnable bias to query, key, value.
|
||||
"""
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
|
||||
self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every']
|
||||
self.sr_ratio = sr_ratio
|
||||
if sr_ratio > 1 and sampling == 'conv':
|
||||
# Avg Conv Init.
|
||||
self.sr = operations.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio, dtype=dtype, device=device)
|
||||
# self.sr.weight.data.fill_(1/sr_ratio**2)
|
||||
# self.sr.bias.data.zero_()
|
||||
self.norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
if qk_norm:
|
||||
self.q_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.q_norm = nn.Identity()
|
||||
self.k_norm = nn.Identity()
|
||||
|
||||
def downsample_2d(self, tensor, H, W, scale_factor, sampling=None):
|
||||
if sampling is None or scale_factor == 1:
|
||||
return tensor
|
||||
B, N, C = tensor.shape
|
||||
|
||||
if sampling == 'uniform_every':
|
||||
return tensor[:, ::scale_factor], int(N // scale_factor)
|
||||
|
||||
tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2)
|
||||
new_H, new_W = int(H / scale_factor), int(W / scale_factor)
|
||||
new_N = new_H * new_W
|
||||
|
||||
if sampling == 'ave':
|
||||
tensor = F.interpolate(
|
||||
tensor, scale_factor=1 / scale_factor, mode='nearest'
|
||||
).permute(0, 2, 3, 1)
|
||||
elif sampling == 'uniform':
|
||||
tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1)
|
||||
elif sampling == 'conv':
|
||||
tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1)
|
||||
tensor = self.norm(tensor)
|
||||
else:
|
||||
raise ValueError
|
||||
|
||||
return tensor.reshape(B, new_N, C).contiguous(), new_N
|
||||
|
||||
def forward(self, x, mask=None, HW=None, block_id=None):
|
||||
B, N, C = x.shape # 2 4096 1152
|
||||
new_N = N
|
||||
if HW is None:
|
||||
H = W = int(N ** 0.5)
|
||||
else:
|
||||
H, W = HW
|
||||
qkv = self.qkv(x).reshape(B, N, 3, C)
|
||||
|
||||
q, k, v = qkv.unbind(2)
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# KV compression
|
||||
if self.sr_ratio > 1:
|
||||
k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling)
|
||||
|
||||
q = q.reshape(B, N, self.num_heads, C // self.num_heads)
|
||||
k = k.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||
v = v.reshape(B, new_N, self.num_heads, C // self.num_heads)
|
||||
|
||||
if mask is not None:
|
||||
raise NotImplementedError("Attn mask logic not added for self attention")
|
||||
|
||||
# This is never called at the moment
|
||||
# attn_bias = None
|
||||
# if mask is not None:
|
||||
# attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device)
|
||||
# attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf'))
|
||||
|
||||
# attention 2
|
||||
q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),)
|
||||
x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True)
|
||||
|
||||
x = x.view(B, N, C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
|
||||
def forward(self, x, c):
|
||||
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
class T2IFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5)
|
||||
self.out_channels = out_channels
|
||||
|
||||
def forward(self, x, t):
|
||||
shift, scale = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t[:, None]).chunk(2, dim=1)
|
||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class MaskFinalLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
def forward(self, x, t):
|
||||
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||
x = modulate(self.norm_final(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""
|
||||
The final layer of PixArt.
|
||||
"""
|
||||
def __init__(self, hidden_size, decoder_hidden_size, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_decoder = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, decoder_hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
def forward(self, x, t):
|
||||
shift, scale = self.adaLN_modulation(t).chunk(2, dim=1)
|
||||
x = modulate(self.norm_decoder(x), shift, scale)
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
class SizeEmbedder(TimestepEmbedder):
|
||||
"""
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size, operations=operations)
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
self.outdim = hidden_size
|
||||
|
||||
def forward(self, s, bs):
|
||||
if s.ndim == 1:
|
||||
s = s[:, None]
|
||||
assert s.ndim == 2
|
||||
if s.shape[0] != bs:
|
||||
s = s.repeat(bs//s.shape[0], 1)
|
||||
assert s.shape[0] == bs
|
||||
b, dims = s.shape[0], s.shape[1]
|
||||
s = rearrange(s, "b d -> (b d)")
|
||||
s_freq = timestep_embedding(s, self.frequency_embedding_size)
|
||||
s_emb = self.mlp(s_freq.to(s.dtype))
|
||||
s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim)
|
||||
return s_emb
|
||||
|
||||
|
||||
class LabelEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
def __init__(self, num_classes, hidden_size, dropout_prob, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
use_cfg_embedding = dropout_prob > 0
|
||||
self.embedding_table = operations.Embedding(num_classes + use_cfg_embedding, hidden_size, dtype=dtype, device=device),
|
||||
self.num_classes = num_classes
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
def token_drop(self, labels, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||
return labels
|
||||
|
||||
def forward(self, labels, train, force_drop_ids=None):
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
labels = self.token_drop(labels, force_drop_ids)
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CaptionEmbedder(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.y_proj = Mlp(
|
||||
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5))
|
||||
self.uncond_prob = uncond_prob
|
||||
|
||||
def token_drop(self, caption, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||
return caption
|
||||
|
||||
def forward(self, caption, train, force_drop_ids=None):
|
||||
if train:
|
||||
assert caption.shape[2:] == self.y_embedding.shape
|
||||
use_dropout = self.uncond_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
caption = self.token_drop(caption, force_drop_ids)
|
||||
caption = self.y_proj(caption)
|
||||
return caption
|
||||
|
||||
|
||||
class CaptionEmbedderDoubleBr(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
"""
|
||||
def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.proj = Mlp(
|
||||
in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5)
|
||||
self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5)
|
||||
self.uncond_prob = uncond_prob
|
||||
|
||||
def token_drop(self, global_caption, caption, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob
|
||||
else:
|
||||
drop_ids = force_drop_ids == 1
|
||||
global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption)
|
||||
caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption)
|
||||
return global_caption, caption
|
||||
|
||||
def forward(self, caption, train, force_drop_ids=None):
|
||||
assert caption.shape[2: ] == self.y_embedding.shape
|
||||
global_caption = caption.mean(dim=2).squeeze()
|
||||
use_dropout = self.uncond_prob > 0
|
||||
if (train and use_dropout) or (force_drop_ids is not None):
|
||||
global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids)
|
||||
y_embed = self.proj(global_caption)
|
||||
return y_embed, caption
|
||||
256
comfy/ldm/pixart/pixartms.py
Normal file
256
comfy/ldm/pixart/pixartms.py
Normal file
@ -0,0 +1,256 @@
|
||||
# Based on:
|
||||
# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license]
|
||||
# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license]
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .blocks import (
|
||||
t2i_modulate,
|
||||
CaptionEmbedder,
|
||||
AttentionKVCompress,
|
||||
MultiHeadCrossAttention,
|
||||
T2IFinalLayer,
|
||||
SizeEmbedder,
|
||||
)
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp, get_1d_sincos_pos_embed_from_grid_torch
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32):
|
||||
grid_h, grid_w = torch.meshgrid(
|
||||
torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation,
|
||||
torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation,
|
||||
indexing='ij'
|
||||
)
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype)
|
||||
emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
class PixArtMSBlock(nn.Module):
|
||||
"""
|
||||
A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||
"""
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None,
|
||||
sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.attn = AttentionKVCompress(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio,
|
||||
qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.cross_attn = MultiHeadCrossAttention(
|
||||
hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs
|
||||
)
|
||||
self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
# to be compatible with lower version pytorch
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5)
|
||||
|
||||
def forward(self, x, y, t, mask=None, HW=None, **kwargs):
|
||||
B, N, C = x.shape
|
||||
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(dtype=x.dtype, device=x.device) + t.reshape(B, 6, -1)).chunk(6, dim=1)
|
||||
x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW))
|
||||
x = x + self.cross_attn(x, y, mask)
|
||||
x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
### Core PixArt Model ###
|
||||
class PixArtMS(nn.Module):
|
||||
"""
|
||||
Diffusion model with a Transformer backbone.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
input_size=32,
|
||||
patch_size=2,
|
||||
in_channels=4,
|
||||
hidden_size=1152,
|
||||
depth=28,
|
||||
num_heads=16,
|
||||
mlp_ratio=4.0,
|
||||
class_dropout_prob=0.1,
|
||||
learn_sigma=True,
|
||||
pred_sigma=True,
|
||||
drop_path: float = 0.,
|
||||
caption_channels=4096,
|
||||
pe_interpolation=None,
|
||||
pe_precision=None,
|
||||
config=None,
|
||||
model_max_length=120,
|
||||
micro_condition=True,
|
||||
qk_norm=False,
|
||||
kv_compress_config=None,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
self.dtype = dtype
|
||||
self.pred_sigma = pred_sigma
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels * 2 if pred_sigma else in_channels
|
||||
self.patch_size = patch_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_interpolation = pe_interpolation
|
||||
self.pe_precision = pe_precision
|
||||
self.hidden_size = hidden_size
|
||||
self.depth = depth
|
||||
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.t_block = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device)
|
||||
)
|
||||
self.x_embedder = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_chans=in_channels,
|
||||
embed_dim=hidden_size,
|
||||
bias=True,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations
|
||||
)
|
||||
self.t_embedder = TimestepEmbedder(
|
||||
hidden_size, dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
self.y_embedder = CaptionEmbedder(
|
||||
in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob,
|
||||
act_layer=approx_gelu, token_num=model_max_length,
|
||||
dtype=dtype, device=device, operations=operations,
|
||||
)
|
||||
|
||||
self.micro_conditioning = micro_condition
|
||||
if self.micro_conditioning:
|
||||
self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# For fixed sin-cos embedding:
|
||||
# num_patches = (input_size // patch_size) * (input_size // patch_size)
|
||||
# self.base_size = input_size // self.patch_size
|
||||
# self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size))
|
||||
|
||||
drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule
|
||||
if kv_compress_config is None:
|
||||
kv_compress_config = {
|
||||
'sampling': None,
|
||||
'scale_factor': 1,
|
||||
'kv_compress_layer': [],
|
||||
}
|
||||
self.blocks = nn.ModuleList([
|
||||
PixArtMSBlock(
|
||||
hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i],
|
||||
sampling=kv_compress_config['sampling'],
|
||||
sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1,
|
||||
qk_norm=qk_norm,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
operations=operations,
|
||||
)
|
||||
for i in range(depth)
|
||||
])
|
||||
self.final_layer = T2IFinalLayer(
|
||||
hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs):
|
||||
"""
|
||||
Original forward pass of PixArt.
|
||||
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
|
||||
t: (N,) tensor of diffusion timesteps
|
||||
y: (N, 1, 120, C) conditioning
|
||||
ar: (N, 1): aspect ratio
|
||||
cs: (N ,2) size conditioning for height/width
|
||||
"""
|
||||
B, C, H, W = x.shape
|
||||
c_res = (H + W) // 2
|
||||
pe_interpolation = self.pe_interpolation
|
||||
if pe_interpolation is None or self.pe_precision is not None:
|
||||
# calculate pe_interpolation on-the-fly
|
||||
pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0)
|
||||
|
||||
pos_embed = get_2d_sincos_pos_embed_torch(
|
||||
self.hidden_size,
|
||||
h=(H // self.patch_size),
|
||||
w=(W // self.patch_size),
|
||||
pe_interpolation=pe_interpolation,
|
||||
base_size=((round(c_res / 64) * 64) // self.patch_size),
|
||||
device=x.device,
|
||||
dtype=x.dtype,
|
||||
).unsqueeze(0)
|
||||
|
||||
x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2
|
||||
t = self.t_embedder(timestep, x.dtype) # (N, D)
|
||||
|
||||
if self.micro_conditioning and (c_size is not None and c_ar is not None):
|
||||
bs = x.shape[0]
|
||||
c_size = self.csize_embedder(c_size, bs) # (N, D)
|
||||
c_ar = self.ar_embedder(c_ar, bs) # (N, D)
|
||||
t = t + torch.cat([c_size, c_ar], dim=1)
|
||||
|
||||
t0 = self.t_block(t)
|
||||
y = self.y_embedder(y, self.training) # (N, D)
|
||||
|
||||
if mask is not None:
|
||||
if mask.shape[0] != y.shape[0]:
|
||||
mask = mask.repeat(y.shape[0] // mask.shape[0], 1)
|
||||
mask = mask.squeeze(1).squeeze(1)
|
||||
y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1])
|
||||
y_lens = mask.sum(dim=1).tolist()
|
||||
else:
|
||||
y_lens = None
|
||||
y = y.squeeze(1).view(1, -1, x.shape[-1])
|
||||
for block in self.blocks:
|
||||
x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D)
|
||||
|
||||
x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels)
|
||||
x = self.unpatchify(x, H, W) # (N, out_channels, H, W)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# Fallback for missing microconds
|
||||
if self.micro_conditioning:
|
||||
if c_size is None:
|
||||
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
if c_ar is None:
|
||||
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
## Still accepts the input w/o that dim but returns garbage
|
||||
if len(context.shape) == 3:
|
||||
context = context.unsqueeze(1)
|
||||
|
||||
## run original forward pass
|
||||
out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar)
|
||||
|
||||
## only return EPS
|
||||
if self.pred_sigma:
|
||||
return out[:, :self.in_channels]
|
||||
return out
|
||||
|
||||
def unpatchify(self, x, h, w):
|
||||
"""
|
||||
x: (N, T, patch_size**2 * C)
|
||||
imgs: (N, H, W, C)
|
||||
"""
|
||||
c = self.out_channels
|
||||
p = self.x_embedder.patch_size[0]
|
||||
h = h // self.patch_size
|
||||
w = w // self.patch_size
|
||||
assert h * w == x.shape[1]
|
||||
|
||||
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
|
||||
x = torch.einsum('nhwpqc->nchpwq', x)
|
||||
imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p))
|
||||
return imgs
|
||||
@ -1,4 +1,5 @@
|
||||
import importlib
|
||||
import logging
|
||||
|
||||
import torch
|
||||
from torch import optim
|
||||
@ -23,7 +24,7 @@ def log_txt_as_img(wh, xc, size=10):
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
logging.warning("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
@ -65,7 +66,7 @@ def mean_flat(tensor):
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
logging.info(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
@ -133,7 +134,6 @@ class AdamWwithEMAandWings(optim.Optimizer):
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
ema_params_with_grad = []
|
||||
state_sums = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps = []
|
||||
amsgrad = group['amsgrad']
|
||||
|
||||
@ -354,6 +354,20 @@ def model_lora_keys_unet(model, key_map=None):
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) # simpletrainer and probably regular diffusers lora format
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, model_base.PixArt):
|
||||
diffusers_keys = utils.pixart_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = "transformer.{}".format(k[:-len(".weight")]) # default format
|
||||
key_map[key_lora] = to
|
||||
|
||||
key_lora = "base_model.model.{}".format(k[:-len(".weight")]) # diffusers training script
|
||||
key_map[key_lora] = to
|
||||
|
||||
key_lora = "unet.base_model.model.{}".format(k[:-len(".weight")]) # old reference peft script
|
||||
key_map[key_lora] = to
|
||||
|
||||
if isinstance(model, model_base.HunyuanDiT):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
@ -375,6 +389,18 @@ def model_lora_keys_unet(model, key_map=None):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, model_base.HunyuanVideo):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
# diffusion-pipe lora format
|
||||
key_lora = k
|
||||
key_lora = key_lora.replace("_mod.lin.", "_mod.linear.").replace("_attn.qkv.", "_attn_qkv.").replace("_attn.proj.", "_attn_proj.")
|
||||
key_lora = key_lora.replace("mlp.0.", "mlp.fc1.").replace("mlp.2.", "mlp.fc2.")
|
||||
key_lora = key_lora.replace(".modulation.lin.", ".modulation.linear.")
|
||||
key_lora = key_lora[len("diffusion_model."):-len(".weight")]
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@ -35,12 +35,14 @@ from .ldm.cascade.stage_b import StageB
|
||||
from .ldm.cascade.stage_c import StageC
|
||||
from .ldm.flux import model as flux_model
|
||||
from .ldm.genmo.joint_model.asymm_models_joint import AsymmDiTJoint
|
||||
from .ldm.hunyuan_video.model import HunyuanVideo as HunyuanVideoModel
|
||||
from .ldm.hydit.models import HunYuanDiT
|
||||
from .ldm.lightricks.model import LTXVModel
|
||||
from .ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper
|
||||
from .ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||
from .ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAugmentation
|
||||
from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentation
|
||||
from .ldm.pixart.pixartms import PixArtMS
|
||||
from .model_management_types import ModelManageable
|
||||
from .ops import Operations
|
||||
from .patcher_extension import WrapperExecutor, WrappersMP, get_all_wrappers
|
||||
@ -450,7 +452,6 @@ class SVD_img2vid(BaseModel):
|
||||
|
||||
latent_image = kwargs.get("concat_latent_image", None)
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if latent_image is None:
|
||||
latent_image = torch.zeros_like(noise)
|
||||
@ -742,8 +743,6 @@ class HunyuanDiT(BaseModel):
|
||||
|
||||
width = kwargs.get("width", 768)
|
||||
height = kwargs.get("height", 768)
|
||||
crop_w = kwargs.get("crop_w", 0)
|
||||
crop_h = kwargs.get("crop_h", 0)
|
||||
target_width = kwargs.get("target_width", width)
|
||||
target_height = kwargs.get("target_height", height)
|
||||
|
||||
@ -751,6 +750,26 @@ class HunyuanDiT(BaseModel):
|
||||
return out
|
||||
|
||||
|
||||
class PixArt(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=PixArtMS)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = conds.CONDRegular(cross_attn)
|
||||
|
||||
width = kwargs.get("width", None)
|
||||
height = kwargs.get("height", None)
|
||||
if width is not None and height is not None:
|
||||
out["c_size"] = conds.CONDRegular(torch.FloatTensor([[height, width]]))
|
||||
out["c_ar"] = conds.CONDRegular(torch.FloatTensor([[kwargs.get("aspect_ratio", height / width)]]))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Flux(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=flux_model.Flux)
|
||||
@ -787,7 +806,6 @@ class Flux(BaseModel):
|
||||
mask = torch.ones_like(noise)[:, :1]
|
||||
|
||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||
print(mask.shape)
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center")
|
||||
mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
@ -801,6 +819,16 @@ class Flux(BaseModel):
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = conds.CONDRegular(cross_attn)
|
||||
# upscale the attention mask, since now we
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
shape = kwargs["noise"].shape
|
||||
mask_ref_size = kwargs["attention_mask_img_shape"]
|
||||
# the model will pad to the patch size, and then divide
|
||||
# essentially dividing and rounding up
|
||||
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||
out['attention_mask'] = conds.CONDRegular(attention_mask)
|
||||
out['guidance'] = conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
|
||||
return out
|
||||
|
||||
@ -844,3 +872,22 @@ class LTXV(BaseModel):
|
||||
|
||||
out['frame_rate'] = conds.CONDConstant(kwargs.get("frame_rate", 25))
|
||||
return out
|
||||
|
||||
|
||||
class HunyuanVideo(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=HunyuanVideoModel)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = conds.CONDRegular(cross_attn)
|
||||
out['guidance'] = conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
|
||||
return out
|
||||
|
||||
@ -137,6 +137,26 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
unet_config["image_model"] = "hydit1"
|
||||
return unet_config
|
||||
|
||||
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "hunyuan_video"
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["patch_size"] = [1, 2, 2]
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["vec_in_dim"] = 768
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["theta"] = 256
|
||||
dit_config["qkv_bias"] = True
|
||||
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
|
||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: # Flux
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
@ -187,11 +207,42 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
return dit_config
|
||||
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys and '{}pos_embed.proj.bias'.format(key_prefix) in state_dict_keys:
|
||||
# PixArt diffusers
|
||||
return None
|
||||
|
||||
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ltxv"
|
||||
return dit_config
|
||||
|
||||
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
|
||||
patch_size = 2
|
||||
dit_config = {}
|
||||
dit_config["num_heads"] = 16
|
||||
dit_config["patch_size"] = patch_size
|
||||
dit_config["hidden_size"] = 1152
|
||||
dit_config["in_channels"] = 4
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
|
||||
|
||||
y_key = "{}y_embedder.y_embedding".format(key_prefix)
|
||||
if y_key in state_dict_keys:
|
||||
dit_config["model_max_length"] = state_dict[y_key].shape[0]
|
||||
|
||||
pe_key = "{}pos_embed".format(key_prefix)
|
||||
if pe_key in state_dict_keys:
|
||||
dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size
|
||||
dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess
|
||||
|
||||
ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix)
|
||||
if ar_key in state_dict_keys:
|
||||
dit_config["image_model"] = "pixart_alpha"
|
||||
dit_config["micro_condition"] = True
|
||||
else:
|
||||
dit_config["image_model"] = "pixart_sigma"
|
||||
dit_config["micro_condition"] = False
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@ -220,7 +271,6 @@ def detect_unet_config(state_dict, key_prefix):
|
||||
|
||||
num_res_blocks = []
|
||||
channel_mult = []
|
||||
attention_resolutions = []
|
||||
transformer_depth = []
|
||||
transformer_depth_output = []
|
||||
context_dim = None
|
||||
@ -394,7 +444,6 @@ def convert_config(unet_config):
|
||||
t_out += [d] * (res + 1)
|
||||
s *= 2
|
||||
transformer_depth = t_in
|
||||
transformer_depth_output = t_out
|
||||
new_config["transformer_depth"] = t_in
|
||||
new_config["transformer_depth_output"] = t_out
|
||||
new_config["transformer_depth_middle"] = transformer_depth_middle
|
||||
@ -562,6 +611,9 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.')
|
||||
num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
sd_map = utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix)
|
||||
elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt
|
||||
num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix)
|
||||
elif 'x_embedder.weight' in state_dict: # Flux
|
||||
depth = count_blocks(state_dict, 'transformer_blocks.{}.')
|
||||
depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.')
|
||||
|
||||
@ -43,6 +43,7 @@ model_management_lock = RLock()
|
||||
# This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
|
||||
torch.set_float32_matmul_precision("high")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class VRAMState(Enum):
|
||||
DISABLED = 0 # No vram present: no need to move models to vram
|
||||
@ -76,7 +77,7 @@ except:
|
||||
|
||||
lowvram_available = True
|
||||
if args.deterministic:
|
||||
logging.info("Using deterministic algorithms for pytorch")
|
||||
logger.info("Using deterministic algorithms for pytorch")
|
||||
torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
|
||||
directml_device = None
|
||||
@ -88,7 +89,7 @@ if args.directml is not None:
|
||||
directml_device = torch_directml.device()
|
||||
else:
|
||||
directml_device = torch_directml.device(device_index)
|
||||
logging.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
|
||||
logger.info("Using directml with device: {}".format(torch_directml.device_name(device_index)))
|
||||
# torch_directml.disable_tiled_resources(True)
|
||||
lowvram_available = False # TODO: need to find a way to get free memory in directml before this can be enabled by default.
|
||||
|
||||
@ -169,13 +170,13 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
|
||||
# we're required to call get_device_name early on to initialize the methods get_total_memory will call
|
||||
if torch.cuda.is_available() and hasattr(torch.version, "hip") and torch.version.hip is not None:
|
||||
logging.info(f"Detected HIP device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
||||
logger.info(f"Detected HIP device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
||||
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
|
||||
total_ram = psutil.virtual_memory().total / (1024 * 1024)
|
||||
logging.debug("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
logger.debug("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
|
||||
|
||||
try:
|
||||
logging.debug("pytorch version: {}".format(torch_version))
|
||||
logger.debug("pytorch version: {}".format(torch_version))
|
||||
except:
|
||||
pass
|
||||
|
||||
@ -205,10 +206,10 @@ else:
|
||||
pass
|
||||
try:
|
||||
XFORMERS_VERSION = xformers.version.__version__
|
||||
logging.debug("xformers version: {}".format(XFORMERS_VERSION))
|
||||
logger.debug("xformers version: {}".format(XFORMERS_VERSION))
|
||||
if XFORMERS_VERSION.startswith("0.0.18"):
|
||||
logging.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
|
||||
logging.warning("Please downgrade or upgrade xformers to a different version.\n")
|
||||
logger.warning("\nWARNING: This version of xformers has a major bug where you will get black images when generating high resolution images.")
|
||||
logger.warning("Please downgrade or upgrade xformers to a different version.\n")
|
||||
XFORMERS_ENABLED_VAE = False
|
||||
except:
|
||||
pass
|
||||
@ -232,6 +233,10 @@ def is_amd():
|
||||
return False
|
||||
|
||||
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||
if is_nvidia():
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.2
|
||||
|
||||
ENABLE_PYTORCH_ATTENTION = False
|
||||
if args.use_pytorch_cross_attention:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
@ -275,15 +280,15 @@ FORCE_FP32 = False
|
||||
FORCE_FP16 = False
|
||||
FORCE_BF16 = False
|
||||
if args.force_fp32:
|
||||
logging.info("Forcing FP32, if this improves things please report it.")
|
||||
logger.info("Forcing FP32, if this improves things please report it.")
|
||||
FORCE_FP32 = True
|
||||
|
||||
if args.force_fp16 or cpu_state == CPUState.MPS:
|
||||
logging.info("Forcing FP16.")
|
||||
logger.info("Forcing FP16.")
|
||||
FORCE_FP16 = True
|
||||
|
||||
if args.force_bf16:
|
||||
logging.info("Force BF16")
|
||||
logger.info("Force BF16")
|
||||
FORCE_BF16 = True
|
||||
|
||||
if lowvram_available:
|
||||
@ -296,12 +301,12 @@ if cpu_state != CPUState.GPU:
|
||||
if cpu_state == CPUState.MPS:
|
||||
vram_state = VRAMState.SHARED
|
||||
|
||||
logging.debug(f"Set vram state to: {vram_state.name}")
|
||||
logger.debug(f"Set vram state to: {vram_state.name}")
|
||||
|
||||
DISABLE_SMART_MEMORY = args.disable_smart_memory
|
||||
|
||||
if DISABLE_SMART_MEMORY:
|
||||
logging.debug("Disabling smart memory management")
|
||||
logger.debug("Disabling smart memory management")
|
||||
|
||||
|
||||
def get_torch_device_name(device):
|
||||
@ -321,9 +326,9 @@ def get_torch_device_name(device):
|
||||
|
||||
|
||||
try:
|
||||
logging.debug("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
logger.debug("Device: {}".format(get_torch_device_name(get_torch_device())))
|
||||
except:
|
||||
logging.warning("Could not pick default device.")
|
||||
logger.warning("Could not pick default device.")
|
||||
|
||||
current_loaded_models: Final[List["LoadedModel"]] = []
|
||||
|
||||
@ -364,6 +369,9 @@ class LoadedModel:
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
|
||||
def model_loaded_memory(self):
|
||||
return self.model.loaded_size()
|
||||
|
||||
def model_offloaded_memory(self):
|
||||
return self.model.model_size() - self.model.loaded_size()
|
||||
|
||||
@ -453,7 +461,7 @@ if WINDOWS:
|
||||
|
||||
if args.reserve_vram is not None:
|
||||
EXTRA_RESERVED_VRAM = args.reserve_vram * 1024 * 1024 * 1024
|
||||
logging.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
|
||||
logger.debug("Reserving {}MB vram for other applications.".format(EXTRA_RESERVED_VRAM / (1024 * 1024)))
|
||||
|
||||
|
||||
def extra_reserved_memory():
|
||||
@ -495,7 +503,7 @@ def _free_memory(memory_required, device, keep_loaded=[]):
|
||||
if free_mem > memory_required:
|
||||
break
|
||||
memory_to_free = memory_required - free_mem
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
logger.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
if current_loaded_models[i].model_unload(memory_to_free):
|
||||
unloaded_model.append(i)
|
||||
|
||||
@ -521,7 +529,7 @@ def load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0,
|
||||
_load_models_gpu(models, memory_required, force_patch_weights, minimum_memory_required, force_full_load)
|
||||
to_load = list(map(str, models))
|
||||
span.set_attribute("models", to_load)
|
||||
logging.info(f"Loaded {to_load}")
|
||||
logger.info(f"Loaded {to_load}")
|
||||
|
||||
|
||||
def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0, force_patch_weights=False, minimum_memory_required=None, force_full_load=False) -> None:
|
||||
@ -575,7 +583,7 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
|
||||
if free_mem < minimum_memory_required:
|
||||
models_l = free_memory(minimum_memory_required, device)
|
||||
models_freed += models_l
|
||||
logging.debug("{} models unloaded.".format(len(models_l)))
|
||||
logger.debug("{} models unloaded.".format(len(models_l)))
|
||||
|
||||
for loaded_model in models_to_load:
|
||||
model = loaded_model.model
|
||||
@ -587,15 +595,18 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
|
||||
lowvram_model_memory = 0
|
||||
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
|
||||
model_size = loaded_model.model_memory_required(torch_dev)
|
||||
current_free_mem = get_free_memory(torch_dev)
|
||||
lowvram_model_memory = max(64 * (1024 * 1024), (current_free_mem - minimum_memory_required), min(current_free_mem * 0.4, current_free_mem - minimum_inference_memory()))
|
||||
loaded_memory = loaded_model.model_loaded_memory()
|
||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||
|
||||
lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
|
||||
if model_size <= lowvram_model_memory: # only switch to lowvram if really necessary
|
||||
lowvram_model_memory = 0
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 64 * 1024 * 1024
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
|
||||
span = get_current_span()
|
||||
@ -625,7 +636,7 @@ def cleanup_models_gc():
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
||||
logger.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
||||
do_gc = True
|
||||
break
|
||||
|
||||
@ -636,7 +647,7 @@ def cleanup_models_gc():
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
logger.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
|
||||
|
||||
def cleanup_models():
|
||||
@ -673,7 +684,7 @@ def unet_offload_device():
|
||||
|
||||
def unet_initial_load_device(parameters, dtype):
|
||||
torch_dev = get_torch_device()
|
||||
if vram_state == VRAMState.HIGH_VRAM:
|
||||
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
|
||||
return torch_dev
|
||||
|
||||
cpu_dev = torch.device("cpu")
|
||||
@ -793,7 +804,7 @@ def text_encoder_initial_device(load_device, offload_device, model_size=0):
|
||||
return offload_device
|
||||
|
||||
if is_device_mps(load_device):
|
||||
return offload_device
|
||||
return load_device
|
||||
|
||||
mem_l = get_free_memory(load_device)
|
||||
mem_o = get_free_memory(offload_device)
|
||||
@ -951,6 +962,10 @@ def cast_to_device(tensor, device, dtype, copy=False):
|
||||
return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy)
|
||||
|
||||
|
||||
def sage_attention_enabled():
|
||||
return args.use_sage_attention
|
||||
|
||||
|
||||
FLASH_ATTENTION_ENABLED = False
|
||||
if not args.disable_flash_attn:
|
||||
try:
|
||||
@ -1143,10 +1158,10 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if props.major < 6:
|
||||
return False
|
||||
except (ValueError, RuntimeError):
|
||||
logging.warning("No CUDA devices were present, even though CUDA is available in this torch installation. This assumes the CPU device will be selected for computation")
|
||||
logger.warning("No CUDA devices were present, even though CUDA is available in this torch installation. This assumes the CPU device will be selected for computation")
|
||||
return False
|
||||
except AssertionError:
|
||||
logging.warning("Torch was not compiled with cuda support")
|
||||
logger.warning("Torch was not compiled with cuda support")
|
||||
return False
|
||||
|
||||
# FP16 is confirmed working on a 1080 (GP104) and on latest pytorch actually seems faster than fp32
|
||||
@ -1207,10 +1222,10 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if props_major >= 8:
|
||||
return True
|
||||
except (ValueError, RuntimeError):
|
||||
logging.warning("No CUDA devices were present, even though CUDA is available in this torch installation. This assumes the CPU device will be selected for computation")
|
||||
logger.warning("No CUDA devices were present, even though CUDA is available in this torch installation. This assumes the CPU device will be selected for computation")
|
||||
return False
|
||||
except AssertionError:
|
||||
logging.warning("Torch was not compiled with CUDA support")
|
||||
logger.warning("Torch was not compiled with CUDA support")
|
||||
return False
|
||||
|
||||
bf16_works = torch.cuda.is_bf16_supported()
|
||||
@ -1267,8 +1282,8 @@ def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
|
||||
def resolve_lowvram_weight(weight, model, key): # TODO: remove
|
||||
warnings.warn("The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.", category=DeprecationWarning)
|
||||
def resolve_lowvram_weight(weight, model, key): #TODO: remove
|
||||
logger.warning("The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.")
|
||||
return weight
|
||||
|
||||
|
||||
|
||||
@ -1093,7 +1093,7 @@ class ModelPatcher(ModelManageable):
|
||||
if cached_weights is not None:
|
||||
for key in cached_weights:
|
||||
if key not in model_sd_keys:
|
||||
print(f"WARNING cached hook could not patch. key does not exist in model: {key}")
|
||||
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
|
||||
else:
|
||||
@ -1103,7 +1103,7 @@ class ModelPatcher(ModelManageable):
|
||||
original_weights = self.get_key_patches()
|
||||
for key in relevant_patches:
|
||||
if key not in model_sd_keys:
|
||||
print(f"WARNING cached hook would not patch. key does not exist in model: {key}")
|
||||
logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}")
|
||||
continue
|
||||
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
|
||||
memory_counter=memory_counter)
|
||||
|
||||
@ -290,7 +290,7 @@ class VAEDecodeTiled:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
||||
"tile_size": ("INT", {"default": 512, "min": 128, "max": 4096, "step": 32}),
|
||||
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
|
||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
@ -324,15 +324,16 @@ class VAEEncodeTiled:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
|
||||
"tile_size": ("INT", {"default": 512, "min": 320, "max": 4096, "step": 64})
|
||||
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
|
||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def encode(self, vae, pixels, tile_size):
|
||||
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, )
|
||||
def encode(self, vae, pixels, tile_size, overlap):
|
||||
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap)
|
||||
return ({"samples":t}, )
|
||||
|
||||
class VAEEncodeForInpaint:
|
||||
@ -926,7 +927,7 @@ class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (get_filename_list_with_downloadable("text_encoders", KNOWN_CLIP_MODELS),),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ),
|
||||
"type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
@ -947,6 +948,8 @@ class CLIPLoader:
|
||||
clip_type = sd.CLIPType.MOCHI
|
||||
elif type == "ltxv":
|
||||
clip_type = sd.CLIPType.LTXV
|
||||
elif type == "pixart":
|
||||
clip_type = comfy.sd.CLIPType.PIXART
|
||||
else:
|
||||
logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}")
|
||||
|
||||
@ -959,7 +962,7 @@ class DualCLIPLoader:
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name1": (get_filename_list_with_downloadable("text_encoders"),), "clip_name2": (
|
||||
get_filename_list_with_downloadable("text_encoders"),),
|
||||
"type": (["sdxl", "sd3", "flux"], ),
|
||||
"type": (["sdxl", "sd3", "flux", "hunyuan_video"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
FUNCTION = "load_clip"
|
||||
@ -977,6 +980,8 @@ class DualCLIPLoader:
|
||||
clip_type = sd.CLIPType.SD3
|
||||
elif type == "flux":
|
||||
clip_type = sd.CLIPType.FLUX
|
||||
elif type == "hunyuan_video":
|
||||
clip_type = sd.CLIPType.HUNYUAN_VIDEO
|
||||
else:
|
||||
raise ValueError(f"Unknown clip type argument passed: {type} for model {clip_name1} and {clip_name2}")
|
||||
|
||||
@ -1044,23 +1049,58 @@ class StyleModelApply:
|
||||
"strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.001}),
|
||||
},
|
||||
"optional": {
|
||||
"strength_type": (["multiply"], {"default": "multiply"}),
|
||||
"strength_type": (["multiply", "attn_bias"], {"default": "multiply"}),
|
||||
}}
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "apply_stylemodel"
|
||||
|
||||
CATEGORY = "conditioning/style_model"
|
||||
|
||||
def apply_stylemodel(self, clip_vision_output, style_model, conditioning, strength=1.0, strength_type="multiply"):
|
||||
def apply_stylemodel(self, conditioning, style_model, clip_vision_output, strength=1.0, strength_type="multiply"):
|
||||
cond = style_model.get_cond(clip_vision_output).flatten(start_dim=0, end_dim=1).unsqueeze(dim=0)
|
||||
if strength_type == "multiply":
|
||||
cond *= strength
|
||||
|
||||
c = []
|
||||
n = cond.shape[1]
|
||||
c_out = []
|
||||
for t in conditioning:
|
||||
n = [torch.cat((t[0], cond), dim=1), t[1].copy()]
|
||||
c.append(n)
|
||||
return (c, )
|
||||
(txt, keys) = t
|
||||
keys = keys.copy()
|
||||
if strength_type == "attn_bias" and strength != 1.0:
|
||||
# math.log raises an error if the argument is zero
|
||||
# torch.log returns -inf, which is what we want
|
||||
attn_bias = torch.log(torch.Tensor([strength]))
|
||||
# get the size of the mask image
|
||||
mask_ref_size = keys.get("attention_mask_img_shape", (1, 1))
|
||||
n_ref = mask_ref_size[0] * mask_ref_size[1]
|
||||
n_txt = txt.shape[1]
|
||||
# grab the existing mask
|
||||
mask = keys.get("attention_mask", None)
|
||||
# create a default mask if it doesn't exist
|
||||
if mask is None:
|
||||
mask = torch.zeros((txt.shape[0], n_txt + n_ref, n_txt + n_ref), dtype=torch.float16)
|
||||
# convert the mask dtype, because it might be boolean
|
||||
# we want it to be interpreted as a bias
|
||||
if mask.dtype == torch.bool:
|
||||
# log(True) = log(1) = 0
|
||||
# log(False) = log(0) = -inf
|
||||
mask = torch.log(mask.to(dtype=torch.float16))
|
||||
# now we make the mask bigger to add space for our new tokens
|
||||
new_mask = torch.zeros((txt.shape[0], n_txt + n + n_ref, n_txt + n + n_ref), dtype=torch.float16)
|
||||
# copy over the old mask, in quandrants
|
||||
new_mask[:, :n_txt, :n_txt] = mask[:, :n_txt, :n_txt]
|
||||
new_mask[:, :n_txt, n_txt+n:] = mask[:, :n_txt, n_txt:]
|
||||
new_mask[:, n_txt+n:, :n_txt] = mask[:, n_txt:, :n_txt]
|
||||
new_mask[:, n_txt+n:, n_txt+n:] = mask[:, n_txt:, n_txt:]
|
||||
# now fill in the attention bias to our redux tokens
|
||||
new_mask[:, :n_txt, n_txt:n_txt+n] = attn_bias
|
||||
new_mask[:, n_txt+n:, n_txt:n_txt+n] = attn_bias
|
||||
keys["attention_mask"] = new_mask.to(txt.device)
|
||||
keys["attention_mask_img_shape"] = mask_ref_size
|
||||
|
||||
c_out.append([torch.cat((txt, cond), dim=1), keys])
|
||||
|
||||
return (c_out,)
|
||||
|
||||
class unCLIPConditioning:
|
||||
@classmethod
|
||||
|
||||
@ -113,7 +113,7 @@ class WrapperExecutor:
|
||||
def _create_next_executor(self) -> 'WrapperExecutor':
|
||||
new_idx = self.idx + 1
|
||||
if new_idx > len(self.wrappers):
|
||||
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.")
|
||||
raise Exception("Wrapper idx exceeded available wrappers; something went very wrong.")
|
||||
if self.class_obj is None:
|
||||
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
|
||||
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)
|
||||
|
||||
@ -107,7 +107,6 @@ def cleanup_additional_models(models):
|
||||
|
||||
|
||||
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
|
||||
device = model.load_device
|
||||
real_model: 'BaseModel' = None
|
||||
models, inference_memory = get_additional_models(conds, model.model_dtype())
|
||||
models += model.get_nested_additional_models() # TODO: does this require inference_memory update?
|
||||
|
||||
@ -137,11 +137,6 @@ def can_concat_cond(c1, c2):
|
||||
|
||||
|
||||
def cond_cat(c_list):
|
||||
c_crossattn = []
|
||||
c_concat = []
|
||||
c_adm = []
|
||||
crossattn_max_len = 0
|
||||
|
||||
temp = {}
|
||||
for x in c_list:
|
||||
for k in x:
|
||||
@ -360,7 +355,7 @@ def cfg_function(model, cond_pred, uncond_pred, cond_scale, x, timestep, model_o
|
||||
cfg_result = uncond_pred + (cond_pred - uncond_pred) * cond_scale
|
||||
|
||||
for fn in model_options.get("sampler_post_cfg_function", []):
|
||||
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||
args = {"denoised": cfg_result, "cond": cond, "uncond": uncond, "cond_scale": cond_scale, "model": model, "uncond_denoised": uncond_pred, "cond_denoised": cond_pred,
|
||||
"sigma": timestep, "model_options": model_options, "input": x}
|
||||
cfg_result = fn(args)
|
||||
|
||||
@ -639,8 +634,6 @@ def pre_run_control(model, conds):
|
||||
for t in range(len(conds)):
|
||||
x = conds[t]
|
||||
|
||||
timestep_start = None
|
||||
timestep_end = None
|
||||
percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
|
||||
if 'control' in x:
|
||||
x['control'].pre_run(model, percent_to_timestep_function)
|
||||
|
||||
114
comfy/sd.py
114
comfy/sd.py
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import math
|
||||
import os.path
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
@ -35,9 +36,11 @@ from .taesd import taesd
|
||||
from .text_encoders import aura_t5
|
||||
from .text_encoders import flux
|
||||
from .text_encoders import genmo
|
||||
from .text_encoders import hunyuan_video
|
||||
from .text_encoders import hydit
|
||||
from .text_encoders import long_clipl
|
||||
from .text_encoders import lt
|
||||
from .text_encoders import pixart_t5
|
||||
from .text_encoders import sa_t5
|
||||
from .text_encoders import sd2_clip
|
||||
from .text_encoders import sd3_clip
|
||||
@ -45,6 +48,7 @@ from .utils import ProgressBar
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_lora_for_models(model, clip, _lora, strength_model, strength_clip):
|
||||
key_map = {}
|
||||
if model is not None:
|
||||
@ -317,8 +321,8 @@ class VAE:
|
||||
self.upscale_ratio = 4
|
||||
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
if 'quant_conv.weight' in sd:
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=4)
|
||||
if 'post_quant_conv.weight' in sd:
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
else:
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': ddconfig},
|
||||
@ -346,15 +350,35 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
|
||||
self.working_dtypes = [torch.float16, torch.float32]
|
||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: # lightricks ltxv
|
||||
self.first_stage_model = lightricks.VideoVAE()
|
||||
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
|
||||
version = 0
|
||||
if tensor_conv1.shape[0] == 512:
|
||||
version = 0
|
||||
elif tensor_conv1.shape[0] == 1024:
|
||||
version = 1
|
||||
self.first_stage_model = lightricks.VideoVAE(version=version)
|
||||
self.latent_channels = 128
|
||||
self.latent_dim = 3
|
||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
elif "decoder.conv_in.conv.weight" in sd:
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
ddconfig["conv3d"] = True
|
||||
ddconfig["time_compress"] = 4
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
else:
|
||||
logger.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@ -384,10 +408,12 @@ class VAE:
|
||||
logger.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
|
||||
dims = pixels.shape[1:-1]
|
||||
for d in range(len(dims)):
|
||||
x = (dims[d] // self.downscale_ratio) * self.downscale_ratio
|
||||
x_offset = (dims[d] % self.downscale_ratio) // 2
|
||||
x = (dims[d] // downscale_ratio) * downscale_ratio
|
||||
x_offset = (dims[d] % downscale_ratio) // 2
|
||||
if x != dims[d]:
|
||||
pixels = pixels.narrow(d + 1, x_offset, x)
|
||||
return pixels
|
||||
@ -408,7 +434,7 @@ class VAE:
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
return utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)
|
||||
return self.process_output(utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
@ -431,6 +457,10 @@ class VAE:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1 / self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
|
||||
|
||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in):
|
||||
pixel_samples = None
|
||||
try:
|
||||
@ -446,7 +476,7 @@ class VAE:
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples[x:x + batch_number] = out
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logger.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
dims = samples_in.ndim - 2
|
||||
if dims == 1:
|
||||
@ -503,20 +533,48 @@ class VAE:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logger.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
if len(pixel_samples.shape) == 3:
|
||||
if self.latent_dim == 3:
|
||||
tile = 256
|
||||
overlap = tile // 4
|
||||
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
|
||||
elif self.latent_dim == 1:
|
||||
samples = self.encode_tiled_1d(pixel_samples)
|
||||
else:
|
||||
samples = self.encode_tiled_(pixel_samples)
|
||||
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
|
||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
load_models_gpu([self.patcher])
|
||||
dims = self.latent_dim
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
samples = self.encode_tiled_(pixel_samples, tile_x=tile_x, tile_y=tile_y, overlap=overlap)
|
||||
if dims == 3:
|
||||
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
|
||||
|
||||
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
||||
load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
|
||||
args = {}
|
||||
if tile_x is not None:
|
||||
args["tile_x"] = tile_x
|
||||
if tile_y is not None:
|
||||
args["tile_y"] = tile_y
|
||||
if overlap is not None:
|
||||
args["overlap"] = overlap
|
||||
|
||||
samples = None
|
||||
if dims == 1:
|
||||
args.pop("tile_y")
|
||||
samples = self.encode_tiled_1d(pixel_samples, **args)
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
samples = self.encode_tiled_3d(pixel_samples, **args)
|
||||
else:
|
||||
raise ValueError(f"unsupported values dim {dims}")
|
||||
|
||||
return samples
|
||||
|
||||
def get_sd(self):
|
||||
@ -528,6 +586,12 @@ class VAE:
|
||||
except:
|
||||
return self.upscale_ratio
|
||||
|
||||
def spacial_compression_encode(self):
|
||||
try:
|
||||
return self.downscale_ratio[-1]
|
||||
except:
|
||||
return self.downscale_ratio
|
||||
|
||||
|
||||
class StyleModel:
|
||||
def __init__(self, model, device="cpu"):
|
||||
@ -559,6 +623,8 @@ class CLIPType(Enum):
|
||||
FLUX = 6
|
||||
MOCHI = 7
|
||||
LTXV = 8
|
||||
HUNYUAN_VIDEO = 9
|
||||
PIXART = 10
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@ -585,6 +651,7 @@ class TEModel(Enum):
|
||||
T5_XXL = 4
|
||||
T5_XL = 5
|
||||
T5_BASE = 6
|
||||
LLAMA3_8 = 7
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
@ -602,6 +669,8 @@ def detect_te_model(sd):
|
||||
return TEModel.T5_XL
|
||||
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
|
||||
return TEModel.T5_BASE
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.LLAMA3_8
|
||||
return None
|
||||
|
||||
|
||||
@ -615,6 +684,16 @@ def t5xxl_detect(clip_data):
|
||||
return {}
|
||||
|
||||
|
||||
def llama_detect(clip_data):
|
||||
weight_name = "model.layers.0.self_attn.k_proj.weight"
|
||||
|
||||
for sd in clip_data:
|
||||
if weight_name in sd:
|
||||
return hunyuan_video.llama_detect(sd)
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}, textmodel_json_config=None):
|
||||
clip_data = state_dicts
|
||||
|
||||
@ -652,6 +731,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.LTXV:
|
||||
clip_target.clip = lt.ltxv_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = lt.LTXVT5Tokenizer
|
||||
elif clip_type == CLIPType.PIXART:
|
||||
clip_target.clip = pixart_t5.pixart_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = pixart_t5.PixArtTokenizer
|
||||
else: # CLIPType.MOCHI
|
||||
clip_target.clip = genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = genmo.MochiT5Tokenizer
|
||||
@ -679,6 +761,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.FLUX:
|
||||
clip_target.clip = flux.flux_clip(**t5xxl_detect(clip_data))
|
||||
clip_target.tokenizer = flux.FluxTokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_VIDEO:
|
||||
clip_target.clip = hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = hunyuan_video.HunyuanVideoTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
@ -720,7 +805,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
|
||||
config = yaml.safe_load(stream)
|
||||
model_config_params = config['model']['params']
|
||||
clip_config = model_config_params['cond_stage_config']
|
||||
scale_factor = model_config_params['scale_factor']
|
||||
|
||||
if "parameterization" in model_config_params:
|
||||
if model_config_params["parameterization"] == "v":
|
||||
@ -906,12 +990,12 @@ def load_diffusion_model(unet_path, model_options: dict = None):
|
||||
|
||||
|
||||
def load_unet(unet_path, dtype=None):
|
||||
print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||
logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model")
|
||||
return load_diffusion_model(unet_path, model_options={"dtype": dtype})
|
||||
|
||||
|
||||
def load_unet_state_dict(sd, dtype=None):
|
||||
print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||
logging.warning("The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict")
|
||||
return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype})
|
||||
|
||||
|
||||
|
||||
@ -52,7 +52,10 @@ class ClipTokenWeightEncoder:
|
||||
|
||||
sections = len(to_encode)
|
||||
if has_weights or sections == 0 and hasattr(self, "special_tokens"):
|
||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) # pylint: disable=no-member
|
||||
if hasattr(self, "gen_empty_tokens"):
|
||||
to_encode.append(self.gen_empty_tokens(self.special_tokens, max_token_len))
|
||||
else:
|
||||
to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) # pylint: disable=no-member
|
||||
|
||||
assert hasattr(self, "encode")
|
||||
assert isinstance(self.encode, Callable) # pylint: disable=no-member
|
||||
@ -216,11 +219,18 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
attention_mask = None
|
||||
if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
|
||||
attention_mask = torch.zeros_like(tokens)
|
||||
end_token = self.special_tokens.get("end", -1)
|
||||
end_token = self.special_tokens.get("end", None)
|
||||
if end_token is None:
|
||||
cmp_token = self.special_tokens.get("pad", -1)
|
||||
else:
|
||||
cmp_token = end_token
|
||||
|
||||
for x in range(attention_mask.shape[0]):
|
||||
for y in range(attention_mask.shape[1]):
|
||||
attention_mask[x, y] = 1
|
||||
if tokens[x, y] == end_token:
|
||||
if tokens[x, y] == cmp_token:
|
||||
if end_token is None:
|
||||
attention_mask[x, y] = 0
|
||||
break
|
||||
|
||||
attention_mask_model = None
|
||||
@ -403,7 +413,6 @@ def expand_directory_list(directories):
|
||||
|
||||
|
||||
def bundled_embed(embed, prefix, suffix): # bundled embedding in lora format
|
||||
i = 0
|
||||
out_list = []
|
||||
for k in embed:
|
||||
if k.startswith(prefix) and k.endswith(suffix):
|
||||
@ -460,7 +469,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
|
||||
embed_out = safe_load_embed_zip(embed_path)
|
||||
else:
|
||||
embed = torch.load(embed_path, map_location="cpu")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
|
||||
return None
|
||||
|
||||
@ -493,7 +502,7 @@ SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer')
|
||||
|
||||
|
||||
class SDTokenizer:
|
||||
def __init__(self, tokenizer_path: torch.Tensor | bytes | bytearray | memoryview | str | Path | Traversable = None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, pad_to_max_length=True, min_length=None, pad_token=None, tokenizer_data=None):
|
||||
def __init__(self, tokenizer_path: torch.Tensor | bytes | bytearray | memoryview | str | Path | Traversable = None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data=None):
|
||||
if tokenizer_data is None:
|
||||
tokenizer_data = dict()
|
||||
if tokenizer_path is None:
|
||||
@ -511,16 +520,25 @@ class SDTokenizer:
|
||||
self.tokenizer: PreTrainedTokenizerBase | SPieceTokenizer = tokenizer_class.from_pretrained(tokenizer_path)
|
||||
self.max_length = max_length
|
||||
self.min_length = min_length
|
||||
self.end_token = None
|
||||
|
||||
empty = self.tokenizer('')["input_ids"]
|
||||
self.tokenizer_adds_end_token = has_end_token
|
||||
if has_start_token:
|
||||
self.tokens_start = 1
|
||||
self.start_token = empty[0]
|
||||
self.end_token = empty[1]
|
||||
if end_token is not None:
|
||||
self.end_token = end_token
|
||||
else:
|
||||
if has_end_token:
|
||||
self.end_token = empty[1]
|
||||
else:
|
||||
self.tokens_start = 0
|
||||
self.start_token = None
|
||||
self.end_token = empty[0]
|
||||
if end_token is not None:
|
||||
self.end_token = end_token
|
||||
else:
|
||||
self.end_token = empty[0]
|
||||
|
||||
if pad_token is not None:
|
||||
self.pad_token = pad_token
|
||||
@ -558,13 +576,16 @@ class SDTokenizer:
|
||||
Takes a potential embedding name and tries to retrieve it.
|
||||
Returns a Tuple consisting of the embedding and any leftover string, embedding can be None.
|
||||
'''
|
||||
split_embed = embedding_name.split()
|
||||
embedding_name = split_embed[0]
|
||||
leftover = ' '.join(split_embed[1:])
|
||||
embed = load_embed(embedding_name, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||
if embed is None:
|
||||
stripped = embedding_name.strip(',')
|
||||
if len(stripped) < len(embedding_name):
|
||||
embed = load_embed(stripped, self.embedding_directory, self.embedding_size, self.embedding_key)
|
||||
return (embed, embedding_name[len(stripped):])
|
||||
return (embed, "")
|
||||
return (embed, "{} {}".format(embedding_name[len(stripped):], leftover))
|
||||
return (embed, leftover)
|
||||
|
||||
def tokenize_with_weights(self, text: str, return_word_ids=False):
|
||||
'''
|
||||
@ -581,7 +602,12 @@ class SDTokenizer:
|
||||
# tokenize words
|
||||
tokens = []
|
||||
for weighted_segment, weight in parsed_weights:
|
||||
to_tokenize = unescape_important(weighted_segment).replace("\n", " ").split(' ')
|
||||
to_tokenize = unescape_important(weighted_segment)
|
||||
split = re.split(' {0}|\n{0}'.format(self.embedding_identifier), to_tokenize)
|
||||
to_tokenize = [split[0]]
|
||||
for i in range(1, len(split)):
|
||||
to_tokenize.append("{}{}".format(self.embedding_identifier, split[i]))
|
||||
|
||||
to_tokenize = [x for x in to_tokenize if x != ""]
|
||||
for word in to_tokenize:
|
||||
# if we find an embedding, deal with the embedding
|
||||
@ -600,6 +626,9 @@ class SDTokenizer:
|
||||
word = leftover
|
||||
else:
|
||||
continue
|
||||
end = 999999999999
|
||||
if self.tokenizer_adds_end_token:
|
||||
end = -1
|
||||
# parse word
|
||||
exact_word = f"{word}</w>"
|
||||
if hasattr(self.tokenizer, "eos_token") and word == self.tokenizer.eos_token:
|
||||
@ -607,7 +636,7 @@ class SDTokenizer:
|
||||
elif exact_word in vocab:
|
||||
tokenizer_result = [vocab[exact_word]]
|
||||
else:
|
||||
tokenizer_result = self.tokenizer(word)["input_ids"][self.tokens_start:-1]
|
||||
tokenizer_result = self.tokenizer(word)["input_ids"][self.tokens_start:end]
|
||||
tokens.append([(t, weight) for t in tokenizer_result])
|
||||
|
||||
# reshape token array to CLIP input size
|
||||
@ -619,18 +648,24 @@ class SDTokenizer:
|
||||
for i, t_group in enumerate(tokens):
|
||||
# determine if we're going to try and keep the tokens in a single batch
|
||||
is_large = len(t_group) >= self.max_word_length
|
||||
if self.end_token is not None:
|
||||
has_end_token = 1
|
||||
else:
|
||||
has_end_token = 0
|
||||
|
||||
while len(t_group) > 0:
|
||||
if len(t_group) + len(batch) > self.max_length - 1:
|
||||
remaining_length = self.max_length - len(batch) - 1
|
||||
if len(t_group) + len(batch) > self.max_length - has_end_token:
|
||||
remaining_length = self.max_length - len(batch) - has_end_token
|
||||
# break word in two and add end token
|
||||
if is_large:
|
||||
batch.extend([(t, w, i + 1) for t, w in t_group[:remaining_length]])
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
t_group = t_group[remaining_length:]
|
||||
# add end token and pad
|
||||
else:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length))
|
||||
# start new batch
|
||||
@ -643,7 +678,8 @@ class SDTokenizer:
|
||||
t_group = []
|
||||
|
||||
# fill last batch
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.end_token is not None:
|
||||
batch.append((self.end_token, 1.0, 0))
|
||||
if self.pad_to_max_length:
|
||||
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
|
||||
if self.min_length is not None and len(batch) < self.min_length:
|
||||
|
||||
@ -1,22 +1,23 @@
|
||||
import torch
|
||||
from . import model_base
|
||||
from . import utils
|
||||
|
||||
from . import sd1_clip
|
||||
from . import sdxl_clip
|
||||
from .text_encoders import sd2_clip
|
||||
from .text_encoders import sd3_clip
|
||||
from .text_encoders import sa_t5
|
||||
from .text_encoders import aura_t5
|
||||
from .text_encoders import hydit
|
||||
from .text_encoders import flux
|
||||
from .text_encoders import genmo
|
||||
from .text_encoders import lt
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
|
||||
from . import diffusers_convert
|
||||
from . import latent_formats
|
||||
from . import model_base
|
||||
from . import sd1_clip
|
||||
from . import sdxl_clip
|
||||
from . import supported_models_base
|
||||
from . import utils
|
||||
from .text_encoders import aura_t5
|
||||
from .text_encoders import flux
|
||||
from .text_encoders import genmo
|
||||
from .text_encoders import hunyuan_video
|
||||
from .text_encoders import hydit
|
||||
from .text_encoders import lt
|
||||
from .text_encoders import pixart_t5
|
||||
from .text_encoders import sa_t5
|
||||
from .text_encoders import sd2_clip
|
||||
from .text_encoders import sd3_clip
|
||||
|
||||
|
||||
class SD15(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
@ -64,6 +65,7 @@ class SD15(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sd1_clip.SD1Tokenizer, sd1_clip.SD1ClipModel)
|
||||
|
||||
|
||||
class SD20(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"context_dim": 1024,
|
||||
@ -83,16 +85,16 @@ class SD20(supported_models_base.BASE):
|
||||
memory_usage_factor = 1.0
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
|
||||
if self.unet_config["in_channels"] == 4: # SD2.0 inpainting models are not v prediction
|
||||
k = "{}output_blocks.11.1.transformer_blocks.0.norm1.bias".format(prefix)
|
||||
out = state_dict.get(k, None)
|
||||
if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||
if out is not None and torch.std(out, unbiased=False) > 0.09: # not sure how well this will actually work. I guess we will find out.
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
return model_base.ModelType.EPS
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
replace_prefix = {}
|
||||
replace_prefix["conditioner.embedders.0.model."] = "clip_h." #SD2 in sgm format
|
||||
replace_prefix["conditioner.embedders.0.model."] = "clip_h." # SD2 in sgm format
|
||||
replace_prefix["cond_stage_model.model."] = "clip_h."
|
||||
state_dict = utils.state_dict_prefix_replace(state_dict, replace_prefix, filter_keys=True)
|
||||
state_dict = utils.clip_text_transformers_convert(state_dict, "clip_h.", "clip_h.transformer.")
|
||||
@ -108,6 +110,7 @@ class SD20(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sd2_clip.SD2Tokenizer, sd2_clip.SD2ClipModel)
|
||||
|
||||
|
||||
class SD21UnclipL(SD20):
|
||||
unet_config = {
|
||||
"context_dim": 1024,
|
||||
@ -133,6 +136,7 @@ class SD21UnclipH(SD20):
|
||||
clip_vision_prefix = "embedder.model.visual."
|
||||
noise_aug_config = {"noise_schedule_config": {"timesteps": 1000, "beta_schedule": "squaredcos_cap_v2"}, "timestep_dim": 1024}
|
||||
|
||||
|
||||
class SDXLRefiner(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"model_channels": 384,
|
||||
@ -171,6 +175,7 @@ class SDXLRefiner(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLRefinerClipModel)
|
||||
|
||||
|
||||
class SDXL(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
@ -186,7 +191,7 @@ class SDXL(supported_models_base.BASE):
|
||||
memory_usage_factor = 0.8
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
|
||||
if 'edm_mean' in state_dict and 'edm_std' in state_dict: # Playground V2.5
|
||||
self.latent_format = latent_formats.SDXL_Playground_2_5()
|
||||
self.sampling_settings["sigma_data"] = 0.5
|
||||
self.sampling_settings["sigma_max"] = 80.0
|
||||
@ -198,7 +203,7 @@ class SDXL(supported_models_base.BASE):
|
||||
self.sampling_settings["sigma_min"] = float(state_dict["edm_vpred.sigma_min"].item())
|
||||
return model_base.ModelType.V_PREDICTION_EDM
|
||||
elif "v_pred" in state_dict:
|
||||
if "ztsnr" in state_dict: #Some zsnr anime checkpoints
|
||||
if "ztsnr" in state_dict: # Some zsnr anime checkpoints
|
||||
self.sampling_settings["zsnr"] = True
|
||||
return model_base.ModelType.V_PREDICTION
|
||||
else:
|
||||
@ -224,7 +229,6 @@ class SDXL(supported_models_base.BASE):
|
||||
|
||||
def process_clip_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {}
|
||||
keys_to_replace = {}
|
||||
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
|
||||
for k in state_dict:
|
||||
if k.startswith("clip_l"):
|
||||
@ -244,6 +248,7 @@ class SDXL(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.SDXLTokenizer, sdxl_clip.SDXLClipModel)
|
||||
|
||||
|
||||
class SSD1B(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
@ -254,6 +259,7 @@ class SSD1B(SDXL):
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
|
||||
class Segmind_Vega(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
@ -264,6 +270,7 @@ class Segmind_Vega(SDXL):
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
|
||||
class KOALA_700M(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
@ -274,6 +281,7 @@ class KOALA_700M(SDXL):
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
|
||||
class KOALA_1B(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
@ -284,6 +292,7 @@ class KOALA_1B(SDXL):
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
|
||||
class SVD_img2vid(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
@ -315,6 +324,7 @@ class SVD_img2vid(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
|
||||
class SV3D_u(SVD_img2vid):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
@ -333,6 +343,7 @@ class SV3D_u(SVD_img2vid):
|
||||
out = model_base.SV3D_u(self, device=device)
|
||||
return out
|
||||
|
||||
|
||||
class SV3D_p(SV3D_u):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
@ -345,11 +356,11 @@ class SV3D_p(SV3D_u):
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.SV3D_p(self, device=device)
|
||||
return out
|
||||
|
||||
|
||||
class Stable_Zero123(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
@ -381,6 +392,7 @@ class Stable_Zero123(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
|
||||
class SD_X4Upscaler(SD20):
|
||||
unet_config = {
|
||||
"context_dim": 1024,
|
||||
@ -409,6 +421,7 @@ class SD_X4Upscaler(SD20):
|
||||
out = model_base.SD_X4Upscaler(self, device=device)
|
||||
return out
|
||||
|
||||
|
||||
class Stable_Cascade_C(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"stable_cascade_stage": 'c',
|
||||
@ -439,7 +452,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
|
||||
for x in range(3):
|
||||
p = ["to_q", "to_k", "to_v"]
|
||||
k_to = "{}.{}.{}".format(prefix, p[x], y)
|
||||
state_dict[k_to] = weights[shape_from*x:shape_from*(x + 1)]
|
||||
state_dict[k_to] = weights[shape_from * x:shape_from * (x + 1)]
|
||||
return state_dict
|
||||
|
||||
def process_clip_state_dict(self, state_dict):
|
||||
@ -455,6 +468,7 @@ class Stable_Cascade_C(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sdxl_clip.StableCascadeTokenizer, sdxl_clip.StableCascadeClipModel)
|
||||
|
||||
|
||||
class Stable_Cascade_B(Stable_Cascade_C):
|
||||
unet_config = {
|
||||
"stable_cascade_stage": 'b',
|
||||
@ -475,6 +489,7 @@ class Stable_Cascade_B(Stable_Cascade_C):
|
||||
out = model_base.StableCascade_B(self, device=device)
|
||||
return out
|
||||
|
||||
|
||||
class SD15_instructpix2pix(SD15):
|
||||
unet_config = {
|
||||
"context_dim": 768,
|
||||
@ -488,6 +503,7 @@ class SD15_instructpix2pix(SD15):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SD15_instructpix2pix(self, device=device)
|
||||
|
||||
|
||||
class SDXL_instructpix2pix(SDXL):
|
||||
unet_config = {
|
||||
"model_channels": 320,
|
||||
@ -502,6 +518,7 @@ class SDXL_instructpix2pix(SDXL):
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
|
||||
|
||||
|
||||
class SD3(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"in_channels": 16,
|
||||
@ -527,7 +544,6 @@ class SD3(supported_models_base.BASE):
|
||||
clip_l = False
|
||||
clip_g = False
|
||||
t5 = False
|
||||
dtype_t5 = None
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
|
||||
clip_l = True
|
||||
@ -539,6 +555,7 @@ class SD3(supported_models_base.BASE):
|
||||
|
||||
return supported_models_base.ClipTarget(sd3_clip.SD3Tokenizer, sd3_clip.sd3_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, **t5_detect))
|
||||
|
||||
|
||||
class StableAudio(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"audio_model": "dit1.0",
|
||||
@ -559,7 +576,7 @@ class StableAudio(supported_models_base.BASE):
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
for k in list(state_dict.keys()):
|
||||
if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): #These weights are all zero
|
||||
if k.endswith(".cross_attend_norm.beta") or k.endswith(".ff_norm.beta") or k.endswith(".pre_norm.beta"): # These weights are all zero
|
||||
state_dict.pop(k)
|
||||
return state_dict
|
||||
|
||||
@ -570,6 +587,7 @@ class StableAudio(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(sa_t5.SAT5Tokenizer, sa_t5.SAT5Model)
|
||||
|
||||
|
||||
class AuraFlow(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"cond_seq_dim": 2048,
|
||||
@ -593,6 +611,42 @@ class AuraFlow(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(aura_t5.AuraT5Tokenizer, aura_t5.AuraT5Model)
|
||||
|
||||
|
||||
class PixArtAlpha(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "pixart_alpha",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"beta_schedule": "sqrt_linear",
|
||||
"linear_start": 0.0001,
|
||||
"linear_end": 0.02,
|
||||
"timesteps": 1000,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.PixArt(self, device=device)
|
||||
return out.eval()
|
||||
|
||||
def clip_target(self, state_dict=None):
|
||||
if state_dict is None:
|
||||
state_dict = {}
|
||||
return supported_models_base.ClipTarget(pixart_t5.PixArtTokenizer, pixart_t5.PixArtT5XXL)
|
||||
|
||||
|
||||
class PixArtSigma(PixArtAlpha):
|
||||
unet_config = {
|
||||
"image_model": "pixart_sigma",
|
||||
}
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
|
||||
class HunyuanDiT(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hydit",
|
||||
@ -619,6 +673,7 @@ class HunyuanDiT(supported_models_base.BASE):
|
||||
def clip_target(self, state_dict={}):
|
||||
return supported_models_base.ClipTarget(hydit.HyditTokenizer, hydit.HyditModel)
|
||||
|
||||
|
||||
class HunyuanDiT1(HunyuanDiT):
|
||||
unet_config = {
|
||||
"image_model": "hydit1",
|
||||
@ -627,10 +682,11 @@ class HunyuanDiT1(HunyuanDiT):
|
||||
unet_extra_config = {}
|
||||
|
||||
sampling_settings = {
|
||||
"linear_start" : 0.00085,
|
||||
"linear_end" : 0.03,
|
||||
"linear_start": 0.00085,
|
||||
"linear_end": 0.03,
|
||||
}
|
||||
|
||||
|
||||
class Flux(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
@ -659,6 +715,7 @@ class Flux(supported_models_base.BASE):
|
||||
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(flux.FluxTokenizer, flux.flux_clip(**t5_detect))
|
||||
|
||||
|
||||
class FluxInpaint(Flux):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
@ -668,6 +725,7 @@ class FluxInpaint(Flux):
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
|
||||
class FluxSchnell(Flux):
|
||||
unet_config = {
|
||||
"image_model": "flux",
|
||||
@ -683,6 +741,7 @@ class FluxSchnell(Flux):
|
||||
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||
return out
|
||||
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "mochi_preview",
|
||||
@ -696,7 +755,7 @@ class GenmoMochi(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Mochi
|
||||
|
||||
memory_usage_factor = 2.0 #TODO
|
||||
memory_usage_factor = 2.0 # TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
@ -714,6 +773,7 @@ class GenmoMochi(supported_models_base.BASE):
|
||||
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(genmo.MochiT5Tokenizer, genmo.mochi_te(**t5_detect))
|
||||
|
||||
|
||||
class LTXV(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ltxv",
|
||||
@ -742,6 +802,58 @@ class LTXV(supported_models_base.BASE):
|
||||
t5_detect = sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(lt.LTXVT5Tokenizer, lt.ltxv_te(**t5_detect))
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
|
||||
|
||||
class HunyuanVideo(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 7.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.HunyuanVideo
|
||||
|
||||
memory_usage_factor = 2.0 # TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanVideo(self, device=device)
|
||||
return out
|
||||
|
||||
def process_unet_state_dict(self, state_dict):
|
||||
out_sd = {}
|
||||
for k in list(state_dict.keys()):
|
||||
key_out = k
|
||||
key_out = key_out.replace("txt_in.t_embedder.mlp.0.", "txt_in.t_embedder.in_layer.").replace("txt_in.t_embedder.mlp.2.", "txt_in.t_embedder.out_layer.")
|
||||
key_out = key_out.replace("txt_in.c_embedder.linear_1.", "txt_in.c_embedder.in_layer.").replace("txt_in.c_embedder.linear_2.", "txt_in.c_embedder.out_layer.")
|
||||
key_out = key_out.replace("_mod.linear.", "_mod.lin.").replace("_attn_qkv.", "_attn.qkv.")
|
||||
key_out = key_out.replace("mlp.fc1.", "mlp.0.").replace("mlp.fc2.", "mlp.2.")
|
||||
key_out = key_out.replace("_attn_q_norm.weight", "_attn.norm.query_norm.scale").replace("_attn_k_norm.weight", "_attn.norm.key_norm.scale")
|
||||
key_out = key_out.replace(".q_norm.weight", ".norm.query_norm.scale").replace(".k_norm.weight", ".norm.key_norm.scale")
|
||||
key_out = key_out.replace("_attn_proj.", "_attn.proj.")
|
||||
key_out = key_out.replace(".modulation.linear.", ".modulation.lin.")
|
||||
key_out = key_out.replace("_in.mlp.2.", "_in.out_layer.").replace("_in.mlp.0.", "_in.in_layer.")
|
||||
out_sd[key_out] = state_dict[k]
|
||||
return out_sd
|
||||
|
||||
def process_unet_state_dict_for_saving(self, state_dict):
|
||||
replace_prefix = {"": "model.model."}
|
||||
return utils.state_dict_prefix_replace(state_dict, replace_prefix)
|
||||
|
||||
def clip_target(self, state_dict=None):
|
||||
if state_dict is None:
|
||||
state_dict = {}
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(hunyuan_video.HunyuanVideoTokenizer, hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
112
comfy/text_encoders/hunyuan_video.py
Normal file
112
comfy/text_encoders/hunyuan_video.py
Normal file
@ -0,0 +1,112 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.model_management
|
||||
import comfy.text_encoders.llama
|
||||
from transformers import LlamaTokenizerFast
|
||||
import torch
|
||||
import os
|
||||
|
||||
|
||||
def llama_detect(state_dict, prefix=""):
|
||||
out = {}
|
||||
t5_key = "{}model.norm.weight".format(prefix)
|
||||
if t5_key in state_dict:
|
||||
out["dtype_llama"] = state_dict[t5_key].dtype
|
||||
|
||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
|
||||
|
||||
class LLAMAModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class HunyuanVideoTokenizer:
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
|
||||
self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
|
||||
self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
|
||||
self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False):
|
||||
out = {}
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
|
||||
|
||||
llama_text = "{}{}".format(self.llama_template, text)
|
||||
out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids)
|
||||
return out
|
||||
|
||||
def untokenize(self, token_weight_pair):
|
||||
return self.clip_l.untokenize(token_weight_pair)
|
||||
|
||||
def state_dict(self):
|
||||
return {}
|
||||
|
||||
|
||||
class HunyuanVideoClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
||||
clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
|
||||
self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
|
||||
self.dtypes = set([dtype, dtype_llama])
|
||||
|
||||
def set_clip_options(self, options):
|
||||
self.clip_l.set_clip_options(options)
|
||||
self.llama.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
self.clip_l.reset_clip_options()
|
||||
self.llama.reset_clip_options()
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
token_weight_pairs_l = token_weight_pairs["l"]
|
||||
token_weight_pairs_llama = token_weight_pairs["llama"]
|
||||
|
||||
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
|
||||
|
||||
template_end = 0
|
||||
for i, v in enumerate(token_weight_pairs_llama[0]):
|
||||
if v[0] == 128007: # <|end_header_id|>
|
||||
template_end = i
|
||||
|
||||
if llama_out.shape[1] > (template_end + 2):
|
||||
if token_weight_pairs_llama[0][template_end + 1][0] == 271:
|
||||
template_end += 2
|
||||
llama_out = llama_out[:, template_end:]
|
||||
llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
|
||||
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
|
||||
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
|
||||
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||
return llama_out, l_pooled, llama_extra_out
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
return self.clip_l.load_sd(sd)
|
||||
else:
|
||||
return self.llama.load_sd(sd)
|
||||
|
||||
|
||||
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
|
||||
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return HunyuanVideoClipModel_
|
||||
226
comfy/text_encoders/llama.py
Normal file
226
comfy/text_encoders/llama.py
Normal file
@ -0,0 +1,226 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
import comfy.model_management
|
||||
|
||||
@dataclass
|
||||
class Llama2Config:
|
||||
vocab_size: int = 128320
|
||||
hidden_size: int = 4096
|
||||
intermediate_size: int = 14336
|
||||
num_hidden_layers: int = 32
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 8192
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 500000.0
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
|
||||
|
||||
|
||||
def rotate_half(x):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
x1 = x[..., : x.shape[-1] // 2]
|
||||
x2 = x[..., x.shape[-1] // 2 :]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def precompute_freqs_cis(head_dim, seq_len, theta, device=None):
|
||||
theta_numerator = torch.arange(0, head_dim, 2, device=device).float()
|
||||
inv_freq = 1.0 / (theta ** (theta_numerator / head_dim))
|
||||
|
||||
position_ids = torch.arange(0, seq_len, device=device).unsqueeze(0)
|
||||
|
||||
inv_freq_expanded = inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
|
||||
position_ids_expanded = position_ids[:, None, :].float()
|
||||
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
||||
emb = torch.cat((freqs, freqs), dim=-1)
|
||||
cos = emb.cos()
|
||||
sin = emb.sin()
|
||||
return (cos, sin)
|
||||
|
||||
|
||||
def apply_rope(xq, xk, freqs_cis):
|
||||
cos = freqs_cis[0].unsqueeze(1)
|
||||
sin = freqs_cis[1].unsqueeze(1)
|
||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.num_kv_heads = config.num_key_value_heads
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
|
||||
ops = ops or nn
|
||||
self.q_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
|
||||
self.o_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
):
|
||||
batch_size, seq_length, _ = hidden_states.shape
|
||||
|
||||
xq = self.q_proj(hidden_states)
|
||||
xk = self.k_proj(hidden_states)
|
||||
xv = self.v_proj(hidden_states)
|
||||
|
||||
xq = xq.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
|
||||
xk = xk.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
xv = xv.view(batch_size, seq_length, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis=freqs_cis)
|
||||
|
||||
xk = xk.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
xv = xv.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
|
||||
|
||||
output = optimized_attention(xq, xk, xv, self.num_heads, mask=attention_mask, skip_reshape=True)
|
||||
return self.o_proj(output)
|
||||
|
||||
class MLP(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
ops = ops or nn
|
||||
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
|
||||
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
|
||||
super().__init__()
|
||||
self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
|
||||
self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
optimized_attention=None,
|
||||
):
|
||||
# Self Attention
|
||||
residual = x
|
||||
x = self.input_layernorm(x)
|
||||
x = self.self_attn(
|
||||
hidden_states=x,
|
||||
attention_mask=attention_mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
)
|
||||
x = residual + x
|
||||
|
||||
# MLP
|
||||
residual = x
|
||||
x = self.post_attention_layernorm(x)
|
||||
x = self.mlp(x)
|
||||
x = residual + x
|
||||
|
||||
return x
|
||||
|
||||
class Llama2_(nn.Module):
|
||||
def __init__(self, config, device=None, dtype=None, ops=None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = ops.Embedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
device=device,
|
||||
dtype=dtype
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
TransformerBlock(config, device=device, dtype=dtype, ops=ops)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
|
||||
x = self.embed_tokens(x, out_dtype=dtype)
|
||||
|
||||
freqs_cis = precompute_freqs_cis(self.config.hidden_size // self.config.num_attention_heads,
|
||||
x.shape[1],
|
||||
self.config.rope_theta,
|
||||
device=x.device)
|
||||
|
||||
mask = None
|
||||
if attention_mask is not None:
|
||||
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
|
||||
mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
|
||||
|
||||
causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
|
||||
if mask is not None:
|
||||
mask += causal_mask
|
||||
else:
|
||||
mask = causal_mask
|
||||
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
|
||||
|
||||
intermediate = None
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output < 0:
|
||||
intermediate_output = len(self.layers) + intermediate_output
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(
|
||||
x=x,
|
||||
attention_mask=mask,
|
||||
freqs_cis=freqs_cis,
|
||||
optimized_attention=optimized_attention,
|
||||
)
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
|
||||
x = self.norm(x)
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
intermediate = self.norm(intermediate)
|
||||
|
||||
return x, intermediate
|
||||
|
||||
|
||||
class Llama2(torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Llama2Config(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.embed_tokens
|
||||
|
||||
def set_input_embeddings(self, embeddings):
|
||||
self.model.embed_tokens = embeddings
|
||||
|
||||
def forward(self, input_ids, *args, **kwargs):
|
||||
return self.model(input_ids, *args, **kwargs)
|
||||
410579
comfy/text_encoders/llama_tokenizer/tokenizer.json
Normal file
410579
comfy/text_encoders/llama_tokenizer/tokenizer.json
Normal file
File diff suppressed because it is too large
Load Diff
2095
comfy/text_encoders/llama_tokenizer/tokenizer_config.json
Normal file
2095
comfy/text_encoders/llama_tokenizer/tokenizer_config.json
Normal file
File diff suppressed because it is too large
Load Diff
42
comfy/text_encoders/pixart_t5.py
Normal file
42
comfy/text_encoders/pixart_t5.py
Normal file
@ -0,0 +1,42 @@
|
||||
import os
|
||||
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.t5
|
||||
import comfy.text_encoders.sd3_clip
|
||||
from comfy.sd1_clip import gen_empty_tokens
|
||||
|
||||
from transformers import T5TokenizerFast
|
||||
|
||||
class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def gen_empty_tokens(self, special_tokens, *args, **kwargs):
|
||||
# PixArt expects the negative to be all pad tokens
|
||||
special_tokens = special_tokens.copy()
|
||||
special_tokens.pop("end")
|
||||
return gen_empty_tokens(special_tokens, *args, **kwargs)
|
||||
|
||||
class PixArtT5XXL(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options)
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
|
||||
super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding
|
||||
|
||||
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
class PixArtTEModel_(PixArtT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return PixArtTEModel_
|
||||
@ -181,7 +181,6 @@ class T5LayerSelfAttention(torch.nn.Module):
|
||||
# self.dropout = nn.Dropout(config.dropout_rate)
|
||||
|
||||
def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
|
||||
normed_hidden_states = self.layer_norm(x)
|
||||
output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias, optimized_attention=optimized_attention)
|
||||
# x = x + self.dropout(attention_output)
|
||||
x += output
|
||||
|
||||
151
comfy/utils.py
151
comfy/utils.py
@ -36,6 +36,8 @@ import numpy as np
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from PIL import Image
|
||||
from einops import rearrange
|
||||
from torch.nn.functional import interpolate
|
||||
from tqdm import tqdm
|
||||
|
||||
from . import checkpoint_pickle, interruption
|
||||
@ -97,7 +99,12 @@ def load_torch_file(ckpt: str, safe_load=False, device=None):
|
||||
if "state_dict" in pl_sd:
|
||||
sd = pl_sd["state_dict"]
|
||||
else:
|
||||
sd = pl_sd
|
||||
if len(pl_sd) == 1:
|
||||
key = list(pl_sd.keys())[0]
|
||||
sd = pl_sd[key]
|
||||
if not isinstance(sd, dict):
|
||||
sd = pl_sd
|
||||
else:sd = pl_sd
|
||||
except UnpicklingError as exc_info:
|
||||
try:
|
||||
# wrong extension is most likely, try to load as safetensors anyway
|
||||
@ -452,6 +459,80 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""):
|
||||
return key_map
|
||||
|
||||
|
||||
PIXART_MAP_BASIC = {
|
||||
("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"),
|
||||
("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"),
|
||||
("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"),
|
||||
("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"),
|
||||
("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"),
|
||||
("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"),
|
||||
("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"),
|
||||
("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"),
|
||||
("x_embedder.proj.weight", "pos_embed.proj.weight"),
|
||||
("x_embedder.proj.bias", "pos_embed.proj.bias"),
|
||||
("y_embedder.y_embedding", "caption_projection.y_embedding"),
|
||||
("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"),
|
||||
("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"),
|
||||
("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"),
|
||||
("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"),
|
||||
("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"),
|
||||
("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"),
|
||||
("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"),
|
||||
("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"),
|
||||
("t_block.1.weight", "adaln_single.linear.weight"),
|
||||
("t_block.1.bias", "adaln_single.linear.bias"),
|
||||
("final_layer.linear.weight", "proj_out.weight"),
|
||||
("final_layer.linear.bias", "proj_out.bias"),
|
||||
("final_layer.scale_shift_table", "scale_shift_table"),
|
||||
}
|
||||
|
||||
PIXART_MAP_BLOCK = {
|
||||
("scale_shift_table", "scale_shift_table"),
|
||||
("attn.proj.weight", "attn1.to_out.0.weight"),
|
||||
("attn.proj.bias", "attn1.to_out.0.bias"),
|
||||
("mlp.fc1.weight", "ff.net.0.proj.weight"),
|
||||
("mlp.fc1.bias", "ff.net.0.proj.bias"),
|
||||
("mlp.fc2.weight", "ff.net.2.weight"),
|
||||
("mlp.fc2.bias", "ff.net.2.bias"),
|
||||
("cross_attn.proj.weight", "attn2.to_out.0.weight"),
|
||||
("cross_attn.proj.bias", "attn2.to_out.0.bias"),
|
||||
}
|
||||
|
||||
|
||||
def pixart_to_diffusers(mmdit_config, output_prefix=""):
|
||||
key_map = {}
|
||||
|
||||
depth = mmdit_config.get("depth", 0)
|
||||
offset = mmdit_config.get("hidden_size", 1152)
|
||||
|
||||
for i in range(depth):
|
||||
block_from = "transformer_blocks.{}".format(i)
|
||||
block_to = "{}blocks.{}".format(output_prefix, i)
|
||||
|
||||
for end in ("weight", "bias"):
|
||||
s = "{}.attn1.".format(block_from)
|
||||
qkv = "{}.attn.qkv.{}".format(block_to, end)
|
||||
key_map["{}to_q.{}".format(s, end)] = (qkv, (0, 0, offset))
|
||||
key_map["{}to_k.{}".format(s, end)] = (qkv, (0, offset, offset))
|
||||
key_map["{}to_v.{}".format(s, end)] = (qkv, (0, offset * 2, offset))
|
||||
|
||||
s = "{}.attn2.".format(block_from)
|
||||
q = "{}.cross_attn.q_linear.{}".format(block_to, end)
|
||||
kv = "{}.cross_attn.kv_linear.{}".format(block_to, end)
|
||||
|
||||
key_map["{}to_q.{}".format(s, end)] = q
|
||||
key_map["{}to_k.{}".format(s, end)] = (kv, (0, 0, offset))
|
||||
key_map["{}to_v.{}".format(s, end)] = (kv, (0, offset, offset))
|
||||
|
||||
for k in PIXART_MAP_BLOCK:
|
||||
key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0])
|
||||
|
||||
for k in PIXART_MAP_BASIC:
|
||||
key_map[k[1]] = "{}{}".format(output_prefix, k[0])
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
def auraflow_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_double_layers = mmdit_config.get("n_double_layers", 0)
|
||||
n_layers = mmdit_config.get("n_layers", 0)
|
||||
@ -830,7 +911,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None):
|
||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, pbar=None):
|
||||
dims = len(tile)
|
||||
|
||||
if not (isinstance(upscale_amount, (tuple, list))):
|
||||
@ -846,10 +927,22 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
else:
|
||||
return up * val
|
||||
|
||||
def get_downscale(dim, val):
|
||||
up = upscale_amount[dim]
|
||||
if callable(up):
|
||||
return up(val)
|
||||
else:
|
||||
return val / up
|
||||
|
||||
if downscale:
|
||||
get_scale = get_downscale
|
||||
else:
|
||||
get_scale = get_upscale
|
||||
|
||||
def mult_list_upscale(a):
|
||||
out = []
|
||||
for i in range(len(a)):
|
||||
out.append(round(get_upscale(i, a[i])))
|
||||
out.append(round(get_scale(i, a[i])))
|
||||
return out
|
||||
|
||||
output = torch.empty([samples.shape[0], out_channels] + mult_list_upscale(samples.shape[2:]), device=output_device)
|
||||
@ -873,16 +966,18 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
upscaled = []
|
||||
|
||||
for d in range(dims):
|
||||
pos = max(0, min(s.shape[d + 2] - (overlap[d] + 1), it[d]))
|
||||
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
||||
l = min(tile[d], s.shape[d + 2] - pos)
|
||||
s_in = s_in.narrow(d + 2, pos, l)
|
||||
upscaled.append(round(get_upscale(d, pos)))
|
||||
upscaled.append(round(get_scale(d, pos)))
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
|
||||
for d in range(2, dims + 2):
|
||||
feather = round(get_upscale(d - 2, overlap[d - 2]))
|
||||
feather = round(get_scale(d - 2, overlap[d - 2]))
|
||||
if feather >= mask.shape[d]:
|
||||
continue
|
||||
for t in range(feather):
|
||||
a = (t + 1) / feather
|
||||
mask.narrow(d, t, 1).mul_(a)
|
||||
@ -905,7 +1000,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
|
||||
|
||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", pbar=None):
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap, upscale_amount, out_channels, output_device, pbar)
|
||||
return tiled_scale_multidim(samples, function, (tile_y, tile_x), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=output_device, pbar=pbar)
|
||||
|
||||
|
||||
def _progress_bar_update(value: float, total: float, preview_image_or_data: Optional[Any] = None, client_id: Optional[str] = None, server: Optional[ExecutorToClientProgress] = None):
|
||||
@ -1072,3 +1167,45 @@ def reshape_mask(input_mask, output_shape):
|
||||
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:, :output_shape[1]]
|
||||
mask = repeat_to_batch_size(mask, output_shape[0])
|
||||
return mask
|
||||
|
||||
|
||||
def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out):
|
||||
hi, wi = img_size_in
|
||||
ho, wo = img_size_out
|
||||
# if it's already the correct size, no need to do anything
|
||||
if (hi, wi) == (ho, wo):
|
||||
return mask
|
||||
if mask.ndim == 2:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.ndim != 3:
|
||||
raise ValueError(f"Got a mask of shape {list(mask.shape)}, expected [b, q, k] or [q, k]")
|
||||
txt_tokens = mask.shape[1] - (hi * wi)
|
||||
# quadrants of the mask
|
||||
txt_to_txt = mask[:, :txt_tokens, :txt_tokens]
|
||||
txt_to_img = mask[:, :txt_tokens, txt_tokens:]
|
||||
img_to_img = mask[:, txt_tokens:, txt_tokens:]
|
||||
img_to_txt = mask[:, txt_tokens:, :txt_tokens]
|
||||
|
||||
# convert to 1d x 2d, interpolate, then back to 1d x 1d
|
||||
txt_to_img = rearrange(txt_to_img, "b t (h w) -> b t h w", h=hi, w=wi)
|
||||
txt_to_img = interpolate(txt_to_img, size=img_size_out, mode="bilinear")
|
||||
txt_to_img = rearrange(txt_to_img, "b t h w -> b t (h w)")
|
||||
# this one is hard because we have to do it twice
|
||||
# convert to 1d x 2d, interpolate, then to 2d x 1d, interpolate, then 1d x 1d
|
||||
img_to_img = rearrange(img_to_img, "b hw (h w) -> b hw h w", h=hi, w=wi)
|
||||
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
|
||||
img_to_img = rearrange(img_to_img, "b (hk wk) hq wq -> b (hq wq) hk wk", hk=hi, wk=wi)
|
||||
img_to_img = interpolate(img_to_img, size=img_size_out, mode="bilinear")
|
||||
img_to_img = rearrange(img_to_img, "b (hq wq) hk wk -> b (hk wk) (hq wq)", hq=ho, wq=wo)
|
||||
# convert to 2d x 1d, interpolate, then back to 1d x 1d
|
||||
img_to_txt = rearrange(img_to_txt, "b (h w) t -> b t h w", h=hi, w=wi)
|
||||
img_to_txt = interpolate(img_to_txt, size=img_size_out, mode="bilinear")
|
||||
img_to_txt = rearrange(img_to_txt, "b t h w -> b (h w) t")
|
||||
|
||||
# reassemble the mask from blocks
|
||||
out = torch.cat([
|
||||
torch.cat([txt_to_txt, txt_to_img], dim=2),
|
||||
torch.cat([img_to_txt, img_to_img], dim=2)],
|
||||
dim=1
|
||||
)
|
||||
return out
|
||||
|
||||
58
comfy/web/assets/DownloadGitView-B3f7KHY3.js
generated
vendored
Normal file
58
comfy/web/assets/DownloadGitView-B3f7KHY3.js
generated
vendored
Normal file
@ -0,0 +1,58 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { a as defineComponent, f as openBlock, g as createElementBlock, A as createBaseVNode, a8 as toDisplayString, h as createVNode, z as unref, D as script, bU as useRouter } from "./index-DIU5yZe9.js";
|
||||
const _hoisted_1 = { class: "font-sans w-screen h-screen mx-0 grid place-items-center justify-center items-center text-neutral-900 bg-neutral-300 pointer-events-auto" };
|
||||
const _hoisted_2 = { class: "col-start-1 h-screen row-start-1 place-content-center mx-auto overflow-y-auto" };
|
||||
const _hoisted_3 = { class: "max-w-screen-sm flex flex-col gap-8 p-8 bg-[url('/assets/images/Git-Logo-White.svg')] bg-no-repeat bg-right-top bg-origin-padding" };
|
||||
const _hoisted_4 = { class: "mt-24 text-4xl font-bold text-red-500" };
|
||||
const _hoisted_5 = { class: "space-y-4" };
|
||||
const _hoisted_6 = { class: "text-xl" };
|
||||
const _hoisted_7 = { class: "text-xl" };
|
||||
const _hoisted_8 = { class: "text-m" };
|
||||
const _hoisted_9 = { class: "flex gap-4 flex-row-reverse" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "DownloadGitView",
|
||||
setup(__props) {
|
||||
const openGitDownloads = /* @__PURE__ */ __name(() => {
|
||||
window.open("https://git-scm.com/downloads/", "_blank");
|
||||
}, "openGitDownloads");
|
||||
const skipGit = /* @__PURE__ */ __name(() => {
|
||||
console.warn("pushing");
|
||||
const router = useRouter();
|
||||
router.push("install");
|
||||
}, "skipGit");
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", _hoisted_1, [
|
||||
createBaseVNode("div", _hoisted_2, [
|
||||
createBaseVNode("div", _hoisted_3, [
|
||||
createBaseVNode("h1", _hoisted_4, toDisplayString(_ctx.$t("downloadGit.title")), 1),
|
||||
createBaseVNode("div", _hoisted_5, [
|
||||
createBaseVNode("p", _hoisted_6, toDisplayString(_ctx.$t("downloadGit.message")), 1),
|
||||
createBaseVNode("p", _hoisted_7, toDisplayString(_ctx.$t("downloadGit.instructions")), 1),
|
||||
createBaseVNode("p", _hoisted_8, toDisplayString(_ctx.$t("downloadGit.warning")), 1)
|
||||
]),
|
||||
createBaseVNode("div", _hoisted_9, [
|
||||
createVNode(unref(script), {
|
||||
label: _ctx.$t("downloadGit.gitWebsite"),
|
||||
icon: "pi pi-external-link",
|
||||
"icon-pos": "right",
|
||||
onClick: openGitDownloads,
|
||||
severity: "primary"
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(unref(script), {
|
||||
label: _ctx.$t("downloadGit.skip"),
|
||||
icon: "pi pi-exclamation-triangle",
|
||||
onClick: skipGit,
|
||||
severity: "secondary"
|
||||
}, null, 8, ["label"])
|
||||
])
|
||||
])
|
||||
])
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=DownloadGitView-B3f7KHY3.js.map
|
||||
1
comfy/web/assets/DownloadGitView-B3f7KHY3.js.map
generated
vendored
Normal file
1
comfy/web/assets/DownloadGitView-B3f7KHY3.js.map
generated
vendored
Normal file
@ -0,0 +1 @@
|
||||
{"version":3,"file":"DownloadGitView-B3f7KHY3.js","sources":["../../src/views/DownloadGitView.vue"],"sourcesContent":["<template>\n <div\n class=\"font-sans w-screen h-screen mx-0 grid place-items-center justify-center items-center text-neutral-900 bg-neutral-300 pointer-events-auto\"\n >\n <div\n class=\"col-start-1 h-screen row-start-1 place-content-center mx-auto overflow-y-auto\"\n >\n <div\n class=\"max-w-screen-sm flex flex-col gap-8 p-8 bg-[url('/assets/images/Git-Logo-White.svg')] bg-no-repeat bg-right-top bg-origin-padding\"\n >\n <!-- Header -->\n <h1 class=\"mt-24 text-4xl font-bold text-red-500\">\n {{ $t('downloadGit.title') }}\n </h1>\n\n <!-- Message -->\n <div class=\"space-y-4\">\n <p class=\"text-xl\">\n {{ $t('downloadGit.message') }}\n </p>\n <p class=\"text-xl\">\n {{ $t('downloadGit.instructions') }}\n </p>\n <p class=\"text-m\">\n {{ $t('downloadGit.warning') }}\n </p>\n </div>\n\n <!-- Actions -->\n <div class=\"flex gap-4 flex-row-reverse\">\n <Button\n :label=\"$t('downloadGit.gitWebsite')\"\n icon=\"pi pi-external-link\"\n icon-pos=\"right\"\n @click=\"openGitDownloads\"\n severity=\"primary\"\n />\n <Button\n :label=\"$t('downloadGit.skip')\"\n icon=\"pi pi-exclamation-triangle\"\n @click=\"skipGit\"\n severity=\"secondary\"\n />\n </div>\n </div>\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport { useRouter } from 'vue-router'\n\nconst openGitDownloads = () => {\n window.open('https://git-scm.com/downloads/', '_blank')\n}\n\nconst skipGit = () => {\n console.warn('pushing')\n const router = useRouter()\n router.push('install')\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;;;;;;AAqDA,UAAM,mBAAmB,6BAAM;AACtB,aAAA,KAAK,kCAAkC,QAAQ;AAAA,IAAA,GAD/B;AAIzB,UAAM,UAAU,6BAAM;AACpB,cAAQ,KAAK,SAAS;AACtB,YAAM,SAAS;AACf,aAAO,KAAK,SAAS;AAAA,IAAA,GAHP;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
@ -1,8 +1,14 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
<<<<<<<< HEAD:comfy/web/assets/ExtensionPanel-vSDJrNxh.js
|
||||
import { a as defineComponent, r as ref, cg as FilterMatchMode, ck as useExtensionStore, u as useSettingStore, o as onMounted, q as computed, f as openBlock, x as createBlock, y as withCtx, h as createVNode, ch as SearchBox, z as unref, bS as script, A as createBaseVNode, g as createElementBlock, Q as renderList, a6 as toDisplayString, ax as createTextVNode, P as Fragment, D as script$1, i as createCommentVNode, c1 as script$3, ci as _sfc_main$1 } from "./index-BQYg0VNJ.js";
|
||||
import { s as script$2, a as script$4 } from "./index-CMsGQEqY.js";
|
||||
import "./index-DJqEjTnE.js";
|
||||
========
|
||||
import { a as defineComponent, r as ref, ck as FilterMatchMode, co as useExtensionStore, u as useSettingStore, o as onMounted, q as computed, f as openBlock, x as createBlock, y as withCtx, h as createVNode, cl as SearchBox, z as unref, bW as script, A as createBaseVNode, g as createElementBlock, Q as renderList, a8 as toDisplayString, ay as createTextVNode, P as Fragment, D as script$1, i as createCommentVNode, c5 as script$3, cm as _sfc_main$1 } from "./index-DIU5yZe9.js";
|
||||
import { s as script$2, a as script$4 } from "./index-D3u7l7ha.js";
|
||||
import "./index-d698Brhb.js";
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/ExtensionPanel-ByeZ01RF.js
|
||||
const _hoisted_1 = { class: "flex justify-end" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "ExtensionPanel",
|
||||
@ -114,4 +120,8 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
<<<<<<<< HEAD:comfy/web/assets/ExtensionPanel-vSDJrNxh.js
|
||||
//# sourceMappingURL=ExtensionPanel-vSDJrNxh.js.map
|
||||
========
|
||||
//# sourceMappingURL=ExtensionPanel-ByeZ01RF.js.map
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/ExtensionPanel-ByeZ01RF.js
|
||||
1
comfy/web/assets/ExtensionPanel-ByeZ01RF.js.map
generated
vendored
Normal file
1
comfy/web/assets/ExtensionPanel-ByeZ01RF.js.map
generated
vendored
Normal file
@ -0,0 +1 @@
|
||||
{"version":3,"file":"ExtensionPanel-ByeZ01RF.js","sources":["../../src/components/dialog/content/setting/ExtensionPanel.vue"],"sourcesContent":["<template>\n <PanelTemplate value=\"Extension\" class=\"extension-panel\">\n <template #header>\n <SearchBox\n v-model=\"filters['global'].value\"\n :placeholder=\"$t('g.searchExtensions') + '...'\"\n />\n <Message v-if=\"hasChanges\" severity=\"info\" pt:text=\"w-full\">\n <ul>\n <li v-for=\"ext in changedExtensions\" :key=\"ext.name\">\n <span>\n {{ extensionStore.isExtensionEnabled(ext.name) ? '[-]' : '[+]' }}\n </span>\n {{ ext.name }}\n </li>\n </ul>\n <div class=\"flex justify-end\">\n <Button\n :label=\"$t('g.reloadToApplyChanges')\"\n @click=\"applyChanges\"\n outlined\n severity=\"danger\"\n />\n </div>\n </Message>\n </template>\n <DataTable\n :value=\"extensionStore.extensions\"\n stripedRows\n size=\"small\"\n :filters=\"filters\"\n >\n <Column field=\"name\" :header=\"$t('g.extensionName')\" sortable></Column>\n <Column\n :pt=\"{\n bodyCell: 'flex items-center justify-end'\n }\"\n >\n <template #body=\"slotProps\">\n <ToggleSwitch\n v-model=\"editingEnabledExtensions[slotProps.data.name]\"\n @change=\"updateExtensionStatus\"\n />\n </template>\n </Column>\n </DataTable>\n </PanelTemplate>\n</template>\n\n<script setup lang=\"ts\">\nimport { ref, computed, onMounted } from 'vue'\nimport { useExtensionStore } from '@/stores/extensionStore'\nimport { useSettingStore } from '@/stores/settingStore'\nimport DataTable from 'primevue/datatable'\nimport Column from 'primevue/column'\nimport ToggleSwitch from 'primevue/toggleswitch'\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\nimport { FilterMatchMode } from '@primevue/core/api'\nimport PanelTemplate from './PanelTemplate.vue'\nimport SearchBox from '@/components/common/SearchBox.vue'\n\nconst filters = ref({\n global: { value: '', matchMode: FilterMatchMode.CONTAINS }\n})\n\nconst extensionStore = useExtensionStore()\nconst settingStore = useSettingStore()\n\nconst editingEnabledExtensions = ref<Record<string, boolean>>({})\n\nonMounted(() => {\n extensionStore.extensions.forEach((ext) => {\n editingEnabledExtensions.value[ext.name] =\n extensionStore.isExtensionEnabled(ext.name)\n })\n})\n\nconst changedExtensions = computed(() => {\n return extensionStore.extensions.filter(\n (ext) =>\n editingEnabledExtensions.value[ext.name] !==\n extensionStore.isExtensionEnabled(ext.name)\n )\n})\n\nconst hasChanges = computed(() => {\n return changedExtensions.value.length > 0\n})\n\nconst updateExtensionStatus = () => {\n const editingDisabledExtensionNames = Object.entries(\n editingEnabledExtensions.value\n )\n .filter(([_, enabled]) => !enabled)\n .map(([name]) => name)\n\n settingStore.set('Comfy.Extension.Disabled', [\n ...extensionStore.inactiveDisabledExtensionNames,\n ...editingDisabledExtensionNames\n ])\n}\n\nconst applyChanges = () => {\n // Refresh the page to apply changes\n window.location.reload()\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;AA8DA,UAAM,UAAU,IAAI;AAAA,MAClB,QAAQ,EAAE,OAAO,IAAI,WAAW,gBAAgB,SAAS;AAAA,IAAA,CAC1D;AAED,UAAM,iBAAiB;AACvB,UAAM,eAAe;AAEf,UAAA,2BAA2B,IAA6B,CAAA,CAAE;AAEhE,cAAU,MAAM;AACC,qBAAA,WAAW,QAAQ,CAAC,QAAQ;AACzC,iCAAyB,MAAM,IAAI,IAAI,IACrC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA,CAC7C;AAAA,IAAA,CACF;AAEK,UAAA,oBAAoB,SAAS,MAAM;AACvC,aAAO,eAAe,WAAW;AAAA,QAC/B,CAAC,QACC,yBAAyB,MAAM,IAAI,IAAI,MACvC,eAAe,mBAAmB,IAAI,IAAI;AAAA,MAAA;AAAA,IAC9C,CACD;AAEK,UAAA,aAAa,SAAS,MAAM;AACzB,aAAA,kBAAkB,MAAM,SAAS;AAAA,IAAA,CACzC;AAED,UAAM,wBAAwB,6BAAM;AAClC,YAAM,gCAAgC,OAAO;AAAA,QAC3C,yBAAyB;AAAA,MAExB,EAAA,OAAO,CAAC,CAAC,GAAG,OAAO,MAAM,CAAC,OAAO,EACjC,IAAI,CAAC,CAAC,IAAI,MAAM,IAAI;AAEvB,mBAAa,IAAI,4BAA4B;AAAA,QAC3C,GAAG,eAAe;AAAA,QAClB,GAAG;AAAA,MAAA,CACJ;AAAA,IAAA,GAV2B;AAa9B,UAAM,eAAe,6BAAM;AAEzB,aAAO,SAAS;IAAO,GAFJ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
48
comfy/web/assets/GraphView-DzvxEUM8.css → comfy/web/assets/GraphView-B3TpSwhZ.css
generated
vendored
48
comfy/web/assets/GraphView-DzvxEUM8.css → comfy/web/assets/GraphView-B3TpSwhZ.css
generated
vendored
@ -125,7 +125,11 @@
|
||||
align-items: flex-start !important;
|
||||
}
|
||||
|
||||
<<<<<<<< HEAD:comfy/web/assets/GraphView-DzvxEUM8.css
|
||||
.node-tooltip[data-v-259081e0] {
|
||||
========
|
||||
.node-tooltip[data-v-9ecc8adc] {
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/GraphView-B3TpSwhZ.css
|
||||
background: var(--comfy-input-bg);
|
||||
border-radius: 5px;
|
||||
box-shadow: 0 0 5px rgba(0, 0, 0, 0.4);
|
||||
@ -153,22 +157,30 @@
|
||||
border-radius: 0;
|
||||
}
|
||||
|
||||
.comfy-menu-hamburger[data-v-2ddd26e8] {
|
||||
.comfy-menu-hamburger[data-v-962c4073] {
|
||||
pointer-events: auto;
|
||||
position: fixed;
|
||||
z-index: 9999;
|
||||
}
|
||||
|
||||
<<<<<<<< HEAD:comfy/web/assets/GraphView-DzvxEUM8.css
|
||||
[data-v-012040ee] .p-togglebutton::before {
|
||||
display: none
|
||||
}
|
||||
[data-v-012040ee] .p-togglebutton {
|
||||
========
|
||||
[data-v-4cb762cb] .p-togglebutton::before {
|
||||
display: none
|
||||
}
|
||||
[data-v-4cb762cb] .p-togglebutton {
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/GraphView-B3TpSwhZ.css
|
||||
position: relative;
|
||||
flex-shrink: 0;
|
||||
border-radius: 0px;
|
||||
background-color: transparent;
|
||||
padding: 0px
|
||||
}
|
||||
<<<<<<<< HEAD:comfy/web/assets/GraphView-DzvxEUM8.css
|
||||
[data-v-012040ee] .p-togglebutton.p-togglebutton-checked {
|
||||
border-bottom-width: 2px;
|
||||
border-bottom-color: var(--p-button-text-primary-color)
|
||||
@ -177,6 +189,16 @@
|
||||
visibility: visible
|
||||
}
|
||||
.status-indicator[data-v-012040ee] {
|
||||
========
|
||||
[data-v-4cb762cb] .p-togglebutton.p-togglebutton-checked {
|
||||
border-bottom-width: 2px;
|
||||
border-bottom-color: var(--p-button-text-primary-color)
|
||||
}
|
||||
[data-v-4cb762cb] .p-togglebutton-checked .close-button,[data-v-4cb762cb] .p-togglebutton:hover .close-button {
|
||||
visibility: visible
|
||||
}
|
||||
.status-indicator[data-v-4cb762cb] {
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/GraphView-B3TpSwhZ.css
|
||||
position: absolute;
|
||||
font-weight: 700;
|
||||
font-size: 1.5rem;
|
||||
@ -184,10 +206,17 @@
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%)
|
||||
}
|
||||
<<<<<<<< HEAD:comfy/web/assets/GraphView-DzvxEUM8.css
|
||||
[data-v-012040ee] .p-togglebutton:hover .status-indicator {
|
||||
display: none
|
||||
}
|
||||
[data-v-012040ee] .p-togglebutton .close-button {
|
||||
========
|
||||
[data-v-4cb762cb] .p-togglebutton:hover .status-indicator {
|
||||
display: none
|
||||
}
|
||||
[data-v-4cb762cb] .p-togglebutton .close-button {
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/GraphView-B3TpSwhZ.css
|
||||
visibility: hidden
|
||||
}
|
||||
|
||||
@ -238,7 +267,11 @@
|
||||
display: none;
|
||||
}
|
||||
|
||||
<<<<<<<< HEAD:comfy/web/assets/GraphView-DzvxEUM8.css
|
||||
.comfyui-menu[data-v-51020bc7] {
|
||||
========
|
||||
.comfyui-menu[data-v-d792da31] {
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/GraphView-B3TpSwhZ.css
|
||||
width: 100vw;
|
||||
background: var(--comfy-menu-bg);
|
||||
color: var(--fg-color);
|
||||
@ -251,6 +284,7 @@
|
||||
grid-column: 1/-1;
|
||||
max-height: 90vh;
|
||||
}
|
||||
<<<<<<<< HEAD:comfy/web/assets/GraphView-DzvxEUM8.css
|
||||
.comfyui-menu.dropzone[data-v-51020bc7] {
|
||||
background: var(--p-highlight-background);
|
||||
}
|
||||
@ -261,6 +295,18 @@
|
||||
line-height: revert;
|
||||
}
|
||||
.comfyui-logo[data-v-51020bc7] {
|
||||
========
|
||||
.comfyui-menu.dropzone[data-v-d792da31] {
|
||||
background: var(--p-highlight-background);
|
||||
}
|
||||
.comfyui-menu.dropzone-active[data-v-d792da31] {
|
||||
background: var(--p-highlight-background-focus);
|
||||
}
|
||||
[data-v-d792da31] .p-menubar-item-label {
|
||||
line-height: revert;
|
||||
}
|
||||
.comfyui-logo[data-v-d792da31] {
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/GraphView-B3TpSwhZ.css
|
||||
font-size: 1.2em;
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
267
comfy/web/assets/GraphView-BAZmiysO.js → comfy/web/assets/GraphView-BWxgNrh6.js
generated
vendored
267
comfy/web/assets/GraphView-BAZmiysO.js → comfy/web/assets/GraphView-BWxgNrh6.js
generated
vendored
File diff suppressed because one or more lines are too long
1
comfy/web/assets/GraphView-BWxgNrh6.js.map
generated
vendored
Normal file
1
comfy/web/assets/GraphView-BWxgNrh6.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
111
comfy/web/assets/InstallView-BDCw9SUg.js → comfy/web/assets/InstallView-DbHtR5YG.js
generated
vendored
111
comfy/web/assets/InstallView-BDCw9SUg.js → comfy/web/assets/InstallView-DbHtR5YG.js
generated
vendored
File diff suppressed because one or more lines are too long
1
comfy/web/assets/InstallView-DbHtR5YG.js.map
generated
vendored
Normal file
1
comfy/web/assets/InstallView-DbHtR5YG.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
8
comfy/web/assets/KeybindingPanel-C3wT8hYZ.css
generated
vendored
Normal file
8
comfy/web/assets/KeybindingPanel-C3wT8hYZ.css
generated
vendored
Normal file
@ -0,0 +1,8 @@
|
||||
|
||||
[data-v-c20ad403] .p-datatable-tbody > tr > td {
|
||||
padding: 0.25rem;
|
||||
min-height: 2rem
|
||||
}
|
||||
[data-v-c20ad403] .p-datatable-row-selected .actions,[data-v-c20ad403] .p-datatable-selectable-row:hover .actions {
|
||||
visibility: visible
|
||||
}
|
||||
@ -1,8 +1,14 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
<<<<<<<< HEAD:comfy/web/assets/KeybindingPanel-DfPGcDsG.js
|
||||
import { a as defineComponent, q as computed, f as openBlock, g as createElementBlock, P as Fragment, Q as renderList, h as createVNode, y as withCtx, ax as createTextVNode, a6 as toDisplayString, z as unref, aB as script, i as createCommentVNode, r as ref, cg as FilterMatchMode, O as useKeybindingStore, F as useCommandStore, aK as watchEffect, bi as useToast, t as resolveDirective, x as createBlock, ch as SearchBox, A as createBaseVNode, D as script$2, ap as script$4, bm as withModifiers, bS as script$5, aH as script$6, v as withDirectives, ci as _sfc_main$2, ca as KeyComboImpl, cj as KeybindingImpl, _ as _export_sfc } from "./index-BQYg0VNJ.js";
|
||||
import { s as script$1, a as script$3 } from "./index-CMsGQEqY.js";
|
||||
import "./index-DJqEjTnE.js";
|
||||
========
|
||||
import { a as defineComponent, q as computed, f as openBlock, g as createElementBlock, P as Fragment, Q as renderList, h as createVNode, y as withCtx, ay as createTextVNode, a8 as toDisplayString, z as unref, aC as script, i as createCommentVNode, r as ref, ck as FilterMatchMode, O as useKeybindingStore, F as useCommandStore, I as useI18n, aS as normalizeI18nKey, aL as watchEffect, bn as useToast, t as resolveDirective, x as createBlock, cl as SearchBox, A as createBaseVNode, D as script$2, aq as script$4, br as withModifiers, bW as script$5, aI as script$6, v as withDirectives, cm as _sfc_main$2, R as pushScopeId, U as popScopeId, ce as KeyComboImpl, cn as KeybindingImpl, _ as _export_sfc } from "./index-DIU5yZe9.js";
|
||||
import { s as script$1, a as script$3 } from "./index-D3u7l7ha.js";
|
||||
import "./index-d698Brhb.js";
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/KeybindingPanel-DC2AxNNa.js
|
||||
const _hoisted_1$1 = {
|
||||
key: 0,
|
||||
class: "px-2"
|
||||
@ -35,6 +41,10 @@ const _sfc_main$1 = /* @__PURE__ */ defineComponent({
|
||||
};
|
||||
}
|
||||
});
|
||||
<<<<<<<< HEAD:comfy/web/assets/KeybindingPanel-DfPGcDsG.js
|
||||
========
|
||||
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-c20ad403"), n = n(), popScopeId(), n), "_withScopeId");
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/KeybindingPanel-DC2AxNNa.js
|
||||
const _hoisted_1 = { class: "actions invisible flex flex-row" };
|
||||
const _hoisted_2 = ["title"];
|
||||
const _hoisted_3 = { key: 1 };
|
||||
@ -46,9 +56,11 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
});
|
||||
const keybindingStore = useKeybindingStore();
|
||||
const commandStore = useCommandStore();
|
||||
const { t } = useI18n();
|
||||
const commandsData = computed(() => {
|
||||
return Object.values(commandStore.commands).map((command) => ({
|
||||
id: command.id,
|
||||
label: t(`commands.${normalizeI18nKey(command.id)}.label`, command.label),
|
||||
keybinding: keybindingStore.getKeybindingByCommandId(command.id)
|
||||
}));
|
||||
});
|
||||
@ -187,7 +199,7 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
createBaseVNode("div", {
|
||||
class: "overflow-hidden text-ellipsis whitespace-nowrap",
|
||||
title: slotProps.data.id
|
||||
}, toDisplayString(slotProps.data.id), 9, _hoisted_2)
|
||||
}, toDisplayString(slotProps.data.label), 9, _hoisted_2)
|
||||
]),
|
||||
_: 1
|
||||
}),
|
||||
@ -271,8 +283,16 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
};
|
||||
}
|
||||
});
|
||||
<<<<<<<< HEAD:comfy/web/assets/KeybindingPanel-DfPGcDsG.js
|
||||
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-8cf0326b"]]);
|
||||
export {
|
||||
KeybindingPanel as default
|
||||
};
|
||||
//# sourceMappingURL=KeybindingPanel-DfPGcDsG.js.map
|
||||
========
|
||||
const KeybindingPanel = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-c20ad403"]]);
|
||||
export {
|
||||
KeybindingPanel as default
|
||||
};
|
||||
//# sourceMappingURL=KeybindingPanel-DC2AxNNa.js.map
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/KeybindingPanel-DC2AxNNa.js
|
||||
1
comfy/web/assets/KeybindingPanel-DC2AxNNa.js.map
generated
vendored
Normal file
1
comfy/web/assets/KeybindingPanel-DC2AxNNa.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
82
comfy/web/assets/NotSupportedView-C8O1Ed5c.js
generated
vendored
Normal file
82
comfy/web/assets/NotSupportedView-C8O1Ed5c.js
generated
vendored
Normal file
@ -0,0 +1,82 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { a as defineComponent, bU as useRouter, t as resolveDirective, f as openBlock, g as createElementBlock, A as createBaseVNode, a8 as toDisplayString, h as createVNode, z as unref, D as script, v as withDirectives } from "./index-DIU5yZe9.js";
|
||||
const _imports_0 = "" + new URL("images/sad_girl.png", import.meta.url).href;
|
||||
const _hoisted_1 = { class: "font-sans w-screen h-screen flex items-center m-0 text-neutral-900 bg-neutral-300 pointer-events-auto" };
|
||||
const _hoisted_2 = { class: "flex-grow flex items-center justify-center" };
|
||||
const _hoisted_3 = { class: "flex flex-col gap-8 p-8" };
|
||||
const _hoisted_4 = { class: "text-4xl font-bold text-red-500" };
|
||||
const _hoisted_5 = { class: "space-y-4" };
|
||||
const _hoisted_6 = { class: "text-xl" };
|
||||
const _hoisted_7 = { class: "list-disc list-inside space-y-1 text-neutral-800" };
|
||||
const _hoisted_8 = { class: "flex gap-4" };
|
||||
const _hoisted_9 = /* @__PURE__ */ createBaseVNode("div", { class: "h-screen flex-grow-0" }, [
|
||||
/* @__PURE__ */ createBaseVNode("img", {
|
||||
src: _imports_0,
|
||||
alt: "Sad girl illustration",
|
||||
class: "h-full object-cover"
|
||||
})
|
||||
], -1);
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "NotSupportedView",
|
||||
setup(__props) {
|
||||
const openDocs = /* @__PURE__ */ __name(() => {
|
||||
window.open(
|
||||
"https://github.com/Comfy-Org/desktop#currently-supported-platforms",
|
||||
"_blank"
|
||||
);
|
||||
}, "openDocs");
|
||||
const reportIssue = /* @__PURE__ */ __name(() => {
|
||||
window.open("https://forum.comfy.org/c/v1-feedback/", "_blank");
|
||||
}, "reportIssue");
|
||||
const router = useRouter();
|
||||
const continueToInstall = /* @__PURE__ */ __name(() => {
|
||||
router.push("/install");
|
||||
}, "continueToInstall");
|
||||
return (_ctx, _cache) => {
|
||||
const _directive_tooltip = resolveDirective("tooltip");
|
||||
return openBlock(), createElementBlock("div", _hoisted_1, [
|
||||
createBaseVNode("div", _hoisted_2, [
|
||||
createBaseVNode("div", _hoisted_3, [
|
||||
createBaseVNode("h1", _hoisted_4, toDisplayString(_ctx.$t("notSupported.title")), 1),
|
||||
createBaseVNode("div", _hoisted_5, [
|
||||
createBaseVNode("p", _hoisted_6, toDisplayString(_ctx.$t("notSupported.message")), 1),
|
||||
createBaseVNode("ul", _hoisted_7, [
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t("notSupported.supportedDevices.macos")), 1),
|
||||
createBaseVNode("li", null, toDisplayString(_ctx.$t("notSupported.supportedDevices.windows")), 1)
|
||||
])
|
||||
]),
|
||||
createBaseVNode("div", _hoisted_8, [
|
||||
createVNode(unref(script), {
|
||||
label: _ctx.$t("notSupported.learnMore"),
|
||||
icon: "pi pi-github",
|
||||
onClick: openDocs,
|
||||
severity: "secondary"
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(unref(script), {
|
||||
label: _ctx.$t("notSupported.reportIssue"),
|
||||
icon: "pi pi-flag",
|
||||
onClick: reportIssue,
|
||||
severity: "secondary"
|
||||
}, null, 8, ["label"]),
|
||||
withDirectives(createVNode(unref(script), {
|
||||
label: _ctx.$t("notSupported.continue"),
|
||||
icon: "pi pi-arrow-right",
|
||||
iconPos: "right",
|
||||
onClick: continueToInstall,
|
||||
severity: "danger"
|
||||
}, null, 8, ["label"]), [
|
||||
[_directive_tooltip, _ctx.$t("notSupported.continueTooltip")]
|
||||
])
|
||||
])
|
||||
])
|
||||
]),
|
||||
_hoisted_9
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=NotSupportedView-C8O1Ed5c.js.map
|
||||
1
comfy/web/assets/NotSupportedView-C8O1Ed5c.js.map
generated
vendored
Normal file
1
comfy/web/assets/NotSupportedView-C8O1Ed5c.js.map
generated
vendored
Normal file
@ -0,0 +1 @@
|
||||
{"version":3,"file":"NotSupportedView-C8O1Ed5c.js","sources":["../../../../../../../assets/images/sad_girl.png","../../src/views/NotSupportedView.vue"],"sourcesContent":["export default \"__VITE_PUBLIC_ASSET__b82952e7__\"","<template>\n <div\n class=\"font-sans w-screen h-screen flex items-center m-0 text-neutral-900 bg-neutral-300 pointer-events-auto\"\n >\n <div class=\"flex-grow flex items-center justify-center\">\n <div class=\"flex flex-col gap-8 p-8\">\n <!-- Header -->\n <h1 class=\"text-4xl font-bold text-red-500\">\n {{ $t('notSupported.title') }}\n </h1>\n\n <!-- Message -->\n <div class=\"space-y-4\">\n <p class=\"text-xl\">\n {{ $t('notSupported.message') }}\n </p>\n <ul class=\"list-disc list-inside space-y-1 text-neutral-800\">\n <li>{{ $t('notSupported.supportedDevices.macos') }}</li>\n <li>{{ $t('notSupported.supportedDevices.windows') }}</li>\n </ul>\n </div>\n\n <!-- Actions -->\n <div class=\"flex gap-4\">\n <Button\n :label=\"$t('notSupported.learnMore')\"\n icon=\"pi pi-github\"\n @click=\"openDocs\"\n severity=\"secondary\"\n />\n <Button\n :label=\"$t('notSupported.reportIssue')\"\n icon=\"pi pi-flag\"\n @click=\"reportIssue\"\n severity=\"secondary\"\n />\n <Button\n :label=\"$t('notSupported.continue')\"\n icon=\"pi pi-arrow-right\"\n iconPos=\"right\"\n @click=\"continueToInstall\"\n severity=\"danger\"\n v-tooltip=\"$t('notSupported.continueTooltip')\"\n />\n </div>\n </div>\n </div>\n\n <!-- Right side image -->\n <div class=\"h-screen flex-grow-0\">\n <img\n src=\"/assets/images/sad_girl.png\"\n alt=\"Sad girl illustration\"\n class=\"h-full object-cover\"\n />\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport { useRouter } from 'vue-router'\n\nconst openDocs = () => {\n window.open(\n 'https://github.com/Comfy-Org/desktop#currently-supported-platforms',\n '_blank'\n )\n}\n\nconst reportIssue = () => {\n window.open('https://forum.comfy.org/c/v1-feedback/', '_blank')\n}\n\nconst router = useRouter()\nconst continueToInstall = () => {\n router.push('/install')\n}\n</script>\n"],"names":[],"mappings":";;;AAAA,MAAe,aAAA,KAAA,IAAA,IAAA,uBAAA,YAAA,GAAA,EAAA;;;;;;;;;;;;;;;;;;;AC+Df,UAAM,WAAW,6BAAM;AACd,aAAA;AAAA,QACL;AAAA,QACA;AAAA,MAAA;AAAA,IACF,GAJe;AAOjB,UAAM,cAAc,6BAAM;AACjB,aAAA,KAAK,0CAA0C,QAAQ;AAAA,IAAA,GAD5C;AAIpB,UAAM,SAAS;AACf,UAAM,oBAAoB,6BAAM;AAC9B,aAAO,KAAK,UAAU;AAAA,IAAA,GADE;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
@ -1,7 +1,12 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
<<<<<<<< HEAD:comfy/web/assets/ServerConfigPanel-B7Ic27AR.js
|
||||
import { f as openBlock, g as createElementBlock, A as createBaseVNode, aW as markRaw, a as defineComponent, u as useSettingStore, aJ as storeToRefs, w as watch, cF as useCopyToClipboard, I as useI18n, x as createBlock, y as withCtx, z as unref, bS as script, a6 as toDisplayString, Q as renderList, P as Fragment, h as createVNode, D as script$1, i as createCommentVNode, bJ as script$2, cG as FormItem, ci as _sfc_main$1, bV as electronAPI } from "./index-BQYg0VNJ.js";
|
||||
import { u as useServerConfigStore } from "./serverConfigStore-DulDGgjD.js";
|
||||
========
|
||||
import { A as createBaseVNode, f as openBlock, g as createElementBlock, aZ as markRaw, a as defineComponent, u as useSettingStore, aK as storeToRefs, w as watch, cL as useCopyToClipboard, I as useI18n, x as createBlock, y as withCtx, z as unref, bW as script, a8 as toDisplayString, Q as renderList, P as Fragment, h as createVNode, D as script$1, i as createCommentVNode, bN as script$2, cM as FormItem, cm as _sfc_main$1, bZ as electronAPI } from "./index-DIU5yZe9.js";
|
||||
import { u as useServerConfigStore } from "./serverConfigStore-DYv7_Nld.js";
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/ServerConfigPanel-CvXC1Xmx.js
|
||||
const _hoisted_1$1 = {
|
||||
viewBox: "0 0 24 24",
|
||||
width: "1.2em",
|
||||
@ -153,4 +158,8 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
<<<<<<<< HEAD:comfy/web/assets/ServerConfigPanel-B7Ic27AR.js
|
||||
//# sourceMappingURL=ServerConfigPanel-B7Ic27AR.js.map
|
||||
========
|
||||
//# sourceMappingURL=ServerConfigPanel-CvXC1Xmx.js.map
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/ServerConfigPanel-CvXC1Xmx.js
|
||||
1
comfy/web/assets/ServerConfigPanel-CvXC1Xmx.js.map
generated
vendored
Normal file
1
comfy/web/assets/ServerConfigPanel-CvXC1Xmx.js.map
generated
vendored
Normal file
@ -0,0 +1 @@
|
||||
{"version":3,"file":"ServerConfigPanel-CvXC1Xmx.js","sources":["../../src/components/dialog/content/setting/ServerConfigPanel.vue"],"sourcesContent":["<template>\n <PanelTemplate value=\"Server-Config\" class=\"server-config-panel\">\n <template #header>\n <div class=\"flex flex-col gap-2\">\n <Message\n v-if=\"modifiedConfigs.length > 0\"\n severity=\"info\"\n pt:text=\"w-full\"\n >\n <p>\n {{ $t('serverConfig.modifiedConfigs') }}\n </p>\n <ul>\n <li v-for=\"config in modifiedConfigs\" :key=\"config.id\">\n {{ config.name }}: {{ config.initialValue }} → {{ config.value }}\n </li>\n </ul>\n <div class=\"flex justify-end gap-2\">\n <Button\n :label=\"$t('serverConfig.revertChanges')\"\n @click=\"revertChanges\"\n outlined\n />\n <Button\n :label=\"$t('serverConfig.restart')\"\n @click=\"restartApp\"\n outlined\n severity=\"danger\"\n />\n </div>\n </Message>\n <Message v-if=\"commandLineArgs\" severity=\"secondary\" pt:text=\"w-full\">\n <template #icon>\n <i-lucide:terminal class=\"text-xl font-bold\" />\n </template>\n <div class=\"flex items-center justify-between\">\n <p>{{ commandLineArgs }}</p>\n <Button\n icon=\"pi pi-clipboard\"\n @click=\"copyCommandLineArgs\"\n severity=\"secondary\"\n text\n />\n </div>\n </Message>\n </div>\n </template>\n <div\n v-for=\"([label, items], i) in Object.entries(serverConfigsByCategory)\"\n :key=\"label\"\n >\n <Divider v-if=\"i > 0\" />\n <h3>{{ $t(`serverConfigCategories.${label}`, label) }}</h3>\n <div\n v-for=\"item in items\"\n :key=\"item.name\"\n class=\"flex items-center mb-4\"\n >\n <FormItem\n :item=\"translateItem(item)\"\n v-model:formValue=\"item.value\"\n :id=\"item.id\"\n :labelClass=\"{\n 'text-highlight': item.initialValue !== item.value\n }\"\n />\n </div>\n </div>\n </PanelTemplate>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport Message from 'primevue/message'\nimport Divider from 'primevue/divider'\nimport FormItem from '@/components/common/FormItem.vue'\nimport PanelTemplate from './PanelTemplate.vue'\nimport { useServerConfigStore } from '@/stores/serverConfigStore'\nimport { storeToRefs } from 'pinia'\nimport { electronAPI } from '@/utils/envUtil'\nimport { useSettingStore } from '@/stores/settingStore'\nimport { watch } from 'vue'\nimport { useCopyToClipboard } from '@/hooks/clipboardHooks'\nimport type { FormItem as FormItemType } from '@/types/settingTypes'\nimport type { ServerConfig } from '@/constants/serverConfig'\nimport { useI18n } from 'vue-i18n'\n\nconst settingStore = useSettingStore()\nconst serverConfigStore = useServerConfigStore()\nconst {\n serverConfigsByCategory,\n serverConfigValues,\n launchArgs,\n commandLineArgs,\n modifiedConfigs\n} = storeToRefs(serverConfigStore)\n\nconst revertChanges = () => {\n serverConfigStore.revertChanges()\n}\n\nconst restartApp = () => {\n electronAPI().restartApp()\n}\n\nwatch(launchArgs, (newVal) => {\n settingStore.set('Comfy.Server.LaunchArgs', newVal)\n})\n\nwatch(serverConfigValues, (newVal) => {\n settingStore.set('Comfy.Server.ServerConfigValues', newVal)\n})\n\nconst { copyToClipboard } = useCopyToClipboard()\nconst copyCommandLineArgs = async () => {\n await copyToClipboard(commandLineArgs.value)\n}\n\nconst { t } = useI18n()\nconst translateItem = (item: ServerConfig<any>): FormItemType => {\n return {\n ...item,\n name: t(`serverConfigItems.${item.id}.name`, item.name),\n tooltip: item.tooltip\n ? t(`serverConfigItems.${item.id}.tooltip`, item.tooltip)\n : undefined\n }\n}\n</script>\n"],"names":[],"mappings":";;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;AAuFA,UAAM,eAAe;AACrB,UAAM,oBAAoB;AACpB,UAAA;AAAA,MACJ;AAAA,MACA;AAAA,MACA;AAAA,MACA;AAAA,MACA;AAAA,IAAA,IACE,YAAY,iBAAiB;AAEjC,UAAM,gBAAgB,6BAAM;AAC1B,wBAAkB,cAAc;AAAA,IAAA,GADZ;AAItB,UAAM,aAAa,6BAAM;AACvB,kBAAA,EAAc;IAAW,GADR;AAIb,UAAA,YAAY,CAAC,WAAW;AACf,mBAAA,IAAI,2BAA2B,MAAM;AAAA,IAAA,CACnD;AAEK,UAAA,oBAAoB,CAAC,WAAW;AACvB,mBAAA,IAAI,mCAAmC,MAAM;AAAA,IAAA,CAC3D;AAEK,UAAA,EAAE,oBAAoB;AAC5B,UAAM,sBAAsB,mCAAY;AAChC,YAAA,gBAAgB,gBAAgB,KAAK;AAAA,IAAA,GADjB;AAItB,UAAA,EAAE,MAAM;AACR,UAAA,gBAAgB,wBAAC,SAA0C;AACxD,aAAA;AAAA,QACL,GAAG;AAAA,QACH,MAAM,EAAE,qBAAqB,KAAK,EAAE,SAAS,KAAK,IAAI;AAAA,QACtD,SAAS,KAAK,UACV,EAAE,qBAAqB,KAAK,EAAE,YAAY,KAAK,OAAO,IACtD;AAAA,MAAA;AAAA,IACN,GAPoB;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
9
comfy/web/assets/ServerStartView-BHqjjHcl.css
generated
vendored
Normal file
9
comfy/web/assets/ServerStartView-BHqjjHcl.css
generated
vendored
Normal file
@ -0,0 +1,9 @@
|
||||
|
||||
<<<<<<<< HEAD:comfy/web/assets/ServerStartView-DxIUrclT.css
|
||||
[data-v-95e9eb99] .xterm-helper-textarea {
|
||||
========
|
||||
[data-v-c0d3157e] .xterm-helper-textarea {
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/ServerStartView-BHqjjHcl.css
|
||||
/* Hide this as it moves all over when uv is running */
|
||||
display: none;
|
||||
}
|
||||
@ -1,13 +1,19 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
<<<<<<<< HEAD:comfy/web/assets/ServerStartView-kWuF5BS_.js
|
||||
import { a as defineComponent, I as useI18n, r as ref, bT as ProgressStatus, o as onMounted, f as openBlock, g as createElementBlock, A as createBaseVNode, ax as createTextVNode, a6 as toDisplayString, z as unref, i as createCommentVNode, h as createVNode, D as script, bU as BaseTerminal, bV as electronAPI, _ as _export_sfc } from "./index-BQYg0VNJ.js";
|
||||
========
|
||||
import { a as defineComponent, I as useI18n, r as ref, bX as ProgressStatus, o as onMounted, f as openBlock, g as createElementBlock, A as createBaseVNode, ay as createTextVNode, a8 as toDisplayString, z as unref, i as createCommentVNode, h as createVNode, D as script, x as createBlock, v as withDirectives, ad as vShow, bY as BaseTerminal, R as pushScopeId, U as popScopeId, bZ as electronAPI, _ as _export_sfc } from "./index-DIU5yZe9.js";
|
||||
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-c0d3157e"), n = n(), popScopeId(), n), "_withScopeId");
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/ServerStartView-BvuHEhuL.js
|
||||
const _hoisted_1 = { class: "font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto" };
|
||||
const _hoisted_2 = { class: "text-2xl font-bold" };
|
||||
const _hoisted_3 = { key: 0 };
|
||||
const _hoisted_4 = {
|
||||
key: 0,
|
||||
class: "flex items-center my-4 gap-2"
|
||||
class: "flex flex-col items-center gap-4"
|
||||
};
|
||||
const _hoisted_5 = { class: "flex items-center my-4 gap-2" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "ServerStartView",
|
||||
setup(__props) {
|
||||
@ -16,9 +22,15 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
const status = ref(ProgressStatus.INITIAL_STATE);
|
||||
const electronVersion = ref("");
|
||||
let xterm;
|
||||
const terminalVisible = ref(true);
|
||||
const updateProgress = /* @__PURE__ */ __name(({ status: newStatus }) => {
|
||||
status.value = newStatus;
|
||||
<<<<<<<< HEAD:comfy/web/assets/ServerStartView-kWuF5BS_.js
|
||||
if (newStatus !== ProgressStatus.ERROR) xterm?.clear();
|
||||
========
|
||||
if (newStatus === ProgressStatus.ERROR) terminalVisible.value = false;
|
||||
else xterm?.clear();
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/ServerStartView-BvuHEhuL.js
|
||||
}, "updateProgress");
|
||||
const terminalCreated = /* @__PURE__ */ __name(({ terminal, useAutoSize }, root) => {
|
||||
xterm = terminal;
|
||||
@ -47,31 +59,50 @@ const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("span", _hoisted_3, " v" + toDisplayString(electronVersion.value), 1)) : createCommentVNode("", true)
|
||||
]),
|
||||
status.value === unref(ProgressStatus).ERROR ? (openBlock(), createElementBlock("div", _hoisted_4, [
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-flag",
|
||||
createBaseVNode("div", _hoisted_5, [
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-flag",
|
||||
severity: "secondary",
|
||||
label: unref(t)("serverStart.reportIssue"),
|
||||
onClick: reportIssue
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-file",
|
||||
severity: "secondary",
|
||||
label: unref(t)("serverStart.openLogs"),
|
||||
onClick: openLogs
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-refresh",
|
||||
label: unref(t)("serverStart.reinstall"),
|
||||
onClick: reinstall
|
||||
}, null, 8, ["label"])
|
||||
]),
|
||||
!terminalVisible.value ? (openBlock(), createBlock(unref(script), {
|
||||
key: 0,
|
||||
icon: "pi pi-search",
|
||||
severity: "secondary",
|
||||
label: unref(t)("serverStart.reportIssue"),
|
||||
onClick: reportIssue
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-file",
|
||||
severity: "secondary",
|
||||
label: unref(t)("serverStart.openLogs"),
|
||||
onClick: openLogs
|
||||
}, null, 8, ["label"]),
|
||||
createVNode(unref(script), {
|
||||
icon: "pi pi-refresh",
|
||||
label: unref(t)("serverStart.reinstall"),
|
||||
onClick: reinstall
|
||||
}, null, 8, ["label"])
|
||||
label: unref(t)("serverStart.showTerminal"),
|
||||
onClick: _cache[0] || (_cache[0] = ($event) => terminalVisible.value = true)
|
||||
}, null, 8, ["label"])) : createCommentVNode("", true)
|
||||
])) : createCommentVNode("", true),
|
||||
createVNode(BaseTerminal, { onCreated: terminalCreated })
|
||||
withDirectives(createVNode(BaseTerminal, { onCreated: terminalCreated }, null, 512), [
|
||||
[vShow, terminalVisible.value]
|
||||
])
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
<<<<<<<< HEAD:comfy/web/assets/ServerStartView-kWuF5BS_.js
|
||||
const ServerStartView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-95e9eb99"]]);
|
||||
export {
|
||||
ServerStartView as default
|
||||
};
|
||||
//# sourceMappingURL=ServerStartView-kWuF5BS_.js.map
|
||||
========
|
||||
const ServerStartView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-v-c0d3157e"]]);
|
||||
export {
|
||||
ServerStartView as default
|
||||
};
|
||||
//# sourceMappingURL=ServerStartView-BvuHEhuL.js.map
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/ServerStartView-BvuHEhuL.js
|
||||
1
comfy/web/assets/ServerStartView-BvuHEhuL.js.map
generated
vendored
Normal file
1
comfy/web/assets/ServerStartView-BvuHEhuL.js.map
generated
vendored
Normal file
@ -0,0 +1 @@
|
||||
{"version":3,"file":"ServerStartView-BvuHEhuL.js","sources":["../../src/views/ServerStartView.vue"],"sourcesContent":["<template>\n <div\n class=\"font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto\"\n >\n <h2 class=\"text-2xl font-bold\">\n {{ t(`serverStart.process.${status}`) }}\n <span v-if=\"status === ProgressStatus.ERROR\">\n v{{ electronVersion }}\n </span>\n </h2>\n <div\n v-if=\"status === ProgressStatus.ERROR\"\n class=\"flex flex-col items-center gap-4\"\n >\n <div class=\"flex items-center my-4 gap-2\">\n <Button\n icon=\"pi pi-flag\"\n severity=\"secondary\"\n :label=\"t('serverStart.reportIssue')\"\n @click=\"reportIssue\"\n />\n <Button\n icon=\"pi pi-file\"\n severity=\"secondary\"\n :label=\"t('serverStart.openLogs')\"\n @click=\"openLogs\"\n />\n <Button\n icon=\"pi pi-refresh\"\n :label=\"t('serverStart.reinstall')\"\n @click=\"reinstall\"\n />\n </div>\n <Button\n v-if=\"!terminalVisible\"\n icon=\"pi pi-search\"\n severity=\"secondary\"\n :label=\"t('serverStart.showTerminal')\"\n @click=\"terminalVisible = true\"\n />\n </div>\n <BaseTerminal v-show=\"terminalVisible\" @created=\"terminalCreated\" />\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport { ref, onMounted, Ref } from 'vue'\nimport BaseTerminal from '@/components/bottomPanel/tabs/terminal/BaseTerminal.vue'\nimport { ProgressStatus } from '@comfyorg/comfyui-electron-types'\nimport { electronAPI } from '@/utils/envUtil'\nimport type { useTerminal } from '@/hooks/bottomPanelTabs/useTerminal'\nimport { Terminal } from '@xterm/xterm'\nimport { useI18n } from 'vue-i18n'\n\nconst electron = electronAPI()\nconst { t } = useI18n()\n\nconst status = ref<ProgressStatus>(ProgressStatus.INITIAL_STATE)\nconst electronVersion = ref<string>('')\nlet xterm: Terminal | undefined\n\nconst terminalVisible = ref(true)\n\nconst updateProgress = ({ status: newStatus }: { status: ProgressStatus }) => {\n status.value = newStatus\n\n // Make critical error screen more obvious.\n if (newStatus === ProgressStatus.ERROR) terminalVisible.value = false\n else xterm?.clear()\n}\n\nconst terminalCreated = (\n { terminal, useAutoSize }: ReturnType<typeof useTerminal>,\n root: Ref<HTMLElement>\n) => {\n xterm = terminal\n\n useAutoSize(root, true, true)\n electron.onLogMessage((message: string) => {\n terminal.write(message)\n })\n\n terminal.options.cursorBlink = false\n terminal.options.disableStdin = true\n terminal.options.cursorInactiveStyle = 'block'\n}\n\nconst reinstall = () => electron.reinstall()\nconst reportIssue = () => {\n window.open('https://forum.comfy.org/c/v1-feedback/', '_blank')\n}\nconst openLogs = () => electron.openLogsFolder()\n\nonMounted(async () => {\n electron.sendReady()\n electron.onProgressUpdate(updateProgress)\n electronVersion.value = await electron.getElectronVersion()\n})\n</script>\n\n<style scoped>\n:deep(.xterm-helper-textarea) {\n /* Hide this as it moves all over when uv is running */\n display: none;\n}\n</style>\n"],"names":[],"mappings":";;;;;;;;;;;;;;;AAuDA,UAAM,WAAW;AACX,UAAA,EAAE,MAAM;AAER,UAAA,SAAS,IAAoB,eAAe,aAAa;AACzD,UAAA,kBAAkB,IAAY,EAAE;AAClC,QAAA;AAEE,UAAA,kBAAkB,IAAI,IAAI;AAEhC,UAAM,iBAAiB,wBAAC,EAAE,QAAQ,gBAA4C;AAC5E,aAAO,QAAQ;AAGf,UAAI,cAAc,eAAe,MAAO,iBAAgB,QAAQ;AAAA,kBACpD,MAAM;AAAA,IAAA,GALG;AAQvB,UAAM,kBAAkB,wBACtB,EAAE,UAAU,YAAA,GACZ,SACG;AACK,cAAA;AAEI,kBAAA,MAAM,MAAM,IAAI;AACnB,eAAA,aAAa,CAAC,YAAoB;AACzC,iBAAS,MAAM,OAAO;AAAA,MAAA,CACvB;AAED,eAAS,QAAQ,cAAc;AAC/B,eAAS,QAAQ,eAAe;AAChC,eAAS,QAAQ,sBAAsB;AAAA,IAAA,GAbjB;AAgBlB,UAAA,YAAY,6BAAM,SAAS,aAAf;AAClB,UAAM,cAAc,6BAAM;AACjB,aAAA,KAAK,0CAA0C,QAAQ;AAAA,IAAA,GAD5C;AAGd,UAAA,WAAW,6BAAM,SAAS,kBAAf;AAEjB,cAAU,YAAY;AACpB,eAAS,UAAU;AACnB,eAAS,iBAAiB,cAAc;AACxB,sBAAA,QAAQ,MAAM,SAAS,mBAAmB;AAAA,IAAA,CAC3D;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
5
comfy/web/assets/ServerStartView-DxIUrclT.css
generated
vendored
5
comfy/web/assets/ServerStartView-DxIUrclT.css
generated
vendored
@ -1,5 +0,0 @@
|
||||
|
||||
[data-v-95e9eb99] .xterm-helper-textarea {
|
||||
/* Hide this as it moves all over when uv is running */
|
||||
display: none;
|
||||
}
|
||||
98
comfy/web/assets/UserSelectView-C_4L-Yqf.js
generated
vendored
Normal file
98
comfy/web/assets/UserSelectView-C_4L-Yqf.js
generated
vendored
Normal file
@ -0,0 +1,98 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { a as defineComponent, J as useUserStore, bU as useRouter, r as ref, q as computed, o as onMounted, f as openBlock, g as createElementBlock, A as createBaseVNode, a8 as toDisplayString, h as createVNode, z as unref, aq as script, bN as script$1, bV as script$2, x as createBlock, y as withCtx, ay as createTextVNode, bW as script$3, i as createCommentVNode, D as script$4 } from "./index-DIU5yZe9.js";
|
||||
const _hoisted_1 = {
|
||||
id: "comfy-user-selection",
|
||||
class: "font-sans flex flex-col items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto"
|
||||
};
|
||||
const _hoisted_2 = { class: "mt-[5vh] 2xl:mt-[20vh] min-w-84 relative rounded-lg bg-[var(--comfy-menu-bg)] p-5 px-10 shadow-lg" };
|
||||
const _hoisted_3 = /* @__PURE__ */ createBaseVNode("h1", { class: "my-2.5 mb-7 font-normal" }, "ComfyUI", -1);
|
||||
const _hoisted_4 = { class: "flex w-full flex-col items-center" };
|
||||
const _hoisted_5 = { class: "flex w-full flex-col gap-2" };
|
||||
const _hoisted_6 = { for: "new-user-input" };
|
||||
const _hoisted_7 = { class: "flex w-full flex-col gap-2" };
|
||||
const _hoisted_8 = { for: "existing-user-select" };
|
||||
const _hoisted_9 = { class: "mt-5" };
|
||||
const _sfc_main = /* @__PURE__ */ defineComponent({
|
||||
__name: "UserSelectView",
|
||||
setup(__props) {
|
||||
const userStore = useUserStore();
|
||||
const router = useRouter();
|
||||
const selectedUser = ref(null);
|
||||
const newUsername = ref("");
|
||||
const loginError = ref("");
|
||||
const createNewUser = computed(() => newUsername.value.trim() !== "");
|
||||
const newUserExistsError = computed(() => {
|
||||
return userStore.users.find((user) => user.username === newUsername.value) ? `User "${newUsername.value}" already exists` : "";
|
||||
});
|
||||
const error = computed(() => newUserExistsError.value || loginError.value);
|
||||
const login = /* @__PURE__ */ __name(async () => {
|
||||
try {
|
||||
const user = createNewUser.value ? await userStore.createUser(newUsername.value) : selectedUser.value;
|
||||
if (!user) {
|
||||
throw new Error("No user selected");
|
||||
}
|
||||
userStore.login(user);
|
||||
router.push("/");
|
||||
} catch (err) {
|
||||
loginError.value = err.message ?? JSON.stringify(err);
|
||||
}
|
||||
}, "login");
|
||||
onMounted(async () => {
|
||||
if (!userStore.initialized) {
|
||||
await userStore.initialize();
|
||||
}
|
||||
});
|
||||
return (_ctx, _cache) => {
|
||||
return openBlock(), createElementBlock("div", _hoisted_1, [
|
||||
createBaseVNode("main", _hoisted_2, [
|
||||
_hoisted_3,
|
||||
createBaseVNode("form", _hoisted_4, [
|
||||
createBaseVNode("div", _hoisted_5, [
|
||||
createBaseVNode("label", _hoisted_6, toDisplayString(_ctx.$t("userSelect.newUser")) + ":", 1),
|
||||
createVNode(unref(script), {
|
||||
id: "new-user-input",
|
||||
modelValue: newUsername.value,
|
||||
"onUpdate:modelValue": _cache[0] || (_cache[0] = ($event) => newUsername.value = $event),
|
||||
placeholder: _ctx.$t("userSelect.enterUsername")
|
||||
}, null, 8, ["modelValue", "placeholder"])
|
||||
]),
|
||||
createVNode(unref(script$1)),
|
||||
createBaseVNode("div", _hoisted_7, [
|
||||
createBaseVNode("label", _hoisted_8, toDisplayString(_ctx.$t("userSelect.existingUser")) + ":", 1),
|
||||
createVNode(unref(script$2), {
|
||||
modelValue: selectedUser.value,
|
||||
"onUpdate:modelValue": _cache[1] || (_cache[1] = ($event) => selectedUser.value = $event),
|
||||
class: "w-full",
|
||||
inputId: "existing-user-select",
|
||||
options: unref(userStore).users,
|
||||
"option-label": "username",
|
||||
placeholder: _ctx.$t("userSelect.selectUser"),
|
||||
disabled: createNewUser.value
|
||||
}, null, 8, ["modelValue", "options", "placeholder", "disabled"]),
|
||||
error.value ? (openBlock(), createBlock(unref(script$3), {
|
||||
key: 0,
|
||||
severity: "error"
|
||||
}, {
|
||||
default: withCtx(() => [
|
||||
createTextVNode(toDisplayString(error.value), 1)
|
||||
]),
|
||||
_: 1
|
||||
})) : createCommentVNode("", true)
|
||||
]),
|
||||
createBaseVNode("footer", _hoisted_9, [
|
||||
createVNode(unref(script$4), {
|
||||
label: _ctx.$t("userSelect.next"),
|
||||
onClick: login
|
||||
}, null, 8, ["label"])
|
||||
])
|
||||
])
|
||||
])
|
||||
]);
|
||||
};
|
||||
}
|
||||
});
|
||||
export {
|
||||
_sfc_main as default
|
||||
};
|
||||
//# sourceMappingURL=UserSelectView-C_4L-Yqf.js.map
|
||||
1
comfy/web/assets/UserSelectView-C_4L-Yqf.js.map
generated
vendored
Normal file
1
comfy/web/assets/UserSelectView-C_4L-Yqf.js.map
generated
vendored
Normal file
@ -0,0 +1 @@
|
||||
{"version":3,"file":"UserSelectView-C_4L-Yqf.js","sources":["../../src/views/UserSelectView.vue"],"sourcesContent":["<template>\n <div\n id=\"comfy-user-selection\"\n class=\"font-sans flex flex-col items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto\"\n >\n <main\n class=\"mt-[5vh] 2xl:mt-[20vh] min-w-84 relative rounded-lg bg-[var(--comfy-menu-bg)] p-5 px-10 shadow-lg\"\n >\n <h1 class=\"my-2.5 mb-7 font-normal\">ComfyUI</h1>\n <form class=\"flex w-full flex-col items-center\">\n <div class=\"flex w-full flex-col gap-2\">\n <label for=\"new-user-input\">{{ $t('userSelect.newUser') }}:</label>\n <InputText\n id=\"new-user-input\"\n v-model=\"newUsername\"\n :placeholder=\"$t('userSelect.enterUsername')\"\n />\n </div>\n <Divider />\n <div class=\"flex w-full flex-col gap-2\">\n <label for=\"existing-user-select\"\n >{{ $t('userSelect.existingUser') }}:</label\n >\n <Select\n v-model=\"selectedUser\"\n class=\"w-full\"\n inputId=\"existing-user-select\"\n :options=\"userStore.users\"\n option-label=\"username\"\n :placeholder=\"$t('userSelect.selectUser')\"\n :disabled=\"createNewUser\"\n />\n <Message v-if=\"error\" severity=\"error\">{{ error }}</Message>\n </div>\n <footer class=\"mt-5\">\n <Button :label=\"$t('userSelect.next')\" @click=\"login\" />\n </footer>\n </form>\n </main>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport Divider from 'primevue/divider'\nimport InputText from 'primevue/inputtext'\nimport Select from 'primevue/select'\nimport Message from 'primevue/message'\nimport { User, useUserStore } from '@/stores/userStore'\nimport { useRouter } from 'vue-router'\nimport { computed, onMounted, ref } from 'vue'\n\nconst userStore = useUserStore()\nconst router = useRouter()\n\nconst selectedUser = ref<User | null>(null)\nconst newUsername = ref('')\nconst loginError = ref('')\n\nconst createNewUser = computed(() => newUsername.value.trim() !== '')\nconst newUserExistsError = computed(() => {\n return userStore.users.find((user) => user.username === newUsername.value)\n ? `User \"${newUsername.value}\" already exists`\n : ''\n})\nconst error = computed(() => newUserExistsError.value || loginError.value)\n\nconst login = async () => {\n try {\n const user = createNewUser.value\n ? await userStore.createUser(newUsername.value)\n : selectedUser.value\n\n if (!user) {\n throw new Error('No user selected')\n }\n\n userStore.login(user)\n router.push('/')\n } catch (err) {\n loginError.value = err.message ?? JSON.stringify(err)\n }\n}\n\nonMounted(async () => {\n if (!userStore.initialized) {\n await userStore.initialize()\n }\n})\n</script>\n"],"names":[],"mappings":";;;;;;;;;;;;;;;;;;AAoDA,UAAM,YAAY;AAClB,UAAM,SAAS;AAET,UAAA,eAAe,IAAiB,IAAI;AACpC,UAAA,cAAc,IAAI,EAAE;AACpB,UAAA,aAAa,IAAI,EAAE;AAEzB,UAAM,gBAAgB,SAAS,MAAM,YAAY,MAAM,KAAA,MAAW,EAAE;AAC9D,UAAA,qBAAqB,SAAS,MAAM;AACxC,aAAO,UAAU,MAAM,KAAK,CAAC,SAAS,KAAK,aAAa,YAAY,KAAK,IACrE,SAAS,YAAY,KAAK,qBAC1B;AAAA,IAAA,CACL;AACD,UAAM,QAAQ,SAAS,MAAM,mBAAmB,SAAS,WAAW,KAAK;AAEzE,UAAM,QAAQ,mCAAY;AACpB,UAAA;AACI,cAAA,OAAO,cAAc,QACvB,MAAM,UAAU,WAAW,YAAY,KAAK,IAC5C,aAAa;AAEjB,YAAI,CAAC,MAAM;AACH,gBAAA,IAAI,MAAM,kBAAkB;AAAA,QACpC;AAEA,kBAAU,MAAM,IAAI;AACpB,eAAO,KAAK,GAAG;AAAA,eACR,KAAK;AACZ,mBAAW,QAAQ,IAAI,WAAW,KAAK,UAAU,GAAG;AAAA,MACtD;AAAA,IAAA,GAdY;AAiBd,cAAU,YAAY;AAChB,UAAA,CAAC,UAAU,aAAa;AAC1B,cAAM,UAAU;MAClB;AAAA,IAAA,CACD;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;"}
|
||||
9
comfy/web/assets/WelcomeView-C5iZ_tHc.js → comfy/web/assets/WelcomeView-Db7ZDfZo.js
generated
vendored
9
comfy/web/assets/WelcomeView-C5iZ_tHc.js → comfy/web/assets/WelcomeView-Db7ZDfZo.js
generated
vendored
@ -1,6 +1,11 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
<<<<<<<< HEAD:comfy/web/assets/WelcomeView-C5iZ_tHc.js
|
||||
import { a as defineComponent, bQ as useRouter, f as openBlock, g as createElementBlock, A as createBaseVNode, a6 as toDisplayString, h as createVNode, z as unref, D as script, _ as _export_sfc } from "./index-BQYg0VNJ.js";
|
||||
========
|
||||
import { a as defineComponent, bU as useRouter, f as openBlock, g as createElementBlock, A as createBaseVNode, a8 as toDisplayString, h as createVNode, z as unref, D as script, R as pushScopeId, U as popScopeId, _ as _export_sfc } from "./index-DIU5yZe9.js";
|
||||
const _withScopeId = /* @__PURE__ */ __name((n) => (pushScopeId("data-v-c4d014c5"), n = n(), popScopeId(), n), "_withScopeId");
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/WelcomeView-Db7ZDfZo.js
|
||||
const _hoisted_1 = { class: "font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto" };
|
||||
const _hoisted_2 = { class: "flex flex-col items-center justify-center gap-8 p-8" };
|
||||
const _hoisted_3 = { class: "animated-gradient-text text-glow select-none" };
|
||||
@ -33,4 +38,8 @@ const WelcomeView = /* @__PURE__ */ _export_sfc(_sfc_main, [["__scopeId", "data-
|
||||
export {
|
||||
WelcomeView as default
|
||||
};
|
||||
<<<<<<<< HEAD:comfy/web/assets/WelcomeView-C5iZ_tHc.js
|
||||
//# sourceMappingURL=WelcomeView-C5iZ_tHc.js.map
|
||||
========
|
||||
//# sourceMappingURL=WelcomeView-Db7ZDfZo.js.map
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/WelcomeView-Db7ZDfZo.js
|
||||
1
comfy/web/assets/WelcomeView-Db7ZDfZo.js.map
generated
vendored
Normal file
1
comfy/web/assets/WelcomeView-Db7ZDfZo.js.map
generated
vendored
Normal file
@ -0,0 +1 @@
|
||||
{"version":3,"file":"WelcomeView-Db7ZDfZo.js","sources":["../../src/views/WelcomeView.vue"],"sourcesContent":["<template>\n <div\n class=\"font-sans flex flex-col justify-center items-center h-screen m-0 text-neutral-300 bg-neutral-900 dark-theme pointer-events-auto\"\n >\n <div class=\"flex flex-col items-center justify-center gap-8 p-8\">\n <!-- Header -->\n <h1 class=\"animated-gradient-text text-glow select-none\">\n {{ $t('welcome.title') }}\n </h1>\n\n <!-- Get Started Button -->\n <Button\n :label=\"$t('welcome.getStarted')\"\n icon=\"pi pi-arrow-right\"\n iconPos=\"right\"\n size=\"large\"\n rounded\n @click=\"navigateTo('/install')\"\n class=\"p-4 text-lg fade-in-up\"\n />\n </div>\n </div>\n</template>\n\n<script setup lang=\"ts\">\nimport Button from 'primevue/button'\nimport { useRouter } from 'vue-router'\n\nconst router = useRouter()\nconst navigateTo = (path: string) => {\n router.push(path)\n}\n</script>\n\n<style scoped>\n.animated-gradient-text {\n @apply font-bold;\n font-size: clamp(2rem, 8vw, 4rem);\n background: linear-gradient(to right, #12c2e9, #c471ed, #f64f59, #12c2e9);\n background-size: 300% auto;\n background-clip: text;\n -webkit-background-clip: text;\n -webkit-text-fill-color: transparent;\n animation: gradient 8s linear infinite;\n}\n\n.text-glow {\n filter: drop-shadow(0 0 8px rgba(255, 255, 255, 0.3));\n}\n\n@keyframes gradient {\n 0% {\n background-position: 0% center;\n }\n\n 100% {\n background-position: 300% center;\n }\n}\n\n.fade-in-up {\n animation: fadeInUp 1.5s ease-out;\n animation-fill-mode: both;\n}\n\n@keyframes fadeInUp {\n 0% {\n opacity: 0;\n transform: translateY(20px);\n }\n\n 100% {\n opacity: 1;\n transform: translateY(0);\n }\n}\n</style>\n"],"names":[],"mappings":";;;;;;;;;;AA4BA,UAAM,SAAS;AACT,UAAA,aAAa,wBAAC,SAAiB;AACnC,aAAO,KAAK,IAAI;AAAA,IAAA,GADC;;;;;;;;;;;;;;;;;;;;"}
|
||||
68
comfy/web/assets/index-BY_-AxSO.css → comfy/web/assets/index-1vLlIVor.css
generated
vendored
68
comfy/web/assets/index-BY_-AxSO.css → comfy/web/assets/index-1vLlIVor.css
generated
vendored
@ -81,13 +81,13 @@
|
||||
border: none !important;
|
||||
}
|
||||
|
||||
.form-input[data-v-4fbf09d8] .input-slider .p-inputnumber input,
|
||||
.form-input[data-v-4fbf09d8] .input-slider .slider-part {
|
||||
.form-input[data-v-e54b447b] .input-slider .p-inputnumber input,
|
||||
.form-input[data-v-e54b447b] .input-slider .slider-part {
|
||||
|
||||
width: 5rem
|
||||
}
|
||||
.form-input[data-v-4fbf09d8] .p-inputtext,
|
||||
.form-input[data-v-4fbf09d8] .p-select {
|
||||
.form-input[data-v-e54b447b] .p-inputtext,
|
||||
.form-input[data-v-e54b447b] .p-select {
|
||||
|
||||
width: 11rem
|
||||
}
|
||||
@ -345,7 +345,11 @@
|
||||
padding-top: 0px !important;
|
||||
}
|
||||
|
||||
<<<<<<<< HEAD:comfy/web/assets/index-BY_-AxSO.css
|
||||
.settings-container[data-v-5a032e0c] {
|
||||
========
|
||||
.settings-container[data-v-d85d6e64] {
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/index-1vLlIVor.css
|
||||
display: flex;
|
||||
height: 70vh;
|
||||
width: 60vw;
|
||||
@ -353,21 +357,39 @@
|
||||
overflow: hidden;
|
||||
}
|
||||
@media (max-width: 768px) {
|
||||
<<<<<<<< HEAD:comfy/web/assets/index-BY_-AxSO.css
|
||||
.settings-container[data-v-5a032e0c] {
|
||||
flex-direction: column;
|
||||
height: auto;
|
||||
}
|
||||
.settings-sidebar[data-v-5a032e0c] {
|
||||
========
|
||||
.settings-container[data-v-d85d6e64] {
|
||||
flex-direction: column;
|
||||
height: auto;
|
||||
}
|
||||
.settings-sidebar[data-v-d85d6e64] {
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/index-1vLlIVor.css
|
||||
width: 100%;
|
||||
}
|
||||
.settings-content[data-v-d85d6e64] {
|
||||
height: 350px;
|
||||
}
|
||||
}
|
||||
|
||||
/* Show a separator line above the Keybinding tab */
|
||||
/* This indicates the start of custom setting panels */
|
||||
<<<<<<<< HEAD:comfy/web/assets/index-BY_-AxSO.css
|
||||
.settings-sidebar[data-v-5a032e0c] .p-listbox-option[aria-label='Keybinding'] {
|
||||
position: relative;
|
||||
}
|
||||
.settings-sidebar[data-v-5a032e0c] .p-listbox-option[aria-label='Keybinding']::before {
|
||||
========
|
||||
.settings-sidebar[data-v-d85d6e64] .p-listbox-option[aria-label='Keybinding'] {
|
||||
position: relative;
|
||||
}
|
||||
.settings-sidebar[data-v-d85d6e64] .p-listbox-option[aria-label='Keybinding']::before {
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/index-1vLlIVor.css
|
||||
position: absolute;
|
||||
top: 0px;
|
||||
left: 0px;
|
||||
@ -2060,6 +2082,13 @@ img.galleria-image {
|
||||
margin-top: 0.25rem;
|
||||
margin-bottom: 0.25rem;
|
||||
}
|
||||
<<<<<<<< HEAD:comfy/web/assets/index-BY_-AxSO.css
|
||||
========
|
||||
.my-2{
|
||||
margin-top: 0.5rem;
|
||||
margin-bottom: 0.5rem;
|
||||
}
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/index-1vLlIVor.css
|
||||
.my-2\.5{
|
||||
margin-top: 0.625rem;
|
||||
margin-bottom: 0.625rem;
|
||||
@ -2425,6 +2454,10 @@ img.galleria-image {
|
||||
--tw-bg-opacity: 1;
|
||||
background-color: rgb(212 212 212 / var(--tw-bg-opacity, 1));
|
||||
}
|
||||
.bg-neutral-300{
|
||||
--tw-bg-opacity: 1;
|
||||
background-color: rgb(212 212 212 / var(--tw-bg-opacity));
|
||||
}
|
||||
.bg-neutral-800{
|
||||
--tw-bg-opacity: 1;
|
||||
background-color: rgb(38 38 38 / var(--tw-bg-opacity, 1));
|
||||
@ -2592,6 +2625,10 @@ img.galleria-image {
|
||||
--tw-text-opacity: 1;
|
||||
color: rgb(150 206 76 / var(--tw-text-opacity, 1));
|
||||
}
|
||||
.text-green-500{
|
||||
--tw-text-opacity: 1;
|
||||
color: rgb(150 206 76 / var(--tw-text-opacity));
|
||||
}
|
||||
.text-highlight{
|
||||
color: var(--p-primary-color);
|
||||
}
|
||||
@ -2626,6 +2663,18 @@ img.galleria-image {
|
||||
--tw-text-opacity: 1;
|
||||
color: rgb(239 68 68 / var(--tw-text-opacity, 1));
|
||||
}
|
||||
.text-neutral-800{
|
||||
--tw-text-opacity: 1;
|
||||
color: rgb(38 38 38 / var(--tw-text-opacity));
|
||||
}
|
||||
.text-neutral-900{
|
||||
--tw-text-opacity: 1;
|
||||
color: rgb(23 23 23 / var(--tw-text-opacity));
|
||||
}
|
||||
.text-red-500{
|
||||
--tw-text-opacity: 1;
|
||||
color: rgb(239 68 68 / var(--tw-text-opacity));
|
||||
}
|
||||
.no-underline{
|
||||
text-decoration-line: none;
|
||||
}
|
||||
@ -3409,6 +3458,17 @@ audio.comfy-audio.empty-audio-widget {
|
||||
opacity: 1;
|
||||
}
|
||||
|
||||
@media (min-width: 768px){
|
||||
|
||||
.md\:flex{
|
||||
display: flex;
|
||||
}
|
||||
|
||||
.md\:hidden{
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
|
||||
@media (min-width: 1536px){
|
||||
|
||||
.\32xl\:mx-4{
|
||||
9
comfy/web/assets/index-CMsGQEqY.js → comfy/web/assets/index-D3u7l7ha.js
generated
vendored
9
comfy/web/assets/index-CMsGQEqY.js → comfy/web/assets/index-D3u7l7ha.js
generated
vendored
@ -1,7 +1,12 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
<<<<<<<< HEAD:comfy/web/assets/index-CMsGQEqY.js
|
||||
import { ba as script$s, f as openBlock, g as createElementBlock, m as mergeProps, A as createBaseVNode, B as BaseStyle, R as script$t, a6 as toDisplayString, a1 as Ripple, t as resolveDirective, v as withDirectives, x as createBlock, M as resolveDynamicComponent, bR as script$u, l as resolveComponent, C as normalizeClass, aw as createSlots, y as withCtx, bz as script$v, bq as script$w, P as Fragment, Q as renderList, ax as createTextVNode, bg as setAttribute, be as normalizeProps, p as renderSlot, i as createCommentVNode, ad as script$x, a4 as equals, bb as script$y, c0 as script$z, cl as getFirstFocusableElement, ah as OverlayEventBus, a8 as getVNodeProp, ag as resolveFieldData, cm as invokeElementMethod, a2 as getAttribute, cn as getNextElementSibling, Y as getOuterWidth, co as getPreviousElementSibling, D as script$A, as as script$B, a0 as script$C, bd as script$E, ac as isNotEmpty, bm as withModifiers, W as getOuterHeight, ae as UniqueComponentId, cp as _default, af as ZIndex, a3 as focus, aj as addStyle, al as absolutePosition, am as ConnectedOverlayScrollHandler, an as isTouchDevice, cq as FilterOperator, ar as script$F, cr as FocusTrap, h as createVNode, av as Transition, cs as withKeys, ct as getIndex, k as script$H, cu as isClickable, cv as clearSelection, cw as localeComparator, cx as sort, cy as FilterService, cg as FilterMatchMode, V as findSingle, bW as findIndexInList, bX as find, cz as exportCSV, X as getOffset, cA as getHiddenElementOuterWidth, cB as getHiddenElementOuterHeight, cC as reorderArray, cD as removeClass, cE as addClass, ai as isEmpty, aq as script$I, at as script$J } from "./index-BQYg0VNJ.js";
|
||||
import { s as script$D, a as script$G } from "./index-DJqEjTnE.js";
|
||||
========
|
||||
import { cp as script$s, A as createBaseVNode, f as openBlock, g as createElementBlock, m as mergeProps, B as BaseStyle, V as script$t, a8 as toDisplayString, a3 as Ripple, t as resolveDirective, v as withDirectives, x as createBlock, M as resolveDynamicComponent, bV as script$u, l as resolveComponent, C as normalizeClass, ax as createSlots, y as withCtx, bD as script$v, bv as script$w, P as Fragment, Q as renderList, ay as createTextVNode, bl as setAttribute, af as UniqueComponentId, bj as normalizeProps, p as renderSlot, i as createCommentVNode, a6 as equals, bf as script$x, c4 as script$y, cq as getFirstFocusableElement, ai as OverlayEventBus, aa as getVNodeProp, ah as resolveFieldData, cr as invokeElementMethod, a4 as getAttribute, cs as getNextElementSibling, $ as getOuterWidth, ct as getPreviousElementSibling, D as script$z, at as script$A, a2 as script$B, bi as script$D, ae as isNotEmpty, br as withModifiers, Y as getOuterHeight, cu as _default, ag as ZIndex, a5 as focus, ak as addStyle, am as absolutePosition, an as ConnectedOverlayScrollHandler, ao as isTouchDevice, cv as FilterOperator, as as script$E, cw as FocusTrap, h as createVNode, aw as Transition, cx as withKeys, cy as getIndex, k as script$G, cz as isClickable, cA as clearSelection, cB as localeComparator, cC as sort, cD as FilterService, ck as FilterMatchMode, X as findSingle, b_ as findIndexInList, b$ as find, cE as exportCSV, Z as getOffset, cF as getHiddenElementOuterWidth, cG as getHiddenElementOuterHeight, cH as reorderArray, cI as getWindowScrollTop, cJ as removeClass, cK as addClass, aj as isEmpty, ar as script$H, au as script$I } from "./index-DIU5yZe9.js";
|
||||
import { s as script$C, a as script$F } from "./index-d698Brhb.js";
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/index-D3u7l7ha.js
|
||||
var script$r = {
|
||||
name: "ArrowDownIcon",
|
||||
"extends": script$s
|
||||
@ -8829,4 +8834,8 @@ export {
|
||||
script$d as a,
|
||||
script as s
|
||||
};
|
||||
<<<<<<<< HEAD:comfy/web/assets/index-CMsGQEqY.js
|
||||
//# sourceMappingURL=index-CMsGQEqY.js.map
|
||||
========
|
||||
//# sourceMappingURL=index-D3u7l7ha.js.map
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/index-D3u7l7ha.js
|
||||
1
comfy/web/assets/index-D3u7l7ha.js.map
generated
vendored
Normal file
1
comfy/web/assets/index-D3u7l7ha.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
39101
comfy/web/assets/index-BQYg0VNJ.js → comfy/web/assets/index-DIU5yZe9.js
generated
vendored
39101
comfy/web/assets/index-BQYg0VNJ.js → comfy/web/assets/index-DIU5yZe9.js
generated
vendored
File diff suppressed because one or more lines are too long
1
comfy/web/assets/index-DIU5yZe9.js.map
generated
vendored
Normal file
1
comfy/web/assets/index-DIU5yZe9.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
50
comfy/web/assets/index-d698Brhb.js
generated
vendored
Normal file
50
comfy/web/assets/index-d698Brhb.js
generated
vendored
Normal file
@ -0,0 +1,50 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
import { cp as script$2, A as createBaseVNode, f as openBlock, g as createElementBlock, m as mergeProps } from "./index-DIU5yZe9.js";
|
||||
var script$1 = {
|
||||
name: "BarsIcon",
|
||||
"extends": script$2
|
||||
};
|
||||
var _hoisted_1$1 = /* @__PURE__ */ createBaseVNode("path", {
|
||||
"fill-rule": "evenodd",
|
||||
"clip-rule": "evenodd",
|
||||
d: "M13.3226 3.6129H0.677419C0.497757 3.6129 0.325452 3.54152 0.198411 3.41448C0.0713707 3.28744 0 3.11514 0 2.93548C0 2.75581 0.0713707 2.58351 0.198411 2.45647C0.325452 2.32943 0.497757 2.25806 0.677419 2.25806H13.3226C13.5022 2.25806 13.6745 2.32943 13.8016 2.45647C13.9286 2.58351 14 2.75581 14 2.93548C14 3.11514 13.9286 3.28744 13.8016 3.41448C13.6745 3.54152 13.5022 3.6129 13.3226 3.6129ZM13.3226 7.67741H0.677419C0.497757 7.67741 0.325452 7.60604 0.198411 7.479C0.0713707 7.35196 0 7.17965 0 6.99999C0 6.82033 0.0713707 6.64802 0.198411 6.52098C0.325452 6.39394 0.497757 6.32257 0.677419 6.32257H13.3226C13.5022 6.32257 13.6745 6.39394 13.8016 6.52098C13.9286 6.64802 14 6.82033 14 6.99999C14 7.17965 13.9286 7.35196 13.8016 7.479C13.6745 7.60604 13.5022 7.67741 13.3226 7.67741ZM0.677419 11.7419H13.3226C13.5022 11.7419 13.6745 11.6706 13.8016 11.5435C13.9286 11.4165 14 11.2442 14 11.0645C14 10.8848 13.9286 10.7125 13.8016 10.5855C13.6745 10.4585 13.5022 10.3871 13.3226 10.3871H0.677419C0.497757 10.3871 0.325452 10.4585 0.198411 10.5855C0.0713707 10.7125 0 10.8848 0 11.0645C0 11.2442 0.0713707 11.4165 0.198411 11.5435C0.325452 11.6706 0.497757 11.7419 0.677419 11.7419Z",
|
||||
fill: "currentColor"
|
||||
}, null, -1);
|
||||
var _hoisted_2$1 = [_hoisted_1$1];
|
||||
function render$1(_ctx, _cache, $props, $setup, $data, $options) {
|
||||
return openBlock(), createElementBlock("svg", mergeProps({
|
||||
width: "14",
|
||||
height: "14",
|
||||
viewBox: "0 0 14 14",
|
||||
fill: "none",
|
||||
xmlns: "http://www.w3.org/2000/svg"
|
||||
}, _ctx.pti()), _hoisted_2$1, 16);
|
||||
}
|
||||
__name(render$1, "render$1");
|
||||
script$1.render = render$1;
|
||||
var script = {
|
||||
name: "PlusIcon",
|
||||
"extends": script$2
|
||||
};
|
||||
var _hoisted_1 = /* @__PURE__ */ createBaseVNode("path", {
|
||||
d: "M7.67742 6.32258V0.677419C7.67742 0.497757 7.60605 0.325452 7.47901 0.198411C7.35197 0.0713707 7.17966 0 7 0C6.82034 0 6.64803 0.0713707 6.52099 0.198411C6.39395 0.325452 6.32258 0.497757 6.32258 0.677419V6.32258H0.677419C0.497757 6.32258 0.325452 6.39395 0.198411 6.52099C0.0713707 6.64803 0 6.82034 0 7C0 7.17966 0.0713707 7.35197 0.198411 7.47901C0.325452 7.60605 0.497757 7.67742 0.677419 7.67742H6.32258V13.3226C6.32492 13.5015 6.39704 13.6725 6.52358 13.799C6.65012 13.9255 6.82106 13.9977 7 14C7.17966 14 7.35197 13.9286 7.47901 13.8016C7.60605 13.6745 7.67742 13.5022 7.67742 13.3226V7.67742H13.3226C13.5022 7.67742 13.6745 7.60605 13.8016 7.47901C13.9286 7.35197 14 7.17966 14 7C13.9977 6.82106 13.9255 6.65012 13.799 6.52358C13.6725 6.39704 13.5015 6.32492 13.3226 6.32258H7.67742Z",
|
||||
fill: "currentColor"
|
||||
}, null, -1);
|
||||
var _hoisted_2 = [_hoisted_1];
|
||||
function render(_ctx, _cache, $props, $setup, $data, $options) {
|
||||
return openBlock(), createElementBlock("svg", mergeProps({
|
||||
width: "14",
|
||||
height: "14",
|
||||
viewBox: "0 0 14 14",
|
||||
fill: "none",
|
||||
xmlns: "http://www.w3.org/2000/svg"
|
||||
}, _ctx.pti()), _hoisted_2, 16);
|
||||
}
|
||||
__name(render, "render");
|
||||
script.render = render;
|
||||
export {
|
||||
script as a,
|
||||
script$1 as s
|
||||
};
|
||||
//# sourceMappingURL=index-d698Brhb.js.map
|
||||
1
comfy/web/assets/index-d698Brhb.js.map
generated
vendored
Normal file
1
comfy/web/assets/index-d698Brhb.js.map
generated
vendored
Normal file
@ -0,0 +1 @@
|
||||
{"version":3,"file":"index-d698Brhb.js","sources":["../../node_modules/@primevue/icons/bars/index.mjs","../../node_modules/@primevue/icons/plus/index.mjs"],"sourcesContent":["import BaseIcon from '@primevue/icons/baseicon';\nimport { openBlock, createElementBlock, mergeProps, createElementVNode } from 'vue';\n\nvar script = {\n name: 'BarsIcon',\n \"extends\": BaseIcon\n};\n\nvar _hoisted_1 = /*#__PURE__*/createElementVNode(\"path\", {\n \"fill-rule\": \"evenodd\",\n \"clip-rule\": \"evenodd\",\n d: \"M13.3226 3.6129H0.677419C0.497757 3.6129 0.325452 3.54152 0.198411 3.41448C0.0713707 3.28744 0 3.11514 0 2.93548C0 2.75581 0.0713707 2.58351 0.198411 2.45647C0.325452 2.32943 0.497757 2.25806 0.677419 2.25806H13.3226C13.5022 2.25806 13.6745 2.32943 13.8016 2.45647C13.9286 2.58351 14 2.75581 14 2.93548C14 3.11514 13.9286 3.28744 13.8016 3.41448C13.6745 3.54152 13.5022 3.6129 13.3226 3.6129ZM13.3226 7.67741H0.677419C0.497757 7.67741 0.325452 7.60604 0.198411 7.479C0.0713707 7.35196 0 7.17965 0 6.99999C0 6.82033 0.0713707 6.64802 0.198411 6.52098C0.325452 6.39394 0.497757 6.32257 0.677419 6.32257H13.3226C13.5022 6.32257 13.6745 6.39394 13.8016 6.52098C13.9286 6.64802 14 6.82033 14 6.99999C14 7.17965 13.9286 7.35196 13.8016 7.479C13.6745 7.60604 13.5022 7.67741 13.3226 7.67741ZM0.677419 11.7419H13.3226C13.5022 11.7419 13.6745 11.6706 13.8016 11.5435C13.9286 11.4165 14 11.2442 14 11.0645C14 10.8848 13.9286 10.7125 13.8016 10.5855C13.6745 10.4585 13.5022 10.3871 13.3226 10.3871H0.677419C0.497757 10.3871 0.325452 10.4585 0.198411 10.5855C0.0713707 10.7125 0 10.8848 0 11.0645C0 11.2442 0.0713707 11.4165 0.198411 11.5435C0.325452 11.6706 0.497757 11.7419 0.677419 11.7419Z\",\n fill: \"currentColor\"\n}, null, -1);\nvar _hoisted_2 = [_hoisted_1];\nfunction render(_ctx, _cache, $props, $setup, $data, $options) {\n return openBlock(), createElementBlock(\"svg\", mergeProps({\n width: \"14\",\n height: \"14\",\n viewBox: \"0 0 14 14\",\n fill: \"none\",\n xmlns: \"http://www.w3.org/2000/svg\"\n }, _ctx.pti()), _hoisted_2, 16);\n}\n\nscript.render = render;\n\nexport { script as default };\n//# sourceMappingURL=index.mjs.map\n","import BaseIcon from '@primevue/icons/baseicon';\nimport { openBlock, createElementBlock, mergeProps, createElementVNode } from 'vue';\n\nvar script = {\n name: 'PlusIcon',\n \"extends\": BaseIcon\n};\n\nvar _hoisted_1 = /*#__PURE__*/createElementVNode(\"path\", {\n d: \"M7.67742 6.32258V0.677419C7.67742 0.497757 7.60605 0.325452 7.47901 0.198411C7.35197 0.0713707 7.17966 0 7 0C6.82034 0 6.64803 0.0713707 6.52099 0.198411C6.39395 0.325452 6.32258 0.497757 6.32258 0.677419V6.32258H0.677419C0.497757 6.32258 0.325452 6.39395 0.198411 6.52099C0.0713707 6.64803 0 6.82034 0 7C0 7.17966 0.0713707 7.35197 0.198411 7.47901C0.325452 7.60605 0.497757 7.67742 0.677419 7.67742H6.32258V13.3226C6.32492 13.5015 6.39704 13.6725 6.52358 13.799C6.65012 13.9255 6.82106 13.9977 7 14C7.17966 14 7.35197 13.9286 7.47901 13.8016C7.60605 13.6745 7.67742 13.5022 7.67742 13.3226V7.67742H13.3226C13.5022 7.67742 13.6745 7.60605 13.8016 7.47901C13.9286 7.35197 14 7.17966 14 7C13.9977 6.82106 13.9255 6.65012 13.799 6.52358C13.6725 6.39704 13.5015 6.32492 13.3226 6.32258H7.67742Z\",\n fill: \"currentColor\"\n}, null, -1);\nvar _hoisted_2 = [_hoisted_1];\nfunction render(_ctx, _cache, $props, $setup, $data, $options) {\n return openBlock(), createElementBlock(\"svg\", mergeProps({\n width: \"14\",\n height: \"14\",\n viewBox: \"0 0 14 14\",\n fill: \"none\",\n xmlns: \"http://www.w3.org/2000/svg\"\n }, _ctx.pti()), _hoisted_2, 16);\n}\n\nscript.render = render;\n\nexport { script as default };\n//# sourceMappingURL=index.mjs.map\n"],"names":["script","BaseIcon","_hoisted_1","createElementVNode","_hoisted_2","render"],"mappings":";;;AAGG,IAACA,WAAS;AAAA,EACX,MAAM;AAAA,EACN,WAAWC;AACb;AAEA,IAAIC,eAA0BC,gCAAmB,QAAQ;AAAA,EACvD,aAAa;AAAA,EACb,aAAa;AAAA,EACb,GAAG;AAAA,EACH,MAAM;AACR,GAAG,MAAM,EAAE;AACX,IAAIC,eAAa,CAACF,YAAU;AAC5B,SAASG,SAAO,MAAM,QAAQ,QAAQ,QAAQ,OAAO,UAAU;AAC7D,SAAO,UAAW,GAAE,mBAAmB,OAAO,WAAW;AAAA,IACvD,OAAO;AAAA,IACP,QAAQ;AAAA,IACR,SAAS;AAAA,IACT,MAAM;AAAA,IACN,OAAO;AAAA,EACR,GAAE,KAAK,IAAG,CAAE,GAAGD,cAAY,EAAE;AAChC;AARSC;AAUTL,SAAO,SAASK;ACtBb,IAAC,SAAS;AAAA,EACX,MAAM;AAAA,EACN,WAAWJ;AACb;AAEA,IAAI,aAA0BE,gCAAmB,QAAQ;AAAA,EACvD,GAAG;AAAA,EACH,MAAM;AACR,GAAG,MAAM,EAAE;AACX,IAAI,aAAa,CAAC,UAAU;AAC5B,SAAS,OAAO,MAAM,QAAQ,QAAQ,QAAQ,OAAO,UAAU;AAC7D,SAAO,UAAW,GAAE,mBAAmB,OAAO,WAAW;AAAA,IACvD,OAAO;AAAA,IACP,QAAQ;AAAA,IACR,SAAS;AAAA,IACT,MAAM;AAAA,IACN,OAAO;AAAA,EACR,GAAE,KAAK,IAAG,CAAE,GAAG,YAAY,EAAE;AAChC;AARS;AAUT,OAAO,SAAS;","x_google_ignoreList":[0,1]}
|
||||
31
comfy/web/assets/index-BfiYPlqA.js → comfy/web/assets/index-p6KSJ2Zq.js
generated
vendored
31
comfy/web/assets/index-BfiYPlqA.js → comfy/web/assets/index-p6KSJ2Zq.js
generated
vendored
@ -1,7 +1,12 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
<<<<<<<< HEAD:comfy/web/assets/index-BfiYPlqA.js
|
||||
import { c2 as ComfyDialog, c3 as $el, c4 as ComfyApp, b as app, j as LiteGraph, b6 as LGraphCanvas, c5 as DraggableList, bj as useToastStore, c6 as showPromptDialog, c7 as t, c8 as serialise, aE as useNodeDefStore, c9 as deserialiseAndCreate, a$ as api, u as useSettingStore, L as LGraphGroup, ca as KeyComboImpl, O as useKeybindingStore, F as useCommandStore, c as LGraphNode, cb as ComfyWidgets, cc as applyTextReplacements, cd as isElectron, bV as electronAPI, ce as showConfirmationDialog, aP as nextTick } from "./index-BQYg0VNJ.js";
|
||||
import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs-BMOQhk10.js";
|
||||
========
|
||||
import { c6 as ComfyDialog, c7 as $el, c8 as ComfyApp, b as app, j as LiteGraph, b9 as LGraphCanvas, c9 as DraggableList, bo as useToastStore, ca as showPromptDialog, cb as t, cc as serialise, aF as useNodeDefStore, cd as deserialiseAndCreate, b2 as api, u as useSettingStore, L as LGraphGroup, ce as KeyComboImpl, O as useKeybindingStore, F as useCommandStore, c as LGraphNode, cf as ComfyWidgets, cg as applyTextReplacements, ch as isElectron, bZ as electronAPI, ci as showConfirmationDialog, aQ as nextTick } from "./index-DIU5yZe9.js";
|
||||
import { mergeIfValid, getWidgetConfig, setWidgetConfig } from "./widgetInputs-Bvm3AgOa.js";
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/index-p6KSJ2Zq.js
|
||||
class ClipspaceDialog extends ComfyDialog {
|
||||
static {
|
||||
__name(this, "ClipspaceDialog");
|
||||
@ -736,7 +741,7 @@ class ManageGroupDialog extends ComfyDialog {
|
||||
groupNodes.map(
|
||||
(g) => $el("option", {
|
||||
textContent: g,
|
||||
selected: `${PREFIX$1}${SEPARATOR$1}` + g === type,
|
||||
selected: `${PREFIX$1}${SEPARATOR$1}${g}` === type,
|
||||
value: g
|
||||
})
|
||||
)
|
||||
@ -854,7 +859,7 @@ class ManageGroupDialog extends ComfyDialog {
|
||||
]);
|
||||
this.element.replaceChildren(outer);
|
||||
this.changeGroup(
|
||||
type ? groupNodes.find((g) => `${PREFIX$1}${SEPARATOR$1}` + g === type) : groupNodes[0]
|
||||
type ? groupNodes.find((g) => `${PREFIX$1}${SEPARATOR$1}${g}` === type) ?? groupNodes[0] : groupNodes[0]
|
||||
);
|
||||
this.element.showModal();
|
||||
this.element.addEventListener("close", () => {
|
||||
@ -1978,7 +1983,7 @@ function addConvertToGroupOptions() {
|
||||
options.splice(index + 1, null, {
|
||||
content: `Manage Group Nodes`,
|
||||
disabled,
|
||||
callback: manageGroupNodes
|
||||
callback: /* @__PURE__ */ __name(() => manageGroupNodes(), "callback")
|
||||
});
|
||||
}
|
||||
__name(addManageOption, "addManageOption");
|
||||
@ -8531,6 +8536,14 @@ app.registerExtension({
|
||||
window.open("https://forum.comfy.org/c/v1-feedback/", "_blank");
|
||||
}
|
||||
},
|
||||
{
|
||||
id: "Comfy-Desktop.OpenUserGuide",
|
||||
label: "Desktop User Guide",
|
||||
icon: "pi pi-book",
|
||||
function() {
|
||||
window.open("https://comfyorg.notion.site/", "_blank");
|
||||
}
|
||||
},
|
||||
{
|
||||
id: "Comfy-Desktop.Reinstall",
|
||||
label: "Reinstall",
|
||||
@ -8556,7 +8569,10 @@ app.registerExtension({
|
||||
menuCommands: [
|
||||
{
|
||||
path: ["Help"],
|
||||
commands: ["Comfy-Desktop.OpenFeedbackPage"]
|
||||
commands: [
|
||||
"Comfy-Desktop.OpenUserGuide",
|
||||
"Comfy-Desktop.OpenFeedbackPage"
|
||||
]
|
||||
},
|
||||
{
|
||||
path: ["Help"],
|
||||
@ -51681,6 +51697,9 @@ class Load3d {
|
||||
});
|
||||
}
|
||||
setViewPosition(position) {
|
||||
if (!this.currentModel) {
|
||||
return;
|
||||
}
|
||||
const box = new Box3();
|
||||
let center = new Vector3();
|
||||
let size = new Vector3();
|
||||
@ -52389,4 +52408,8 @@ app.registerExtension({
|
||||
);
|
||||
}
|
||||
});
|
||||
<<<<<<<< HEAD:comfy/web/assets/index-BfiYPlqA.js
|
||||
//# sourceMappingURL=index-BfiYPlqA.js.map
|
||||
========
|
||||
//# sourceMappingURL=index-p6KSJ2Zq.js.map
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/index-p6KSJ2Zq.js
|
||||
1
comfy/web/assets/index-p6KSJ2Zq.js.map
generated
vendored
Normal file
1
comfy/web/assets/index-p6KSJ2Zq.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -1,6 +1,10 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
<<<<<<<< HEAD:comfy/web/assets/serverConfigStore-DulDGgjD.js
|
||||
import { d as defineStore, r as ref, q as computed } from "./index-BQYg0VNJ.js";
|
||||
========
|
||||
import { d as defineStore, r as ref, q as computed } from "./index-DIU5yZe9.js";
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/serverConfigStore-DYv7_Nld.js
|
||||
const useServerConfigStore = defineStore("serverConfig", () => {
|
||||
const serverConfigById = ref({});
|
||||
const serverConfigs = computed(() => {
|
||||
@ -87,4 +91,8 @@ const useServerConfigStore = defineStore("serverConfig", () => {
|
||||
export {
|
||||
useServerConfigStore as u
|
||||
};
|
||||
<<<<<<<< HEAD:comfy/web/assets/serverConfigStore-DulDGgjD.js
|
||||
//# sourceMappingURL=serverConfigStore-DulDGgjD.js.map
|
||||
========
|
||||
//# sourceMappingURL=serverConfigStore-DYv7_Nld.js.map
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/serverConfigStore-DYv7_Nld.js
|
||||
5
comfy/web/assets/serverConfigStore-DYv7_Nld.js.map
generated
vendored
Normal file
5
comfy/web/assets/serverConfigStore-DYv7_Nld.js.map
generated
vendored
Normal file
File diff suppressed because one or more lines are too long
1
comfy/web/assets/serverConfigStore-DulDGgjD.js.map
generated
vendored
1
comfy/web/assets/serverConfigStore-DulDGgjD.js.map
generated
vendored
File diff suppressed because one or more lines are too long
@ -1,6 +1,10 @@
|
||||
var __defProp = Object.defineProperty;
|
||||
var __name = (target, value) => __defProp(target, "name", { value, configurable: true });
|
||||
<<<<<<<< HEAD:comfy/web/assets/widgetInputs-BMOQhk10.js
|
||||
import { c as LGraphNode, b as app, cc as applyTextReplacements, cb as ComfyWidgets, cf as addValueControlWidgets, j as LiteGraph } from "./index-BQYg0VNJ.js";
|
||||
========
|
||||
import { c as LGraphNode, b as app, cg as applyTextReplacements, cf as ComfyWidgets, cj as addValueControlWidgets, j as LiteGraph } from "./index-DIU5yZe9.js";
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/widgetInputs-Bvm3AgOa.js
|
||||
const CONVERTED_TYPE = "converted-widget";
|
||||
const VALID_TYPES = [
|
||||
"STRING",
|
||||
@ -763,4 +767,8 @@ export {
|
||||
mergeIfValid,
|
||||
setWidgetConfig
|
||||
};
|
||||
<<<<<<<< HEAD:comfy/web/assets/widgetInputs-BMOQhk10.js
|
||||
//# sourceMappingURL=widgetInputs-BMOQhk10.js.map
|
||||
========
|
||||
//# sourceMappingURL=widgetInputs-Bvm3AgOa.js.map
|
||||
>>>>>>>> 57f330caf91af37dda67c4202bb27cdebb7161d8:comfy/web/assets/widgetInputs-Bvm3AgOa.js
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user