mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-16 01:37:04 +08:00
Compare commits
157 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ac3b26a7d | ||
|
|
6592bffc60 | ||
|
|
971cefe7d4 | ||
|
|
da2bfb5b0a | ||
|
|
c5a47a1692 | ||
|
|
908fd7d749 | ||
|
|
5495589db3 | ||
|
|
982876d59a | ||
|
|
338d9ae3bb | ||
|
|
eeb020b9b7 | ||
|
|
ae65433a60 | ||
|
|
fdebe18296 | ||
|
|
f8321eb57b | ||
|
|
93948e3fc5 | ||
|
|
e711aaf1a7 | ||
|
|
57ddb7fd13 | ||
|
|
17c92a9f28 | ||
|
|
36357bbcc3 | ||
|
|
f668c2e3c9 | ||
|
|
fc657f471a | ||
|
|
791e30ff50 | ||
|
|
e2a800e7ef | ||
|
|
9d252f3b70 | ||
|
|
b9fb542703 | ||
|
|
cabc4d351f | ||
|
|
e136b6dbb0 | ||
|
|
d50f342c90 | ||
|
|
3b0368aa34 | ||
|
|
935493f6c1 | ||
|
|
60ee574748 | ||
|
|
8e889c535d | ||
|
|
fd271dedfd | ||
|
|
c3c6313fc7 | ||
|
|
85c4b4ae26 | ||
|
|
058f084371 | ||
|
|
ec7f65187d | ||
|
|
56fa7dbe38 | ||
|
|
329480da5a | ||
|
|
4086acf3c2 | ||
|
|
50ca97e776 | ||
|
|
7ac7d69d94 | ||
|
|
76f18e955d | ||
|
|
d7a0aef650 | ||
|
|
913f86b727 | ||
|
|
117bf3f2bd | ||
|
|
ae676ed105 | ||
|
|
fd109325db | ||
|
|
bed12674a1 | ||
|
|
092ee8a500 | ||
|
|
79d17ba233 | ||
|
|
6fd463aec9 | ||
|
|
43071e3de3 | ||
|
|
0ec05b1481 | ||
|
|
35fa091340 | ||
|
|
3c8456223c | ||
|
|
9bc893c5bb | ||
|
|
f4bdf5f830 | ||
|
|
6be85c7920 | ||
|
|
ea17add3c6 | ||
|
|
ecdc8697d5 | ||
|
|
dce518c2b4 | ||
|
|
440268d394 | ||
|
|
87c104bfc1 | ||
|
|
19f2192d69 | ||
|
|
519c941165 | ||
|
|
861817d22d | ||
|
|
c120eee5ba | ||
|
|
73f5649196 | ||
|
|
3f512f5659 | ||
|
|
b94d394a64 | ||
|
|
277237ccc1 | ||
|
|
daaceac769 | ||
|
|
33d6aec3b7 | ||
|
|
44baa0b7f3 | ||
|
|
a17cf1c387 | ||
|
|
b4a20acc54 | ||
|
|
c55dc857d5 | ||
|
|
878db3a727 | ||
|
|
30c259cac8 | ||
|
|
1cb7e22a95 | ||
|
|
2640acb31c | ||
|
|
7dbd5dfe91 | ||
|
|
f8b981ae9a | ||
|
|
4967f81778 | ||
|
|
0a6746898d | ||
|
|
5151cff293 | ||
|
|
af96d9812d | ||
|
|
52a32e2b32 | ||
|
|
b907085709 | ||
|
|
065a2fbbec | ||
|
|
0ff0457892 | ||
|
|
6484ac89dc | ||
|
|
f55c98a89f | ||
|
|
ca7808f240 | ||
|
|
52e778fff3 | ||
|
|
9d8a817985 | ||
|
|
b59750a86a | ||
|
|
3f382a4f98 | ||
|
|
f17251bec6 | ||
|
|
c38e7d6599 | ||
|
|
eaf68c9b5b | ||
|
|
cc6a8dcd1a | ||
|
|
a2d60aad0f | ||
|
|
d8433c63fd | ||
|
|
dd41b74549 | ||
|
|
55f654db3d | ||
|
|
58c6ed541d | ||
|
|
234c3dc85f | ||
|
|
8908ee2628 | ||
|
|
1105e0d139 | ||
|
|
8938aa3f30 | ||
|
|
f16219e3aa | ||
|
|
8402c8700a | ||
|
|
58b8574661 | ||
|
|
90b3995ec8 | ||
|
|
bdb10a583f | ||
|
|
0e24dbb19f | ||
|
|
e9aae31fa2 | ||
|
|
0c18842acb | ||
|
|
d196a905bb | ||
|
|
18b79acba9 | ||
|
|
dff996ca39 | ||
|
|
828b1b9953 | ||
|
|
af81cb962d | ||
|
|
5c7b08ca58 | ||
|
|
6b573ae0cb | ||
|
|
015a0599d0 | ||
|
|
acfaa5c4a1 | ||
|
|
b6805429b9 | ||
|
|
25022e0b09 | ||
|
|
22a2644e57 | ||
|
|
b2ef58e2b1 | ||
|
|
6a6d456c88 | ||
|
|
3d1fdaf9f4 | ||
|
|
1286fcfe40 | ||
|
|
3bd71554a2 | ||
|
|
f66183a541 | ||
|
|
cbd68e3d58 | ||
|
|
d89c29f259 | ||
|
|
a9c35256bc | ||
|
|
532938b16b | ||
|
|
ecb683b057 | ||
|
|
c55fd74816 | ||
|
|
3398123752 | ||
|
|
943b3b615d | ||
|
|
10e90a5757 | ||
|
|
b75d349f25 | ||
|
|
7b8389578e | ||
|
|
9e00ce5b76 | ||
|
|
f5e66d5e47 | ||
|
|
87b0359392 | ||
|
|
cb96d4d18c | ||
|
|
394348f5ca | ||
|
|
7601e89255 | ||
|
|
6a1d3a1ae1 | ||
|
|
65ee24c978 | ||
|
|
17027f2a6a |
@ -53,6 +53,16 @@ try:
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash") # noqa: T201
|
||||
except:
|
||||
print("Could not stash, cleaning index and trying again.") # noqa: T201
|
||||
repo.state_cleanup()
|
||||
repo.index.read_tree(repo.head.peel().tree)
|
||||
repo.index.write()
|
||||
try:
|
||||
repo.stash(ident)
|
||||
except KeyError:
|
||||
print("nothing to stash.") # noqa: T201
|
||||
|
||||
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
|
||||
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
|
||||
try:
|
||||
@ -66,8 +76,10 @@ if branch is None:
|
||||
try:
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
except:
|
||||
print("pulling.") # noqa: T201
|
||||
pull(repo)
|
||||
print("fetching.") # noqa: T201
|
||||
for remote in repo.remotes:
|
||||
if remote.name == "origin":
|
||||
remote.fetch()
|
||||
ref = repo.lookup_reference('refs/remotes/origin/master')
|
||||
repo.checkout(ref)
|
||||
branch = repo.lookup_branch('master')
|
||||
@ -149,3 +161,4 @@ try:
|
||||
shutil.copy(stable_update_script, stable_update_script_to)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
As of the time of writing this you need this preview driver for best results:
|
||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-PREVIEW.html
|
||||
As of the time of writing this you need this driver for best results:
|
||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
|
||||
|
||||
HOW TO RUN:
|
||||
|
||||
@ -25,3 +25,4 @@ In the ComfyUI directory you will find a file: extra_model_paths.yaml.example
|
||||
Rename this file to: extra_model_paths.yaml and edit it with your favorite text editor.
|
||||
|
||||
|
||||
|
||||
|
||||
6
.github/workflows/release-stable-all.yml
vendored
6
.github/workflows/release-stable-all.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release NVIDIA Default (cu129)"
|
||||
name: "Release NVIDIA Default (cu130)"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
@ -65,11 +65,11 @@ jobs:
|
||||
contents: "write"
|
||||
packages: "write"
|
||||
pull-requests: "read"
|
||||
name: "Release AMD ROCm 6.4.4"
|
||||
name: "Release AMD ROCm 7.1.1"
|
||||
uses: ./.github/workflows/stable-release.yml
|
||||
with:
|
||||
git_tag: ${{ inputs.git_tag }}
|
||||
cache_tag: "rocm644"
|
||||
cache_tag: "rocm711"
|
||||
python_minor: "12"
|
||||
python_patch: "10"
|
||||
rel_name: "amd"
|
||||
|
||||
@ -1,3 +1,2 @@
|
||||
# Admins
|
||||
* @comfyanonymous
|
||||
* @kosinkadink
|
||||
* @comfyanonymous @kosinkadink @guill
|
||||
|
||||
29
README.md
29
README.md
@ -67,6 +67,8 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
|
||||
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
|
||||
- [Hunyuan Image 2.1](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_image/)
|
||||
- [Flux 2](https://comfyanonymous.github.io/ComfyUI_examples/flux2/)
|
||||
- [Z Image](https://comfyanonymous.github.io/ComfyUI_examples/z_image/)
|
||||
- Image Editing Models
|
||||
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
|
||||
- [Flux Kontext](https://comfyanonymous.github.io/ComfyUI_examples/flux/#flux-kontext-image-editing-model)
|
||||
@ -79,6 +81,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
|
||||
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
|
||||
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
|
||||
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
|
||||
- [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
|
||||
- Audio Models
|
||||
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
|
||||
@ -317,6 +320,32 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
||||
1. Install the Iluvatar Corex Toolkit by adhering to the platform-specific instructions on the [Installation](https://support.iluvatar.com/#/DocumentCentre?id=1&nameCenter=2&productId=520117912052801536)
|
||||
2. Launch ComfyUI by running `python main.py`
|
||||
|
||||
|
||||
## [ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4)
|
||||
|
||||
**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI.
|
||||
|
||||
### Setup
|
||||
|
||||
1. Install the manager dependencies:
|
||||
```bash
|
||||
pip install -r manager_requirements.txt
|
||||
```
|
||||
|
||||
2. Enable the manager with the `--enable-manager` flag when running ComfyUI:
|
||||
```bash
|
||||
python main.py --enable-manager
|
||||
```
|
||||
|
||||
### Command Line Options
|
||||
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
|
||||
|
||||
# Running
|
||||
|
||||
```python main.py```
|
||||
|
||||
@ -58,8 +58,13 @@ class InternalRoutes:
|
||||
return web.json_response({"error": "Invalid directory type"}, status=400)
|
||||
|
||||
directory = get_directory_by_type(directory_type)
|
||||
|
||||
def is_visible_file(entry: os.DirEntry) -> bool:
|
||||
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
|
||||
return entry.is_file() and not entry.name.startswith('.')
|
||||
|
||||
sorted_files = sorted(
|
||||
(entry for entry in os.scandir(directory) if entry.is_file()),
|
||||
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
|
||||
key=lambda entry: -entry.stat().st_mtime
|
||||
)
|
||||
return web.json_response([entry.name for entry in sorted_files], status=200)
|
||||
|
||||
@ -10,7 +10,8 @@ import importlib
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from pathlib import Path
|
||||
from typing import TypedDict, Optional
|
||||
from typing import Dict, TypedDict, Optional
|
||||
from aiohttp import web
|
||||
from importlib.metadata import version
|
||||
|
||||
import requests
|
||||
@ -257,7 +258,54 @@ comfyui-frontend-package is not installed.
|
||||
sys.exit(-1)
|
||||
|
||||
@classmethod
|
||||
def templates_path(cls) -> str:
|
||||
def template_asset_map(cls) -> Optional[Dict[str, str]]:
|
||||
"""Return a mapping of template asset names to their absolute paths."""
|
||||
try:
|
||||
from comfyui_workflow_templates import (
|
||||
get_asset_path,
|
||||
iter_templates,
|
||||
)
|
||||
except ImportError:
|
||||
logging.error(
|
||||
f"""
|
||||
********** ERROR ***********
|
||||
|
||||
comfyui-workflow-templates is not installed.
|
||||
|
||||
{frontend_install_warning_message()}
|
||||
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
template_entries = list(iter_templates())
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to enumerate workflow templates: {exc}")
|
||||
return None
|
||||
|
||||
asset_map: Dict[str, str] = {}
|
||||
try:
|
||||
for entry in template_entries:
|
||||
for asset in entry.assets:
|
||||
asset_map[asset.filename] = get_asset_path(
|
||||
entry.template_id, asset.filename
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.error(f"Failed to resolve template asset paths: {exc}")
|
||||
return None
|
||||
|
||||
if not asset_map:
|
||||
logging.error("No workflow template assets found. Did the packages install correctly?")
|
||||
return None
|
||||
|
||||
return asset_map
|
||||
|
||||
|
||||
@classmethod
|
||||
def legacy_templates_path(cls) -> Optional[str]:
|
||||
"""Return the legacy templates directory shipped inside the meta package."""
|
||||
try:
|
||||
import comfyui_workflow_templates
|
||||
|
||||
@ -276,6 +324,7 @@ comfyui-workflow-templates is not installed.
|
||||
********** ERROR ***********
|
||||
""".strip()
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def embedded_docs_path(cls) -> str:
|
||||
@ -392,3 +441,17 @@ comfyui-workflow-templates is not installed.
|
||||
logging.info("Falling back to the default frontend.")
|
||||
check_frontend_version()
|
||||
return cls.default_frontend_path()
|
||||
@classmethod
|
||||
def template_asset_handler(cls):
|
||||
assets = cls.template_asset_map()
|
||||
if not assets:
|
||||
return None
|
||||
|
||||
async def serve_template(request: web.Request) -> web.StreamResponse:
|
||||
rel_path = request.match_info.get("path", "")
|
||||
target = assets.get(rel_path)
|
||||
if target is None:
|
||||
raise web.HTTPNotFound()
|
||||
return web.FileResponse(target)
|
||||
|
||||
return serve_template
|
||||
|
||||
@ -59,6 +59,9 @@ class UserManager():
|
||||
user = "default"
|
||||
if args.multi_user and "comfy-user" in request.headers:
|
||||
user = request.headers["comfy-user"]
|
||||
# Block System Users (use same error message to prevent probing)
|
||||
if user.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise KeyError("Unknown user: " + user)
|
||||
|
||||
if user not in self.users:
|
||||
raise KeyError("Unknown user: " + user)
|
||||
@ -66,15 +69,16 @@ class UserManager():
|
||||
return user
|
||||
|
||||
def get_request_user_filepath(self, request, file, type="userdata", create_dir=True):
|
||||
user_directory = folder_paths.get_user_directory()
|
||||
|
||||
if type == "userdata":
|
||||
root_dir = user_directory
|
||||
root_dir = folder_paths.get_user_directory()
|
||||
else:
|
||||
raise KeyError("Unknown filepath type:" + type)
|
||||
|
||||
user = self.get_request_user_id(request)
|
||||
path = user_root = os.path.abspath(os.path.join(root_dir, user))
|
||||
user_root = folder_paths.get_public_user_directory(user)
|
||||
if user_root is None:
|
||||
return None
|
||||
path = user_root
|
||||
|
||||
# prevent leaving /{type}
|
||||
if os.path.commonpath((root_dir, user_root)) != root_dir:
|
||||
@ -101,7 +105,11 @@ class UserManager():
|
||||
name = name.strip()
|
||||
if not name:
|
||||
raise ValueError("username not provided")
|
||||
if name.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise ValueError("System User prefix not allowed")
|
||||
user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name)
|
||||
if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX):
|
||||
raise ValueError("System User prefix not allowed")
|
||||
user_id = user_id + "_" + str(uuid.uuid4())
|
||||
|
||||
self.users[user_id] = name
|
||||
@ -132,7 +140,10 @@ class UserManager():
|
||||
if username in self.users.values():
|
||||
return web.json_response({"error": "Duplicate username."}, status=400)
|
||||
|
||||
user_id = self.add_user(username)
|
||||
try:
|
||||
user_id = self.add_user(username)
|
||||
except ValueError as e:
|
||||
return web.json_response({"error": str(e)}, status=400)
|
||||
return web.json_response(user_id)
|
||||
|
||||
@routes.get("/userdata")
|
||||
@ -424,7 +435,7 @@ class UserManager():
|
||||
return source
|
||||
|
||||
dest = get_user_data_path(request, check_exists=False, param="dest")
|
||||
if not isinstance(source, str):
|
||||
if not isinstance(dest, str):
|
||||
return dest
|
||||
|
||||
overwrite = request.query.get("overwrite", 'true') != "false"
|
||||
|
||||
@ -413,7 +413,8 @@ class ControlNet(nn.Module):
|
||||
out_middle = []
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
if y is None:
|
||||
raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?")
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x
|
||||
|
||||
@ -121,6 +121,12 @@ upcast.add_argument("--force-upcast-attention", action="store_true", help="Force
|
||||
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
||||
|
||||
|
||||
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
||||
manager_group = parser.add_mutually_exclusive_group()
|
||||
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
||||
|
||||
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||
vram_group.add_argument("--highvram", action="store_true", help="By default models will be unloaded to CPU memory after being used. This option keeps them in GPU memory.")
|
||||
@ -131,7 +137,8 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
|
||||
|
||||
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
|
||||
|
||||
parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
@ -160,13 +167,14 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
|
||||
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
|
||||
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
|
||||
parser.add_argument("--whitelist-custom-nodes", type=str, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
|
||||
parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.")
|
||||
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
|
||||
parser.add_argument("--verbose", default='INFO', const='DEBUG', nargs="?", choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], help='Set the logging level')
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
|
||||
# The default built-in provider hosted under web/
|
||||
DEFAULT_VERSION_STRING = "comfyanonymous/ComfyUI@latest"
|
||||
|
||||
|
||||
@ -51,26 +51,36 @@ class ContextHandlerABC(ABC):
|
||||
|
||||
|
||||
class IndexListContextWindow(ContextWindowABC):
|
||||
def __init__(self, index_list: list[int], dim: int=0):
|
||||
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
|
||||
self.index_list = index_list
|
||||
self.context_length = len(index_list)
|
||||
self.dim = dim
|
||||
self.total_frames = total_frames
|
||||
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
|
||||
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
|
||||
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
if dim == 0 and full.shape[dim] == 1:
|
||||
return full
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
return full[idx].to(device)
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
window = full[idx]
|
||||
if retain_index_list:
|
||||
idx = tuple([slice(None)] * dim + [retain_index_list])
|
||||
window[idx] = full[idx]
|
||||
return window.to(device)
|
||||
|
||||
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
|
||||
if dim is None:
|
||||
dim = self.dim
|
||||
idx = [slice(None)] * dim + [self.index_list]
|
||||
idx = tuple([slice(None)] * dim + [self.index_list])
|
||||
full[idx] += to_add
|
||||
return full
|
||||
|
||||
def get_region_index(self, num_regions: int) -> int:
|
||||
region_idx = int(self.center_ratio * num_regions)
|
||||
return min(max(region_idx, 0), num_regions - 1)
|
||||
|
||||
|
||||
class IndexListCallbacks:
|
||||
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
|
||||
@ -94,7 +104,8 @@ class ContextFuseMethod:
|
||||
|
||||
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
|
||||
class IndexListContextHandler(ContextHandlerABC):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
|
||||
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
|
||||
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
|
||||
self.context_schedule = context_schedule
|
||||
self.fuse_method = fuse_method
|
||||
self.context_length = context_length
|
||||
@ -103,13 +114,18 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
self.closed_loop = closed_loop
|
||||
self.dim = dim
|
||||
self._step = 0
|
||||
self.freenoise = freenoise
|
||||
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
|
||||
self.split_conds_to_windows = split_conds_to_windows
|
||||
|
||||
self.callbacks = {}
|
||||
|
||||
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
|
||||
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
|
||||
if x_in.size(self.dim) > self.context_length:
|
||||
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
|
||||
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
|
||||
if self.cond_retain_index_list:
|
||||
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -123,6 +139,11 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return None
|
||||
# reuse or resize cond items to match context requirements
|
||||
resized_cond = []
|
||||
# if multiple conds, split based on primary region
|
||||
if self.split_conds_to_windows and len(cond_in) > 1:
|
||||
region = window.get_region_index(len(cond_in))
|
||||
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
|
||||
cond_in = [cond_in[region]]
|
||||
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
|
||||
for actual_cond in cond_in:
|
||||
resized_actual_cond = actual_cond.copy()
|
||||
@ -146,12 +167,19 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
|
||||
for cond_key, cond_value in new_cond_item.items():
|
||||
if isinstance(cond_value, torch.Tensor):
|
||||
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
|
||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||
# Handle audio_embed (temporal dim is 1)
|
||||
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
audio_cond = cond_value.cond
|
||||
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
|
||||
# if has cond that is a Tensor, check if needs to be subset
|
||||
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
|
||||
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
|
||||
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
|
||||
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
|
||||
elif cond_key == "num_video_frames": # for SVD
|
||||
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
|
||||
new_cond_item[cond_key].cond = window.context_length
|
||||
@ -164,7 +192,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
return resized_cond
|
||||
|
||||
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
|
||||
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
|
||||
matches = torch.nonzero(mask)
|
||||
if torch.numel(matches) == 0:
|
||||
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
|
||||
@ -173,7 +201,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
|
||||
full_length = x_in.size(self.dim) # TODO: choose dim based on model
|
||||
context_windows = self.context_schedule.func(full_length, self, model_options)
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
|
||||
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
|
||||
return context_windows
|
||||
|
||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||
@ -250,8 +278,8 @@ class IndexListContextHandler(ContextHandlerABC):
|
||||
prev_weight = (bias_total / (bias_total + bias))
|
||||
new_weight = (bias / (bias_total + bias))
|
||||
# account for dims of tensors
|
||||
idx_window = [slice(None)] * self.dim + [idx]
|
||||
pos_window = [slice(None)] * self.dim + [pos]
|
||||
idx_window = tuple([slice(None)] * self.dim + [idx])
|
||||
pos_window = tuple([slice(None)] * self.dim + [pos])
|
||||
# apply new values
|
||||
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
|
||||
biases_final[i][idx] = bias_total + bias
|
||||
@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
|
||||
)
|
||||
|
||||
|
||||
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
|
||||
model_options = extra_args.get("model_options", None)
|
||||
if model_options is None:
|
||||
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
handler: IndexListContextHandler = model_options.get("context_handler", None)
|
||||
if handler is None:
|
||||
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
|
||||
if not handler.freenoise:
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
|
||||
|
||||
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
|
||||
|
||||
|
||||
def create_sampler_sample_wrapper(model: ModelPatcher):
|
||||
model.add_wrapper_with_key(
|
||||
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
|
||||
"ContextWindows_sampler_sample",
|
||||
_sampler_sample_wrapper
|
||||
)
|
||||
|
||||
|
||||
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
|
||||
total_dims = len(x_in.shape)
|
||||
weights_tensor = torch.Tensor(weights).to(device=device)
|
||||
@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
|
||||
for i in range(len(window)):
|
||||
# 2) add end_delta to each val to slide windows to end
|
||||
window[i] = window[i] + end_delta
|
||||
|
||||
|
||||
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
|
||||
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
|
||||
logging.info("Context windows: Applying FreeNoise")
|
||||
generator = torch.Generator(device='cpu').manual_seed(seed)
|
||||
latent_video_length = noise.shape[dim]
|
||||
delta = context_length - context_overlap
|
||||
|
||||
for start_idx in range(0, latent_video_length - context_length, delta):
|
||||
place_idx = start_idx + context_length
|
||||
|
||||
actual_delta = min(delta, latent_video_length - place_idx)
|
||||
if actual_delta <= 0:
|
||||
break
|
||||
|
||||
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
|
||||
|
||||
source_slice = [slice(None)] * noise.ndim
|
||||
source_slice[dim] = list_idx
|
||||
target_slice = [slice(None)] * noise.ndim
|
||||
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
|
||||
|
||||
noise[tuple(target_slice)] = noise[tuple(source_slice)]
|
||||
|
||||
return noise
|
||||
|
||||
@ -1557,10 +1557,13 @@ def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
|
||||
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
|
||||
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
|
||||
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
|
||||
"""
|
||||
if solver_type not in {"phi_1", "phi_2"}:
|
||||
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
|
||||
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
seed = extra_args.get("seed", None)
|
||||
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
|
||||
@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
||||
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
|
||||
|
||||
# Step 2
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
if solver_type == "phi_1":
|
||||
denoised_d = torch.lerp(denoised, denoised_2, fac)
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
|
||||
elif solver_type == "phi_2":
|
||||
b2 = ei_h_phi_2(-h_eta) / r
|
||||
b1 = ei_h_phi_1(-h_eta) - b2
|
||||
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
|
||||
|
||||
if inject_noise:
|
||||
segment_factor = (r - 1) * h * eta
|
||||
sde_noise = sde_noise * segment_factor.exp()
|
||||
|
||||
@ -6,6 +6,7 @@ class LatentFormat:
|
||||
latent_dimensions = 2
|
||||
latent_rgb_factors = None
|
||||
latent_rgb_factors_bias = None
|
||||
latent_rgb_factors_reshape = None
|
||||
taesd_decoder_name = None
|
||||
|
||||
def process_in(self, latent):
|
||||
@ -178,6 +179,54 @@ class Flux(SD3):
|
||||
def process_out(self, latent):
|
||||
return (latent / self.scale_factor) + self.shift_factor
|
||||
|
||||
class Flux2(LatentFormat):
|
||||
latent_channels = 128
|
||||
|
||||
def __init__(self):
|
||||
self.latent_rgb_factors =[
|
||||
[0.0058, 0.0113, 0.0073],
|
||||
[0.0495, 0.0443, 0.0836],
|
||||
[-0.0099, 0.0096, 0.0644],
|
||||
[0.2144, 0.3009, 0.3652],
|
||||
[0.0166, -0.0039, -0.0054],
|
||||
[0.0157, 0.0103, -0.0160],
|
||||
[-0.0398, 0.0902, -0.0235],
|
||||
[-0.0052, 0.0095, 0.0109],
|
||||
[-0.3527, -0.2712, -0.1666],
|
||||
[-0.0301, -0.0356, -0.0180],
|
||||
[-0.0107, 0.0078, 0.0013],
|
||||
[0.0746, 0.0090, -0.0941],
|
||||
[0.0156, 0.0169, 0.0070],
|
||||
[-0.0034, -0.0040, -0.0114],
|
||||
[0.0032, 0.0181, 0.0080],
|
||||
[-0.0939, -0.0008, 0.0186],
|
||||
[0.0018, 0.0043, 0.0104],
|
||||
[0.0284, 0.0056, -0.0127],
|
||||
[-0.0024, -0.0022, -0.0030],
|
||||
[0.1207, -0.0026, 0.0065],
|
||||
[0.0128, 0.0101, 0.0142],
|
||||
[0.0137, -0.0072, -0.0007],
|
||||
[0.0095, 0.0092, -0.0059],
|
||||
[0.0000, -0.0077, -0.0049],
|
||||
[-0.0465, -0.0204, -0.0312],
|
||||
[0.0095, 0.0012, -0.0066],
|
||||
[0.0290, -0.0034, 0.0025],
|
||||
[0.0220, 0.0169, -0.0048],
|
||||
[-0.0332, -0.0457, -0.0468],
|
||||
[-0.0085, 0.0389, 0.0609],
|
||||
[-0.0076, 0.0003, -0.0043],
|
||||
[-0.0111, -0.0460, -0.0614],
|
||||
]
|
||||
|
||||
self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
|
||||
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
|
||||
|
||||
def process_in(self, latent):
|
||||
return latent
|
||||
|
||||
def process_out(self, latent):
|
||||
return latent
|
||||
|
||||
class Mochi(LatentFormat):
|
||||
latent_channels = 12
|
||||
latent_dimensions = 3
|
||||
@ -382,6 +431,7 @@ class HunyuanVideo(LatentFormat):
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761]
|
||||
taesd_decoder_name = "taehv"
|
||||
|
||||
class Cosmos1CV8x8x8(LatentFormat):
|
||||
latent_channels = 16
|
||||
@ -445,7 +495,7 @@ class Wan21(LatentFormat):
|
||||
]).view(1, self.latent_channels, 1, 1, 1)
|
||||
|
||||
|
||||
self.taesd_decoder_name = None #TODO
|
||||
self.taesd_decoder_name = "lighttaew2_1"
|
||||
|
||||
def process_in(self, latent):
|
||||
latents_mean = self.latents_mean.to(latent.device, latent.dtype)
|
||||
@ -516,6 +566,7 @@ class Wan22(Wan21):
|
||||
|
||||
def __init__(self):
|
||||
self.scale_factor = 1.0
|
||||
self.taesd_decoder_name = "lighttaew2_2"
|
||||
self.latents_mean = torch.tensor([
|
||||
-0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557,
|
||||
-0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825,
|
||||
@ -611,6 +662,67 @@ class HunyuanImage21Refiner(LatentFormat):
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
|
||||
def process_in(self, latent):
|
||||
out = latent * self.scale_factor
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
return out
|
||||
|
||||
def process_out(self, latent):
|
||||
z = latent / self.scale_factor
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
return z
|
||||
|
||||
class HunyuanVideo15(LatentFormat):
|
||||
latent_rgb_factors = [
|
||||
[ 0.0568, -0.0521, -0.0131],
|
||||
[ 0.0014, 0.0735, 0.0326],
|
||||
[ 0.0186, 0.0531, -0.0138],
|
||||
[-0.0031, 0.0051, 0.0288],
|
||||
[ 0.0110, 0.0556, 0.0432],
|
||||
[-0.0041, -0.0023, -0.0485],
|
||||
[ 0.0530, 0.0413, 0.0253],
|
||||
[ 0.0283, 0.0251, 0.0339],
|
||||
[ 0.0277, -0.0372, -0.0093],
|
||||
[ 0.0393, 0.0944, 0.1131],
|
||||
[ 0.0020, 0.0251, 0.0037],
|
||||
[-0.0017, 0.0012, 0.0234],
|
||||
[ 0.0468, 0.0436, 0.0203],
|
||||
[ 0.0354, 0.0439, -0.0233],
|
||||
[ 0.0090, 0.0123, 0.0346],
|
||||
[ 0.0382, 0.0029, 0.0217],
|
||||
[ 0.0261, -0.0300, 0.0030],
|
||||
[-0.0088, -0.0220, -0.0283],
|
||||
[-0.0272, -0.0121, -0.0363],
|
||||
[-0.0664, -0.0622, 0.0144],
|
||||
[ 0.0414, 0.0479, 0.0529],
|
||||
[ 0.0355, 0.0612, -0.0247],
|
||||
[ 0.0147, 0.0264, 0.0174],
|
||||
[ 0.0438, 0.0038, 0.0542],
|
||||
[ 0.0431, -0.0573, -0.0033],
|
||||
[-0.0162, -0.0211, -0.0406],
|
||||
[-0.0487, -0.0295, -0.0393],
|
||||
[ 0.0005, -0.0109, 0.0253],
|
||||
[ 0.0296, 0.0591, 0.0353],
|
||||
[ 0.0119, 0.0181, -0.0306],
|
||||
[-0.0085, -0.0362, 0.0229],
|
||||
[ 0.0005, -0.0106, 0.0242]
|
||||
]
|
||||
|
||||
latent_rgb_factors_bias = [ 0.0456, -0.0202, -0.0644]
|
||||
latent_channels = 32
|
||||
latent_dimensions = 3
|
||||
scale_factor = 1.03682
|
||||
taesd_decoder_name = "lighttaehy1_5"
|
||||
|
||||
class Hunyuan3Dv2(LatentFormat):
|
||||
latent_channels = 64
|
||||
latent_dimensions = 1
|
||||
|
||||
@ -40,7 +40,8 @@ class ChromaParams:
|
||||
out_dim: int
|
||||
hidden_dim: int
|
||||
n_layers: int
|
||||
|
||||
txt_ids_dims: list
|
||||
vec_in_dim: int
|
||||
|
||||
|
||||
|
||||
@ -179,7 +180,10 @@ class Chroma(nn.Module):
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_mmdit:
|
||||
double_mod = (
|
||||
self.get_modulations(mod_vectors, "double_img", idx=i),
|
||||
@ -222,7 +226,10 @@ class Chroma(nn.Module):
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if i not in self.skip_dit:
|
||||
single_mod = self.get_modulations(mod_vectors, "single", idx=i)
|
||||
if ("single_block", i) in blocks_replace:
|
||||
|
||||
@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
|
||||
nerf_final_head_type: str
|
||||
# None means use the same dtype as the model.
|
||||
nerf_embedder_dtype: Optional[torch.dtype]
|
||||
|
||||
use_x0: bool
|
||||
|
||||
class ChromaRadiance(Chroma):
|
||||
"""
|
||||
@ -159,6 +159,9 @@ class ChromaRadiance(Chroma):
|
||||
self.skip_dit = []
|
||||
self.lite = False
|
||||
|
||||
if params.use_x0:
|
||||
self.register_buffer("__x0__", torch.tensor([]))
|
||||
|
||||
@property
|
||||
def _nerf_final_layer(self) -> nn.Module:
|
||||
if self.params.nerf_final_head_type == "linear":
|
||||
@ -276,6 +279,12 @@ class ChromaRadiance(Chroma):
|
||||
params_dict |= overrides
|
||||
return params.__class__(**params_dict)
|
||||
|
||||
def _apply_x0_residual(self, predicted, noisy, timesteps):
|
||||
|
||||
# non zero during training to prevent 0 div
|
||||
eps = 0.0
|
||||
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
|
||||
|
||||
def _forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
@ -316,4 +325,11 @@ class ChromaRadiance(Chroma):
|
||||
transformer_options,
|
||||
attn_mask=kwargs.get("attention_mask", None),
|
||||
)
|
||||
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
|
||||
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
|
||||
|
||||
# If x0 variant → v-pred, just return this instead
|
||||
if hasattr(self, "__x0__"):
|
||||
out = self._apply_x0_residual(out, img, timestep)
|
||||
return out
|
||||
|
||||
|
||||
@ -48,15 +48,44 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
||||
return embedding
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None):
|
||||
def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device)
|
||||
self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
class YakMLP(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device)
|
||||
self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
return down_proj
|
||||
|
||||
def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
if yak_mlp:
|
||||
return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations)
|
||||
if mlp_silu_act:
|
||||
return nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device),
|
||||
SiLUActivation(),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device),
|
||||
)
|
||||
else:
|
||||
return nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, dtype=None, device=None, operations=None):
|
||||
@ -80,14 +109,14 @@ class QKNorm(torch.nn.Module):
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device)
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
self.proj = operations.Linear(dim, dim, dtype=dtype, device=device)
|
||||
self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -98,11 +127,11 @@ class ModulationOut:
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None):
|
||||
def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.is_double = double
|
||||
self.multiplier = 6 if double else 3
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
|
||||
self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, vec: Tensor) -> tuple:
|
||||
if vec.ndim == 2:
|
||||
@ -129,8 +158,18 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
return tensor
|
||||
|
||||
|
||||
class SiLUActivation(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.gate_fn = nn.SiLU()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x1, x2 = x.chunk(2, dim=-1)
|
||||
return self.gate_fn(x1) * x2
|
||||
|
||||
|
||||
class DoubleStreamBlock(nn.Module):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
@ -142,27 +181,22 @@ class DoubleStreamBlock(nn.Module):
|
||||
self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.img_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if self.modulation:
|
||||
self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
|
||||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.txt_mlp = nn.Sequential(
|
||||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
|
||||
nn.GELU(approximate="tanh"),
|
||||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
|
||||
self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.flipped_img_txt = flipped_img_txt
|
||||
|
||||
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}):
|
||||
@ -246,6 +280,9 @@ class SingleStreamBlock(nn.Module):
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: float = None,
|
||||
modulation=True,
|
||||
mlp_silu_act=False,
|
||||
bias=True,
|
||||
yak_mlp=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None
|
||||
@ -257,17 +294,29 @@ class SingleStreamBlock(nn.Module):
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
|
||||
self.mlp_hidden_dim_first = self.mlp_hidden_dim
|
||||
self.yak_mlp = yak_mlp
|
||||
if mlp_silu_act:
|
||||
self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2)
|
||||
self.mlp_act = SiLUActivation()
|
||||
else:
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
if self.yak_mlp:
|
||||
self.mlp_hidden_dim_first *= 2
|
||||
self.mlp_act = nn.SiLU()
|
||||
|
||||
# qkv and mlp_in
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
|
||||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device)
|
||||
# proj and mlp_out
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
|
||||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device)
|
||||
|
||||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
if modulation:
|
||||
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
@ -279,7 +328,7 @@ class SingleStreamBlock(nn.Module):
|
||||
else:
|
||||
mod = vec
|
||||
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||
|
||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
del qkv
|
||||
@ -289,7 +338,10 @@ class SingleStreamBlock(nn.Module):
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
# compute activation in mlp stream, cat again and run second linear layer
|
||||
mlp = self.mlp_act(mlp)
|
||||
if self.yak_mlp:
|
||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||
else:
|
||||
mlp = self.mlp_act(mlp)
|
||||
output = self.linear2(torch.cat((attn, mlp), 2))
|
||||
x += apply_mod(output, mod.gate, None, modulation_dims)
|
||||
if x.dtype == torch.float16:
|
||||
@ -298,11 +350,11 @@ class SingleStreamBlock(nn.Module):
|
||||
|
||||
|
||||
class LastLayer(nn.Module):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
|
||||
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device))
|
||||
|
||||
def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
|
||||
if vec.ndim == 2:
|
||||
|
||||
@ -15,6 +15,8 @@ from .layers import (
|
||||
MLPEmbedder,
|
||||
SingleStreamBlock,
|
||||
timestep_embedding,
|
||||
Modulation,
|
||||
RMSNorm
|
||||
)
|
||||
|
||||
@dataclass
|
||||
@ -33,6 +35,14 @@ class FluxParams:
|
||||
patch_size: int
|
||||
qkv_bias: bool
|
||||
guidance_embed: bool
|
||||
txt_ids_dims: list
|
||||
global_modulation: bool = False
|
||||
mlp_silu_act: bool = False
|
||||
ops_bias: bool = True
|
||||
default_ref_method: str = "offset"
|
||||
ref_index_scale: float = 1.0
|
||||
yak_mlp: bool = False
|
||||
txt_norm: bool = False
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
@ -58,13 +68,22 @@ class Flux(nn.Module):
|
||||
self.hidden_size = params.hidden_size
|
||||
self.num_heads = params.num_heads
|
||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||
if params.vec_in_dim is not None:
|
||||
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.vector_in = None
|
||||
|
||||
self.guidance_in = (
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity()
|
||||
)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
|
||||
self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device)
|
||||
|
||||
if params.txt_norm:
|
||||
self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations)
|
||||
else:
|
||||
self.txt_norm = None
|
||||
|
||||
self.double_blocks = nn.ModuleList(
|
||||
[
|
||||
@ -73,6 +92,10 @@ class Flux(nn.Module):
|
||||
self.num_heads,
|
||||
mlp_ratio=params.mlp_ratio,
|
||||
qkv_bias=params.qkv_bias,
|
||||
modulation=params.global_modulation is False,
|
||||
mlp_silu_act=params.mlp_silu_act,
|
||||
proj_bias=params.ops_bias,
|
||||
yak_mlp=params.yak_mlp,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(params.depth)
|
||||
@ -81,13 +104,30 @@ class Flux(nn.Module):
|
||||
|
||||
self.single_blocks = nn.ModuleList(
|
||||
[
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
|
||||
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations)
|
||||
for _ in range(params.depth_single_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if params.global_modulation:
|
||||
self.double_stream_modulation_img = Modulation(
|
||||
self.hidden_size,
|
||||
double=True,
|
||||
bias=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.double_stream_modulation_txt = Modulation(
|
||||
self.hidden_size,
|
||||
double=True,
|
||||
bias=False,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.single_stream_modulation = Modulation(
|
||||
self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
@ -103,9 +143,6 @@ class Flux(nn.Module):
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
if img.ndim != 3 or txt.ndim != 3:
|
||||
@ -118,9 +155,19 @@ class Flux(nn.Module):
|
||||
if guidance is not None:
|
||||
vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
|
||||
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
if self.vector_in is not None:
|
||||
if y is None:
|
||||
y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype)
|
||||
vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
|
||||
|
||||
if self.txt_norm is not None:
|
||||
txt = self.txt_norm(txt)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vec_orig = vec
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||
@ -136,7 +183,10 @@ class Flux(nn.Module):
|
||||
pe = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -177,7 +227,13 @@ class Flux(nn.Module):
|
||||
|
||||
img = torch.cat((txt, img), 1)
|
||||
|
||||
if self.params.global_modulation:
|
||||
vec, _ = self.single_stream_modulation(vec_orig)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -207,7 +263,7 @@ class Flux(nn.Module):
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
||||
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||
@ -234,10 +290,10 @@ class Flux(nn.Module):
|
||||
h_offset += rope_options.get("shift_y", 0.0)
|
||||
w_offset += rope_options.get("shift_x", 0.0)
|
||||
|
||||
img_ids = torch.zeros((steps_h, steps_w, 3), device=x.device, dtype=x.dtype)
|
||||
img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
img_ids[:, :, 0] = img_ids[:, :, 1] + index
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=x.dtype).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1)
|
||||
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0)
|
||||
return img, repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options={}, **kwargs):
|
||||
@ -259,10 +315,10 @@ class Flux(nn.Module):
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_latents_method = kwargs.get("ref_latents_method", "offset")
|
||||
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
|
||||
for ref in ref_latents:
|
||||
if ref_latents_method == "index":
|
||||
index += 1
|
||||
index += self.params.ref_index_scale
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
elif ref_latents_method == "uxo":
|
||||
@ -286,7 +342,12 @@ class Flux(nn.Module):
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
|
||||
if len(self.params.txt_ids_dims) > 0:
|
||||
for i in self.params.txt_ids_dims:
|
||||
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h_orig,:w_orig]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
|
||||
|
||||
@ -6,7 +6,6 @@ import comfy.ldm.flux.layers
|
||||
import comfy.ldm.modules.diffusionmodules.mmdit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from einops import repeat
|
||||
|
||||
@ -42,6 +41,9 @@ class HunyuanVideoParams:
|
||||
guidance_embed: bool
|
||||
byt5: bool
|
||||
meanflow: bool
|
||||
use_cond_type_embedding: bool
|
||||
vision_in_dim: int
|
||||
meanflow_sum: bool
|
||||
|
||||
|
||||
class SelfAttentionRef(nn.Module):
|
||||
@ -157,7 +159,10 @@ class TokenRefiner(nn.Module):
|
||||
t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype))
|
||||
# m = mask.float().unsqueeze(-1)
|
||||
# c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
if x.dtype == torch.float16:
|
||||
c = x.float().sum(dim=1) / x.shape[1]
|
||||
else:
|
||||
c = x.sum(dim=1) / x.shape[1]
|
||||
|
||||
c = t + self.c_embedder(c.to(x.dtype))
|
||||
x = self.input_embedder(x)
|
||||
@ -196,11 +201,15 @@ class HunyuanVideo(nn.Module):
|
||||
def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
|
||||
super().__init__()
|
||||
self.dtype = dtype
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
params = HunyuanVideoParams(**kwargs)
|
||||
self.params = params
|
||||
self.patch_size = params.patch_size
|
||||
self.in_channels = params.in_channels
|
||||
self.out_channels = params.out_channels
|
||||
self.use_cond_type_embedding = params.use_cond_type_embedding
|
||||
self.vision_in_dim = params.vision_in_dim
|
||||
if params.hidden_size % params.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
|
||||
@ -266,6 +275,18 @@ class HunyuanVideo(nn.Module):
|
||||
if final_layer:
|
||||
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
# HunyuanVideo 1.5 specific modules
|
||||
if self.vision_in_dim is not None:
|
||||
from comfy.ldm.wan.model import MLPProj
|
||||
self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings)
|
||||
else:
|
||||
self.vision_in = None
|
||||
if self.use_cond_type_embedding:
|
||||
# 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature
|
||||
self.cond_type_embedding = nn.Embedding(3, self.hidden_size)
|
||||
else:
|
||||
self.cond_type_embedding = None
|
||||
|
||||
def forward_orig(
|
||||
self,
|
||||
img: Tensor,
|
||||
@ -276,6 +297,7 @@ class HunyuanVideo(nn.Module):
|
||||
timesteps: Tensor,
|
||||
y: Tensor = None,
|
||||
txt_byt5=None,
|
||||
clip_fea=None,
|
||||
guidance: Tensor = None,
|
||||
guiding_frame_index=None,
|
||||
ref_latent=None,
|
||||
@ -296,7 +318,7 @@ class HunyuanVideo(nn.Module):
|
||||
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
|
||||
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
|
||||
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
|
||||
vec = (vec + vec_r) / 2
|
||||
vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
|
||||
|
||||
if ref_latent is not None:
|
||||
ref_latent_ids = self.img_ids(ref_latent)
|
||||
@ -331,12 +353,31 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options)
|
||||
|
||||
if self.cond_type_embedding is not None:
|
||||
self.cond_type_embedding.to(txt.device)
|
||||
cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long))
|
||||
txt = txt + cond_emb.to(txt.dtype)
|
||||
|
||||
if self.byt5_in is not None and txt_byt5 is not None:
|
||||
txt_byt5 = self.byt5_in(txt_byt5)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long))
|
||||
txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype)
|
||||
txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5
|
||||
else:
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt = torch.cat((txt, txt_byt5), dim=1)
|
||||
txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1)
|
||||
|
||||
if clip_fea is not None:
|
||||
txt_vision_states = self.vision_in(clip_fea)
|
||||
if self.cond_type_embedding is not None:
|
||||
cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device))
|
||||
txt_vision_states = txt_vision_states + cond_emb
|
||||
txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1)
|
||||
extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype)
|
||||
txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1)
|
||||
|
||||
ids = torch.cat((img_ids, txt_ids), dim=1)
|
||||
pe = self.pe_embedder(ids)
|
||||
|
||||
@ -349,7 +390,10 @@ class HunyuanVideo(nn.Module):
|
||||
attn_mask = None
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.double_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -371,7 +415,10 @@ class HunyuanVideo(nn.Module):
|
||||
|
||||
img = torch.cat((img, txt), 1)
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
for i, block in enumerate(self.single_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("single_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
@ -430,14 +477,14 @@ class HunyuanVideo(nn.Module):
|
||||
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
|
||||
return repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
||||
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs)
|
||||
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options={}, **kwargs):
|
||||
bs = x.shape[0]
|
||||
if len(self.patch_size) == 3:
|
||||
img_ids = self.img_ids(x)
|
||||
@ -445,5 +492,5 @@ class HunyuanVideo(nn.Module):
|
||||
else:
|
||||
img_ids = self.img_ids_2d(x)
|
||||
txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options)
|
||||
return out
|
||||
|
||||
121
comfy/ldm/hunyuan_video/upsampler.py
Normal file
121
comfy/ldm/hunyuan_video/upsampler.py
Normal file
@ -0,0 +1,121 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||
import model_management, model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
super().__init__()
|
||||
self.block = nn.Sequential(
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
nn.SiLU(inplace=True),
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
nn.SiLU(inplace=True),
|
||||
VideoConv3d(channels, channels, kernel_size=3),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x + self.block(x)
|
||||
|
||||
class SRModel3DV2(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
hidden_channels: int = 64,
|
||||
num_blocks: int = 6,
|
||||
global_residual: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3)
|
||||
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
|
||||
self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3)
|
||||
self.global_residual = bool(global_residual)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
y = self.in_conv(x)
|
||||
for blk in self.blocks:
|
||||
y = blk(y)
|
||||
y = self.out_conv(y)
|
||||
if self.global_residual and (y.shape == residual.shape):
|
||||
y = y + residual
|
||||
return y
|
||||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
z_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: tuple[int, ...],
|
||||
num_res_blocks: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.block_out_channels = block_out_channels
|
||||
self.z_channels = z_channels
|
||||
|
||||
ch = block_out_channels[0]
|
||||
self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
|
||||
for i, tgt in enumerate(block_out_channels):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt,
|
||||
out_channels=tgt,
|
||||
temb_channels=0,
|
||||
conv_shortcut=False,
|
||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||
for j in range(num_res_blocks + 1)])
|
||||
ch = tgt
|
||||
self.up.append(stage)
|
||||
|
||||
self.norm_out = RMS_norm(ch)
|
||||
self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3)
|
||||
|
||||
def forward(self, z):
|
||||
"""
|
||||
Args:
|
||||
z: (B, C, T, H, W)
|
||||
target_shape: (H, W)
|
||||
"""
|
||||
# z to block_in
|
||||
repeats = self.block_out_channels[0] // (self.z_channels)
|
||||
x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
|
||||
|
||||
# upsampling
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x)))
|
||||
return out
|
||||
|
||||
UPSAMPLERS = {
|
||||
"720p": SRModel3DV2,
|
||||
"1080p": Upsampler,
|
||||
}
|
||||
|
||||
class HunyuanVideo15SRModel():
|
||||
def __init__(self, model_type, config):
|
||||
self.load_device = model_management.vae_device()
|
||||
offload_device = model_management.vae_offload_device()
|
||||
self.dtype = model_management.vae_dtype(self.load_device)
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
|
||||
def get_sd(self):
|
||||
return self.model.state_dict()
|
||||
|
||||
def resample_latent(self, latent):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
return self.model(latent.to(self.load_device))
|
||||
@ -1,11 +1,13 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, torch_cat_if_needed
|
||||
import comfy.ops
|
||||
import comfy.ldm.models.autoencoder
|
||||
import comfy.model_management
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
|
||||
class RMS_norm(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
@ -14,10 +16,10 @@ class RMS_norm(nn.Module):
|
||||
self.gamma = nn.Parameter(torch.empty(shape))
|
||||
|
||||
def forward(self, x):
|
||||
return F.normalize(x, dim=1) * self.scale * self.gamma
|
||||
return F.normalize(x, dim=1) * self.scale * comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device)
|
||||
|
||||
class DnSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d):
|
||||
def __init__(self, ic, oc, tds, refiner_vae, op):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tds else 1 * 2 * 2
|
||||
assert oc % fct == 0
|
||||
@ -27,11 +29,12 @@ class DnSmpl(nn.Module):
|
||||
self.tds = tds
|
||||
self.gs = fct * ic // oc
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
r1 = 2 if self.tds else 1
|
||||
h = self.conv(x)
|
||||
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
|
||||
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||
|
||||
if self.tds and self.refiner_vae:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2)
|
||||
@ -39,14 +42,7 @@ class DnSmpl(nn.Module):
|
||||
hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2)
|
||||
hf = torch.cat([hf, hf], dim=1)
|
||||
|
||||
hn = h[:, :, 1:, :, :]
|
||||
b, c, frms, ht, wd = hn.shape
|
||||
nf = frms // r1
|
||||
hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
h = torch.cat([hf, hn], dim=2)
|
||||
h = h[:, :, 1:, :, :]
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
@ -54,38 +50,36 @@ class DnSmpl(nn.Module):
|
||||
xf = xf.permute(0, 4, 6, 1, 2, 3, 5)
|
||||
xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xf.shape
|
||||
xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||
xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2)
|
||||
|
||||
xn = x[:, :, 1:, :, :]
|
||||
b, ci, frms, ht, wd = xn.shape
|
||||
nf = frms // r1
|
||||
xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = xn.shape
|
||||
xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
sc = torch.cat([xf, xn], dim=2)
|
||||
else:
|
||||
b, c, frms, ht, wd = h.shape
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
nf = frms // r1
|
||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
if h.shape[2] == 0:
|
||||
return hf + xf
|
||||
|
||||
b, ci, frms, ht, wd = x.shape
|
||||
nf = frms // r1
|
||||
sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = sc.shape
|
||||
sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nf = frms // r1
|
||||
h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
h = h.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2)
|
||||
|
||||
return h + sc
|
||||
b, ci, frms, ht, wd = x.shape
|
||||
nf = frms // r1
|
||||
x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2)
|
||||
x = x.permute(0, 3, 5, 7, 1, 2, 4, 6)
|
||||
x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2)
|
||||
B, C, T, H, W = x.shape
|
||||
x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2)
|
||||
|
||||
if self.tds and self.refiner_vae and conv_carry_in is None:
|
||||
h = torch.cat([hf, h], dim=2)
|
||||
x = torch.cat([xf, x], dim=2)
|
||||
|
||||
return h + x
|
||||
|
||||
|
||||
class UpSmpl(nn.Module):
|
||||
def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d):
|
||||
def __init__(self, ic, oc, tus, refiner_vae, op):
|
||||
super().__init__()
|
||||
fct = 2 * 2 * 2 if tus else 1 * 2 * 2
|
||||
self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1)
|
||||
@ -94,11 +88,11 @@ class UpSmpl(nn.Module):
|
||||
self.tus = tus
|
||||
self.rp = fct * oc // ic
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
r1 = 2 if self.tus else 1
|
||||
h = self.conv(x)
|
||||
h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
|
||||
if self.tus and self.refiner_vae:
|
||||
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||
hf = h[:, :, :1, :, :]
|
||||
b, c, f, ht, wd = hf.shape
|
||||
nc = c // (2 * 2)
|
||||
@ -107,14 +101,7 @@ class UpSmpl(nn.Module):
|
||||
hf = hf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
hf = hf[:, : hf.shape[1] // 2]
|
||||
|
||||
hn = h[:, :, 1:, :, :]
|
||||
b, c, frms, ht, wd = hn.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
h = torch.cat([hf, hn], dim=2)
|
||||
h = h[:, :, 1:, :, :]
|
||||
|
||||
xf = x[:, :, :1, :, :]
|
||||
b, ci, f, ht, wd = xf.shape
|
||||
@ -125,29 +112,26 @@ class UpSmpl(nn.Module):
|
||||
xf = xf.permute(0, 3, 4, 5, 1, 6, 2)
|
||||
xf = xf.reshape(b, nc, f, ht * 2, wd * 2)
|
||||
|
||||
xn = x[:, :, 1:, :, :]
|
||||
xn = xn.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = xn.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
sc = torch.cat([xf, xn], dim=2)
|
||||
else:
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
x = x[:, :, 1:, :, :]
|
||||
|
||||
sc = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = sc.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
b, c, frms, ht, wd = h.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
h = h.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
return h + sc
|
||||
x = x.repeat_interleave(repeats=self.rp, dim=1)
|
||||
b, c, frms, ht, wd = x.shape
|
||||
nc = c // (r1 * 2 * 2)
|
||||
x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd)
|
||||
x = x.permute(0, 4, 5, 1, 6, 2, 7, 3)
|
||||
x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2)
|
||||
|
||||
if self.tus and self.refiner_vae and conv_carry_in is None:
|
||||
h = torch.cat([hf, h], dim=2)
|
||||
x = torch.cat([xf, x], dim=2)
|
||||
|
||||
return h + x
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||
@ -160,7 +144,7 @@ class Encoder(nn.Module):
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = VideoConv3d
|
||||
conv_op = CarriedConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
@ -188,9 +172,9 @@ class Encoder(nn.Module):
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.norm_out = norm_op(ch)
|
||||
self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1)
|
||||
@ -201,31 +185,48 @@ class Encoder(nn.Module):
|
||||
if not self.refiner_vae and x.shape[2] == 1:
|
||||
x = x.expand(-1, -1, self.ffactor_temporal, -1, -1)
|
||||
|
||||
x = self.conv_in(x)
|
||||
if self.refiner_vae:
|
||||
xl = [x[:, :, :1, :, :]]
|
||||
if x.shape[2] > self.ffactor_temporal:
|
||||
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2)
|
||||
x = xl
|
||||
else:
|
||||
x = [x]
|
||||
out = []
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x = stage.downsample(x)
|
||||
conv_carry_in = None
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
|
||||
x1 = [ x1 ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||
|
||||
for stage in self.down:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'downsample'):
|
||||
x1 = stage.downsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out)))
|
||||
del out
|
||||
|
||||
b, c, t, h, w = x.shape
|
||||
grp = c // (self.z_channels << 1)
|
||||
skip = x.view(b, c // grp, grp, t, h, w).mean(2)
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x))) + skip
|
||||
out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip
|
||||
|
||||
if self.refiner_vae:
|
||||
out = self.regul(out)[0]
|
||||
|
||||
out = torch.cat((out[:, :, :1], out), dim=2)
|
||||
out = out.permute(0, 2, 1, 3, 4)
|
||||
b, f_times_2, c, h, w = out.shape
|
||||
out = out.reshape(b, f_times_2 // 2, 2 * c, h, w)
|
||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||
|
||||
return out
|
||||
|
||||
class Decoder(nn.Module):
|
||||
@ -239,7 +240,7 @@ class Decoder(nn.Module):
|
||||
|
||||
self.refiner_vae = refiner_vae
|
||||
if self.refiner_vae:
|
||||
conv_op = VideoConv3d
|
||||
conv_op = CarriedConv3d
|
||||
norm_op = RMS_norm
|
||||
else:
|
||||
conv_op = ops.Conv3d
|
||||
@ -249,9 +250,9 @@ class Decoder(nn.Module):
|
||||
self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
depth = (ffactor_spatial >> 1).bit_length()
|
||||
@ -275,27 +276,38 @@ class Decoder(nn.Module):
|
||||
self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
if self.refiner_vae:
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
b, f, c, h, w = z.shape
|
||||
z = z.reshape(b, f, 2, c // 2, h, w)
|
||||
z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w)
|
||||
z = z.permute(0, 2, 1, 3, 4)
|
||||
z = z[:, :, 1:]
|
||||
|
||||
x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1)
|
||||
x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x)))
|
||||
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x = blk(x)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x = stage.upsample(x)
|
||||
if self.refiner_vae:
|
||||
x = torch.split(x, 2, dim=2)
|
||||
else:
|
||||
x = [ x ]
|
||||
out = []
|
||||
|
||||
out = self.conv_out(F.silu(self.norm_out(x)))
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
for stage in self.up:
|
||||
for blk in stage.block:
|
||||
x1 = blk(x1, None, conv_carry_in, conv_carry_out)
|
||||
if hasattr(stage, 'upsample'):
|
||||
x1 = stage.upsample(x1, conv_carry_in, conv_carry_out)
|
||||
|
||||
x1 = [ F.silu(self.norm_out(x1)) ]
|
||||
x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||
out.append(x1)
|
||||
conv_carry_in = conv_carry_out
|
||||
del x
|
||||
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
if not self.refiner_vae:
|
||||
if z.shape[-3] == 1:
|
||||
out = out[:, :, -1:]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
413
comfy/ldm/kandinsky5/model.py
Normal file
413
comfy/ldm/kandinsky5/model.py
Normal file
@ -0,0 +1,413 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
import math
|
||||
|
||||
import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
|
||||
def attention(q, k, v, heads, transformer_options={}):
|
||||
return optimized_attention(
|
||||
q.transpose(1, 2),
|
||||
k.transpose(1, 2),
|
||||
v.transpose(1, 2),
|
||||
heads=heads,
|
||||
skip_reshape=True,
|
||||
transformer_options=transformer_options
|
||||
)
|
||||
|
||||
def apply_scale_shift_norm(norm, x, scale, shift):
|
||||
return torch.addcmul(shift, norm(x), scale + 1.0)
|
||||
|
||||
def apply_gate_sum(x, out, gate):
|
||||
return torch.addcmul(x, gate, out)
|
||||
|
||||
def get_shift_scale_gate(params):
|
||||
shift, scale, gate = torch.chunk(params, 3, dim=-1)
|
||||
return tuple(x.unsqueeze(1) for x in (shift, scale, gate))
|
||||
|
||||
def get_freqs(dim, max_period=10000.0):
|
||||
return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim)
|
||||
|
||||
|
||||
class TimeEmbeddings(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None):
|
||||
super().__init__()
|
||||
assert model_dim % 2 == 0
|
||||
self.model_dim = model_dim
|
||||
self.max_period = max_period
|
||||
self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False)
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, timestep, dtype):
|
||||
args = torch.outer(timestep, self.freqs.to(device=timestep.device))
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype)
|
||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||
return time_embed
|
||||
|
||||
|
||||
class TextEmbeddings(nn.Module):
|
||||
def __init__(self, text_dim, model_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, text_embed):
|
||||
text_embed = self.in_layer(text_embed)
|
||||
return self.norm(text_embed).type_as(text_embed)
|
||||
|
||||
|
||||
class VisualEmbeddings(nn.Module):
|
||||
def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
x = x.movedim(1, -1) # B C T H W -> B T H W C
|
||||
B, T, H, W, dim = x.shape
|
||||
pt, ph, pw = self.patch_size
|
||||
|
||||
x = x.view(
|
||||
B,
|
||||
T // pt, pt,
|
||||
H // ph, ph,
|
||||
W // pw, pw,
|
||||
dim,
|
||||
).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7)
|
||||
|
||||
return self.in_layer(x)
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
def __init__(self, time_dim, model_dim, num_params, operation_settings=None):
|
||||
super().__init__()
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(x))
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
def __init__(self, num_channels, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
assert num_channels % head_dim == 0
|
||||
self.num_heads = num_channels // head_dim
|
||||
self.head_dim = head_dim
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 2
|
||||
|
||||
def _compute_qk(self, x, freqs, proj_fn, norm_fn):
|
||||
result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
return apply_rope1(norm_fn(result), freqs)
|
||||
|
||||
def _forward(self, x, freqs, transformer_options={}):
|
||||
q = self._compute_qk(x, freqs, self.to_query, self.query_norm)
|
||||
k = self._compute_qk(x, freqs, self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def _forward_chunked(self, x, freqs, transformer_options={}):
|
||||
def process_chunks(proj_fn, norm_fn):
|
||||
x_chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1)
|
||||
chunks = []
|
||||
for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks):
|
||||
chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn))
|
||||
return torch.cat(chunks, dim=1)
|
||||
|
||||
q = process_chunks(self.to_query, self.query_norm)
|
||||
k = process_chunks(self.to_key, self.key_norm)
|
||||
v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
out = attention(q, k, v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
def forward(self, x, freqs, transformer_options={}):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x, freqs, transformer_options=transformer_options)
|
||||
else:
|
||||
return self._forward(x, freqs, transformer_options=transformer_options)
|
||||
|
||||
|
||||
class CrossAttention(SelfAttention):
|
||||
def get_qkv(self, x, context):
|
||||
q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1)
|
||||
k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1)
|
||||
return q, k, v
|
||||
|
||||
def forward(self, x, context, transformer_options={}):
|
||||
q, k, v = self.get_qkv(x, context)
|
||||
out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options)
|
||||
return self.out_layer(out)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, ff_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
operations = operation_settings.get("operations")
|
||||
self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.activation = nn.GELU()
|
||||
self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.num_chunks = 4
|
||||
|
||||
def _forward(self, x):
|
||||
return self.out_layer(self.activation(self.in_layer(x)))
|
||||
|
||||
def _forward_chunked(self, x):
|
||||
chunks = torch.chunk(x, self.num_chunks, dim=1)
|
||||
output_chunks = []
|
||||
for chunk in chunks:
|
||||
output_chunks.append(self._forward(chunk))
|
||||
return torch.cat(output_chunks, dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
if x.shape[1] > 8192:
|
||||
return self._forward_chunked(x)
|
||||
else:
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class OutLayer(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, visual_embed, time_embed):
|
||||
B, T, H, W, _ = visual_embed.shape
|
||||
shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1)
|
||||
scale = scale[:, None, None, None, :]
|
||||
shift = shift[:, None, None, None, :]
|
||||
visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift)
|
||||
x = self.out_layer(visual_embed)
|
||||
|
||||
out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2])
|
||||
x = x.view(
|
||||
B, T, H, W,
|
||||
out_dim,
|
||||
self.patch_size[0], self.patch_size[1], self.patch_size[2]
|
||||
)
|
||||
return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5)
|
||||
|
||||
|
||||
class TransformerEncoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings)
|
||||
operations = operation_settings.get("operations")
|
||||
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, x, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1)
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift)
|
||||
out = self.self_attention(out, freqs, transformer_options=transformer_options)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift)
|
||||
out = self.feed_forward(out)
|
||||
x = apply_gate_sum(x, out, gate)
|
||||
return x
|
||||
|
||||
|
||||
class TransformerDecoderBlock(nn.Module):
|
||||
def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None):
|
||||
super().__init__()
|
||||
self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings)
|
||||
|
||||
operations = operation_settings.get("operations")
|
||||
self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings)
|
||||
|
||||
self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings)
|
||||
|
||||
def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}):
|
||||
self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1)
|
||||
# self attention
|
||||
shift, scale, gate = get_shift_scale_gate(self_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# cross attention
|
||||
shift, scale, gate = get_shift_scale_gate(cross_attn_params)
|
||||
visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift)
|
||||
visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
# feed forward
|
||||
shift, scale, gate = get_shift_scale_gate(ff_params)
|
||||
visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift)
|
||||
visual_out = self.feed_forward(visual_out)
|
||||
visual_embed = apply_gate_sum(visual_embed, visual_out, gate)
|
||||
return visual_embed
|
||||
|
||||
|
||||
class Kandinsky5(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512,
|
||||
model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32,
|
||||
axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0),
|
||||
dtype=None, device=None, operations=None, **kwargs
|
||||
):
|
||||
super().__init__()
|
||||
head_dim = sum(axes_dims)
|
||||
self.rope_scale_factor = rope_scale_factor
|
||||
self.in_visual_dim = in_visual_dim
|
||||
self.model_dim = model_dim
|
||||
self.patch_size = patch_size
|
||||
self.visual_embed_dim = visual_embed_dim
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings)
|
||||
self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings)
|
||||
self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings)
|
||||
self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.text_transformer_blocks = nn.ModuleList(
|
||||
[TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)]
|
||||
)
|
||||
|
||||
self.visual_transformer_blocks = nn.ModuleList(
|
||||
[TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)]
|
||||
)
|
||||
|
||||
self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings)
|
||||
|
||||
self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims)
|
||||
self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim])
|
||||
|
||||
def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}):
|
||||
steps = seq_len if steps is None else steps
|
||||
seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype)
|
||||
seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1)
|
||||
freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||
|
||||
if steps_t is None:
|
||||
steps_t = t_len
|
||||
if steps_h is None:
|
||||
steps_h = h_len
|
||||
if steps_w is None:
|
||||
steps_w = w_len
|
||||
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
if rope_options is not None:
|
||||
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||
|
||||
t_start += rope_options.get("shift_t", 0.0)
|
||||
h_start += rope_options.get("shift_y", 0.0)
|
||||
w_start += rope_options.get("shift_x", 0.0)
|
||||
else:
|
||||
rope_scale_factor = self.rope_scale_factor
|
||||
if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions
|
||||
if h * w >= 14080:
|
||||
rope_scale_factor = (1.0, 3.16, 3.16)
|
||||
|
||||
t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0
|
||||
h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0
|
||||
w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0
|
||||
|
||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder_3d(img_ids).movedim(1, 2)
|
||||
return freqs
|
||||
|
||||
def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs):
|
||||
patches_replace = transformer_options.get("patches_replace", {})
|
||||
context = self.text_embeddings(context)
|
||||
time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y)
|
||||
|
||||
for block in self.text_transformer_blocks:
|
||||
context = block(context, time_embed, freqs_text, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = self.visual_embeddings(x)
|
||||
visual_shape = visual_embed.shape[:-1]
|
||||
visual_embed = visual_embed.flatten(1, -2)
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.visual_transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.visual_transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options"))
|
||||
visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"]
|
||||
else:
|
||||
visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options)
|
||||
|
||||
visual_embed = visual_embed.reshape(*visual_shape, -1)
|
||||
return self.out_layer(visual_embed, time_embed)
|
||||
|
||||
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
original_dims = x.ndim
|
||||
if original_dims == 4:
|
||||
x = x.unsqueeze(2)
|
||||
bs, c, t_len, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if time_dim_replace is not None:
|
||||
time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size)
|
||||
x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace
|
||||
|
||||
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
|
||||
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
|
||||
if original_dims == 4:
|
||||
out = out.squeeze(2)
|
||||
return out
|
||||
|
||||
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
|
||||
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
|
||||
self._forward,
|
||||
self,
|
||||
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
|
||||
).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs)
|
||||
160
comfy/ldm/lumina/controlnet.py
Normal file
160
comfy/ldm/lumina/controlnet.py
Normal file
@ -0,0 +1,160 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .model import JointTransformerBlock
|
||||
|
||||
class ZImageControlTransformerBlock(JointTransformerBlock):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
multiple_of: int,
|
||||
ffn_dim_multiplier: float,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
block_id=0,
|
||||
operation_settings=None,
|
||||
):
|
||||
super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings)
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
|
||||
def forward(self, c, x, **kwargs):
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
c = super().forward(c, **kwargs)
|
||||
c_skip = self.after_proj(c)
|
||||
return c_skip, c
|
||||
|
||||
class ZImage_Control(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 3840,
|
||||
n_heads: int = 30,
|
||||
n_kv_heads: int = 30,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: float = (8.0 / 3.0),
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = True,
|
||||
n_control_layers=6,
|
||||
control_in_dim=16,
|
||||
additional_in_dim=0,
|
||||
broken=False,
|
||||
refiner_control=False,
|
||||
dtype=None,
|
||||
device=None,
|
||||
operations=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
|
||||
|
||||
self.broken = broken
|
||||
self.additional_in_dim = additional_in_dim
|
||||
self.control_in_dim = control_in_dim
|
||||
n_refiner_layers = 2
|
||||
self.n_control_layers = n_control_layers
|
||||
self.control_layers = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
i,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
block_id=i,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for i in range(self.n_control_layers)
|
||||
]
|
||||
)
|
||||
|
||||
all_x_embedder = {}
|
||||
patch_size = 2
|
||||
f_patch_size = 1
|
||||
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
|
||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||
|
||||
self.refiner_control = refiner_control
|
||||
|
||||
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||
if self.refiner_control:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
block_id=layer_id,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
multiple_of,
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=True,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
|
||||
patch_size = 2
|
||||
f_patch_size = 1
|
||||
pH = pW = patch_size
|
||||
B, C, H, W = control_context.shape
|
||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
x_attn_mask = None
|
||||
if not self.refiner_control:
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
|
||||
|
||||
return control_context
|
||||
|
||||
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
if self.refiner_control:
|
||||
if self.broken:
|
||||
if layer_id == 0:
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if layer_id > 0:
|
||||
out = None
|
||||
for i in range(1, len(self.control_layers)):
|
||||
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
if out is None:
|
||||
out = o
|
||||
|
||||
return (out, control_context)
|
||||
else:
|
||||
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
else:
|
||||
return (None, control_context)
|
||||
|
||||
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
|
||||
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
|
||||
@ -11,6 +11,7 @@ import comfy.ldm.common_dit
|
||||
from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope
|
||||
import comfy.patcher_extension
|
||||
|
||||
|
||||
@ -21,6 +22,10 @@ def modulate(x, scale):
|
||||
# Core NextDiT Model #
|
||||
#############################################################################
|
||||
|
||||
def clamp_fp16(x):
|
||||
if x.dtype == torch.float16:
|
||||
return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
|
||||
return x
|
||||
|
||||
class JointAttention(nn.Module):
|
||||
"""Multi-head attention module."""
|
||||
@ -31,6 +36,7 @@ class JointAttention(nn.Module):
|
||||
n_heads: int,
|
||||
n_kv_heads: Optional[int],
|
||||
qk_norm: bool,
|
||||
out_bias: bool = False,
|
||||
operation_settings={},
|
||||
):
|
||||
"""
|
||||
@ -59,7 +65,7 @@ class JointAttention(nn.Module):
|
||||
self.out = operation_settings.get("operations").Linear(
|
||||
n_heads * self.head_dim,
|
||||
dim,
|
||||
bias=False,
|
||||
bias=out_bias,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
@ -70,35 +76,6 @@ class JointAttention(nn.Module):
|
||||
else:
|
||||
self.q_norm = self.k_norm = nn.Identity()
|
||||
|
||||
@staticmethod
|
||||
def apply_rotary_emb(
|
||||
x_in: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply rotary embeddings to input tensors using the given frequency
|
||||
tensor.
|
||||
|
||||
This function applies rotary embeddings to the given query 'xq' and
|
||||
key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
|
||||
input tensors are reshaped as complex numbers, and the frequency tensor
|
||||
is reshaped for broadcasting compatibility. The resulting tensors
|
||||
contain rotary embeddings and are returned as real tensors.
|
||||
|
||||
Args:
|
||||
x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
|
||||
freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
|
||||
exponentials.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
|
||||
and key tensor with rotary embeddings.
|
||||
"""
|
||||
|
||||
t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x_in.shape)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
@ -134,8 +111,7 @@ class JointAttention(nn.Module):
|
||||
xq = self.q_norm(xq)
|
||||
xk = self.k_norm(xk)
|
||||
|
||||
xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
|
||||
xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
|
||||
xq, xk = apply_rope(xq, xk, freqs_cis)
|
||||
|
||||
n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
if n_rep >= 1:
|
||||
@ -197,7 +173,7 @@ class FeedForward(nn.Module):
|
||||
|
||||
# @torch.compile
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
return clamp_fp16(F.silu(x1) * x3)
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
@ -215,6 +191,8 @@ class JointTransformerBlock(nn.Module):
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
z_image_modulation=False,
|
||||
attn_out_bias=False,
|
||||
operation_settings={},
|
||||
) -> None:
|
||||
"""
|
||||
@ -235,10 +213,10 @@ class JointTransformerBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
|
||||
self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings)
|
||||
self.feed_forward = FeedForward(
|
||||
dim=dim,
|
||||
hidden_dim=4 * dim,
|
||||
hidden_dim=dim,
|
||||
multiple_of=multiple_of,
|
||||
ffn_dim_multiplier=ffn_dim_multiplier,
|
||||
operation_settings=operation_settings,
|
||||
@ -252,16 +230,27 @@ class JointTransformerBlock(nn.Module):
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
if z_image_modulation:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 256),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024),
|
||||
4 * dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -288,27 +277,27 @@ class JointTransformerBlock(nn.Module):
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
||||
|
||||
x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
|
||||
self.attention(
|
||||
clamp_fp16(self.attention(
|
||||
modulate(self.attention_norm1(x), scale_msa),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
))
|
||||
)
|
||||
x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
clamp_fp16(self.feed_forward(
|
||||
modulate(self.ffn_norm1(x), scale_mlp),
|
||||
)
|
||||
))
|
||||
)
|
||||
else:
|
||||
assert adaln_input is None
|
||||
x = x + self.attention_norm2(
|
||||
self.attention(
|
||||
clamp_fp16(self.attention(
|
||||
self.attention_norm1(x),
|
||||
x_mask,
|
||||
freqs_cis,
|
||||
transformer_options=transformer_options,
|
||||
)
|
||||
))
|
||||
)
|
||||
x = x + self.ffn_norm2(
|
||||
self.feed_forward(
|
||||
@ -323,7 +312,7 @@ class FinalLayer(nn.Module):
|
||||
The final layer of NextDiT.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
|
||||
def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings={}):
|
||||
super().__init__()
|
||||
self.norm_final = operation_settings.get("operations").LayerNorm(
|
||||
hidden_size,
|
||||
@ -340,10 +329,15 @@ class FinalLayer(nn.Module):
|
||||
dtype=operation_settings.get("dtype"),
|
||||
)
|
||||
|
||||
if z_image_modulation:
|
||||
min_mod = 256
|
||||
else:
|
||||
min_mod = 1024
|
||||
|
||||
self.adaLN_modulation = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(hidden_size, 1024),
|
||||
min(hidden_size, min_mod),
|
||||
hidden_size,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
@ -373,12 +367,17 @@ class NextDiT(nn.Module):
|
||||
n_heads: int = 32,
|
||||
n_kv_heads: Optional[int] = None,
|
||||
multiple_of: int = 256,
|
||||
ffn_dim_multiplier: Optional[float] = None,
|
||||
ffn_dim_multiplier: float = 4.0,
|
||||
norm_eps: float = 1e-5,
|
||||
qk_norm: bool = False,
|
||||
cap_feat_dim: int = 5120,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (1, 512, 512),
|
||||
rope_theta=10000.0,
|
||||
z_image_modulation=False,
|
||||
time_scale=1.0,
|
||||
pad_tokens_multiple=None,
|
||||
clip_text_dim=None,
|
||||
image_model=None,
|
||||
device=None,
|
||||
dtype=None,
|
||||
@ -390,6 +389,8 @@ class NextDiT(nn.Module):
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.time_scale = time_scale
|
||||
self.pad_tokens_multiple = pad_tokens_multiple
|
||||
|
||||
self.x_embedder = operation_settings.get("operations").Linear(
|
||||
in_features=patch_size * patch_size * in_channels,
|
||||
@ -411,6 +412,7 @@ class NextDiT(nn.Module):
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
z_image_modulation=z_image_modulation,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
@ -434,7 +436,7 @@ class NextDiT(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
|
||||
self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings)
|
||||
self.cap_embedder = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
@ -446,6 +448,31 @@ class NextDiT(nn.Module):
|
||||
),
|
||||
)
|
||||
|
||||
self.clip_text_pooled_proj = None
|
||||
|
||||
if clip_text_dim is not None:
|
||||
self.clip_text_dim = clip_text_dim
|
||||
self.clip_text_pooled_proj = nn.Sequential(
|
||||
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
|
||||
operation_settings.get("operations").Linear(
|
||||
clip_text_dim,
|
||||
clip_text_dim,
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
self.time_text_embed = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
operation_settings.get("operations").Linear(
|
||||
min(dim, 1024) + clip_text_dim,
|
||||
min(dim, 1024),
|
||||
bias=True,
|
||||
device=operation_settings.get("device"),
|
||||
dtype=operation_settings.get("dtype"),
|
||||
),
|
||||
)
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
JointTransformerBlock(
|
||||
@ -457,18 +484,24 @@ class NextDiT(nn.Module):
|
||||
ffn_dim_multiplier,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
z_image_modulation=z_image_modulation,
|
||||
attn_out_bias=False,
|
||||
operation_settings=operation_settings,
|
||||
)
|
||||
for layer_id in range(n_layers)
|
||||
]
|
||||
)
|
||||
self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
|
||||
self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings)
|
||||
|
||||
if self.pad_tokens_multiple is not None:
|
||||
self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype))
|
||||
|
||||
assert (dim // n_heads) == sum(axes_dims)
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
|
||||
self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims)
|
||||
self.dim = dim
|
||||
self.n_heads = n_heads
|
||||
|
||||
@ -503,108 +536,63 @@ class NextDiT(nn.Module):
|
||||
bsz = len(x)
|
||||
pH = pW = self.patch_size
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
orig_x = x
|
||||
|
||||
if cap_mask is not None:
|
||||
l_effective_cap_len = cap_mask.sum(dim=1).tolist()
|
||||
else:
|
||||
l_effective_cap_len = [num_tokens] * bsz
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
|
||||
cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1)
|
||||
|
||||
if cap_mask is not None and not torch.is_floating_point(cap_mask):
|
||||
cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
|
||||
cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device)
|
||||
cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0
|
||||
|
||||
img_sizes = [(img.size(1), img.size(2)) for img in x]
|
||||
l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
|
||||
B, C, H, W = x.shape
|
||||
x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
|
||||
|
||||
max_seq_len = max(
|
||||
(cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
|
||||
)
|
||||
max_cap_len = max(l_effective_cap_len)
|
||||
max_img_len = max(l_effective_img_len)
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
|
||||
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
H, W = img_sizes[i]
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
assert H_tokens * W_tokens == img_len
|
||||
H_tokens, W_tokens = H // pH, W // pW
|
||||
x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device)
|
||||
x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1
|
||||
x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
|
||||
rope_options = transformer_options.get("rope_options", None)
|
||||
h_scale = 1.0
|
||||
w_scale = 1.0
|
||||
h_start = 0
|
||||
w_start = 0
|
||||
if rope_options is not None:
|
||||
h_scale = rope_options.get("scale_y", 1.0)
|
||||
w_scale = rope_options.get("scale_x", 1.0)
|
||||
if self.pad_tokens_multiple is not None:
|
||||
pad_extra = (-x.shape[1]) % self.pad_tokens_multiple
|
||||
x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1)
|
||||
x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra))
|
||||
|
||||
h_start = rope_options.get("shift_y", 0.0)
|
||||
w_start = rope_options.get("shift_x", 0.0)
|
||||
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
|
||||
|
||||
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
|
||||
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
|
||||
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
|
||||
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
|
||||
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
|
||||
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
|
||||
|
||||
freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
|
||||
|
||||
# build freqs_cis for cap and image individually
|
||||
cap_freqs_cis_shape = list(freqs_cis.shape)
|
||||
# cap_freqs_cis_shape[1] = max_cap_len
|
||||
cap_freqs_cis_shape[1] = cap_feats.shape[1]
|
||||
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
img_freqs_cis_shape = list(freqs_cis.shape)
|
||||
img_freqs_cis_shape[1] = max_img_len
|
||||
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
|
||||
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
|
||||
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
|
||||
patches = transformer_options.get("patches", {})
|
||||
|
||||
# refine context
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options)
|
||||
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
|
||||
|
||||
# refine image
|
||||
flat_x = []
|
||||
for i in range(bsz):
|
||||
img = x[i]
|
||||
C, H, W = img.size()
|
||||
img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
|
||||
flat_x.append(img)
|
||||
x = flat_x
|
||||
padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
|
||||
padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
|
||||
for i in range(bsz):
|
||||
padded_img_embed[i, :l_effective_img_len[i]] = x[i]
|
||||
padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
|
||||
|
||||
padded_img_embed = self.x_embedder(padded_img_embed)
|
||||
padded_img_mask = padded_img_mask.unsqueeze(1)
|
||||
for layer in self.noise_refiner:
|
||||
padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options)
|
||||
|
||||
if cap_mask is not None:
|
||||
mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
|
||||
mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
|
||||
else:
|
||||
mask = None
|
||||
|
||||
padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
|
||||
for i in range(bsz):
|
||||
cap_len = l_effective_cap_len[i]
|
||||
img_len = l_effective_img_len[i]
|
||||
|
||||
padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
|
||||
padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
|
||||
padded_img_mask = None
|
||||
x_input = x
|
||||
for i, layer in enumerate(self.noise_refiner):
|
||||
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
|
||||
if "noise_refiner" in patches:
|
||||
for p in patches["noise_refiner"]:
|
||||
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
|
||||
if "img" in out:
|
||||
x = out["img"]
|
||||
|
||||
padded_full_embed = torch.cat((cap_feats, x), dim=1)
|
||||
mask = None
|
||||
img_sizes = [(H, W)] * bsz
|
||||
l_effective_cap_len = [cap_feats.shape[1]] * bsz
|
||||
return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
|
||||
|
||||
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
@ -615,7 +603,7 @@ class NextDiT(nn.Module):
|
||||
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
|
||||
|
||||
# def forward(self, x, t, cap_feats, cap_mask):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
|
||||
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options={}, **kwargs):
|
||||
t = 1.0 - timesteps
|
||||
cap_feats = context
|
||||
cap_mask = attention_mask
|
||||
@ -627,21 +615,38 @@ class NextDiT(nn.Module):
|
||||
y: (N,) tensor of text tokens/features
|
||||
"""
|
||||
|
||||
t = self.t_embedder(t, dtype=x.dtype) # (N, D)
|
||||
t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D)
|
||||
adaln_input = t
|
||||
|
||||
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
|
||||
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
if self.clip_text_pooled_proj is not None:
|
||||
pooled = kwargs.get("clip_text_pooled", None)
|
||||
if pooled is not None:
|
||||
pooled = self.clip_text_pooled_proj(pooled)
|
||||
else:
|
||||
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
|
||||
|
||||
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
|
||||
|
||||
patches = transformer_options.get("patches", {})
|
||||
x_is_tensor = isinstance(x, torch.Tensor)
|
||||
x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(x.device)
|
||||
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
|
||||
freqs_cis = freqs_cis.to(img.device)
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
img_input = img
|
||||
for i, layer in enumerate(self.layers):
|
||||
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
|
||||
if "double_block" in patches:
|
||||
for p in patches["double_block"]:
|
||||
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
|
||||
if "img" in out:
|
||||
img[:, cap_size[0]:] = out["img"]
|
||||
if "txt" in out:
|
||||
img[:, :cap_size[0]] = out["txt"]
|
||||
|
||||
x = self.final_layer(x, adaln_input)
|
||||
x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
|
||||
img = self.final_layer(img, adaln_input)
|
||||
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
|
||||
|
||||
return -x
|
||||
return -img
|
||||
|
||||
|
||||
@ -9,6 +9,8 @@ from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistri
|
||||
from comfy.ldm.util import get_obj_from_str, instantiate_from_config
|
||||
from comfy.ldm.modules.ema import LitEma
|
||||
import comfy.ops
|
||||
from einops import rearrange
|
||||
import comfy.model_management
|
||||
|
||||
class DiagonalGaussianRegularizer(torch.nn.Module):
|
||||
def __init__(self, sample: bool = False):
|
||||
@ -179,6 +181,21 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
if ddconfig.get("batch_norm_latent", False):
|
||||
self.bn_eps = 1e-4
|
||||
self.bn_momentum = 0.1
|
||||
self.ps = [2, 2]
|
||||
self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"],
|
||||
eps=self.bn_eps,
|
||||
momentum=self.bn_momentum,
|
||||
affine=False,
|
||||
track_running_stats=True,
|
||||
)
|
||||
self.bn.eval()
|
||||
else:
|
||||
self.bn = None
|
||||
|
||||
|
||||
def get_autoencoder_params(self) -> list:
|
||||
params = super().get_autoencoder_params()
|
||||
return params
|
||||
@ -201,11 +218,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine):
|
||||
z = torch.cat(z, 0)
|
||||
|
||||
z, reg_log = self.regularization(z)
|
||||
|
||||
if self.bn is not None:
|
||||
z = rearrange(z,
|
||||
"... c (i pi) (j pj) -> ... (c pi pj) i j",
|
||||
pi=self.ps[0],
|
||||
pj=self.ps[1],
|
||||
)
|
||||
|
||||
z = torch.nn.functional.batch_norm(z,
|
||||
comfy.model_management.cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device),
|
||||
comfy.model_management.cast_to(self.bn.running_var, dtype=z.dtype, device=z.device),
|
||||
momentum=self.bn_momentum,
|
||||
eps=self.bn_eps)
|
||||
|
||||
if return_reg_log:
|
||||
return z, reg_log
|
||||
return z
|
||||
|
||||
def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor:
|
||||
if self.bn is not None:
|
||||
s = torch.sqrt(comfy.model_management.cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps)
|
||||
m = comfy.model_management.cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device)
|
||||
z = z * s + m
|
||||
z = rearrange(
|
||||
z,
|
||||
"... (c pi pj) i j -> ... c (i pi) (j pj)",
|
||||
pi=self.ps[0],
|
||||
pj=self.ps[1],
|
||||
)
|
||||
|
||||
if self.max_batch_size is None:
|
||||
dec = self.post_quant_conv(z)
|
||||
dec = self.decoder(dec, **decoder_kwargs)
|
||||
|
||||
@ -517,6 +517,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
exception_fallback = False
|
||||
if skip_reshape:
|
||||
b, _, _, dim_head = q.shape
|
||||
tensor_layout = "HND"
|
||||
@ -541,6 +542,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
|
||||
out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
|
||||
except Exception as e:
|
||||
logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
|
||||
exception_fallback = True
|
||||
if exception_fallback:
|
||||
if tensor_layout == "NHD":
|
||||
q, k, v = map(
|
||||
lambda t: t.transpose(1, 2),
|
||||
|
||||
@ -211,12 +211,14 @@ class TimestepEmbedder(nn.Module):
|
||||
Embeds scalar timesteps into vector representations.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
|
||||
def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
if output_size is None:
|
||||
output_size = hidden_size
|
||||
self.mlp = nn.Sequential(
|
||||
operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
nn.SiLU(),
|
||||
operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device),
|
||||
operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device),
|
||||
)
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
|
||||
@ -13,6 +13,12 @@ if model_management.xformers_enabled_vae():
|
||||
import xformers
|
||||
import xformers.ops
|
||||
|
||||
def torch_cat_if_needed(xl, dim):
|
||||
if len(xl) > 1:
|
||||
return torch.cat(xl, dim)
|
||||
else:
|
||||
return xl[0]
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
@ -43,6 +49,37 @@ def Normalize(in_channels, num_groups=32):
|
||||
return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class CarriedConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs):
|
||||
super().__init__()
|
||||
self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None):
|
||||
|
||||
x = xl[0]
|
||||
xl.clear()
|
||||
|
||||
if isinstance(op, CarriedConv3d):
|
||||
if conv_carry_in is None:
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode = 'replicate')
|
||||
else:
|
||||
carry_len = conv_carry_in[0].shape[2]
|
||||
x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode = 'replicate')
|
||||
x = torch.cat([conv_carry_in.pop(0), x], dim=2)
|
||||
|
||||
if conv_carry_out is not None:
|
||||
to_push = x[:, :, -2:, :, :].clone()
|
||||
conv_carry_out.append(to_push)
|
||||
|
||||
out = op(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class VideoConv3d(nn.Module):
|
||||
def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs):
|
||||
super().__init__()
|
||||
@ -89,29 +126,24 @@ class Upsample(nn.Module):
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
scale_factor = self.scale_factor
|
||||
if isinstance(scale_factor, (int, float)):
|
||||
scale_factor = (scale_factor,) * (x.ndim - 2)
|
||||
|
||||
if x.ndim == 5 and scale_factor[0] > 1.0:
|
||||
t = x.shape[2]
|
||||
if t > 1:
|
||||
a, b = x.split((1, t - 1), dim=2)
|
||||
del x
|
||||
b = interpolate_up(b, scale_factor)
|
||||
else:
|
||||
a = x
|
||||
|
||||
a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)
|
||||
if t > 1:
|
||||
x = torch.cat((a, b), dim=2)
|
||||
else:
|
||||
x = a
|
||||
results = []
|
||||
if conv_carry_in is None:
|
||||
first = x[:, :, :1, :, :]
|
||||
results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2))
|
||||
x = x[:, :, 1:, :, :]
|
||||
if x.shape[2] > 0:
|
||||
results.append(interpolate_up(x, scale_factor))
|
||||
x = torch_cat_if_needed(results, dim=2)
|
||||
else:
|
||||
x = interpolate_up(x, scale_factor)
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
return x
|
||||
|
||||
|
||||
@ -127,17 +159,20 @@ class Downsample(nn.Module):
|
||||
stride=stride,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, conv_carry_in=None, conv_carry_out=None):
|
||||
if self.with_conv:
|
||||
if x.ndim == 4:
|
||||
if isinstance(self.conv, CarriedConv3d):
|
||||
x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out)
|
||||
elif x.ndim == 4:
|
||||
pad = (0, 1, 0, 1)
|
||||
mode = "constant"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode, value=0)
|
||||
x = self.conv(x)
|
||||
elif x.ndim == 5:
|
||||
pad = (1, 1, 1, 1, 2, 0)
|
||||
mode = "replicate"
|
||||
x = torch.nn.functional.pad(x, pad, mode=mode)
|
||||
x = self.conv(x)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
@ -183,23 +218,23 @@ class ResnetBlock(nn.Module):
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb=None):
|
||||
def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = self.swish(h)
|
||||
h = self.conv1(h)
|
||||
h = [ self.swish(h) ]
|
||||
h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.swish(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.swish(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
h = [ self.dropout(h) ]
|
||||
h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
@ -279,6 +314,7 @@ def pytorch_attention(q, k, v):
|
||||
orig_shape = q.shape
|
||||
B = orig_shape[0]
|
||||
C = orig_shape[1]
|
||||
oom_fallback = False
|
||||
q, k, v = map(
|
||||
lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(),
|
||||
(q, k, v),
|
||||
@ -289,6 +325,8 @@ def pytorch_attention(q, k, v):
|
||||
out = out.transpose(2, 3).reshape(orig_shape)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
oom_fallback = True
|
||||
if oom_fallback:
|
||||
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape)
|
||||
return out
|
||||
|
||||
@ -517,9 +555,14 @@ class Encoder(nn.Module):
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.carried = False
|
||||
|
||||
if conv3d:
|
||||
conv_op = VideoConv3d
|
||||
if not attn_resolutions:
|
||||
conv_op = CarriedConv3d
|
||||
self.carried = True
|
||||
else:
|
||||
conv_op = VideoConv3d
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
@ -532,6 +575,7 @@ class Encoder(nn.Module):
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
self.time_compress = 1
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
@ -558,10 +602,15 @@ class Encoder(nn.Module):
|
||||
if time_compress is not None:
|
||||
if (self.num_resolutions - 1 - i_level) > math.log2(time_compress):
|
||||
stride = (1, 2, 2)
|
||||
else:
|
||||
self.time_compress *= 2
|
||||
down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
if time_compress is not None:
|
||||
self.time_compress = time_compress
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
@ -587,15 +636,42 @@ class Encoder(nn.Module):
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
# downsampling
|
||||
h = self.conv_in(x)
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](h, temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
h = self.down[i_level].downsample(h)
|
||||
|
||||
if self.carried:
|
||||
xl = [x[:, :, :1, :, :]]
|
||||
if x.shape[2] > self.time_compress:
|
||||
tc = self.time_compress
|
||||
xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim = 2)
|
||||
x = xl
|
||||
else:
|
||||
x = [x]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
for i, x1 in enumerate(x):
|
||||
conv_carry_out = []
|
||||
if i == len(x) - 1:
|
||||
conv_carry_out = None
|
||||
|
||||
# downsampling
|
||||
x1 = [ x1 ]
|
||||
h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out)
|
||||
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
assert i == 0 #carried should not happen if attn exists
|
||||
h1 = self.down[i_level].attn[i_block](h1)
|
||||
if i_level != self.num_resolutions-1:
|
||||
h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out)
|
||||
|
||||
out.append(h1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
h = torch_cat_if_needed(out, dim=2)
|
||||
del out
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
@ -604,15 +680,15 @@ class Encoder(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
h = [ nonlinearity(h) ]
|
||||
h = conv_carry_causal_3d(h, self.conv_out)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
resolution, z_channels, tanh_out=False, use_linear_attn=False,
|
||||
conv_out_op=ops.Conv2d,
|
||||
resnet_op=ResnetBlock,
|
||||
attn_op=AttnBlock,
|
||||
@ -626,12 +702,18 @@ class Decoder(nn.Module):
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
self.carried = False
|
||||
|
||||
if conv3d:
|
||||
conv_op = VideoConv3d
|
||||
conv_out_op = VideoConv3d
|
||||
if not attn_resolutions and resnet_op == ResnetBlock:
|
||||
conv_op = CarriedConv3d
|
||||
conv_out_op = CarriedConv3d
|
||||
self.carried = True
|
||||
else:
|
||||
conv_op = VideoConv3d
|
||||
conv_out_op = VideoConv3d
|
||||
|
||||
mid_attn_conv_op = ops.Conv3d
|
||||
else:
|
||||
conv_op = ops.Conv2d
|
||||
@ -706,29 +788,43 @@ class Decoder(nn.Module):
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
h = conv_carry_causal_3d([z], self.conv_in)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb, **kwargs)
|
||||
h = self.mid.attn_1(h, **kwargs)
|
||||
h = self.mid.block_2(h, temb, **kwargs)
|
||||
|
||||
if self.carried:
|
||||
h = torch.split(h, 2, dim=2)
|
||||
else:
|
||||
h = [ h ]
|
||||
out = []
|
||||
|
||||
conv_carry_in = None
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h, **kwargs)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
for i, h1 in enumerate(h):
|
||||
conv_carry_out = []
|
||||
if i == len(h) - 1:
|
||||
conv_carry_out = None
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
assert i == 0 #carried should not happen if attn exists
|
||||
h1 = self.up[i_level].attn[i_block](h1, **kwargs)
|
||||
if i_level != 0:
|
||||
h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
h1 = self.norm_out(h1)
|
||||
h1 = [ nonlinearity(h1) ]
|
||||
h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out)
|
||||
if self.tanh_out:
|
||||
h1 = torch.tanh(h1)
|
||||
out.append(h1)
|
||||
conv_carry_in = conv_carry_out
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h, **kwargs)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
out = torch_cat_if_needed(out, dim=2)
|
||||
|
||||
return out
|
||||
|
||||
@ -439,7 +439,10 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
patches = transformer_options.get("patches", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
transformer_options["block_index"] = i
|
||||
if ("double_block", i) in blocks_replace:
|
||||
def block_wrap(args):
|
||||
out = {}
|
||||
|
||||
@ -313,6 +313,23 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k #SimpleTuner lycoris format
|
||||
|
||||
if isinstance(model, comfy.model_base.Lumina2):
|
||||
diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
|
||||
for k in diffusers_keys:
|
||||
if k.endswith(".weight"):
|
||||
to = diffusers_keys[k]
|
||||
key_lora = k[:-len(".weight")]
|
||||
key_map["diffusion_model.{}".format(key_lora)] = to
|
||||
key_map["transformer.{}".format(key_lora)] = to
|
||||
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
|
||||
|
||||
if isinstance(model, comfy.model_base.Kandinsky5):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@ -47,6 +47,7 @@ import comfy.ldm.chroma_radiance.model
|
||||
import comfy.ldm.ace.model
|
||||
import comfy.ldm.omnigen.omnigen2
|
||||
import comfy.ldm.qwen_image.model
|
||||
import comfy.ldm.kandinsky5.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -134,7 +135,7 @@ class BaseModel(torch.nn.Module):
|
||||
if not unet_config.get("disable_unet_model_creation", False):
|
||||
if model_config.custom_operations is None:
|
||||
fp8 = model_config.optimizations.get("fp8", False)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
|
||||
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
|
||||
else:
|
||||
operations = model_config.custom_operations
|
||||
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
|
||||
@ -329,18 +330,6 @@ class BaseModel(torch.nn.Module):
|
||||
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
|
||||
|
||||
unet_state_dict = self.diffusion_model.state_dict()
|
||||
|
||||
if self.model_config.scaled_fp8 is not None:
|
||||
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
|
||||
|
||||
# Save mixed precision metadata
|
||||
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
|
||||
metadata = {
|
||||
"format_version": "1.0",
|
||||
"layers": self.model_config.layer_quant_config
|
||||
}
|
||||
unet_state_dict["_quantization_metadata"] = metadata
|
||||
|
||||
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
|
||||
|
||||
if self.model_type == ModelType.V_PREDICTION:
|
||||
@ -898,12 +887,13 @@ class Flux(BaseModel):
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
shape = kwargs["noise"].shape
|
||||
mask_ref_size = kwargs["attention_mask_img_shape"]
|
||||
# the model will pad to the patch size, and then divide
|
||||
# essentially dividing and rounding up
|
||||
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
mask_ref_size = kwargs.get("attention_mask_img_shape", None)
|
||||
if mask_ref_size is not None:
|
||||
# the model will pad to the patch size, and then divide
|
||||
# essentially dividing and rounding up
|
||||
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
|
||||
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
|
||||
guidance = kwargs.get("guidance", 3.5)
|
||||
if guidance is not None:
|
||||
@ -925,9 +915,19 @@ class Flux(BaseModel):
|
||||
out = {}
|
||||
ref_latents = kwargs.get("reference_latents", None)
|
||||
if ref_latents is not None:
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16])
|
||||
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
|
||||
return out
|
||||
|
||||
class Flux2(Flux):
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
target_text_len = 512
|
||||
if cross_attn.shape[1] < target_text_len:
|
||||
cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0))
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
return out
|
||||
|
||||
class GenmoMochi(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
@ -1103,9 +1103,17 @@ class Lumina2(BaseModel):
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
|
||||
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
if 'num_tokens' not in out:
|
||||
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
|
||||
|
||||
clip_text_pooled = kwargs["pooled_output"] # Newbie
|
||||
if clip_text_pooled is not None:
|
||||
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
|
||||
|
||||
return out
|
||||
|
||||
class WAN21(BaseModel):
|
||||
@ -1536,3 +1544,140 @@ class HunyuanImage21Refiner(HunyuanImage21):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out['disable_time_r'] = comfy.conds.CONDConstant(True)
|
||||
return out
|
||||
|
||||
class HunyuanVideo15(HunyuanVideo):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 #noise 32 img cond 32 + mask 1
|
||||
if extra_channels == 0:
|
||||
return None
|
||||
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
shape_image = list(noise.shape)
|
||||
shape_image[1] = extra_channels
|
||||
image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
|
||||
else:
|
||||
latent_dim = self.latent_format.latent_channels
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
for i in range(0, image.shape[1], latent_dim):
|
||||
image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim])
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :1]
|
||||
else:
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
if torch.numel(attention_mask) != attention_mask.sum():
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
conditioning_byt5small = kwargs.get("conditioning_byt5small", None)
|
||||
if conditioning_byt5small is not None:
|
||||
out['txt_byt5'] = comfy.conds.CONDRegular(conditioning_byt5small)
|
||||
|
||||
guidance = kwargs.get("guidance", 6.0)
|
||||
if guidance is not None:
|
||||
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
|
||||
|
||||
clip_vision_output = kwargs.get("clip_vision_output", None)
|
||||
if clip_vision_output is not None:
|
||||
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.last_hidden_state)
|
||||
|
||||
return out
|
||||
|
||||
class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
image = kwargs.get("concat_latent_image", None)
|
||||
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
|
||||
device = kwargs["device"]
|
||||
|
||||
if image is None:
|
||||
image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=comfy.model_management.intermediate_device())
|
||||
else:
|
||||
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
#image = self.process_latent_in(image) # scaling wasn't applied in reference code
|
||||
image = utils.resize_to_batch_size(image, noise.shape[0])
|
||||
lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1)
|
||||
if noise_augmentation > 0:
|
||||
generator = torch.Generator(device="cpu")
|
||||
generator.manual_seed(kwargs.get("seed", 0) - 10)
|
||||
noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
|
||||
image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice]
|
||||
else:
|
||||
image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice]
|
||||
return image
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
out['disable_time_r'] = comfy.conds.CONDConstant(False)
|
||||
return out
|
||||
|
||||
class Kandinsky5(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.kandinsky5.model.Kandinsky5)
|
||||
|
||||
def encode_adm(self, **kwargs):
|
||||
return kwargs["pooled_output"]
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
noise = kwargs.get("noise", None)
|
||||
device = kwargs["device"]
|
||||
image = torch.zeros_like(noise)
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
mask = torch.zeros_like(noise)[:, :1]
|
||||
else:
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
|
||||
return torch.cat((image, mask), dim=1)
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
if attention_mask is not None:
|
||||
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
|
||||
cross_attn = kwargs.get("cross_attn", None)
|
||||
if cross_attn is not None:
|
||||
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
|
||||
|
||||
time_dim_replace = kwargs.get("time_dim_replace", None)
|
||||
if time_dim_replace is not None:
|
||||
out['time_dim_replace'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_replace))
|
||||
|
||||
return out
|
||||
|
||||
class Kandinsky5Image(Kandinsky5):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
return None
|
||||
|
||||
@ -6,20 +6,6 @@ import math
|
||||
import logging
|
||||
import torch
|
||||
|
||||
|
||||
def detect_layer_quantization(metadata):
|
||||
quant_key = "_quantization_metadata"
|
||||
if metadata is not None and quant_key in metadata:
|
||||
quant_metadata = metadata.pop(quant_key)
|
||||
quant_metadata = json.loads(quant_metadata)
|
||||
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
|
||||
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
|
||||
return quant_metadata["layers"]
|
||||
else:
|
||||
raise ValueError("Invalid quantization metadata format")
|
||||
return None
|
||||
|
||||
|
||||
def count_blocks(state_dict_keys, prefix_string):
|
||||
count = 0
|
||||
while True:
|
||||
@ -186,30 +172,73 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
|
||||
guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys))
|
||||
dit_config["guidance_embed"] = len(guidance_keys) > 0
|
||||
|
||||
# HunyuanVideo 1.5
|
||||
if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["use_cond_type_embedding"] = True
|
||||
else:
|
||||
dit_config["use_cond_type_embedding"] = False
|
||||
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["meanflow_sum"] = True
|
||||
else:
|
||||
dit_config["vision_in_dim"] = None
|
||||
dit_config["meanflow_sum"] = False
|
||||
return dit_config
|
||||
|
||||
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "flux"
|
||||
if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["image_model"] = "flux2"
|
||||
dit_config["axes_dim"] = [32, 32, 32, 32]
|
||||
dit_config["num_heads"] = 48
|
||||
dit_config["mlp_ratio"] = 3.0
|
||||
dit_config["theta"] = 2000
|
||||
dit_config["out_channels"] = 128
|
||||
dit_config["global_modulation"] = True
|
||||
dit_config["mlp_silu_act"] = True
|
||||
dit_config["qkv_bias"] = False
|
||||
dit_config["ops_bias"] = False
|
||||
dit_config["default_ref_method"] = "index"
|
||||
dit_config["ref_index_scale"] = 10.0
|
||||
dit_config["txt_ids_dims"] = [3]
|
||||
patch_size = 1
|
||||
else:
|
||||
dit_config["image_model"] = "flux"
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["num_heads"] = 24
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["out_channels"] = 16
|
||||
dit_config["qkv_bias"] = True
|
||||
dit_config["txt_ids_dims"] = []
|
||||
patch_size = 2
|
||||
|
||||
dit_config["in_channels"] = 16
|
||||
patch_size = 2
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["context_in_dim"] = 4096
|
||||
|
||||
dit_config["patch_size"] = patch_size
|
||||
in_key = "{}img_in.weight".format(key_prefix)
|
||||
if in_key in state_dict_keys:
|
||||
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
|
||||
dit_config["out_channels"] = 16
|
||||
w = state_dict[in_key]
|
||||
dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size)
|
||||
dit_config["hidden_size"] = w.shape[0]
|
||||
|
||||
txt_in_key = "{}txt_in.weight".format(key_prefix)
|
||||
if txt_in_key in state_dict_keys:
|
||||
w = state_dict[txt_in_key]
|
||||
dit_config["context_in_dim"] = w.shape[1]
|
||||
dit_config["hidden_size"] = w.shape[0]
|
||||
|
||||
vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
|
||||
if vec_in_key in state_dict_keys:
|
||||
dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
|
||||
dit_config["context_in_dim"] = 4096
|
||||
dit_config["hidden_size"] = 3072
|
||||
dit_config["mlp_ratio"] = 4.0
|
||||
dit_config["num_heads"] = 24
|
||||
else:
|
||||
dit_config["vec_in_dim"] = None
|
||||
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["axes_dim"] = [16, 56, 56]
|
||||
dit_config["theta"] = 10000
|
||||
dit_config["qkv_bias"] = True
|
||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||
dit_config["image_model"] = "chroma"
|
||||
dit_config["in_channels"] = 64
|
||||
@ -230,8 +259,17 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["nerf_tile_size"] = 512
|
||||
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
|
||||
dit_config["nerf_embedder_dtype"] = torch.float32
|
||||
if "__x0__" in state_dict_keys: # x0 pred
|
||||
dit_config["use_x0"] = True
|
||||
else:
|
||||
dit_config["use_x0"] = False
|
||||
else:
|
||||
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
|
||||
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys
|
||||
if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model
|
||||
dit_config["txt_ids_dims"] = [1, 2]
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
|
||||
@ -378,14 +416,34 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["image_model"] = "lumina2"
|
||||
dit_config["patch_size"] = 2
|
||||
dit_config["in_channels"] = 16
|
||||
dit_config["dim"] = 2304
|
||||
dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1]
|
||||
w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)]
|
||||
dit_config["dim"] = w.shape[0]
|
||||
dit_config["cap_feat_dim"] = w.shape[1]
|
||||
dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.')
|
||||
dit_config["n_heads"] = 24
|
||||
dit_config["n_kv_heads"] = 8
|
||||
dit_config["qk_norm"] = True
|
||||
dit_config["axes_dims"] = [32, 32, 32]
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
|
||||
if dit_config["dim"] == 2304: # Original Lumina 2
|
||||
dit_config["n_heads"] = 24
|
||||
dit_config["n_kv_heads"] = 8
|
||||
dit_config["axes_dims"] = [32, 32, 32]
|
||||
dit_config["axes_lens"] = [300, 512, 512]
|
||||
dit_config["rope_theta"] = 10000.0
|
||||
dit_config["ffn_dim_multiplier"] = 4.0
|
||||
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
|
||||
if ctd_weight is not None:
|
||||
dit_config["clip_text_dim"] = ctd_weight.shape[0]
|
||||
elif dit_config["dim"] == 3840: # Z image
|
||||
dit_config["n_heads"] = 30
|
||||
dit_config["n_kv_heads"] = 30
|
||||
dit_config["axes_dims"] = [32, 48, 48]
|
||||
dit_config["axes_lens"] = [1536, 512, 512]
|
||||
dit_config["rope_theta"] = 256.0
|
||||
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
|
||||
dit_config["z_image_modulation"] = True
|
||||
dit_config["time_scale"] = 1000.0
|
||||
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["pad_tokens_multiple"] = 32
|
||||
|
||||
return dit_config
|
||||
|
||||
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
|
||||
@ -562,6 +620,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5
|
||||
dit_config = {}
|
||||
model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
dit_config["model_dim"] = model_dim
|
||||
if model_dim in [4096, 2560]: # pro video and lite image
|
||||
dit_config["axes_dims"] = (32, 48, 48)
|
||||
if model_dim == 2560: # lite image
|
||||
dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0)
|
||||
elif model_dim == 1792: # lite video
|
||||
dit_config["axes_dims"] = (16, 24, 24)
|
||||
dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0]
|
||||
dit_config["image_model"] = "kandinsky5"
|
||||
dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.')
|
||||
return dit_config
|
||||
|
||||
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
|
||||
return None
|
||||
|
||||
@ -704,22 +780,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
|
||||
if model_config is None and use_base_if_no_match:
|
||||
model_config = comfy.supported_models_base.BASE(unet_config)
|
||||
|
||||
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
|
||||
model_config.scaled_fp8 = scaled_fp8_weight.dtype
|
||||
if model_config.scaled_fp8 == torch.float32:
|
||||
model_config.scaled_fp8 = torch.float8_e4m3fn
|
||||
if scaled_fp8_weight.nelement() == 2:
|
||||
model_config.optimizations["fp8"] = False
|
||||
else:
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
# Detect per-layer quantization (mixed precision)
|
||||
layer_quant_config = detect_layer_quantization(metadata)
|
||||
if layer_quant_config:
|
||||
model_config.layer_quant_config = layer_quant_config
|
||||
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
|
||||
quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
|
||||
if quant_config:
|
||||
model_config.quant_config = quant_config
|
||||
logging.info("Detected mixed precision quantization")
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
loaded_memory = loaded_model.model_loaded_memory()
|
||||
current_free_mem = get_free_memory(torch_dev) + loaded_memory
|
||||
|
||||
lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
|
||||
lowvram_model_memory = lowvram_model_memory - loaded_memory
|
||||
|
||||
if lowvram_model_memory == 0:
|
||||
@ -1012,9 +1012,18 @@ def force_channels_last():
|
||||
|
||||
|
||||
STREAMS = {}
|
||||
NUM_STREAMS = 1
|
||||
if args.async_offload:
|
||||
NUM_STREAMS = 2
|
||||
NUM_STREAMS = 0
|
||||
if args.async_offload is not None:
|
||||
NUM_STREAMS = args.async_offload
|
||||
else:
|
||||
# Enable by default on Nvidia
|
||||
if is_nvidia():
|
||||
NUM_STREAMS = 2
|
||||
|
||||
if args.disable_async_offload:
|
||||
NUM_STREAMS = 0
|
||||
|
||||
if NUM_STREAMS > 0:
|
||||
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
|
||||
|
||||
def current_stream(device):
|
||||
@ -1030,7 +1039,10 @@ def current_stream(device):
|
||||
stream_counters = {}
|
||||
def get_offload_stream(device):
|
||||
stream_counter = stream_counters.get(device, 0)
|
||||
if NUM_STREAMS <= 1:
|
||||
if NUM_STREAMS == 0:
|
||||
return None
|
||||
|
||||
if torch.compiler.is_compiling():
|
||||
return None
|
||||
|
||||
if device in STREAMS:
|
||||
@ -1043,7 +1055,9 @@ def get_offload_stream(device):
|
||||
elif is_device_cuda(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.cuda.Stream(device=device, priority=0))
|
||||
s1 = torch.cuda.Stream(device=device, priority=0)
|
||||
s1.as_context = torch.cuda.stream
|
||||
ss.append(s1)
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counters[device] = stream_counter
|
||||
@ -1051,7 +1065,9 @@ def get_offload_stream(device):
|
||||
elif is_device_xpu(device):
|
||||
ss = []
|
||||
for k in range(NUM_STREAMS):
|
||||
ss.append(torch.xpu.Stream(device=device, priority=0))
|
||||
s1 = torch.xpu.Stream(device=device, priority=0)
|
||||
s1.as_context = torch.xpu.stream
|
||||
ss.append(s1)
|
||||
STREAMS[device] = ss
|
||||
s = ss[stream_counter]
|
||||
stream_counters[device] = stream_counter
|
||||
@ -1069,12 +1085,19 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
if dtype is None or weight.dtype == dtype:
|
||||
return weight
|
||||
if stream is not None:
|
||||
with stream:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
return weight.to(dtype=dtype, copy=copy)
|
||||
|
||||
|
||||
if stream is not None:
|
||||
with stream:
|
||||
wf_context = stream
|
||||
if hasattr(wf_context, "as_context"):
|
||||
wf_context = wf_context.as_context(stream)
|
||||
with wf_context:
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
else:
|
||||
@ -1098,13 +1121,14 @@ if not args.disable_pinned_memory:
|
||||
MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95
|
||||
logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024)))
|
||||
|
||||
PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"])
|
||||
|
||||
def pin_memory(tensor):
|
||||
global TOTAL_PINNED_MEMORY
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
|
||||
if type(tensor) is not torch.nn.parameter.Parameter:
|
||||
if type(tensor).__name__ not in PINNING_ALLOWED_TYPES:
|
||||
return False
|
||||
|
||||
if not is_device_cpu(tensor.device):
|
||||
@ -1124,6 +1148,9 @@ def pin_memory(tensor):
|
||||
return False
|
||||
|
||||
ptr = tensor.data_ptr()
|
||||
if ptr == 0:
|
||||
return False
|
||||
|
||||
if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0:
|
||||
PINNED_MEMORY[ptr] = size
|
||||
TOTAL_PINNED_MEMORY += size
|
||||
@ -1465,6 +1492,20 @@ def extended_fp16_support():
|
||||
|
||||
return True
|
||||
|
||||
LORA_COMPUTE_DTYPES = {}
|
||||
def lora_compute_dtype(device):
|
||||
dtype = LORA_COMPUTE_DTYPES.get(device, None)
|
||||
if dtype is not None:
|
||||
return dtype
|
||||
|
||||
if should_use_fp16(device):
|
||||
dtype = torch.float16
|
||||
else:
|
||||
dtype = torch.float32
|
||||
|
||||
LORA_COMPUTE_DTYPES[device] = dtype
|
||||
return dtype
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
|
||||
@ -35,6 +35,7 @@ import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
import comfy.utils
|
||||
from comfy.comfy_types import UnetWrapperFunction
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
|
||||
|
||||
|
||||
@ -126,27 +127,23 @@ class LowVramPatch:
|
||||
def __init__(self, key, patches, convert_func=None, set_func=None):
|
||||
self.key = key
|
||||
self.patches = patches
|
||||
self.convert_func = convert_func
|
||||
self.convert_func = convert_func # TODO: remove
|
||||
self.set_func = set_func
|
||||
|
||||
def __call__(self, weight):
|
||||
intermediate_dtype = weight.dtype
|
||||
if self.convert_func is not None:
|
||||
weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True)
|
||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
||||
|
||||
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
|
||||
intermediate_dtype = torch.float32
|
||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
|
||||
if self.set_func is None:
|
||||
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
|
||||
else:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
|
||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
||||
|
||||
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
|
||||
if self.set_func is not None:
|
||||
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
|
||||
else:
|
||||
return out
|
||||
def low_vram_patch_estimate_vram(model, key):
|
||||
weight, set_func, convert_func = get_key_weight(model, key)
|
||||
if weight is None:
|
||||
return 0
|
||||
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
|
||||
if model_dtype is None:
|
||||
model_dtype = weight.dtype
|
||||
|
||||
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
|
||||
|
||||
def get_key_weight(model, key):
|
||||
set_func = None
|
||||
@ -231,7 +228,6 @@ class ModelPatcher:
|
||||
self.object_patches_backup = {}
|
||||
self.weight_wrapper_patches = {}
|
||||
self.model_options = {"transformer_options":{}}
|
||||
self.model_size()
|
||||
self.load_device = load_device
|
||||
self.offload_device = offload_device
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
@ -270,6 +266,9 @@ class ModelPatcher:
|
||||
if not hasattr(self.model, 'current_weight_patches_uuid'):
|
||||
self.model.current_weight_patches_uuid = None
|
||||
|
||||
if not hasattr(self.model, 'model_offload_buffer_memory'):
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
return self.size
|
||||
@ -286,7 +285,7 @@ class ModelPatcher:
|
||||
return self.model.lowvram_patch_counter
|
||||
|
||||
def clone(self):
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update)
|
||||
n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update)
|
||||
n.patches = {}
|
||||
for k in self.patches:
|
||||
n.patches[k] = self.patches[k][:]
|
||||
@ -455,6 +454,9 @@ class ModelPatcher:
|
||||
def set_model_post_input_patch(self, patch):
|
||||
self.set_model_patch(patch, "post_input")
|
||||
|
||||
def set_model_noise_refiner_patch(self, patch):
|
||||
self.set_model_patch(patch, "noise_refiner")
|
||||
|
||||
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||
rope_options["scale_x"] = scale_x
|
||||
@ -619,10 +621,11 @@ class ModelPatcher:
|
||||
if key not in self.backup:
|
||||
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
|
||||
|
||||
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
|
||||
if device_to is not None:
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
|
||||
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
|
||||
else:
|
||||
temp_weight = weight.to(torch.float32, copy=True)
|
||||
temp_weight = weight.to(temp_dtype, copy=True)
|
||||
if convert_func is not None:
|
||||
temp_weight = convert_func(temp_weight, inplace=True)
|
||||
|
||||
@ -663,7 +666,22 @@ class ModelPatcher:
|
||||
skip = True # skip random weights in non leaf modules
|
||||
break
|
||||
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
loading.append((comfy.model_management.module_size(m), n, m, params))
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
def check_module_offload_mem(key):
|
||||
if key in self.patches:
|
||||
return low_vram_patch_estimate_vram(self.model, key)
|
||||
model_dtype = getattr(self.model, "manual_cast_dtype", None)
|
||||
weight, _, _ = get_key_weight(self.model, key)
|
||||
if model_dtype is None or weight is None:
|
||||
return 0
|
||||
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
|
||||
return weight.numel() * model_dtype.itemsize
|
||||
return 0
|
||||
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
|
||||
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
|
||||
loading.append((module_offload_mem, module_mem, n, m, params))
|
||||
return loading
|
||||
|
||||
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
|
||||
@ -677,20 +695,22 @@ class ModelPatcher:
|
||||
|
||||
load_completely = []
|
||||
offloaded = []
|
||||
offload_buffer = 0
|
||||
loading.sort(reverse=True)
|
||||
for x in loading:
|
||||
n = x[1]
|
||||
m = x[2]
|
||||
params = x[3]
|
||||
module_mem = x[0]
|
||||
for i, x in enumerate(loading):
|
||||
module_offload_mem, module_mem, n, m, params = x
|
||||
|
||||
lowvram_weight = False
|
||||
|
||||
potential_offload = max(offload_buffer, module_offload_mem + sum([ x1[1] for x1 in loading[i+1:i+1+comfy.model_management.NUM_STREAMS]]))
|
||||
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory
|
||||
|
||||
weight_key = "{}.weight".format(n)
|
||||
bias_key = "{}.bias".format(n)
|
||||
|
||||
if not full_load and hasattr(m, "comfy_cast_weights"):
|
||||
if mem_counter + module_mem >= lowvram_model_memory:
|
||||
if not lowvram_fits:
|
||||
offload_buffer = potential_offload
|
||||
lowvram_weight = True
|
||||
lowvram_counter += 1
|
||||
lowvram_mem_counter += module_mem
|
||||
@ -724,9 +744,11 @@ class ModelPatcher:
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
wipe_lowvram_weight(m)
|
||||
|
||||
if full_load or mem_counter + module_mem < lowvram_model_memory:
|
||||
if full_load or lowvram_fits:
|
||||
mem_counter += module_mem
|
||||
load_completely.append((module_mem, n, m, params))
|
||||
else:
|
||||
offload_buffer = potential_offload
|
||||
|
||||
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
@ -753,6 +775,8 @@ class ModelPatcher:
|
||||
key = "{}.{}".format(n, param)
|
||||
self.unpin_weight(key)
|
||||
self.patch_weight_to_device(key, device_to=device_to)
|
||||
if comfy.model_management.is_device_cuda(device_to):
|
||||
torch.cuda.synchronize()
|
||||
|
||||
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
|
||||
m.comfy_patched_weights = True
|
||||
@ -767,7 +791,7 @@ class ModelPatcher:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
if lowvram_counter > 0:
|
||||
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
|
||||
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
|
||||
self.model.model_lowvram = True
|
||||
else:
|
||||
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
|
||||
@ -779,6 +803,7 @@ class ModelPatcher:
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = mem_counter
|
||||
self.model.model_offload_buffer_memory = offload_buffer
|
||||
self.model.current_weight_patches_uuid = self.patches_uuid
|
||||
|
||||
for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
|
||||
@ -832,6 +857,7 @@ class ModelPatcher:
|
||||
self.model.to(device_to)
|
||||
self.model.device = device_to
|
||||
self.model.model_loaded_weight_memory = 0
|
||||
self.model.model_offload_buffer_memory = 0
|
||||
|
||||
for m in self.model.modules():
|
||||
if hasattr(m, "comfy_patched_weights"):
|
||||
@ -850,13 +876,18 @@ class ModelPatcher:
|
||||
patch_counter = 0
|
||||
unload_list = self._load_list()
|
||||
unload_list.sort()
|
||||
|
||||
offload_buffer = self.model.model_offload_buffer_memory
|
||||
if len(unload_list) > 0:
|
||||
NS = comfy.model_management.NUM_STREAMS
|
||||
offload_weight_factor = [ min(offload_buffer / (NS + 1), unload_list[0][1]) ] * NS
|
||||
|
||||
for unload in unload_list:
|
||||
if memory_to_free < memory_freed:
|
||||
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
|
||||
break
|
||||
module_mem = unload[0]
|
||||
n = unload[1]
|
||||
m = unload[2]
|
||||
params = unload[3]
|
||||
module_offload_mem, module_mem, n, m, params = unload
|
||||
|
||||
potential_offload = module_offload_mem + sum(offload_weight_factor)
|
||||
|
||||
lowvram_possible = hasattr(m, "comfy_cast_weights")
|
||||
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
|
||||
@ -902,20 +933,25 @@ class ModelPatcher:
|
||||
patch_counter += 1
|
||||
cast_weight = True
|
||||
|
||||
if cast_weight:
|
||||
if cast_weight and hasattr(m, "comfy_cast_weights"):
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
m.comfy_patched_weights = False
|
||||
memory_freed += module_mem
|
||||
offload_buffer = max(offload_buffer, potential_offload)
|
||||
offload_weight_factor.append(module_mem)
|
||||
offload_weight_factor.pop(0)
|
||||
logging.debug("freed {}".format(n))
|
||||
|
||||
for param in params:
|
||||
self.pin_weight_to_device("{}.{}".format(n, param))
|
||||
|
||||
|
||||
self.model.model_lowvram = True
|
||||
self.model.lowvram_patch_counter += patch_counter
|
||||
self.model.model_loaded_weight_memory -= memory_freed
|
||||
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
self.model.model_offload_buffer_memory = offload_buffer
|
||||
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
|
||||
return memory_freed
|
||||
|
||||
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
|
||||
|
||||
373
comfy/ops.py
373
comfy/ops.py
@ -22,7 +22,7 @@ import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy.float
|
||||
import comfy.rmsnorm
|
||||
import contextlib
|
||||
import json
|
||||
|
||||
def run_every_op():
|
||||
if torch.compiler.is_compiling():
|
||||
@ -58,7 +58,8 @@ except (ModuleNotFoundError, TypeError):
|
||||
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
|
||||
try:
|
||||
if comfy.model_management.is_nvidia():
|
||||
if torch.backends.cudnn.version() >= 91002 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
|
||||
cudnn_version = torch.backends.cudnn.version()
|
||||
if (cudnn_version >= 91002 and cudnn_version < 91500) and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
|
||||
#TODO: change upper bound version once it's fixed'
|
||||
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
|
||||
logging.info("working around nvidia conv3d memory bug.")
|
||||
@ -92,11 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
else:
|
||||
offload_stream = None
|
||||
|
||||
if offload_stream is not None:
|
||||
wf_context = offload_stream
|
||||
else:
|
||||
wf_context = contextlib.nullcontext()
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
weight_has_function = len(s.weight_function) > 0
|
||||
@ -108,20 +104,24 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
if s.bias is not None:
|
||||
bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream)
|
||||
|
||||
if bias_has_function:
|
||||
with wf_context:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
|
||||
bias_a = bias
|
||||
weight_a = weight
|
||||
|
||||
if s.bias is not None:
|
||||
for f in s.bias_function:
|
||||
bias = f(bias)
|
||||
|
||||
if weight_has_function or weight.dtype != dtype:
|
||||
with wf_context:
|
||||
weight = weight.to(dtype=dtype)
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
weight = weight.to(dtype=dtype)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
for f in s.weight_function:
|
||||
weight = f(weight)
|
||||
|
||||
comfy.model_management.sync_stream(device, offload_stream)
|
||||
if offloadable:
|
||||
return weight, bias, offload_stream
|
||||
return weight, bias, (offload_stream, weight_a, bias_a)
|
||||
else:
|
||||
#Legacy function signature
|
||||
return weight, bias
|
||||
@ -130,13 +130,16 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||
if offload_stream is None:
|
||||
return
|
||||
if weight is not None:
|
||||
device = weight.device
|
||||
os, weight_a, bias_a = offload_stream
|
||||
if os is None:
|
||||
return
|
||||
if weight_a is not None:
|
||||
device = weight_a.device
|
||||
else:
|
||||
if bias is None:
|
||||
if bias_a is None:
|
||||
return
|
||||
device = bias.device
|
||||
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||
device = bias_a.device
|
||||
os.wait_stream(comfy.model_management.current_stream(device))
|
||||
|
||||
|
||||
class CastWeightBiasOp:
|
||||
@ -412,22 +415,12 @@ def fp8_linear(self, input):
|
||||
|
||||
if input.ndim == 3 or input.ndim == 2:
|
||||
w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True)
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
scale_input = self.scale_input
|
||||
if scale_weight is None:
|
||||
scale_weight = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
else:
|
||||
scale_weight = scale_weight.to(input.device)
|
||||
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype)
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype}
|
||||
quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight)
|
||||
|
||||
# Wrap weight in QuantizedTensor - this enables unified dispatch
|
||||
# Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py!
|
||||
@ -448,7 +441,7 @@ class fp8_ops(manual_cast):
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if not self.training:
|
||||
if len(self.weight_function) == 0 and len(self.bias_function) == 0:
|
||||
try:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
@ -461,59 +454,6 @@ class fp8_ops(manual_cast):
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
|
||||
logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
|
||||
class scaled_fp8_op(manual_cast):
|
||||
class Linear(manual_cast.Linear):
|
||||
def __init__(self, *args, **kwargs):
|
||||
if override_dtype is not None:
|
||||
kwargs['dtype'] = override_dtype
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def reset_parameters(self):
|
||||
if not hasattr(self, 'scale_weight'):
|
||||
self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
|
||||
if not scale_input:
|
||||
self.scale_input = None
|
||||
|
||||
if not hasattr(self, 'scale_input'):
|
||||
self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False)
|
||||
return None
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
if fp8_matrix_mult:
|
||||
out = fp8_linear(self, input)
|
||||
if out is not None:
|
||||
return out
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
|
||||
if weight.numel() < input.numel(): #TODO: optimize
|
||||
x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias)
|
||||
else:
|
||||
x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if inplace:
|
||||
weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
return weight
|
||||
else:
|
||||
return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype)
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed)
|
||||
if return_weight:
|
||||
return weight
|
||||
if inplace_update:
|
||||
self.weight.data.copy_(weight)
|
||||
else:
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
return scaled_fp8_op
|
||||
|
||||
CUBLAS_IS_AVAILABLE = False
|
||||
try:
|
||||
from cublas_ops import CublasLinear
|
||||
@ -539,117 +479,180 @@ if CUBLAS_IS_AVAILABLE:
|
||||
# ==============================================================================
|
||||
from .quant_ops import QuantizedTensor, QUANT_ALGOS
|
||||
|
||||
class MixedPrecisionOps(disable_weight_init):
|
||||
_layer_quant_config = {}
|
||||
_compute_dtype = torch.bfloat16
|
||||
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_quant_config = quant_config
|
||||
_compute_dtype = compute_dtype
|
||||
_full_precision_mm = full_precision_mm
|
||||
|
||||
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
|
||||
# self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
class Linear(torch.nn.Module, CastWeightBiasOp):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
if bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
if dtype is None:
|
||||
dtype = MixedPrecisionOps._compute_dtype
|
||||
|
||||
self.tensor_class = None
|
||||
self.factory_kwargs = {"device": device, "dtype": dtype}
|
||||
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self._has_bias = bias
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
self.tensor_class = None
|
||||
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
|
||||
|
||||
device = self.factory_kwargs["device"]
|
||||
layer_name = prefix.rstrip('.')
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||
def reset_parameters(self):
|
||||
return None
|
||||
|
||||
manually_loaded_keys = [weight_key]
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if layer_name not in MixedPrecisionOps._layer_quant_config:
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
|
||||
else:
|
||||
quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None)
|
||||
if quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
device = self.factory_kwargs["device"]
|
||||
layer_name = prefix.rstrip('.')
|
||||
weight_key = f"{prefix}weight"
|
||||
weight = state_dict.pop(weight_key, None)
|
||||
if weight is None:
|
||||
raise ValueError(f"Missing weight for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
manually_loaded_keys = [weight_key]
|
||||
|
||||
weight_scale_key = f"{prefix}weight_scale"
|
||||
layout_params = {
|
||||
'scale': state_dict.pop(weight_scale_key, None),
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||
'block_size': qconfig.get("group_size", None),
|
||||
}
|
||||
if layout_params['scale'] is not None:
|
||||
manually_loaded_keys.append(weight_scale_key)
|
||||
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||
if layer_conf is not None:
|
||||
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device), self.layout_type, layout_params),
|
||||
requires_grad=False
|
||||
)
|
||||
if layer_conf is None:
|
||||
dtype = self.factory_kwargs["dtype"]
|
||||
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
|
||||
if dtype != MixedPrecisionOps._compute_dtype:
|
||||
self.comfy_cast_weights = True
|
||||
if self._has_bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
else:
|
||||
self.quant_format = layer_conf.get("format", None)
|
||||
if not self._full_precision_mm:
|
||||
self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
if self.quant_format is None:
|
||||
raise ValueError(f"Unknown quantization format for layer {layer_name}")
|
||||
|
||||
qconfig = QUANT_ALGOS[self.quant_format]
|
||||
self.layout_type = qconfig["comfy_tensor_layout"]
|
||||
|
||||
weight_scale_key = f"{prefix}weight_scale"
|
||||
scale = state_dict.pop(weight_scale_key, None)
|
||||
if scale is not None:
|
||||
scale = scale.to(device)
|
||||
layout_params = {
|
||||
'scale': scale,
|
||||
'orig_dtype': MixedPrecisionOps._compute_dtype,
|
||||
'block_size': qconfig.get("group_size", None),
|
||||
}
|
||||
|
||||
if scale is not None:
|
||||
manually_loaded_keys.append(weight_scale_key)
|
||||
|
||||
self.weight = torch.nn.Parameter(
|
||||
QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params),
|
||||
requires_grad=False
|
||||
)
|
||||
|
||||
if self._has_bias:
|
||||
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
|
||||
else:
|
||||
self.register_parameter("bias", None)
|
||||
|
||||
for param_name in qconfig["parameters"]:
|
||||
param_key = f"{prefix}{param_name}"
|
||||
_v = state_dict.pop(param_key, None)
|
||||
if _v is None:
|
||||
continue
|
||||
self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
|
||||
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
|
||||
if isinstance(self.weight, QuantizedTensor):
|
||||
sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale']
|
||||
quant_conf = {"format": self.quant_format}
|
||||
if self._full_precision_mm:
|
||||
quant_conf["full_precision_matrix_mult"] = True
|
||||
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||
return sd
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
run_every_op()
|
||||
|
||||
if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
|
||||
def convert_weight(self, weight, inplace=False, **kwargs):
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
return weight.dequantize()
|
||||
else:
|
||||
return weight
|
||||
|
||||
def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs):
|
||||
if getattr(self, 'layout_type', None) is not None:
|
||||
weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True)
|
||||
else:
|
||||
weight = weight.to(self.weight.dtype)
|
||||
if return_weight:
|
||||
return weight
|
||||
|
||||
assert inplace_update is False # TODO: eventually remove the inplace_update stuff
|
||||
self.weight = torch.nn.Parameter(weight, requires_grad=False)
|
||||
|
||||
def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working
|
||||
if recurse:
|
||||
for module in self.children():
|
||||
module._apply(fn)
|
||||
|
||||
for key, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
|
||||
manually_loaded_keys.append(param_key)
|
||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
return self
|
||||
|
||||
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
return MixedPrecisionOps
|
||||
|
||||
for key in manually_loaded_keys:
|
||||
if key in missing_keys:
|
||||
missing_keys.remove(key)
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||
|
||||
def _forward(self, input, weight, bias):
|
||||
return torch.nn.functional.linear(input, weight, bias)
|
||||
|
||||
def forward_comfy_cast_weights(self, input):
|
||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
|
||||
x = self._forward(input, weight, bias)
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return x
|
||||
|
||||
def forward(self, input, *args, **kwargs):
|
||||
run_every_op()
|
||||
|
||||
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
|
||||
return self.forward_comfy_cast_weights(input, *args, **kwargs)
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
getattr(self, 'input_scale', None) is not None and
|
||||
not isinstance(input, QuantizedTensor)):
|
||||
input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype)
|
||||
return self._forward(input, self.weight, self.bias)
|
||||
|
||||
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None):
|
||||
if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config:
|
||||
MixedPrecisionOps._layer_quant_config = model_config.layer_quant_config
|
||||
MixedPrecisionOps._compute_dtype = compute_dtype
|
||||
logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers")
|
||||
return MixedPrecisionOps
|
||||
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
|
||||
if scaled_fp8 is not None:
|
||||
return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
|
||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||
logging.info("Using mixed precision operations")
|
||||
return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute)
|
||||
|
||||
if (
|
||||
fp8_compute and
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import logging
|
||||
from typing import Tuple, Dict
|
||||
import comfy.float
|
||||
|
||||
_LAYOUT_REGISTRY = {}
|
||||
_GENERIC_UTILS = {}
|
||||
@ -228,6 +229,17 @@ class QuantizedTensor(torch.Tensor):
|
||||
new_kwargs = dequant_arg(kwargs)
|
||||
return func(*new_args, **new_kwargs)
|
||||
|
||||
def data_ptr(self):
|
||||
return self._qdata.data_ptr()
|
||||
|
||||
def is_pinned(self):
|
||||
return self._qdata.is_pinned()
|
||||
|
||||
def is_contiguous(self, *arg, **kwargs):
|
||||
return self._qdata.is_contiguous(*arg, **kwargs)
|
||||
|
||||
def storage(self):
|
||||
return self._qdata.storage()
|
||||
|
||||
# ==============================================================================
|
||||
# Generic Utilities (Layout-Agnostic Operations)
|
||||
@ -240,12 +252,6 @@ def _create_transformed_qtensor(qt, transform_fn):
|
||||
|
||||
|
||||
def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"):
|
||||
if target_dtype is not None and target_dtype != qt.dtype:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: dtype conversion requested to {target_dtype}, "
|
||||
f"but not supported for quantized tensors. Ignoring dtype."
|
||||
)
|
||||
|
||||
if target_layout is not None and target_layout != torch.strided:
|
||||
logging.warning(
|
||||
f"QuantizedTensor: layout change requested to {target_layout}, "
|
||||
@ -265,6 +271,8 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
|
||||
new_q_data = qt._qdata.to(device=target_device)
|
||||
new_params = _move_layout_params_to_device(qt._layout_params, target_device)
|
||||
if target_dtype is not None:
|
||||
new_params["orig_dtype"] = target_dtype
|
||||
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
|
||||
logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}")
|
||||
return new_qt
|
||||
@ -330,7 +338,9 @@ def generic_copy_(func, args, kwargs):
|
||||
# Copy from another quantized tensor
|
||||
qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking)
|
||||
qt_dest._layout_type = src._layout_type
|
||||
orig_dtype = qt_dest._layout_params["orig_dtype"]
|
||||
_copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking)
|
||||
qt_dest._layout_params["orig_dtype"] = orig_dtype
|
||||
else:
|
||||
# Copy from regular tensor - just copy raw data
|
||||
qt_dest._qdata.copy_(src)
|
||||
@ -338,6 +348,18 @@ def generic_copy_(func, args, kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten.to.dtype)
|
||||
def generic_to_dtype(func, args, kwargs):
|
||||
"""Handle .to(dtype) calls - dtype conversion only."""
|
||||
src = args[0]
|
||||
if isinstance(src, QuantizedTensor):
|
||||
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
|
||||
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
|
||||
src._layout_params["orig_dtype"] = target_dtype
|
||||
return src
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
|
||||
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
|
||||
return True
|
||||
@ -373,32 +395,45 @@ class TensorCoreFP8Layout(QuantizedLayout):
|
||||
- orig_dtype: Original dtype before quantization (for casting back)
|
||||
"""
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
|
||||
def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False):
|
||||
orig_dtype = tensor.dtype
|
||||
|
||||
if scale is None:
|
||||
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
|
||||
if isinstance(scale, str) and scale == "recalculate":
|
||||
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
|
||||
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
|
||||
tensor_info = torch.finfo(tensor.dtype)
|
||||
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
|
||||
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
if scale is not None:
|
||||
if not isinstance(scale, torch.Tensor):
|
||||
scale = torch.tensor(scale)
|
||||
scale = scale.to(device=tensor.device, dtype=torch.float32)
|
||||
|
||||
tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
|
||||
# lp_amax = torch.finfo(dtype).max
|
||||
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
|
||||
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)
|
||||
if inplace_ops:
|
||||
tensor *= (1.0 / scale).to(tensor.dtype)
|
||||
else:
|
||||
tensor = tensor * (1.0 / scale).to(tensor.dtype)
|
||||
else:
|
||||
scale = torch.ones((), device=tensor.device, dtype=torch.float32)
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding)
|
||||
else:
|
||||
lp_amax = torch.finfo(dtype).max
|
||||
torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor)
|
||||
tensor = tensor.to(dtype, memory_format=torch.contiguous_format)
|
||||
|
||||
layout_params = {
|
||||
'scale': scale,
|
||||
'orig_dtype': orig_dtype
|
||||
}
|
||||
return qdata, layout_params
|
||||
return tensor, layout_params
|
||||
|
||||
@staticmethod
|
||||
def dequantize(qdata, scale, orig_dtype, **kwargs):
|
||||
plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
|
||||
return plain_tensor * scale
|
||||
plain_tensor.mul_(scale)
|
||||
return plain_tensor
|
||||
|
||||
@classmethod
|
||||
def get_plain_tensors(cls, qtensor):
|
||||
|
||||
225
comfy/sd.py
225
comfy/sd.py
@ -52,6 +52,9 @@ import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.z_image
|
||||
import comfy.text_encoders.ovis
|
||||
import comfy.text_encoders.kandinsky5
|
||||
|
||||
import comfy.model_patcher
|
||||
import comfy.lora
|
||||
@ -59,6 +62,8 @@ import comfy.lora_convert
|
||||
import comfy.hooks
|
||||
import comfy.t2i_adapter.adapter
|
||||
import comfy.taesd.taesd
|
||||
import comfy.taesd.taehv
|
||||
import comfy.latent_formats
|
||||
|
||||
import comfy.ldm.flux.redux
|
||||
|
||||
@ -94,7 +99,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
|
||||
|
||||
|
||||
class CLIP:
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}):
|
||||
def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}):
|
||||
if no_init:
|
||||
return
|
||||
params = target.params.copy()
|
||||
@ -122,9 +127,32 @@ class CLIP:
|
||||
|
||||
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
|
||||
#Match torch.float32 hardcode upcast in TE implemention
|
||||
self.patcher.set_model_compute_dtype(torch.float32)
|
||||
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
|
||||
self.patcher.is_clip = True
|
||||
self.apply_hooks_to_conds = None
|
||||
if len(state_dict) > 0:
|
||||
if isinstance(state_dict, list):
|
||||
for c in state_dict:
|
||||
m, u = self.load_sd(c)
|
||||
if len(m) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected: {}".format(u))
|
||||
else:
|
||||
m, u = self.load_sd(state_dict, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
if len(m_filter) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
else:
|
||||
logging.debug("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected {}:".format(u))
|
||||
|
||||
if params['device'] == load_device:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
@ -189,6 +217,7 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
all_hooks.reset()
|
||||
self.patcher.patch_hooks(None)
|
||||
if show_pbar:
|
||||
@ -236,6 +265,7 @@ class CLIP:
|
||||
self.cond_stage_model.set_clip_options({"projected_pooled": False})
|
||||
|
||||
self.load_model()
|
||||
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
|
||||
o = self.cond_stage_model.encode_token_weights(tokens)
|
||||
cond, pooled = o[:2]
|
||||
if return_dict:
|
||||
@ -356,7 +386,7 @@ class VAE:
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype)
|
||||
elif sd['decoder.conv_in.weight'].shape[1] == 32:
|
||||
elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False}
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
@ -382,6 +412,17 @@ class VAE:
|
||||
self.upscale_ratio = 4
|
||||
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1]
|
||||
if 'decoder.post_quant_conv.weight' in sd:
|
||||
sd = comfy.utils.state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."})
|
||||
|
||||
if 'bn.running_mean' in sd:
|
||||
ddconfig["batch_norm_latent"] = True
|
||||
self.downscale_ratio *= 2
|
||||
self.upscale_ratio *= 2
|
||||
self.latent_channels *= 4
|
||||
old_memory_used_decode = self.memory_used_decode
|
||||
self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0
|
||||
|
||||
if 'post_quant_conv.weight' in sd:
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
else:
|
||||
@ -441,20 +482,20 @@ class VAE:
|
||||
elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32:
|
||||
ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True}
|
||||
ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.latent_channels = 64
|
||||
self.latent_channels = 32
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
self.latent_dim = 3
|
||||
self.not_video = True
|
||||
self.not_video = False
|
||||
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"},
|
||||
encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig},
|
||||
decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig})
|
||||
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
|
||||
elif "decoder.conv_in.conv.weight" in sd:
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
ddconfig["conv3d"] = True
|
||||
@ -466,8 +507,10 @@ class VAE:
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
#This is likely to significantly over-estimate with single image or low frame counts as the
|
||||
#implementation is able to completely skip caching. Rework if used as an image only VAE
|
||||
self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
elif "decoder.unpatcher3d.wavelets" in sd:
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8)
|
||||
@ -496,17 +539,20 @@ class VAE:
|
||||
self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype)
|
||||
else: # Wan 2.1 VAE
|
||||
dim = sd["decoder.head.0.gamma"].shape[0]
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = 16
|
||||
ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
|
||||
self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1500 if shape[2]<=4 else 6000) * shape[3] * shape[4] * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (2200 if shape[2]<=4 else 7000) * shape[3] * shape[4] * (8*8) * model_management.dtype_size(dtype)
|
||||
|
||||
|
||||
# Hunyuan 3d v2 2.0 & 2.1
|
||||
elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
|
||||
|
||||
@ -572,6 +618,35 @@ class VAE:
|
||||
self.process_input = lambda audio: audio
|
||||
self.working_dtypes = [torch.float32]
|
||||
self.crop_input = False
|
||||
elif "decoder.22.bias" in sd: # taehv, taew and lighttae
|
||||
self.latent_channels = sd["decoder.1.weight"].shape[1]
|
||||
self.latent_dim = 3
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16)
|
||||
self.upscale_index_formula = (4, 16, 16)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16)
|
||||
self.downscale_index_formula = (4, 16, 16)
|
||||
if self.latent_channels == 48: # Wan 2.2
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.process_output = lambda image: image
|
||||
self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype))
|
||||
elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=comfy.latent_formats.HunyuanVideo15)
|
||||
self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently"))
|
||||
self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
else:
|
||||
if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical
|
||||
latent_format=comfy.latent_formats.HunyuanVideo
|
||||
else:
|
||||
latent_format=None # lighttaew2_1 doesn't need scaling
|
||||
self.first_stage_model = comfy.taesd.taehv.TAEHV(latent_channels=self.latent_channels, latent_format=latent_format)
|
||||
self.process_input = self.process_output = lambda image: image
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype))
|
||||
self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype))
|
||||
else:
|
||||
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
|
||||
self.first_stage_model = None
|
||||
@ -696,6 +771,8 @@ class VAE:
|
||||
self.throw_exception_if_invalid()
|
||||
pixel_samples = None
|
||||
do_tile = False
|
||||
if self.latent_dim == 2 and samples_in.ndim == 5:
|
||||
samples_in = samples_in[:, :, 0]
|
||||
try:
|
||||
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
|
||||
model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
|
||||
@ -911,12 +988,19 @@ class CLIPType(Enum):
|
||||
OMNIGEN2 = 17
|
||||
QWEN_IMAGE = 18
|
||||
HUNYUAN_IMAGE = 19
|
||||
HUNYUAN_VIDEO_15 = 20
|
||||
OVIS = 21
|
||||
KANDINSKY5 = 22
|
||||
KANDINSKY5_IMAGE = 23
|
||||
|
||||
|
||||
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
|
||||
clip_data = []
|
||||
for p in ckpt_paths:
|
||||
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
|
||||
sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True)
|
||||
if model_options.get("custom_operations", None) is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata)
|
||||
clip_data.append(sd)
|
||||
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)
|
||||
|
||||
|
||||
@ -934,6 +1018,11 @@ class TEModel(Enum):
|
||||
QWEN25_7B = 11
|
||||
BYT5_SMALL_GLYPH = 12
|
||||
GEMMA_3_4B = 13
|
||||
MISTRAL3_24B = 14
|
||||
MISTRAL3_24B_PRUNED_FLUX2 = 15
|
||||
QWEN3_4B = 16
|
||||
QWEN3_2B = 17
|
||||
|
||||
|
||||
def detect_te_model(sd):
|
||||
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
|
||||
@ -966,6 +1055,18 @@ def detect_te_model(sd):
|
||||
if weight.shape[0] == 512:
|
||||
return TEModel.QWEN25_7B
|
||||
if "model.layers.0.post_attention_layernorm.weight" in sd:
|
||||
weight = sd['model.layers.0.post_attention_layernorm.weight']
|
||||
if 'model.layers.0.self_attn.q_norm.weight' in sd:
|
||||
if weight.shape[0] == 2560:
|
||||
return TEModel.QWEN3_4B
|
||||
elif weight.shape[0] == 2048:
|
||||
return TEModel.QWEN3_2B
|
||||
if weight.shape[0] == 5120:
|
||||
if "model.layers.39.post_attention_layernorm.weight" in sd:
|
||||
return TEModel.MISTRAL3_24B
|
||||
else:
|
||||
return TEModel.MISTRAL3_24B_PRUNED_FLUX2
|
||||
|
||||
return TEModel.LLAMA3_8
|
||||
return None
|
||||
|
||||
@ -1015,7 +1116,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
|
||||
@ -1039,7 +1140,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
|
||||
clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else: #CLIPType.MOCHI
|
||||
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
|
||||
@ -1068,7 +1169,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
|
||||
elif te_model == TEModel.LLAMA3_8:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
|
||||
clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
elif te_model == TEModel.QWEN25_3B:
|
||||
clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data))
|
||||
@ -1080,13 +1181,23 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
else:
|
||||
clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer
|
||||
elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2:
|
||||
clip_target.clip = comfy.text_encoders.flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2)
|
||||
clip_target.tokenizer = comfy.text_encoders.flux.Flux2Tokenizer
|
||||
tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None)
|
||||
elif te_model == TEModel.QWEN3_4B:
|
||||
clip_target.clip = comfy.text_encoders.z_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.z_image.ZImageTokenizer
|
||||
elif te_model == TEModel.QWEN3_2B:
|
||||
clip_target.clip = comfy.text_encoders.ovis.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.ovis.OvisTokenizer
|
||||
else:
|
||||
# clip_l
|
||||
if clip_type == CLIPType.SD3:
|
||||
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
|
||||
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
|
||||
elif clip_type == CLIPType.HIDREAM:
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
|
||||
clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None)
|
||||
clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
|
||||
else:
|
||||
clip_target.clip = sd1_clip.SD1ClipModel
|
||||
@ -1126,6 +1237,15 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
elif clip_type == CLIPType.HUNYUAN_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer
|
||||
elif clip_type == CLIPType.HUNYUAN_VIDEO_15:
|
||||
clip_target.clip = comfy.text_encoders.hunyuan_image.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer
|
||||
elif clip_type == CLIPType.KANDINSKY5:
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer
|
||||
elif clip_type == CLIPType.KANDINSKY5_IMAGE:
|
||||
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
|
||||
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
|
||||
else:
|
||||
clip_target.clip = sdxl_clip.SDXLClipModel
|
||||
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
|
||||
@ -1141,14 +1261,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
|
||||
parameters += comfy.utils.calculate_parameters(c)
|
||||
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
|
||||
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options)
|
||||
for c in clip_data:
|
||||
m, u = clip.load_sd(c)
|
||||
if len(m) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected: {}".format(u))
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options)
|
||||
return clip
|
||||
|
||||
def load_gligen(ckpt_path):
|
||||
@ -1207,6 +1320,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
|
||||
load_device = model_management.get_torch_device()
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata)
|
||||
|
||||
model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
|
||||
if model_config is None:
|
||||
logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
|
||||
@ -1215,18 +1332,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
return None
|
||||
return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
|
||||
|
||||
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.scaled_fp8 is not None:
|
||||
if model_config.quant_config is not None:
|
||||
weight_dtype = None
|
||||
|
||||
model_config.custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is not None:
|
||||
model_config.custom_operations = custom_operations
|
||||
|
||||
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
|
||||
|
||||
if unet_dtype is None:
|
||||
unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
|
||||
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
if model_config.quant_config is not None:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
|
||||
if model_config.clip_vision_prefix is not None:
|
||||
@ -1244,22 +1365,33 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
|
||||
vae = VAE(sd=vae_sd, metadata=metadata)
|
||||
|
||||
if output_clip:
|
||||
if te_model_options.get("custom_operations", None) is None:
|
||||
scaled_fp8_list = []
|
||||
for k in list(sd.keys()): # Convert scaled fp8 to mixed ops
|
||||
if k.endswith(".scaled_fp8"):
|
||||
scaled_fp8_list.append(k[:-len("scaled_fp8")])
|
||||
|
||||
if len(scaled_fp8_list) > 0:
|
||||
out_sd = {}
|
||||
for k in sd:
|
||||
skip = False
|
||||
for pref in scaled_fp8_list:
|
||||
skip = skip or k.startswith(pref)
|
||||
if not skip:
|
||||
out_sd[k] = sd[k]
|
||||
|
||||
for pref in scaled_fp8_list:
|
||||
quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={})
|
||||
for k in quant_sd:
|
||||
out_sd[k] = quant_sd[k]
|
||||
sd = out_sd
|
||||
|
||||
clip_target = model_config.clip_target(state_dict=sd)
|
||||
if clip_target is not None:
|
||||
clip_sd = model_config.process_clip_state_dict(sd)
|
||||
if len(clip_sd) > 0:
|
||||
parameters = comfy.utils.calculate_parameters(clip_sd)
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options)
|
||||
m, u = clip.load_sd(clip_sd, full_model=True)
|
||||
if len(m) > 0:
|
||||
m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m))
|
||||
if len(m_filter) > 0:
|
||||
logging.warning("clip missing: {}".format(m))
|
||||
else:
|
||||
logging.debug("clip missing: {}".format(m))
|
||||
|
||||
if len(u) > 0:
|
||||
logging.debug("clip unexpected {}:".format(u))
|
||||
clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options)
|
||||
else:
|
||||
logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.")
|
||||
|
||||
@ -1306,6 +1438,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
if len(temp_sd) > 0:
|
||||
sd = temp_sd
|
||||
|
||||
custom_operations = model_options.get("custom_operations", None)
|
||||
if custom_operations is None:
|
||||
sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata)
|
||||
parameters = comfy.utils.calculate_parameters(sd)
|
||||
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||
|
||||
@ -1336,7 +1471,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
|
||||
offload_device = model_management.unet_offload_device()
|
||||
unet_weight_dtype = list(model_config.supported_inference_dtypes)
|
||||
if model_config.scaled_fp8 is not None:
|
||||
if model_config.quant_config is not None:
|
||||
weight_dtype = None
|
||||
|
||||
if dtype is None:
|
||||
@ -1344,12 +1479,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None):
|
||||
else:
|
||||
unet_dtype = dtype
|
||||
|
||||
if model_config.layer_quant_config is not None:
|
||||
if model_config.quant_config is not None:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes)
|
||||
else:
|
||||
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
|
||||
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
|
||||
model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations)
|
||||
|
||||
if custom_operations is not None:
|
||||
model_config.custom_operations = custom_operations
|
||||
|
||||
if model_options.get("fp8_optimizations", False):
|
||||
model_config.optimizations["fp8"] = True
|
||||
|
||||
@ -1388,6 +1526,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
if vae is not None:
|
||||
vae_sd = vae.get_sd()
|
||||
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd)
|
||||
|
||||
@ -90,7 +90,6 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=True, enable_attention_masks=False, zero_out_masked=False,
|
||||
return_projected_pooled=True, return_attention_masks=False, model_options={}): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
|
||||
if textmodel_json_config is None:
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
|
||||
@ -108,19 +107,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
config[k] = v
|
||||
|
||||
operations = model_options.get("custom_operations", None)
|
||||
scaled_fp8 = None
|
||||
quant_config = model_options.get("quantization_metadata", None)
|
||||
|
||||
if operations is None:
|
||||
scaled_fp8 = model_options.get("scaled_fp8", None)
|
||||
if scaled_fp8 is not None:
|
||||
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
|
||||
if quant_config is not None:
|
||||
operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True)
|
||||
logging.info("Using MixedPrecisionOps for text encoder")
|
||||
else:
|
||||
operations = comfy.ops.manual_cast
|
||||
|
||||
self.operations = operations
|
||||
self.transformer = model_class(config, dtype, device, self.operations)
|
||||
if scaled_fp8 is not None:
|
||||
self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8))
|
||||
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
@ -138,6 +135,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.layer_norm_hidden_state = layer_norm_hidden_state
|
||||
self.return_projected_pooled = return_projected_pooled
|
||||
self.return_attention_masks = return_attention_masks
|
||||
self.execution_device = None
|
||||
|
||||
if layer == "hidden":
|
||||
assert layer_idx is not None
|
||||
@ -154,7 +152,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
def set_clip_options(self, options):
|
||||
layer_idx = options.get("layer", self.layer_idx)
|
||||
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
|
||||
if self.layer == "all":
|
||||
self.execution_device = options.get("execution_device", self.execution_device)
|
||||
if isinstance(self.layer, list) or self.layer == "all":
|
||||
pass
|
||||
elif layer_idx is None or abs(layer_idx) > self.num_layers:
|
||||
self.layer = "last"
|
||||
@ -166,6 +165,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
self.layer = self.options_default[0]
|
||||
self.layer_idx = self.options_default[1]
|
||||
self.return_projected_pooled = self.options_default[2]
|
||||
self.execution_device = None
|
||||
|
||||
def process_tokens(self, tokens, device):
|
||||
end_token = self.special_tokens.get("end", None)
|
||||
@ -249,14 +249,20 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info
|
||||
|
||||
def forward(self, tokens):
|
||||
device = self.transformer.get_input_embeddings().weight.device
|
||||
if self.execution_device is None:
|
||||
device = self.transformer.get_input_embeddings().weight.device
|
||||
else:
|
||||
device = self.execution_device
|
||||
|
||||
embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device)
|
||||
|
||||
attention_mask_model = None
|
||||
if self.enable_attention_masks:
|
||||
attention_mask_model = attention_mask
|
||||
|
||||
if self.layer == "all":
|
||||
if isinstance(self.layer, list):
|
||||
intermediate_output = self.layer
|
||||
elif self.layer == "all":
|
||||
intermediate_output = "all"
|
||||
else:
|
||||
intermediate_output = self.layer_idx
|
||||
|
||||
@ -21,6 +21,8 @@ import comfy.text_encoders.ace
|
||||
import comfy.text_encoders.omnigen2
|
||||
import comfy.text_encoders.qwen_image
|
||||
import comfy.text_encoders.hunyuan_image
|
||||
import comfy.text_encoders.kandinsky5
|
||||
import comfy.text_encoders.z_image
|
||||
|
||||
from . import supported_models_base
|
||||
from . import latent_formats
|
||||
@ -539,7 +541,7 @@ class SD3(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.SD3
|
||||
|
||||
memory_usage_factor = 1.2
|
||||
memory_usage_factor = 1.6
|
||||
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
@ -741,6 +743,37 @@ class FluxSchnell(Flux):
|
||||
out = model_base.Flux(self, model_type=model_base.ModelType.FLOW, device=device)
|
||||
return out
|
||||
|
||||
class Flux2(Flux):
|
||||
unet_config = {
|
||||
"image_model": "flux2",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.02,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux2
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Flux2(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None # TODO
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
|
||||
|
||||
class GenmoMochi(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "mochi_preview",
|
||||
@ -932,7 +965,7 @@ class CosmosT2IPredict2(supported_models_base.BASE):
|
||||
|
||||
def __init__(self, unet_config):
|
||||
super().__init__(unet_config)
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.9
|
||||
self.memory_usage_factor = (unet_config.get("model_channels", 2048) / 2048) * 0.95
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.CosmosPredict2(self, device=device)
|
||||
@ -963,7 +996,7 @@ class Lumina2(supported_models_base.BASE):
|
||||
"shift": 6.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.2
|
||||
memory_usage_factor = 1.4
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
@ -982,6 +1015,26 @@ class Lumina2(supported_models_base.BASE):
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
|
||||
|
||||
class ZImage(Lumina2):
|
||||
unet_config = {
|
||||
"image_model": "lumina2",
|
||||
"dim": 3840,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"multiplier": 1.0,
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
memory_usage_factor = 2.0
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
|
||||
|
||||
class WAN21_T2V(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -1236,7 +1289,7 @@ class ChromaRadiance(Chroma):
|
||||
latent_format = comfy.latent_formats.ChromaRadiance
|
||||
|
||||
# Pixel-space model, no spatial compression for model input.
|
||||
memory_usage_factor = 0.038
|
||||
memory_usage_factor = 0.044
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.ChromaRadiance(self, device=device)
|
||||
@ -1279,7 +1332,7 @@ class Omnigen2(supported_models_base.BASE):
|
||||
"shift": 2.6,
|
||||
}
|
||||
|
||||
memory_usage_factor = 1.65 #TODO
|
||||
memory_usage_factor = 1.95 #TODO
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.Flux
|
||||
@ -1344,7 +1397,7 @@ class HunyuanImage21(HunyuanVideo):
|
||||
|
||||
latent_format = latent_formats.HunyuanImage21
|
||||
|
||||
memory_usage_factor = 7.7
|
||||
memory_usage_factor = 8.7
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
@ -1374,6 +1427,108 @@ class HunyuanImage21Refiner(HunyuanVideo):
|
||||
out = model_base.HunyuanImage21Refiner(self, device=device)
|
||||
return out
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage]
|
||||
class HunyuanVideo15(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"vision_in_dim": 1152,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 7.0,
|
||||
}
|
||||
memory_usage_factor = 4.0 #TODO
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
latent_format = latent_formats.HunyuanVideo15
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanVideo15(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class HunyuanVideo15_SR_Distilled(HunyuanVideo):
|
||||
unet_config = {
|
||||
"image_model": "hunyuan_video",
|
||||
"vision_in_dim": 1152,
|
||||
"in_channels": 98,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 2.0,
|
||||
}
|
||||
memory_usage_factor = 4.0 #TODO
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
latent_format = latent_formats.HunyuanVideo15
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.HunyuanVideo15_SR_Distilled(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideo15Tokenizer, comfy.text_encoders.hunyuan_image.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class Kandinsky5(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "kandinsky5",
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 10.0,
|
||||
}
|
||||
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.HunyuanVideo
|
||||
|
||||
memory_usage_factor = 1.25 #TODO
|
||||
|
||||
supported_inference_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Kandinsky5(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5Tokenizer, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
class Kandinsky5Image(Kandinsky5):
|
||||
unet_config = {
|
||||
"image_model": "kandinsky5",
|
||||
"model_dim": 2560,
|
||||
"visual_embed_dim": 64,
|
||||
}
|
||||
|
||||
sampling_settings = {
|
||||
"shift": 3.0,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.Flux
|
||||
memory_usage_factor = 1.25 #TODO
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.Kandinsky5Image(self, device=device)
|
||||
return out
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
pref = self.text_encoder_key_prefix[0]
|
||||
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
|
||||
return supported_models_base.ClipTarget(comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage, comfy.text_encoders.kandinsky5.te(**hunyuan_detect))
|
||||
|
||||
|
||||
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5]
|
||||
|
||||
models += [SVD_img2vid]
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
from . import model_base
|
||||
from . import utils
|
||||
from . import latent_formats
|
||||
@ -49,8 +50,7 @@ class BASE:
|
||||
|
||||
manual_cast_dtype = None
|
||||
custom_operations = None
|
||||
scaled_fp8 = None
|
||||
layer_quant_config = None # Per-layer quantization configuration for mixed precision
|
||||
quant_config = None # quantization configuration for mixed precision
|
||||
optimizations = {"fp8": False}
|
||||
|
||||
@classmethod
|
||||
@ -118,3 +118,7 @@ class BASE:
|
||||
def set_inference_dtype(self, dtype, manual_cast_dtype):
|
||||
self.unet_config['dtype'] = dtype
|
||||
self.manual_cast_dtype = manual_cast_dtype
|
||||
|
||||
def __getattr__(self, name):
|
||||
logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name))
|
||||
return None
|
||||
|
||||
171
comfy/taesd/taehv.py
Normal file
171
comfy/taesd/taehv.py
Normal file
@ -0,0 +1,171 @@
|
||||
# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from tqdm.auto import tqdm
|
||||
from collections import namedtuple, deque
|
||||
|
||||
import comfy.ops
|
||||
operations=comfy.ops.disable_weight_init
|
||||
|
||||
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
||||
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
||||
|
||||
def conv(n_in, n_out, **kwargs):
|
||||
return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
||||
|
||||
class Clamp(nn.Module):
|
||||
def forward(self, x):
|
||||
return torch.tanh(x / 3) * 3
|
||||
|
||||
class MemBlock(nn.Module):
|
||||
def __init__(self, n_in, n_out, act_func):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
|
||||
self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
||||
self.act = act_func
|
||||
def forward(self, x, past):
|
||||
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
||||
|
||||
class TPool(nn.Module):
|
||||
def __init__(self, n_f, stride):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False)
|
||||
def forward(self, x):
|
||||
_NT, C, H, W = x.shape
|
||||
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
||||
|
||||
class TGrow(nn.Module):
|
||||
def __init__(self, n_f, stride):
|
||||
super().__init__()
|
||||
self.stride = stride
|
||||
self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False)
|
||||
def forward(self, x):
|
||||
_NT, C, H, W = x.shape
|
||||
x = self.conv(x)
|
||||
return x.reshape(-1, C, H, W)
|
||||
|
||||
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
||||
|
||||
B, T, C, H, W = x.shape
|
||||
if parallel:
|
||||
x = x.reshape(B*T, C, H, W)
|
||||
# parallel over input timesteps, iterate over blocks
|
||||
for b in tqdm(model, disable=not show_progress_bar):
|
||||
if isinstance(b, MemBlock):
|
||||
BT, C, H, W = x.shape
|
||||
T = BT // B
|
||||
_x = x.reshape(B, T, C, H, W)
|
||||
mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape)
|
||||
x = b(x, mem)
|
||||
else:
|
||||
x = b(x)
|
||||
BT, C, H, W = x.shape
|
||||
T = BT // B
|
||||
x = x.view(B, T, C, H, W)
|
||||
else:
|
||||
out = []
|
||||
work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
|
||||
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
||||
mem = [None] * len(model)
|
||||
while work_queue:
|
||||
xt, i = work_queue.popleft()
|
||||
if i == 0:
|
||||
progress_bar.update(1)
|
||||
if i == len(model):
|
||||
out.append(xt)
|
||||
del xt
|
||||
else:
|
||||
b = model[i]
|
||||
if isinstance(b, MemBlock):
|
||||
if mem[i] is None:
|
||||
xt_new = b(xt, xt * 0)
|
||||
mem[i] = xt.detach().clone()
|
||||
else:
|
||||
xt_new = b(xt, mem[i])
|
||||
mem[i] = xt.detach().clone()
|
||||
del xt
|
||||
work_queue.appendleft(TWorkItem(xt_new, i+1))
|
||||
elif isinstance(b, TPool):
|
||||
if mem[i] is None:
|
||||
mem[i] = []
|
||||
mem[i].append(xt.detach().clone())
|
||||
if len(mem[i]) == b.stride:
|
||||
B, C, H, W = xt.shape
|
||||
xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W))
|
||||
mem[i] = []
|
||||
work_queue.appendleft(TWorkItem(xt, i+1))
|
||||
elif isinstance(b, TGrow):
|
||||
xt = b(xt)
|
||||
NT, C, H, W = xt.shape
|
||||
for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)):
|
||||
work_queue.appendleft(TWorkItem(xt_next, i+1))
|
||||
del xt
|
||||
else:
|
||||
xt = b(xt)
|
||||
work_queue.appendleft(TWorkItem(xt, i+1))
|
||||
progress_bar.close()
|
||||
x = torch.stack(out, 1)
|
||||
return x
|
||||
|
||||
|
||||
class TAEHV(nn.Module):
|
||||
def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True):
|
||||
super().__init__()
|
||||
self.image_channels = 3
|
||||
self.patch_size = 1
|
||||
self.latent_channels = latent_channels
|
||||
self.parallel = parallel
|
||||
self.latent_format = latent_format
|
||||
self.show_progress_bar = show_progress_bar
|
||||
self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x)
|
||||
self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x)
|
||||
if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5
|
||||
self.patch_size = 2
|
||||
if self.latent_channels == 32: # HunyuanVideo1.5
|
||||
act_func = nn.LeakyReLU(0.2, inplace=True)
|
||||
else: # HunyuanVideo, Wan 2.1
|
||||
act_func = nn.ReLU(inplace=True)
|
||||
|
||||
self.encoder = nn.Sequential(
|
||||
conv(self.image_channels*self.patch_size**2, 64), act_func,
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func),
|
||||
conv(64, self.latent_channels),
|
||||
)
|
||||
n_f = [256, 128, 64, 64]
|
||||
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
||||
self.decoder = nn.Sequential(
|
||||
Clamp(), conv(self.latent_channels, n_f[0]), act_func,
|
||||
MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
||||
MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
||||
MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
||||
act_func, conv(n_f[3], self.image_channels*self.patch_size**2),
|
||||
)
|
||||
@property
|
||||
def show_progress_bar(self):
|
||||
return self._show_progress_bar
|
||||
|
||||
@show_progress_bar.setter
|
||||
def show_progress_bar(self, value):
|
||||
self._show_progress_bar = value
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if x.shape[1] % 4 != 0:
|
||||
# pad at end to multiple of 4
|
||||
n_pad = 4 - x.shape[1] % 4
|
||||
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
|
||||
x = torch.cat([x, padding], 1)
|
||||
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
|
||||
return self.process_out(x)
|
||||
|
||||
def decode(self, x, **kwargs):
|
||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||
@ -7,10 +7,10 @@ from transformers import T5TokenizerFast
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json")
|
||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||
if t5xxl_scaled_fp8 is not None:
|
||||
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
|
||||
if t5xxl_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["quantization_metadata"] = t5xxl_quantization_metadata
|
||||
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options)
|
||||
|
||||
@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
class CosmosTEModel_(CosmosT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.text_encoders.t5
|
||||
import comfy.text_encoders.sd3_clip
|
||||
import comfy.text_encoders.llama
|
||||
import comfy.model_management
|
||||
from transformers import T5TokenizerFast
|
||||
from transformers import T5TokenizerFast, LlamaTokenizerFast
|
||||
import torch
|
||||
import os
|
||||
import json
|
||||
import base64
|
||||
|
||||
class T5XXLTokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
@ -60,11 +63,112 @@ class FluxClipModel(torch.nn.Module):
|
||||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
def flux_clip(dtype_t5=None, t5_quantization_metadata=None):
|
||||
class FluxClipModel_(FluxClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options)
|
||||
return FluxClipModel_
|
||||
|
||||
def load_mistral_tokenizer(data):
|
||||
if torch.is_tensor(data):
|
||||
data = data.numpy().tobytes()
|
||||
|
||||
try:
|
||||
from transformers.integrations.mistral import MistralConverter
|
||||
except ModuleNotFoundError:
|
||||
from transformers.models.pixtral.convert_pixtral_weights_to_hf import MistralConverter
|
||||
|
||||
mistral_vocab = json.loads(data)
|
||||
|
||||
special_tokens = {}
|
||||
vocab = {}
|
||||
|
||||
max_vocab = mistral_vocab["config"]["default_vocab_size"]
|
||||
max_vocab -= len(mistral_vocab["special_tokens"])
|
||||
|
||||
for w in mistral_vocab["vocab"]:
|
||||
r = w["rank"]
|
||||
if r >= max_vocab:
|
||||
continue
|
||||
|
||||
vocab[base64.b64decode(w["token_bytes"])] = r
|
||||
|
||||
for w in mistral_vocab["special_tokens"]:
|
||||
if "token_bytes" in w:
|
||||
special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"]
|
||||
else:
|
||||
special_tokens[w["token_str"]] = w["rank"]
|
||||
|
||||
all_special = []
|
||||
for v in special_tokens:
|
||||
all_special.append(v)
|
||||
|
||||
special_tokens.update(vocab)
|
||||
vocab = special_tokens
|
||||
return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False}
|
||||
|
||||
class MistralTokenizerClass:
|
||||
@staticmethod
|
||||
def from_pretrained(path, **kwargs):
|
||||
return LlamaTokenizerFast(**kwargs)
|
||||
|
||||
class Mistral3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
self.tekken_data = tokenizer_data.get("tekken_model", None)
|
||||
super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data)
|
||||
|
||||
def state_dict(self):
|
||||
return {"tekken_model": self.tekken_data}
|
||||
|
||||
class Flux2Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer)
|
||||
self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]'
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
return tokens
|
||||
|
||||
class Mistral3_24BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer=[10, 20, 30], layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
textmodel_json_config = {}
|
||||
num_layers = model_options.get("num_layers", None)
|
||||
if num_layers is not None:
|
||||
textmodel_json_config["num_hidden_layers"] = num_layers
|
||||
if num_layers < 40:
|
||||
textmodel_json_config["final_norm"] = False
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
class Flux2TEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel):
|
||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
out, pooled, extra = super().encode_token_weights(token_weight_pairs)
|
||||
|
||||
out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1)
|
||||
out = out.movedim(1, 2)
|
||||
out = out.reshape(out.shape[0], out.shape[1], -1)
|
||||
return out, pooled, extra
|
||||
|
||||
def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False):
|
||||
class Flux2TEModel_(Flux2TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if pruned:
|
||||
model_options = model_options.copy()
|
||||
model_options["num_layers"] = 30
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Flux2TEModel_
|
||||
|
||||
@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
|
||||
def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
class MochiTEModel_(MochiT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module):
|
||||
return self.llama.load_sd(sd)
|
||||
|
||||
|
||||
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
|
||||
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None):
|
||||
class HiDreamTEModel_(HiDreamTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return HiDreamTEModel_
|
||||
|
||||
@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer):
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel):
|
||||
else:
|
||||
return super().load_sd(sd)
|
||||
|
||||
def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None):
|
||||
def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None):
|
||||
class QwenImageTEModel_(HunyuanImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
from comfy import sd1_clip
|
||||
import comfy.model_management
|
||||
import comfy.text_encoders.llama
|
||||
from .hunyuan_image import HunyuanImageTokenizer
|
||||
from transformers import LlamaTokenizerFast
|
||||
import torch
|
||||
import os
|
||||
import numbers
|
||||
|
||||
import comfy.utils
|
||||
|
||||
def llama_detect(state_dict, prefix=""):
|
||||
out = {}
|
||||
@ -13,9 +14,9 @@ def llama_detect(state_dict, prefix=""):
|
||||
if t5_key in state_dict:
|
||||
out["dtype_llama"] = state_dict[t5_key].dtype
|
||||
|
||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||
if quant is not None:
|
||||
out["llama_quantization_metadata"] = quant
|
||||
|
||||
return out
|
||||
|
||||
@ -27,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
|
||||
|
||||
class LLAMAModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
|
||||
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
|
||||
if llama_scaled_fp8 is not None:
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
|
||||
textmodel_json_config = {}
|
||||
vocab_size = model_options.get("vocab_size", None)
|
||||
@ -73,6 +74,14 @@ class HunyuanVideoTokenizer:
|
||||
return {}
|
||||
|
||||
|
||||
class HunyuanVideo15Tokenizer(HunyuanImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nYou are a helpful assistant. Describe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
return super().tokenize_with_weights(text, return_word_ids, prevent_empty_text=True, **kwargs)
|
||||
|
||||
class HunyuanVideoClipModel(torch.nn.Module):
|
||||
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__()
|
||||
@ -149,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module):
|
||||
return self.llama.load_sd(sd)
|
||||
|
||||
|
||||
def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None):
|
||||
def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class HunyuanVideoClipModel_(HunyuanVideoClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||
return HunyuanVideoClipModel_
|
||||
|
||||
68
comfy/text_encoders/kandinsky5.py
Normal file
68
comfy/text_encoders/kandinsky5.py
Normal file
@ -0,0 +1,68 @@
|
||||
from comfy import sd1_clip
|
||||
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
||||
from .llama import Qwen25_7BVLI
|
||||
|
||||
|
||||
class Kandinsky5Tokenizer(QwenImageTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
|
||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
||||
out = super().tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Kandinsky5TokenizerImage(Kandinsky5Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||
self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>"
|
||||
|
||||
|
||||
class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
|
||||
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class Kandinsky5TEModel(QwenImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options)
|
||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs):
|
||||
cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1)
|
||||
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"])
|
||||
|
||||
return cond, l_pooled, extra
|
||||
|
||||
def set_clip_options(self, options):
|
||||
super().set_clip_options(options)
|
||||
self.clip_l.set_clip_options(options)
|
||||
|
||||
def reset_clip_options(self):
|
||||
super().reset_clip_options()
|
||||
self.clip_l.reset_clip_options()
|
||||
|
||||
def load_sd(self, sd):
|
||||
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
|
||||
return self.clip_l.load_sd(sd)
|
||||
else:
|
||||
return super().load_sd(sd)
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Kandinsky5TEModel_(Kandinsky5TEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["llama_quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return Kandinsky5TEModel_
|
||||
@ -32,6 +32,29 @@ class Llama2Config:
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Mistral3Small24BConfig:
|
||||
vocab_size: int = 131072
|
||||
hidden_size: int = 5120
|
||||
intermediate_size: int = 32768
|
||||
num_hidden_layers: int = 40
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 8192
|
||||
rms_norm_eps: float = 1e-5
|
||||
rope_theta: float = 1000000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen25_3BConfig:
|
||||
@ -53,6 +76,51 @@ class Qwen25_3BConfig:
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen3_4BConfig:
|
||||
vocab_size: int = 151936
|
||||
hidden_size: int = 2560
|
||||
intermediate_size: int = 9728
|
||||
num_hidden_layers: int = 36
|
||||
num_attention_heads: int = 32
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Ovis25_2BConfig:
|
||||
vocab_size: int = 151936
|
||||
hidden_size: int = 2048
|
||||
intermediate_size: int = 6144
|
||||
num_hidden_layers: int = 28
|
||||
num_attention_heads: int = 16
|
||||
num_key_value_heads: int = 8
|
||||
max_position_embeddings: int = 40960
|
||||
rms_norm_eps: float = 1e-6
|
||||
rope_theta: float = 1000000.0
|
||||
transformer_type: str = "llama"
|
||||
head_dim = 128
|
||||
rms_norm_add = False
|
||||
mlp_activation = "silu"
|
||||
qkv_bias = False
|
||||
rope_dims = None
|
||||
q_norm = "gemma3"
|
||||
k_norm = "gemma3"
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Qwen25_7BVLI_Config:
|
||||
@ -74,6 +142,7 @@ class Qwen25_7BVLI_Config:
|
||||
q_norm = None
|
||||
k_norm = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Gemma2_2B_Config:
|
||||
@ -96,6 +165,7 @@ class Gemma2_2B_Config:
|
||||
k_norm = None
|
||||
sliding_attention = None
|
||||
rope_scale = None
|
||||
final_norm: bool = True
|
||||
|
||||
@dataclass
|
||||
class Gemma3_4B_Config:
|
||||
@ -118,6 +188,7 @@ class Gemma3_4B_Config:
|
||||
k_norm = "gemma3"
|
||||
sliding_attention = [False, False, False, False, False, 1024]
|
||||
rope_scale = [1.0, 8.0]
|
||||
final_norm: bool = True
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
|
||||
@ -366,7 +437,12 @@ class Llama2_(nn.Module):
|
||||
transformer(config, index=i, device=device, dtype=dtype, ops=ops)
|
||||
for i in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
|
||||
if config.final_norm:
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]):
|
||||
@ -402,8 +478,12 @@ class Llama2_(nn.Module):
|
||||
|
||||
intermediate = None
|
||||
all_intermediate = None
|
||||
only_layers = None
|
||||
if intermediate_output is not None:
|
||||
if intermediate_output == "all":
|
||||
if isinstance(intermediate_output, list):
|
||||
all_intermediate = []
|
||||
only_layers = set(intermediate_output)
|
||||
elif intermediate_output == "all":
|
||||
all_intermediate = []
|
||||
intermediate_output = None
|
||||
elif intermediate_output < 0:
|
||||
@ -411,7 +491,8 @@ class Llama2_(nn.Module):
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
if only_layers is None or (i in only_layers):
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
x = layer(
|
||||
x=x,
|
||||
attention_mask=mask,
|
||||
@ -421,14 +502,17 @@ class Llama2_(nn.Module):
|
||||
if i == intermediate_output:
|
||||
intermediate = x.clone()
|
||||
|
||||
x = self.norm(x)
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
|
||||
if all_intermediate is not None:
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
if only_layers is None or ((i + 1) in only_layers):
|
||||
all_intermediate.append(x.unsqueeze(1).clone())
|
||||
|
||||
if all_intermediate is not None:
|
||||
intermediate = torch.cat(all_intermediate, dim=1)
|
||||
|
||||
if intermediate is not None and final_layer_norm_intermediate:
|
||||
if intermediate is not None and final_layer_norm_intermediate and self.norm is not None:
|
||||
intermediate = self.norm(intermediate)
|
||||
|
||||
return x, intermediate
|
||||
@ -453,6 +537,15 @@ class Llama2(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Mistral3Small24B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Mistral3Small24BConfig(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
@ -462,6 +555,24 @@ class Qwen25_3B(BaseLlama, torch.nn.Module):
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen3_4B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Qwen3_4BConfig(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Ovis25_2B(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
config = Ovis25_2BConfig(**config_dict)
|
||||
self.num_layers = config.num_hidden_layers
|
||||
|
||||
self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
|
||||
self.dtype = dtype
|
||||
|
||||
class Qwen25_7BVLI(BaseLlama, torch.nn.Module):
|
||||
def __init__(self, config_dict, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
@ -40,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel):
|
||||
super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"):
|
||||
if model_type == "gemma2_2b":
|
||||
model = Gemma2_2BModel
|
||||
elif model_type == "gemma3_4b":
|
||||
@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"):
|
||||
|
||||
class LuminaTEModel_(LuminaModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model)
|
||||
|
||||
@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class Omnigen2TEModel_(Omnigen2Model):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
66
comfy/text_encoders/ovis.py
Normal file
66
comfy/text_encoders/ovis.py
Normal file
@ -0,0 +1,66 @@
|
||||
from transformers import Qwen2Tokenizer
|
||||
import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
import torch
|
||||
import numbers
|
||||
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class OvisTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer)
|
||||
self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
return tokens
|
||||
|
||||
class Ovis25_2BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options)
|
||||
|
||||
|
||||
class OvisTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options)
|
||||
|
||||
def encode_token_weights(self, token_weight_pairs, template_end=-1):
|
||||
out, pooled = super().encode_token_weights(token_weight_pairs)
|
||||
tok_pairs = token_weight_pairs["qwen3_2b"][0]
|
||||
count_im_start = 0
|
||||
if template_end == -1:
|
||||
for i, v in enumerate(tok_pairs):
|
||||
elem = v[0]
|
||||
if not torch.is_tensor(elem):
|
||||
if isinstance(elem, numbers.Integral):
|
||||
if elem == 4004 and count_im_start < 1:
|
||||
template_end = i
|
||||
count_im_start += 1
|
||||
|
||||
if out.shape[1] > (template_end + 1):
|
||||
if tok_pairs[template_end + 1][0] == 25:
|
||||
template_end += 1
|
||||
|
||||
out = out[:, template_end:]
|
||||
return out, pooled, {}
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class OvisTEModel_(OvisTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return OvisTEModel_
|
||||
@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer)
|
||||
|
||||
def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
class PixArtTEModel_(PixArtT5XXL):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
@ -179,36 +179,36 @@
|
||||
"special": false
|
||||
},
|
||||
"151665": {
|
||||
"content": "<|img|>",
|
||||
"content": "<tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
"special": false
|
||||
},
|
||||
"151666": {
|
||||
"content": "<|endofimg|>",
|
||||
"content": "</tool_response>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
"special": false
|
||||
},
|
||||
"151667": {
|
||||
"content": "<|meta|>",
|
||||
"content": "<think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
"special": false
|
||||
},
|
||||
"151668": {
|
||||
"content": "<|endofmeta|>",
|
||||
"content": "</think>",
|
||||
"lstrip": false,
|
||||
"normalized": false,
|
||||
"rstrip": false,
|
||||
"single_word": false,
|
||||
"special": true
|
||||
"special": false
|
||||
}
|
||||
},
|
||||
"additional_special_tokens": [
|
||||
|
||||
@ -17,12 +17,14 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs):
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs):
|
||||
skip_template = False
|
||||
if text.startswith('<|im_start|>'):
|
||||
skip_template = True
|
||||
if text.startswith('<|start_header_id|>'):
|
||||
skip_template = True
|
||||
if prevent_empty_text and text == '':
|
||||
text = ' '
|
||||
|
||||
if skip_template:
|
||||
llama_text = text
|
||||
@ -83,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel):
|
||||
return out, pooled, extra
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_scaled_fp8=None):
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class QwenImageTEModel_(QwenImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = llama_scaled_fp8
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
@ -6,14 +6,15 @@ import torch
|
||||
import os
|
||||
import comfy.model_management
|
||||
import logging
|
||||
import comfy.utils
|
||||
|
||||
class T5XXLModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}):
|
||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json")
|
||||
t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None)
|
||||
if t5xxl_scaled_fp8 is not None:
|
||||
t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None)
|
||||
if t5xxl_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["quantization_metadata"] = t5xxl_quantization_metadata
|
||||
|
||||
model_options = {**model_options, "model_name": "t5xxl"}
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
@ -25,9 +26,9 @@ def t5_xxl_detect(state_dict, prefix=""):
|
||||
if t5_key in state_dict:
|
||||
out["dtype_t5"] = state_dict[t5_key].dtype
|
||||
|
||||
scaled_fp8_key = "{}scaled_fp8".format(prefix)
|
||||
if scaled_fp8_key in state_dict:
|
||||
out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype
|
||||
quant = comfy.utils.detect_layer_quantization(state_dict, prefix)
|
||||
if quant is not None:
|
||||
out["t5_quantization_metadata"] = quant
|
||||
|
||||
return out
|
||||
|
||||
@ -156,11 +157,11 @@ class SD3ClipModel(torch.nn.Module):
|
||||
else:
|
||||
return self.t5xxl.load_sd(sd)
|
||||
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False):
|
||||
def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False):
|
||||
class SD3ClipModel_(SD3ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options)
|
||||
return SD3ClipModel_
|
||||
|
||||
@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
|
||||
|
||||
def te(dtype_t5=None, t5xxl_scaled_fp8=None):
|
||||
def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
class WanTEModel(WanT5Model):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["scaled_fp8"] = t5xxl_scaled_fp8
|
||||
model_options["quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype_t5 is not None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
|
||||
45
comfy/text_encoders/z_image.py
Normal file
45
comfy/text_encoders/z_image.py
Normal file
@ -0,0 +1,45 @@
|
||||
from transformers import Qwen2Tokenizer
|
||||
import comfy.text_encoders.llama
|
||||
from comfy import sd1_clip
|
||||
import os
|
||||
|
||||
class Qwen3Tokenizer(sd1_clip.SDTokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer")
|
||||
super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data)
|
||||
|
||||
|
||||
class ZImageTokenizer(sd1_clip.SD1Tokenizer):
|
||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer)
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
if llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
|
||||
tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs)
|
||||
return tokens
|
||||
|
||||
|
||||
class Qwen3_4BModel(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
|
||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
|
||||
|
||||
|
||||
class ZImageTEModel(sd1_clip.SD1ClipModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options)
|
||||
|
||||
|
||||
def te(dtype_llama=None, llama_quantization_metadata=None):
|
||||
class ZImageTEModel_(ZImageTEModel):
|
||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
||||
if dtype_llama is not None:
|
||||
dtype = dtype_llama
|
||||
if llama_quantization_metadata is not None:
|
||||
model_options["quantization_metadata"] = llama_quantization_metadata
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return ZImageTEModel_
|
||||
143
comfy/utils.py
143
comfy/utils.py
@ -29,6 +29,7 @@ import itertools
|
||||
from torch.nn.functional import interpolate
|
||||
from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
DISABLE_MMAP = args.disable_mmap
|
||||
@ -52,7 +53,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in
|
||||
ALWAYS_SAFE_LOAD = True
|
||||
logging.info("Checkpoint files will always be loaded safely.")
|
||||
else:
|
||||
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
|
||||
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
|
||||
if device is None:
|
||||
@ -675,6 +676,72 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
|
||||
|
||||
return key_map
|
||||
|
||||
def z_image_to_diffusers(mmdit_config, output_prefix=""):
|
||||
n_layers = mmdit_config.get("n_layers", 0)
|
||||
hidden_size = mmdit_config.get("dim", 0)
|
||||
n_context_refiner = mmdit_config.get("n_refiner_layers", 2)
|
||||
n_noise_refiner = mmdit_config.get("n_refiner_layers", 2)
|
||||
key_map = {}
|
||||
|
||||
def add_block_keys(prefix_from, prefix_to, has_adaln=True):
|
||||
for end in ("weight", "bias"):
|
||||
k = "{}.attention.".format(prefix_from)
|
||||
qkv = "{}.attention.qkv.{}".format(prefix_to, end)
|
||||
key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size))
|
||||
key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size))
|
||||
key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size))
|
||||
|
||||
block_map = {
|
||||
"attention.norm_q.weight": "attention.q_norm.weight",
|
||||
"attention.norm_k.weight": "attention.k_norm.weight",
|
||||
"attention.to_out.0.weight": "attention.out.weight",
|
||||
"attention.to_out.0.bias": "attention.out.bias",
|
||||
"attention_norm1.weight": "attention_norm1.weight",
|
||||
"attention_norm2.weight": "attention_norm2.weight",
|
||||
"feed_forward.w1.weight": "feed_forward.w1.weight",
|
||||
"feed_forward.w2.weight": "feed_forward.w2.weight",
|
||||
"feed_forward.w3.weight": "feed_forward.w3.weight",
|
||||
"ffn_norm1.weight": "ffn_norm1.weight",
|
||||
"ffn_norm2.weight": "ffn_norm2.weight",
|
||||
}
|
||||
if has_adaln:
|
||||
block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight"
|
||||
block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias"
|
||||
for k, v in block_map.items():
|
||||
key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v)
|
||||
|
||||
for i in range(n_layers):
|
||||
add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i))
|
||||
|
||||
for i in range(n_context_refiner):
|
||||
add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i))
|
||||
|
||||
for i in range(n_noise_refiner):
|
||||
add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i))
|
||||
|
||||
MAP_BASIC = [
|
||||
("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"),
|
||||
("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"),
|
||||
("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"),
|
||||
("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"),
|
||||
("x_embedder.weight", "all_x_embedder.2-1.weight"),
|
||||
("x_embedder.bias", "all_x_embedder.2-1.bias"),
|
||||
("x_pad_token", "x_pad_token"),
|
||||
("cap_embedder.0.weight", "cap_embedder.0.weight"),
|
||||
("cap_embedder.1.weight", "cap_embedder.1.weight"),
|
||||
("cap_embedder.1.bias", "cap_embedder.1.bias"),
|
||||
("cap_pad_token", "cap_pad_token"),
|
||||
("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"),
|
||||
("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"),
|
||||
("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"),
|
||||
("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"),
|
||||
]
|
||||
|
||||
for c, diffusers in MAP_BASIC:
|
||||
key_map[diffusers] = "{}{}".format(output_prefix, c)
|
||||
|
||||
return key_map
|
||||
|
||||
def repeat_to_batch_size(tensor, batch_size, dim=0):
|
||||
if tensor.shape[dim] > batch_size:
|
||||
return tensor.narrow(dim, 0, batch_size)
|
||||
@ -736,12 +803,17 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
|
||||
return None
|
||||
return f.read(length_of_header)
|
||||
|
||||
ATTR_UNSET={}
|
||||
|
||||
def set_attr(obj, attr, value):
|
||||
attrs = attr.split(".")
|
||||
for name in attrs[:-1]:
|
||||
obj = getattr(obj, name)
|
||||
prev = getattr(obj, attrs[-1])
|
||||
setattr(obj, attrs[-1], value)
|
||||
prev = getattr(obj, attrs[-1], ATTR_UNSET)
|
||||
if value is ATTR_UNSET:
|
||||
delattr(obj, attrs[-1])
|
||||
else:
|
||||
setattr(obj, attrs[-1], value)
|
||||
return prev
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
@ -1128,3 +1200,68 @@ def unpack_latents(combined_latent, latent_shapes):
|
||||
else:
|
||||
output_tensors = combined_latent
|
||||
return output_tensors
|
||||
|
||||
def detect_layer_quantization(state_dict, prefix):
|
||||
for k in state_dict:
|
||||
if k.startswith(prefix) and k.endswith(".comfy_quant"):
|
||||
logging.info("Found quantization metadata version 1")
|
||||
return {"mixed_ops": True}
|
||||
return None
|
||||
|
||||
def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
quant_metadata = None
|
||||
if "_quantization_metadata" not in metadata:
|
||||
scaled_fp8_key = "{}scaled_fp8".format(model_prefix)
|
||||
|
||||
if scaled_fp8_key in state_dict:
|
||||
scaled_fp8_weight = state_dict[scaled_fp8_key]
|
||||
scaled_fp8_dtype = scaled_fp8_weight.dtype
|
||||
if scaled_fp8_dtype == torch.float32:
|
||||
scaled_fp8_dtype = torch.float8_e4m3fn
|
||||
|
||||
if scaled_fp8_weight.nelement() == 2:
|
||||
full_precision_matrix_mult = True
|
||||
else:
|
||||
full_precision_matrix_mult = False
|
||||
|
||||
out_sd = {}
|
||||
layers = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if not k.startswith(model_prefix):
|
||||
out_sd[k] = state_dict[k]
|
||||
continue
|
||||
k_out = k
|
||||
w = state_dict.pop(k)
|
||||
layer = None
|
||||
if k_out.endswith(".scale_weight"):
|
||||
layer = k_out[:-len(".scale_weight")]
|
||||
k_out = "{}.weight_scale".format(layer)
|
||||
|
||||
if layer is not None:
|
||||
layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints
|
||||
if full_precision_matrix_mult:
|
||||
layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult
|
||||
layers[layer] = layer_conf
|
||||
|
||||
if k_out.endswith(".scale_input"):
|
||||
layer = k_out[:-len(".scale_input")]
|
||||
k_out = "{}.input_scale".format(layer)
|
||||
if w.item() == 1.0:
|
||||
continue
|
||||
|
||||
out_sd[k_out] = w
|
||||
|
||||
state_dict = out_sd
|
||||
quant_metadata = {"layers": layers}
|
||||
else:
|
||||
quant_metadata = json.loads(metadata["_quantization_metadata"])
|
||||
|
||||
if quant_metadata is not None:
|
||||
layers = quant_metadata["layers"]
|
||||
for k, v in layers.items():
|
||||
state_dict["{}.comfy_quant".format(k)] = torch.tensor(list(json.dumps(v).encode('utf-8')), dtype=torch.uint8)
|
||||
|
||||
return state_dict, metadata
|
||||
|
||||
@ -194,6 +194,7 @@ class LoRAAdapter(WeightAdapterBase):
|
||||
lora_diff = torch.mm(
|
||||
mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
|
||||
).reshape(weight.shape)
|
||||
del mat1, mat2
|
||||
if dora_scale is not None:
|
||||
weight = weight_decompose(
|
||||
dora_scale,
|
||||
|
||||
@ -13,6 +13,7 @@ from comfy.cli_args import args
|
||||
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
|
||||
"supports_preview_metadata": True,
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ import os
|
||||
import textwrap
|
||||
import threading
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, get_origin, get_args
|
||||
from typing import Optional, Type, get_origin, get_args, get_type_hints
|
||||
|
||||
|
||||
class TypeTracker:
|
||||
@ -220,11 +220,18 @@ class AsyncToSyncConverter:
|
||||
self._async_instance = async_class(*args, **kwargs)
|
||||
|
||||
# Handle annotated class attributes (like execution: Execution)
|
||||
# Get all annotations from the class hierarchy
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
# Get all annotations from the class hierarchy and resolve string annotations
|
||||
try:
|
||||
# get_type_hints resolves string annotations to actual type objects
|
||||
# This handles classes using 'from __future__ import annotations'
|
||||
all_annotations = get_type_hints(async_class)
|
||||
except Exception:
|
||||
# Fallback to raw annotations if get_type_hints fails
|
||||
# (e.g., for undefined forward references)
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
|
||||
# For each annotated attribute, check if it needs to be created or wrapped
|
||||
for attr_name, attr_type in all_annotations.items():
|
||||
@ -625,15 +632,19 @@ class AsyncToSyncConverter:
|
||||
"""Extract class attributes that are classes themselves."""
|
||||
class_attributes = []
|
||||
|
||||
# Get resolved type hints to handle string annotations
|
||||
try:
|
||||
type_hints = get_type_hints(async_class)
|
||||
except Exception:
|
||||
type_hints = {}
|
||||
|
||||
# Look for class attributes that are classes
|
||||
for name, attr in sorted(inspect.getmembers(async_class)):
|
||||
if isinstance(attr, type) and not name.startswith("_"):
|
||||
class_attributes.append((name, attr))
|
||||
elif (
|
||||
hasattr(async_class, "__annotations__")
|
||||
and name in async_class.__annotations__
|
||||
):
|
||||
annotation = async_class.__annotations__[name]
|
||||
elif name in type_hints:
|
||||
# Use resolved type hint instead of raw annotation
|
||||
annotation = type_hints[name]
|
||||
if isinstance(annotation, type):
|
||||
class_attributes.append((name, annotation))
|
||||
|
||||
@ -908,11 +919,15 @@ class AsyncToSyncConverter:
|
||||
attribute_mappings = {}
|
||||
|
||||
# First check annotations for typed attributes (including from parent classes)
|
||||
# Collect all annotations from the class hierarchy
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
# Resolve string annotations to actual types
|
||||
try:
|
||||
all_annotations = get_type_hints(async_class)
|
||||
except Exception:
|
||||
# Fallback to raw annotations
|
||||
all_annotations = {}
|
||||
for base_class in reversed(inspect.getmro(async_class)):
|
||||
if hasattr(base_class, "__annotations__"):
|
||||
all_annotations.update(base_class.__annotations__)
|
||||
|
||||
for attr_name, attr_type in sorted(all_annotations.items()):
|
||||
for class_name, class_type in class_attributes:
|
||||
|
||||
@ -5,11 +5,11 @@ from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal import ComfyAPIBase
|
||||
from comfy_api.internal.singleton import ProxiedSingleton
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents
|
||||
from . import _io as io
|
||||
from . import _ui as ui
|
||||
from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput
|
||||
from ._input_impl import VideoFromFile, VideoFromComponents
|
||||
from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL
|
||||
from . import _io_public as io
|
||||
from . import _ui_public as ui
|
||||
# from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_execution.progress import get_progress_state, PreviewImageTuple
|
||||
@ -80,7 +80,7 @@ class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
Called when an extension is loaded.
|
||||
This should be used to initialize any global resources neeeded by the extension.
|
||||
This should be used to initialize any global resources needed by the extension.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@ -104,6 +104,8 @@ class Types:
|
||||
VideoCodec = VideoCodec
|
||||
VideoContainer = VideoContainer
|
||||
VideoComponents = VideoComponents
|
||||
MESH = MESH
|
||||
VOXEL = VOXEL
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from fractions import Fraction
|
||||
from typing import Optional, Union, IO
|
||||
import io
|
||||
import av
|
||||
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
class VideoInput(ABC):
|
||||
"""
|
||||
@ -72,6 +73,33 @@ class VideoInput(ABC):
|
||||
frame_count = components.images.shape[0]
|
||||
return float(frame_count / components.frame_rate)
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
"""
|
||||
Returns the number of frames in the video.
|
||||
|
||||
Default implementation uses :meth:`get_components`, which may require
|
||||
loading all frames into memory. File-based implementations should
|
||||
override this method and use container/stream metadata instead.
|
||||
|
||||
Returns:
|
||||
Total number of frames as an integer.
|
||||
"""
|
||||
return int(self.get_components().images.shape[0])
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
"""
|
||||
Returns the frame rate of the video.
|
||||
|
||||
Default implementation materializes the video into memory via
|
||||
`get_components()`. Subclasses that can inspect the underlying
|
||||
container (e.g. `VideoFromFile`) should override this with a more
|
||||
efficient implementation.
|
||||
|
||||
Returns:
|
||||
Frame rate as a Fraction.
|
||||
"""
|
||||
return self.get_components().frame_rate
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
|
||||
@ -3,14 +3,14 @@ from av.container import InputContainer
|
||||
from av.subtitles.stream import SubtitleStream
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import AudioInput, VideoInput
|
||||
from .._input import AudioInput, VideoInput
|
||||
import av
|
||||
import io
|
||||
import json
|
||||
import numpy as np
|
||||
import math
|
||||
import torch
|
||||
from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents
|
||||
from .._util import VideoContainer, VideoCodec, VideoComponents
|
||||
|
||||
|
||||
def container_to_output_format(container_format: str | None) -> str | None:
|
||||
@ -121,6 +121,71 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
raise ValueError(f"Could not determine duration for file '{self.__file}'")
|
||||
|
||||
def get_frame_count(self) -> int:
|
||||
"""
|
||||
Returns the number of frames in the video without materializing them as
|
||||
torch tensors.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# 1. Prefer the frames field if available
|
||||
if video_stream.frames and video_stream.frames > 0:
|
||||
return int(video_stream.frames)
|
||||
|
||||
# 2. Try to estimate from duration and average_rate using only metadata
|
||||
if container.duration is not None and video_stream.average_rate:
|
||||
duration_seconds = float(container.duration / av.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
if (
|
||||
getattr(video_stream, "duration", None) is not None
|
||||
and getattr(video_stream, "time_base", None) is not None
|
||||
and video_stream.average_rate
|
||||
):
|
||||
duration_seconds = float(video_stream.duration * video_stream.time_base)
|
||||
estimated_frames = int(round(duration_seconds * float(video_stream.average_rate)))
|
||||
if estimated_frames > 0:
|
||||
return estimated_frames
|
||||
|
||||
# 3. Last resort: decode frames and count them (streaming)
|
||||
frame_count = 0
|
||||
container.seek(0)
|
||||
for packet in container.demux(video_stream):
|
||||
for _ in packet.decode():
|
||||
frame_count += 1
|
||||
|
||||
if frame_count == 0:
|
||||
raise ValueError(f"Could not determine frame count for file '{self.__file}'")
|
||||
return frame_count
|
||||
|
||||
def get_frame_rate(self) -> Fraction:
|
||||
"""
|
||||
Returns the average frame rate of the video using container metadata
|
||||
without decoding all frames.
|
||||
"""
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0)
|
||||
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = self._get_first_video_stream(container)
|
||||
# Preferred: use PyAV's average_rate (usually already a Fraction-like)
|
||||
if video_stream.average_rate:
|
||||
return Fraction(video_stream.average_rate)
|
||||
|
||||
# Fallback: estimate from frames + duration if available
|
||||
if video_stream.frames and container.duration:
|
||||
duration_seconds = float(container.duration / av.time_base)
|
||||
if duration_seconds > 0:
|
||||
return Fraction(video_stream.frames / duration_seconds).limit_denominator()
|
||||
|
||||
# Last resort: match get_components_internal default
|
||||
return Fraction(1)
|
||||
|
||||
def get_container_format(self) -> str:
|
||||
"""
|
||||
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
|
||||
@ -238,6 +303,13 @@ class VideoFromFile(VideoInput):
|
||||
packet.stream = stream_map[packet.stream]
|
||||
output_container.mux(packet)
|
||||
|
||||
def _get_first_video_stream(self, container: InputContainer):
|
||||
video_stream = next((s for s in container.streams if s.type == "video"), None)
|
||||
if video_stream is None:
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
return video_stream
|
||||
|
||||
|
||||
class VideoFromComponents(VideoInput):
|
||||
"""
|
||||
Class representing video input from tensors.
|
||||
@ -264,7 +336,10 @@ class VideoFromComponents(VideoInput):
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
|
||||
extra_kwargs = {}
|
||||
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||
extra_kwargs["format"] = format.value
|
||||
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
|
||||
# Add metadata before writing any streams
|
||||
if metadata is not None:
|
||||
for key, value in metadata.items():
|
||||
|
||||
@ -4,7 +4,8 @@ import copy
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import Counter
|
||||
from dataclasses import asdict, dataclass
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING
|
||||
from typing_extensions import NotRequired, final
|
||||
@ -25,8 +26,9 @@ if TYPE_CHECKING:
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
|
||||
prune_dict, shallow_clone_class)
|
||||
from comfy_api.latest._resources import Resources, ResourcesLocal
|
||||
from ._resources import Resources, ResourcesLocal
|
||||
from comfy_execution.graph_utils import ExecutionBlocker
|
||||
from ._util import MESH, VOXEL
|
||||
|
||||
# from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference
|
||||
|
||||
@ -149,6 +151,9 @@ class _IO_V3:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def validate(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
def io_type(self):
|
||||
return self.Parent.io_type
|
||||
@ -181,6 +186,9 @@ class Input(_IO_V3):
|
||||
def get_io_type(self):
|
||||
return _StringIOType(self.io_type)
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self]
|
||||
|
||||
class WidgetInput(Input):
|
||||
'''
|
||||
Base class for a V3 Input with widget.
|
||||
@ -560,6 +568,8 @@ class Conditioning(ComfyTypeIO):
|
||||
'''Used by WAN Camera.'''
|
||||
time_dim_concat: NotRequired[torch.Tensor]
|
||||
'''Used by WAN Phantom Subject.'''
|
||||
time_dim_replace: NotRequired[torch.Tensor]
|
||||
'''Used by Kandinsky5 I2V.'''
|
||||
|
||||
CondList = list[tuple[torch.Tensor, PooledDict]]
|
||||
Type = CondList
|
||||
@ -628,6 +638,10 @@ class UpscaleModel(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
Type = ImageModelDescriptor
|
||||
|
||||
@comfytype(io_type="LATENT_UPSCALE_MODEL")
|
||||
class LatentUpscaleModel(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="AUDIO")
|
||||
class Audio(ComfyTypeIO):
|
||||
class AudioDict(TypedDict):
|
||||
@ -656,11 +670,11 @@ class LossMap(ComfyTypeIO):
|
||||
|
||||
@comfytype(io_type="VOXEL")
|
||||
class Voxel(ComfyTypeIO):
|
||||
Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
|
||||
Type = VOXEL
|
||||
|
||||
@comfytype(io_type="MESH")
|
||||
class Mesh(ComfyTypeIO):
|
||||
Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3
|
||||
Type = MESH
|
||||
|
||||
@comfytype(io_type="HOOKS")
|
||||
class Hooks(ComfyTypeIO):
|
||||
@ -760,6 +774,13 @@ class AudioEncoder(ComfyTypeIO):
|
||||
class AudioEncoderOutput(ComfyTypeIO):
|
||||
Type = Any
|
||||
|
||||
@comfytype(io_type="TRACKS")
|
||||
class Tracks(ComfyTypeIO):
|
||||
class TrackDict(TypedDict):
|
||||
track_path: torch.Tensor
|
||||
track_visibility: torch.Tensor
|
||||
Type = TrackDict
|
||||
|
||||
@comfytype(io_type="COMFY_MULTITYPED_V3")
|
||||
class MultiType:
|
||||
Type = Any
|
||||
@ -809,13 +830,61 @@ class MultiType:
|
||||
else:
|
||||
return super().as_dict()
|
||||
|
||||
@comfytype(io_type="COMFY_MATCHTYPE_V3")
|
||||
class MatchType(ComfyTypeIO):
|
||||
class Template:
|
||||
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType):
|
||||
self.template_id = template_id
|
||||
# account for syntactic sugar
|
||||
if not isinstance(allowed_types, Iterable):
|
||||
allowed_types = [allowed_types]
|
||||
for t in allowed_types:
|
||||
if not isinstance(t, type):
|
||||
if not isinstance(t, _ComfyType):
|
||||
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}")
|
||||
else:
|
||||
if not issubclass(t, _ComfyType):
|
||||
raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}")
|
||||
self.allowed_types = allowed_types
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"template_id": self.template_id,
|
||||
"allowed_types": ",".join([t.io_type for t in self.allowed_types]),
|
||||
}
|
||||
|
||||
class Input(Input):
|
||||
def __init__(self, id: str, template: MatchType.Template,
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
self.template = template
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": self.template.as_dict(),
|
||||
})
|
||||
|
||||
class Output(Output):
|
||||
def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
self.template = template
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": self.template.as_dict(),
|
||||
})
|
||||
|
||||
class DynamicInput(Input, ABC):
|
||||
'''
|
||||
Abstract class for dynamic input registration.
|
||||
'''
|
||||
@abstractmethod
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
...
|
||||
return []
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
pass
|
||||
|
||||
|
||||
class DynamicOutput(Output, ABC):
|
||||
'''
|
||||
@ -825,99 +894,223 @@ class DynamicOutput(Output, ABC):
|
||||
is_output_list=False):
|
||||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
|
||||
@abstractmethod
|
||||
def get_dynamic(self) -> list[Output]:
|
||||
...
|
||||
return []
|
||||
|
||||
|
||||
@comfytype(io_type="COMFY_AUTOGROW_V3")
|
||||
class AutogrowDynamic(ComfyTypeI):
|
||||
Type = list[Any]
|
||||
class Input(DynamicInput):
|
||||
def __init__(self, id: str, template_input: Input, min: int=1, max: int=None,
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
self.template_input = template_input
|
||||
if min is not None:
|
||||
assert(min >= 1)
|
||||
if max is not None:
|
||||
assert(max >= 1)
|
||||
class Autogrow(ComfyTypeI):
|
||||
Type = dict[str, Any]
|
||||
_MaxNames = 100 # NOTE: max 100 names for sanity
|
||||
|
||||
class _AutogrowTemplate:
|
||||
def __init__(self, input: Input):
|
||||
# dynamic inputs are not allowed as the template input
|
||||
assert(not isinstance(input, DynamicInput))
|
||||
self.input = copy.copy(input)
|
||||
if isinstance(self.input, WidgetInput):
|
||||
self.input.force_input = True
|
||||
self.names: list[str] = []
|
||||
self.cached_inputs = {}
|
||||
|
||||
def _create_input(self, input: Input, name: str):
|
||||
new_input = copy.copy(self.input)
|
||||
new_input.id = name
|
||||
return new_input
|
||||
|
||||
def _create_cached_inputs(self):
|
||||
for name in self.names:
|
||||
self.cached_inputs[name] = self._create_input(self.input, name)
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return list(self.cached_inputs.values())
|
||||
|
||||
def as_dict(self):
|
||||
return prune_dict({
|
||||
"input": create_input_dict_v1([self.input]),
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
self.input.validate()
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
real_inputs = []
|
||||
for name, input in self.cached_inputs.items():
|
||||
if name in live_inputs:
|
||||
real_inputs.append(input)
|
||||
add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix)
|
||||
add_dynamic_id_mapping(d, real_inputs, curr_prefix)
|
||||
|
||||
class TemplatePrefix(_AutogrowTemplate):
|
||||
def __init__(self, input: Input, prefix: str, min: int=1, max: int=10):
|
||||
super().__init__(input)
|
||||
self.prefix = prefix
|
||||
assert(min >= 0)
|
||||
assert(max >= 1)
|
||||
assert(max <= Autogrow._MaxNames)
|
||||
self.min = min
|
||||
self.max = max
|
||||
self.names = [f"{self.prefix}{i}" for i in range(self.max)]
|
||||
self._create_cached_inputs()
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"prefix": self.prefix,
|
||||
"min": self.min,
|
||||
"max": self.max,
|
||||
})
|
||||
|
||||
class TemplateNames(_AutogrowTemplate):
|
||||
def __init__(self, input: Input, names: list[str], min: int=1):
|
||||
super().__init__(input)
|
||||
self.names = names[:Autogrow._MaxNames]
|
||||
assert(min >= 0)
|
||||
self.min = min
|
||||
self._create_cached_inputs()
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"names": self.names,
|
||||
"min": self.min,
|
||||
})
|
||||
|
||||
class Input(DynamicInput):
|
||||
def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames,
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
self.template = template
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": self.template.as_dict(),
|
||||
})
|
||||
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
curr_count = 1
|
||||
new_inputs = []
|
||||
for i in range(self.min):
|
||||
new_input = copy.copy(self.template_input)
|
||||
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
|
||||
if new_input.display_name is not None:
|
||||
new_input.display_name = f"{new_input.display_name}{curr_count}"
|
||||
new_input.optional = self.optional or new_input.optional
|
||||
if isinstance(self.template_input, WidgetInput):
|
||||
new_input.force_input = True
|
||||
new_inputs.append(new_input)
|
||||
curr_count += 1
|
||||
# pretend to expand up to max
|
||||
for i in range(curr_count-1, self.max):
|
||||
new_input = copy.copy(self.template_input)
|
||||
new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$"
|
||||
if new_input.display_name is not None:
|
||||
new_input.display_name = f"{new_input.display_name}{curr_count}"
|
||||
new_input.optional = True
|
||||
if isinstance(self.template_input, WidgetInput):
|
||||
new_input.force_input = True
|
||||
new_inputs.append(new_input)
|
||||
curr_count += 1
|
||||
return new_inputs
|
||||
return self.template.get_all()
|
||||
|
||||
@comfytype(io_type="COMFY_COMBODYNAMIC_V3")
|
||||
class ComboDynamic(ComfyTypeI):
|
||||
class Input(DynamicInput):
|
||||
def __init__(self, id: str):
|
||||
pass
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self] + self.template.get_all()
|
||||
|
||||
@comfytype(io_type="COMFY_MATCHTYPE_V3")
|
||||
class MatchType(ComfyTypeIO):
|
||||
class Template:
|
||||
def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]):
|
||||
self.template_id = template_id
|
||||
self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types
|
||||
def validate(self):
|
||||
self.template.validate()
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
curr_prefix = f"{curr_prefix}{self.id}."
|
||||
# need to remove self from expected inputs dictionary; replaced by template inputs in frontend
|
||||
for inner_dict in d.values():
|
||||
if self.id in inner_dict:
|
||||
del inner_dict[self.id]
|
||||
self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
||||
|
||||
@comfytype(io_type="COMFY_DYNAMICCOMBO_V3")
|
||||
class DynamicCombo(ComfyTypeI):
|
||||
Type = dict[str, Any]
|
||||
|
||||
class Option:
|
||||
def __init__(self, key: str, inputs: list[Input]):
|
||||
self.key = key
|
||||
self.inputs = inputs
|
||||
|
||||
def as_dict(self):
|
||||
return {
|
||||
"template_id": self.template_id,
|
||||
"allowed_types": "".join(t.io_type for t in self.allowed_types),
|
||||
"key": self.key,
|
||||
"inputs": create_input_dict_v1(self.inputs),
|
||||
}
|
||||
|
||||
class Input(DynamicInput):
|
||||
def __init__(self, id: str, template: MatchType.Template,
|
||||
def __init__(self, id: str, options: list[DynamicCombo.Option],
|
||||
display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, extra_dict)
|
||||
self.template = template
|
||||
self.options = options
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
# check if dynamic input's id is in live_inputs
|
||||
if self.id in live_inputs:
|
||||
curr_prefix = f"{curr_prefix}{self.id}."
|
||||
key = live_inputs[self.id]
|
||||
selected_option = None
|
||||
for option in self.options:
|
||||
if option.key == key:
|
||||
selected_option = option
|
||||
break
|
||||
if selected_option is not None:
|
||||
add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix)
|
||||
add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self)
|
||||
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
return [self]
|
||||
return [input for option in self.options for input in option.inputs]
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self] + [input for option in self.options for input in option.inputs]
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": self.template.as_dict(),
|
||||
"options": [o.as_dict() for o in self.options],
|
||||
})
|
||||
|
||||
class Output(DynamicOutput):
|
||||
def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None,
|
||||
is_output_list=False):
|
||||
super().__init__(id, display_name, tooltip, is_output_list)
|
||||
self.template = template
|
||||
def validate(self):
|
||||
# make sure all nested inputs are validated
|
||||
for option in self.options:
|
||||
for input in option.inputs:
|
||||
input.validate()
|
||||
|
||||
def get_dynamic(self) -> list[Output]:
|
||||
return [self]
|
||||
@comfytype(io_type="COMFY_DYNAMICSLOT_V3")
|
||||
class DynamicSlot(ComfyTypeI):
|
||||
Type = dict[str, Any]
|
||||
|
||||
class Input(DynamicInput):
|
||||
def __init__(self, slot: Input, inputs: list[Input],
|
||||
display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None):
|
||||
assert(not isinstance(slot, DynamicInput))
|
||||
self.slot = copy.copy(slot)
|
||||
self.slot.display_name = slot.display_name if slot.display_name is not None else display_name
|
||||
optional = True
|
||||
self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip
|
||||
self.slot.lazy = slot.lazy if slot.lazy is not None else lazy
|
||||
self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict
|
||||
super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict)
|
||||
self.inputs = inputs
|
||||
self.force_input = None
|
||||
# force widget inputs to have no widgets, otherwise this would be awkward
|
||||
if isinstance(self.slot, WidgetInput):
|
||||
self.force_input = True
|
||||
self.slot.force_input = True
|
||||
|
||||
def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''):
|
||||
if self.id in live_inputs:
|
||||
curr_prefix = f"{curr_prefix}{self.id}."
|
||||
add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix)
|
||||
add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix)
|
||||
|
||||
def get_dynamic(self) -> list[Input]:
|
||||
return [self.slot] + self.inputs
|
||||
|
||||
def get_all(self) -> list[Input]:
|
||||
return [self] + [self.slot] + self.inputs
|
||||
|
||||
def as_dict(self):
|
||||
return super().as_dict() | prune_dict({
|
||||
"template": self.template.as_dict(),
|
||||
"slotType": str(self.slot.get_io_type()),
|
||||
"inputs": create_input_dict_v1(self.inputs),
|
||||
"forceInput": self.force_input,
|
||||
})
|
||||
|
||||
def validate(self):
|
||||
self.slot.validate()
|
||||
for input in self.inputs:
|
||||
input.validate()
|
||||
|
||||
def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None):
|
||||
dynamic = d.setdefault("dynamic_paths", {})
|
||||
if self is not None:
|
||||
dynamic[self.id] = f"{curr_prefix}{self.id}"
|
||||
for i in inputs:
|
||||
if not isinstance(i, DynamicInput):
|
||||
dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}"
|
||||
|
||||
class V3Data(TypedDict):
|
||||
hidden_inputs: dict[str, Any]
|
||||
dynamic_paths: dict[str, Any]
|
||||
|
||||
class HiddenHolder:
|
||||
def __init__(self, unique_id: str, prompt: Any,
|
||||
@ -979,6 +1172,7 @@ class NodeInfoV1:
|
||||
output_is_list: list[bool]=None
|
||||
output_name: list[str]=None
|
||||
output_tooltips: list[str]=None
|
||||
output_matchtypes: list[str]=None
|
||||
name: str=None
|
||||
display_name: str=None
|
||||
description: str=None
|
||||
@ -1014,9 +1208,9 @@ class Schema:
|
||||
"""Display name of node."""
|
||||
category: str = "sd"
|
||||
"""The category of the node, as per the "Add Node" menu."""
|
||||
inputs: list[Input]=None
|
||||
outputs: list[Output]=None
|
||||
hidden: list[Hidden]=None
|
||||
inputs: list[Input] = field(default_factory=list)
|
||||
outputs: list[Output] = field(default_factory=list)
|
||||
hidden: list[Hidden] = field(default_factory=list)
|
||||
description: str=""
|
||||
"""Node description, shown as a tooltip when hovering over the node."""
|
||||
is_input_list: bool = False
|
||||
@ -1056,7 +1250,11 @@ class Schema:
|
||||
'''Validate the schema:
|
||||
- verify ids on inputs and outputs are unique - both internally and in relation to each other
|
||||
'''
|
||||
input_ids = [i.id for i in self.inputs] if self.inputs is not None else []
|
||||
nested_inputs: list[Input] = []
|
||||
if self.inputs is not None:
|
||||
for input in self.inputs:
|
||||
nested_inputs.extend(input.get_all())
|
||||
input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else []
|
||||
output_ids = [o.id for o in self.outputs] if self.outputs is not None else []
|
||||
input_set = set(input_ids)
|
||||
output_set = set(output_ids)
|
||||
@ -1072,6 +1270,13 @@ class Schema:
|
||||
issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.")
|
||||
if len(issues) > 0:
|
||||
raise ValueError("\n".join(issues))
|
||||
# validate inputs and outputs
|
||||
if self.inputs is not None:
|
||||
for input in self.inputs:
|
||||
input.validate()
|
||||
if self.outputs is not None:
|
||||
for output in self.outputs:
|
||||
output.validate()
|
||||
|
||||
def finalize(self):
|
||||
"""Add hidden based on selected schema options, and give outputs without ids default ids."""
|
||||
@ -1097,19 +1302,10 @@ class Schema:
|
||||
if output.id is None:
|
||||
output.id = f"_{i}_{output.io_type}_"
|
||||
|
||||
def get_v1_info(self, cls) -> NodeInfoV1:
|
||||
def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1:
|
||||
# NOTE: live_inputs will not be used anymore very soon and this will be done another way
|
||||
# get V1 inputs
|
||||
input = {
|
||||
"required": {}
|
||||
}
|
||||
if self.inputs:
|
||||
for i in self.inputs:
|
||||
if isinstance(i, DynamicInput):
|
||||
dynamic_inputs = i.get_dynamic()
|
||||
for d in dynamic_inputs:
|
||||
add_to_dict_v1(d, input)
|
||||
else:
|
||||
add_to_dict_v1(i, input)
|
||||
input = create_input_dict_v1(self.inputs, live_inputs)
|
||||
if self.hidden:
|
||||
for hidden in self.hidden:
|
||||
input.setdefault("hidden", {})[hidden.name] = (hidden.value,)
|
||||
@ -1118,12 +1314,24 @@ class Schema:
|
||||
output_is_list = []
|
||||
output_name = []
|
||||
output_tooltips = []
|
||||
output_matchtypes = []
|
||||
any_matchtypes = False
|
||||
if self.outputs:
|
||||
for o in self.outputs:
|
||||
output.append(o.io_type)
|
||||
output_is_list.append(o.is_output_list)
|
||||
output_name.append(o.display_name if o.display_name else o.io_type)
|
||||
output_tooltips.append(o.tooltip if o.tooltip else None)
|
||||
# special handling for MatchType
|
||||
if isinstance(o, MatchType.Output):
|
||||
output_matchtypes.append(o.template.template_id)
|
||||
any_matchtypes = True
|
||||
else:
|
||||
output_matchtypes.append(None)
|
||||
|
||||
# clear out lists that are all None
|
||||
if not any_matchtypes:
|
||||
output_matchtypes = None
|
||||
|
||||
info = NodeInfoV1(
|
||||
input=input,
|
||||
@ -1132,6 +1340,7 @@ class Schema:
|
||||
output_is_list=output_is_list,
|
||||
output_name=output_name,
|
||||
output_tooltips=output_tooltips,
|
||||
output_matchtypes=output_matchtypes,
|
||||
name=self.node_id,
|
||||
display_name=self.display_name,
|
||||
category=self.category,
|
||||
@ -1177,16 +1386,57 @@ class Schema:
|
||||
return info
|
||||
|
||||
|
||||
def add_to_dict_v1(i: Input, input: dict):
|
||||
def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict:
|
||||
input = {
|
||||
"required": {}
|
||||
}
|
||||
add_to_input_dict_v1(input, inputs, live_inputs)
|
||||
return input
|
||||
|
||||
def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''):
|
||||
for i in inputs:
|
||||
if isinstance(i, DynamicInput):
|
||||
add_to_dict_v1(i, d)
|
||||
if live_inputs is not None:
|
||||
i.expand_schema_for_dynamic(d, live_inputs, curr_prefix)
|
||||
else:
|
||||
add_to_dict_v1(i, d)
|
||||
|
||||
def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None):
|
||||
key = "optional" if i.optional else "required"
|
||||
as_dict = i.as_dict()
|
||||
# for v1, we don't want to include the optional key
|
||||
as_dict.pop("optional", None)
|
||||
input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict)
|
||||
if dynamic_dict is None:
|
||||
value = (i.get_io_type(), as_dict)
|
||||
else:
|
||||
value = (i.get_io_type(), as_dict, dynamic_dict)
|
||||
d.setdefault(key, {})[i.id] = value
|
||||
|
||||
def add_to_dict_v3(io: Input | Output, d: dict):
|
||||
d[io.id] = (io.get_io_type(), io.as_dict())
|
||||
|
||||
def build_nested_inputs(values: dict[str, Any], v3_data: V3Data):
|
||||
paths = v3_data.get("dynamic_paths", None)
|
||||
if paths is None:
|
||||
return values
|
||||
values = values.copy()
|
||||
result = {}
|
||||
|
||||
for key, path in paths.items():
|
||||
parts = path.split(".")
|
||||
current = result
|
||||
|
||||
for i, p in enumerate(parts):
|
||||
is_last = (i == len(parts) - 1)
|
||||
|
||||
if is_last:
|
||||
current[p] = values.pop(key, None)
|
||||
else:
|
||||
current = current.setdefault(p, {})
|
||||
|
||||
values.update(result)
|
||||
return values
|
||||
|
||||
|
||||
class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
@ -1306,12 +1556,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]:
|
||||
def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]:
|
||||
"""Creates clone of real node class to prevent monkey-patching."""
|
||||
c_type: type[ComfyNode] = cls if is_class(cls) else type(cls)
|
||||
type_clone: type[ComfyNode] = shallow_clone_class(c_type)
|
||||
# set hidden
|
||||
type_clone.hidden = HiddenHolder.from_dict(hidden_inputs)
|
||||
type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"])
|
||||
return type_clone
|
||||
|
||||
@final
|
||||
@ -1428,14 +1678,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal):
|
||||
|
||||
@final
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]:
|
||||
def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]:
|
||||
schema = cls.FINALIZE_SCHEMA()
|
||||
info = schema.get_v1_info(cls)
|
||||
info = schema.get_v1_info(cls, live_inputs)
|
||||
input = info.input
|
||||
if not include_hidden:
|
||||
input.pop("hidden", None)
|
||||
if return_schema:
|
||||
return input, schema
|
||||
v3_data: V3Data = {}
|
||||
dynamic = input.pop("dynamic_paths", None)
|
||||
if dynamic is not None:
|
||||
v3_data["dynamic_paths"] = dynamic
|
||||
return input, schema, v3_data
|
||||
return input
|
||||
|
||||
@final
|
||||
@ -1508,7 +1762,7 @@ class ComfyNode(_ComfyNodeBaseInternal):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, **kwargs) -> bool:
|
||||
def validate_inputs(cls, **kwargs) -> bool | str:
|
||||
"""Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1568,7 +1822,7 @@ class NodeOutput(_NodeOutputInternal):
|
||||
ui = data["ui"]
|
||||
if "expand" in data:
|
||||
expand = data["expand"]
|
||||
return cls(args=args, ui=ui, expand=expand)
|
||||
return cls(*args, ui=ui, expand=expand)
|
||||
|
||||
def __getitem__(self, index) -> Any:
|
||||
return self.args[index]
|
||||
@ -1623,6 +1877,7 @@ __all__ = [
|
||||
"StyleModel",
|
||||
"Gligen",
|
||||
"UpscaleModel",
|
||||
"LatentUpscaleModel",
|
||||
"Audio",
|
||||
"Video",
|
||||
"SVG",
|
||||
@ -1646,6 +1901,11 @@ __all__ = [
|
||||
"SEGS",
|
||||
"AnyType",
|
||||
"MultiType",
|
||||
"Tracks",
|
||||
# Dynamic Types
|
||||
"MatchType",
|
||||
# "DynamicCombo",
|
||||
# "Autogrow",
|
||||
# Other classes
|
||||
"HiddenHolder",
|
||||
"Hidden",
|
||||
@ -1656,4 +1916,5 @@ __all__ = [
|
||||
"NodeOutput",
|
||||
"add_to_dict_v1",
|
||||
"add_to_dict_v3",
|
||||
"V3Data",
|
||||
]
|
||||
|
||||
1
comfy_api/latest/_io_public.py
Normal file
1
comfy_api/latest/_io_public.py
Normal file
@ -0,0 +1 @@
|
||||
from ._io import * # noqa: F403
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Type
|
||||
|
||||
@ -21,7 +22,7 @@ import folder_paths
|
||||
|
||||
# used for image preview
|
||||
from comfy.cli_args import args
|
||||
from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
from ._io import ComfyNode, FolderType, Image, _UIOutput
|
||||
|
||||
|
||||
class SavedResult(dict):
|
||||
@ -318,9 +319,10 @@ class AudioSaveHelper:
|
||||
for key, value in metadata.items():
|
||||
output_container.metadata[key] = value
|
||||
|
||||
layout = "mono" if waveform.shape[0] == 1 else "stereo"
|
||||
# Set up the output stream with appropriate properties
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate)
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
|
||||
if quality == "64k":
|
||||
out_stream.bit_rate = 64000
|
||||
elif quality == "96k":
|
||||
@ -332,7 +334,7 @@ class AudioSaveHelper:
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
|
||||
if quality == "V0":
|
||||
# TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
@ -341,12 +343,12 @@ class AudioSaveHelper:
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: # format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate)
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(
|
||||
waveform.movedim(0, 1).reshape(1, -1).float().numpy(),
|
||||
format="flt",
|
||||
layout="mono" if waveform.shape[0] == 1 else "stereo",
|
||||
layout=layout,
|
||||
)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
@ -436,9 +438,19 @@ class PreviewUI3D(_UIOutput):
|
||||
def __init__(self, model_file, camera_info, **kwargs):
|
||||
self.model_file = model_file
|
||||
self.camera_info = camera_info
|
||||
self.bg_image_path = None
|
||||
bg_image = kwargs.get("bg_image", None)
|
||||
if bg_image is not None:
|
||||
img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8)
|
||||
img = PILImage.fromarray(img_array)
|
||||
temp_dir = folder_paths.get_temp_directory()
|
||||
filename = f"bg_{uuid.uuid4().hex}.png"
|
||||
bg_image_path = os.path.join(temp_dir, filename)
|
||||
img.save(bg_image_path, compress_level=1)
|
||||
self.bg_image_path = f"temp/{filename}"
|
||||
|
||||
def as_dict(self):
|
||||
return {"result": [self.model_file, self.camera_info]}
|
||||
return {"result": [self.model_file, self.camera_info, self.bg_image_path]}
|
||||
|
||||
|
||||
class PreviewText(_UIOutput):
|
||||
|
||||
1
comfy_api/latest/_ui_public.py
Normal file
1
comfy_api/latest/_ui_public.py
Normal file
@ -0,0 +1 @@
|
||||
from ._ui import * # noqa: F403
|
||||
@ -1,8 +1,11 @@
|
||||
from .video_types import VideoContainer, VideoCodec, VideoComponents
|
||||
from .geometry_types import VOXEL, MESH
|
||||
|
||||
__all__ = [
|
||||
# Utility Types
|
||||
"VideoContainer",
|
||||
"VideoCodec",
|
||||
"VideoComponents",
|
||||
"VOXEL",
|
||||
"MESH",
|
||||
]
|
||||
|
||||
12
comfy_api/latest/_util/geometry_types.py
Normal file
12
comfy_api/latest/_util/geometry_types.py
Normal file
@ -0,0 +1,12 @@
|
||||
import torch
|
||||
|
||||
|
||||
class VOXEL:
|
||||
def __init__(self, data: torch.Tensor):
|
||||
self.data = data
|
||||
|
||||
|
||||
class MESH:
|
||||
def __init__(self, vertices: torch.Tensor, faces: torch.Tensor):
|
||||
self.vertices = vertices
|
||||
self.faces = faces
|
||||
@ -3,7 +3,7 @@ from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
from comfy_api.latest._input import ImageInput, AudioInput
|
||||
from .._input import ImageInput, AudioInput
|
||||
|
||||
class VideoCodec(str, Enum):
|
||||
AUTO = "auto"
|
||||
|
||||
@ -6,7 +6,7 @@ from comfy_api.latest import (
|
||||
)
|
||||
from typing import Type, TYPE_CHECKING
|
||||
from comfy_api.internal.async_to_sync import create_sync_class
|
||||
from comfy_api.latest import io, ui, ComfyExtension #noqa: F401
|
||||
from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401
|
||||
|
||||
|
||||
class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest):
|
||||
@ -42,4 +42,8 @@ __all__ = [
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"ComfyExtension",
|
||||
"io",
|
||||
"IO",
|
||||
"ui",
|
||||
"UI",
|
||||
]
|
||||
|
||||
@ -70,6 +70,29 @@ class BFLFluxProGenerateRequest(BaseModel):
|
||||
# )
|
||||
|
||||
|
||||
class Flux2ProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
width: int = Field(1024, description="Must be a multiple of 32.")
|
||||
height: int = Field(768, description="Must be a multiple of 32.")
|
||||
seed: int | None = Field(None)
|
||||
prompt_upsampling: bool | None = Field(None)
|
||||
input_image: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_2: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_3: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_4: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_5: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_6: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation")
|
||||
safety_tolerance: int | None = Field(
|
||||
5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5
|
||||
)
|
||||
output_format: str | None = Field(
|
||||
"png", description="Output format for the generated image. Can be 'jpeg' or 'png'."
|
||||
)
|
||||
|
||||
|
||||
class BFLFluxKontextProGenerateRequest(BaseModel):
|
||||
prompt: str = Field(..., description='The text prompt for what you wannt to edit.')
|
||||
input_image: Optional[str] = Field(None, description='Image to edit in base64 format')
|
||||
@ -109,8 +132,9 @@ class BFLFluxProUltraGenerateRequest(BaseModel):
|
||||
|
||||
|
||||
class BFLFluxProGenerateResponse(BaseModel):
|
||||
id: str = Field(..., description='The unique identifier for the generation task.')
|
||||
polling_url: str = Field(..., description='URL to poll for the generation result.')
|
||||
id: str = Field(..., description="The unique identifier for the generation task.")
|
||||
polling_url: str = Field(..., description="URL to poll for the generation result.")
|
||||
cost: float | None = Field(None, description="Price in cents")
|
||||
|
||||
|
||||
class BFLStatus(str, Enum):
|
||||
|
||||
144
comfy_api_nodes/apis/bytedance_api.py
Normal file
144
comfy_api_nodes/apis/bytedance_api.py
Normal file
@ -0,0 +1,144 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str | None = Field("url")
|
||||
size: str | None = Field(None)
|
||||
seed: int | None = Field(0, ge=0, le=2147483647)
|
||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||
watermark: bool | None = Field(True)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str | None = Field("url")
|
||||
image: str = Field(..., description="Base64 encoded string or image URL")
|
||||
size: str | None = Field("adaptive")
|
||||
seed: int | None = Field(..., ge=0, le=2147483647)
|
||||
guidance_scale: float | None = Field(..., ge=1.0, le=10.0)
|
||||
watermark: bool | None = Field(True)
|
||||
|
||||
|
||||
class Seedream4Options(BaseModel):
|
||||
max_images: int = Field(15)
|
||||
|
||||
|
||||
class Seedream4TaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
prompt: str = Field(...)
|
||||
response_format: str = Field("url")
|
||||
image: list[str] | None = Field(None, description="Image URLs")
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
sequential_image_generation: str = Field("disabled")
|
||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||
watermark: bool = Field(True)
|
||||
|
||||
|
||||
class ImageTaskCreationResponse(BaseModel):
|
||||
model: str = Field(...)
|
||||
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
|
||||
data: list = Field([], description="Contains information about the generated image(s).")
|
||||
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
|
||||
|
||||
|
||||
class TaskTextContent(BaseModel):
|
||||
type: str = Field("text")
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContent(BaseModel):
|
||||
type: str = Field("image_url")
|
||||
image_url: TaskImageContentUrl = Field(...)
|
||||
role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
model: str = Field(...)
|
||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||
error: TaskStatusError | None = Field(None)
|
||||
content: TaskStatusResult | None = Field(None)
|
||||
|
||||
|
||||
RECOMMENDED_PRESETS = [
|
||||
("1024x1024 (1:1)", 1024, 1024),
|
||||
("864x1152 (3:4)", 864, 1152),
|
||||
("1152x864 (4:3)", 1152, 864),
|
||||
("1280x720 (16:9)", 1280, 720),
|
||||
("720x1280 (9:16)", 720, 1280),
|
||||
("832x1248 (2:3)", 832, 1248),
|
||||
("1248x832 (3:2)", 1248, 832),
|
||||
("1512x648 (21:9)", 1512, 648),
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("2304x1728 (4:3)", 2304, 1728),
|
||||
("1728x2304 (3:4)", 1728, 2304),
|
||||
("2560x1440 (16:9)", 2560, 1440),
|
||||
("1440x2560 (9:16)", 1440, 2560),
|
||||
("2496x1664 (3:2)", 2496, 1664),
|
||||
("1664x2496 (2:3)", 1664, 2496),
|
||||
("3024x1296 (21:9)", 3024, 1296),
|
||||
("4096x4096 (1:1)", 4096, 4096),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"seedance-1-0-lite-t2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-lite-i2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-pro-250528": {
|
||||
"480p": 70,
|
||||
"720p": 85,
|
||||
"1080p": 115,
|
||||
},
|
||||
"seedance-1-0-pro-fast-251015": {
|
||||
"480p": 50,
|
||||
"720p": 65,
|
||||
"1080p": 100,
|
||||
},
|
||||
}
|
||||
@ -58,8 +58,14 @@ class GeminiInlineData(BaseModel):
|
||||
mimeType: GeminiMimeType | None = Field(None)
|
||||
|
||||
|
||||
class GeminiFileData(BaseModel):
|
||||
fileUri: str | None = Field(None)
|
||||
mimeType: GeminiMimeType | None = Field(None)
|
||||
|
||||
|
||||
class GeminiPart(BaseModel):
|
||||
inlineData: GeminiInlineData | None = Field(None)
|
||||
fileData: GeminiFileData | None = Field(None)
|
||||
text: str | None = Field(None)
|
||||
|
||||
|
||||
@ -68,7 +74,7 @@ class GeminiTextPart(BaseModel):
|
||||
|
||||
|
||||
class GeminiContent(BaseModel):
|
||||
parts: list[GeminiPart] = Field(...)
|
||||
parts: list[GeminiPart] = Field([])
|
||||
role: GeminiRole = Field(..., examples=["user"])
|
||||
|
||||
|
||||
@ -78,15 +84,7 @@ class GeminiSystemInstructionContent(BaseModel):
|
||||
description="A list of ordered parts that make up a single message. "
|
||||
"Different parts may have different IANA MIME types.",
|
||||
)
|
||||
role: GeminiRole = Field(
|
||||
...,
|
||||
description="The identity of the entity that creates the message. "
|
||||
"The following values are supported: "
|
||||
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
|
||||
"model: This indicates that the message is generated by the model. "
|
||||
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
|
||||
"For non-multi-turn conversations, this field can be left blank or unset.",
|
||||
)
|
||||
role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.")
|
||||
|
||||
|
||||
class GeminiFunctionDeclaration(BaseModel):
|
||||
@ -113,14 +111,14 @@ class GeminiGenerationConfig(BaseModel):
|
||||
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
|
||||
seed: int | None = Field(None)
|
||||
stopSequences: list[str] | None = Field(None)
|
||||
temperature: float | None = Field(1, ge=0.0, le=2.0)
|
||||
topK: int | None = Field(40, ge=1)
|
||||
topP: float | None = Field(0.95, ge=0.0, le=1.0)
|
||||
temperature: float | None = Field(None, ge=0.0, le=2.0)
|
||||
topK: int | None = Field(None, ge=1)
|
||||
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class GeminiImageConfig(BaseModel):
|
||||
aspectRatio: str | None = Field(None)
|
||||
resolution: str | None = Field(None)
|
||||
imageSize: str | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||
@ -227,3 +225,4 @@ class GeminiGenerateContentResponse(BaseModel):
|
||||
candidates: list[GeminiCandidate] | None = Field(None)
|
||||
promptFeedback: GeminiPromptFeedback | None = Field(None)
|
||||
usageMetadata: GeminiUsageMetadata | None = Field(None)
|
||||
modelVersion: str | None = Field(None)
|
||||
|
||||
104
comfy_api_nodes/apis/kling_api.py
Normal file
104
comfy_api_nodes/apis/kling_api.py
Normal file
@ -0,0 +1,104 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OmniProText2VideoRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-video-o1")
|
||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
|
||||
|
||||
class OmniParamImage(BaseModel):
|
||||
image_url: str = Field(...)
|
||||
type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'")
|
||||
|
||||
|
||||
class OmniParamVideo(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
refer_type: str | None = Field(..., description="Can be 'base' or 'feature'")
|
||||
keep_original_sound: str = Field(..., description="'yes' or 'no'")
|
||||
|
||||
|
||||
class OmniProFirstLastFrameRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-video-o1")
|
||||
image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7)
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
|
||||
|
||||
class OmniProReferences2VideoRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-video-o1")
|
||||
aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
image_list: list[OmniParamImage] | None = Field(
|
||||
None, max_length=7, description="Max length 4 when video is present."
|
||||
)
|
||||
video_list: list[OmniParamVideo] | None = Field(None, max_length=1)
|
||||
duration: str | None = Field(..., description="From 3 to 10.")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
|
||||
|
||||
class TaskStatusVideoResult(BaseModel):
|
||||
duration: str | None = Field(None, description="Total video duration")
|
||||
id: str | None = Field(None, description="Generated video ID")
|
||||
url: str | None = Field(None, description="URL for generated video")
|
||||
|
||||
|
||||
class TaskStatusImageResult(BaseModel):
|
||||
index: int = Field(..., description="Image Number,0-9")
|
||||
url: str = Field(..., description="URL for generated image")
|
||||
|
||||
|
||||
class TaskStatusResults(BaseModel):
|
||||
videos: list[TaskStatusVideoResult] | None = Field(None)
|
||||
images: list[TaskStatusImageResult] | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusResponseData(BaseModel):
|
||||
created_at: int | None = Field(None, description="Task creation time")
|
||||
updated_at: int | None = Field(None, description="Task update time")
|
||||
task_status: str | None = None
|
||||
task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.")
|
||||
task_id: str | None = Field(None, description="Task ID")
|
||||
task_result: TaskStatusResults | None = Field(None)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
code: int | None = Field(None, description="Error code")
|
||||
message: str | None = Field(None, description="Error message")
|
||||
request_id: str | None = Field(None, description="Request ID")
|
||||
data: TaskStatusResponseData | None = Field(None)
|
||||
|
||||
|
||||
class OmniImageParamImage(BaseModel):
|
||||
image: str = Field(...)
|
||||
|
||||
|
||||
class OmniProImageRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-image-o1")
|
||||
resolution: str = Field(..., description="'1k' or '2k'")
|
||||
aspect_ratio: str | None = Field(...)
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
n: int | None = Field(1, le=9)
|
||||
image_list: list[OmniImageParamImage] | None = Field(..., max_length=10)
|
||||
|
||||
|
||||
class TextToVideoWithAudioRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-v2-6")
|
||||
aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'")
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
|
||||
|
||||
class ImageToVideoWithAudioRequest(BaseModel):
|
||||
model_name: str = Field(..., description="kling-v2-6")
|
||||
image: str = Field(...)
|
||||
duration: str = Field(..., description="'5' or '10'")
|
||||
prompt: str = Field(...)
|
||||
mode: str = Field("pro")
|
||||
sound: str = Field(..., description="'on' or 'off'")
|
||||
133
comfy_api_nodes/apis/topaz_api.py
Normal file
133
comfy_api_nodes/apis/topaz_api.py
Normal file
@ -0,0 +1,133 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ImageEnhanceRequest(BaseModel):
|
||||
model: str = Field("Reimagine")
|
||||
output_format: str = Field("jpeg")
|
||||
subject_detection: str = Field("All")
|
||||
face_enhancement: bool = Field(True)
|
||||
face_enhancement_creativity: float = Field(0, description="Is ignored if face_enhancement is false")
|
||||
face_enhancement_strength: float = Field(0.8, description="Is ignored if face_enhancement is false")
|
||||
source_url: str = Field(...)
|
||||
output_width: Optional[int] = Field(None)
|
||||
output_height: Optional[int] = Field(None)
|
||||
crop_to_fill: bool = Field(False)
|
||||
prompt: Optional[str] = Field(None, description="Text prompt for creative upscaling guidance")
|
||||
creativity: int = Field(3, description="Creativity settings range from 1 to 9")
|
||||
face_preservation: str = Field("true", description="To preserve the identity of characters")
|
||||
color_preservation: str = Field("true", description="To preserve the original color")
|
||||
|
||||
|
||||
class ImageAsyncTaskResponse(BaseModel):
|
||||
process_id: str = Field(...)
|
||||
|
||||
|
||||
class ImageStatusResponse(BaseModel):
|
||||
process_id: str = Field(...)
|
||||
status: str = Field(...)
|
||||
progress: Optional[int] = Field(None)
|
||||
credits: int = Field(...)
|
||||
|
||||
|
||||
class ImageDownloadResponse(BaseModel):
|
||||
download_url: str = Field(...)
|
||||
expiry: int = Field(...)
|
||||
|
||||
|
||||
class Resolution(BaseModel):
|
||||
width: int = Field(...)
|
||||
height: int = Field(...)
|
||||
|
||||
|
||||
class CreateCreateVideoRequestSource(BaseModel):
|
||||
container: str = Field(...)
|
||||
size: int = Field(..., description="Size of the video file in bytes")
|
||||
duration: int = Field(..., description="Duration of the video file in seconds")
|
||||
frameCount: int = Field(..., description="Total number of frames in the video")
|
||||
frameRate: int = Field(...)
|
||||
resolution: Resolution = Field(...)
|
||||
|
||||
|
||||
class VideoFrameInterpolationFilter(BaseModel):
|
||||
model: str = Field(...)
|
||||
slowmo: Optional[int] = Field(None)
|
||||
fps: int = Field(...)
|
||||
duplicate: bool = Field(...)
|
||||
duplicate_threshold: float = Field(...)
|
||||
|
||||
|
||||
class VideoEnhancementFilter(BaseModel):
|
||||
model: str = Field(...)
|
||||
auto: Optional[str] = Field(None, description="Auto, Manual, Relative")
|
||||
focusFixLevel: Optional[str] = Field(None, description="Downscales video input for correction of blurred subjects")
|
||||
compression: Optional[float] = Field(None, description="Strength of compression recovery")
|
||||
details: Optional[float] = Field(None, description="Amount of detail reconstruction")
|
||||
prenoise: Optional[float] = Field(None, description="Amount of noise to add to input to reduce over-smoothing")
|
||||
noise: Optional[float] = Field(None, description="Amount of noise reduction")
|
||||
halo: Optional[float] = Field(None, description="Amount of halo reduction")
|
||||
preblur: Optional[float] = Field(None, description="Anti-aliasing and deblurring strength")
|
||||
blur: Optional[float] = Field(None, description="Amount of sharpness applied")
|
||||
grain: Optional[float] = Field(None, description="Grain after AI model processing")
|
||||
grainSize: Optional[float] = Field(None, description="Size of generated grain")
|
||||
recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video")
|
||||
creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only")
|
||||
isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only")
|
||||
|
||||
|
||||
class OutputInformationVideo(BaseModel):
|
||||
resolution: Resolution = Field(...)
|
||||
frameRate: int = Field(...)
|
||||
audioCodec: Optional[str] = Field(..., description="Required if audioTransfer is Copy or Convert")
|
||||
audioTransfer: str = Field(..., description="Copy, Convert, None")
|
||||
dynamicCompressionLevel: str = Field(..., description="Low, Mid, High")
|
||||
|
||||
|
||||
class Overrides(BaseModel):
|
||||
isPaidDiffusion: bool = Field(True)
|
||||
|
||||
|
||||
class CreateVideoRequest(BaseModel):
|
||||
source: CreateCreateVideoRequestSource = Field(...)
|
||||
filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...)
|
||||
output: OutputInformationVideo = Field(...)
|
||||
overrides: Overrides = Field(Overrides(isPaidDiffusion=True))
|
||||
|
||||
|
||||
class CreateVideoResponse(BaseModel):
|
||||
requestId: str = Field(...)
|
||||
|
||||
|
||||
class VideoAcceptResponse(BaseModel):
|
||||
uploadId: str = Field(...)
|
||||
urls: list[str] = Field(...)
|
||||
|
||||
|
||||
class VideoCompleteUploadRequestPart(BaseModel):
|
||||
partNum: int = Field(...)
|
||||
eTag: str = Field(...)
|
||||
|
||||
|
||||
class VideoCompleteUploadRequest(BaseModel):
|
||||
uploadResults: list[VideoCompleteUploadRequestPart] = Field(...)
|
||||
|
||||
|
||||
class VideoCompleteUploadResponse(BaseModel):
|
||||
message: str = Field(..., description="Confirmation message")
|
||||
|
||||
|
||||
class VideoStatusResponseEstimates(BaseModel):
|
||||
cost: list[int] = Field(...)
|
||||
|
||||
|
||||
class VideoStatusResponseDownloadUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class VideoStatusResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
estimates: Optional[VideoStatusResponseEstimates] = Field(None)
|
||||
progress: Optional[float] = Field(None)
|
||||
message: Optional[str] = Field("")
|
||||
download: Optional[VideoStatusResponseDownloadUrl] = Field(None)
|
||||
@ -1,34 +1,21 @@
|
||||
from typing import Optional, Union
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Image2(BaseModel):
|
||||
bytesBase64Encoded: str
|
||||
gcsUri: Optional[str] = None
|
||||
mimeType: Optional[str] = None
|
||||
class VeoRequestInstanceImage(BaseModel):
|
||||
bytesBase64Encoded: str | None = Field(None)
|
||||
gcsUri: str | None = Field(None)
|
||||
mimeType: str | None = Field(None)
|
||||
|
||||
|
||||
class Image3(BaseModel):
|
||||
bytesBase64Encoded: Optional[str] = None
|
||||
gcsUri: str
|
||||
mimeType: Optional[str] = None
|
||||
|
||||
|
||||
class Instance1(BaseModel):
|
||||
image: Optional[Union[Image2, Image3]] = Field(
|
||||
None, description='Optional image to guide video generation'
|
||||
)
|
||||
class VeoRequestInstance(BaseModel):
|
||||
image: VeoRequestInstanceImage | None = Field(None)
|
||||
lastFrame: VeoRequestInstanceImage | None = Field(None)
|
||||
prompt: str = Field(..., description='Text description of the video')
|
||||
|
||||
|
||||
class PersonGeneration1(str, Enum):
|
||||
ALLOW = 'ALLOW'
|
||||
BLOCK = 'BLOCK'
|
||||
|
||||
|
||||
class Parameters1(BaseModel):
|
||||
class VeoRequestParameters(BaseModel):
|
||||
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
|
||||
durationSeconds: Optional[int] = None
|
||||
enhancePrompt: Optional[bool] = None
|
||||
@ -37,17 +24,18 @@ class Parameters1(BaseModel):
|
||||
description='Generate audio for the video. Only supported by veo 3 models.',
|
||||
)
|
||||
negativePrompt: Optional[str] = None
|
||||
personGeneration: Optional[PersonGeneration1] = None
|
||||
personGeneration: str | None = Field(None, description="ALLOW or BLOCK")
|
||||
sampleCount: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
storageUri: Optional[str] = Field(
|
||||
None, description='Optional Cloud Storage URI to upload the video'
|
||||
)
|
||||
resolution: str | None = Field(None)
|
||||
|
||||
|
||||
class VeoGenVidRequest(BaseModel):
|
||||
instances: Optional[list[Instance1]] = None
|
||||
parameters: Optional[Parameters1] = None
|
||||
instances: list[VeoRequestInstance] | None = Field(None)
|
||||
parameters: VeoRequestParameters | None = Field(None)
|
||||
|
||||
|
||||
class VeoGenVidResponse(BaseModel):
|
||||
@ -97,7 +85,7 @@ class Response1(BaseModel):
|
||||
raiMediaFilteredReasons: Optional[list[str]] = Field(
|
||||
None, description='Reasons why media was filtered by responsible AI policies'
|
||||
)
|
||||
videos: Optional[list[Video]] = None
|
||||
videos: Optional[list[Video]] = Field(None)
|
||||
|
||||
|
||||
class VeoGenVidPollResponse(BaseModel):
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from inspect import cleandoc
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
@ -9,15 +9,16 @@ from comfy_api_nodes.apis.bfl_api import (
|
||||
BFLFluxExpandImageRequest,
|
||||
BFLFluxFillImageRequest,
|
||||
BFLFluxKontextProGenerateRequest,
|
||||
BFLFluxProGenerateRequest,
|
||||
BFLFluxProGenerateResponse,
|
||||
BFLFluxProUltraGenerateRequest,
|
||||
BFLFluxStatusResponse,
|
||||
BFLStatus,
|
||||
Flux2ProGenerateRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
resize_mask_to_image,
|
||||
sync_op,
|
||||
@ -116,7 +117,7 @@ class FluxProUltraImageNode(IO.ComfyNode):
|
||||
prompt_upsampling: bool = False,
|
||||
raw: bool = False,
|
||||
seed: int = 0,
|
||||
image_prompt: Optional[torch.Tensor] = None,
|
||||
image_prompt: torch.Tensor | None = None,
|
||||
image_prompt_strength: float = 0.1,
|
||||
) -> IO.NodeOutput:
|
||||
if image_prompt is None:
|
||||
@ -230,7 +231,7 @@ class FluxKontextProImageNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
guidance: float,
|
||||
steps: int,
|
||||
input_image: Optional[torch.Tensor] = None,
|
||||
input_image: torch.Tensor | None = None,
|
||||
seed=0,
|
||||
prompt_upsampling=False,
|
||||
) -> IO.NodeOutput:
|
||||
@ -280,124 +281,6 @@ class FluxKontextMaxImageNode(FluxKontextProImageNode):
|
||||
DISPLAY_NAME = "Flux.1 Kontext [max] Image"
|
||||
|
||||
|
||||
class FluxProImageNode(IO.ComfyNode):
|
||||
"""
|
||||
Generates images synchronously based on prompt and resolution.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="FluxProImageNode",
|
||||
display_name="Flux 1.1 [pro] Image",
|
||||
category="api node/image/BFL",
|
||||
description=cleandoc(cls.__doc__ or ""),
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. "
|
||||
"If active, automatically modifies the prompt for more creative generation, "
|
||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"width",
|
||||
default=1024,
|
||||
min=256,
|
||||
max=1440,
|
||||
step=32,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"height",
|
||||
default=768,
|
||||
min=256,
|
||||
max=1440,
|
||||
step=32,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"image_prompt",
|
||||
optional=True,
|
||||
),
|
||||
# "image_prompt_strength": (
|
||||
# IO.FLOAT,
|
||||
# {
|
||||
# "default": 0.1,
|
||||
# "min": 0.0,
|
||||
# "max": 1.0,
|
||||
# "step": 0.01,
|
||||
# "tooltip": "Blend between the prompt and the image prompt.",
|
||||
# },
|
||||
# ),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
prompt_upsampling,
|
||||
width: int,
|
||||
height: int,
|
||||
seed=0,
|
||||
image_prompt=None,
|
||||
# image_prompt_strength=0.1,
|
||||
) -> IO.NodeOutput:
|
||||
image_prompt = image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path="/proxy/bfl/flux-pro-1.1/generate",
|
||||
method="POST",
|
||||
),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=BFLFluxProGenerateRequest(
|
||||
prompt=prompt,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
width=width,
|
||||
height=height,
|
||||
seed=seed,
|
||||
image_prompt=image_prompt,
|
||||
),
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class FluxProExpandNode(IO.ComfyNode):
|
||||
"""
|
||||
Outpaints image based on prompt.
|
||||
@ -640,16 +523,125 @@ class FluxProFillNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class Flux2ProImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="Flux2ProImageNode",
|
||||
display_name="Flux.2 [pro] Image",
|
||||
category="api node/image/BFL",
|
||||
description="Generates images synchronously based on prompt and resolution.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Prompt for the image generation or edit",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"width",
|
||||
default=1024,
|
||||
min=256,
|
||||
max=2048,
|
||||
step=32,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"height",
|
||||
default=768,
|
||||
min=256,
|
||||
max=2048,
|
||||
step=32,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"prompt_upsampling",
|
||||
default=False,
|
||||
tooltip="Whether to perform upsampling on the prompt. "
|
||||
"If active, automatically modifies the prompt for more creative generation, "
|
||||
"but results are nondeterministic (same seed will not produce exactly the same result).",
|
||||
),
|
||||
IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
width: int,
|
||||
height: int,
|
||||
seed: int,
|
||||
prompt_upsampling: bool,
|
||||
images: torch.Tensor | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
reference_images = {}
|
||||
if images is not None:
|
||||
if get_number_of_images(images) > 9:
|
||||
raise ValueError("The current maximum number of supported images is 9.")
|
||||
for image_index in range(images.shape[0]):
|
||||
key_name = f"input_image_{image_index + 1}" if image_index else "input_image"
|
||||
reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048)
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"),
|
||||
response_model=BFLFluxProGenerateResponse,
|
||||
data=Flux2ProGenerateRequest(
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
seed=seed,
|
||||
prompt_upsampling=prompt_upsampling,
|
||||
**reference_images,
|
||||
),
|
||||
)
|
||||
|
||||
def price_extractor(_r: BaseModel) -> float | None:
|
||||
return None if initial_response.cost is None else initial_response.cost / 100
|
||||
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(initial_response.polling_url),
|
||||
response_model=BFLFluxStatusResponse,
|
||||
status_extractor=lambda r: r.status,
|
||||
progress_extractor=lambda r: r.progress,
|
||||
price_extractor=price_extractor,
|
||||
completed_statuses=[BFLStatus.ready],
|
||||
failed_statuses=[
|
||||
BFLStatus.request_moderated,
|
||||
BFLStatus.content_moderated,
|
||||
BFLStatus.error,
|
||||
BFLStatus.task_not_found,
|
||||
],
|
||||
queued_statuses=[],
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"]))
|
||||
|
||||
|
||||
class BFLExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
FluxProUltraImageNode,
|
||||
# FluxProImageNode,
|
||||
FluxKontextProImageNode,
|
||||
FluxKontextMaxImageNode,
|
||||
FluxProExpandNode,
|
||||
FluxProFillNode,
|
||||
Flux2ProImageNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,13 +1,27 @@
|
||||
import logging
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.bytedance_api import (
|
||||
RECOMMENDED_PRESETS,
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4,
|
||||
VIDEO_TASKS_EXECUTION_TIME,
|
||||
Image2ImageTaskCreationRequest,
|
||||
Image2VideoTaskCreationRequest,
|
||||
ImageTaskCreationResponse,
|
||||
Seedream4Options,
|
||||
Seedream4TaskCreationRequest,
|
||||
TaskCreationResponse,
|
||||
TaskImageContent,
|
||||
TaskImageContentUrl,
|
||||
TaskStatusResponse,
|
||||
TaskTextContent,
|
||||
Text2ImageTaskCreationRequest,
|
||||
Text2VideoTaskCreationRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
@ -29,162 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||
|
||||
|
||||
class Text2ImageModelName(str, Enum):
|
||||
seedream_3 = "seedream-3-0-t2i-250415"
|
||||
|
||||
|
||||
class Image2ImageModelName(str, Enum):
|
||||
seededit_3 = "seededit-3-0-i2i-250628"
|
||||
|
||||
|
||||
class Text2VideoModelName(str, Enum):
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-t2v-250428"
|
||||
|
||||
|
||||
class Image2VideoModelName(str, Enum):
|
||||
"""note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757"""
|
||||
|
||||
seedance_1_pro = "seedance-1-0-pro-250528"
|
||||
seedance_1_lite = "seedance-1-0-lite-i2v-250428"
|
||||
|
||||
|
||||
class Text2ImageTaskCreationRequest(BaseModel):
|
||||
model: Text2ImageModelName = Text2ImageModelName.seedream_3
|
||||
prompt: str = Field(...)
|
||||
response_format: Optional[str] = Field("url")
|
||||
size: Optional[str] = Field(None)
|
||||
seed: Optional[int] = Field(0, ge=0, le=2147483647)
|
||||
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
|
||||
watermark: Optional[bool] = Field(True)
|
||||
|
||||
|
||||
class Image2ImageTaskCreationRequest(BaseModel):
|
||||
model: Image2ImageModelName = Image2ImageModelName.seededit_3
|
||||
prompt: str = Field(...)
|
||||
response_format: Optional[str] = Field("url")
|
||||
image: str = Field(..., description="Base64 encoded string or image URL")
|
||||
size: Optional[str] = Field("adaptive")
|
||||
seed: Optional[int] = Field(..., ge=0, le=2147483647)
|
||||
guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0)
|
||||
watermark: Optional[bool] = Field(True)
|
||||
|
||||
|
||||
class Seedream4Options(BaseModel):
|
||||
max_images: int = Field(15)
|
||||
|
||||
|
||||
class Seedream4TaskCreationRequest(BaseModel):
|
||||
model: str = Field("seedream-4-0-250828")
|
||||
prompt: str = Field(...)
|
||||
response_format: str = Field("url")
|
||||
image: Optional[list[str]] = Field(None, description="Image URLs")
|
||||
size: str = Field(...)
|
||||
seed: int = Field(..., ge=0, le=2147483647)
|
||||
sequential_image_generation: str = Field("disabled")
|
||||
sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15))
|
||||
watermark: bool = Field(True)
|
||||
|
||||
|
||||
class ImageTaskCreationResponse(BaseModel):
|
||||
model: str = Field(...)
|
||||
created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.")
|
||||
data: list = Field([], description="Contains information about the generated image(s).")
|
||||
error: dict = Field({}, description="Contains `code` and `message` fields in case of error.")
|
||||
|
||||
|
||||
class TaskTextContent(BaseModel):
|
||||
type: str = Field("text")
|
||||
text: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContentUrl(BaseModel):
|
||||
url: str = Field(...)
|
||||
|
||||
|
||||
class TaskImageContent(BaseModel):
|
||||
type: str = Field("image_url")
|
||||
image_url: TaskImageContentUrl = Field(...)
|
||||
role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None)
|
||||
|
||||
|
||||
class Text2VideoTaskCreationRequest(BaseModel):
|
||||
model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro
|
||||
content: list[TaskTextContent] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class Image2VideoTaskCreationRequest(BaseModel):
|
||||
model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro
|
||||
content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2)
|
||||
|
||||
|
||||
class TaskCreationResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusError(BaseModel):
|
||||
code: str = Field(...)
|
||||
message: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResult(BaseModel):
|
||||
video_url: str = Field(...)
|
||||
|
||||
|
||||
class TaskStatusResponse(BaseModel):
|
||||
id: str = Field(...)
|
||||
model: str = Field(...)
|
||||
status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...)
|
||||
error: Optional[TaskStatusError] = Field(None)
|
||||
content: Optional[TaskStatusResult] = Field(None)
|
||||
|
||||
|
||||
RECOMMENDED_PRESETS = [
|
||||
("1024x1024 (1:1)", 1024, 1024),
|
||||
("864x1152 (3:4)", 864, 1152),
|
||||
("1152x864 (4:3)", 1152, 864),
|
||||
("1280x720 (16:9)", 1280, 720),
|
||||
("720x1280 (9:16)", 720, 1280),
|
||||
("832x1248 (2:3)", 832, 1248),
|
||||
("1248x832 (3:2)", 1248, 832),
|
||||
("1512x648 (21:9)", 1512, 648),
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
RECOMMENDED_PRESETS_SEEDREAM_4 = [
|
||||
("2048x2048 (1:1)", 2048, 2048),
|
||||
("2304x1728 (4:3)", 2304, 1728),
|
||||
("1728x2304 (3:4)", 1728, 2304),
|
||||
("2560x1440 (16:9)", 2560, 1440),
|
||||
("1440x2560 (9:16)", 1440, 2560),
|
||||
("2496x1664 (3:2)", 2496, 1664),
|
||||
("1664x2496 (2:3)", 1664, 2496),
|
||||
("3024x1296 (21:9)", 3024, 1296),
|
||||
("4096x4096 (1:1)", 4096, 4096),
|
||||
("Custom", None, None),
|
||||
]
|
||||
|
||||
# The time in this dictionary are given for 10 seconds duration.
|
||||
VIDEO_TASKS_EXECUTION_TIME = {
|
||||
"seedance-1-0-lite-t2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-lite-i2v-250428": {
|
||||
"480p": 40,
|
||||
"720p": 60,
|
||||
"1080p": 90,
|
||||
},
|
||||
"seedance-1-0-pro-250528": {
|
||||
"480p": 70,
|
||||
"720p": 85,
|
||||
"1080p": 115,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
if response.error:
|
||||
error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}"
|
||||
@ -194,13 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||
return response.data[0]["url"]
|
||||
|
||||
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
"""Returns the video URL from the task status response if it exists."""
|
||||
if hasattr(response, "content") and response.content:
|
||||
return response.content.video_url
|
||||
return None
|
||||
|
||||
|
||||
class ByteDanceImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
@ -211,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
||||
category="api node/image/ByteDance",
|
||||
description="Generate images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Text2ImageModelName,
|
||||
default=Text2ImageModelName.seedream_3,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
@ -335,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
category="api node/image/ByteDance",
|
||||
description="Edit images using ByteDance models via api based on prompt",
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Image2ImageModelName,
|
||||
default=Image2ImageModelName.seededit_3,
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The base image to edit",
|
||||
@ -394,7 +235,7 @@ class ByteDanceImageEditNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
seed: int,
|
||||
guidance_scale: float,
|
||||
@ -434,7 +275,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["seedream-4-0-250828"],
|
||||
options=["seedream-4-5-251128", "seedream-4-0-250828"],
|
||||
tooltip="Model name",
|
||||
),
|
||||
IO.String.Input(
|
||||
@ -459,7 +300,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=4096,
|
||||
step=64,
|
||||
step=8,
|
||||
tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
optional=True,
|
||||
),
|
||||
@ -468,7 +309,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
default=2048,
|
||||
min=1024,
|
||||
max=4096,
|
||||
step=64,
|
||||
step=8,
|
||||
tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`",
|
||||
optional=True,
|
||||
),
|
||||
@ -532,7 +373,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
image: torch.Tensor = None,
|
||||
image: Input.Image | None = None,
|
||||
size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0],
|
||||
width: int = 2048,
|
||||
height: int = 2048,
|
||||
@ -555,6 +396,18 @@ class ByteDanceSeedreamNode(IO.ComfyNode):
|
||||
raise ValueError(
|
||||
f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels."
|
||||
)
|
||||
out_num_pixels = w * h
|
||||
mp_provided = out_num_pixels / 1_000_000.0
|
||||
if "seedream-4-5" in model and out_num_pixels < 3686400:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
if "seedream-4-0" in model and out_num_pixels < 921600:
|
||||
raise ValueError(
|
||||
f"Minimum image resolution that the selected model can generate is 0.92MP, "
|
||||
f"but {mp_provided:.2f}MP provided."
|
||||
)
|
||||
n_input_images = get_number_of_images(image) if image is not None else 0
|
||||
if n_input_images > 10:
|
||||
raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.")
|
||||
@ -607,9 +460,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Text2VideoModelName,
|
||||
default=Text2VideoModelName.seedance_1_pro,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||
default="seedance-1-0-pro-fast-251015",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -714,9 +566,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=Image2VideoModelName,
|
||||
default=Image2VideoModelName.seedance_1_pro,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"],
|
||||
default="seedance-1-0-pro-fast-251015",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -787,7 +638,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
@ -833,9 +684,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[model.value for model in Image2VideoModelName],
|
||||
default=Image2VideoModelName.seedance_1_lite.value,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||
default="seedance-1-0-lite-i2v-250428",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -910,8 +760,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
first_frame: torch.Tensor,
|
||||
last_frame: torch.Tensor,
|
||||
first_frame: Input.Image,
|
||||
last_frame: Input.Image,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
@ -968,9 +818,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=[Image2VideoModelName.seedance_1_lite.value],
|
||||
default=Image2VideoModelName.seedance_1_lite.value,
|
||||
tooltip="Model name",
|
||||
options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"],
|
||||
default="seedance-1-0-lite-i2v-250428",
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
@ -1034,7 +883,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
cls,
|
||||
model: str,
|
||||
prompt: str,
|
||||
images: torch.Tensor,
|
||||
images: Input.Image,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
@ -1069,8 +918,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
||||
|
||||
async def process_video_task(
|
||||
cls: type[IO.ComfyNode],
|
||||
payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest],
|
||||
estimated_duration: Optional[int],
|
||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||
estimated_duration: int | None,
|
||||
) -> IO.NodeOutput:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
@ -1085,7 +934,7 @@ async def process_video_task(
|
||||
estimated_duration=estimated_duration,
|
||||
response_model=TaskStatusResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response)))
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.content.video_url))
|
||||
|
||||
|
||||
def raise_if_text_params(prompt: str, text_params: list[str]) -> None:
|
||||
|
||||
@ -4,10 +4,7 @@ See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/infer
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Literal
|
||||
@ -16,10 +13,10 @@ import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiContent,
|
||||
GeminiFileData,
|
||||
GeminiGenerateContentRequest,
|
||||
GeminiGenerateContentResponse,
|
||||
GeminiImageConfig,
|
||||
@ -29,20 +26,32 @@ from comfy_api_nodes.apis.gemini_api import (
|
||||
GeminiMimeType,
|
||||
GeminiPart,
|
||||
GeminiRole,
|
||||
GeminiSystemInstructionContent,
|
||||
GeminiTextPart,
|
||||
Modality,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
audio_to_base64_string,
|
||||
bytesio_to_image_tensor,
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_images_to_comfyapi,
|
||||
validate_string,
|
||||
video_to_base64_string,
|
||||
)
|
||||
from server import PromptServer
|
||||
|
||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
GEMINI_IMAGE_SYS_PROMPT = (
|
||||
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
|
||||
"Interpret all user input—regardless of "
|
||||
"format, intent, or abstraction—as literal visual directives for image composition.\n"
|
||||
"If a prompt is conversational or lacks specific visual details, "
|
||||
"you must creatively invent a concrete visual scenario that depicts the concept.\n"
|
||||
"Prioritize generating the visual representation above any text, formatting, or conversational requests."
|
||||
)
|
||||
|
||||
|
||||
class GeminiModel(str, Enum):
|
||||
@ -66,24 +75,43 @@ class GeminiImageModel(str, Enum):
|
||||
gemini_2_5_flash_image = "gemini-2.5-flash-image"
|
||||
|
||||
|
||||
def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert image tensor input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
image_input: Batch of image tensors from ComfyUI.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded images.
|
||||
"""
|
||||
async def create_image_parts(
|
||||
cls: type[IO.ComfyNode],
|
||||
images: Input.Image,
|
||||
image_limit: int = 0,
|
||||
) -> list[GeminiPart]:
|
||||
image_parts: list[GeminiPart] = []
|
||||
for image_index in range(image_input.shape[0]):
|
||||
image_as_b64 = tensor_to_base64_string(image_input[image_index].unsqueeze(0))
|
||||
if image_limit < 0:
|
||||
raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.")
|
||||
total_images = get_number_of_images(images)
|
||||
if total_images <= 0:
|
||||
raise ValueError("No images provided to create_image_parts; at least one image is required.")
|
||||
|
||||
# If image_limit == 0 --> use all images; otherwise clamp to image_limit.
|
||||
effective_max = total_images if image_limit == 0 else min(total_images, image_limit)
|
||||
|
||||
# Number of images we'll send as URLs (fileData)
|
||||
num_url_images = min(effective_max, 10) # Vertex API max number of image links
|
||||
reference_images_urls = await upload_images_to_comfyapi(
|
||||
cls,
|
||||
images,
|
||||
max_images=num_url_images,
|
||||
)
|
||||
for reference_image_url in reference_images_urls:
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
fileData=GeminiFileData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
fileUri=reference_image_url,
|
||||
)
|
||||
)
|
||||
)
|
||||
for idx in range(num_url_images, effective_max):
|
||||
image_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.image_png,
|
||||
data=image_as_b64,
|
||||
data=tensor_to_base64_string(images[idx]),
|
||||
)
|
||||
)
|
||||
)
|
||||
@ -102,14 +130,14 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
|
||||
List of response parts matching the requested type.
|
||||
"""
|
||||
if response.candidates is None:
|
||||
if response.promptFeedback.blockReason:
|
||||
if response.promptFeedback and response.promptFeedback.blockReason:
|
||||
feedback = response.promptFeedback
|
||||
raise ValueError(
|
||||
f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})"
|
||||
)
|
||||
raise NotImplementedError(
|
||||
"Gemini returned no response candidates. "
|
||||
"Please report to ComfyUI repository with the example of workflow to reproduce this."
|
||||
raise ValueError(
|
||||
"Gemini API returned no response candidates. If you are using the `IMAGE` modality, "
|
||||
"try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed."
|
||||
)
|
||||
parts = []
|
||||
for part in response.candidates[0].content.parts:
|
||||
@ -135,8 +163,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
||||
return "\n".join([part.text for part in parts])
|
||||
|
||||
|
||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
|
||||
image_tensors: list[torch.Tensor] = []
|
||||
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
||||
image_tensors: list[Input.Image] = []
|
||||
parts = get_parts_by_type(response, "image/png")
|
||||
for part in parts:
|
||||
image_data = base64.b64decode(part.inlineData.data)
|
||||
@ -147,6 +175,50 @@ def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Te
|
||||
return torch.cat(image_tensors, dim=0)
|
||||
|
||||
|
||||
def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None:
|
||||
if not response.modelVersion:
|
||||
return None
|
||||
# Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing
|
||||
if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"):
|
||||
input_tokens_price = 1.25
|
||||
output_text_tokens_price = 10.0
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion in (
|
||||
"gemini-2.5-flash-preview-04-17",
|
||||
"gemini-2.5-flash",
|
||||
):
|
||||
input_tokens_price = 0.30
|
||||
output_text_tokens_price = 2.50
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion in (
|
||||
"gemini-2.5-flash-image-preview",
|
||||
"gemini-2.5-flash-image",
|
||||
):
|
||||
input_tokens_price = 0.30
|
||||
output_text_tokens_price = 2.50
|
||||
output_image_tokens_price = 30.0
|
||||
elif response.modelVersion == "gemini-3-pro-preview":
|
||||
input_tokens_price = 2
|
||||
output_text_tokens_price = 12.0
|
||||
output_image_tokens_price = 0.0
|
||||
elif response.modelVersion == "gemini-3-pro-image-preview":
|
||||
input_tokens_price = 2
|
||||
output_text_tokens_price = 12.0
|
||||
output_image_tokens_price = 120.0
|
||||
else:
|
||||
return None
|
||||
final_price = response.usageMetadata.promptTokenCount * input_tokens_price
|
||||
if response.usageMetadata.candidatesTokensDetails:
|
||||
for i in response.usageMetadata.candidatesTokensDetails:
|
||||
if i.modality == Modality.IMAGE:
|
||||
final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models
|
||||
else:
|
||||
final_price += output_text_tokens_price * i.tokenCount
|
||||
if response.usageMetadata.thoughtsTokenCount:
|
||||
final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount
|
||||
return final_price / 1_000_000.0
|
||||
|
||||
|
||||
class GeminiNode(IO.ComfyNode):
|
||||
"""
|
||||
Node to generate text responses from a Gemini model.
|
||||
@ -214,6 +286,13 @@ class GeminiNode(IO.ComfyNode):
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(),
|
||||
@ -230,7 +309,9 @@ class GeminiNode(IO.ComfyNode):
|
||||
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
||||
"""Convert video input to Gemini API compatible parts."""
|
||||
|
||||
base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
@ -280,10 +361,11 @@ class GeminiNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: torch.Tensor | None = None,
|
||||
images: Input.Image | None = None,
|
||||
audio: Input.Audio | None = None,
|
||||
video: Input.Video | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
@ -292,8 +374,7 @@ class GeminiNode(IO.ComfyNode):
|
||||
|
||||
# Add other modal parts
|
||||
if images is not None:
|
||||
image_parts = create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
parts.extend(await create_image_parts(cls, images))
|
||||
if audio is not None:
|
||||
parts.extend(cls.create_audio_parts(audio))
|
||||
if video is not None:
|
||||
@ -301,7 +382,10 @@ class GeminiNode(IO.ComfyNode):
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
# Create response
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
@ -311,35 +395,14 @@ class GeminiNode(IO.ComfyNode):
|
||||
role=GeminiRole.user,
|
||||
parts=parts,
|
||||
)
|
||||
]
|
||||
],
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
|
||||
output_text = get_text_from_response(response)
|
||||
if output_text:
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": cls.hidden.unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
|
||||
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||
|
||||
|
||||
@ -476,6 +539,20 @@ class GeminiImage(IO.ComfyNode):
|
||||
"or otherwise generates 1:1 squares.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"response_modalities",
|
||||
options=["IMAGE+TEXT", "IMAGE"],
|
||||
tooltip="Choose 'IMAGE' for image-only output, or "
|
||||
"'IMAGE+TEXT' to return both the generated image and a text response.",
|
||||
optional=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@ -495,9 +572,11 @@ class GeminiImage(IO.ComfyNode):
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
images: torch.Tensor | None = None,
|
||||
images: Input.Image | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
aspect_ratio: str = "auto",
|
||||
response_modalities: str = "IMAGE+TEXT",
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
@ -507,11 +586,14 @@ class GeminiImage(IO.ComfyNode):
|
||||
image_config = GeminiImageConfig(aspectRatio=aspect_ratio)
|
||||
|
||||
if images is not None:
|
||||
image_parts = create_image_parts(images)
|
||||
parts.extend(image_parts)
|
||||
parts.extend(await create_image_parts(cls, images))
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
@ -520,40 +602,148 @@ class GeminiImage(IO.ComfyNode):
|
||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=["TEXT", "IMAGE"],
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=None if aspect_ratio == "auto" else image_config,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
||||
|
||||
|
||||
class GeminiImage2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiImage2Node",
|
||||
display_name="Nano Banana Pro (Google Gemini Image)",
|
||||
category="api node/image/Gemini",
|
||||
description="Generate or edit images synchronously via Google Vertex API.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="Text prompt describing the image to generate or the edits to apply. "
|
||||
"Include any constraints, styles, or details the model should follow.",
|
||||
default="",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["gemini-3-pro-image-preview"],
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=0xFFFFFFFFFFFFFFFF,
|
||||
control_after_generate=True,
|
||||
tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide "
|
||||
"the same response for repeated requests. Deterministic output isn't guaranteed. "
|
||||
"Also, changing the model or parameter settings, such as the temperature, "
|
||||
"can cause variations in the response even when you use the same seed value. "
|
||||
"By default, a random seed value is used.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"],
|
||||
default="auto",
|
||||
tooltip="If set to 'auto', matches your input image's aspect ratio; "
|
||||
"if no image is provided, a 16:9 square is usually generated.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"resolution",
|
||||
options=["1K", "2K", "4K"],
|
||||
tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"response_modalities",
|
||||
options=["IMAGE+TEXT", "IMAGE"],
|
||||
tooltip="Choose 'IMAGE' for image-only output, or "
|
||||
"'IMAGE+TEXT' to return both the generated image and a text response.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"images",
|
||||
optional=True,
|
||||
tooltip="Optional reference image(s). "
|
||||
"To include multiple images, use the Batch Images node (up to 14).",
|
||||
),
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
"files",
|
||||
optional=True,
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Generate Content Input Files node.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default=GEMINI_IMAGE_SYS_PROMPT,
|
||||
optional=True,
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
IO.String.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
output_image = get_image_from_response(response)
|
||||
output_text = get_text_from_response(response)
|
||||
if output_text:
|
||||
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
|
||||
render_spec = {
|
||||
"node_id": cls.hidden.unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: str,
|
||||
seed: int,
|
||||
aspect_ratio: str,
|
||||
resolution: str,
|
||||
response_modalities: str,
|
||||
images: Input.Image | None = None,
|
||||
files: list[GeminiPart] | None = None,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
|
||||
output_text = output_text or "Empty response from Gemini model..."
|
||||
return IO.NodeOutput(output_image, output_text)
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
if images is not None:
|
||||
if get_number_of_images(images) > 14:
|
||||
raise ValueError("The current maximum number of supported images is 14.")
|
||||
parts.extend(await create_image_parts(cls, images))
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
image_config = GeminiImageConfig(imageSize=resolution)
|
||||
if aspect_ratio != "auto":
|
||||
image_config.aspectRatio = aspect_ratio
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
|
||||
data=GeminiImageGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(role=GeminiRole.user, parts=parts),
|
||||
],
|
||||
generationConfig=GeminiImageGenerationConfig(
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=image_config,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response))
|
||||
|
||||
|
||||
class GeminiExtension(ComfyExtension):
|
||||
@ -562,6 +752,7 @@ class GeminiExtension(ComfyExtension):
|
||||
return [
|
||||
GeminiNode,
|
||||
GeminiImage,
|
||||
GeminiImage2,
|
||||
GeminiInputFiles,
|
||||
]
|
||||
|
||||
|
||||
@ -4,15 +4,14 @@ For source of truth on the allowed permutations of request fields, please refere
|
||||
- [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional, TypeVar
|
||||
import math
|
||||
import logging
|
||||
|
||||
from typing_extensions import override
|
||||
import math
|
||||
import re
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.apis import (
|
||||
KlingCameraControl,
|
||||
KlingCameraConfig,
|
||||
@ -50,25 +49,35 @@ from comfy_api_nodes.apis import (
|
||||
KlingCharacterEffectModelName,
|
||||
KlingSingleImageEffectModelName,
|
||||
)
|
||||
from comfy_api_nodes.apis.kling_api import (
|
||||
ImageToVideoWithAudioRequest,
|
||||
OmniImageParamImage,
|
||||
OmniParamImage,
|
||||
OmniParamVideo,
|
||||
OmniProFirstLastFrameRequest,
|
||||
OmniProImageRequest,
|
||||
OmniProReferences2VideoRequest,
|
||||
OmniProText2VideoRequest,
|
||||
TaskStatusResponse,
|
||||
TextToVideoWithAudioRequest,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
validate_image_dimensions,
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_image_aspect_ratio,
|
||||
validate_image_dimensions,
|
||||
validate_string,
|
||||
validate_video_dimensions,
|
||||
validate_video_duration,
|
||||
tensor_to_base64_string,
|
||||
validate_string,
|
||||
upload_audio_to_comfyapi,
|
||||
download_url_to_image_tensor,
|
||||
upload_video_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
sync_op,
|
||||
ApiEndpoint,
|
||||
poll_op,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.input.basic_types import AudioInput
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
KLING_API_VERSION = "v1"
|
||||
PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
|
||||
@ -94,8 +103,6 @@ AVERAGE_DURATION_IMAGE_GEN = 32
|
||||
AVERAGE_DURATION_VIDEO_EFFECTS = 320
|
||||
AVERAGE_DURATION_VIDEO_EXTEND = 320
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
MODE_TEXT2VIDEO = {
|
||||
"standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
|
||||
@ -130,6 +137,8 @@ MODE_START_END_FRAME = {
|
||||
"pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"),
|
||||
"pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"),
|
||||
"pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"),
|
||||
"pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"),
|
||||
"pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"),
|
||||
}
|
||||
"""
|
||||
Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples.
|
||||
@ -206,6 +215,50 @@ VOICES_CONFIG = {
|
||||
}
|
||||
|
||||
|
||||
def normalize_omni_prompt_references(prompt: str) -> str:
|
||||
"""
|
||||
Rewrites Kling Omni-style placeholders used in the app, like:
|
||||
|
||||
@image, @image1, @image2, ... @imageN
|
||||
@video, @video1, @video2, ... @videoN
|
||||
|
||||
into the API-compatible form:
|
||||
|
||||
<<<image_1>>>, <<<image_2>>>, ...
|
||||
<<<video_1>>>, <<<video_2>>>, ...
|
||||
|
||||
This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app.
|
||||
"""
|
||||
if not prompt:
|
||||
return prompt
|
||||
|
||||
def _image_repl(match):
|
||||
return f"<<<image_{match.group('idx') or '1'}>>>"
|
||||
|
||||
def _video_repl(match):
|
||||
return f"<<<video_{match.group('idx') or '1'}>>>"
|
||||
|
||||
# (?<!\w) avoids matching e.g. "test@image.com"
|
||||
# (?!\w) makes sure we only match @image / @image<digits> and not @imageFoo
|
||||
prompt = re.sub(r"(?<!\w)@image(?P<idx>\d*)(?!\w)", _image_repl, prompt)
|
||||
return re.sub(r"(?<!\w)@video(?P<idx>\d*)(?!\w)", _video_repl, prompt)
|
||||
|
||||
|
||||
async def finish_omni_video_task(cls: type[IO.ComfyNode], response: TaskStatusResponse) -> IO.NodeOutput:
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
max_poll_attempts=160,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
def is_valid_camera_control_configs(configs: list[float]) -> bool:
|
||||
"""Verifies that at least one camera control configuration is non-zero."""
|
||||
return any(not math.isclose(value, 0.0) for value in configs)
|
||||
@ -296,7 +349,7 @@ def get_video_from_response(response) -> KlingVideoResult:
|
||||
return video
|
||||
|
||||
|
||||
def get_video_url_from_response(response) -> Optional[str]:
|
||||
def get_video_url_from_response(response) -> str | None:
|
||||
"""Returns the first video url from the Kling video generation task result.
|
||||
Will not raise an error if the response is not valid.
|
||||
"""
|
||||
@ -315,7 +368,7 @@ def get_images_from_response(response) -> list[KlingImageResult]:
|
||||
return images
|
||||
|
||||
|
||||
def get_images_urls_from_response(response) -> Optional[str]:
|
||||
def get_images_urls_from_response(response) -> str | None:
|
||||
"""Returns the list of image urls from the Kling image generation task result.
|
||||
Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls.
|
||||
"""
|
||||
@ -349,7 +402,7 @@ async def execute_text2video(
|
||||
model_mode: str,
|
||||
duration: str,
|
||||
aspect_ratio: str,
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
camera_control: KlingCameraControl | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
|
||||
task_creation_response = await sync_op(
|
||||
@ -394,8 +447,8 @@ async def execute_image2video(
|
||||
model_mode: str,
|
||||
aspect_ratio: str,
|
||||
duration: str,
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
end_frame: Optional[torch.Tensor] = None,
|
||||
camera_control: KlingCameraControl | None = None,
|
||||
end_frame: torch.Tensor | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
|
||||
validate_input_image(start_frame)
|
||||
@ -432,12 +485,12 @@ async def execute_image2video(
|
||||
task_id = task_creation_response.data.task_id
|
||||
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
||||
response_model=KlingImage2VideoResponse,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
)
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"),
|
||||
response_model=KlingImage2VideoResponse,
|
||||
estimated_duration=AVERAGE_DURATION_I2V,
|
||||
status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None),
|
||||
)
|
||||
validate_video_result_response(final_response)
|
||||
|
||||
video = get_video_from_response(final_response)
|
||||
@ -451,9 +504,9 @@ async def execute_video_effect(
|
||||
model_name: str,
|
||||
duration: KlingVideoGenDuration,
|
||||
image_1: torch.Tensor,
|
||||
image_2: Optional[torch.Tensor] = None,
|
||||
model_mode: Optional[KlingVideoGenMode] = None,
|
||||
) -> tuple[VideoFromFile, str, str]:
|
||||
image_2: torch.Tensor | None = None,
|
||||
model_mode: KlingVideoGenMode | None = None,
|
||||
) -> tuple[InputImpl.VideoFromFile, str, str]:
|
||||
if dual_character:
|
||||
request_input_field = KlingDualCharacterEffectInput(
|
||||
model_name=model_name,
|
||||
@ -499,13 +552,13 @@ async def execute_video_effect(
|
||||
|
||||
async def execute_lipsync(
|
||||
cls: type[IO.ComfyNode],
|
||||
video: VideoInput,
|
||||
audio: Optional[AudioInput] = None,
|
||||
voice_language: Optional[str] = None,
|
||||
model_mode: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
voice_speed: Optional[float] = None,
|
||||
voice_id: Optional[str] = None,
|
||||
video: Input.Video,
|
||||
audio: Input.Audio | None = None,
|
||||
voice_language: str | None = None,
|
||||
model_mode: str | None = None,
|
||||
text: str | None = None,
|
||||
voice_speed: float | None = None,
|
||||
voice_id: str | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if text:
|
||||
validate_string(text, field_name="Text", max_length=MAX_PROMPT_LENGTH_LIP_SYNC)
|
||||
@ -518,7 +571,9 @@ async def execute_lipsync(
|
||||
|
||||
# Upload the audio file to Comfy API and get download URL
|
||||
if audio:
|
||||
audio_url = await upload_audio_to_comfyapi(cls, audio)
|
||||
audio_url = await upload_audio_to_comfyapi(
|
||||
cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3"
|
||||
)
|
||||
logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
|
||||
else:
|
||||
audio_url = None
|
||||
@ -738,6 +793,474 @@ class KlingTextToVideoNode(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProTextToVideoNode",
|
||||
display_name="Kling Omni Text to Video (Pro)",
|
||||
category="api node/video/Kling",
|
||||
description="Use text prompts to generate videos with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A text prompt describing the video content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Combo.Input("duration", options=[5, 10]),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=OmniProText2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
||||
|
||||
class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProFirstLastFrameNode",
|
||||
display_name="Kling Omni First-Last-Frame to Video (Pro)",
|
||||
category="api node/video/Kling",
|
||||
description="Use a start frame, an optional end frame, or reference images with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A text prompt describing the video content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Combo.Input("duration", options=["5", "10"]),
|
||||
IO.Image.Input("first_frame"),
|
||||
IO.Image.Input(
|
||||
"end_frame",
|
||||
optional=True,
|
||||
tooltip="An optional end frame for the video. "
|
||||
"This cannot be used simultaneously with 'reference_images'.",
|
||||
),
|
||||
IO.Image.Input(
|
||||
"reference_images",
|
||||
optional=True,
|
||||
tooltip="Up to 6 additional reference images.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
first_frame: Input.Image,
|
||||
end_frame: Input.Image | None = None,
|
||||
reference_images: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
if end_frame is not None and reference_images is not None:
|
||||
raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.")
|
||||
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1))
|
||||
image_list: list[OmniParamImage] = [
|
||||
OmniParamImage(
|
||||
image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0],
|
||||
type="first_frame",
|
||||
)
|
||||
]
|
||||
if end_frame is not None:
|
||||
validate_image_dimensions(end_frame, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1))
|
||||
image_list.append(
|
||||
OmniParamImage(
|
||||
image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0],
|
||||
type="end_frame",
|
||||
)
|
||||
)
|
||||
if reference_images is not None:
|
||||
if get_number_of_images(reference_images) > 6:
|
||||
raise ValueError("The maximum number of reference images allowed is 6.")
|
||||
for i in reference_images:
|
||||
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"):
|
||||
image_list.append(OmniParamImage(image_url=i))
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=OmniProFirstLastFrameRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
||||
|
||||
class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProImageToVideoNode",
|
||||
display_name="Kling Omni Image to Video (Pro)",
|
||||
category="api node/video/Kling",
|
||||
description="Use up to 7 reference images to generate a video with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A text prompt describing the video content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
|
||||
IO.Image.Input(
|
||||
"reference_images",
|
||||
tooltip="Up to 7 reference images.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
reference_images: Input.Image,
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
if get_number_of_images(reference_images) > 7:
|
||||
raise ValueError("The maximum number of reference images is 7.")
|
||||
for i in reference_images:
|
||||
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||
image_list: list[OmniParamImage] = []
|
||||
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||
image_list.append(OmniParamImage(image_url=i))
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=OmniProReferences2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
||||
|
||||
class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProVideoToVideoNode",
|
||||
display_name="Kling Omni Video to Video (Pro)",
|
||||
category="api node/video/Kling",
|
||||
description="Use a video and up to 4 reference images to generate a video with the latest Kling model.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A text prompt describing the video content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider),
|
||||
IO.Video.Input("reference_video", tooltip="Video to use as a reference."),
|
||||
IO.Boolean.Input("keep_original_sound", default=True),
|
||||
IO.Image.Input(
|
||||
"reference_images",
|
||||
tooltip="Up to 4 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
reference_video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
reference_images: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05)
|
||||
validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160)
|
||||
image_list: list[OmniParamImage] = []
|
||||
if reference_images is not None:
|
||||
if get_number_of_images(reference_images) > 4:
|
||||
raise ValueError("The maximum number of reference images allowed with a video input is 4.")
|
||||
for i in reference_images:
|
||||
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||
image_list.append(OmniParamImage(image_url=i))
|
||||
video_list = [
|
||||
OmniParamVideo(
|
||||
video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"),
|
||||
refer_type="feature",
|
||||
keep_original_sound="yes" if keep_original_sound else "no",
|
||||
)
|
||||
]
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=OmniProReferences2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
image_list=image_list if image_list else None,
|
||||
video_list=video_list,
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
||||
|
||||
class OmniProEditVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProEditVideoNode",
|
||||
display_name="Kling Omni Edit Video (Pro)",
|
||||
category="api node/video/Kling",
|
||||
description="Edit an existing video with the latest model from Kling.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-video-o1"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A text prompt describing the video content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."),
|
||||
IO.Boolean.Input("keep_original_sound", default=True),
|
||||
IO.Image.Input(
|
||||
"reference_images",
|
||||
tooltip="Up to 4 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
reference_images: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
validate_video_duration(video, min_duration=3.0, max_duration=10.05)
|
||||
validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160)
|
||||
image_list: list[OmniParamImage] = []
|
||||
if reference_images is not None:
|
||||
if get_number_of_images(reference_images) > 4:
|
||||
raise ValueError("The maximum number of reference images allowed with a video input is 4.")
|
||||
for i in reference_images:
|
||||
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||
image_list.append(OmniParamImage(image_url=i))
|
||||
video_list = [
|
||||
OmniParamVideo(
|
||||
video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"),
|
||||
refer_type="base",
|
||||
keep_original_sound="yes" if keep_original_sound else "no",
|
||||
)
|
||||
]
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=OmniProReferences2VideoRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
aspect_ratio=None,
|
||||
duration=None,
|
||||
image_list=image_list if image_list else None,
|
||||
video_list=video_list,
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
||||
|
||||
class OmniProImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingOmniProImageNode",
|
||||
display_name="Kling Omni Image (Pro)",
|
||||
category="api node/image/Kling",
|
||||
description="Create or edit images with the latest model from Kling.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-image-o1"]),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
tooltip="A text prompt describing the image content. "
|
||||
"This can include both positive and negative descriptions.",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1K", "2K"]),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"],
|
||||
),
|
||||
IO.Image.Input(
|
||||
"reference_images",
|
||||
tooltip="Up to 10 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
reference_images: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
image_list: list[OmniImageParamImage] = []
|
||||
if reference_images is not None:
|
||||
if get_number_of_images(reference_images) > 10:
|
||||
raise ValueError("The maximum number of reference images is 10.")
|
||||
for i in reference_images:
|
||||
validate_image_dimensions(i, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1))
|
||||
for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"):
|
||||
image_list.append(OmniImageParamImage(image=i))
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=OmniProImageRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
resolution=resolution.lower(),
|
||||
aspect_ratio=aspect_ratio,
|
||||
image_list=image_list if image_list else None,
|
||||
),
|
||||
)
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url))
|
||||
|
||||
|
||||
class KlingCameraControlT2VNode(IO.ComfyNode):
|
||||
"""
|
||||
Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
|
||||
@ -785,7 +1308,7 @@ class KlingCameraControlT2VNode(IO.ComfyNode):
|
||||
negative_prompt: str,
|
||||
cfg_scale: float,
|
||||
aspect_ratio: str,
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
camera_control: KlingCameraControl | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_text2video(
|
||||
cls,
|
||||
@ -807,9 +1330,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingImage2VideoNode",
|
||||
display_name="Kling Image to Video",
|
||||
display_name="Kling Image(First Frame) to Video",
|
||||
category="api node/video/Kling",
|
||||
description="Kling Image to Video Node",
|
||||
inputs=[
|
||||
IO.Image.Input("start_frame", tooltip="The reference image used to generate the video."),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
@ -852,8 +1374,8 @@ class KlingImage2VideoNode(IO.ComfyNode):
|
||||
mode: str,
|
||||
aspect_ratio: str,
|
||||
duration: str,
|
||||
camera_control: Optional[KlingCameraControl] = None,
|
||||
end_frame: Optional[torch.Tensor] = None,
|
||||
camera_control: KlingCameraControl | None = None,
|
||||
end_frame: torch.Tensor | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_image2video(
|
||||
cls,
|
||||
@ -963,15 +1485,11 @@ class KlingStartEndFrameNode(IO.ComfyNode):
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"),
|
||||
IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"),
|
||||
IO.Float.Input("cfg_scale", default=0.5, min=0.0, max=1.0),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=[i.value for i in KlingVideoGenAspectRatio],
|
||||
default="16:9",
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Combo.Input(
|
||||
"mode",
|
||||
options=modes,
|
||||
default=modes[2],
|
||||
default=modes[8],
|
||||
tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.",
|
||||
),
|
||||
],
|
||||
@ -1168,7 +1686,10 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode):
|
||||
category="api node/video/Kling",
|
||||
description="Achieve different special effects when generating a video based on the effect_scene.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"),
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"effect_scene",
|
||||
options=[i.value for i in KlingSingleImageEffectsScene],
|
||||
@ -1252,8 +1773,8 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
audio: AudioInput,
|
||||
video: Input.Video,
|
||||
audio: Input.Audio,
|
||||
voice_language: str,
|
||||
) -> IO.NodeOutput:
|
||||
return await execute_lipsync(
|
||||
@ -1312,7 +1833,7 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: VideoInput,
|
||||
video: Input.Video,
|
||||
text: str,
|
||||
voice: str,
|
||||
voice_speed: float,
|
||||
@ -1469,7 +1990,7 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
human_fidelity: float,
|
||||
n: int,
|
||||
aspect_ratio: KlingImageGenAspectRatio,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
image: torch.Tensor | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
|
||||
validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN)
|
||||
@ -1514,6 +2035,136 @@ class KlingImageGenerationNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(await image_result_to_node_output(images))
|
||||
|
||||
|
||||
class TextToVideoWithAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingTextToVideoWithAudio",
|
||||
display_name="Kling Text to Video with Audio",
|
||||
category="api node/video/Kling",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
||||
IO.Combo.Input("mode", options=["pro"]),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Combo.Input("duration", options=[5, 10]),
|
||||
IO.Boolean.Input("generate_audio", default=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
mode: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
generate_audio: bool,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/text2video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=TextToVideoWithAudioRequest(
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
mode=mode,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
sound="on" if generate_audio else "off",
|
||||
),
|
||||
)
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/text2video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
class ImageToVideoWithAudio(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls) -> IO.Schema:
|
||||
return IO.Schema(
|
||||
node_id="KlingImageToVideoWithAudio",
|
||||
display_name="Kling Image(First Frame) to Video with Audio",
|
||||
category="api node/video/Kling",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v2-6"]),
|
||||
IO.Image.Input("start_frame"),
|
||||
IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt."),
|
||||
IO.Combo.Input("mode", options=["pro"]),
|
||||
IO.Combo.Input("duration", options=[5, 10]),
|
||||
IO.Boolean.Input("generate_audio", default=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model_name: str,
|
||||
start_frame: Input.Image,
|
||||
prompt: str,
|
||||
mode: str,
|
||||
duration: int,
|
||||
generate_audio: bool,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
validate_image_dimensions(start_frame, min_width=300, min_height=300)
|
||||
validate_image_aspect_ratio(start_frame, (1, 2.5), (2.5, 1))
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/kling/v1/videos/image2video", method="POST"),
|
||||
response_model=TaskStatusResponse,
|
||||
data=ImageToVideoWithAudioRequest(
|
||||
model_name=model_name,
|
||||
image=(await upload_images_to_comfyapi(cls, start_frame))[0],
|
||||
prompt=prompt,
|
||||
mode=mode,
|
||||
duration=str(duration),
|
||||
sound="on" if generate_audio else "off",
|
||||
),
|
||||
)
|
||||
if response.code:
|
||||
raise RuntimeError(
|
||||
f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/kling/v1/videos/image2video/{response.data.task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: (r.data.task_status if r.data else None),
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url))
|
||||
|
||||
|
||||
class KlingExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -1531,6 +2182,14 @@ class KlingExtension(ComfyExtension):
|
||||
KlingImageGenerationNode,
|
||||
KlingSingleImageVideoEffectNode,
|
||||
KlingDualCharacterVideoEffectNode,
|
||||
OmniProTextToVideoNode,
|
||||
OmniProFirstLastFrameNode,
|
||||
OmniProImageToVideoNode,
|
||||
OmniProVideoToVideoNode,
|
||||
OmniProEditVideoNode,
|
||||
OmniProImageNode,
|
||||
TextToVideoWithAudio,
|
||||
ImageToVideoWithAudio,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -1,12 +1,9 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
get_number_of_images,
|
||||
@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel):
|
||||
model: str = Field(...)
|
||||
duration: int = Field(...)
|
||||
resolution: str = Field(...)
|
||||
fps: Optional[int] = Field(25)
|
||||
generate_audio: Optional[bool] = Field(True)
|
||||
image_uri: Optional[str] = Field(None)
|
||||
fps: int | None = Field(25)
|
||||
generate_audio: bool | None = Field(True)
|
||||
image_uri: str | None = Field(None)
|
||||
|
||||
|
||||
class TextToVideoNode(IO.ComfyNode):
|
||||
@ -103,7 +100,7 @@ class TextToVideoNode(IO.ComfyNode):
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class ImageToVideoNode(IO.ComfyNode):
|
||||
@ -153,7 +150,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
model: str,
|
||||
prompt: str,
|
||||
duration: int,
|
||||
@ -183,7 +180,7 @@ class ImageToVideoNode(IO.ComfyNode):
|
||||
as_binary=True,
|
||||
max_retries=1,
|
||||
)
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))
|
||||
|
||||
|
||||
class LtxvApiExtension(ComfyExtension):
|
||||
|
||||
@ -1,11 +1,8 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input import VideoInput
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis import (
|
||||
MoonvalleyPromptResponse,
|
||||
MoonvalleyTextToVideoInferenceParams,
|
||||
@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None:
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
|
||||
def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
|
||||
"""
|
||||
Validates and processes video input for Moonvalley Video-to-Video generation.
|
||||
|
||||
@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
|
||||
return _validate_and_trim_duration(video)
|
||||
|
||||
|
||||
def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
|
||||
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
|
||||
"""Extracts video dimensions with error handling."""
|
||||
try:
|
||||
return video.get_dimensions()
|
||||
@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
|
||||
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")
|
||||
|
||||
|
||||
def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
|
||||
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
|
||||
"""Validates video duration and trims to 5 seconds if needed."""
|
||||
duration = video.get_duration()
|
||||
_validate_minimum_duration(duration)
|
||||
@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None:
|
||||
raise ValueError("Input video must be at least 5 seconds long.")
|
||||
|
||||
|
||||
def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
|
||||
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
|
||||
"""Trims video to 5 seconds if longer."""
|
||||
if duration > 5:
|
||||
return trim_video(video, 5)
|
||||
@ -241,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode):
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: torch.Tensor,
|
||||
image: Input.Image,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
resolution: str,
|
||||
@ -362,9 +359,9 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
seed: int,
|
||||
video: Optional[VideoInput] = None,
|
||||
video: Input.Video | None = None,
|
||||
control_type: str = "Motion Transfer",
|
||||
motion_intensity: Optional[int] = 100,
|
||||
motion_intensity: int | None = 100,
|
||||
steps=33,
|
||||
prompt_adherence=4.5,
|
||||
) -> IO.NodeOutput:
|
||||
|
||||
@ -1,15 +1,10 @@
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from inspect import cleandoc
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from server import PromptServer
|
||||
import folder_paths
|
||||
import base64
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
@ -587,11 +582,11 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
def create_input_message_contents(
|
||||
cls,
|
||||
prompt: str,
|
||||
image: Optional[torch.Tensor] = None,
|
||||
files: Optional[list[InputFileContent]] = None,
|
||||
image: torch.Tensor | None = None,
|
||||
files: list[InputFileContent] | None = None,
|
||||
) -> InputMessageContentList:
|
||||
"""Create a list of input message contents from prompt and optional image."""
|
||||
content_list: list[Union[InputContent, InputTextContent, InputImageContent, InputFileContent]] = [
|
||||
content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [
|
||||
InputTextContent(text=prompt, type="input_text"),
|
||||
]
|
||||
if image is not None:
|
||||
@ -617,9 +612,9 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
persist_context: bool = False,
|
||||
model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value,
|
||||
images: Optional[torch.Tensor] = None,
|
||||
files: Optional[list[InputFileContent]] = None,
|
||||
advanced_options: Optional[CreateModelResponseProperties] = None,
|
||||
images: torch.Tensor | None = None,
|
||||
files: list[InputFileContent] | None = None,
|
||||
advanced_options: CreateModelResponseProperties | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=False)
|
||||
|
||||
@ -660,30 +655,7 @@ class OpenAIChatNode(IO.ComfyNode):
|
||||
status_extractor=lambda response: response.status,
|
||||
completed_statuses=["incomplete", "completed"]
|
||||
)
|
||||
output_text = cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))
|
||||
|
||||
# Update history
|
||||
render_spec = {
|
||||
"node_id": cls.hidden.unique_id,
|
||||
"component": "ChatHistoryWidget",
|
||||
"props": {
|
||||
"history": json.dumps(
|
||||
[
|
||||
{
|
||||
"prompt": prompt,
|
||||
"response": output_text,
|
||||
"response_id": str(uuid.uuid4()),
|
||||
"timestamp": time.time(),
|
||||
}
|
||||
]
|
||||
),
|
||||
},
|
||||
}
|
||||
PromptServer.instance.send_sync(
|
||||
"display_component",
|
||||
render_spec,
|
||||
)
|
||||
return IO.NodeOutput(output_text)
|
||||
return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response)))
|
||||
|
||||
|
||||
class OpenAIInputFiles(IO.ComfyNode):
|
||||
@ -790,8 +762,8 @@ class OpenAIChatConfig(IO.ComfyNode):
|
||||
def execute(
|
||||
cls,
|
||||
truncation: bool,
|
||||
instructions: Optional[str] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
instructions: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
"""
|
||||
Configure advanced options for the OpenAI Chat Node.
|
||||
|
||||
@ -92,6 +92,7 @@ class PikaImageToVideo(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -152,6 +153,7 @@ class PikaTextToVideoNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -239,6 +241,7 @@ class PikaScenes(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -323,6 +326,7 @@ class PikAdditionsNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -399,6 +403,7 @@ class PikaSwapsNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -466,6 +471,7 @@ class PikaffectsNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -515,6 +521,7 @@ class PikaStartEndFrameNode(IO.ComfyNode):
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -11,12 +11,11 @@ User Guides:
|
||||
|
||||
"""
|
||||
|
||||
from typing import Union, Optional
|
||||
from typing_extensions import override
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.apis import (
|
||||
RunwayImageToVideoRequest,
|
||||
RunwayImageToVideoResponse,
|
||||
@ -44,8 +43,6 @@ from comfy_api_nodes.util import (
|
||||
sync_op,
|
||||
poll_op,
|
||||
)
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import ComfyExtension, IO
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum):
|
||||
field_1280_768 = "1280:768"
|
||||
|
||||
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
"""Returns the video URL from the task status response if it exists."""
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
@ -89,13 +86,13 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> Union[float, None]:
|
||||
) -> float | None:
|
||||
if hasattr(response, "progress") and response.progress is not None:
|
||||
return response.progress * 100
|
||||
return None
|
||||
|
||||
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
"""Returns the image URL from the task status response if it exists."""
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
return response.output[0]
|
||||
@ -103,7 +100,7 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
|
||||
|
||||
|
||||
async def get_response(
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_op(
|
||||
@ -119,8 +116,8 @@ async def get_response(
|
||||
async def generate_video(
|
||||
cls: type[IO.ComfyNode],
|
||||
request: RunwayImageToVideoRequest,
|
||||
estimated_duration: Optional[int] = None,
|
||||
) -> VideoFromFile:
|
||||
estimated_duration: int | None = None,
|
||||
) -> InputImpl.VideoFromFile:
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
|
||||
@ -193,7 +190,7 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
start_frame: Input.Image,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
@ -283,7 +280,7 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
start_frame: Input.Image,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
@ -381,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
start_frame: torch.Tensor,
|
||||
end_frame: torch.Tensor,
|
||||
start_frame: Input.Image,
|
||||
end_frame: Input.Image,
|
||||
duration: str,
|
||||
ratio: str,
|
||||
seed: int,
|
||||
@ -467,7 +464,7 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
cls,
|
||||
prompt: str,
|
||||
ratio: str,
|
||||
reference_image: Optional[torch.Tensor] = None,
|
||||
reference_image: Input.Image | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1)
|
||||
|
||||
|
||||
418
comfy_api_nodes/nodes_topaz.py
Normal file
418
comfy_api_nodes/nodes_topaz.py
Normal file
@ -0,0 +1,418 @@
|
||||
import builtins
|
||||
from io import BytesIO
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis import topaz_api
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
download_url_to_image_tensor,
|
||||
download_url_to_video_output,
|
||||
get_fs_object_size,
|
||||
get_number_of_images,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_images_to_comfyapi,
|
||||
validate_container_format_is_mp4,
|
||||
)
|
||||
|
||||
UPSCALER_MODELS_MAP = {
|
||||
"Starlight (Astra) Fast": "slf-1",
|
||||
"Starlight (Astra) Creative": "slc-1",
|
||||
}
|
||||
UPSCALER_VALUES_MAP = {
|
||||
"FullHD (1080p)": 1920,
|
||||
"4K (2160p)": 3840,
|
||||
}
|
||||
|
||||
|
||||
class TopazImageEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TopazImageEnhance",
|
||||
display_name="Topaz Image Enhance",
|
||||
category="api node/image/Topaz",
|
||||
description="Industry-standard upscaling and image enhancement.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model", options=["Reimagine"]),
|
||||
IO.Image.Input("image"),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Optional text prompt for creative upscaling guidance.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"subject_detection",
|
||||
options=["All", "Foreground", "Background"],
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"face_enhancement",
|
||||
default=True,
|
||||
optional=True,
|
||||
tooltip="Enhance faces (if present) during processing.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"face_enhancement_creativity",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
tooltip="Set the creativity level for face enhancement.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"face_enhancement_strength",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
tooltip="Controls how sharp enhanced faces are relative to the background.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"crop_to_fill",
|
||||
default=False,
|
||||
optional=True,
|
||||
tooltip="By default, the image is letterboxed when the output aspect ratio differs. "
|
||||
"Enable to crop the image to fill the output dimensions.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"output_width",
|
||||
default=0,
|
||||
min=0,
|
||||
max=32000,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
tooltip="Zero value means to calculate automatically (usually it will be original size or output_height if specified).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"output_height",
|
||||
default=0,
|
||||
min=0,
|
||||
max=32000,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
optional=True,
|
||||
tooltip="Zero value means to output in the same height as original or output width.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"creativity",
|
||||
default=3,
|
||||
min=1,
|
||||
max=9,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"face_preservation",
|
||||
default=True,
|
||||
optional=True,
|
||||
tooltip="Preserve subjects' facial identity.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"color_preservation",
|
||||
default=True,
|
||||
optional=True,
|
||||
tooltip="Preserve the original colors.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
model: str,
|
||||
image: torch.Tensor,
|
||||
prompt: str = "",
|
||||
subject_detection: str = "All",
|
||||
face_enhancement: bool = True,
|
||||
face_enhancement_creativity: float = 1.0,
|
||||
face_enhancement_strength: float = 0.8,
|
||||
crop_to_fill: bool = False,
|
||||
output_width: int = 0,
|
||||
output_height: int = 0,
|
||||
creativity: int = 3,
|
||||
face_preservation: bool = True,
|
||||
color_preservation: bool = True,
|
||||
) -> IO.NodeOutput:
|
||||
if get_number_of_images(image) != 1:
|
||||
raise ValueError("Only one input image is supported.")
|
||||
download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png")
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"),
|
||||
response_model=topaz_api.ImageAsyncTaskResponse,
|
||||
data=topaz_api.ImageEnhanceRequest(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
subject_detection=subject_detection,
|
||||
face_enhancement=face_enhancement,
|
||||
face_enhancement_creativity=face_enhancement_creativity,
|
||||
face_enhancement_strength=face_enhancement_strength,
|
||||
crop_to_fill=crop_to_fill,
|
||||
output_width=output_width if output_width else None,
|
||||
output_height=output_height if output_height else None,
|
||||
creativity=creativity,
|
||||
face_preservation=str(face_preservation).lower(),
|
||||
color_preservation=str(color_preservation).lower(),
|
||||
source_url=download_url[0],
|
||||
output_format="png",
|
||||
),
|
||||
content_type="multipart/form-data",
|
||||
)
|
||||
|
||||
await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"),
|
||||
response_model=topaz_api.ImageStatusResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||
price_extractor=lambda x: x.credits * 0.08,
|
||||
poll_interval=8.0,
|
||||
max_poll_attempts=160,
|
||||
estimated_duration=60,
|
||||
)
|
||||
|
||||
results = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"),
|
||||
response_model=topaz_api.ImageDownloadResponse,
|
||||
monitor_progress=False,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(results.download_url))
|
||||
|
||||
|
||||
class TopazVideoEnhance(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TopazVideoEnhance",
|
||||
display_name="Topaz Video Enhance",
|
||||
category="api node/video/Topaz",
|
||||
description="Breathe new life into video with powerful upscaling and recovery technology.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Boolean.Input("upscaler_enabled", default=True),
|
||||
IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())),
|
||||
IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())),
|
||||
IO.Combo.Input(
|
||||
"upscaler_creativity",
|
||||
options=["low", "middle", "high"],
|
||||
default="low",
|
||||
tooltip="Creativity level (applies only to Starlight (Astra) Creative).",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input("interpolation_enabled", default=False, optional=True),
|
||||
IO.Combo.Input("interpolation_model", options=["apo-8"], default="apo-8", optional=True),
|
||||
IO.Int.Input(
|
||||
"interpolation_slowmo",
|
||||
default=1,
|
||||
min=1,
|
||||
max=16,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Slow-motion factor applied to the input video. "
|
||||
"For example, 2 makes the output twice as slow and doubles the duration.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"interpolation_frame_rate",
|
||||
default=60,
|
||||
min=15,
|
||||
max=240,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Output frame rate.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"interpolation_duplicate",
|
||||
default=False,
|
||||
tooltip="Analyze the input for duplicate frames and remove them.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"interpolation_duplicate_threshold",
|
||||
default=0.01,
|
||||
min=0.001,
|
||||
max=0.1,
|
||||
step=0.001,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Detection sensitivity for duplicate frames.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"dynamic_compression_level",
|
||||
options=["Low", "Mid", "High"],
|
||||
default="Low",
|
||||
tooltip="CQP level.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
upscaler_enabled: bool,
|
||||
upscaler_model: str,
|
||||
upscaler_resolution: str,
|
||||
upscaler_creativity: str = "low",
|
||||
interpolation_enabled: bool = False,
|
||||
interpolation_model: str = "apo-8",
|
||||
interpolation_slowmo: int = 1,
|
||||
interpolation_frame_rate: int = 60,
|
||||
interpolation_duplicate: bool = False,
|
||||
interpolation_duplicate_threshold: float = 0.01,
|
||||
dynamic_compression_level: str = "Low",
|
||||
) -> IO.NodeOutput:
|
||||
if upscaler_enabled is False and interpolation_enabled is False:
|
||||
raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.")
|
||||
validate_container_format_is_mp4(video)
|
||||
src_width, src_height = video.get_dimensions()
|
||||
src_frame_rate = int(video.get_frame_rate())
|
||||
duration_sec = video.get_duration()
|
||||
src_video_stream = video.get_stream_source()
|
||||
target_width = src_width
|
||||
target_height = src_height
|
||||
target_frame_rate = src_frame_rate
|
||||
filters = []
|
||||
if upscaler_enabled:
|
||||
target_width = UPSCALER_VALUES_MAP[upscaler_resolution]
|
||||
target_height = UPSCALER_VALUES_MAP[upscaler_resolution]
|
||||
filters.append(
|
||||
topaz_api.VideoEnhancementFilter(
|
||||
model=UPSCALER_MODELS_MAP[upscaler_model],
|
||||
creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
|
||||
isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None),
|
||||
),
|
||||
)
|
||||
if interpolation_enabled:
|
||||
target_frame_rate = interpolation_frame_rate
|
||||
filters.append(
|
||||
topaz_api.VideoFrameInterpolationFilter(
|
||||
model=interpolation_model,
|
||||
slowmo=interpolation_slowmo,
|
||||
fps=interpolation_frame_rate,
|
||||
duplicate=interpolation_duplicate,
|
||||
duplicate_threshold=interpolation_duplicate_threshold,
|
||||
),
|
||||
)
|
||||
initial_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/topaz/video/", method="POST"),
|
||||
response_model=topaz_api.CreateVideoResponse,
|
||||
data=topaz_api.CreateVideoRequest(
|
||||
source=topaz_api.CreateCreateVideoRequestSource(
|
||||
container="mp4",
|
||||
size=get_fs_object_size(src_video_stream),
|
||||
duration=int(duration_sec),
|
||||
frameCount=video.get_frame_count(),
|
||||
frameRate=src_frame_rate,
|
||||
resolution=topaz_api.Resolution(width=src_width, height=src_height),
|
||||
),
|
||||
filters=filters,
|
||||
output=topaz_api.OutputInformationVideo(
|
||||
resolution=topaz_api.Resolution(width=target_width, height=target_height),
|
||||
frameRate=target_frame_rate,
|
||||
audioCodec="AAC",
|
||||
audioTransfer="Copy",
|
||||
dynamicCompressionLevel=dynamic_compression_level,
|
||||
),
|
||||
),
|
||||
wait_label="Creating task",
|
||||
final_label_on_success="Task created",
|
||||
)
|
||||
upload_res = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/accept",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=topaz_api.VideoAcceptResponse,
|
||||
wait_label="Preparing upload",
|
||||
final_label_on_success="Upload started",
|
||||
)
|
||||
if len(upload_res.urls) > 1:
|
||||
raise NotImplementedError(
|
||||
"Large files are not currently supported. Please open an issue in the ComfyUI repository."
|
||||
)
|
||||
async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session:
|
||||
if isinstance(src_video_stream, BytesIO):
|
||||
src_video_stream.seek(0)
|
||||
async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res:
|
||||
upload_etag = res.headers["Etag"]
|
||||
else:
|
||||
with builtins.open(src_video_stream, "rb") as video_file:
|
||||
async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res:
|
||||
upload_etag = res.headers["Etag"]
|
||||
await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload",
|
||||
method="PATCH",
|
||||
),
|
||||
response_model=topaz_api.VideoCompleteUploadResponse,
|
||||
data=topaz_api.VideoCompleteUploadRequest(
|
||||
uploadResults=[
|
||||
topaz_api.VideoCompleteUploadRequestPart(
|
||||
partNum=1,
|
||||
eTag=upload_etag,
|
||||
),
|
||||
],
|
||||
),
|
||||
wait_label="Finalizing upload",
|
||||
final_label_on_success="Upload completed",
|
||||
)
|
||||
final_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"),
|
||||
response_model=topaz_api.VideoStatusResponse,
|
||||
status_extractor=lambda x: x.status,
|
||||
progress_extractor=lambda x: getattr(x, "progress", 0),
|
||||
price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None),
|
||||
poll_interval=10.0,
|
||||
max_poll_attempts=320,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(final_response.download.url))
|
||||
|
||||
|
||||
class TopazExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
TopazImageEnhance,
|
||||
TopazVideoEnhance,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> TopazExtension:
|
||||
return TopazExtension()
|
||||
@ -3,13 +3,15 @@ from io import BytesIO
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.input_impl.video_types import VideoFromFile
|
||||
from comfy_api.latest import IO, ComfyExtension
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
|
||||
from comfy_api_nodes.apis.veo_api import (
|
||||
VeoGenVidPollRequest,
|
||||
VeoGenVidPollResponse,
|
||||
VeoGenVidRequest,
|
||||
VeoGenVidResponse,
|
||||
VeoRequestInstance,
|
||||
VeoRequestInstanceImage,
|
||||
VeoRequestParameters,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
@ -228,7 +230,7 @@ class VeoVideoGenerationNode(IO.ComfyNode):
|
||||
|
||||
# Check if video is provided as base64 or URL
|
||||
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
|
||||
if hasattr(video, "gcsUri") and video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
@ -346,12 +348,163 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
|
||||
)
|
||||
|
||||
|
||||
class Veo3FirstLastFrameNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="Veo3FirstLastFrameNode",
|
||||
display_name="Google Veo 3 First-Last-Frame to Video",
|
||||
category="api node/video/Veo",
|
||||
description="Generate video using prompt and first and last frames.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the video",
|
||||
),
|
||||
IO.String.Input(
|
||||
"negative_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Negative text prompt to guide what to avoid in the video",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["720p", "1080p"]),
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=["16:9", "9:16"],
|
||||
default="16:9",
|
||||
tooltip="Aspect ratio of the output video",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"duration",
|
||||
default=8,
|
||||
min=4,
|
||||
max=8,
|
||||
step=2,
|
||||
display_mode=IO.NumberDisplay.slider,
|
||||
tooltip="Duration of the output video in seconds",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=0xFFFFFFFF,
|
||||
step=1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for video generation",
|
||||
),
|
||||
IO.Image.Input("first_frame", tooltip="Start frame"),
|
||||
IO.Image.Input("last_frame", tooltip="End frame"),
|
||||
IO.Combo.Input(
|
||||
"model",
|
||||
options=["veo-3.1-generate", "veo-3.1-fast-generate"],
|
||||
default="veo-3.1-fast-generate",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"generate_audio",
|
||||
default=True,
|
||||
tooltip="Generate audio for the video.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
resolution: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
seed: int,
|
||||
first_frame: Input.Image,
|
||||
last_frame: Input.Image,
|
||||
model: str,
|
||||
generate_audio: bool,
|
||||
):
|
||||
model = MODELS_MAP[model]
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
|
||||
response_model=VeoGenVidResponse,
|
||||
data=VeoGenVidRequest(
|
||||
instances=[
|
||||
VeoRequestInstance(
|
||||
prompt=prompt,
|
||||
image=VeoRequestInstanceImage(
|
||||
bytesBase64Encoded=tensor_to_base64_string(first_frame), mimeType="image/png"
|
||||
),
|
||||
lastFrame=VeoRequestInstanceImage(
|
||||
bytesBase64Encoded=tensor_to_base64_string(last_frame), mimeType="image/png"
|
||||
),
|
||||
),
|
||||
],
|
||||
parameters=VeoRequestParameters(
|
||||
aspectRatio=aspect_ratio,
|
||||
personGeneration="ALLOW",
|
||||
durationSeconds=duration,
|
||||
enhancePrompt=True, # cannot be False for Veo3
|
||||
seed=seed,
|
||||
generateAudio=generate_audio,
|
||||
negativePrompt=negative_prompt,
|
||||
resolution=resolution,
|
||||
),
|
||||
),
|
||||
)
|
||||
poll_response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
|
||||
response_model=VeoGenVidPollResponse,
|
||||
status_extractor=lambda r: "completed" if r.done else "pending",
|
||||
data=VeoGenVidPollRequest(
|
||||
operationName=initial_response.name,
|
||||
),
|
||||
poll_interval=5.0,
|
||||
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
|
||||
)
|
||||
|
||||
if poll_response.error:
|
||||
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
|
||||
|
||||
response = poll_response.response
|
||||
filtered_count = response.raiMediaFilteredCount
|
||||
if filtered_count:
|
||||
reasons = response.raiMediaFilteredReasons or []
|
||||
reason_part = f": {reasons[0]}" if reasons else ""
|
||||
raise Exception(
|
||||
f"Content blocked by Google's Responsible AI filters{reason_part} "
|
||||
f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)."
|
||||
)
|
||||
|
||||
if response.videos:
|
||||
video = response.videos[0]
|
||||
if video.bytesBase64Encoded:
|
||||
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
|
||||
if video.gcsUri:
|
||||
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
|
||||
raise Exception("Video returned but no data or URL was provided")
|
||||
raise Exception("Video generation completed but no video was returned")
|
||||
|
||||
|
||||
class VeoExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
VeoVideoGenerationNode,
|
||||
Veo3VideoGenerationNode,
|
||||
Veo3FirstLastFrameNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -36,6 +36,7 @@ from .upload_helpers import (
|
||||
upload_video_to_comfyapi,
|
||||
)
|
||||
from .validation_utils import (
|
||||
get_image_dimensions,
|
||||
get_number_of_images,
|
||||
validate_aspect_ratio_string,
|
||||
validate_audio_duration,
|
||||
@ -46,6 +47,7 @@ from .validation_utils import (
|
||||
validate_string,
|
||||
validate_video_dimensions,
|
||||
validate_video_duration,
|
||||
validate_video_frame_count,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -82,6 +84,7 @@ __all__ = [
|
||||
"trim_video",
|
||||
"video_to_base64_string",
|
||||
# Validation utilities
|
||||
"get_image_dimensions",
|
||||
"get_number_of_images",
|
||||
"validate_aspect_ratio_string",
|
||||
"validate_audio_duration",
|
||||
@ -92,6 +95,7 @@ __all__ = [
|
||||
"validate_string",
|
||||
"validate_video_dimensions",
|
||||
"validate_video_duration",
|
||||
"validate_video_frame_count",
|
||||
# Misc functions
|
||||
"get_fs_object_size",
|
||||
]
|
||||
|
||||
@ -2,8 +2,8 @@ import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from io import BytesIO
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.model_management import processing_interrupted
|
||||
@ -35,12 +35,12 @@ def default_base_url() -> str:
|
||||
|
||||
async def sleep_with_interrupt(
|
||||
seconds: float,
|
||||
node_cls: Optional[type[IO.ComfyNode]],
|
||||
label: Optional[str] = None,
|
||||
start_ts: Optional[float] = None,
|
||||
estimated_total: Optional[int] = None,
|
||||
node_cls: type[IO.ComfyNode] | None,
|
||||
label: str | None = None,
|
||||
start_ts: float | None = None,
|
||||
estimated_total: int | None = None,
|
||||
*,
|
||||
display_callback: Optional[Callable[[type[IO.ComfyNode], str, int, Optional[int]], None]] = None,
|
||||
display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None,
|
||||
):
|
||||
"""
|
||||
Sleep in 1s slices while:
|
||||
@ -65,7 +65,7 @@ def mimetype_to_extension(mime_type: str) -> str:
|
||||
return mime_type.split("/")[-1].lower()
|
||||
|
||||
|
||||
def get_fs_object_size(path_or_object: Union[str, BytesIO]) -> int:
|
||||
def get_fs_object_size(path_or_object: str | BytesIO) -> int:
|
||||
if isinstance(path_or_object, str):
|
||||
return os.path.getsize(path_or_object)
|
||||
return len(path_or_object.getvalue())
|
||||
|
||||
@ -4,10 +4,11 @@ import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from typing import Any, Callable, Iterable, Literal, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Literal, TypeVar
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
@ -37,8 +38,8 @@ class ApiEndpoint:
|
||||
path: str,
|
||||
method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET",
|
||||
*,
|
||||
query_params: Optional[dict[str, Any]] = None,
|
||||
headers: Optional[dict[str, str]] = None,
|
||||
query_params: dict[str, Any] | None = None,
|
||||
headers: dict[str, str] | None = None,
|
||||
):
|
||||
self.path = path
|
||||
self.method = method
|
||||
@ -52,17 +53,18 @@ class _RequestConfig:
|
||||
endpoint: ApiEndpoint
|
||||
timeout: float
|
||||
content_type: str
|
||||
data: Optional[dict[str, Any]]
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]]
|
||||
multipart_parser: Optional[Callable]
|
||||
data: dict[str, Any] | None
|
||||
files: dict[str, Any] | list[tuple[str, Any]] | None
|
||||
multipart_parser: Callable | None
|
||||
max_retries: int
|
||||
retry_delay: float
|
||||
retry_backoff: float
|
||||
wait_label: str = "Waiting"
|
||||
monitor_progress: bool = True
|
||||
estimated_total: Optional[int] = None
|
||||
final_label_on_success: Optional[str] = "Completed"
|
||||
progress_origin_ts: Optional[float] = None
|
||||
estimated_total: int | None = None
|
||||
final_label_on_success: str | None = "Completed"
|
||||
progress_origin_ts: float | None = None
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -70,40 +72,42 @@ class _PollUIState:
|
||||
started: float
|
||||
status_label: str = "Queued"
|
||||
is_queued: bool = True
|
||||
price: Optional[float] = None
|
||||
estimated_duration: Optional[int] = None
|
||||
price: float | None = None
|
||||
estimated_duration: int | None = None
|
||||
base_processing_elapsed: float = 0.0 # sum of completed active intervals
|
||||
active_since: Optional[float] = None # start time of current active interval (None if queued)
|
||||
active_since: float | None = None # start time of current active interval (None if queued)
|
||||
|
||||
|
||||
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
|
||||
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
|
||||
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
|
||||
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
|
||||
|
||||
|
||||
async def sync_op(
|
||||
cls: type[IO.ComfyNode],
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
response_model: Type[M],
|
||||
data: Optional[BaseModel] = None,
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||
response_model: type[M],
|
||||
price_extractor: Callable[[M | Any], float | None] | None = None,
|
||||
data: BaseModel | None = None,
|
||||
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
|
||||
content_type: str = "application/json",
|
||||
timeout: float = 3600.0,
|
||||
multipart_parser: Optional[Callable] = None,
|
||||
multipart_parser: Callable | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: str = "Waiting for server",
|
||||
estimated_duration: Optional[int] = None,
|
||||
final_label_on_success: Optional[str] = "Completed",
|
||||
progress_origin_ts: Optional[float] = None,
|
||||
estimated_duration: int | None = None,
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
) -> M:
|
||||
raw = await sync_op_raw(
|
||||
cls,
|
||||
endpoint,
|
||||
price_extractor=_wrap_model_extractor(response_model, price_extractor),
|
||||
data=data,
|
||||
files=files,
|
||||
content_type=content_type,
|
||||
@ -128,22 +132,22 @@ async def poll_op(
|
||||
cls: type[IO.ComfyNode],
|
||||
poll_endpoint: ApiEndpoint,
|
||||
*,
|
||||
response_model: Type[M],
|
||||
status_extractor: Callable[[M], Optional[Union[str, int]]],
|
||||
progress_extractor: Optional[Callable[[M], Optional[int]]] = None,
|
||||
price_extractor: Optional[Callable[[M], Optional[float]]] = None,
|
||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||||
data: Optional[BaseModel] = None,
|
||||
response_model: type[M],
|
||||
status_extractor: Callable[[M | Any], str | int | None],
|
||||
progress_extractor: Callable[[M | Any], int | None] | None = None,
|
||||
price_extractor: Callable[[M | Any], float | None] | None = None,
|
||||
completed_statuses: list[str | int] | None = None,
|
||||
failed_statuses: list[str | int] | None = None,
|
||||
queued_statuses: list[str | int] | None = None,
|
||||
data: BaseModel | None = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
estimated_duration: Optional[int] = None,
|
||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
) -> M:
|
||||
raw = await poll_op_raw(
|
||||
@ -175,21 +179,22 @@ async def sync_op_raw(
|
||||
cls: type[IO.ComfyNode],
|
||||
endpoint: ApiEndpoint,
|
||||
*,
|
||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]] = None,
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
|
||||
data: dict[str, Any] | BaseModel | None = None,
|
||||
files: dict[str, Any] | list[tuple[str, Any]] | None = None,
|
||||
content_type: str = "application/json",
|
||||
timeout: float = 3600.0,
|
||||
multipart_parser: Optional[Callable] = None,
|
||||
multipart_parser: Callable | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: str = "Waiting for server",
|
||||
estimated_duration: Optional[int] = None,
|
||||
estimated_duration: int | None = None,
|
||||
as_binary: bool = False,
|
||||
final_label_on_success: Optional[str] = "Completed",
|
||||
progress_origin_ts: Optional[float] = None,
|
||||
final_label_on_success: str | None = "Completed",
|
||||
progress_origin_ts: float | None = None,
|
||||
monitor_progress: bool = True,
|
||||
) -> Union[dict[str, Any], bytes]:
|
||||
) -> dict[str, Any] | bytes:
|
||||
"""
|
||||
Make a single network request.
|
||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||
@ -216,6 +221,7 @@ async def sync_op_raw(
|
||||
estimated_total=estimated_duration,
|
||||
final_label_on_success=final_label_on_success,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
price_extractor=price_extractor,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
@ -224,21 +230,21 @@ async def poll_op_raw(
|
||||
cls: type[IO.ComfyNode],
|
||||
poll_endpoint: ApiEndpoint,
|
||||
*,
|
||||
status_extractor: Callable[[dict[str, Any]], Optional[Union[str, int]]],
|
||||
progress_extractor: Optional[Callable[[dict[str, Any]], Optional[int]]] = None,
|
||||
price_extractor: Optional[Callable[[dict[str, Any]], Optional[float]]] = None,
|
||||
completed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
failed_statuses: Optional[list[Union[str, int]]] = None,
|
||||
queued_statuses: Optional[list[Union[str, int]]] = None,
|
||||
data: Optional[Union[dict[str, Any], BaseModel]] = None,
|
||||
status_extractor: Callable[[dict[str, Any]], str | int | None],
|
||||
progress_extractor: Callable[[dict[str, Any]], int | None] | None = None,
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None,
|
||||
completed_statuses: list[str | int] | None = None,
|
||||
failed_statuses: list[str | int] | None = None,
|
||||
queued_statuses: list[str | int] | None = None,
|
||||
data: dict[str, Any] | BaseModel | None = None,
|
||||
poll_interval: float = 5.0,
|
||||
max_poll_attempts: int = 120,
|
||||
timeout_per_poll: float = 120.0,
|
||||
max_retries_per_poll: int = 3,
|
||||
retry_delay_per_poll: float = 1.0,
|
||||
retry_backoff_per_poll: float = 2.0,
|
||||
estimated_duration: Optional[int] = None,
|
||||
cancel_endpoint: Optional[ApiEndpoint] = None,
|
||||
estimated_duration: int | None = None,
|
||||
cancel_endpoint: ApiEndpoint | None = None,
|
||||
cancel_timeout: float = 10.0,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
@ -256,7 +262,7 @@ async def poll_op_raw(
|
||||
consumed_attempts = 0 # counts only non-queued polls
|
||||
|
||||
progress_bar = utils.ProgressBar(100) if progress_extractor else None
|
||||
last_progress: Optional[int] = None
|
||||
last_progress: int | None = None
|
||||
|
||||
state = _PollUIState(started=started, estimated_duration=estimated_duration)
|
||||
stop_ticker = asyncio.Event()
|
||||
@ -415,16 +421,18 @@ async def poll_op_raw(
|
||||
|
||||
def _display_text(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
text: Optional[str],
|
||||
text: str | None,
|
||||
*,
|
||||
status: Optional[Union[str, int]] = None,
|
||||
price: Optional[float] = None,
|
||||
status: str | int | None = None,
|
||||
price: float | None = None,
|
||||
) -> None:
|
||||
display_lines: list[str] = []
|
||||
if status:
|
||||
display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}")
|
||||
if price is not None:
|
||||
display_lines.append(f"Price: ${float(price):,.4f}")
|
||||
p = f"{float(price):,.4f}".rstrip("0").rstrip(".")
|
||||
if p != "0":
|
||||
display_lines.append(f"Price: ${p}")
|
||||
if text is not None:
|
||||
display_lines.append(text)
|
||||
if display_lines:
|
||||
@ -433,13 +441,13 @@ def _display_text(
|
||||
|
||||
def _display_time_progress(
|
||||
node_cls: type[IO.ComfyNode],
|
||||
status: Optional[Union[str, int]],
|
||||
status: str | int | None,
|
||||
elapsed_seconds: int,
|
||||
estimated_total: Optional[int] = None,
|
||||
estimated_total: int | None = None,
|
||||
*,
|
||||
price: Optional[float] = None,
|
||||
is_queued: Optional[bool] = None,
|
||||
processing_elapsed_seconds: Optional[int] = None,
|
||||
price: float | None = None,
|
||||
is_queued: bool | None = None,
|
||||
processing_elapsed_seconds: int | None = None,
|
||||
) -> None:
|
||||
if estimated_total is not None and estimated_total > 0 and is_queued is False:
|
||||
pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds
|
||||
@ -481,7 +489,7 @@ def _unpack_tuple(t: tuple) -> tuple[str, Any, str]:
|
||||
raise ValueError("files tuple must be (filename, file[, content_type])")
|
||||
|
||||
|
||||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: Optional[dict[str, Any]]) -> dict[str, Any]:
|
||||
def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]:
|
||||
params = dict(endpoint_params or {})
|
||||
if method.upper() == "GET" and data:
|
||||
for k, v in data.items():
|
||||
@ -527,9 +535,9 @@ def _generate_operation_id(method: str, path: str, attempt: int) -> str:
|
||||
def _snapshot_request_body_for_logging(
|
||||
content_type: str,
|
||||
method: str,
|
||||
data: Optional[dict[str, Any]],
|
||||
files: Optional[Union[dict[str, Any], list[tuple[str, Any]]]],
|
||||
) -> Optional[Union[dict[str, Any], str]]:
|
||||
data: dict[str, Any] | None,
|
||||
files: dict[str, Any] | list[tuple[str, Any]] | None,
|
||||
) -> dict[str, Any] | str | None:
|
||||
if method.upper() == "GET":
|
||||
return None
|
||||
if content_type == "multipart/form-data":
|
||||
@ -579,12 +587,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
attempt = 0
|
||||
delay = cfg.retry_delay
|
||||
operation_succeeded: bool = False
|
||||
final_elapsed_seconds: Optional[int] = None
|
||||
final_elapsed_seconds: int | None = None
|
||||
extracted_price: float | None = None
|
||||
while True:
|
||||
attempt += 1
|
||||
stop_event = asyncio.Event()
|
||||
monitor_task: Optional[asyncio.Task] = None
|
||||
sess: Optional[aiohttp.ClientSession] = None
|
||||
monitor_task: asyncio.Task | None = None
|
||||
sess: aiohttp.ClientSession | None = None
|
||||
|
||||
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
|
||||
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
|
||||
@ -767,6 +776,8 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
except json.JSONDecodeError:
|
||||
payload = {"_raw": text}
|
||||
response_content_to_log = payload if isinstance(payload, dict) else text
|
||||
with contextlib.suppress(Exception):
|
||||
extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
try:
|
||||
@ -871,13 +882,13 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
else int(time.monotonic() - start_time)
|
||||
),
|
||||
estimated_total=cfg.estimated_total,
|
||||
price=None,
|
||||
price=extracted_price,
|
||||
is_queued=False,
|
||||
processing_elapsed_seconds=final_elapsed_seconds,
|
||||
)
|
||||
|
||||
|
||||
def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
||||
def _validate_or_raise(response_model: type[M], payload: Any) -> M:
|
||||
try:
|
||||
return response_model.model_validate(payload)
|
||||
except Exception as e:
|
||||
@ -892,9 +903,9 @@ def _validate_or_raise(response_model: Type[M], payload: Any) -> M:
|
||||
|
||||
|
||||
def _wrap_model_extractor(
|
||||
response_model: Type[M],
|
||||
extractor: Optional[Callable[[M], Any]],
|
||||
) -> Optional[Callable[[dict[str, Any]], Any]]:
|
||||
response_model: type[M],
|
||||
extractor: Callable[[M], Any] | None,
|
||||
) -> Callable[[dict[str, Any]], Any] | None:
|
||||
"""Wrap a typed extractor so it can be used by the dict-based poller.
|
||||
Validates the dict into `response_model` before invoking `extractor`.
|
||||
Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating
|
||||
@ -919,10 +930,10 @@ def _wrap_model_extractor(
|
||||
return _wrapped
|
||||
|
||||
|
||||
def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Union[str, int]]:
|
||||
def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]:
|
||||
if not values:
|
||||
return set()
|
||||
out: set[Union[str, int]] = set()
|
||||
out: set[str | int] = set()
|
||||
for v in values:
|
||||
nv = _normalize_status_value(v)
|
||||
if nv is not None:
|
||||
@ -930,7 +941,7 @@ def _normalize_statuses(values: Optional[Iterable[Union[str, int]]]) -> set[Unio
|
||||
return out
|
||||
|
||||
|
||||
def _normalize_status_value(val: Union[str, int, None]) -> Union[str, int, None]:
|
||||
def _normalize_status_value(val: str | int | None) -> str | int | None:
|
||||
if isinstance(val, str):
|
||||
return val.strip().lower()
|
||||
return val
|
||||
|
||||
@ -4,7 +4,6 @@ import math
|
||||
import mimetypes
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
@ -12,8 +11,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.latest import Input, InputImpl
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api.latest import Input, InputImpl, Types
|
||||
|
||||
from ._helpers import mimetype_to_extension
|
||||
|
||||
@ -57,7 +55,7 @@ def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> to
|
||||
|
||||
def tensor_to_bytesio(
|
||||
image: torch.Tensor,
|
||||
name: Optional[str] = None,
|
||||
name: str | None = None,
|
||||
total_pixels: int = 2048 * 2048,
|
||||
mime_type: str = "image/png",
|
||||
) -> BytesIO:
|
||||
@ -177,8 +175,8 @@ def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", co
|
||||
|
||||
def video_to_base64_string(
|
||||
video: Input.Video,
|
||||
container_format: VideoContainer = None,
|
||||
codec: VideoCodec = None
|
||||
container_format: Types.VideoContainer | None = None,
|
||||
codec: Types.VideoCodec | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Converts a video input to a base64 string.
|
||||
@ -189,12 +187,11 @@ def video_to_base64_string(
|
||||
codec: Optional codec to use (defaults to video.codec if available)
|
||||
"""
|
||||
video_bytes_io = BytesIO()
|
||||
|
||||
# Use provided format/codec if specified, otherwise use video's own if available
|
||||
format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
|
||||
codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
|
||||
|
||||
video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
|
||||
video.save_to(
|
||||
video_bytes_io,
|
||||
format=container_format or getattr(video, "container", Types.VideoContainer.MP4),
|
||||
codec=codec or getattr(video, "codec", Types.VideoCodec.H264),
|
||||
)
|
||||
video_bytes_io.seek(0)
|
||||
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
|
||||
|
||||
|
||||
@ -3,15 +3,15 @@ import contextlib
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import IO, Optional, Union
|
||||
from typing import IO
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from aiohttp.client_exceptions import ClientError, ContentTypeError
|
||||
|
||||
from comfy_api.input_impl import VideoFromFile
|
||||
from comfy_api.latest import IO as COMFY_IO
|
||||
from comfy_api.latest import InputImpl
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
@ -29,9 +29,9 @@ _RETRY_STATUS = {408, 429, 500, 502, 503, 504}
|
||||
|
||||
async def download_url_to_bytesio(
|
||||
url: str,
|
||||
dest: Optional[Union[BytesIO, IO[bytes], str, Path]],
|
||||
dest: BytesIO | IO[bytes] | str | Path | None,
|
||||
*,
|
||||
timeout: Optional[float] = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 5,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
@ -71,10 +71,10 @@ async def download_url_to_bytesio(
|
||||
|
||||
is_path_sink = isinstance(dest, (str, Path))
|
||||
fhandle = None
|
||||
session: Optional[aiohttp.ClientSession] = None
|
||||
stop_evt: Optional[asyncio.Event] = None
|
||||
monitor_task: Optional[asyncio.Task] = None
|
||||
req_task: Optional[asyncio.Task] = None
|
||||
session: aiohttp.ClientSession | None = None
|
||||
stop_evt: asyncio.Event | None = None
|
||||
monitor_task: asyncio.Task | None = None
|
||||
req_task: asyncio.Task | None = None
|
||||
|
||||
try:
|
||||
with contextlib.suppress(Exception):
|
||||
@ -234,11 +234,11 @@ async def download_url_to_video_output(
|
||||
timeout: float = None,
|
||||
max_retries: int = 5,
|
||||
cls: type[COMFY_IO.ComfyNode] = None,
|
||||
) -> VideoFromFile:
|
||||
) -> InputImpl.VideoFromFile:
|
||||
"""Downloads a video from a URL and returns a `VIDEO` output."""
|
||||
result = BytesIO()
|
||||
await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls)
|
||||
return VideoFromFile(result)
|
||||
return InputImpl.VideoFromFile(result)
|
||||
|
||||
|
||||
async def download_url_as_bytesio(
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
@ -4,15 +4,13 @@ import logging
|
||||
import time
|
||||
import uuid
|
||||
from io import BytesIO
|
||||
from typing import Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import aiohttp
|
||||
import torch
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from comfy_api.latest import IO, Input
|
||||
from comfy_api.util import VideoCodec, VideoContainer
|
||||
from comfy_api.latest import IO, Input, Types
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import is_processing_interrupted, sleep_with_interrupt
|
||||
@ -32,7 +30,7 @@ from .conversions import (
|
||||
|
||||
class UploadRequest(BaseModel):
|
||||
file_name: str = Field(..., description="Filename to upload")
|
||||
content_type: Optional[str] = Field(
|
||||
content_type: str | None = Field(
|
||||
None,
|
||||
description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
|
||||
)
|
||||
@ -48,22 +46,30 @@ async def upload_images_to_comfyapi(
|
||||
image: torch.Tensor,
|
||||
*,
|
||||
max_images: int = 8,
|
||||
mime_type: Optional[str] = None,
|
||||
wait_label: Optional[str] = "Uploading",
|
||||
mime_type: str | None = None,
|
||||
wait_label: str | None = "Uploading",
|
||||
show_batch_index: bool = True,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Uploads images to ComfyUI API and returns download URLs.
|
||||
To upload multiple images, stack them in the batch dimension first.
|
||||
"""
|
||||
# if batch, try to upload each file if max_images is greater than 0
|
||||
# if batched, try to upload each file if max_images is greater than 0
|
||||
download_urls: list[str] = []
|
||||
is_batch = len(image.shape) > 3
|
||||
batch_len = image.shape[0] if is_batch else 1
|
||||
num_to_upload = min(batch_len, max_images)
|
||||
batch_start_ts = time.monotonic()
|
||||
|
||||
for idx in range(min(batch_len, max_images)):
|
||||
for idx in range(num_to_upload):
|
||||
tensor = image[idx] if is_batch else image
|
||||
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
|
||||
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, wait_label)
|
||||
|
||||
effective_label = wait_label
|
||||
if wait_label and show_batch_index and num_to_upload > 1:
|
||||
effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})"
|
||||
|
||||
url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts)
|
||||
download_urls.append(url)
|
||||
return download_urls
|
||||
|
||||
@ -92,9 +98,10 @@ async def upload_video_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
video: Input.Video,
|
||||
*,
|
||||
container: VideoContainer = VideoContainer.MP4,
|
||||
codec: VideoCodec = VideoCodec.H264,
|
||||
max_duration: Optional[int] = None,
|
||||
container: Types.VideoContainer = Types.VideoContainer.MP4,
|
||||
codec: Types.VideoCodec = Types.VideoCodec.H264,
|
||||
max_duration: int | None = None,
|
||||
wait_label: str | None = "Uploading",
|
||||
) -> str:
|
||||
"""
|
||||
Uploads a single video to ComfyUI API and returns its download URL.
|
||||
@ -119,15 +126,16 @@ async def upload_video_to_comfyapi(
|
||||
video.save_to(video_bytes_io, format=container, codec=codec)
|
||||
video_bytes_io.seek(0)
|
||||
|
||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type)
|
||||
return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label)
|
||||
|
||||
|
||||
async def upload_file_to_comfyapi(
|
||||
cls: type[IO.ComfyNode],
|
||||
file_bytes_io: BytesIO,
|
||||
filename: str,
|
||||
upload_mime_type: Optional[str],
|
||||
wait_label: Optional[str] = "Uploading",
|
||||
upload_mime_type: str | None,
|
||||
wait_label: str | None = "Uploading",
|
||||
progress_origin_ts: float | None = None,
|
||||
) -> str:
|
||||
"""Uploads a single file to ComfyUI API and returns its download URL."""
|
||||
if upload_mime_type is None:
|
||||
@ -148,6 +156,7 @@ async def upload_file_to_comfyapi(
|
||||
file_bytes_io,
|
||||
content_type=upload_mime_type,
|
||||
wait_label=wait_label,
|
||||
progress_origin_ts=progress_origin_ts,
|
||||
)
|
||||
return create_resp.download_url
|
||||
|
||||
@ -155,27 +164,18 @@ async def upload_file_to_comfyapi(
|
||||
async def upload_file(
|
||||
cls: type[IO.ComfyNode],
|
||||
upload_url: str,
|
||||
file: Union[BytesIO, str],
|
||||
file: BytesIO | str,
|
||||
*,
|
||||
content_type: Optional[str] = None,
|
||||
content_type: str | None = None,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 1.0,
|
||||
retry_backoff: float = 2.0,
|
||||
wait_label: Optional[str] = None,
|
||||
wait_label: str | None = None,
|
||||
progress_origin_ts: float | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption.
|
||||
|
||||
Args:
|
||||
cls: Node class (provides auth context + UI progress hooks).
|
||||
upload_url: Pre-signed PUT URL.
|
||||
file: BytesIO or path string.
|
||||
content_type: Explicit MIME type. If None, we *suppress* Content-Type.
|
||||
max_retries: Maximum retry attempts.
|
||||
retry_delay: Initial delay in seconds.
|
||||
retry_backoff: Exponential backoff factor.
|
||||
wait_label: Progress label shown in Comfy UI.
|
||||
|
||||
Raises:
|
||||
ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception
|
||||
"""
|
||||
@ -198,7 +198,7 @@ async def upload_file(
|
||||
|
||||
attempt = 0
|
||||
delay = retry_delay
|
||||
start_ts = time.monotonic()
|
||||
start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic()
|
||||
op_uuid = uuid.uuid4().hex[:8]
|
||||
while True:
|
||||
attempt += 1
|
||||
@ -218,7 +218,7 @@ async def upload_file(
|
||||
return
|
||||
|
||||
monitor_task = asyncio.create_task(_monitor())
|
||||
sess: Optional[aiohttp.ClientSession] = None
|
||||
sess: aiohttp.ClientSession | None = None
|
||||
try:
|
||||
try:
|
||||
request_logger.log_request_response(
|
||||
|
||||
@ -1,9 +1,7 @@
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from comfy_api.input.video_types import VideoInput
|
||||
from comfy_api.latest import Input
|
||||
|
||||
|
||||
@ -18,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
|
||||
|
||||
def validate_image_dimensions(
|
||||
image: torch.Tensor,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
min_height: Optional[int] = None,
|
||||
max_height: Optional[int] = None,
|
||||
min_width: int | None = None,
|
||||
max_width: int | None = None,
|
||||
min_height: int | None = None,
|
||||
max_height: int | None = None,
|
||||
):
|
||||
height, width = get_image_dimensions(image)
|
||||
|
||||
@ -37,8 +35,8 @@ def validate_image_dimensions(
|
||||
|
||||
def validate_image_aspect_ratio(
|
||||
image: torch.Tensor,
|
||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
|
||||
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
|
||||
*,
|
||||
strict: bool = True, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
@ -54,8 +52,8 @@ def validate_image_aspect_ratio(
|
||||
def validate_images_aspect_ratio_closeness(
|
||||
first_image: torch.Tensor,
|
||||
second_image: torch.Tensor,
|
||||
min_rel: float, # e.g. 0.8
|
||||
max_rel: float, # e.g. 1.25
|
||||
min_rel: float, # e.g. 0.8
|
||||
max_rel: float, # e.g. 1.25
|
||||
*,
|
||||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
@ -84,8 +82,8 @@ def validate_images_aspect_ratio_closeness(
|
||||
|
||||
def validate_aspect_ratio_string(
|
||||
aspect_ratio: str,
|
||||
min_ratio: Optional[tuple[float, float]] = None, # e.g. (1, 4)
|
||||
max_ratio: Optional[tuple[float, float]] = None, # e.g. (4, 1)
|
||||
min_ratio: tuple[float, float] | None = None, # e.g. (1, 4)
|
||||
max_ratio: tuple[float, float] | None = None, # e.g. (4, 1)
|
||||
*,
|
||||
strict: bool = False, # True -> (min, max); False -> [min, max]
|
||||
) -> float:
|
||||
@ -97,10 +95,10 @@ def validate_aspect_ratio_string(
|
||||
|
||||
def validate_video_dimensions(
|
||||
video: Input.Video,
|
||||
min_width: Optional[int] = None,
|
||||
max_width: Optional[int] = None,
|
||||
min_height: Optional[int] = None,
|
||||
max_height: Optional[int] = None,
|
||||
min_width: int | None = None,
|
||||
max_width: int | None = None,
|
||||
min_height: int | None = None,
|
||||
max_height: int | None = None,
|
||||
):
|
||||
try:
|
||||
width, height = video.get_dimensions()
|
||||
@ -120,8 +118,8 @@ def validate_video_dimensions(
|
||||
|
||||
def validate_video_duration(
|
||||
video: Input.Video,
|
||||
min_duration: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
min_duration: float | None = None,
|
||||
max_duration: float | None = None,
|
||||
):
|
||||
try:
|
||||
duration = video.get_duration()
|
||||
@ -136,6 +134,23 @@ def validate_video_duration(
|
||||
raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s")
|
||||
|
||||
|
||||
def validate_video_frame_count(
|
||||
video: Input.Video,
|
||||
min_frame_count: int | None = None,
|
||||
max_frame_count: int | None = None,
|
||||
):
|
||||
try:
|
||||
frame_count = video.get_frame_count()
|
||||
except Exception as e:
|
||||
logging.error("Error getting frame count of video: %s", e)
|
||||
return
|
||||
|
||||
if min_frame_count is not None and min_frame_count > frame_count:
|
||||
raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}")
|
||||
if max_frame_count is not None and frame_count > max_frame_count:
|
||||
raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}")
|
||||
|
||||
|
||||
def get_number_of_images(images):
|
||||
if isinstance(images, torch.Tensor):
|
||||
return images.shape[0] if images.ndim >= 4 else 1
|
||||
@ -144,8 +159,8 @@ def get_number_of_images(images):
|
||||
|
||||
def validate_audio_duration(
|
||||
audio: Input.Audio,
|
||||
min_duration: Optional[float] = None,
|
||||
max_duration: Optional[float] = None,
|
||||
min_duration: float | None = None,
|
||||
max_duration: float | None = None,
|
||||
) -> None:
|
||||
sr = int(audio["sample_rate"])
|
||||
dur = int(audio["waveform"].shape[-1]) / sr
|
||||
@ -177,7 +192,7 @@ def validate_string(
|
||||
)
|
||||
|
||||
|
||||
def validate_container_format_is_mp4(video: VideoInput) -> None:
|
||||
def validate_container_format_is_mp4(video: Input.Video) -> None:
|
||||
"""Validates video container format is MP4."""
|
||||
container_format = video.get_container_format()
|
||||
if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]:
|
||||
@ -194,8 +209,8 @@ def _ratio_from_tuple(r: tuple[float, float]) -> float:
|
||||
def _assert_ratio_bounds(
|
||||
ar: float,
|
||||
*,
|
||||
min_ratio: Optional[tuple[float, float]] = None,
|
||||
max_ratio: Optional[tuple[float, float]] = None,
|
||||
min_ratio: tuple[float, float] | None = None,
|
||||
max_ratio: tuple[float, float] | None = None,
|
||||
strict: bool = True,
|
||||
) -> None:
|
||||
"""Validate a numeric aspect ratio against optional min/max ratio bounds."""
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from comfy_api.latest import IO
|
||||
|
||||
|
||||
def validate_node_input(
|
||||
@ -23,6 +24,11 @@ def validate_node_input(
|
||||
if not received_type != input_type:
|
||||
return True
|
||||
|
||||
# If the received type or input_type is a MatchType, we can return True immediately;
|
||||
# validation for this is handled by the frontend
|
||||
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
|
||||
return True
|
||||
|
||||
# Not equal, and not strings
|
||||
if not isinstance(received_type, str) or not isinstance(input_type, str):
|
||||
return False
|
||||
|
||||
@ -6,65 +6,80 @@ import torch
|
||||
import comfy.model_management
|
||||
import folder_paths
|
||||
import os
|
||||
import io
|
||||
import json
|
||||
import random
|
||||
import hashlib
|
||||
import node_helpers
|
||||
import logging
|
||||
from comfy.cli_args import args
|
||||
from comfy.comfy_types import FileLocator
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, IO, UI
|
||||
|
||||
class EmptyLatentAudio:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
class EmptyLatentAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyLatentAudio",
|
||||
display_name="Empty Latent Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||
IO.Int.Input(
|
||||
"batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."
|
||||
),
|
||||
],
|
||||
outputs=[IO.Latent.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}),
|
||||
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "generate"
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def generate(self, seconds, batch_size):
|
||||
def execute(cls, seconds, batch_size) -> IO.NodeOutput:
|
||||
length = round((seconds * 44100 / 2048) / 2) * 2
|
||||
latent = torch.zeros([batch_size, 64, length], device=self.device)
|
||||
return ({"samples":latent, "type": "audio"}, )
|
||||
latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device())
|
||||
return IO.NodeOutput({"samples":latent, "type": "audio"})
|
||||
|
||||
class ConditioningStableAudio:
|
||||
generate = execute # TODO: remove
|
||||
|
||||
|
||||
class ConditioningStableAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
|
||||
"seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ConditioningStableAudio",
|
||||
category="conditioning",
|
||||
inputs=[
|
||||
IO.Conditioning.Input("positive"),
|
||||
IO.Conditioning.Input("negative"),
|
||||
IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1),
|
||||
IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1),
|
||||
],
|
||||
outputs=[
|
||||
IO.Conditioning.Output(display_name="positive"),
|
||||
IO.Conditioning.Output(display_name="negative"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING","CONDITIONING")
|
||||
RETURN_NAMES = ("positive", "negative")
|
||||
|
||||
FUNCTION = "append"
|
||||
|
||||
CATEGORY = "conditioning"
|
||||
|
||||
def append(self, positive, negative, seconds_start, seconds_total):
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total})
|
||||
return (positive, negative)
|
||||
return IO.NodeOutput(positive, negative)
|
||||
|
||||
class VAEEncodeAudio:
|
||||
append = execute # TODO: remove
|
||||
|
||||
|
||||
class VAEEncodeAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ), "vae": ("VAE", )}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEEncodeAudio",
|
||||
display_name="VAE Encode Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[IO.Latent.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def encode(self, vae, audio):
|
||||
@classmethod
|
||||
def execute(cls, vae, audio) -> IO.NodeOutput:
|
||||
sample_rate = audio["sample_rate"]
|
||||
if 44100 != sample_rate:
|
||||
waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100)
|
||||
@ -72,213 +87,134 @@ class VAEEncodeAudio:
|
||||
waveform = audio["waveform"]
|
||||
|
||||
t = vae.encode(waveform.movedim(1, -1))
|
||||
return ({"samples":t}, )
|
||||
return IO.NodeOutput({"samples":t})
|
||||
|
||||
class VAEDecodeAudio:
|
||||
encode = execute # TODO: remove
|
||||
|
||||
|
||||
class VAEDecodeAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "samples": ("LATENT", ), "vae": ("VAE", )}}
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "decode"
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="VAEDecodeAudio",
|
||||
display_name="VAE Decode Audio",
|
||||
category="latent/audio",
|
||||
inputs=[
|
||||
IO.Latent.Input("samples"),
|
||||
IO.Vae.Input("vae"),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "latent/audio"
|
||||
|
||||
def decode(self, vae, samples):
|
||||
@classmethod
|
||||
def execute(cls, vae, samples) -> IO.NodeOutput:
|
||||
audio = vae.decode(samples["samples"]).movedim(-1, 1)
|
||||
std = torch.std(audio, dim=[1,2], keepdim=True) * 5.0
|
||||
std[std < 1.0] = 1.0
|
||||
audio /= std
|
||||
return ({"waveform": audio, "sample_rate": 44100}, )
|
||||
return IO.NodeOutput({"waveform": audio, "sample_rate": 44100})
|
||||
|
||||
decode = execute # TODO: remove
|
||||
|
||||
|
||||
def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
|
||||
filename_prefix += self.prefix_append
|
||||
full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
|
||||
results: list[FileLocator] = []
|
||||
|
||||
# Prepare metadata dictionary
|
||||
metadata = {}
|
||||
if not args.disable_metadata:
|
||||
if prompt is not None:
|
||||
metadata["prompt"] = json.dumps(prompt)
|
||||
if extra_pnginfo is not None:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
# Opus supported sample rates
|
||||
OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
|
||||
|
||||
for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||
output_path = os.path.join(full_output_folder, file)
|
||||
|
||||
# Use original sample rate initially
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
# Handle Opus sample rate requirements
|
||||
if format == "opus":
|
||||
if sample_rate > 48000:
|
||||
sample_rate = 48000
|
||||
elif sample_rate not in OPUS_RATES:
|
||||
# Find the next highest supported rate
|
||||
for rate in sorted(OPUS_RATES):
|
||||
if rate > sample_rate:
|
||||
sample_rate = rate
|
||||
break
|
||||
if sample_rate not in OPUS_RATES: # Fallback if still not supported
|
||||
sample_rate = 48000
|
||||
|
||||
# Resample if necessary
|
||||
if sample_rate != audio["sample_rate"]:
|
||||
waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
|
||||
|
||||
# Create output with specified format
|
||||
output_buffer = io.BytesIO()
|
||||
output_container = av.open(output_buffer, mode='w', format=format)
|
||||
|
||||
# Set metadata on the container
|
||||
for key, value in metadata.items():
|
||||
output_container.metadata[key] = value
|
||||
|
||||
layout = 'mono' if waveform.shape[0] == 1 else 'stereo'
|
||||
# Set up the output stream with appropriate properties
|
||||
if format == "opus":
|
||||
out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout)
|
||||
if quality == "64k":
|
||||
out_stream.bit_rate = 64000
|
||||
elif quality == "96k":
|
||||
out_stream.bit_rate = 96000
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "192k":
|
||||
out_stream.bit_rate = 192000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
elif format == "mp3":
|
||||
out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout)
|
||||
if quality == "V0":
|
||||
#TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
|
||||
out_stream.codec_context.qscale = 1
|
||||
elif quality == "128k":
|
||||
out_stream.bit_rate = 128000
|
||||
elif quality == "320k":
|
||||
out_stream.bit_rate = 320000
|
||||
else: #format == "flac":
|
||||
out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout)
|
||||
|
||||
frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout)
|
||||
frame.sample_rate = sample_rate
|
||||
frame.pts = 0
|
||||
output_container.mux(out_stream.encode(frame))
|
||||
|
||||
# Flush encoder
|
||||
output_container.mux(out_stream.encode(None))
|
||||
|
||||
# Close containers
|
||||
output_container.close()
|
||||
|
||||
# Write the output to file
|
||||
output_buffer.seek(0)
|
||||
with open(output_path, 'wb') as f:
|
||||
f.write(output_buffer.getbuffer())
|
||||
|
||||
results.append({
|
||||
"filename": file,
|
||||
"subfolder": subfolder,
|
||||
"type": self.type
|
||||
})
|
||||
counter += 1
|
||||
|
||||
return { "ui": { "audio": results } }
|
||||
|
||||
class SaveAudio:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
class SaveAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudio",
|
||||
display_name="Save Audio (FLAC)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput:
|
||||
return IO.NodeOutput(
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format)
|
||||
)
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_flac"
|
||||
save_flac = execute # TODO: remove
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
|
||||
|
||||
class SaveAudioMP3:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
class SaveAudioMP3(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioMP3",
|
||||
display_name="Save Audio (MP3)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["V0", "128k", "320k"], {"default": "V0"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput:
|
||||
return IO.NodeOutput(
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||
)
|
||||
)
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_mp3"
|
||||
save_mp3 = execute # TODO: remove
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
class SaveAudioOpus:
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_output_directory()
|
||||
self.type = "output"
|
||||
self.prefix_append = ""
|
||||
class SaveAudioOpus(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioOpus",
|
||||
display_name="Save Audio (Opus)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "audio": ("AUDIO", ),
|
||||
"filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
|
||||
"quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
|
||||
},
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput:
|
||||
return IO.NodeOutput(
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(
|
||||
audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality
|
||||
)
|
||||
)
|
||||
|
||||
RETURN_TYPES = ()
|
||||
FUNCTION = "save_opus"
|
||||
save_opus = execute # TODO: remove
|
||||
|
||||
OUTPUT_NODE = True
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
|
||||
return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
|
||||
|
||||
class PreviewAudio(SaveAudio):
|
||||
def __init__(self):
|
||||
self.output_dir = folder_paths.get_temp_directory()
|
||||
self.type = "temp"
|
||||
self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5))
|
||||
class PreviewAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="PreviewAudio",
|
||||
display_name="Preview Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"audio": ("AUDIO", ), },
|
||||
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
|
||||
}
|
||||
def execute(cls, audio) -> IO.NodeOutput:
|
||||
return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls))
|
||||
|
||||
save_flac = execute # TODO: remove
|
||||
|
||||
|
||||
def f32_pcm(wav: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert audio to float 32 bits PCM format."""
|
||||
@ -316,26 +252,30 @@ def load(filepath: str) -> tuple[torch.Tensor, int]:
|
||||
wav = f32_pcm(wav)
|
||||
return wav, sr
|
||||
|
||||
class LoadAudio:
|
||||
class LoadAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
def define_schema(cls):
|
||||
input_dir = folder_paths.get_input_directory()
|
||||
files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"])
|
||||
return {"required": {"audio": (sorted(files), {"audio_upload": True})}}
|
||||
return IO.Schema(
|
||||
node_id="LoadAudio",
|
||||
display_name="Load Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
RETURN_TYPES = ("AUDIO", )
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, audio):
|
||||
@classmethod
|
||||
def execute(cls, audio) -> IO.NodeOutput:
|
||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||
waveform, sample_rate = load(audio_path)
|
||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||
return (audio, )
|
||||
return IO.NodeOutput(audio)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, audio):
|
||||
def fingerprint_inputs(cls, audio):
|
||||
image_path = folder_paths.get_annotated_filepath(audio)
|
||||
m = hashlib.sha256()
|
||||
with open(image_path, 'rb') as f:
|
||||
@ -343,46 +283,69 @@ class LoadAudio:
|
||||
return m.digest().hex()
|
||||
|
||||
@classmethod
|
||||
def VALIDATE_INPUTS(s, audio):
|
||||
def validate_inputs(cls, audio):
|
||||
if not folder_paths.exists_annotated_filepath(audio):
|
||||
return "Invalid audio file: {}".format(audio)
|
||||
return True
|
||||
|
||||
class RecordAudio:
|
||||
load = execute # TODO: remove
|
||||
|
||||
|
||||
class RecordAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"audio": ("AUDIO_RECORD", {})}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RecordAudio",
|
||||
display_name="Record Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Custom("AUDIO_RECORD").Input("audio"),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
CATEGORY = "audio"
|
||||
|
||||
RETURN_TYPES = ("AUDIO", )
|
||||
FUNCTION = "load"
|
||||
|
||||
def load(self, audio):
|
||||
@classmethod
|
||||
def execute(cls, audio) -> IO.NodeOutput:
|
||||
audio_path = folder_paths.get_annotated_filepath(audio)
|
||||
|
||||
waveform, sample_rate = load(audio_path)
|
||||
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
|
||||
return (audio, )
|
||||
return IO.NodeOutput(audio)
|
||||
|
||||
load = execute # TODO: remove
|
||||
|
||||
|
||||
class TrimAudioDuration:
|
||||
class TrimAudioDuration(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"audio": ("AUDIO",),
|
||||
"start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
|
||||
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
|
||||
},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TrimAudioDuration",
|
||||
display_name="Trim Audio Duration",
|
||||
description="Trim audio tensor into chosen time range.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Float.Input(
|
||||
"start_index",
|
||||
default=0.0,
|
||||
min=-0xffffffffffffffff,
|
||||
max=0xffffffffffffffff,
|
||||
step=0.01,
|
||||
tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"duration",
|
||||
default=60.0,
|
||||
min=0.0,
|
||||
step=0.01,
|
||||
tooltip="Duration in seconds",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
FUNCTION = "trim"
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
CATEGORY = "audio"
|
||||
DESCRIPTION = "Trim audio tensor into chosen time range."
|
||||
|
||||
def trim(self, audio, start_index, duration):
|
||||
@classmethod
|
||||
def execute(cls, audio, start_index, duration) -> IO.NodeOutput:
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
audio_length = waveform.shape[-1]
|
||||
@ -399,23 +362,30 @@ class TrimAudioDuration:
|
||||
if start_frame >= end_frame:
|
||||
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
|
||||
|
||||
return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
|
||||
return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate})
|
||||
|
||||
trim = execute # TODO: remove
|
||||
|
||||
|
||||
class SplitAudioChannels:
|
||||
class SplitAudioChannels(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"audio": ("AUDIO",),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SplitAudioChannels",
|
||||
display_name="Split Audio Channels",
|
||||
description="Separates the audio into left and right channels.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
],
|
||||
outputs=[
|
||||
IO.Audio.Output(display_name="left"),
|
||||
IO.Audio.Output(display_name="right"),
|
||||
],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("AUDIO", "AUDIO")
|
||||
RETURN_NAMES = ("left", "right")
|
||||
FUNCTION = "separate"
|
||||
CATEGORY = "audio"
|
||||
DESCRIPTION = "Separates the audio into left and right channels."
|
||||
|
||||
def separate(self, audio):
|
||||
@classmethod
|
||||
def execute(cls, audio) -> IO.NodeOutput:
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
@ -425,7 +395,9 @@ class SplitAudioChannels:
|
||||
left_channel = waveform[..., 0:1, :]
|
||||
right_channel = waveform[..., 1:2, :]
|
||||
|
||||
return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
|
||||
return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
|
||||
|
||||
separate = execute # TODO: remove
|
||||
|
||||
|
||||
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
|
||||
@ -443,21 +415,29 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_
|
||||
return waveform_1, waveform_2, output_sample_rate
|
||||
|
||||
|
||||
class AudioConcat:
|
||||
class AudioConcat(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"audio1": ("AUDIO",),
|
||||
"audio2": ("AUDIO",),
|
||||
"direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioConcat",
|
||||
display_name="Audio Concat",
|
||||
description="Concatenates the audio1 to audio2 in the specified direction.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio1"),
|
||||
IO.Audio.Input("audio2"),
|
||||
IO.Combo.Input(
|
||||
"direction",
|
||||
options=['after', 'before'],
|
||||
default="after",
|
||||
tooltip="Whether to append audio2 after or before audio1.",
|
||||
)
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "concat"
|
||||
CATEGORY = "audio"
|
||||
DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
|
||||
|
||||
def concat(self, audio1, audio2, direction):
|
||||
@classmethod
|
||||
def execute(cls, audio1, audio2, direction) -> IO.NodeOutput:
|
||||
waveform_1 = audio1["waveform"]
|
||||
waveform_2 = audio2["waveform"]
|
||||
sample_rate_1 = audio1["sample_rate"]
|
||||
@ -477,26 +457,33 @@ class AudioConcat:
|
||||
elif direction == 'before':
|
||||
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
|
||||
|
||||
return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
|
||||
return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate})
|
||||
|
||||
concat = execute # TODO: remove
|
||||
|
||||
|
||||
class AudioMerge:
|
||||
class AudioMerge(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
return {
|
||||
"required": {
|
||||
"audio1": ("AUDIO",),
|
||||
"audio2": ("AUDIO",),
|
||||
"merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
|
||||
},
|
||||
}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioMerge",
|
||||
display_name="Audio Merge",
|
||||
description="Combine two audio tracks by overlaying their waveforms.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio1"),
|
||||
IO.Audio.Input("audio2"),
|
||||
IO.Combo.Input(
|
||||
"merge_method",
|
||||
options=["add", "mean", "subtract", "multiply"],
|
||||
tooltip="The method used to combine the audio waveforms.",
|
||||
)
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
FUNCTION = "merge"
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
CATEGORY = "audio"
|
||||
DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
|
||||
|
||||
def merge(self, audio1, audio2, merge_method):
|
||||
@classmethod
|
||||
def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput:
|
||||
waveform_1 = audio1["waveform"]
|
||||
waveform_2 = audio2["waveform"]
|
||||
sample_rate_1 = audio1["sample_rate"]
|
||||
@ -530,85 +517,110 @@ class AudioMerge:
|
||||
if max_val > 1.0:
|
||||
waveform = waveform / max_val
|
||||
|
||||
return ({"waveform": waveform, "sample_rate": output_sample_rate},)
|
||||
return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate})
|
||||
|
||||
merge = execute # TODO: remove
|
||||
|
||||
|
||||
class AudioAdjustVolume:
|
||||
class AudioAdjustVolume(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"audio": ("AUDIO",),
|
||||
"volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="AudioAdjustVolume",
|
||||
display_name="Audio Adjust Volume",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.Int.Input(
|
||||
"volume",
|
||||
default=1,
|
||||
min=-100,
|
||||
max=100,
|
||||
tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc",
|
||||
)
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "adjust_volume"
|
||||
CATEGORY = "audio"
|
||||
|
||||
def adjust_volume(self, audio, volume):
|
||||
@classmethod
|
||||
def execute(cls, audio, volume) -> IO.NodeOutput:
|
||||
if volume == 0:
|
||||
return (audio,)
|
||||
return IO.NodeOutput(audio)
|
||||
waveform = audio["waveform"]
|
||||
sample_rate = audio["sample_rate"]
|
||||
|
||||
gain = 10 ** (volume / 20)
|
||||
waveform = waveform * gain
|
||||
|
||||
return ({"waveform": waveform, "sample_rate": sample_rate},)
|
||||
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
|
||||
|
||||
adjust_volume = execute # TODO: remove
|
||||
|
||||
|
||||
class EmptyAudio:
|
||||
class EmptyAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
|
||||
"sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
|
||||
"channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
|
||||
}}
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="EmptyAudio",
|
||||
display_name="Empty Audio",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Float.Input(
|
||||
"duration",
|
||||
default=60.0,
|
||||
min=0.0,
|
||||
max=0xffffffffffffffff,
|
||||
step=0.01,
|
||||
tooltip="Duration of the empty audio clip in seconds",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"sample_rate",
|
||||
default=44100,
|
||||
tooltip="Sample rate of the empty audio clip.",
|
||||
min=1,
|
||||
max=192000,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"channels",
|
||||
default=2,
|
||||
min=1,
|
||||
max=2,
|
||||
tooltip="Number of audio channels (1 for mono, 2 for stereo).",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Audio.Output()],
|
||||
)
|
||||
|
||||
RETURN_TYPES = ("AUDIO",)
|
||||
FUNCTION = "create_empty_audio"
|
||||
CATEGORY = "audio"
|
||||
|
||||
def create_empty_audio(self, duration, sample_rate, channels):
|
||||
@classmethod
|
||||
def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput:
|
||||
num_samples = int(round(duration * sample_rate))
|
||||
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
|
||||
return ({"waveform": waveform, "sample_rate": sample_rate},)
|
||||
return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate})
|
||||
|
||||
create_empty_audio = execute # TODO: remove
|
||||
|
||||
|
||||
NODE_CLASS_MAPPINGS = {
|
||||
"EmptyLatentAudio": EmptyLatentAudio,
|
||||
"VAEEncodeAudio": VAEEncodeAudio,
|
||||
"VAEDecodeAudio": VAEDecodeAudio,
|
||||
"SaveAudio": SaveAudio,
|
||||
"SaveAudioMP3": SaveAudioMP3,
|
||||
"SaveAudioOpus": SaveAudioOpus,
|
||||
"LoadAudio": LoadAudio,
|
||||
"PreviewAudio": PreviewAudio,
|
||||
"ConditioningStableAudio": ConditioningStableAudio,
|
||||
"RecordAudio": RecordAudio,
|
||||
"TrimAudioDuration": TrimAudioDuration,
|
||||
"SplitAudioChannels": SplitAudioChannels,
|
||||
"AudioConcat": AudioConcat,
|
||||
"AudioMerge": AudioMerge,
|
||||
"AudioAdjustVolume": AudioAdjustVolume,
|
||||
"EmptyAudio": EmptyAudio,
|
||||
}
|
||||
class AudioExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
EmptyLatentAudio,
|
||||
VAEEncodeAudio,
|
||||
VAEDecodeAudio,
|
||||
SaveAudio,
|
||||
SaveAudioMP3,
|
||||
SaveAudioOpus,
|
||||
LoadAudio,
|
||||
PreviewAudio,
|
||||
ConditioningStableAudio,
|
||||
RecordAudio,
|
||||
TrimAudioDuration,
|
||||
SplitAudioChannels,
|
||||
AudioConcat,
|
||||
AudioMerge,
|
||||
AudioAdjustVolume,
|
||||
EmptyAudio,
|
||||
]
|
||||
|
||||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||||
"EmptyLatentAudio": "Empty Latent Audio",
|
||||
"VAEEncodeAudio": "VAE Encode Audio",
|
||||
"VAEDecodeAudio": "VAE Decode Audio",
|
||||
"PreviewAudio": "Preview Audio",
|
||||
"LoadAudio": "Load Audio",
|
||||
"SaveAudio": "Save Audio (FLAC)",
|
||||
"SaveAudioMP3": "Save Audio (MP3)",
|
||||
"SaveAudioOpus": "Save Audio (Opus)",
|
||||
"RecordAudio": "Record Audio",
|
||||
"TrimAudioDuration": "Trim Audio Duration",
|
||||
"SplitAudioChannels": "Split Audio Channels",
|
||||
"AudioConcat": "Audio Concat",
|
||||
"AudioMerge": "Audio Merge",
|
||||
"AudioAdjustVolume": "Audio Adjust Volume",
|
||||
"EmptyAudio": "Empty Audio",
|
||||
}
|
||||
async def comfy_entrypoint() -> AudioExtension:
|
||||
return AudioExtension()
|
||||
|
||||
@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||
@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model:
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool,
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||
model = model.clone()
|
||||
model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler(
|
||||
context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule),
|
||||
@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode):
|
||||
context_overlap=context_overlap,
|
||||
context_stride=context_stride,
|
||||
closed_loop=closed_loop,
|
||||
dim=dim)
|
||||
dim=dim,
|
||||
freenoise=freenoise,
|
||||
cond_retain_index_list=cond_retain_index_list,
|
||||
split_conds_to_windows=split_conds_to_windows
|
||||
)
|
||||
# make memory usage calculation only take into account the context window latents
|
||||
comfy.context_windows.create_prepare_sampling_wrapper(model)
|
||||
if freenoise: # no other use for this wrapper at this time
|
||||
comfy.context_windows.create_sampler_sample_wrapper(model)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||
@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode):
|
||||
io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."),
|
||||
io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."),
|
||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||
]
|
||||
return schema
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model:
|
||||
def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool,
|
||||
cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model:
|
||||
context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1
|
||||
context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2)
|
||||
return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows)
|
||||
|
||||
|
||||
class ContextWindowsExtension(ComfyExtension):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user