# ComfyUI
-**The most powerful and modular diffusion model GUI and backend.**
+**The most powerful and modular visual AI engine and application.**
[![Website][website-shield]][website-url]
@@ -31,10 +31,23 @@

-This ui will let you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. For some workflow examples and see what ComfyUI can do you can check out:
-### [ComfyUI Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
+ComfyUI lets you design and execute advanced stable diffusion pipelines using a graph/nodes/flowchart based interface. Available on Windows, Linux, and macOS.
-### [Installing ComfyUI](#installing)
+## Get Started
+
+#### [Desktop Application](https://www.comfy.org/download)
+- The easiest way to get started.
+- Available on Windows & macOS.
+
+#### [Windows Portable Package](#installing)
+- Get the latest commits and completely portable.
+- Available on Windows.
+
+#### [Manual Install](#manual-install-windows-linux)
+Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
+
+## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
+See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
@@ -47,12 +60,20 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
- [AuraFlow](https://comfyanonymous.github.io/ComfyUI_examples/aura_flow/)
- [HunyuanDiT](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_dit/)
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
+ - [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
+ - [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- Video Models
- [Stable Video Diffusion](https://comfyanonymous.github.io/ComfyUI_examples/video/)
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
-- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
+ - [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/)
+ - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
+- Audio Models
+ - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
+ - [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
+- 3D Models
+ - [Hunyuan3D 2.0](https://docs.comfy.org/tutorials/3d/hunyuan3D-2)
- Asynchronous Queue system
- Many optimizations: Only re-executes the parts of the workflow that changes between executions.
- Smart memory management: can automatically run models on GPUs with as low as 1GB vram.
@@ -79,6 +100,22 @@ This ui will let you design and execute advanced stable diffusion pipelines usin
Workflow examples can be found on the [Examples page](https://comfyanonymous.github.io/ComfyUI_examples/)
+## Release Process
+
+ComfyUI follows a weekly release cycle every Friday, with three interconnected repositories:
+
+1. **[ComfyUI Core](https://github.com/comfyanonymous/ComfyUI)**
+ - Releases a new stable version (e.g., v0.7.0)
+ - Serves as the foundation for the desktop release
+
+2. **[ComfyUI Desktop](https://github.com/Comfy-Org/desktop)**
+ - Builds a new release using the latest stable core version
+
+3. **[ComfyUI Frontend](https://github.com/Comfy-Org/ComfyUI_frontend)**
+ - Weekly frontend updates are merged into the core repository
+ - Features are frozen for the upcoming core release
+ - Development continues for the next release cycle
+
## Shortcuts
| Keybind | Explanation |
@@ -119,7 +156,7 @@ Workflow examples can be found on the [Examples page](https://comfyanonymous.git
# Installing
-## Windows
+## Windows Portable
There is a portable standalone build for Windows that should work for running on Nvidia GPUs or for running on your CPU only on the [releases page](https://github.com/comfyanonymous/ComfyUI/releases).
@@ -137,9 +174,18 @@ See the [Config file](extra_model_paths.yaml.example) to set the search paths fo
To run it on services like paperspace, kaggle or colab you can use my [Jupyter Notebook](notebooks/comfyui_colab.ipynb)
+
+## [comfy-cli](https://docs.comfy.org/comfy-cli/getting-started)
+
+You can install and start ComfyUI using comfy-cli:
+```bash
+pip install comfy-cli
+comfy install
+```
+
## Manual Install (Windows, Linux)
-Note that some dependencies do not yet support python 3.13 so using 3.12 is recommended.
+python 3.13 is supported but using 3.12 is recommended because some custom nodes and their dependencies might not support it yet.
Git clone this repo.
@@ -151,11 +197,11 @@ Put your VAE in: models/vae
### AMD GPUs (Linux only)
AMD users can install rocm and pytorch with pip if you don't have it already installed, this is the command to install the stable version:
-```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.2```
+```pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.3```
-This is the command to install the nightly with ROCm 6.2 which might have some performance improvements:
+This is the command to install the nightly with ROCm 6.4 which might have some performance improvements:
-```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.2.4```
+```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/rocm6.4```
### Intel GPUs (Windows and Linux)
@@ -185,11 +231,11 @@ Additional discussion and help can be found [here](https://github.com/comfyanony
Nvidia users should install stable pytorch using this command:
-```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu124```
+```pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu128```
-This is the command to install pytorch nightly instead which might have performance improvements:
+This is the command to install pytorch nightly instead which might have performance improvements.
-```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126```
+```pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128```
#### Troubleshooting
@@ -233,6 +279,13 @@ For models compatible with Ascend Extension for PyTorch (torch_npu). To get star
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
+#### Cambricon MLUs
+
+For models compatible with Cambricon Extension for PyTorch (torch_mlu). Here's a step-by-step guide tailored to your platform and installation method:
+
+1. Install the Cambricon CNToolkit by adhering to the platform-specific instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cntoolkit_3.7.2/cntoolkit_install_3.7.2/index.html)
+2. Next, install the PyTorch(torch_mlu) following the instructions on the [Installation](https://www.cambricon.com/docs/sdk_1.15.0/cambricon_pytorch_1.17.0/user_guide_1.9/index.html)
+3. Launch ComfyUI by running `python main.py`
# Running
@@ -248,7 +301,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
### AMD ROCm Tips
-You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI on RDNA3 and potentially other AMD GPUs using this command:
+You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
@@ -289,6 +342,8 @@ Use `--tls-keyfile key.pem --tls-certfile cert.pem` to enable TLS/SSL, the app w
## Support and dev channel
+[Discord](https://comfy.org/discord): Try the #help or #feedback channels.
+
[Matrix space: #comfyui_space:matrix.org](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) (it's like discord but open source).
See also: [https://www.comfy.org/](https://www.comfy.org/)
@@ -305,7 +360,7 @@ For any bugs, issues, or feature requests related to the frontend, please use th
The new frontend is now the default for ComfyUI. However, please note:
-1. The frontend in the main ComfyUI repository is updated weekly.
+1. The frontend in the main ComfyUI repository is updated fortnightly.
2. Daily releases are available in the separate frontend repository.
To use the most up-to-date frontend version:
@@ -322,7 +377,7 @@ To use the most up-to-date frontend version:
--front-end-version Comfy-Org/ComfyUI_frontend@1.2.2
```
-This approach allows you to easily switch between the stable weekly release and the cutting-edge daily updates, or even specific versions for testing purposes.
+This approach allows you to easily switch between the stable fortnightly release and the cutting-edge daily updates, or even specific versions for testing purposes.
### Accessing the Legacy Frontend
diff --git a/api_server/routes/internal/internal_routes.py b/api_server/routes/internal/internal_routes.py
index 8f74529ba..613b0f7c7 100644
--- a/api_server/routes/internal/internal_routes.py
+++ b/api_server/routes/internal/internal_routes.py
@@ -1,9 +1,9 @@
from aiohttp import web
from typing import Optional
-from folder_paths import models_dir, user_directory, output_directory, folder_names_and_paths
-from api_server.services.file_service import FileService
+from folder_paths import folder_names_and_paths, get_directory_by_type
from api_server.services.terminal_service import TerminalService
import app.logger
+import os
class InternalRoutes:
'''
@@ -15,26 +15,10 @@ class InternalRoutes:
def __init__(self, prompt_server):
self.routes: web.RouteTableDef = web.RouteTableDef()
self._app: Optional[web.Application] = None
- self.file_service = FileService({
- "models": models_dir,
- "user": user_directory,
- "output": output_directory
- })
self.prompt_server = prompt_server
self.terminal_service = TerminalService(prompt_server)
def setup_routes(self):
- @self.routes.get('/files')
- async def list_files(request):
- directory_key = request.query.get('directory', '')
- try:
- file_list = self.file_service.list_files(directory_key)
- return web.json_response({"files": file_list})
- except ValueError as e:
- return web.json_response({"error": str(e)}, status=400)
- except Exception as e:
- return web.json_response({"error": str(e)}, status=500)
-
@self.routes.get('/logs')
async def get_logs(request):
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
@@ -67,6 +51,20 @@ class InternalRoutes:
response[key] = folder_names_and_paths[key][0]
return web.json_response(response)
+ @self.routes.get('/files/{directory_type}')
+ async def get_files(request: web.Request) -> web.Response:
+ directory_type = request.match_info['directory_type']
+ if directory_type not in ("output", "input", "temp"):
+ return web.json_response({"error": "Invalid directory type"}, status=400)
+
+ directory = get_directory_by_type(directory_type)
+ sorted_files = sorted(
+ (entry for entry in os.scandir(directory) if entry.is_file()),
+ key=lambda entry: -entry.stat().st_mtime
+ )
+ return web.json_response([entry.name for entry in sorted_files], status=200)
+
+
def get_app(self):
if self._app is None:
self._app = web.Application()
diff --git a/api_server/services/file_service.py b/api_server/services/file_service.py
deleted file mode 100644
index 115edccd3..000000000
--- a/api_server/services/file_service.py
+++ /dev/null
@@ -1,13 +0,0 @@
-from typing import Dict, List, Optional
-from api_server.utils.file_operations import FileSystemOperations, FileSystemItem
-
-class FileService:
- def __init__(self, allowed_directories: Dict[str, str], file_system_ops: Optional[FileSystemOperations] = None):
- self.allowed_directories: Dict[str, str] = allowed_directories
- self.file_system_ops: FileSystemOperations = file_system_ops or FileSystemOperations()
-
- def list_files(self, directory_key: str) -> List[FileSystemItem]:
- if directory_key not in self.allowed_directories:
- raise ValueError("Invalid directory key")
- directory_path: str = self.allowed_directories[directory_key]
- return self.file_system_ops.walk_directory(directory_path)
diff --git a/app/app_settings.py b/app/app_settings.py
index a545df92e..c7ac73bf6 100644
--- a/app/app_settings.py
+++ b/app/app_settings.py
@@ -9,8 +9,14 @@ class AppSettings():
self.user_manager = user_manager
def get_settings(self, request):
- file = self.user_manager.get_request_user_filepath(
- request, "comfy.settings.json")
+ try:
+ file = self.user_manager.get_request_user_filepath(
+ request,
+ "comfy.settings.json"
+ )
+ except KeyError as e:
+ logging.error("User settings not found.")
+ raise web.HTTPUnauthorized() from e
if os.path.isfile(file):
try:
with open(file) as f:
diff --git a/app/custom_node_manager.py b/app/custom_node_manager.py
index 7f9f645cd..281febca9 100644
--- a/app/custom_node_manager.py
+++ b/app/custom_node_manager.py
@@ -4,31 +4,142 @@ import os
import folder_paths
import glob
from aiohttp import web
+import json
+import logging
+from functools import lru_cache
+
+from utils.json_util import merge_json_recursive
+
+
+# Extra locale files to load into main.json
+EXTRA_LOCALE_FILES = [
+ "nodeDefs.json",
+ "commands.json",
+ "settings.json",
+]
+
+
+def safe_load_json_file(file_path: str) -> dict:
+ if not os.path.exists(file_path):
+ return {}
+
+ try:
+ with open(file_path, "r", encoding="utf-8") as f:
+ return json.load(f)
+ except json.JSONDecodeError:
+ logging.error(f"Error loading {file_path}")
+ return {}
+
class CustomNodeManager:
- """
- Placeholder to refactor the custom node management features from ComfyUI-Manager.
- Currently it only contains the custom workflow templates feature.
- """
+ @lru_cache(maxsize=1)
+ def build_translations(self):
+ """Load all custom nodes translations during initialization. Translations are
+ expected to be loaded from `locales/` folder.
+
+ The folder structure is expected to be the following:
+ - custom_nodes/
+ - custom_node_1/
+ - locales/
+ - en/
+ - main.json
+ - commands.json
+ - settings.json
+
+ returned translations are expected to be in the following format:
+ {
+ "en": {
+ "nodeDefs": {...},
+ "commands": {...},
+ "settings": {...},
+ ...{other main.json keys}
+ }
+ }
+ """
+
+ translations = {}
+
+ for folder in folder_paths.get_folder_paths("custom_nodes"):
+ # Sort glob results for deterministic ordering
+ for custom_node_dir in sorted(glob.glob(os.path.join(folder, "*/"))):
+ locales_dir = os.path.join(custom_node_dir, "locales")
+ if not os.path.exists(locales_dir):
+ continue
+
+ for lang_dir in glob.glob(os.path.join(locales_dir, "*/")):
+ lang_code = os.path.basename(os.path.dirname(lang_dir))
+
+ if lang_code not in translations:
+ translations[lang_code] = {}
+
+ # Load main.json
+ main_file = os.path.join(lang_dir, "main.json")
+ node_translations = safe_load_json_file(main_file)
+
+ # Load extra locale files
+ for extra_file in EXTRA_LOCALE_FILES:
+ extra_file_path = os.path.join(lang_dir, extra_file)
+ key = extra_file.split(".")[0]
+ json_data = safe_load_json_file(extra_file_path)
+ if json_data:
+ node_translations[key] = json_data
+
+ if node_translations:
+ translations[lang_code] = merge_json_recursive(
+ translations[lang_code], node_translations
+ )
+
+ return translations
+
def add_routes(self, routes, webapp, loadedModules):
+ example_workflow_folder_names = ["example_workflows", "example", "examples", "workflow", "workflows"]
+
@routes.get("/workflow_templates")
async def get_workflow_templates(request):
"""Returns a web response that contains the map of custom_nodes names and their associated workflow templates. The ones without templates are omitted."""
- files = [
- file
- for folder in folder_paths.get_folder_paths("custom_nodes")
- for file in glob.glob(os.path.join(folder, '*/example_workflows/*.json'))
- ]
- workflow_templates_dict = {} # custom_nodes folder name -> example workflow names
+
+ files = []
+
+ for folder in folder_paths.get_folder_paths("custom_nodes"):
+ for folder_name in example_workflow_folder_names:
+ pattern = os.path.join(folder, f"*/{folder_name}/*.json")
+ matched_files = glob.glob(pattern)
+ files.extend(matched_files)
+
+ workflow_templates_dict = (
+ {}
+ ) # custom_nodes folder name -> example workflow names
for file in files:
- custom_nodes_name = os.path.basename(os.path.dirname(os.path.dirname(file)))
+ custom_nodes_name = os.path.basename(
+ os.path.dirname(os.path.dirname(file))
+ )
workflow_name = os.path.splitext(os.path.basename(file))[0]
- workflow_templates_dict.setdefault(custom_nodes_name, []).append(workflow_name)
+ workflow_templates_dict.setdefault(custom_nodes_name, []).append(
+ workflow_name
+ )
return web.json_response(workflow_templates_dict)
# Serve workflow templates from custom nodes.
for module_name, module_dir in loadedModules:
- workflows_dir = os.path.join(module_dir, 'example_workflows')
- if os.path.exists(workflows_dir):
- webapp.add_routes([web.static('/api/workflow_templates/' + module_name, workflows_dir)])
+ for folder_name in example_workflow_folder_names:
+ workflows_dir = os.path.join(module_dir, folder_name)
+
+ if os.path.exists(workflows_dir):
+ if folder_name != "example_workflows":
+ logging.debug(
+ "Found example workflow folder '%s' for custom node '%s', consider renaming it to 'example_workflows'",
+ folder_name, module_name)
+
+ webapp.add_routes(
+ [
+ web.static(
+ "/api/workflow_templates/" + module_name, workflows_dir
+ )
+ ]
+ )
+
+ @routes.get("/i18n")
+ async def get_i18n(request):
+ """Returns translations from all custom nodes' locales folders."""
+ return web.json_response(self.build_translations())
diff --git a/app/frontend_management.py b/app/frontend_management.py
index 6f20e439c..7b7923b79 100644
--- a/app/frontend_management.py
+++ b/app/frontend_management.py
@@ -3,16 +3,69 @@ import argparse
import logging
import os
import re
+import sys
import tempfile
import zipfile
+import importlib
from dataclasses import dataclass
from functools import cached_property
from pathlib import Path
from typing import TypedDict, Optional
+from importlib.metadata import version
import requests
from typing_extensions import NotRequired
+
from comfy.cli_args import DEFAULT_VERSION_STRING
+import app.logger
+
+# The path to the requirements.txt file
+req_path = Path(__file__).parents[1] / "requirements.txt"
+
+
+def frontend_install_warning_message():
+ """The warning message to display when the frontend version is not up to date."""
+
+ extra = ""
+ if sys.flags.no_user_site:
+ extra = "-s "
+ return f"""
+Please install the updated requirements.txt file by running:
+{sys.executable} {extra}-m pip install -r {req_path}
+
+This error is happening because the ComfyUI frontend is no longer shipped as part of the main repo but as a pip package instead.
+
+If you are on the portable package you can run: update\\update_comfyui.bat to solve this problem
+""".strip()
+
+
+def check_frontend_version():
+ """Check if the frontend version is up to date."""
+
+ def parse_version(version: str) -> tuple[int, int, int]:
+ return tuple(map(int, version.split(".")))
+
+ try:
+ frontend_version_str = version("comfyui-frontend-package")
+ frontend_version = parse_version(frontend_version_str)
+ with open(req_path, "r", encoding="utf-8") as f:
+ required_frontend = parse_version(f.readline().split("=")[-1])
+ if frontend_version < required_frontend:
+ app.logger.log_startup_warning(
+ f"""
+________________________________________________________________________
+WARNING WARNING WARNING WARNING WARNING
+
+Installed frontend version {".".join(map(str, frontend_version))} is lower than the recommended version {".".join(map(str, required_frontend))}.
+
+{frontend_install_warning_message()}
+________________________________________________________________________
+""".strip()
+ )
+ else:
+ logging.info("ComfyUI frontend version: {}".format(frontend_version_str))
+ except Exception as e:
+ logging.error(f"Failed to check frontend version: {e}")
REQUEST_TIMEOUT = 10 # seconds
@@ -109,9 +162,49 @@ def download_release_asset_zip(release: Release, destination_path: str) -> None:
class FrontendManager:
- DEFAULT_FRONTEND_PATH = str(Path(__file__).parents[1] / "web")
CUSTOM_FRONTENDS_ROOT = str(Path(__file__).parents[1] / "web_custom_versions")
+ @classmethod
+ def default_frontend_path(cls) -> str:
+ try:
+ import comfyui_frontend_package
+
+ return str(importlib.resources.files(comfyui_frontend_package) / "static")
+ except ImportError:
+ logging.error(
+ f"""
+********** ERROR ***********
+
+comfyui-frontend-package is not installed.
+
+{frontend_install_warning_message()}
+
+********** ERROR ***********
+""".strip()
+ )
+ sys.exit(-1)
+
+ @classmethod
+ def templates_path(cls) -> str:
+ try:
+ import comfyui_workflow_templates
+
+ return str(
+ importlib.resources.files(comfyui_workflow_templates) / "templates"
+ )
+ except ImportError:
+ logging.error(
+ f"""
+********** ERROR ***********
+
+comfyui-workflow-templates is not installed.
+
+{frontend_install_warning_message()}
+
+********** ERROR ***********
+""".strip()
+ )
+
@classmethod
def parse_version_string(cls, value: str) -> tuple[str, str, str]:
"""
@@ -132,7 +225,9 @@ class FrontendManager:
return match_result.group(1), match_result.group(2), match_result.group(3)
@classmethod
- def init_frontend_unsafe(cls, version_string: str, provider: Optional[FrontEndProvider] = None) -> str:
+ def init_frontend_unsafe(
+ cls, version_string: str, provider: Optional[FrontEndProvider] = None
+ ) -> str:
"""
Initializes the frontend for the specified version.
@@ -148,17 +243,26 @@ class FrontendManager:
main error source might be request timeout or invalid URL.
"""
if version_string == DEFAULT_VERSION_STRING:
- return cls.DEFAULT_FRONTEND_PATH
+ check_frontend_version()
+ return cls.default_frontend_path()
repo_owner, repo_name, version = cls.parse_version_string(version_string)
if version.startswith("v"):
- expected_path = str(Path(cls.CUSTOM_FRONTENDS_ROOT) / f"{repo_owner}_{repo_name}" / version.lstrip("v"))
+ expected_path = str(
+ Path(cls.CUSTOM_FRONTENDS_ROOT)
+ / f"{repo_owner}_{repo_name}"
+ / version.lstrip("v")
+ )
if os.path.exists(expected_path):
- logging.info(f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}")
+ logging.info(
+ f"Using existing copy of specific frontend version tag: {repo_owner}/{repo_name}@{version}"
+ )
return expected_path
- logging.info(f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub...")
+ logging.info(
+ f"Initializing frontend: {repo_owner}/{repo_name}@{version}, requesting version details from GitHub..."
+ )
provider = provider or FrontEndProvider(repo_owner, repo_name)
release = provider.get_release(version)
@@ -201,4 +305,5 @@ class FrontendManager:
except Exception as e:
logging.error("Failed to initialize frontend: %s", e)
logging.info("Falling back to the default frontend.")
- return cls.DEFAULT_FRONTEND_PATH
+ check_frontend_version()
+ return cls.default_frontend_path()
diff --git a/app/logger.py b/app/logger.py
index 9e9f84ccf..3d26d98fe 100644
--- a/app/logger.py
+++ b/app/logger.py
@@ -82,3 +82,17 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool
logger.addHandler(stdout_handler)
logger.addHandler(stream_handler)
+
+
+STARTUP_WARNINGS = []
+
+
+def log_startup_warning(msg):
+ logging.warning(msg)
+ STARTUP_WARNINGS.append(msg)
+
+
+def print_startup_warnings():
+ for s in STARTUP_WARNINGS:
+ logging.warning(s)
+ STARTUP_WARNINGS.clear()
diff --git a/app/user_manager.py b/app/user_manager.py
index e7381e621..d31da5b9b 100644
--- a/app/user_manager.py
+++ b/app/user_manager.py
@@ -197,6 +197,112 @@ class UserManager():
return web.json_response(results)
+ @routes.get("/v2/userdata")
+ async def list_userdata_v2(request):
+ """
+ List files and directories in a user's data directory.
+
+ This endpoint provides a structured listing of contents within a specified
+ subdirectory of the user's data storage.
+
+ Query Parameters:
+ - path (optional): The relative path within the user's data directory
+ to list. Defaults to the root ('').
+
+ Returns:
+ - 400: If the requested path is invalid, outside the user's data directory, or is not a directory.
+ - 404: If the requested path does not exist.
+ - 403: If the user is invalid.
+ - 500: If there is an error reading the directory contents.
+ - 200: JSON response containing a list of file and directory objects.
+ Each object includes:
+ - name: The name of the file or directory.
+ - type: 'file' or 'directory'.
+ - path: The relative path from the user's data root.
+ - size (for files): The size in bytes.
+ - modified (for files): The last modified timestamp (Unix epoch).
+ """
+ requested_rel_path = request.rel_url.query.get('path', '')
+
+ # URL-decode the path parameter
+ try:
+ requested_rel_path = parse.unquote(requested_rel_path)
+ except Exception as e:
+ logging.warning(f"Failed to decode path parameter: {requested_rel_path}, Error: {e}")
+ return web.Response(status=400, text="Invalid characters in path parameter")
+
+
+ # Check user validity and get the absolute path for the requested directory
+ try:
+ base_user_path = self.get_request_user_filepath(request, None, create_dir=False)
+
+ if requested_rel_path:
+ target_abs_path = self.get_request_user_filepath(request, requested_rel_path, create_dir=False)
+ else:
+ target_abs_path = base_user_path
+
+ except KeyError as e:
+ # Invalid user detected by get_request_user_id inside get_request_user_filepath
+ logging.warning(f"Access denied for user: {e}")
+ return web.Response(status=403, text="Invalid user specified in request")
+
+
+ if not target_abs_path:
+ # Path traversal or other issue detected by get_request_user_filepath
+ return web.Response(status=400, text="Invalid path requested")
+
+ # Handle cases where the user directory or target path doesn't exist
+ if not os.path.exists(target_abs_path):
+ # Check if it's the base user directory that's missing (new user case)
+ if target_abs_path == base_user_path:
+ # It's okay if the base user directory doesn't exist yet, return empty list
+ return web.json_response([])
+ else:
+ # A specific subdirectory was requested but doesn't exist
+ return web.Response(status=404, text="Requested path not found")
+
+ if not os.path.isdir(target_abs_path):
+ return web.Response(status=400, text="Requested path is not a directory")
+
+ results = []
+ try:
+ for root, dirs, files in os.walk(target_abs_path, topdown=True):
+ # Process directories
+ for dir_name in dirs:
+ dir_path = os.path.join(root, dir_name)
+ rel_path = os.path.relpath(dir_path, base_user_path).replace(os.sep, '/')
+ results.append({
+ "name": dir_name,
+ "path": rel_path,
+ "type": "directory"
+ })
+
+ # Process files
+ for file_name in files:
+ file_path = os.path.join(root, file_name)
+ rel_path = os.path.relpath(file_path, base_user_path).replace(os.sep, '/')
+ entry_info = {
+ "name": file_name,
+ "path": rel_path,
+ "type": "file"
+ }
+ try:
+ stats = os.stat(file_path) # Use os.stat for potentially better performance with os.walk
+ entry_info["size"] = stats.st_size
+ entry_info["modified"] = stats.st_mtime
+ except OSError as stat_error:
+ logging.warning(f"Could not stat file {file_path}: {stat_error}")
+ pass # Include file with available info
+ results.append(entry_info)
+ except OSError as e:
+ logging.error(f"Error listing directory {target_abs_path}: {e}")
+ return web.Response(status=500, text="Error reading directory contents")
+
+ # Sort results alphabetically, directories first then files
+ results.sort(key=lambda x: (x['type'] != 'directory', x['name'].lower()))
+
+ return web.json_response(results)
+
def get_user_data_path(request, check_exists = False, param = "file"):
file = request.match_info.get(param, None)
if not file:
diff --git a/comfy/cli_args.py b/comfy/cli_args.py
index 812798bf8..4fb675f99 100644
--- a/comfy/cli_args.py
+++ b/comfy/cli_args.py
@@ -1,7 +1,6 @@
import argparse
import enum
import os
-from typing import Optional
import comfy.options
@@ -43,10 +42,11 @@ parser.add_argument("--tls-certfile", type=str, help="Path to TLS (SSL) certific
parser.add_argument("--enable-cors-header", type=str, default=None, metavar="ORIGIN", nargs="?", const="*", help="Enable CORS (Cross-Origin Resource Sharing) with optional origin or allow all with default '*'.")
parser.add_argument("--max-upload-size", type=float, default=100, help="Set the maximum upload size in MB.")
+parser.add_argument("--base-directory", type=str, default=None, help="Set the ComfyUI base directory for models, custom_nodes, input, output, temp, and user directories.")
parser.add_argument("--extra-model-paths-config", type=str, default=None, metavar="PATH", nargs='+', action='append', help="Load one or more extra_model_paths.yaml files.")
-parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory.")
-parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory).")
-parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory.")
+parser.add_argument("--output-directory", type=str, default=None, help="Set the ComfyUI output directory. Overrides --base-directory.")
+parser.add_argument("--temp-directory", type=str, default=None, help="Set the ComfyUI temp directory (default is in the ComfyUI directory). Overrides --base-directory.")
+parser.add_argument("--input-directory", type=str, default=None, help="Set the ComfyUI input directory. Overrides --base-directory.")
parser.add_argument("--auto-launch", action="store_true", help="Automatically launch ComfyUI in the default browser.")
parser.add_argument("--disable-auto-launch", action="store_true", help="Disable auto launching the browser.")
parser.add_argument("--cuda-device", type=int, default=None, metavar="DEVICE_ID", help="Set the id of the cuda device this instance will use.")
@@ -66,6 +66,7 @@ fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the diff
fpunet_group.add_argument("--fp16-unet", action="store_true", help="Run the diffusion model in fp16")
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
+fpunet_group.add_argument("--fp8_e8m0fnu-unet", action="store_true", help="Store unet weights in fp8_e8m0fnu.")
fpvae_group = parser.add_mutually_exclusive_group()
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
@@ -79,6 +80,7 @@ fpte_group.add_argument("--fp8_e4m3fn-text-enc", action="store_true", help="Stor
fpte_group.add_argument("--fp8_e5m2-text-enc", action="store_true", help="Store text encoder weights in fp8 (e5m2 variant).")
fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text encoder weights in fp16.")
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
+fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
@@ -86,6 +88,7 @@ parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE"
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
+parser.add_argument("--supports-fp8-compute", action="store_true", help="ComfyUI will act like if the device supports fp8 compute.")
class LatentPreviewMethod(enum.Enum):
NoPreviews = "none"
@@ -100,12 +103,14 @@ parser.add_argument("--preview-size", type=int, default=512, help="Sets the maxi
cache_group = parser.add_mutually_exclusive_group()
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
+cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
+attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
@@ -124,12 +129,21 @@ vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for e
parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.")
+parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.")
parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.")
parser.add_argument("--disable-smart-memory", action="store_true", help="Force ComfyUI to agressively offload to regular ram instead of keeping models in vram when it can.")
parser.add_argument("--deterministic", action="store_true", help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.")
-parser.add_argument("--fast", action="store_true", help="Enable some untested and potentially quality deteriorating optimizations.")
+
+class PerformanceFeature(enum.Enum):
+ Fp16Accumulation = "fp16_accumulation"
+ Fp8MatrixMultiplication = "fp8_matrix_mult"
+ CublasOps = "cublas_ops"
+
+parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
+
+parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--dont-print-server", action="store_true", help="Don't print server output.")
parser.add_argument("--quick-test-for-ci", action="store_true", help="Quick test for CI.")
@@ -137,6 +151,7 @@ parser.add_argument("--windows-standalone-build", action="store_true", help="Win
parser.add_argument("--disable-metadata", action="store_true", help="Disable saving prompt metadata in files.")
parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.")
+parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.")
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
@@ -160,13 +175,14 @@ parser.add_argument(
""",
)
-def is_valid_directory(path: Optional[str]) -> Optional[str]:
- """Validate if the given path is a directory."""
- if path is None:
- return None
-
+def is_valid_directory(path: str) -> str:
+ """Validate if the given path is a directory, and check permissions."""
+ if not os.path.exists(path):
+ raise argparse.ArgumentTypeError(f"The path '{path}' does not exist.")
if not os.path.isdir(path):
- raise argparse.ArgumentTypeError(f"{path} is not a valid directory.")
+ raise argparse.ArgumentTypeError(f"'{path}' is not a directory.")
+ if not os.access(path, os.R_OK):
+ raise argparse.ArgumentTypeError(f"You do not have read permissions for '{path}'.")
return path
parser.add_argument(
@@ -176,7 +192,16 @@ parser.add_argument(
help="The local filesystem path to the directory where the frontend is located. Overrides --front-end-version.",
)
-parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path.")
+parser.add_argument("--user-directory", type=is_valid_directory, default=None, help="Set the ComfyUI user directory with an absolute path. Overrides --base-directory.")
+
+parser.add_argument("--enable-compress-response-body", action="store_true", help="Enable compressing response body.")
+
+parser.add_argument(
+ "--comfy-api-base",
+ type=str,
+ default="https://api.comfy.org",
+ help="Set the base URL for the ComfyUI API. (default: https://api.comfy.org)",
+)
if comfy.options.args_parsing:
args = parser.parse_args()
@@ -188,3 +213,17 @@ if args.windows_standalone_build:
if args.disable_auto_launch:
args.auto_launch = False
+
+if args.force_fp16:
+ args.fp16_unet = True
+
+
+# '--fast' is not provided, use an empty set
+if args.fast is None:
+ args.fast = set()
+# '--fast' is provided with an empty list, enable all optimizations
+elif args.fast == []:
+ args.fast = set(PerformanceFeature)
+# '--fast' is provided with a list of performance features, use that list
+else:
+ args.fast = set(args.fast)
diff --git a/comfy/clip_model.py b/comfy/clip_model.py
index 23ddea9c0..c8294d483 100644
--- a/comfy/clip_model.py
+++ b/comfy/clip_model.py
@@ -97,14 +97,19 @@ class CLIPTextModel_(torch.nn.Module):
self.encoder = CLIPEncoder(num_layers, embed_dim, heads, intermediate_size, intermediate_activation, dtype, device, operations)
self.final_layer_norm = operations.LayerNorm(embed_dim, dtype=dtype, device=device)
- def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
- x = self.embeddings(input_tokens, dtype=dtype)
+ def forward(self, input_tokens=None, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=torch.float32):
+ if embeds is not None:
+ x = embeds + comfy.ops.cast_to(self.embeddings.position_embedding.weight, dtype=dtype, device=embeds.device)
+ else:
+ x = self.embeddings(input_tokens, dtype=dtype)
+
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
- mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
+ mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
+
+ causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1)
- causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1)
if mask is not None:
mask += causal_mask
else:
@@ -115,7 +120,10 @@ class CLIPTextModel_(torch.nn.Module):
if i is not None and final_layer_norm_intermediate:
i = self.final_layer_norm(i)
- pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
+ if num_tokens is not None:
+ pooled_output = x[list(range(x.shape[0])), list(map(lambda a: a - 1, num_tokens))]
+ else:
+ pooled_output = x[torch.arange(x.shape[0], device=x.device), (torch.round(input_tokens).to(dtype=torch.int, device=x.device) == self.eos_token_id).int().argmax(dim=-1),]
return x, i, pooled_output
class CLIPTextModel(torch.nn.Module):
@@ -203,6 +211,15 @@ class CLIPVision(torch.nn.Module):
pooled_output = self.post_layernorm(x[:, 0, :])
return x, i, pooled_output
+class LlavaProjector(torch.nn.Module):
+ def __init__(self, in_dim, out_dim, dtype, device, operations):
+ super().__init__()
+ self.linear_1 = operations.Linear(in_dim, out_dim, bias=True, device=device, dtype=dtype)
+ self.linear_2 = operations.Linear(out_dim, out_dim, bias=True, device=device, dtype=dtype)
+
+ def forward(self, x):
+ return self.linear_2(torch.nn.functional.gelu(self.linear_1(x[:, 1:])))
+
class CLIPVisionModelProjection(torch.nn.Module):
def __init__(self, config_dict, dtype, device, operations):
super().__init__()
@@ -212,7 +229,16 @@ class CLIPVisionModelProjection(torch.nn.Module):
else:
self.visual_projection = lambda a: a
+ if "llava3" == config_dict.get("projector_type", None):
+ self.multi_modal_projector = LlavaProjector(config_dict["hidden_size"], 4096, dtype, device, operations)
+ else:
+ self.multi_modal_projector = None
+
def forward(self, *args, **kwargs):
x = self.vision_model(*args, **kwargs)
out = self.visual_projection(x[2])
- return (x[0], x[1], out)
+ projected = None
+ if self.multi_modal_projector is not None:
+ projected = self.multi_modal_projector(x[1])
+
+ return (x[0], x[1], out, projected)
diff --git a/comfy/clip_vision.py b/comfy/clip_vision.py
index c9c82e9ad..00aab9164 100644
--- a/comfy/clip_vision.py
+++ b/comfy/clip_vision.py
@@ -9,6 +9,7 @@ import comfy.model_patcher
import comfy.model_management
import comfy.utils
import comfy.clip_model
+import comfy.image_encoders.dino2
class Output:
def __getitem__(self, key):
@@ -17,6 +18,7 @@ class Output:
setattr(self, key, item)
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
+ image = image[:, :, :, :3] if image.shape[3] > 3 else image
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
std = torch.tensor(std, device=image.device, dtype=image.dtype)
image = image.movedim(-1, 1)
@@ -34,6 +36,12 @@ def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], s
image = torch.clip((255. * image), 0, 255).round() / 255.0
return (image - mean.view([3,1,1])) / std.view([3,1,1])
+IMAGE_ENCODERS = {
+ "clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
+ "siglip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
+ "dinov2": comfy.image_encoders.dino2.Dinov2Model,
+}
+
class ClipVisionModel():
def __init__(self, json_config):
with open(json_config) as f:
@@ -42,10 +50,11 @@ class ClipVisionModel():
self.image_size = config.get("image_size", 224)
self.image_mean = config.get("image_mean", [0.48145466, 0.4578275, 0.40821073])
self.image_std = config.get("image_std", [0.26862954, 0.26130258, 0.27577711])
+ model_class = IMAGE_ENCODERS.get(config.get("model_type", "clip_vision_model"))
self.load_device = comfy.model_management.text_encoder_device()
offload_device = comfy.model_management.text_encoder_offload_device()
self.dtype = comfy.model_management.text_encoder_dtype(self.load_device)
- self.model = comfy.clip_model.CLIPVisionModelProjection(config, self.dtype, offload_device, comfy.ops.manual_cast)
+ self.model = model_class(config, self.dtype, offload_device, comfy.ops.manual_cast)
self.model.eval()
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
@@ -65,6 +74,7 @@ class ClipVisionModel():
outputs["last_hidden_state"] = out[0].to(comfy.model_management.intermediate_device())
outputs["image_embeds"] = out[2].to(comfy.model_management.intermediate_device())
outputs["penultimate_hidden_states"] = out[1].to(comfy.model_management.intermediate_device())
+ outputs["mm_projected"] = out[3]
return outputs
def convert_to_transformers(sd, prefix):
@@ -101,12 +111,21 @@ def load_clipvision_from_sd(sd, prefix="", convert_keys=False):
elif "vision_model.encoder.layers.30.layer_norm1.weight" in sd:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_h.json")
elif "vision_model.encoder.layers.22.layer_norm1.weight" in sd:
+ embed_shape = sd["vision_model.embeddings.position_embedding.weight"].shape[0]
if sd["vision_model.encoder.layers.0.layer_norm1.weight"].shape[0] == 1152:
- json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
- elif sd["vision_model.embeddings.position_embedding.weight"].shape[0] == 577:
- json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
+ if embed_shape == 729:
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_384.json")
+ elif embed_shape == 1024:
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_siglip_512.json")
+ elif embed_shape == 577:
+ if "multi_modal_projector.linear_1.bias" in sd:
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336_llava.json")
+ else:
+ json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl_336.json")
else:
json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_vision_config_vitl.json")
+ elif "embeddings.patch_embeddings.projection.weight" in sd:
+ json_config = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "image_encoders"), "dino2_giant.json")
else:
return None
diff --git a/comfy/clip_vision_config_vitl_336_llava.json b/comfy/clip_vision_config_vitl_336_llava.json
new file mode 100644
index 000000000..f23a50d8b
--- /dev/null
+++ b/comfy/clip_vision_config_vitl_336_llava.json
@@ -0,0 +1,19 @@
+{
+ "attention_dropout": 0.0,
+ "dropout": 0.0,
+ "hidden_act": "quick_gelu",
+ "hidden_size": 1024,
+ "image_size": 336,
+ "initializer_factor": 1.0,
+ "initializer_range": 0.02,
+ "intermediate_size": 4096,
+ "layer_norm_eps": 1e-5,
+ "model_type": "clip_vision_model",
+ "num_attention_heads": 16,
+ "num_channels": 3,
+ "num_hidden_layers": 24,
+ "patch_size": 14,
+ "projection_dim": 768,
+ "projector_type": "llava3",
+ "torch_dtype": "float32"
+}
diff --git a/comfy/clip_vision_siglip_512.json b/comfy/clip_vision_siglip_512.json
new file mode 100644
index 000000000..7fb93ce15
--- /dev/null
+++ b/comfy/clip_vision_siglip_512.json
@@ -0,0 +1,13 @@
+{
+ "num_channels": 3,
+ "hidden_act": "gelu_pytorch_tanh",
+ "hidden_size": 1152,
+ "image_size": 512,
+ "intermediate_size": 4304,
+ "model_type": "siglip_vision_model",
+ "num_attention_heads": 16,
+ "num_hidden_layers": 27,
+ "patch_size": 16,
+ "image_mean": [0.5, 0.5, 0.5],
+ "image_std": [0.5, 0.5, 0.5]
+}
diff --git a/comfy/comfy_types/__init__.py b/comfy/comfy_types/__init__.py
index 19ec33f98..7640fbe3f 100644
--- a/comfy/comfy_types/__init__.py
+++ b/comfy/comfy_types/__init__.py
@@ -1,6 +1,6 @@
import torch
from typing import Callable, Protocol, TypedDict, Optional, List
-from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin
+from .node_typing import IO, InputTypeDict, ComfyNodeABC, CheckLazyMixin, FileLocator
class UnetApplyFunction(Protocol):
@@ -42,4 +42,5 @@ __all__ = [
InputTypeDict.__name__,
ComfyNodeABC.__name__,
CheckLazyMixin.__name__,
+ FileLocator.__name__,
]
diff --git a/comfy/comfy_types/node_typing.py b/comfy/comfy_types/node_typing.py
index 056b1aa65..470eb9fdb 100644
--- a/comfy/comfy_types/node_typing.py
+++ b/comfy/comfy_types/node_typing.py
@@ -1,7 +1,8 @@
"""Comfy-specific type hinting"""
from __future__ import annotations
-from typing import Literal, TypedDict
+from typing import Literal, TypedDict, Optional
+from typing_extensions import NotRequired
from abc import ABC, abstractmethod
from enum import Enum
@@ -26,6 +27,7 @@ class IO(StrEnum):
BOOLEAN = "BOOLEAN"
INT = "INT"
FLOAT = "FLOAT"
+ COMBO = "COMBO"
CONDITIONING = "CONDITIONING"
SAMPLER = "SAMPLER"
SIGMAS = "SIGMAS"
@@ -46,6 +48,7 @@ class IO(StrEnum):
FACE_ANALYSIS = "FACE_ANALYSIS"
BBOX = "BBOX"
SEGS = "SEGS"
+ VIDEO = "VIDEO"
ANY = "*"
"""Always matches any type, but at a price.
@@ -67,90 +70,148 @@ class IO(StrEnum):
return not (b.issubset(a) or a.issubset(b))
+class RemoteInputOptions(TypedDict):
+ route: str
+ """The route to the remote source."""
+ refresh_button: bool
+ """Specifies whether to show a refresh button in the UI below the widget."""
+ control_after_refresh: Literal["first", "last"]
+ """Specifies the control after the refresh button is clicked. If "first", the first item will be automatically selected, and so on."""
+ timeout: int
+ """The maximum amount of time to wait for a response from the remote source in milliseconds."""
+ max_retries: int
+ """The maximum number of retries before aborting the request."""
+ refresh: int
+ """The TTL of the remote input's value in milliseconds. Specifies the interval at which the remote input's value is refreshed."""
+
+
+class MultiSelectOptions(TypedDict):
+ placeholder: NotRequired[str]
+ """The placeholder text to display in the multi-select widget when no items are selected."""
+ chip: NotRequired[bool]
+ """Specifies whether to use chips instead of comma separated values for the multi-select widget."""
+
+
class InputTypeOptions(TypedDict):
"""Provides type hinting for the return type of the INPUT_TYPES node function.
Due to IDE limitations with unions, for now all options are available for all types (e.g. `label_on` is hinted even when the type is not `IO.BOOLEAN`).
- Comfy Docs: https://docs.comfy.org/essentials/custom_node_datatypes
+ Comfy Docs: https://docs.comfy.org/custom-nodes/backend/datatypes
"""
- default: bool | str | float | int | list | tuple
+ default: NotRequired[bool | str | float | int | list | tuple]
"""The default value of the widget"""
- defaultInput: bool
- """Defaults to an input slot rather than a widget"""
- forceInput: bool
- """`defaultInput` and also don't allow converting to a widget"""
- lazy: bool
+ defaultInput: NotRequired[bool]
+ """@deprecated in v1.16 frontend. v1.16 frontend allows input socket and widget to co-exist.
+ - defaultInput on required inputs should be dropped.
+ - defaultInput on optional inputs should be replaced with forceInput.
+ Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3364
+ """
+ forceInput: NotRequired[bool]
+ """Forces the input to be an input slot rather than a widget even a widget is available for the input type."""
+ lazy: NotRequired[bool]
"""Declares that this input uses lazy evaluation"""
- rawLink: bool
+ rawLink: NotRequired[bool]
"""When a link exists, rather than receiving the evaluated value, you will receive the link (i.e. `["nodeId", ]`). Designed for node expansion."""
- tooltip: str
+ tooltip: NotRequired[str]
"""Tooltip for the input (or widget), shown on pointer hover"""
+ socketless: NotRequired[bool]
+ """All inputs (including widgets) have an input socket to connect links. When ``true``, if there is a widget for this input, no socket will be created.
+ Available from frontend v1.17.5
+ Ref: https://github.com/Comfy-Org/ComfyUI_frontend/pull/3548
+ """
+ widgetType: NotRequired[str]
+ """Specifies a type to be used for widget initialization if different from the input type.
+ Available from frontend v1.18.0
+ https://github.com/Comfy-Org/ComfyUI_frontend/pull/3550"""
# class InputTypeNumber(InputTypeOptions):
# default: float | int
- min: float
+ min: NotRequired[float]
"""The minimum value of a number (``FLOAT`` | ``INT``)"""
- max: float
+ max: NotRequired[float]
"""The maximum value of a number (``FLOAT`` | ``INT``)"""
- step: float
+ step: NotRequired[float]
"""The amount to increment or decrement a widget by when stepping up/down (``FLOAT`` | ``INT``)"""
- round: float
+ round: NotRequired[float]
"""Floats are rounded by this value (``FLOAT``)"""
# class InputTypeBoolean(InputTypeOptions):
# default: bool
- label_on: str
+ label_on: NotRequired[str]
"""The label to use in the UI when the bool is True (``BOOLEAN``)"""
- label_on: str
+ label_off: NotRequired[str]
"""The label to use in the UI when the bool is False (``BOOLEAN``)"""
# class InputTypeString(InputTypeOptions):
# default: str
- multiline: bool
+ multiline: NotRequired[bool]
"""Use a multiline text box (``STRING``)"""
- placeholder: str
+ placeholder: NotRequired[str]
"""Placeholder text to display in the UI when empty (``STRING``)"""
# Deprecated:
# defaultVal: str
- dynamicPrompts: bool
+ dynamicPrompts: NotRequired[bool]
"""Causes the front-end to evaluate dynamic prompts (``STRING``)"""
+ # class InputTypeCombo(InputTypeOptions):
+ image_upload: NotRequired[bool]
+ """Specifies whether the input should have an image upload button and image preview attached to it. Requires that the input's name is `image`."""
+ image_folder: NotRequired[Literal["input", "output", "temp"]]
+ """Specifies which folder to get preview images from if the input has the ``image_upload`` flag.
+ """
+ remote: NotRequired[RemoteInputOptions]
+ """Specifies the configuration for a remote input.
+ Available after ComfyUI frontend v1.9.7
+ https://github.com/Comfy-Org/ComfyUI_frontend/pull/2422"""
+ control_after_generate: NotRequired[bool]
+ """Specifies whether a control widget should be added to the input, adding options to automatically change the value after each prompt is queued. Currently only used for INT and COMBO types."""
+ options: NotRequired[list[str | int | float]]
+ """COMBO type only. Specifies the selectable options for the combo widget.
+ Prefer:
+ ["COMBO", {"options": ["Option 1", "Option 2", "Option 3"]}]
+ Over:
+ [["Option 1", "Option 2", "Option 3"]]
+ """
+ multi_select: NotRequired[MultiSelectOptions]
+ """COMBO type only. Specifies the configuration for a multi-select widget.
+ Available after ComfyUI frontend v1.13.4
+ https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
class HiddenInputTypeDict(TypedDict):
"""Provides type hinting for the hidden entry of node INPUT_TYPES."""
- node_id: Literal["UNIQUE_ID"]
+ node_id: NotRequired[Literal["UNIQUE_ID"]]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
- unique_id: Literal["UNIQUE_ID"]
+ unique_id: NotRequired[Literal["UNIQUE_ID"]]
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
- prompt: Literal["PROMPT"]
+ prompt: NotRequired[Literal["PROMPT"]]
"""PROMPT is the complete prompt sent by the client to the server. See the prompt object for a full description."""
- extra_pnginfo: Literal["EXTRA_PNGINFO"]
+ extra_pnginfo: NotRequired[Literal["EXTRA_PNGINFO"]]
"""EXTRA_PNGINFO is a dictionary that will be copied into the metadata of any .png files saved. Custom nodes can store additional information in this dictionary for saving (or as a way to communicate with a downstream node)."""
- dynprompt: Literal["DYNPROMPT"]
+ dynprompt: NotRequired[Literal["DYNPROMPT"]]
"""DYNPROMPT is an instance of comfy_execution.graph.DynamicPrompt. It differs from PROMPT in that it may mutate during the course of execution in response to Node Expansion."""
class InputTypeDict(TypedDict):
"""Provides type hinting for node INPUT_TYPES.
- Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs
+ Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs
"""
- required: dict[str, tuple[IO, InputTypeOptions]]
+ required: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
"""Describes all inputs that must be connected for the node to execute."""
- optional: dict[str, tuple[IO, InputTypeOptions]]
+ optional: NotRequired[dict[str, tuple[IO, InputTypeOptions]]]
"""Describes inputs which do not need to be connected."""
- hidden: HiddenInputTypeDict
+ hidden: NotRequired[HiddenInputTypeDict]
"""Offers advanced functionality and server-client communication.
- Comfy Docs: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
+ Comfy Docs: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
"""
class ComfyNodeABC(ABC):
"""Abstract base class for Comfy nodes. Includes the names and expected types of attributes.
- Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview
+ Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview
"""
DESCRIPTION: str
@@ -167,12 +228,14 @@ class ComfyNodeABC(ABC):
CATEGORY: str
"""The category of the node, as per the "Add Node" menu.
- Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#category
+ Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#category
"""
EXPERIMENTAL: bool
"""Flags a node as experimental, informing users that it may change or not work as expected."""
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
+ API_NODE: Optional[bool]
+ """Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""
@classmethod
@abstractmethod
@@ -181,9 +244,9 @@ class ComfyNodeABC(ABC):
* Must include the ``required`` key, which describes all inputs that must be connected for the node to execute.
* The ``optional`` key can be added to describe inputs which do not need to be connected.
- * The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/essentials/custom_node_more_on_inputs#hidden-inputs
+ * The ``hidden`` key offers some advanced functionality. More info at: https://docs.comfy.org/custom-nodes/backend/more_on_inputs#hidden-inputs
- Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#input-types
+ Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#input-types
"""
return {"required": {}}
@@ -198,7 +261,7 @@ class ComfyNodeABC(ABC):
By default, a node is not considered an output. Set ``OUTPUT_NODE = True`` to specify that it is.
- Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#output-node
+ Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#output-node
"""
INPUT_IS_LIST: bool
"""A flag indicating if this node implements the additional code necessary to deal with OUTPUT_IS_LIST nodes.
@@ -209,9 +272,9 @@ class ComfyNodeABC(ABC):
A node can also override the default input behaviour and receive the whole list in a single call. This is done by setting a class attribute `INPUT_IS_LIST` to ``True``.
- Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
+ Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
"""
- OUTPUT_IS_LIST: tuple[bool]
+ OUTPUT_IS_LIST: tuple[bool, ...]
"""A tuple indicating which node outputs are lists, but will be connected to nodes that expect individual items.
Connected nodes that do not implement `INPUT_IS_LIST` will be executed once for every item in the list.
@@ -227,29 +290,29 @@ class ComfyNodeABC(ABC):
the node should provide a class attribute `OUTPUT_IS_LIST`, which is a ``tuple[bool]``, of the same length as `RETURN_TYPES`,
specifying which outputs which should be so treated.
- Comfy Docs: https://docs.comfy.org/essentials/custom_node_lists#list-processing
+ Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lists#list-processing
"""
- RETURN_TYPES: tuple[IO]
+ RETURN_TYPES: tuple[IO, ...]
"""A tuple representing the outputs of this node.
Usage::
RETURN_TYPES = (IO.INT, "INT", "CUSTOM_TYPE")
- Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-types
+ Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-types
"""
- RETURN_NAMES: tuple[str]
+ RETURN_NAMES: tuple[str, ...]
"""The output slot names for each item in `RETURN_TYPES`, e.g. ``RETURN_NAMES = ("count", "filter_string")``
- Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#return-names
+ Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#return-names
"""
- OUTPUT_TOOLTIPS: tuple[str]
+ OUTPUT_TOOLTIPS: tuple[str, ...]
"""A tuple of strings to use as tooltips for node outputs, one for each item in `RETURN_TYPES`."""
FUNCTION: str
"""The name of the function to execute as a literal string, e.g. `FUNCTION = "execute"`
- Comfy Docs: https://docs.comfy.org/essentials/custom_node_server_overview#function
+ Comfy Docs: https://docs.comfy.org/custom-nodes/backend/server_overview#function
"""
@@ -267,8 +330,19 @@ class CheckLazyMixin:
Params should match the nodes execution ``FUNCTION`` (self, and all inputs by name).
Will be executed repeatedly until it returns an empty list, or all requested items were already evaluated (and sent as params).
- Comfy Docs: https://docs.comfy.org/essentials/custom_node_lazy_evaluation#defining-check-lazy-status
+ Comfy Docs: https://docs.comfy.org/custom-nodes/backend/lazy_evaluation#defining-check-lazy-status
"""
need = [name for name in kwargs if kwargs[name] is None]
return need
+
+
+class FileLocator(TypedDict):
+ """Provides type hinting for the file location"""
+
+ filename: str
+ """The filename of the file."""
+ subfolder: str
+ """The subfolder of the file."""
+ type: Literal["input", "output", "temp"]
+ """The root folder of the file."""
diff --git a/comfy/conds.py b/comfy/conds.py
index 660690af8..211fb8d57 100644
--- a/comfy/conds.py
+++ b/comfy/conds.py
@@ -3,9 +3,6 @@ import math
import comfy.utils
-def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9)
- return abs(a*b) // math.gcd(a, b)
-
class CONDRegular:
def __init__(self, cond):
self.cond = cond
@@ -46,7 +43,7 @@ class CONDCrossAttn(CONDRegular):
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
return False
- mult_min = lcm(s1[1], s2[1])
+ mult_min = math.lcm(s1[1], s2[1])
diff = mult_min // min(s1[1], s2[1])
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
return False
@@ -57,7 +54,7 @@ class CONDCrossAttn(CONDRegular):
crossattn_max_len = self.cond.shape[1]
for x in others:
c = x.cond
- crossattn_max_len = lcm(crossattn_max_len, c.shape[1])
+ crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1])
conds.append(c)
out = []
diff --git a/comfy/controlnet.py b/comfy/controlnet.py
index ee29251b9..11483e21d 100644
--- a/comfy/controlnet.py
+++ b/comfy/controlnet.py
@@ -418,10 +418,7 @@ def controlnet_config(sd, model_options={}):
weight_dtype = comfy.utils.weight_dtype(sd)
supported_inference_dtypes = list(model_config.supported_inference_dtypes)
- if weight_dtype is not None:
- supported_inference_dtypes.append(weight_dtype)
-
- unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
+ unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
load_device = comfy.model_management.get_torch_device()
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
@@ -689,10 +686,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
if supported_inference_dtypes is None:
supported_inference_dtypes = [comfy.model_management.unet_dtype()]
- if weight_dtype is not None:
- supported_inference_dtypes.append(weight_dtype)
-
- unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes)
+ unet_dtype = comfy.model_management.unet_dtype(model_params=-1, supported_dtypes=supported_inference_dtypes, weight_dtype=weight_dtype)
load_device = comfy.model_management.get_torch_device()
@@ -742,6 +736,7 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
return control
def load_controlnet(ckpt_path, model=None, model_options={}):
+ model_options = model_options.copy()
if "global_average_pooling" not in model_options:
filename = os.path.splitext(ckpt_path)[0]
if filename.endswith("_shuffle") or filename.endswith("_shuffle_fp16"): #TODO: smarter way of enabling global_average_pooling
diff --git a/comfy/diffusers_convert.py b/comfy/diffusers_convert.py
index 26e8d96d5..fb9495348 100644
--- a/comfy/diffusers_convert.py
+++ b/comfy/diffusers_convert.py
@@ -4,105 +4,6 @@ import logging
# conversion code from https://github.com/huggingface/diffusers/blob/main/scripts/convert_diffusers_to_original_stable_diffusion.py
-# =================#
-# UNet Conversion #
-# =================#
-
-unet_conversion_map = [
- # (stable-diffusion, HF Diffusers)
- ("time_embed.0.weight", "time_embedding.linear_1.weight"),
- ("time_embed.0.bias", "time_embedding.linear_1.bias"),
- ("time_embed.2.weight", "time_embedding.linear_2.weight"),
- ("time_embed.2.bias", "time_embedding.linear_2.bias"),
- ("input_blocks.0.0.weight", "conv_in.weight"),
- ("input_blocks.0.0.bias", "conv_in.bias"),
- ("out.0.weight", "conv_norm_out.weight"),
- ("out.0.bias", "conv_norm_out.bias"),
- ("out.2.weight", "conv_out.weight"),
- ("out.2.bias", "conv_out.bias"),
-]
-
-unet_conversion_map_resnet = [
- # (stable-diffusion, HF Diffusers)
- ("in_layers.0", "norm1"),
- ("in_layers.2", "conv1"),
- ("out_layers.0", "norm2"),
- ("out_layers.3", "conv2"),
- ("emb_layers.1", "time_emb_proj"),
- ("skip_connection", "conv_shortcut"),
-]
-
-unet_conversion_map_layer = []
-# hardcoded number of downblocks and resnets/attentions...
-# would need smarter logic for other networks.
-for i in range(4):
- # loop over downblocks/upblocks
-
- for j in range(2):
- # loop over resnets/attentions for downblocks
- hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
- sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
- unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
-
- if i < 3:
- # no attention layers in down_blocks.3
- hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
- sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
- unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
-
- for j in range(3):
- # loop over resnets/attentions for upblocks
- hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
- sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
- unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
-
- if i > 0:
- # no attention layers in up_blocks.0
- hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
- sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
- unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
-
- if i < 3:
- # no downsample in down_blocks.3
- hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
- sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
- unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
-
- # no upsample in up_blocks.3
- hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
- sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
- unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
-
-hf_mid_atn_prefix = "mid_block.attentions.0."
-sd_mid_atn_prefix = "middle_block.1."
-unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
-
-for j in range(2):
- hf_mid_res_prefix = f"mid_block.resnets.{j}."
- sd_mid_res_prefix = f"middle_block.{2 * j}."
- unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
-
-
-def convert_unet_state_dict(unet_state_dict):
- # buyer beware: this is a *brittle* function,
- # and correct output requires that all of these pieces interact in
- # the exact order in which I have arranged them.
- mapping = {k: k for k in unet_state_dict.keys()}
- for sd_name, hf_name in unet_conversion_map:
- mapping[hf_name] = sd_name
- for k, v in mapping.items():
- if "resnets" in k:
- for sd_part, hf_part in unet_conversion_map_resnet:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- for k, v in mapping.items():
- for sd_part, hf_part in unet_conversion_map_layer:
- v = v.replace(hf_part, sd_part)
- mapping[k] = v
- new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
- return new_state_dict
-
-
# ================#
# VAE Conversion #
# ================#
@@ -213,6 +114,7 @@ textenc_pattern = re.compile("|".join(protected.keys()))
# Ordering is from https://github.com/pytorch/pytorch/blob/master/test/cpp/api/modules.cpp
code2idx = {"q": 0, "k": 1, "v": 2}
+
# This function exists because at the time of writing torch.cat can't do fp8 with cuda
def cat_tensors(tensors):
x = 0
@@ -229,6 +131,7 @@ def cat_tensors(tensors):
return out
+
def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
new_state_dict = {}
capture_qkv_weight = {}
@@ -284,5 +187,3 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
def convert_text_enc_state_dict(text_enc_dict):
return text_enc_dict
-
-
diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py
index 5b80a8aff..c57e081e4 100644
--- a/comfy/extra_samplers/uni_pc.py
+++ b/comfy/extra_samplers/uni_pc.py
@@ -661,7 +661,7 @@ class UniPC:
if x_t is None:
if use_predictor:
- pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
+ pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * pred_res
@@ -669,7 +669,7 @@ class UniPC:
if use_corrector:
model_t = self.model_fn(x_t, t)
if D1s is not None:
- corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
+ corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0])) # torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = (model_t - model_prev_0)
diff --git a/comfy/image_encoders/dino2.py b/comfy/image_encoders/dino2.py
new file mode 100644
index 000000000..976f98c65
--- /dev/null
+++ b/comfy/image_encoders/dino2.py
@@ -0,0 +1,141 @@
+import torch
+from comfy.text_encoders.bert import BertAttention
+import comfy.model_management
+from comfy.ldm.modules.attention import optimized_attention_for_device
+
+
+class Dino2AttentionOutput(torch.nn.Module):
+ def __init__(self, input_dim, output_dim, layer_norm_eps, dtype, device, operations):
+ super().__init__()
+ self.dense = operations.Linear(input_dim, output_dim, dtype=dtype, device=device)
+
+ def forward(self, x):
+ return self.dense(x)
+
+
+class Dino2AttentionBlock(torch.nn.Module):
+ def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
+ super().__init__()
+ self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
+ self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
+
+ def forward(self, x, mask, optimized_attention):
+ return self.output(self.attention(x, mask, optimized_attention))
+
+
+class LayerScale(torch.nn.Module):
+ def __init__(self, dim, dtype, device, operations):
+ super().__init__()
+ self.lambda1 = torch.nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
+
+ def forward(self, x):
+ return x * comfy.model_management.cast_to_device(self.lambda1, x.device, x.dtype)
+
+
+class SwiGLUFFN(torch.nn.Module):
+ def __init__(self, dim, dtype, device, operations):
+ super().__init__()
+ in_features = out_features = dim
+ hidden_features = int(dim * 4)
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+
+ self.weights_in = operations.Linear(in_features, 2 * hidden_features, bias=True, device=device, dtype=dtype)
+ self.weights_out = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
+
+ def forward(self, x):
+ x = self.weights_in(x)
+ x1, x2 = x.chunk(2, dim=-1)
+ x = torch.nn.functional.silu(x1) * x2
+ return self.weights_out(x)
+
+
+class Dino2Block(torch.nn.Module):
+ def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations):
+ super().__init__()
+ self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
+ self.layer_scale1 = LayerScale(dim, dtype, device, operations)
+ self.layer_scale2 = LayerScale(dim, dtype, device, operations)
+ self.mlp = SwiGLUFFN(dim, dtype, device, operations)
+ self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
+ self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
+
+ def forward(self, x, optimized_attention):
+ x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
+ x = x + self.layer_scale2(self.mlp(self.norm2(x)))
+ return x
+
+
+class Dino2Encoder(torch.nn.Module):
+ def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations):
+ super().__init__()
+ self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations) for _ in range(num_layers)])
+
+ def forward(self, x, intermediate_output=None):
+ optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
+
+ if intermediate_output is not None:
+ if intermediate_output < 0:
+ intermediate_output = len(self.layer) + intermediate_output
+
+ intermediate = None
+ for i, l in enumerate(self.layer):
+ x = l(x, optimized_attention)
+ if i == intermediate_output:
+ intermediate = x.clone()
+ return x, intermediate
+
+
+class Dino2PatchEmbeddings(torch.nn.Module):
+ def __init__(self, dim, num_channels=3, patch_size=14, image_size=518, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.projection = operations.Conv2d(
+ in_channels=num_channels,
+ out_channels=dim,
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=True,
+ dtype=dtype,
+ device=device
+ )
+
+ def forward(self, pixel_values):
+ return self.projection(pixel_values).flatten(2).transpose(1, 2)
+
+
+class Dino2Embeddings(torch.nn.Module):
+ def __init__(self, dim, dtype, device, operations):
+ super().__init__()
+ patch_size = 14
+ image_size = 518
+
+ self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
+ self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
+ self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device))
+ self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
+
+ def forward(self, pixel_values):
+ x = self.patch_embeddings(pixel_values)
+ # TODO: mask_token?
+ x = torch.cat((self.cls_token.to(device=x.device, dtype=x.dtype).expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + comfy.model_management.cast_to_device(self.position_embeddings, x.device, x.dtype)
+ return x
+
+
+class Dinov2Model(torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ num_layers = config_dict["num_hidden_layers"]
+ dim = config_dict["hidden_size"]
+ heads = config_dict["num_attention_heads"]
+ layer_norm_eps = config_dict["layer_norm_eps"]
+
+ self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
+ self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations)
+ self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
+
+ def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
+ x = self.embeddings(pixel_values)
+ x, i = self.encoder(x, intermediate_output=intermediate_output)
+ x = self.layernorm(x)
+ pooled_output = x[:, 0, :]
+ return x, i, pooled_output, None
diff --git a/comfy/image_encoders/dino2_giant.json b/comfy/image_encoders/dino2_giant.json
new file mode 100644
index 000000000..f6076a4dc
--- /dev/null
+++ b/comfy/image_encoders/dino2_giant.json
@@ -0,0 +1,21 @@
+{
+ "attention_probs_dropout_prob": 0.0,
+ "drop_path_rate": 0.0,
+ "hidden_act": "gelu",
+ "hidden_dropout_prob": 0.0,
+ "hidden_size": 1536,
+ "image_size": 518,
+ "initializer_range": 0.02,
+ "layer_norm_eps": 1e-06,
+ "layerscale_value": 1.0,
+ "mlp_ratio": 4,
+ "model_type": "dinov2",
+ "num_attention_heads": 24,
+ "num_channels": 3,
+ "num_hidden_layers": 40,
+ "patch_size": 14,
+ "qkv_bias": true,
+ "use_swiglu_ffn": true,
+ "image_mean": [0.485, 0.456, 0.406],
+ "image_std": [0.229, 0.224, 0.225]
+}
diff --git a/comfy/k_diffusion/res.py b/comfy/k_diffusion/res.py
deleted file mode 100644
index 6caedec39..000000000
--- a/comfy/k_diffusion/res.py
+++ /dev/null
@@ -1,258 +0,0 @@
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
-# SPDX-License-Identifier: Apache-2.0
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Copied from Nvidia Cosmos code.
-
-import torch
-from torch import Tensor
-from typing import Callable, List, Tuple, Optional, Any
-import math
-from tqdm.auto import trange
-
-
-def common_broadcast(x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
- ndims1 = x.ndim
- ndims2 = y.ndim
-
- if ndims1 < ndims2:
- x = x.reshape(x.shape + (1,) * (ndims2 - ndims1))
- elif ndims2 < ndims1:
- y = y.reshape(y.shape + (1,) * (ndims1 - ndims2))
-
- return x, y
-
-
-def batch_mul(x: Tensor, y: Tensor) -> Tensor:
- x, y = common_broadcast(x, y)
- return x * y
-
-
-def phi1(t: torch.Tensor) -> torch.Tensor:
- """
- Compute the first order phi function: (exp(t) - 1) / t.
-
- Args:
- t: Input tensor.
-
- Returns:
- Tensor: Result of phi1 function.
- """
- input_dtype = t.dtype
- t = t.to(dtype=torch.float32)
- return (torch.expm1(t) / t).to(dtype=input_dtype)
-
-
-def phi2(t: torch.Tensor) -> torch.Tensor:
- """
- Compute the second order phi function: (phi1(t) - 1) / t.
-
- Args:
- t: Input tensor.
-
- Returns:
- Tensor: Result of phi2 function.
- """
- input_dtype = t.dtype
- t = t.to(dtype=torch.float32)
- return ((phi1(t) - 1.0) / t).to(dtype=input_dtype)
-
-
-def res_x0_rk2_step(
- x_s: torch.Tensor,
- t: torch.Tensor,
- s: torch.Tensor,
- x0_s: torch.Tensor,
- s1: torch.Tensor,
- x0_s1: torch.Tensor,
-) -> torch.Tensor:
- """
- Perform a residual-based 2nd order Runge-Kutta step.
-
- Args:
- x_s: Current state tensor.
- t: Target time tensor.
- s: Current time tensor.
- x0_s: Prediction at current time.
- s1: Intermediate time tensor.
- x0_s1: Prediction at intermediate time.
-
- Returns:
- Tensor: Updated state tensor.
-
- Raises:
- AssertionError: If step size is too small.
- """
- s = -torch.log(s)
- t = -torch.log(t)
- m = -torch.log(s1)
-
- dt = t - s
- assert not torch.any(torch.isclose(dt, torch.zeros_like(dt), atol=1e-6)), "Step size is too small"
- assert not torch.any(torch.isclose(m - s, torch.zeros_like(dt), atol=1e-6)), "Step size is too small"
-
- c2 = (m - s) / dt
- phi1_val, phi2_val = phi1(-dt), phi2(-dt)
-
- # Handle edge case where t = s = m
- b1 = torch.nan_to_num(phi1_val - 1.0 / c2 * phi2_val, nan=0.0)
- b2 = torch.nan_to_num(1.0 / c2 * phi2_val, nan=0.0)
-
- return batch_mul(torch.exp(-dt), x_s) + batch_mul(dt, batch_mul(b1, x0_s) + batch_mul(b2, x0_s1))
-
-
-def reg_x0_euler_step(
- x_s: torch.Tensor,
- s: torch.Tensor,
- t: torch.Tensor,
- x0_s: torch.Tensor,
-) -> Tuple[torch.Tensor, torch.Tensor]:
- """
- Perform a regularized Euler step based on x0 prediction.
-
- Args:
- x_s: Current state tensor.
- s: Current time tensor.
- t: Target time tensor.
- x0_s: Prediction at current time.
-
- Returns:
- Tuple[Tensor, Tensor]: Updated state tensor and current prediction.
- """
- coef_x0 = (s - t) / s
- coef_xs = t / s
- return batch_mul(coef_x0, x0_s) + batch_mul(coef_xs, x_s), x0_s
-
-
-def order2_fn(
- x_s: torch.Tensor, s: torch.Tensor, t: torch.Tensor, x0_s: torch.Tensor, x0_preds: torch.Tensor
-) -> Tuple[torch.Tensor, List[torch.Tensor]]:
- """
- impl the second order multistep method in https://arxiv.org/pdf/2308.02157
- Adams Bashforth approach!
- """
- if x0_preds:
- x0_s1, s1 = x0_preds[0]
- x_t = res_x0_rk2_step(x_s, t, s, x0_s, s1, x0_s1)
- else:
- x_t = reg_x0_euler_step(x_s, s, t, x0_s)[0]
- return x_t, [(x0_s, s)]
-
-
-class SolverConfig:
- is_multi: bool = True
- rk: str = "2mid"
- multistep: str = "2ab"
- s_churn: float = 0.0
- s_t_max: float = float("inf")
- s_t_min: float = 0.0
- s_noise: float = 1.0
-
-
-def fori_loop(lower: int, upper: int, body_fun: Callable[[int, Any], Any], init_val: Any, disable=None) -> Any:
- """
- Implements a for loop with a function.
-
- Args:
- lower: Lower bound of the loop (inclusive).
- upper: Upper bound of the loop (exclusive).
- body_fun: Function to be applied in each iteration.
- init_val: Initial value for the loop.
-
- Returns:
- The final result after all iterations.
- """
- val = init_val
- for i in trange(lower, upper, disable=disable):
- val = body_fun(i, val)
- return val
-
-
-def differential_equation_solver(
- x0_fn: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
- sigmas_L: torch.Tensor,
- solver_cfg: SolverConfig,
- noise_sampler,
- callback=None,
- disable=None,
-) -> Callable[[torch.Tensor], torch.Tensor]:
- """
- Creates a differential equation solver function.
-
- Args:
- x0_fn: Function to compute x0 prediction.
- sigmas_L: Tensor of sigma values with shape [L,].
- solver_cfg: Configuration for the solver.
-
- Returns:
- A function that solves the differential equation.
- """
- num_step = len(sigmas_L) - 1
-
- # if solver_cfg.is_multi:
- # update_step_fn = get_multi_step_fn(solver_cfg.multistep)
- # else:
- # update_step_fn = get_runge_kutta_fn(solver_cfg.rk)
- update_step_fn = order2_fn
-
- eta = min(solver_cfg.s_churn / (num_step + 1), math.sqrt(1.2) - 1)
-
- def sample_fn(input_xT_B_StateShape: torch.Tensor) -> torch.Tensor:
- """
- Samples from the differential equation.
-
- Args:
- input_xT_B_StateShape: Input tensor with shape [B, StateShape].
-
- Returns:
- Output tensor with shape [B, StateShape].
- """
- ones_B = torch.ones(input_xT_B_StateShape.size(0), device=input_xT_B_StateShape.device, dtype=torch.float32)
-
- def step_fn(
- i_th: int, state: Tuple[torch.Tensor, Optional[List[torch.Tensor]]]
- ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
- input_x_B_StateShape, x0_preds = state
- sigma_cur_0, sigma_next_0 = sigmas_L[i_th], sigmas_L[i_th + 1]
-
- if sigma_next_0 == 0:
- output_x_B_StateShape = x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B)
- else:
- # algorithm 2: line 4-6
- if solver_cfg.s_t_min < sigma_cur_0 < solver_cfg.s_t_max and eta > 0:
- hat_sigma_cur_0 = sigma_cur_0 + eta * sigma_cur_0
- input_x_B_StateShape = input_x_B_StateShape + (
- hat_sigma_cur_0**2 - sigma_cur_0**2
- ).sqrt() * solver_cfg.s_noise * noise_sampler(sigma_cur_0, sigma_next_0) # torch.randn_like(input_x_B_StateShape)
- sigma_cur_0 = hat_sigma_cur_0
-
- if solver_cfg.is_multi:
- x0_pred_B_StateShape = x0_fn(input_x_B_StateShape, sigma_cur_0 * ones_B)
- output_x_B_StateShape, x0_preds = update_step_fn(
- input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_pred_B_StateShape, x0_preds
- )
- else:
- output_x_B_StateShape, x0_preds = update_step_fn(
- input_x_B_StateShape, sigma_cur_0 * ones_B, sigma_next_0 * ones_B, x0_fn
- )
-
- if callback is not None:
- callback({'x': input_x_B_StateShape, 'i': i_th, 'sigma': sigma_cur_0, 'sigma_hat': sigma_cur_0, 'denoised': x0_pred_B_StateShape})
-
- return output_x_B_StateShape, x0_preds
-
- x_at_eps, _ = fori_loop(0, num_step, step_fn, [input_xT_B_StateShape, None], disable=disable)
- return x_at_eps
-
- return sample_fn
diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py
index 3a98e6a7c..fbdf6f554 100644
--- a/comfy/k_diffusion/sampling.py
+++ b/comfy/k_diffusion/sampling.py
@@ -8,7 +8,6 @@ from tqdm.auto import trange, tqdm
from . import utils
from . import deis
-from . import res
import comfy.model_patcher
import comfy.model_sampling
@@ -41,7 +40,7 @@ def get_sigmas_polyexponential(n, sigma_min, sigma_max, rho=1., device='cpu'):
def get_sigmas_vp(n, beta_d=19.9, beta_min=0.1, eps_s=1e-3, device='cpu'):
"""Constructs a continuous VP noise schedule."""
t = torch.linspace(1, eps_s, n, device=device)
- sigmas = torch.sqrt(torch.exp(beta_d * t ** 2 / 2 + beta_min * t) - 1)
+ sigmas = torch.sqrt(torch.special.expm1(beta_d * t ** 2 / 2 + beta_min * t))
return append_zero(sigmas)
@@ -689,10 +688,10 @@ def sample_dpmpp_sde(model, x, sigmas, extra_args=None, callback=None, disable=N
if len(sigmas) <= 1:
return x
+ extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
seed = extra_args.get("seed", None)
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
- extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
sigma_fn = lambda t: t.neg().exp()
t_fn = lambda sigma: sigma.log().neg()
@@ -763,10 +762,10 @@ def sample_dpmpp_2m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if solver_type not in {'heun', 'midpoint'}:
raise ValueError('solver_type must be \'heun\' or \'midpoint\'')
+ extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
- extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
old_denoised = None
@@ -809,10 +808,10 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
if len(sigmas) <= 1:
return x
+ extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=True) if noise_sampler is None else noise_sampler
- extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones([x.shape[0]])
denoised_1, denoised_2 = None, None
@@ -859,7 +858,7 @@ def sample_dpmpp_3m_sde(model, x, sigmas, extra_args=None, callback=None, disabl
def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
if len(sigmas) <= 1:
return x
-
+ extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_3m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler)
@@ -868,7 +867,7 @@ def sample_dpmpp_3m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type='midpoint'):
if len(sigmas) <= 1:
return x
-
+ extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_2m_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, solver_type=solver_type)
@@ -877,7 +876,7 @@ def sample_dpmpp_2m_sde_gpu(model, x, sigmas, extra_args=None, callback=None, di
def sample_dpmpp_sde_gpu(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=1 / 2):
if len(sigmas) <= 1:
return x
-
+ extra_args = {} if extra_args is None else extra_args
sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=extra_args.get("seed", None), cpu=False) if noise_sampler is None else noise_sampler
return sample_dpmpp_sde(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=r)
@@ -1268,18 +1267,282 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
return x
@torch.no_grad()
-def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1., noise_sampler=None):
+def res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, eta=1., cfg_pp=False):
extra_args = {} if extra_args is None else extra_args
seed = extra_args.get("seed", None)
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+ sigma_fn = lambda t: t.neg().exp()
+ t_fn = lambda sigma: sigma.log().neg()
+ phi1_fn = lambda t: torch.expm1(t) / t
+ phi2_fn = lambda t: (phi1_fn(t) - 1.0) / t
- x0_func = lambda x, sigma: model(x, sigma, **extra_args)
+ old_sigma_down = None
+ old_denoised = None
+ uncond_denoised = None
+ def post_cfg_function(args):
+ nonlocal uncond_denoised
+ uncond_denoised = args["uncond_denoised"]
+ return args["denoised"]
- solver_cfg = res.SolverConfig()
- solver_cfg.s_churn = s_churn
- solver_cfg.s_t_max = s_tmax
- solver_cfg.s_t_min = s_tmin
- solver_cfg.s_noise = s_noise
+ if cfg_pp:
+ model_options = extra_args.get("model_options", {}).copy()
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
- x = res.differential_equation_solver(x0_func, sigmas, solver_cfg, noise_sampler, callback=callback, disable=disable)(x)
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
+ if callback is not None:
+ callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
+ if sigma_down == 0 or old_denoised is None:
+ # Euler method
+ if cfg_pp:
+ d = to_d(x, sigmas[i], uncond_denoised)
+ x = denoised + d * sigma_down
+ else:
+ d = to_d(x, sigmas[i], denoised)
+ dt = sigma_down - sigmas[i]
+ x = x + d * dt
+ else:
+ # Second order multistep method in https://arxiv.org/pdf/2308.02157
+ t, t_old, t_next, t_prev = t_fn(sigmas[i]), t_fn(old_sigma_down), t_fn(sigma_down), t_fn(sigmas[i - 1])
+ h = t_next - t
+ c2 = (t_prev - t_old) / h
+
+ phi1_val, phi2_val = phi1_fn(-h), phi2_fn(-h)
+ b1 = torch.nan_to_num(phi1_val - phi2_val / c2, nan=0.0)
+ b2 = torch.nan_to_num(phi2_val / c2, nan=0.0)
+
+ if cfg_pp:
+ x = x + (denoised - uncond_denoised)
+ x = sigma_fn(h) * x + h * (b1 * uncond_denoised + b2 * old_denoised)
+ else:
+ x = sigma_fn(h) * x + h * (b1 * denoised + b2 * old_denoised)
+
+ # Noise addition
+ if sigmas[i + 1] > 0:
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
+
+ if cfg_pp:
+ old_denoised = uncond_denoised
+ else:
+ old_denoised = denoised
+ old_sigma_down = sigma_down
+ return x
+
+@torch.no_grad()
+def sample_res_multistep(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
+ return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=False)
+
+@torch.no_grad()
+def sample_res_multistep_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None):
+ return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=0., cfg_pp=True)
+
+@torch.no_grad()
+def sample_res_multistep_ancestral(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
+ return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=False)
+
+@torch.no_grad()
+def sample_res_multistep_ancestral_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None):
+ return res_multistep(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, s_noise=s_noise, noise_sampler=noise_sampler, eta=eta, cfg_pp=True)
+
+@torch.no_grad()
+def sample_gradient_estimation(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2., cfg_pp=False):
+ """Gradient-estimation sampler. Paper: https://openreview.net/pdf?id=o2ND9v0CeK"""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ old_d = None
+
+ uncond_denoised = None
+ def post_cfg_function(args):
+ nonlocal uncond_denoised
+ uncond_denoised = args["uncond_denoised"]
+ return args["denoised"]
+
+ if cfg_pp:
+ model_options = extra_args.get("model_options", {}).copy()
+ extra_args["model_options"] = comfy.model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ if cfg_pp:
+ d = to_d(x, sigmas[i], uncond_denoised)
+ else:
+ d = to_d(x, sigmas[i], denoised)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ dt = sigmas[i + 1] - sigmas[i]
+ if i == 0:
+ # Euler method
+ if cfg_pp:
+ x = denoised + d * sigmas[i + 1]
+ else:
+ x = x + d * dt
+ else:
+ # Gradient estimation
+ if cfg_pp:
+ d_bar = (ge_gamma - 1) * (d - old_d)
+ x = denoised + d * sigmas[i + 1] + d_bar * dt
+ else:
+ d_bar = ge_gamma * d + (1 - ge_gamma) * old_d
+ x = x + d_bar * dt
+ old_d = d
+ return x
+
+@torch.no_grad()
+def sample_gradient_estimation_cfg_pp(model, x, sigmas, extra_args=None, callback=None, disable=None, ge_gamma=2.):
+ return sample_gradient_estimation(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, ge_gamma=ge_gamma, cfg_pp=True)
+
+@torch.no_grad()
+def sample_er_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., noise_sampler=None, noise_scaler=None, max_stage=3):
+ """
+ Extended Reverse-Time SDE solver (VE ER-SDE-Solver-3). Arxiv: https://arxiv.org/abs/2309.06169.
+ Code reference: https://github.com/QinpengCui/ER-SDE-Solver/blob/main/er_sde_solver.py.
+ """
+ extra_args = {} if extra_args is None else extra_args
+ seed = extra_args.get("seed", None)
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+
+ def default_noise_scaler(sigma):
+ return sigma * ((sigma ** 0.3).exp() + 10.0)
+ noise_scaler = default_noise_scaler if noise_scaler is None else noise_scaler
+ num_integration_points = 200.0
+ point_indice = torch.arange(0, num_integration_points, dtype=torch.float32, device=x.device)
+
+ old_denoised = None
+ old_denoised_d = None
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ stage_used = min(max_stage, i + 1)
+ if sigmas[i + 1] == 0:
+ x = denoised
+ elif stage_used == 1:
+ r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
+ x = r * x + (1 - r) * denoised
+ else:
+ r = noise_scaler(sigmas[i + 1]) / noise_scaler(sigmas[i])
+ x = r * x + (1 - r) * denoised
+
+ dt = sigmas[i + 1] - sigmas[i]
+ sigma_step_size = -dt / num_integration_points
+ sigma_pos = sigmas[i + 1] + point_indice * sigma_step_size
+ scaled_pos = noise_scaler(sigma_pos)
+
+ # Stage 2
+ s = torch.sum(1 / scaled_pos) * sigma_step_size
+ denoised_d = (denoised - old_denoised) / (sigmas[i] - sigmas[i - 1])
+ x = x + (dt + s * noise_scaler(sigmas[i + 1])) * denoised_d
+
+ if stage_used >= 3:
+ # Stage 3
+ s_u = torch.sum((sigma_pos - sigmas[i]) / scaled_pos) * sigma_step_size
+ denoised_u = (denoised_d - old_denoised_d) / ((sigmas[i] - sigmas[i - 2]) / 2)
+ x = x + ((dt ** 2) / 2 + s_u * noise_scaler(sigmas[i + 1])) * denoised_u
+ old_denoised_d = denoised_d
+
+ if s_noise != 0 and sigmas[i + 1] > 0:
+ x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * (sigmas[i + 1] ** 2 - sigmas[i] ** 2 * r ** 2).sqrt().nan_to_num(nan=0.0)
+ old_denoised = denoised
+ return x
+
+@torch.no_grad()
+def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
+ '''
+ SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 2
+ Arxiv: https://arxiv.org/abs/2305.14267
+ '''
+ extra_args = {} if extra_args is None else extra_args
+ seed = extra_args.get("seed", None)
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+
+ inject_noise = eta > 0 and s_noise > 0
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ if sigmas[i + 1] == 0:
+ x = denoised
+ else:
+ t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
+ h = t_next - t
+ h_eta = h * (eta + 1)
+ s = t + r * h
+ fac = 1 / (2 * r)
+ sigma_s = s.neg().exp()
+
+ coeff_1, coeff_2 = (-r * h_eta).expm1(), (-h_eta).expm1()
+ if inject_noise:
+ noise_coeff_1 = (-2 * r * h * eta).expm1().neg().sqrt()
+ noise_coeff_2 = ((-2 * r * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
+ noise_1, noise_2 = noise_sampler(sigmas[i], sigma_s), noise_sampler(sigma_s, sigmas[i + 1])
+
+ # Step 1
+ x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
+ if inject_noise:
+ x_2 = x_2 + sigma_s * (noise_coeff_1 * noise_1) * s_noise
+ denoised_2 = model(x_2, sigma_s * s_in, **extra_args)
+
+ # Step 2
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
+ x = (coeff_2 + 1) * x - coeff_2 * denoised_d
+ if inject_noise:
+ x = x + sigmas[i + 1] * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
+ return x
+
+@torch.no_grad()
+def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
+ '''
+ SEEDS-3 - Stochastic Explicit Exponential Derivative-free Solvers (VE Data Prediction) stage 3
+ Arxiv: https://arxiv.org/abs/2305.14267
+ '''
+ extra_args = {} if extra_args is None else extra_args
+ seed = extra_args.get("seed", None)
+ noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
+ s_in = x.new_ones([x.shape[0]])
+
+ inject_noise = eta > 0 and s_noise > 0
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ if sigmas[i + 1] == 0:
+ x = denoised
+ else:
+ t, t_next = -sigmas[i].log(), -sigmas[i + 1].log()
+ h = t_next - t
+ h_eta = h * (eta + 1)
+ s_1 = t + r_1 * h
+ s_2 = t + r_2 * h
+ sigma_s_1, sigma_s_2 = s_1.neg().exp(), s_2.neg().exp()
+
+ coeff_1, coeff_2, coeff_3 = (-r_1 * h_eta).expm1(), (-r_2 * h_eta).expm1(), (-h_eta).expm1()
+ if inject_noise:
+ noise_coeff_1 = (-2 * r_1 * h * eta).expm1().neg().sqrt()
+ noise_coeff_2 = ((-2 * r_1 * h * eta).expm1() - (-2 * r_2 * h * eta).expm1()).sqrt()
+ noise_coeff_3 = ((-2 * r_2 * h * eta).expm1() - (-2 * h * eta).expm1()).sqrt()
+ noise_1, noise_2, noise_3 = noise_sampler(sigmas[i], sigma_s_1), noise_sampler(sigma_s_1, sigma_s_2), noise_sampler(sigma_s_2, sigmas[i + 1])
+
+ # Step 1
+ x_2 = (coeff_1 + 1) * x - coeff_1 * denoised
+ if inject_noise:
+ x_2 = x_2 + sigma_s_1 * (noise_coeff_1 * noise_1) * s_noise
+ denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
+
+ # Step 2
+ x_3 = (coeff_2 + 1) * x - coeff_2 * denoised + (r_2 / r_1) * (coeff_2 / (r_2 * h_eta) + 1) * (denoised_2 - denoised)
+ if inject_noise:
+ x_3 = x_3 + sigma_s_2 * (noise_coeff_2 * noise_1 + noise_coeff_1 * noise_2) * s_noise
+ denoised_3 = model(x_3, sigma_s_2 * s_in, **extra_args)
+
+ # Step 3
+ x = (coeff_3 + 1) * x - coeff_3 * denoised + (1. / r_2) * (coeff_3 / h_eta + 1) * (denoised_3 - denoised)
+ if inject_noise:
+ x = x + sigmas[i + 1] * (noise_coeff_3 * noise_1 + noise_coeff_2 * noise_2 + noise_coeff_1 * noise_3) * s_noise
return x
diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py
index e98982c94..82d9f9bb8 100644
--- a/comfy/latent_formats.py
+++ b/comfy/latent_formats.py
@@ -407,3 +407,66 @@ class Cosmos1CV8x8x8(LatentFormat):
]
latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976]
+
+class Wan21(LatentFormat):
+ latent_channels = 16
+ latent_dimensions = 3
+
+ latent_rgb_factors = [
+ [-0.1299, -0.1692, 0.2932],
+ [ 0.0671, 0.0406, 0.0442],
+ [ 0.3568, 0.2548, 0.1747],
+ [ 0.0372, 0.2344, 0.1420],
+ [ 0.0313, 0.0189, -0.0328],
+ [ 0.0296, -0.0956, -0.0665],
+ [-0.3477, -0.4059, -0.2925],
+ [ 0.0166, 0.1902, 0.1975],
+ [-0.0412, 0.0267, -0.1364],
+ [-0.1293, 0.0740, 0.1636],
+ [ 0.0680, 0.3019, 0.1128],
+ [ 0.0032, 0.0581, 0.0639],
+ [-0.1251, 0.0927, 0.1699],
+ [ 0.0060, -0.0633, 0.0005],
+ [ 0.3477, 0.2275, 0.2950],
+ [ 0.1984, 0.0913, 0.1861]
+ ]
+
+ latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360]
+
+ def __init__(self):
+ self.scale_factor = 1.0
+ self.latents_mean = torch.tensor([
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
+ ]).view(1, self.latent_channels, 1, 1, 1)
+ self.latents_std = torch.tensor([
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
+ ]).view(1, self.latent_channels, 1, 1, 1)
+
+
+ self.taesd_decoder_name = None #TODO
+
+ def process_in(self, latent):
+ latents_mean = self.latents_mean.to(latent.device, latent.dtype)
+ latents_std = self.latents_std.to(latent.device, latent.dtype)
+ return (latent - latents_mean) * self.scale_factor / latents_std
+
+ def process_out(self, latent):
+ latents_mean = self.latents_mean.to(latent.device, latent.dtype)
+ latents_std = self.latents_std.to(latent.device, latent.dtype)
+ return latent * latents_std / self.scale_factor + latents_mean
+
+class Hunyuan3Dv2(LatentFormat):
+ latent_channels = 64
+ latent_dimensions = 1
+ scale_factor = 0.9990943042622529
+
+class Hunyuan3Dv2mini(LatentFormat):
+ latent_channels = 64
+ latent_dimensions = 1
+ scale_factor = 1.0188137142395404
+
+class ACEAudio(LatentFormat):
+ latent_channels = 8
+ latent_dimensions = 2
diff --git a/comfy/ldm/ace/attention.py b/comfy/ldm/ace/attention.py
new file mode 100644
index 000000000..f20a01669
--- /dev/null
+++ b/comfy/ldm/ace/attention.py
@@ -0,0 +1,761 @@
+# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/attention.py
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Tuple, Union, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+import comfy.model_management
+from comfy.ldm.modules.attention import optimized_attention
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ kv_heads: Optional[int] = None,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias: bool = False,
+ qk_norm: Optional[str] = None,
+ added_kv_proj_dim: Optional[int] = None,
+ added_proj_bias: Optional[bool] = True,
+ out_bias: bool = True,
+ scale_qk: bool = True,
+ only_cross_attention: bool = False,
+ eps: float = 1e-5,
+ rescale_output_factor: float = 1.0,
+ residual_connection: bool = False,
+ processor=None,
+ out_dim: int = None,
+ out_context_dim: int = None,
+ context_pre_only=None,
+ pre_only=False,
+ elementwise_affine: bool = True,
+ is_causal: bool = False,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.inner_kv_dim = self.inner_dim if kv_heads is None else dim_head * kv_heads
+ self.query_dim = query_dim
+ self.use_bias = bias
+ self.is_cross_attention = cross_attention_dim is not None
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.rescale_output_factor = rescale_output_factor
+ self.residual_connection = residual_connection
+ self.dropout = dropout
+ self.fused_projections = False
+ self.out_dim = out_dim if out_dim is not None else query_dim
+ self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
+ self.context_pre_only = context_pre_only
+ self.pre_only = pre_only
+ self.is_causal = is_causal
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+
+ self.added_kv_proj_dim = added_kv_proj_dim
+ self.only_cross_attention = only_cross_attention
+
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
+ raise ValueError(
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
+ )
+
+ self.group_norm = None
+ self.spatial_norm = None
+
+ self.norm_q = None
+ self.norm_k = None
+
+ self.norm_cross = None
+ self.to_q = operations.Linear(query_dim, self.inner_dim, bias=bias, dtype=dtype, device=device)
+
+ if not self.only_cross_attention:
+ # only relevant for the `AddedKVProcessor` classes
+ self.to_k = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
+ self.to_v = operations.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias, dtype=dtype, device=device)
+ else:
+ self.to_k = None
+ self.to_v = None
+
+ self.added_proj_bias = added_proj_bias
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
+ self.add_v_proj = operations.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=added_proj_bias, dtype=dtype, device=device)
+ if self.context_pre_only is not None:
+ self.add_q_proj = operations.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias, dtype=dtype, device=device)
+ else:
+ self.add_q_proj = None
+ self.add_k_proj = None
+ self.add_v_proj = None
+
+ if not self.pre_only:
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(operations.Linear(self.inner_dim, self.out_dim, bias=out_bias, dtype=dtype, device=device))
+ self.to_out.append(nn.Dropout(dropout))
+ else:
+ self.to_out = None
+
+ if self.context_pre_only is not None and not self.context_pre_only:
+ self.to_add_out = operations.Linear(self.inner_dim, self.out_context_dim, bias=out_bias, dtype=dtype, device=device)
+ else:
+ self.to_add_out = None
+
+ self.norm_added_q = None
+ self.norm_added_k = None
+ self.processor = processor
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ **cross_attention_kwargs,
+ ) -> torch.Tensor:
+ return self.processor(
+ self,
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+
+
+class CustomLiteLAProcessor2_0:
+ """Attention processor used typically in processing the SD3-like self-attention projections. add rms norm for query and key and apply RoPE"""
+
+ def __init__(self):
+ self.kernel_func = nn.ReLU(inplace=False)
+ self.eps = 1e-15
+ self.pad_val = 1.0
+
+ def apply_rotary_emb(
+ self,
+ x: torch.Tensor,
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
+ tensors contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x (`torch.Tensor`):
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+ """
+ cos, sin = freqs_cis # [S, D]
+ cos = cos[None, None]
+ sin = sin[None, None]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+
+ return out
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
+ rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ hidden_states_len = hidden_states.shape[1]
+
+ input_ndim = hidden_states.ndim
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+ if encoder_hidden_states is not None:
+ context_input_ndim = encoder_hidden_states.ndim
+ if context_input_ndim == 4:
+ batch_size, channel, height, width = encoder_hidden_states.shape
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size = hidden_states.shape[0]
+
+ # `sample` projections.
+ dtype = hidden_states.dtype
+ query = attn.to_q(hidden_states)
+ key = attn.to_k(hidden_states)
+ value = attn.to_v(hidden_states)
+
+ # `context` projections.
+ has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
+ if encoder_hidden_states is not None and has_encoder_hidden_state_proj:
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
+
+ # attention
+ if not attn.is_cross_attention:
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
+ else:
+ query = hidden_states
+ key = encoder_hidden_states
+ value = encoder_hidden_states
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
+ key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
+ value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
+
+ # RoPE需要 [B, H, S, D] 输入
+ # 此时 query是 [B, H, D, S], 需要转成 [B, H, S, D] 才能应用RoPE
+ query = query.permute(0, 1, 3, 2) # [B, H, S, D] (从 [B, H, D, S])
+
+ # Apply query and key normalization if needed
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if rotary_freqs_cis is not None:
+ query = self.apply_rotary_emb(query, rotary_freqs_cis)
+ if not attn.is_cross_attention:
+ key = self.apply_rotary_emb(key, rotary_freqs_cis)
+ elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
+ key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
+
+ # 此时 query是 [B, H, S, D],需要还原成 [B, H, D, S]
+ query = query.permute(0, 1, 3, 2) # [B, H, D, S]
+
+ if attention_mask is not None:
+ # attention_mask: [B, S] -> [B, 1, S, 1]
+ attention_mask = attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S, 1]
+ query = query * attention_mask.permute(0, 1, 3, 2) # [B, H, S, D] * [B, 1, S, 1]
+ if not attn.is_cross_attention:
+ key = key * attention_mask # key: [B, h, S, D] 与 mask [B, 1, S, 1] 相乘
+ value = value * attention_mask.permute(0, 1, 3, 2) # 如果 value 是 [B, h, D, S],那么需调整mask以匹配S维度
+
+ if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
+ encoder_attention_mask = encoder_attention_mask[:, None, :, None].to(key.dtype) # [B, 1, S_enc, 1]
+ # 此时 key: [B, h, S_enc, D], value: [B, h, D, S_enc]
+ key = key * encoder_attention_mask # [B, h, S_enc, D] * [B, 1, S_enc, 1]
+ value = value * encoder_attention_mask.permute(0, 1, 3, 2) # [B, h, D, S_enc] * [B, 1, 1, S_enc]
+
+ query = self.kernel_func(query)
+ key = self.kernel_func(key)
+
+ query, key, value = query.float(), key.float(), value.float()
+
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
+
+ vk = torch.matmul(value, key)
+
+ hidden_states = torch.matmul(vk, query)
+
+ if hidden_states.dtype in [torch.float16, torch.bfloat16]:
+ hidden_states = hidden_states.float()
+
+ hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
+
+ hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
+
+ hidden_states = hidden_states.to(dtype)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = encoder_hidden_states.to(dtype)
+
+ # Split the attention outputs.
+ if encoder_hidden_states is not None and not attn.is_cross_attention and has_encoder_hidden_state_proj:
+ hidden_states, encoder_hidden_states = (
+ hidden_states[:, : hidden_states_len],
+ hidden_states[:, hidden_states_len:],
+ )
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+ if encoder_hidden_states is not None and not attn.context_pre_only and not attn.is_cross_attention and hasattr(attn, "to_add_out"):
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+ if encoder_hidden_states is not None and context_input_ndim == 4:
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if torch.get_autocast_gpu_dtype() == torch.float16:
+ hidden_states = hidden_states.clip(-65504, 65504)
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
+
+ return hidden_states, encoder_hidden_states
+
+
+class CustomerAttnProcessor2_0:
+ r"""
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
+ """
+
+ def apply_rotary_emb(
+ self,
+ x: torch.Tensor,
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
+ tensors contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x (`torch.Tensor`):
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+ """
+ cos, sin = freqs_cis # [S, D]
+ cos = cos[None, None]
+ sin = sin[None, None]
+ cos, sin = cos.to(x.device), sin.to(x.device)
+
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
+
+ return out
+
+ def __call__(
+ self,
+ attn: Attention,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
+ rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+
+ residual = hidden_states
+ input_ndim = hidden_states.ndim
+
+ if input_ndim == 4:
+ batch_size, channel, height, width = hidden_states.shape
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
+
+ batch_size, sequence_length, _ = (
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
+ )
+
+ has_encoder_hidden_state_proj = hasattr(attn, "add_q_proj") and hasattr(attn, "add_k_proj") and hasattr(attn, "add_v_proj")
+
+ if attn.group_norm is not None:
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = attn.to_q(hidden_states)
+
+ if encoder_hidden_states is None:
+ encoder_hidden_states = hidden_states
+ elif attn.norm_cross:
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
+
+ key = attn.to_k(encoder_hidden_states)
+ value = attn.to_v(encoder_hidden_states)
+
+ inner_dim = key.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
+
+ if attn.norm_q is not None:
+ query = attn.norm_q(query)
+ if attn.norm_k is not None:
+ key = attn.norm_k(key)
+
+ # Apply RoPE if needed
+ if rotary_freqs_cis is not None:
+ query = self.apply_rotary_emb(query, rotary_freqs_cis)
+ if not attn.is_cross_attention:
+ key = self.apply_rotary_emb(key, rotary_freqs_cis)
+ elif rotary_freqs_cis_cross is not None and has_encoder_hidden_state_proj:
+ key = self.apply_rotary_emb(key, rotary_freqs_cis_cross)
+
+ if attn.is_cross_attention and encoder_attention_mask is not None and has_encoder_hidden_state_proj:
+ # attention_mask: N x S1
+ # encoder_attention_mask: N x S2
+ # cross attention 整合attention_mask和encoder_attention_mask
+ combined_mask = attention_mask[:, :, None] * encoder_attention_mask[:, None, :]
+ attention_mask = torch.where(combined_mask == 1, 0.0, -torch.inf)
+ attention_mask = attention_mask[:, None, :, :].expand(-1, attn.heads, -1, -1).to(query.dtype)
+
+ elif not attn.is_cross_attention and attention_mask is not None:
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
+ # scaled_dot_product_attention expects attention_mask shape to be
+ # (batch, heads, source_length, target_length)
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = optimized_attention(
+ query, key, value, heads=query.shape[1], mask=attention_mask, skip_reshape=True,
+ ).to(query.dtype)
+
+ # linear proj
+ hidden_states = attn.to_out[0](hidden_states)
+ # dropout
+ hidden_states = attn.to_out[1](hidden_states)
+
+ if input_ndim == 4:
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
+
+ if attn.residual_connection:
+ hidden_states = hidden_states + residual
+
+ hidden_states = hidden_states / attn.rescale_output_factor
+
+ return hidden_states
+
+def val2list(x: list or tuple or any, repeat_time=1) -> list: # type: ignore
+ """Repeat `val` for `repeat_time` times and return the list or val if list/tuple."""
+ if isinstance(x, (list, tuple)):
+ return list(x)
+ return [x for _ in range(repeat_time)]
+
+
+def val2tuple(x: list or tuple or any, min_len: int = 1, idx_repeat: int = -1) -> tuple: # type: ignore
+ """Return tuple with min_len by repeating element at idx_repeat."""
+ # convert to list first
+ x = val2list(x)
+
+ # repeat elements if necessary
+ if len(x) > 0:
+ x[idx_repeat:idx_repeat] = [x[idx_repeat] for _ in range(min_len - len(x))]
+
+ return tuple(x)
+
+
+def t2i_modulate(x, shift, scale):
+ return x * (1 + scale) + shift
+
+
+def get_same_padding(kernel_size: Union[int, Tuple[int, ...]]) -> Union[int, Tuple[int, ...]]:
+ if isinstance(kernel_size, tuple):
+ return tuple([get_same_padding(ks) for ks in kernel_size])
+ else:
+ assert kernel_size % 2 > 0, f"kernel size {kernel_size} should be odd number"
+ return kernel_size // 2
+
+class ConvLayer(nn.Module):
+ def __init__(
+ self,
+ in_dim: int,
+ out_dim: int,
+ kernel_size=3,
+ stride=1,
+ dilation=1,
+ groups=1,
+ padding: Union[int, None] = None,
+ use_bias=False,
+ norm=None,
+ act=None,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+ if padding is None:
+ padding = get_same_padding(kernel_size)
+ padding *= dilation
+
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.dilation = dilation
+ self.groups = groups
+ self.padding = padding
+ self.use_bias = use_bias
+
+ self.conv = operations.Conv1d(
+ in_dim,
+ out_dim,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=use_bias,
+ device=device,
+ dtype=dtype
+ )
+ if norm is not None:
+ self.norm = operations.RMSNorm(out_dim, elementwise_affine=False, dtype=dtype, device=device)
+ else:
+ self.norm = None
+ if act is not None:
+ self.act = nn.SiLU(inplace=True)
+ else:
+ self.act = None
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.conv(x)
+ if self.norm:
+ x = self.norm(x)
+ if self.act:
+ x = self.act(x)
+ return x
+
+
+class GLUMBConv(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: int,
+ out_feature=None,
+ kernel_size=3,
+ stride=1,
+ padding: Union[int, None] = None,
+ use_bias=False,
+ norm=(None, None, None),
+ act=("silu", "silu", None),
+ dilation=1,
+ dtype=None, device=None, operations=None
+ ):
+ out_feature = out_feature or in_features
+ super().__init__()
+ use_bias = val2tuple(use_bias, 3)
+ norm = val2tuple(norm, 3)
+ act = val2tuple(act, 3)
+
+ self.glu_act = nn.SiLU(inplace=False)
+ self.inverted_conv = ConvLayer(
+ in_features,
+ hidden_features * 2,
+ 1,
+ use_bias=use_bias[0],
+ norm=norm[0],
+ act=act[0],
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ self.depth_conv = ConvLayer(
+ hidden_features * 2,
+ hidden_features * 2,
+ kernel_size,
+ stride=stride,
+ groups=hidden_features * 2,
+ padding=padding,
+ use_bias=use_bias[1],
+ norm=norm[1],
+ act=None,
+ dilation=dilation,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ self.point_conv = ConvLayer(
+ hidden_features,
+ out_feature,
+ 1,
+ use_bias=use_bias[2],
+ norm=norm[2],
+ act=act[2],
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = x.transpose(1, 2)
+ x = self.inverted_conv(x)
+ x = self.depth_conv(x)
+
+ x, gate = torch.chunk(x, 2, dim=1)
+ gate = self.glu_act(gate)
+ x = x * gate
+
+ x = self.point_conv(x)
+ x = x.transpose(1, 2)
+
+ return x
+
+
+class LinearTransformerBlock(nn.Module):
+ """
+ A Sana block with global shared adaptive layer norm (adaLN-single) conditioning.
+ """
+ def __init__(
+ self,
+ dim,
+ num_attention_heads,
+ attention_head_dim,
+ use_adaln_single=True,
+ cross_attention_dim=None,
+ added_kv_proj_dim=None,
+ context_pre_only=False,
+ mlp_ratio=4.0,
+ add_cross_attention=False,
+ add_cross_attention_dim=None,
+ qk_norm=None,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+
+ self.norm1 = operations.RMSNorm(dim, elementwise_affine=False, eps=1e-6)
+ self.attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ added_kv_proj_dim=added_kv_proj_dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ bias=True,
+ qk_norm=qk_norm,
+ processor=CustomLiteLAProcessor2_0(),
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+
+ self.add_cross_attention = add_cross_attention
+ self.context_pre_only = context_pre_only
+
+ if add_cross_attention and add_cross_attention_dim is not None:
+ self.cross_attn = Attention(
+ query_dim=dim,
+ cross_attention_dim=add_cross_attention_dim,
+ added_kv_proj_dim=add_cross_attention_dim,
+ dim_head=attention_head_dim,
+ heads=num_attention_heads,
+ out_dim=dim,
+ context_pre_only=context_pre_only,
+ bias=True,
+ qk_norm=qk_norm,
+ processor=CustomerAttnProcessor2_0(),
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+
+ self.norm2 = operations.RMSNorm(dim, 1e-06, elementwise_affine=False)
+
+ self.ff = GLUMBConv(
+ in_features=dim,
+ hidden_features=int(dim * mlp_ratio),
+ use_bias=(True, True, False),
+ norm=(None, None, None),
+ act=("silu", "silu", None),
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ self.use_adaln_single = use_adaln_single
+ if use_adaln_single:
+ self.scale_shift_table = nn.Parameter(torch.empty(6, dim, dtype=dtype, device=device))
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: torch.FloatTensor = None,
+ attention_mask: torch.FloatTensor = None,
+ encoder_attention_mask: torch.FloatTensor = None,
+ rotary_freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
+ rotary_freqs_cis_cross: Union[torch.Tensor, Tuple[torch.Tensor]] = None,
+ temb: torch.FloatTensor = None,
+ ):
+
+ N = hidden_states.shape[0]
+
+ # step 1: AdaLN single
+ if self.use_adaln_single:
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
+ comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
+ ).chunk(6, dim=1)
+
+ norm_hidden_states = self.norm1(hidden_states)
+ if self.use_adaln_single:
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
+
+ # step 2: attention
+ if not self.add_cross_attention:
+ attn_output, encoder_hidden_states = self.attn(
+ hidden_states=norm_hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ rotary_freqs_cis=rotary_freqs_cis,
+ rotary_freqs_cis_cross=rotary_freqs_cis_cross,
+ )
+ else:
+ attn_output, _ = self.attn(
+ hidden_states=norm_hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ rotary_freqs_cis=rotary_freqs_cis,
+ rotary_freqs_cis_cross=None,
+ )
+
+ if self.use_adaln_single:
+ attn_output = gate_msa * attn_output
+ hidden_states = attn_output + hidden_states
+
+ if self.add_cross_attention:
+ attn_output = self.cross_attn(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ rotary_freqs_cis=rotary_freqs_cis,
+ rotary_freqs_cis_cross=rotary_freqs_cis_cross,
+ )
+ hidden_states = attn_output + hidden_states
+
+ # step 3: add norm
+ norm_hidden_states = self.norm2(hidden_states)
+ if self.use_adaln_single:
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
+
+ # step 4: feed forward
+ ff_output = self.ff(norm_hidden_states)
+ if self.use_adaln_single:
+ ff_output = gate_mlp * ff_output
+
+ hidden_states = hidden_states + ff_output
+
+ return hidden_states
diff --git a/comfy/ldm/ace/lyric_encoder.py b/comfy/ldm/ace/lyric_encoder.py
new file mode 100644
index 000000000..ff4359b26
--- /dev/null
+++ b/comfy/ldm/ace/lyric_encoder.py
@@ -0,0 +1,1067 @@
+# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/lyrics_utils/lyric_encoder.py
+from typing import Optional, Tuple, Union
+import math
+import torch
+from torch import nn
+
+import comfy.model_management
+
+class ConvolutionModule(nn.Module):
+ """ConvolutionModule in Conformer model."""
+
+ def __init__(self,
+ channels: int,
+ kernel_size: int = 15,
+ activation: nn.Module = nn.ReLU(),
+ norm: str = "batch_norm",
+ causal: bool = False,
+ bias: bool = True,
+ dtype=None, device=None, operations=None):
+ """Construct an ConvolutionModule object.
+ Args:
+ channels (int): The number of channels of conv layers.
+ kernel_size (int): Kernel size of conv layers.
+ causal (int): Whether use causal convolution or not
+ """
+ super().__init__()
+
+ self.pointwise_conv1 = operations.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ dtype=dtype, device=device
+ )
+ # self.lorder is used to distinguish if it's a causal convolution,
+ # if self.lorder > 0: it's a causal convolution, the input will be
+ # padded with self.lorder frames on the left in forward.
+ # else: it's a symmetrical convolution
+ if causal:
+ padding = 0
+ self.lorder = kernel_size - 1
+ else:
+ # kernel_size should be an odd number for none causal convolution
+ assert (kernel_size - 1) % 2 == 0
+ padding = (kernel_size - 1) // 2
+ self.lorder = 0
+ self.depthwise_conv = operations.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ bias=bias,
+ dtype=dtype, device=device
+ )
+
+ assert norm in ['batch_norm', 'layer_norm']
+ if norm == "batch_norm":
+ self.use_layer_norm = False
+ self.norm = nn.BatchNorm1d(channels)
+ else:
+ self.use_layer_norm = True
+ self.norm = operations.LayerNorm(channels, dtype=dtype, device=device)
+
+ self.pointwise_conv2 = operations.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=bias,
+ dtype=dtype, device=device
+ )
+ self.activation = activation
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ cache: torch.Tensor = torch.zeros((0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute convolution module.
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, channels).
+ mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
+ (0, 0, 0) means fake mask.
+ cache (torch.Tensor): left context cache, it is only
+ used in causal convolution (#batch, channels, cache_t),
+ (0, 0, 0) meas fake cache.
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, channels).
+ """
+ # exchange the temporal dimension and the feature dimension
+ x = x.transpose(1, 2) # (#batch, channels, time)
+
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ x.masked_fill_(~mask_pad, 0.0)
+
+ if self.lorder > 0:
+ if cache.size(2) == 0: # cache_t == 0
+ x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
+ else:
+ assert cache.size(0) == x.size(0) # equal batch
+ assert cache.size(1) == x.size(1) # equal channel
+ x = torch.cat((cache, x), dim=2)
+ assert (x.size(2) > self.lorder)
+ new_cache = x[:, :, -self.lorder:]
+ else:
+ # It's better we just return None if no cache is required,
+ # However, for JIT export, here we just fake one tensor instead of
+ # None.
+ new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+
+ # GLU mechanism
+ x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
+ x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
+
+ # 1D Depthwise Conv
+ x = self.depthwise_conv(x)
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.activation(self.norm(x))
+ if self.use_layer_norm:
+ x = x.transpose(1, 2)
+ x = self.pointwise_conv2(x)
+ # mask batch padding
+ if mask_pad.size(2) > 0: # time > 0
+ x.masked_fill_(~mask_pad, 0.0)
+
+ return x.transpose(1, 2), new_cache
+
+class PositionwiseFeedForward(torch.nn.Module):
+ """Positionwise feed forward layer.
+
+ FeedForward are appied on each position of the sequence.
+ The output dim is same with the input dim.
+
+ Args:
+ idim (int): Input dimenstion.
+ hidden_units (int): The number of hidden units.
+ dropout_rate (float): Dropout rate.
+ activation (torch.nn.Module): Activation function
+ """
+
+ def __init__(
+ self,
+ idim: int,
+ hidden_units: int,
+ dropout_rate: float,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ dtype=None, device=None, operations=None
+ ):
+ """Construct a PositionwiseFeedForward object."""
+ super(PositionwiseFeedForward, self).__init__()
+ self.w_1 = operations.Linear(idim, hidden_units, dtype=dtype, device=device)
+ self.activation = activation
+ self.dropout = torch.nn.Dropout(dropout_rate)
+ self.w_2 = operations.Linear(hidden_units, idim, dtype=dtype, device=device)
+
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
+ """Forward function.
+
+ Args:
+ xs: input tensor (B, L, D)
+ Returns:
+ output tensor, (B, L, D)
+ """
+ return self.w_2(self.dropout(self.activation(self.w_1(xs))))
+
+class Swish(torch.nn.Module):
+ """Construct an Swish object."""
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Return Swish activation function."""
+ return x * torch.sigmoid(x)
+
+class MultiHeadedAttention(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self,
+ n_head: int,
+ n_feat: int,
+ dropout_rate: float,
+ key_bias: bool = True,
+ dtype=None, device=None, operations=None):
+ """Construct an MultiHeadedAttention object."""
+ super().__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_q = operations.Linear(n_feat, n_feat, dtype=dtype, device=device)
+ self.linear_k = operations.Linear(n_feat, n_feat, bias=key_bias, dtype=dtype, device=device)
+ self.linear_v = operations.Linear(n_feat, n_feat, dtype=dtype, device=device)
+ self.linear_out = operations.Linear(n_feat, n_feat, dtype=dtype, device=device)
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(
+ self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor, size
+ (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor, size
+ (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor, size
+ (#batch, n_head, time2, d_k).
+
+ """
+ n_batch = query.size(0)
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
+ return q, k, v
+
+ def forward_attention(
+ self,
+ value: torch.Tensor,
+ scores: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
+ ) -> torch.Tensor:
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value, size
+ (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score, size
+ (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask, size (#batch, 1, time2) or
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+
+ if mask is not None and mask.size(2) > 0: # time2 > 0
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ # For last chunk, time2 might be larger than scores.size(-1)
+ mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
+ scores = scores.masked_fill(mask, -float('inf'))
+ attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0) # (batch, head, time1, time2)
+
+ else:
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
+ self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0),
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+ 1.When applying cross attention between decoder and encoder,
+ the batch padding mask for input is in (#batch, 1, T) shape.
+ 2.When applying self attention of encoder,
+ the mask is in (#batch, T, T) shape.
+ 3.When applying self attention of decoder,
+ the mask is in (#batch, L, L) shape.
+ 4.If the different position in decoder see different block
+ of the encoder, such as Mocha, the passed in mask could be
+ in (#batch, L, T) shape. But there is no such case in current
+ CosyVoice.
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ if cache.size(0) > 0:
+ key_cache, value_cache = torch.split(cache,
+ cache.size(-1) // 2,
+ dim=-1)
+ k = torch.cat([key_cache, k], dim=2)
+ v = torch.cat([value_cache, v], dim=2)
+ new_cache = torch.cat((k, v), dim=-1)
+
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, mask), new_cache
+
+
+class RelPositionMultiHeadedAttention(MultiHeadedAttention):
+ """Multi-Head Attention layer with relative position encoding.
+ Paper: https://arxiv.org/abs/1901.02860
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+ """
+
+ def __init__(self,
+ n_head: int,
+ n_feat: int,
+ dropout_rate: float,
+ key_bias: bool = True,
+ dtype=None, device=None, operations=None):
+ """Construct an RelPositionMultiHeadedAttention object."""
+ super().__init__(n_head, n_feat, dropout_rate, key_bias, dtype=dtype, device=device, operations=operations)
+ # linear transformation for positional encoding
+ self.linear_pos = operations.Linear(n_feat, n_feat, bias=False, dtype=dtype, device=device)
+ # these two learnable bias are used in matrix c and matrix d
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ self.pos_bias_u = nn.Parameter(torch.empty(self.h, self.d_k, dtype=dtype, device=device))
+ self.pos_bias_v = nn.Parameter(torch.empty(self.h, self.d_k, dtype=dtype, device=device))
+ # torch.nn.init.xavier_uniform_(self.pos_bias_u)
+ # torch.nn.init.xavier_uniform_(self.pos_bias_v)
+
+ def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
+ """Compute relative positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
+ time1 means the length of query vector.
+
+ Returns:
+ torch.Tensor: Output tensor.
+
+ """
+ zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
+ device=x.device,
+ dtype=x.dtype)
+ x_padded = torch.cat([zero_pad, x], dim=-1)
+
+ x_padded = x_padded.view(x.size()[0],
+ x.size()[1],
+ x.size(3) + 1, x.size(2))
+ x = x_padded[:, :, 1:].view_as(x)[
+ :, :, :, : x.size(-1) // 2 + 1
+ ] # only keep the positions from 0 to time2
+ return x
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ pos_emb: torch.Tensor = torch.empty(0),
+ cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute 'Scaled Dot Product Attention' with rel. positional encoding.
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2), (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): Positional embedding tensor
+ (#batch, time2, size).
+ cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+ torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
+ where `cache_t == chunk_size * num_decoding_left_chunks`
+ and `head * d_k == size`
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ q = q.transpose(1, 2) # (batch, time1, head, d_k)
+
+ if cache.size(0) > 0:
+ key_cache, value_cache = torch.split(cache,
+ cache.size(-1) // 2,
+ dim=-1)
+ k = torch.cat([key_cache, k], dim=2)
+ v = torch.cat([value_cache, v], dim=2)
+ # NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
+ # non-trivial to calculate `next_cache_start` here.
+ new_cache = torch.cat((k, v), dim=-1)
+
+ n_batch_pos = pos_emb.size(0)
+ p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
+ p = p.transpose(1, 2) # (batch, head, time1, d_k)
+
+ # (batch, head, time1, d_k)
+ q_with_bias_u = (q + comfy.model_management.cast_to(self.pos_bias_u, dtype=q.dtype, device=q.device)).transpose(1, 2)
+ # (batch, head, time1, d_k)
+ q_with_bias_v = (q + comfy.model_management.cast_to(self.pos_bias_v, dtype=q.dtype, device=q.device)).transpose(1, 2)
+
+ # compute attention score
+ # first compute matrix a and matrix c
+ # as described in https://arxiv.org/abs/1901.02860 Section 3.3
+ # (batch, head, time1, time2)
+ matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
+
+ # compute matrix b and matrix d
+ # (batch, head, time1, time2)
+ matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
+ # NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
+ if matrix_ac.shape != matrix_bd.shape:
+ matrix_bd = self.rel_shift(matrix_bd)
+
+ scores = (matrix_ac + matrix_bd) / math.sqrt(
+ self.d_k) # (batch, head, time1, time2)
+
+ return self.forward_attention(v, scores, mask), new_cache
+
+
+
+def subsequent_mask(
+ size: int,
+ device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size).
+
+ This mask is used only in decoder which works in an auto-regressive mode.
+ This means the current step could only do attention with its left steps.
+
+ In encoder, fully attention is used when streaming is not necessary and
+ the sequence is not long. In this case, no attention mask is needed.
+
+ When streaming is need, chunk-based attention is used in encoder. See
+ subsequent_chunk_mask for the chunk-based attention mask.
+
+ Args:
+ size (int): size of mask
+ str device (str): "cpu" or "cuda" or torch.Tensor.device
+ dtype (torch.device): result dtype
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_mask(3)
+ [[1, 0, 0],
+ [1, 1, 0],
+ [1, 1, 1]]
+ """
+ arange = torch.arange(size, device=device)
+ mask = arange.expand(size, size)
+ arange = arange.unsqueeze(-1)
+ mask = mask <= arange
+ return mask
+
+
+def subsequent_chunk_mask(
+ size: int,
+ chunk_size: int,
+ num_left_chunks: int = -1,
+ device: torch.device = torch.device("cpu"),
+ ) -> torch.Tensor:
+ """Create mask for subsequent steps (size, size) with chunk size,
+ this is for streaming encoder
+
+ Args:
+ size (int): size of mask
+ chunk_size (int): size of chunk
+ num_left_chunks (int): number of left chunks
+ <0: use full chunk
+ >=0: use num_left_chunks
+ device (torch.device): "cpu" or "cuda" or torch.Tensor.device
+
+ Returns:
+ torch.Tensor: mask
+
+ Examples:
+ >>> subsequent_chunk_mask(4, 2)
+ [[1, 1, 0, 0],
+ [1, 1, 0, 0],
+ [1, 1, 1, 1],
+ [1, 1, 1, 1]]
+ """
+ ret = torch.zeros(size, size, device=device, dtype=torch.bool)
+ for i in range(size):
+ if num_left_chunks < 0:
+ start = 0
+ else:
+ start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
+ ending = min((i // chunk_size + 1) * chunk_size, size)
+ ret[i, start:ending] = True
+ return ret
+
+def add_optional_chunk_mask(xs: torch.Tensor,
+ masks: torch.Tensor,
+ use_dynamic_chunk: bool,
+ use_dynamic_left_chunk: bool,
+ decoding_chunk_size: int,
+ static_chunk_size: int,
+ num_decoding_left_chunks: int,
+ enable_full_context: bool = True):
+ """ Apply optional mask for encoder.
+
+ Args:
+ xs (torch.Tensor): padded input, (B, L, D), L for max length
+ mask (torch.Tensor): mask for xs, (B, 1, L)
+ use_dynamic_chunk (bool): whether to use dynamic chunk or not
+ use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
+ training.
+ decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
+ 0: default for training, use random dynamic chunk.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ static_chunk_size (int): chunk size for static chunk training/decoding
+ if it's greater than 0, if use_dynamic_chunk is true,
+ this parameter will be ignored
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
+ the chunk size is decoding_chunk_size.
+ >=0: use num_decoding_left_chunks
+ <0: use all left chunks
+ enable_full_context (bool):
+ True: chunk size is either [1, 25] or full context(max_len)
+ False: chunk size ~ U[1, 25]
+
+ Returns:
+ torch.Tensor: chunk mask of the input xs.
+ """
+ # Whether to use chunk mask or not
+ if use_dynamic_chunk:
+ max_len = xs.size(1)
+ if decoding_chunk_size < 0:
+ chunk_size = max_len
+ num_left_chunks = -1
+ elif decoding_chunk_size > 0:
+ chunk_size = decoding_chunk_size
+ num_left_chunks = num_decoding_left_chunks
+ else:
+ # chunk size is either [1, 25] or full context(max_len).
+ # Since we use 4 times subsampling and allow up to 1s(100 frames)
+ # delay, the maximum frame is 100 / 4 = 25.
+ chunk_size = torch.randint(1, max_len, (1, )).item()
+ num_left_chunks = -1
+ if chunk_size > max_len // 2 and enable_full_context:
+ chunk_size = max_len
+ else:
+ chunk_size = chunk_size % 25 + 1
+ if use_dynamic_left_chunk:
+ max_left_chunks = (max_len - 1) // chunk_size
+ num_left_chunks = torch.randint(0, max_left_chunks,
+ (1, )).item()
+ chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
+ num_left_chunks,
+ xs.device) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ elif static_chunk_size > 0:
+ num_left_chunks = num_decoding_left_chunks
+ chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
+ num_left_chunks,
+ xs.device) # (L, L)
+ chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
+ chunk_masks = masks & chunk_masks # (B, L, L)
+ else:
+ chunk_masks = masks
+ return chunk_masks
+
+
+class ConformerEncoderLayer(nn.Module):
+ """Encoder layer module.
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
+ instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ feed_forward_macaron (torch.nn.Module): Additional feed-forward module
+ instance.
+ `PositionwiseFeedForward` instance can be used as the argument.
+ conv_module (torch.nn.Module): Convolution module instance.
+ `ConvlutionModule` instance can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool):
+ True: use layer_norm before each sub-block.
+ False: use layer_norm after each sub-block.
+ """
+
+ def __init__(
+ self,
+ size: int,
+ self_attn: torch.nn.Module,
+ feed_forward: Optional[nn.Module] = None,
+ feed_forward_macaron: Optional[nn.Module] = None,
+ conv_module: Optional[nn.Module] = None,
+ dropout_rate: float = 0.1,
+ normalize_before: bool = True,
+ dtype=None, device=None, operations=None
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.conv_module = conv_module
+ self.norm_ff = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device) # for the FNN module
+ self.norm_mha = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device) # for the MHA module
+ if feed_forward_macaron is not None:
+ self.norm_ff_macaron = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device)
+ self.ff_scale = 0.5
+ else:
+ self.ff_scale = 1.0
+ if self.conv_module is not None:
+ self.norm_conv = operations.LayerNorm(size, eps=1e-5, dtype=dtype, device=device) # for the CNN module
+ self.norm_final = operations.LayerNorm(
+ size, eps=1e-5, dtype=dtype, device=device) # for the final output of the block
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
+ att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Compute encoded features.
+
+ Args:
+ x (torch.Tensor): (#batch, time, size)
+ mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
+ (0, 0, 0) means fake mask.
+ pos_emb (torch.Tensor): positional encoding, must not be None
+ for ConformerEncoderLayer.
+ mask_pad (torch.Tensor): batch padding mask used for conv module.
+ (#batch, 1,time), (0, 0, 0) means fake mask.
+ att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
+ (#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
+ cnn_cache (torch.Tensor): Convolution cache in conformer layer
+ (#batch=1, size, cache_t2)
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time, time).
+ torch.Tensor: att_cache tensor,
+ (#batch=1, head, cache_t1 + time, d_k * 2).
+ torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
+ """
+
+ # whether to use macaron style
+ if self.feed_forward_macaron is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff_macaron(x)
+ x = residual + self.ff_scale * self.dropout(
+ self.feed_forward_macaron(x))
+ if not self.normalize_before:
+ x = self.norm_ff_macaron(x)
+
+ # multi-headed self-attention module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_mha(x)
+ x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
+ att_cache)
+ x = residual + self.dropout(x_att)
+ if not self.normalize_before:
+ x = self.norm_mha(x)
+
+ # convolution module
+ # Fake new cnn cache here, and then change it in conv_module
+ new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
+ if self.conv_module is not None:
+ residual = x
+ if self.normalize_before:
+ x = self.norm_conv(x)
+ x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
+ x = residual + self.dropout(x)
+
+ if not self.normalize_before:
+ x = self.norm_conv(x)
+
+ # feed forward module
+ residual = x
+ if self.normalize_before:
+ x = self.norm_ff(x)
+
+ x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm_ff(x)
+
+ if self.conv_module is not None:
+ x = self.norm_final(x)
+
+ return x, mask, new_att_cache, new_cnn_cache
+
+
+
+class EspnetRelPositionalEncoding(torch.nn.Module):
+ """Relative positional encoding module (new implementation).
+
+ Details can be found in https://github.com/espnet/espnet/pull/2816.
+
+ See : Appendix B in https://arxiv.org/abs/1901.02860
+
+ Args:
+ d_model (int): Embedding dimension.
+ dropout_rate (float): Dropout rate.
+ max_len (int): Maximum input length.
+
+ """
+
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
+ """Construct an PositionalEncoding object."""
+ super(EspnetRelPositionalEncoding, self).__init__()
+ self.d_model = d_model
+ self.xscale = math.sqrt(self.d_model)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+ self.pe = None
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
+
+ def extend_pe(self, x: torch.Tensor):
+ """Reset the positional encodings."""
+ if self.pe is not None:
+ # self.pe contains both positive and negative parts
+ # the length of self.pe is 2 * input_len - 1
+ if self.pe.size(1) >= x.size(1) * 2 - 1:
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
+ return
+ # Suppose `i` means to the position of query vecotr and `j` means the
+ # position of key vector. We use position relative positions when keys
+ # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]:
+ """Add positional encoding.
+
+ Args:
+ x (torch.Tensor): Input tensor (batch, time, `*`).
+
+ Returns:
+ torch.Tensor: Encoded tensor (batch, time, `*`).
+
+ """
+ self.extend_pe(x)
+ x = x * self.xscale
+ pos_emb = self.position_encoding(size=x.size(1), offset=offset)
+ return self.dropout(x), self.dropout(pos_emb)
+
+ def position_encoding(self,
+ offset: Union[int, torch.Tensor],
+ size: int) -> torch.Tensor:
+ """ For getting encoding in a streaming fashion
+
+ Attention!!!!!
+ we apply dropout only once at the whole utterance level in a none
+ streaming way, but will call this function several times with
+ increasing input size in a streaming scenario, so the dropout will
+ be applied several times.
+
+ Args:
+ offset (int or torch.tensor): start offset
+ size (int): required size of position encoding
+
+ Returns:
+ torch.Tensor: Corresponding encoding
+ """
+ pos_emb = self.pe[
+ :,
+ self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
+ ]
+ return pos_emb
+
+
+
+class LinearEmbed(torch.nn.Module):
+ """Linear transform the input without subsampling
+
+ Args:
+ idim (int): Input dimension.
+ odim (int): Output dimension.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, idim: int, odim: int, dropout_rate: float,
+ pos_enc_class: torch.nn.Module, dtype=None, device=None, operations=None):
+ """Construct an linear object."""
+ super().__init__()
+ self.out = torch.nn.Sequential(
+ operations.Linear(idim, odim, dtype=dtype, device=device),
+ operations.LayerNorm(odim, eps=1e-5, dtype=dtype, device=device),
+ torch.nn.Dropout(dropout_rate),
+ )
+ self.pos_enc = pos_enc_class #rel_pos_espnet
+
+ def position_encoding(self, offset: Union[int, torch.Tensor],
+ size: int) -> torch.Tensor:
+ return self.pos_enc.position_encoding(offset, size)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ offset: Union[int, torch.Tensor] = 0
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Input x.
+
+ Args:
+ x (torch.Tensor): Input tensor (#batch, time, idim).
+ x_mask (torch.Tensor): Input mask (#batch, 1, time).
+
+ Returns:
+ torch.Tensor: linear input tensor (#batch, time', odim),
+ where time' = time .
+ torch.Tensor: linear input mask (#batch, 1, time'),
+ where time' = time .
+
+ """
+ x = self.out(x)
+ x, pos_emb = self.pos_enc(x, offset)
+ return x, pos_emb
+
+
+ATTENTION_CLASSES = {
+ "selfattn": MultiHeadedAttention,
+ "rel_selfattn": RelPositionMultiHeadedAttention,
+}
+
+ACTIVATION_CLASSES = {
+ "hardtanh": torch.nn.Hardtanh,
+ "tanh": torch.nn.Tanh,
+ "relu": torch.nn.ReLU,
+ "selu": torch.nn.SELU,
+ "swish": getattr(torch.nn, "SiLU", Swish),
+ "gelu": torch.nn.GELU,
+}
+
+
+def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
+ """Make mask tensor containing indices of padded part.
+
+ See description of make_non_pad_mask.
+
+ Args:
+ lengths (torch.Tensor): Batch of lengths (B,).
+ Returns:
+ torch.Tensor: Mask tensor containing indices of padded part.
+
+ Examples:
+ >>> lengths = [5, 3, 2]
+ >>> make_pad_mask(lengths)
+ masks = [[0, 0, 0, 0 ,0],
+ [0, 0, 0, 1, 1],
+ [0, 0, 1, 1, 1]]
+ """
+ batch_size = lengths.size(0)
+ max_len = max_len if max_len > 0 else lengths.max().item()
+ seq_range = torch.arange(0,
+ max_len,
+ dtype=torch.int64,
+ device=lengths.device)
+ seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
+ seq_length_expand = lengths.unsqueeze(-1)
+ mask = seq_range_expand >= seq_length_expand
+ return mask
+
+#https://github.com/FunAudioLLM/CosyVoice/blob/main/examples/magicdata-read/cosyvoice/conf/cosyvoice.yaml
+class ConformerEncoder(torch.nn.Module):
+ """Conformer encoder module."""
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 1024,
+ attention_heads: int = 16,
+ linear_units: int = 4096,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: str = 'linear',
+ pos_enc_layer_type: str = 'rel_pos_espnet',
+ normalize_before: bool = True,
+ static_chunk_size: int = 1, # 1: causal_mask; 0: full_mask
+ use_dynamic_chunk: bool = False,
+ use_dynamic_left_chunk: bool = False,
+ positionwise_conv_kernel_size: int = 1,
+ macaron_style: bool =False,
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = False,
+ cnn_module_kernel: int = 15,
+ causal: bool = False,
+ cnn_module_norm: str = "batch_norm",
+ key_bias: bool = True,
+ dtype=None, device=None, operations=None
+ ):
+ """Construct ConformerEncoder
+
+ Args:
+ input_size to use_dynamic_chunk, see in BaseEncoder
+ positionwise_conv_kernel_size (int): Kernel size of positionwise
+ conv1d layer.
+ macaron_style (bool): Whether to use macaron style for
+ positionwise layer.
+ selfattention_layer_type (str): Encoder attention layer type,
+ the parameter has no effect now, it's just for configure
+ compatibility. #'rel_selfattn'
+ activation_type (str): Encoder activation function type.
+ use_cnn_module (bool): Whether to use convolution module.
+ cnn_module_kernel (int): Kernel size of convolution module.
+ causal (bool): whether to use causal convolution or not.
+ key_bias: whether use bias in attention.linear_k, False for whisper models.
+ """
+ super().__init__()
+ self.output_size = output_size
+ self.embed = LinearEmbed(input_size, output_size, dropout_rate,
+ EspnetRelPositionalEncoding(output_size, positional_dropout_rate), dtype=dtype, device=device, operations=operations)
+ self.normalize_before = normalize_before
+ self.after_norm = operations.LayerNorm(output_size, eps=1e-5, dtype=dtype, device=device)
+ self.use_dynamic_chunk = use_dynamic_chunk
+
+ self.static_chunk_size = static_chunk_size
+ self.use_dynamic_chunk = use_dynamic_chunk
+ self.use_dynamic_left_chunk = use_dynamic_left_chunk
+ activation = ACTIVATION_CLASSES[activation_type]()
+
+ # self-attention module definition
+ encoder_selfattn_layer_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ key_bias,
+ )
+ # feed-forward module definition
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ activation,
+ )
+ # convolution module definition
+ convolution_layer_args = (output_size, cnn_module_kernel, activation,
+ cnn_module_norm, causal)
+
+ self.encoders = torch.nn.ModuleList([
+ ConformerEncoderLayer(
+ output_size,
+ RelPositionMultiHeadedAttention(
+ *encoder_selfattn_layer_args, dtype=dtype, device=device, operations=operations),
+ PositionwiseFeedForward(*positionwise_layer_args, dtype=dtype, device=device, operations=operations),
+ PositionwiseFeedForward(
+ *positionwise_layer_args, dtype=dtype, device=device, operations=operations) if macaron_style else None,
+ ConvolutionModule(
+ *convolution_layer_args, dtype=dtype, device=device, operations=operations) if use_cnn_module else None,
+ dropout_rate,
+ normalize_before, dtype=dtype, device=device, operations=operations
+ ) for _ in range(num_blocks)
+ ])
+
+ def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
+ pos_emb: torch.Tensor,
+ mask_pad: torch.Tensor) -> torch.Tensor:
+ for layer in self.encoders:
+ xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
+ return xs
+
+ def forward(
+ self,
+ xs: torch.Tensor,
+ pad_mask: torch.Tensor,
+ decoding_chunk_size: int = 0,
+ num_decoding_left_chunks: int = -1,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Embed positions in tensor.
+
+ Args:
+ xs: padded input tensor (B, T, D)
+ xs_lens: input length (B)
+ decoding_chunk_size: decoding chunk size for dynamic chunk
+ 0: default for training, use random dynamic chunk.
+ <0: for decoding, use full chunk.
+ >0: for decoding, use fixed chunk size as set.
+ num_decoding_left_chunks: number of left chunks, this is for decoding,
+ the chunk size is decoding_chunk_size.
+ >=0: use num_decoding_left_chunks
+ <0: use all left chunks
+ Returns:
+ encoder output tensor xs, and subsampled masks
+ xs: padded output tensor (B, T' ~= T/subsample_rate, D)
+ masks: torch.Tensor batch padding mask after subsample
+ (B, 1, T' ~= T/subsample_rate)
+ NOTE(xcsong):
+ We pass the `__call__` method of the modules instead of `forward` to the
+ checkpointing API because `__call__` attaches all the hooks of the module.
+ https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
+ """
+ masks = None
+ if pad_mask is not None:
+ masks = pad_mask.to(torch.bool).unsqueeze(1) # (B, 1, T)
+ xs, pos_emb = self.embed(xs)
+ mask_pad = masks # (B, 1, T/subsample_rate)
+ chunk_masks = add_optional_chunk_mask(xs, masks,
+ self.use_dynamic_chunk,
+ self.use_dynamic_left_chunk,
+ decoding_chunk_size,
+ self.static_chunk_size,
+ num_decoding_left_chunks)
+
+ xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
+ if self.normalize_before:
+ xs = self.after_norm(xs)
+ # Here we assume the mask is not changed in encoder layers, so just
+ # return the masks before encoder layers, and the masks will be used
+ # for cross attention with decoder later
+ return xs, masks
+
diff --git a/comfy/ldm/ace/model.py b/comfy/ldm/ace/model.py
new file mode 100644
index 000000000..12c524701
--- /dev/null
+++ b/comfy/ldm/ace/model.py
@@ -0,0 +1,385 @@
+# Original from: https://github.com/ace-step/ACE-Step/blob/main/models/ace_step_transformer.py
+
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional, List, Union
+
+import torch
+from torch import nn
+
+import comfy.model_management
+
+from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
+from .attention import LinearTransformerBlock, t2i_modulate
+from .lyric_encoder import ConformerEncoder as LyricEncoder
+
+
+def cross_norm(hidden_states, controlnet_input):
+ # input N x T x c
+ mean_hidden_states, std_hidden_states = hidden_states.mean(dim=(1,2), keepdim=True), hidden_states.std(dim=(1,2), keepdim=True)
+ mean_controlnet_input, std_controlnet_input = controlnet_input.mean(dim=(1,2), keepdim=True), controlnet_input.std(dim=(1,2), keepdim=True)
+ controlnet_input = (controlnet_input - mean_controlnet_input) * (std_hidden_states / (std_controlnet_input + 1e-12)) + mean_hidden_states
+ return controlnet_input
+
+
+# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2
+class Qwen2RotaryEmbedding(nn.Module):
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, dtype=None, device=None):
+ super().__init__()
+
+ self.dim = dim
+ self.max_position_embeddings = max_position_embeddings
+ self.base = base
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64, device=device).float() / self.dim))
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
+
+ # Build here to make `torch.jit.trace` work.
+ self._set_cos_sin_cache(
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
+ )
+
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
+ self.max_seq_len_cached = seq_len
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq)
+
+ freqs = torch.outer(t, self.inv_freq)
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
+ emb = torch.cat((freqs, freqs), dim=-1)
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
+
+ def forward(self, x, seq_len=None):
+ # x: [bs, num_attention_heads, seq_len, head_size]
+ if seq_len > self.max_seq_len_cached:
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
+
+ return (
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
+ )
+
+
+class T2IFinalLayer(nn.Module):
+ """
+ The final layer of Sana.
+ """
+
+ def __init__(self, hidden_size, patch_size=[16, 1], out_channels=256, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.norm_final = operations.RMSNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.linear = operations.Linear(hidden_size, patch_size[0] * patch_size[1] * out_channels, bias=True, dtype=dtype, device=device)
+ self.scale_shift_table = nn.Parameter(torch.empty(2, hidden_size, dtype=dtype, device=device))
+ self.out_channels = out_channels
+ self.patch_size = patch_size
+
+ def unpatchfy(
+ self,
+ hidden_states: torch.Tensor,
+ width: int,
+ ):
+ # 4 unpatchify
+ new_height, new_width = 1, hidden_states.size(1)
+ hidden_states = hidden_states.reshape(
+ shape=(hidden_states.shape[0], new_height, new_width, self.patch_size[0], self.patch_size[1], self.out_channels)
+ ).contiguous()
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(hidden_states.shape[0], self.out_channels, new_height * self.patch_size[0], new_width * self.patch_size[1])
+ ).contiguous()
+ if width > new_width:
+ output = torch.nn.functional.pad(output, (0, width - new_width, 0, 0), 'constant', 0)
+ elif width < new_width:
+ output = output[:, :, :, :width]
+ return output
+
+ def forward(self, x, t, output_length):
+ shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
+ x = t2i_modulate(self.norm_final(x), shift, scale)
+ x = self.linear(x)
+ # unpatchify
+ output = self.unpatchfy(x, output_length)
+ return output
+
+
+class PatchEmbed(nn.Module):
+ """2D Image to Patch Embedding"""
+
+ def __init__(
+ self,
+ height=16,
+ width=4096,
+ patch_size=(16, 1),
+ in_channels=8,
+ embed_dim=1152,
+ bias=True,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+ patch_size_h, patch_size_w = patch_size
+ self.early_conv_layers = nn.Sequential(
+ operations.Conv2d(in_channels, in_channels*256, kernel_size=patch_size, stride=patch_size, padding=0, bias=bias, dtype=dtype, device=device),
+ operations.GroupNorm(num_groups=32, num_channels=in_channels*256, eps=1e-6, affine=True, dtype=dtype, device=device),
+ operations.Conv2d(in_channels*256, embed_dim, kernel_size=1, stride=1, padding=0, bias=bias, dtype=dtype, device=device)
+ )
+ self.patch_size = patch_size
+ self.height, self.width = height // patch_size_h, width // patch_size_w
+ self.base_size = self.width
+
+ def forward(self, latent):
+ # early convolutions, N x C x H x W -> N x 256 * sqrt(patch_size) x H/patch_size x W/patch_size
+ latent = self.early_conv_layers(latent)
+ latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
+ return latent
+
+
+class ACEStepTransformer2DModel(nn.Module):
+ # _supports_gradient_checkpointing = True
+
+ def __init__(
+ self,
+ in_channels: Optional[int] = 8,
+ num_layers: int = 28,
+ inner_dim: int = 1536,
+ attention_head_dim: int = 64,
+ num_attention_heads: int = 24,
+ mlp_ratio: float = 4.0,
+ out_channels: int = 8,
+ max_position: int = 32768,
+ rope_theta: float = 1000000.0,
+ speaker_embedding_dim: int = 512,
+ text_embedding_dim: int = 768,
+ ssl_encoder_depths: List[int] = [9, 9],
+ ssl_names: List[str] = ["mert", "m-hubert"],
+ ssl_latent_dims: List[int] = [1024, 768],
+ lyric_encoder_vocab_size: int = 6681,
+ lyric_hidden_size: int = 1024,
+ patch_size: List[int] = [16, 1],
+ max_height: int = 16,
+ max_width: int = 4096,
+ audio_model=None,
+ dtype=None, device=None, operations=None
+
+ ):
+ super().__init__()
+
+ self.dtype = dtype
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+ self.inner_dim = inner_dim
+ self.out_channels = out_channels
+ self.max_position = max_position
+ self.patch_size = patch_size
+
+ self.rope_theta = rope_theta
+
+ self.rotary_emb = Qwen2RotaryEmbedding(
+ dim=self.attention_head_dim,
+ max_position_embeddings=self.max_position,
+ base=self.rope_theta,
+ dtype=dtype,
+ device=device,
+ )
+
+ # 2. Define input layers
+ self.in_channels = in_channels
+
+ self.num_layers = num_layers
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ LinearTransformerBlock(
+ dim=self.inner_dim,
+ num_attention_heads=self.num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ mlp_ratio=mlp_ratio,
+ add_cross_attention=True,
+ add_cross_attention_dim=self.inner_dim,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+ for i in range(self.num_layers)
+ ]
+ )
+
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=self.inner_dim, dtype=dtype, device=device, operations=operations)
+ self.t_block = nn.Sequential(nn.SiLU(), operations.Linear(self.inner_dim, 6 * self.inner_dim, bias=True, dtype=dtype, device=device))
+
+ # speaker
+ self.speaker_embedder = operations.Linear(speaker_embedding_dim, self.inner_dim, dtype=dtype, device=device)
+
+ # genre
+ self.genre_embedder = operations.Linear(text_embedding_dim, self.inner_dim, dtype=dtype, device=device)
+
+ # lyric
+ self.lyric_embs = operations.Embedding(lyric_encoder_vocab_size, lyric_hidden_size, dtype=dtype, device=device)
+ self.lyric_encoder = LyricEncoder(input_size=lyric_hidden_size, static_chunk_size=0, dtype=dtype, device=device, operations=operations)
+ self.lyric_proj = operations.Linear(lyric_hidden_size, self.inner_dim, dtype=dtype, device=device)
+
+ projector_dim = 2 * self.inner_dim
+
+ self.projectors = nn.ModuleList([
+ nn.Sequential(
+ operations.Linear(self.inner_dim, projector_dim, dtype=dtype, device=device),
+ nn.SiLU(),
+ operations.Linear(projector_dim, projector_dim, dtype=dtype, device=device),
+ nn.SiLU(),
+ operations.Linear(projector_dim, ssl_dim, dtype=dtype, device=device),
+ ) for ssl_dim in ssl_latent_dims
+ ])
+
+ self.proj_in = PatchEmbed(
+ height=max_height,
+ width=max_width,
+ patch_size=patch_size,
+ embed_dim=self.inner_dim,
+ bias=True,
+ dtype=dtype,
+ device=device,
+ operations=operations,
+ )
+
+ self.final_layer = T2IFinalLayer(self.inner_dim, patch_size=patch_size, out_channels=out_channels, dtype=dtype, device=device, operations=operations)
+
+ def forward_lyric_encoder(
+ self,
+ lyric_token_idx: Optional[torch.LongTensor] = None,
+ lyric_mask: Optional[torch.LongTensor] = None,
+ out_dtype=None,
+ ):
+ # N x T x D
+ lyric_embs = self.lyric_embs(lyric_token_idx, out_dtype=out_dtype)
+ prompt_prenet_out, _mask = self.lyric_encoder(lyric_embs, lyric_mask, decoding_chunk_size=1, num_decoding_left_chunks=-1)
+ prompt_prenet_out = self.lyric_proj(prompt_prenet_out)
+ return prompt_prenet_out
+
+ def encode(
+ self,
+ encoder_text_hidden_states: Optional[torch.Tensor] = None,
+ text_attention_mask: Optional[torch.LongTensor] = None,
+ speaker_embeds: Optional[torch.FloatTensor] = None,
+ lyric_token_idx: Optional[torch.LongTensor] = None,
+ lyric_mask: Optional[torch.LongTensor] = None,
+ lyrics_strength=1.0,
+ ):
+
+ bs = encoder_text_hidden_states.shape[0]
+ device = encoder_text_hidden_states.device
+
+ # speaker embedding
+ encoder_spk_hidden_states = self.speaker_embedder(speaker_embeds).unsqueeze(1)
+
+ # genre embedding
+ encoder_text_hidden_states = self.genre_embedder(encoder_text_hidden_states)
+
+ # lyric
+ encoder_lyric_hidden_states = self.forward_lyric_encoder(
+ lyric_token_idx=lyric_token_idx,
+ lyric_mask=lyric_mask,
+ out_dtype=encoder_text_hidden_states.dtype,
+ )
+
+ encoder_lyric_hidden_states *= lyrics_strength
+
+ encoder_hidden_states = torch.cat([encoder_spk_hidden_states, encoder_text_hidden_states, encoder_lyric_hidden_states], dim=1)
+
+ encoder_hidden_mask = None
+ if text_attention_mask is not None:
+ speaker_mask = torch.ones(bs, 1, device=device)
+ encoder_hidden_mask = torch.cat([speaker_mask, text_attention_mask, lyric_mask], dim=1)
+
+ return encoder_hidden_states, encoder_hidden_mask
+
+ def decode(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ encoder_hidden_states: torch.Tensor,
+ encoder_hidden_mask: torch.Tensor,
+ timestep: Optional[torch.Tensor],
+ output_length: int = 0,
+ block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
+ controlnet_scale: Union[float, torch.Tensor] = 1.0,
+ ):
+ embedded_timestep = self.timestep_embedder(self.time_proj(timestep).to(dtype=hidden_states.dtype))
+ temb = self.t_block(embedded_timestep)
+
+ hidden_states = self.proj_in(hidden_states)
+
+ # controlnet logic
+ if block_controlnet_hidden_states is not None:
+ control_condi = cross_norm(hidden_states, block_controlnet_hidden_states)
+ hidden_states = hidden_states + control_condi * controlnet_scale
+
+ # inner_hidden_states = []
+
+ rotary_freqs_cis = self.rotary_emb(hidden_states, seq_len=hidden_states.shape[1])
+ encoder_rotary_freqs_cis = self.rotary_emb(encoder_hidden_states, seq_len=encoder_hidden_states.shape[1])
+
+ for index_block, block in enumerate(self.transformer_blocks):
+ hidden_states = block(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_hidden_mask,
+ rotary_freqs_cis=rotary_freqs_cis,
+ rotary_freqs_cis_cross=encoder_rotary_freqs_cis,
+ temb=temb,
+ )
+
+ output = self.final_layer(hidden_states, embedded_timestep, output_length)
+ return output
+
+ def forward(
+ self,
+ x,
+ timestep,
+ attention_mask=None,
+ context: Optional[torch.Tensor] = None,
+ text_attention_mask: Optional[torch.LongTensor] = None,
+ speaker_embeds: Optional[torch.FloatTensor] = None,
+ lyric_token_idx: Optional[torch.LongTensor] = None,
+ lyric_mask: Optional[torch.LongTensor] = None,
+ block_controlnet_hidden_states: Optional[Union[List[torch.Tensor], torch.Tensor]] = None,
+ controlnet_scale: Union[float, torch.Tensor] = 1.0,
+ lyrics_strength=1.0,
+ **kwargs
+ ):
+ hidden_states = x
+ encoder_text_hidden_states = context
+ encoder_hidden_states, encoder_hidden_mask = self.encode(
+ encoder_text_hidden_states=encoder_text_hidden_states,
+ text_attention_mask=text_attention_mask,
+ speaker_embeds=speaker_embeds,
+ lyric_token_idx=lyric_token_idx,
+ lyric_mask=lyric_mask,
+ lyrics_strength=lyrics_strength,
+ )
+
+ output_length = hidden_states.shape[-1]
+
+ output = self.decode(
+ hidden_states=hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_hidden_mask=encoder_hidden_mask,
+ timestep=timestep,
+ output_length=output_length,
+ block_controlnet_hidden_states=block_controlnet_hidden_states,
+ controlnet_scale=controlnet_scale,
+ )
+
+ return output
diff --git a/comfy/ldm/ace/vae/autoencoder_dc.py b/comfy/ldm/ace/vae/autoencoder_dc.py
new file mode 100644
index 000000000..e7b1d4801
--- /dev/null
+++ b/comfy/ldm/ace/vae/autoencoder_dc.py
@@ -0,0 +1,644 @@
+# Rewritten from diffusers
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Tuple, Union
+
+import comfy.model_management
+import comfy.ops
+ops = comfy.ops.disable_weight_init
+
+
+class RMSNorm(ops.RMSNorm):
+ def __init__(self, dim, eps=1e-5, elementwise_affine=True, bias=False):
+ super().__init__(dim, eps=eps, elementwise_affine=elementwise_affine)
+ if elementwise_affine:
+ self.bias = nn.Parameter(torch.empty(dim)) if bias else None
+
+ def forward(self, x):
+ x = super().forward(x)
+ if self.elementwise_affine:
+ if self.bias is not None:
+ x = x + comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
+ return x
+
+
+def get_normalization(norm_type, num_features, num_groups=32, eps=1e-5):
+ if norm_type == "batch_norm":
+ return nn.BatchNorm2d(num_features)
+ elif norm_type == "group_norm":
+ return ops.GroupNorm(num_groups, num_features)
+ elif norm_type == "layer_norm":
+ return ops.LayerNorm(num_features)
+ elif norm_type == "rms_norm":
+ return RMSNorm(num_features, eps=eps, elementwise_affine=True, bias=True)
+ else:
+ raise ValueError(f"Unknown normalization type: {norm_type}")
+
+
+def get_activation(activation_type):
+ if activation_type == "relu":
+ return nn.ReLU()
+ elif activation_type == "relu6":
+ return nn.ReLU6()
+ elif activation_type == "silu":
+ return nn.SiLU()
+ elif activation_type == "leaky_relu":
+ return nn.LeakyReLU(0.2)
+ else:
+ raise ValueError(f"Unknown activation type: {activation_type}")
+
+
+class ResBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ norm_type: str = "batch_norm",
+ act_fn: str = "relu6",
+ ) -> None:
+ super().__init__()
+
+ self.norm_type = norm_type
+ self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
+ self.conv1 = ops.Conv2d(in_channels, in_channels, 3, 1, 1)
+ self.conv2 = ops.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
+ self.norm = get_normalization(norm_type, out_channels)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ residual = hidden_states
+ hidden_states = self.conv1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.norm_type == "rms_norm":
+ # move channel to the last dimension so we apply RMSnorm across channel dimension
+ hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
+ else:
+ hidden_states = self.norm(hidden_states)
+
+ return hidden_states + residual
+
+class SanaMultiscaleAttentionProjection(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ num_attention_heads: int,
+ kernel_size: int,
+ ) -> None:
+ super().__init__()
+
+ channels = 3 * in_channels
+ self.proj_in = ops.Conv2d(
+ channels,
+ channels,
+ kernel_size,
+ padding=kernel_size // 2,
+ groups=channels,
+ bias=False,
+ )
+ self.proj_out = ops.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.proj_in(hidden_states)
+ hidden_states = self.proj_out(hidden_states)
+ return hidden_states
+
+class SanaMultiscaleLinearAttention(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_attention_heads: int = None,
+ attention_head_dim: int = 8,
+ mult: float = 1.0,
+ norm_type: str = "batch_norm",
+ kernel_sizes: tuple = (5,),
+ eps: float = 1e-15,
+ residual_connection: bool = False,
+ ):
+ super().__init__()
+
+ self.eps = eps
+ self.attention_head_dim = attention_head_dim
+ self.norm_type = norm_type
+ self.residual_connection = residual_connection
+
+ num_attention_heads = (
+ int(in_channels // attention_head_dim * mult)
+ if num_attention_heads is None
+ else num_attention_heads
+ )
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.to_q = ops.Linear(in_channels, inner_dim, bias=False)
+ self.to_k = ops.Linear(in_channels, inner_dim, bias=False)
+ self.to_v = ops.Linear(in_channels, inner_dim, bias=False)
+
+ self.to_qkv_multiscale = nn.ModuleList()
+ for kernel_size in kernel_sizes:
+ self.to_qkv_multiscale.append(
+ SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
+ )
+
+ self.nonlinearity = nn.ReLU()
+ self.to_out = ops.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
+ self.norm_out = get_normalization(norm_type, out_channels)
+
+ def apply_linear_attention(self, query, key, value):
+ value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
+ scores = torch.matmul(value, key.transpose(-1, -2))
+ hidden_states = torch.matmul(scores, query)
+
+ hidden_states = hidden_states.to(dtype=torch.float32)
+ hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
+ return hidden_states
+
+ def apply_quadratic_attention(self, query, key, value):
+ scores = torch.matmul(key.transpose(-1, -2), query)
+ scores = scores.to(dtype=torch.float32)
+ scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
+ hidden_states = torch.matmul(value, scores.to(value.dtype))
+ return hidden_states
+
+ def forward(self, hidden_states):
+ height, width = hidden_states.shape[-2:]
+ if height * width > self.attention_head_dim:
+ use_linear_attention = True
+ else:
+ use_linear_attention = False
+
+ residual = hidden_states
+
+ batch_size, _, height, width = list(hidden_states.size())
+ original_dtype = hidden_states.dtype
+
+ hidden_states = hidden_states.movedim(1, -1)
+ query = self.to_q(hidden_states)
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ hidden_states = torch.cat([query, key, value], dim=3)
+ hidden_states = hidden_states.movedim(-1, 1)
+
+ multi_scale_qkv = [hidden_states]
+ for block in self.to_qkv_multiscale:
+ multi_scale_qkv.append(block(hidden_states))
+
+ hidden_states = torch.cat(multi_scale_qkv, dim=1)
+
+ if use_linear_attention:
+ # for linear attention upcast hidden_states to float32
+ hidden_states = hidden_states.to(dtype=torch.float32)
+
+ hidden_states = hidden_states.reshape(batch_size, -1, 3 * self.attention_head_dim, height * width)
+
+ query, key, value = hidden_states.chunk(3, dim=2)
+ query = self.nonlinearity(query)
+ key = self.nonlinearity(key)
+
+ if use_linear_attention:
+ hidden_states = self.apply_linear_attention(query, key, value)
+ hidden_states = hidden_states.to(dtype=original_dtype)
+ else:
+ hidden_states = self.apply_quadratic_attention(query, key, value)
+
+ hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
+ hidden_states = self.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
+
+ if self.norm_type == "rms_norm":
+ hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
+ else:
+ hidden_states = self.norm_out(hidden_states)
+
+ if self.residual_connection:
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+class EfficientViTBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ mult: float = 1.0,
+ attention_head_dim: int = 32,
+ qkv_multiscales: tuple = (5,),
+ norm_type: str = "batch_norm",
+ ) -> None:
+ super().__init__()
+
+ self.attn = SanaMultiscaleLinearAttention(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ mult=mult,
+ attention_head_dim=attention_head_dim,
+ norm_type=norm_type,
+ kernel_sizes=qkv_multiscales,
+ residual_connection=True,
+ )
+
+ self.conv_out = GLUMBConv(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ norm_type="rms_norm",
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.attn(x)
+ x = self.conv_out(x)
+ return x
+
+
+class GLUMBConv(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ expand_ratio: float = 4,
+ norm_type: str = None,
+ residual_connection: bool = True,
+ ) -> None:
+ super().__init__()
+
+ hidden_channels = int(expand_ratio * in_channels)
+ self.norm_type = norm_type
+ self.residual_connection = residual_connection
+
+ self.nonlinearity = nn.SiLU()
+ self.conv_inverted = ops.Conv2d(in_channels, hidden_channels * 2, 1, 1, 0)
+ self.conv_depth = ops.Conv2d(hidden_channels * 2, hidden_channels * 2, 3, 1, 1, groups=hidden_channels * 2)
+ self.conv_point = ops.Conv2d(hidden_channels, out_channels, 1, 1, 0, bias=False)
+
+ self.norm = None
+ if norm_type == "rms_norm":
+ self.norm = RMSNorm(out_channels, eps=1e-5, elementwise_affine=True, bias=True)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if self.residual_connection:
+ residual = hidden_states
+
+ hidden_states = self.conv_inverted(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv_depth(hidden_states)
+ hidden_states, gate = torch.chunk(hidden_states, 2, dim=1)
+ hidden_states = hidden_states * self.nonlinearity(gate)
+
+ hidden_states = self.conv_point(hidden_states)
+
+ if self.norm_type == "rms_norm":
+ # move channel to the last dimension so we apply RMSnorm across channel dimension
+ hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
+
+ if self.residual_connection:
+ hidden_states = hidden_states + residual
+
+ return hidden_states
+
+
+def get_block(
+ block_type: str,
+ in_channels: int,
+ out_channels: int,
+ attention_head_dim: int,
+ norm_type: str,
+ act_fn: str,
+ qkv_mutliscales: tuple = (),
+):
+ if block_type == "ResBlock":
+ block = ResBlock(in_channels, out_channels, norm_type, act_fn)
+ elif block_type == "EfficientViTBlock":
+ block = EfficientViTBlock(
+ in_channels,
+ attention_head_dim=attention_head_dim,
+ norm_type=norm_type,
+ qkv_multiscales=qkv_mutliscales
+ )
+ else:
+ raise ValueError(f"Block with {block_type=} is not supported.")
+
+ return block
+
+
+class DCDownBlock2d(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, downsample: bool = False, shortcut: bool = True) -> None:
+ super().__init__()
+
+ self.downsample = downsample
+ self.factor = 2
+ self.stride = 1 if downsample else 2
+ self.group_size = in_channels * self.factor**2 // out_channels
+ self.shortcut = shortcut
+
+ out_ratio = self.factor**2
+ if downsample:
+ assert out_channels % out_ratio == 0
+ out_channels = out_channels // out_ratio
+
+ self.conv = ops.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=self.stride,
+ padding=1,
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ x = self.conv(hidden_states)
+ if self.downsample:
+ x = F.pixel_unshuffle(x, self.factor)
+
+ if self.shortcut:
+ y = F.pixel_unshuffle(hidden_states, self.factor)
+ y = y.unflatten(1, (-1, self.group_size))
+ y = y.mean(dim=2)
+ hidden_states = x + y
+ else:
+ hidden_states = x
+
+ return hidden_states
+
+
+class DCUpBlock2d(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ interpolate: bool = False,
+ shortcut: bool = True,
+ interpolation_mode: str = "nearest",
+ ) -> None:
+ super().__init__()
+
+ self.interpolate = interpolate
+ self.interpolation_mode = interpolation_mode
+ self.shortcut = shortcut
+ self.factor = 2
+ self.repeats = out_channels * self.factor**2 // in_channels
+
+ out_ratio = self.factor**2
+ if not interpolate:
+ out_channels = out_channels * out_ratio
+
+ self.conv = ops.Conv2d(in_channels, out_channels, 3, 1, 1)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if self.interpolate:
+ x = F.interpolate(hidden_states, scale_factor=self.factor, mode=self.interpolation_mode)
+ x = self.conv(x)
+ else:
+ x = self.conv(hidden_states)
+ x = F.pixel_shuffle(x, self.factor)
+
+ if self.shortcut:
+ y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
+ y = F.pixel_shuffle(y, self.factor)
+ hidden_states = x + y
+ else:
+ hidden_states = x
+
+ return hidden_states
+
+
+class Encoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ latent_channels: int,
+ attention_head_dim: int = 32,
+ block_type: str or tuple = "ResBlock",
+ block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
+ layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
+ qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
+ downsample_block_type: str = "pixel_unshuffle",
+ out_shortcut: bool = True,
+ ):
+ super().__init__()
+
+ num_blocks = len(block_out_channels)
+
+ if isinstance(block_type, str):
+ block_type = (block_type,) * num_blocks
+
+ if layers_per_block[0] > 0:
+ self.conv_in = ops.Conv2d(
+ in_channels,
+ block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ )
+ else:
+ self.conv_in = DCDownBlock2d(
+ in_channels=in_channels,
+ out_channels=block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1],
+ downsample=downsample_block_type == "pixel_unshuffle",
+ shortcut=False,
+ )
+
+ down_blocks = []
+ for i, (out_channel, num_layers) in enumerate(zip(block_out_channels, layers_per_block)):
+ down_block_list = []
+
+ for _ in range(num_layers):
+ block = get_block(
+ block_type[i],
+ out_channel,
+ out_channel,
+ attention_head_dim=attention_head_dim,
+ norm_type="rms_norm",
+ act_fn="silu",
+ qkv_mutliscales=qkv_multiscales[i],
+ )
+ down_block_list.append(block)
+
+ if i < num_blocks - 1 and num_layers > 0:
+ downsample_block = DCDownBlock2d(
+ in_channels=out_channel,
+ out_channels=block_out_channels[i + 1],
+ downsample=downsample_block_type == "pixel_unshuffle",
+ shortcut=True,
+ )
+ down_block_list.append(downsample_block)
+
+ down_blocks.append(nn.Sequential(*down_block_list))
+
+ self.down_blocks = nn.ModuleList(down_blocks)
+
+ self.conv_out = ops.Conv2d(block_out_channels[-1], latent_channels, 3, 1, 1)
+
+ self.out_shortcut = out_shortcut
+ if out_shortcut:
+ self.out_shortcut_average_group_size = block_out_channels[-1] // latent_channels
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.conv_in(hidden_states)
+ for down_block in self.down_blocks:
+ hidden_states = down_block(hidden_states)
+
+ if self.out_shortcut:
+ x = hidden_states.unflatten(1, (-1, self.out_shortcut_average_group_size))
+ x = x.mean(dim=2)
+ hidden_states = self.conv_out(hidden_states) + x
+ else:
+ hidden_states = self.conv_out(hidden_states)
+
+ return hidden_states
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ latent_channels: int,
+ attention_head_dim: int = 32,
+ block_type: str or tuple = "ResBlock",
+ block_out_channels: tuple = (128, 256, 512, 512, 1024, 1024),
+ layers_per_block: tuple = (2, 2, 2, 2, 2, 2),
+ qkv_multiscales: tuple = ((), (), (), (5,), (5,), (5,)),
+ norm_type: str or tuple = "rms_norm",
+ act_fn: str or tuple = "silu",
+ upsample_block_type: str = "pixel_shuffle",
+ in_shortcut: bool = True,
+ ):
+ super().__init__()
+
+ num_blocks = len(block_out_channels)
+
+ if isinstance(block_type, str):
+ block_type = (block_type,) * num_blocks
+ if isinstance(norm_type, str):
+ norm_type = (norm_type,) * num_blocks
+ if isinstance(act_fn, str):
+ act_fn = (act_fn,) * num_blocks
+
+ self.conv_in = ops.Conv2d(latent_channels, block_out_channels[-1], 3, 1, 1)
+
+ self.in_shortcut = in_shortcut
+ if in_shortcut:
+ self.in_shortcut_repeats = block_out_channels[-1] // latent_channels
+
+ up_blocks = []
+ for i, (out_channel, num_layers) in reversed(list(enumerate(zip(block_out_channels, layers_per_block)))):
+ up_block_list = []
+
+ if i < num_blocks - 1 and num_layers > 0:
+ upsample_block = DCUpBlock2d(
+ block_out_channels[i + 1],
+ out_channel,
+ interpolate=upsample_block_type == "interpolate",
+ shortcut=True,
+ )
+ up_block_list.append(upsample_block)
+
+ for _ in range(num_layers):
+ block = get_block(
+ block_type[i],
+ out_channel,
+ out_channel,
+ attention_head_dim=attention_head_dim,
+ norm_type=norm_type[i],
+ act_fn=act_fn[i],
+ qkv_mutliscales=qkv_multiscales[i],
+ )
+ up_block_list.append(block)
+
+ up_blocks.insert(0, nn.Sequential(*up_block_list))
+
+ self.up_blocks = nn.ModuleList(up_blocks)
+
+ channels = block_out_channels[0] if layers_per_block[0] > 0 else block_out_channels[1]
+
+ self.norm_out = RMSNorm(channels, 1e-5, elementwise_affine=True, bias=True)
+ self.conv_act = nn.ReLU()
+ self.conv_out = None
+
+ if layers_per_block[0] > 0:
+ self.conv_out = ops.Conv2d(channels, in_channels, 3, 1, 1)
+ else:
+ self.conv_out = DCUpBlock2d(
+ channels, in_channels, interpolate=upsample_block_type == "interpolate", shortcut=False
+ )
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ if self.in_shortcut:
+ x = hidden_states.repeat_interleave(
+ self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
+ )
+ hidden_states = self.conv_in(hidden_states) + x
+ else:
+ hidden_states = self.conv_in(hidden_states)
+
+ for up_block in reversed(self.up_blocks):
+ hidden_states = up_block(hidden_states)
+
+ hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
+ hidden_states = self.conv_act(hidden_states)
+ hidden_states = self.conv_out(hidden_states)
+ return hidden_states
+
+
+class AutoencoderDC(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 2,
+ latent_channels: int = 8,
+ attention_head_dim: int = 32,
+ encoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
+ decoder_block_types: Union[str, Tuple[str]] = ["ResBlock", "ResBlock", "ResBlock", "EfficientViTBlock"],
+ encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
+ decoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 1024),
+ encoder_layers_per_block: Tuple[int] = (2, 2, 3, 3),
+ decoder_layers_per_block: Tuple[int] = (3, 3, 3, 3),
+ encoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
+ decoder_qkv_multiscales: Tuple[Tuple[int, ...], ...] = ((), (), (5,), (5,)),
+ upsample_block_type: str = "interpolate",
+ downsample_block_type: str = "Conv",
+ decoder_norm_types: Union[str, Tuple[str]] = "rms_norm",
+ decoder_act_fns: Union[str, Tuple[str]] = "silu",
+ scaling_factor: float = 0.41407,
+ ) -> None:
+ super().__init__()
+
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ latent_channels=latent_channels,
+ attention_head_dim=attention_head_dim,
+ block_type=encoder_block_types,
+ block_out_channels=encoder_block_out_channels,
+ layers_per_block=encoder_layers_per_block,
+ qkv_multiscales=encoder_qkv_multiscales,
+ downsample_block_type=downsample_block_type,
+ )
+
+ self.decoder = Decoder(
+ in_channels=in_channels,
+ latent_channels=latent_channels,
+ attention_head_dim=attention_head_dim,
+ block_type=decoder_block_types,
+ block_out_channels=decoder_block_out_channels,
+ layers_per_block=decoder_layers_per_block,
+ qkv_multiscales=decoder_qkv_multiscales,
+ norm_type=decoder_norm_types,
+ act_fn=decoder_act_fns,
+ upsample_block_type=upsample_block_type,
+ )
+
+ self.scaling_factor = scaling_factor
+ self.spatial_compression_ratio = 2 ** (len(encoder_block_out_channels) - 1)
+
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
+ """Internal encoding function."""
+ encoded = self.encoder(x)
+ return encoded * self.scaling_factor
+
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
+ # Scale the latents back
+ z = z / self.scaling_factor
+ decoded = self.decoder(z)
+ return decoded
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ z = self.encode(x)
+ return self.decode(z)
+
diff --git a/comfy/ldm/ace/vae/music_dcae_pipeline.py b/comfy/ldm/ace/vae/music_dcae_pipeline.py
new file mode 100644
index 000000000..af81280eb
--- /dev/null
+++ b/comfy/ldm/ace/vae/music_dcae_pipeline.py
@@ -0,0 +1,109 @@
+# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_dcae_pipeline.py
+import torch
+from .autoencoder_dc import AutoencoderDC
+import logging
+try:
+ import torchaudio
+except:
+ logging.warning("torchaudio missing, ACE model will be broken")
+
+import torchvision.transforms as transforms
+from .music_vocoder import ADaMoSHiFiGANV1
+
+
+class MusicDCAE(torch.nn.Module):
+ def __init__(self, source_sample_rate=None, dcae_config={}, vocoder_config={}):
+ super(MusicDCAE, self).__init__()
+
+ self.dcae = AutoencoderDC(**dcae_config)
+ self.vocoder = ADaMoSHiFiGANV1(**vocoder_config)
+
+ if source_sample_rate is None:
+ self.source_sample_rate = 48000
+ else:
+ self.source_sample_rate = source_sample_rate
+
+ # self.resampler = torchaudio.transforms.Resample(source_sample_rate, 44100)
+
+ self.transform = transforms.Compose([
+ transforms.Normalize(0.5, 0.5),
+ ])
+ self.min_mel_value = -11.0
+ self.max_mel_value = 3.0
+ self.audio_chunk_size = int(round((1024 * 512 / 44100 * 48000)))
+ self.mel_chunk_size = 1024
+ self.time_dimention_multiple = 8
+ self.latent_chunk_size = self.mel_chunk_size // self.time_dimention_multiple
+ self.scale_factor = 0.1786
+ self.shift_factor = -1.9091
+
+ def load_audio(self, audio_path):
+ audio, sr = torchaudio.load(audio_path)
+ return audio, sr
+
+ def forward_mel(self, audios):
+ mels = []
+ for i in range(len(audios)):
+ image = self.vocoder.mel_transform(audios[i])
+ mels.append(image)
+ mels = torch.stack(mels)
+ return mels
+
+ @torch.no_grad()
+ def encode(self, audios, audio_lengths=None, sr=None):
+ if audio_lengths is None:
+ audio_lengths = torch.tensor([audios.shape[2]] * audios.shape[0])
+ audio_lengths = audio_lengths.to(audios.device)
+
+ if sr is None:
+ sr = self.source_sample_rate
+
+ if sr != 44100:
+ audios = torchaudio.functional.resample(audios, sr, 44100)
+
+ max_audio_len = audios.shape[-1]
+ if max_audio_len % (8 * 512) != 0:
+ audios = torch.nn.functional.pad(audios, (0, 8 * 512 - max_audio_len % (8 * 512)))
+
+ mels = self.forward_mel(audios)
+ mels = (mels - self.min_mel_value) / (self.max_mel_value - self.min_mel_value)
+ mels = self.transform(mels)
+ latents = []
+ for mel in mels:
+ latent = self.dcae.encoder(mel.unsqueeze(0))
+ latents.append(latent)
+ latents = torch.cat(latents, dim=0)
+ # latent_lengths = (audio_lengths / sr * 44100 / 512 / self.time_dimention_multiple).long()
+ latents = (latents - self.shift_factor) * self.scale_factor
+ return latents
+ # return latents, latent_lengths
+
+ @torch.no_grad()
+ def decode(self, latents, audio_lengths=None, sr=None):
+ latents = latents / self.scale_factor + self.shift_factor
+
+ pred_wavs = []
+
+ for latent in latents:
+ mels = self.dcae.decoder(latent.unsqueeze(0))
+ mels = mels * 0.5 + 0.5
+ mels = mels * (self.max_mel_value - self.min_mel_value) + self.min_mel_value
+ wav = self.vocoder.decode(mels[0]).squeeze(1)
+
+ if sr is not None:
+ # resampler = torchaudio.transforms.Resample(44100, sr).to(latents.device).to(latents.dtype)
+ wav = torchaudio.functional.resample(wav, 44100, sr)
+ # wav = resampler(wav)
+ else:
+ sr = 44100
+ pred_wavs.append(wav)
+
+ if audio_lengths is not None:
+ pred_wavs = [wav[:, :length].cpu() for wav, length in zip(pred_wavs, audio_lengths)]
+ return torch.stack(pred_wavs)
+ # return sr, pred_wavs
+
+ def forward(self, audios, audio_lengths=None, sr=None):
+ latents, latent_lengths = self.encode(audios=audios, audio_lengths=audio_lengths, sr=sr)
+ sr, pred_wavs = self.decode(latents=latents, audio_lengths=audio_lengths, sr=sr)
+ return sr, pred_wavs, latents, latent_lengths
diff --git a/comfy/ldm/ace/vae/music_log_mel.py b/comfy/ldm/ace/vae/music_log_mel.py
new file mode 100755
index 000000000..9c584eb7f
--- /dev/null
+++ b/comfy/ldm/ace/vae/music_log_mel.py
@@ -0,0 +1,113 @@
+# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_log_mel.py
+import torch
+import torch.nn as nn
+from torch import Tensor
+import logging
+try:
+ from torchaudio.transforms import MelScale
+except:
+ logging.warning("torchaudio missing, ACE model will be broken")
+
+import comfy.model_management
+
+class LinearSpectrogram(nn.Module):
+ def __init__(
+ self,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ center=False,
+ mode="pow2_sqrt",
+ ):
+ super().__init__()
+
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.mode = mode
+
+ self.register_buffer("window", torch.hann_window(win_length))
+
+ def forward(self, y: Tensor) -> Tensor:
+ if y.ndim == 3:
+ y = y.squeeze(1)
+
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1),
+ (
+ (self.win_length - self.hop_length) // 2,
+ (self.win_length - self.hop_length + 1) // 2,
+ ),
+ mode="reflect",
+ ).squeeze(1)
+ dtype = y.dtype
+ spec = torch.stft(
+ y.float(),
+ self.n_fft,
+ hop_length=self.hop_length,
+ win_length=self.win_length,
+ window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
+ center=self.center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+ spec = torch.view_as_real(spec)
+
+ if self.mode == "pow2_sqrt":
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
+ spec = spec.to(dtype)
+ return spec
+
+
+class LogMelSpectrogram(nn.Module):
+ def __init__(
+ self,
+ sample_rate=44100,
+ n_fft=2048,
+ win_length=2048,
+ hop_length=512,
+ n_mels=128,
+ center=False,
+ f_min=0.0,
+ f_max=None,
+ ):
+ super().__init__()
+
+ self.sample_rate = sample_rate
+ self.n_fft = n_fft
+ self.win_length = win_length
+ self.hop_length = hop_length
+ self.center = center
+ self.n_mels = n_mels
+ self.f_min = f_min
+ self.f_max = f_max or sample_rate // 2
+
+ self.spectrogram = LinearSpectrogram(n_fft, win_length, hop_length, center)
+ self.mel_scale = MelScale(
+ self.n_mels,
+ self.sample_rate,
+ self.f_min,
+ self.f_max,
+ self.n_fft // 2 + 1,
+ "slaney",
+ "slaney",
+ )
+
+ def compress(self, x: Tensor) -> Tensor:
+ return torch.log(torch.clamp(x, min=1e-5))
+
+ def decompress(self, x: Tensor) -> Tensor:
+ return torch.exp(x)
+
+ def forward(self, x: Tensor, return_linear: bool = False) -> Tensor:
+ linear = self.spectrogram(x)
+ x = self.mel_scale(linear)
+ x = self.compress(x)
+ # print(x.shape)
+ if return_linear:
+ return x, self.compress(linear)
+
+ return x
diff --git a/comfy/ldm/ace/vae/music_vocoder.py b/comfy/ldm/ace/vae/music_vocoder.py
new file mode 100755
index 000000000..2f989fa86
--- /dev/null
+++ b/comfy/ldm/ace/vae/music_vocoder.py
@@ -0,0 +1,538 @@
+# Original from: https://github.com/ace-step/ACE-Step/blob/main/music_dcae/music_vocoder.py
+import torch
+from torch import nn
+
+from functools import partial
+from math import prod
+from typing import Callable, Tuple, List
+
+import numpy as np
+import torch.nn.functional as F
+from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_norm
+
+from .music_log_mel import LogMelSpectrogram
+
+import comfy.model_management
+import comfy.ops
+ops = comfy.ops.disable_weight_init
+
+
+def drop_path(
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
+):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """ # noqa: E501
+
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" # noqa: E501
+
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class LayerNorm(nn.Module):
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """ # noqa: E501
+
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape,)
+
+ def forward(self, x):
+ if self.data_format == "channels_last":
+ return F.layer_norm(
+ x, self.normalized_shape, comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
+ )
+ elif self.data_format == "channels_first":
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = comfy.model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + comfy.model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
+ return x
+
+
+class ConvNeXtBlock(nn.Module):
+ r"""ConvNeXt Block. There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+ We use (2) as we find it slightly faster in PyTorch
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+ kernel_size (int): Kernel size for depthwise conv. Default: 7.
+ dilation (int): Dilation for depthwise conv. Default: 1.
+ """ # noqa: E501
+
+ def __init__(
+ self,
+ dim: int,
+ drop_path: float = 0.0,
+ layer_scale_init_value: float = 1e-6,
+ mlp_ratio: float = 4.0,
+ kernel_size: int = 7,
+ dilation: int = 1,
+ ):
+ super().__init__()
+
+ self.dwconv = ops.Conv1d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=int(dilation * (kernel_size - 1) / 2),
+ groups=dim,
+ ) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = ops.Linear(
+ dim, int(mlp_ratio * dim)
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.pwconv2 = ops.Linear(int(mlp_ratio * dim), dim)
+ self.gamma = (
+ nn.Parameter(torch.empty((dim)), requires_grad=False)
+ if layer_scale_init_value > 0
+ else None
+ )
+ self.drop_path = DropPath(
+ drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x, apply_residual: bool = True):
+ input = x
+
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 1) # (N, C, L) -> (N, L, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+
+ if self.gamma is not None:
+ x = comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
+
+ x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
+ x = self.drop_path(x)
+
+ if apply_residual:
+ x = input + x
+
+ return x
+
+
+class ParallelConvNeXtBlock(nn.Module):
+ def __init__(self, kernel_sizes: List[int], *args, **kwargs):
+ super().__init__()
+ self.blocks = nn.ModuleList(
+ [
+ ConvNeXtBlock(kernel_size=kernel_size, *args, **kwargs)
+ for kernel_size in kernel_sizes
+ ]
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return torch.stack(
+ [block(x, apply_residual=False) for block in self.blocks] + [x],
+ dim=1,
+ ).sum(dim=1)
+
+
+class ConvNeXtEncoder(nn.Module):
+ def __init__(
+ self,
+ input_channels=3,
+ depths=[3, 3, 9, 3],
+ dims=[96, 192, 384, 768],
+ drop_path_rate=0.0,
+ layer_scale_init_value=1e-6,
+ kernel_sizes: Tuple[int] = (7,),
+ ):
+ super().__init__()
+ assert len(depths) == len(dims)
+
+ self.channel_layers = nn.ModuleList()
+ stem = nn.Sequential(
+ ops.Conv1d(
+ input_channels,
+ dims[0],
+ kernel_size=7,
+ padding=3,
+ padding_mode="replicate",
+ ),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
+ )
+ self.channel_layers.append(stem)
+
+ for i in range(len(depths) - 1):
+ mid_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ ops.Conv1d(dims[i], dims[i + 1], kernel_size=1),
+ )
+ self.channel_layers.append(mid_layer)
+
+ block_fn = (
+ partial(ConvNeXtBlock, kernel_size=kernel_sizes[0])
+ if len(kernel_sizes) == 1
+ else partial(ParallelConvNeXtBlock, kernel_sizes=kernel_sizes)
+ )
+
+ self.stages = nn.ModuleList()
+ drop_path_rates = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ]
+
+ cur = 0
+ for i in range(len(depths)):
+ stage = nn.Sequential(
+ *[
+ block_fn(
+ dim=dims[i],
+ drop_path=drop_path_rates[cur + j],
+ layer_scale_init_value=layer_scale_init_value,
+ )
+ for j in range(depths[i])
+ ]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ ) -> torch.Tensor:
+ for channel_layer, stage in zip(self.channel_layers, self.stages):
+ x = channel_layer(x)
+ x = stage(x)
+
+ return self.norm(x)
+
+
+def get_padding(kernel_size, dilation=1):
+ return (kernel_size * dilation - dilation) // 2
+
+
+class ResBlock1(torch.nn.Module):
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
+ super().__init__()
+
+ self.convs1 = nn.ModuleList(
+ [
+ torch.nn.utils.parametrizations.weight_norm(
+ ops.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[0],
+ padding=get_padding(kernel_size, dilation[0]),
+ )
+ ),
+ torch.nn.utils.parametrizations.weight_norm(
+ ops.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[1],
+ padding=get_padding(kernel_size, dilation[1]),
+ )
+ ),
+ torch.nn.utils.parametrizations.weight_norm(
+ ops.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=dilation[2],
+ padding=get_padding(kernel_size, dilation[2]),
+ )
+ ),
+ ]
+ )
+
+ self.convs2 = nn.ModuleList(
+ [
+ torch.nn.utils.parametrizations.weight_norm(
+ ops.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ torch.nn.utils.parametrizations.weight_norm(
+ ops.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ torch.nn.utils.parametrizations.weight_norm(
+ ops.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ 1,
+ dilation=1,
+ padding=get_padding(kernel_size, 1),
+ )
+ ),
+ ]
+ )
+
+ def forward(self, x):
+ for c1, c2 in zip(self.convs1, self.convs2):
+ xt = F.silu(x)
+ xt = c1(xt)
+ xt = F.silu(xt)
+ xt = c2(xt)
+ x = xt + x
+ return x
+
+ def remove_weight_norm(self):
+ for conv in self.convs1:
+ remove_weight_norm(conv)
+ for conv in self.convs2:
+ remove_weight_norm(conv)
+
+
+class HiFiGANGenerator(nn.Module):
+ def __init__(
+ self,
+ *,
+ hop_length: int = 512,
+ upsample_rates: Tuple[int] = (8, 8, 2, 2, 2),
+ upsample_kernel_sizes: Tuple[int] = (16, 16, 8, 2, 2),
+ resblock_kernel_sizes: Tuple[int] = (3, 7, 11),
+ resblock_dilation_sizes: Tuple[Tuple[int]] = (
+ (1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ num_mels: int = 128,
+ upsample_initial_channel: int = 512,
+ use_template: bool = True,
+ pre_conv_kernel_size: int = 7,
+ post_conv_kernel_size: int = 7,
+ post_activation: Callable = partial(nn.SiLU, inplace=True),
+ ):
+ super().__init__()
+
+ assert (
+ prod(upsample_rates) == hop_length
+ ), f"hop_length must be {prod(upsample_rates)}"
+
+ self.conv_pre = torch.nn.utils.parametrizations.weight_norm(
+ ops.Conv1d(
+ num_mels,
+ upsample_initial_channel,
+ pre_conv_kernel_size,
+ 1,
+ padding=get_padding(pre_conv_kernel_size),
+ )
+ )
+
+ self.num_upsamples = len(upsample_rates)
+ self.num_kernels = len(resblock_kernel_sizes)
+
+ self.noise_convs = nn.ModuleList()
+ self.use_template = use_template
+ self.ups = nn.ModuleList()
+
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
+ self.ups.append(
+ torch.nn.utils.parametrizations.weight_norm(
+ ops.ConvTranspose1d(
+ upsample_initial_channel // (2**i),
+ upsample_initial_channel // (2 ** (i + 1)),
+ k,
+ u,
+ padding=(k - u) // 2,
+ )
+ )
+ )
+
+ if not use_template:
+ continue
+
+ if i + 1 < len(upsample_rates):
+ stride_f0 = np.prod(upsample_rates[i + 1:])
+ self.noise_convs.append(
+ ops.Conv1d(
+ 1,
+ c_cur,
+ kernel_size=stride_f0 * 2,
+ stride=stride_f0,
+ padding=stride_f0 // 2,
+ )
+ )
+ else:
+ self.noise_convs.append(ops.Conv1d(1, c_cur, kernel_size=1))
+
+ self.resblocks = nn.ModuleList()
+ for i in range(len(self.ups)):
+ ch = upsample_initial_channel // (2 ** (i + 1))
+ for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
+ self.resblocks.append(ResBlock1(ch, k, d))
+
+ self.activation_post = post_activation()
+ self.conv_post = torch.nn.utils.parametrizations.weight_norm(
+ ops.Conv1d(
+ ch,
+ 1,
+ post_conv_kernel_size,
+ 1,
+ padding=get_padding(post_conv_kernel_size),
+ )
+ )
+
+ def forward(self, x, template=None):
+ x = self.conv_pre(x)
+
+ for i in range(self.num_upsamples):
+ x = F.silu(x, inplace=True)
+ x = self.ups[i](x)
+
+ if self.use_template:
+ x = x + self.noise_convs[i](template)
+
+ xs = None
+
+ for j in range(self.num_kernels):
+ if xs is None:
+ xs = self.resblocks[i * self.num_kernels + j](x)
+ else:
+ xs += self.resblocks[i * self.num_kernels + j](x)
+
+ x = xs / self.num_kernels
+
+ x = self.activation_post(x)
+ x = self.conv_post(x)
+ x = torch.tanh(x)
+
+ return x
+
+ def remove_weight_norm(self):
+ for up in self.ups:
+ remove_weight_norm(up)
+ for block in self.resblocks:
+ block.remove_weight_norm()
+ remove_weight_norm(self.conv_pre)
+ remove_weight_norm(self.conv_post)
+
+
+class ADaMoSHiFiGANV1(nn.Module):
+ def __init__(
+ self,
+ input_channels: int = 128,
+ depths: List[int] = [3, 3, 9, 3],
+ dims: List[int] = [128, 256, 384, 512],
+ drop_path_rate: float = 0.0,
+ kernel_sizes: Tuple[int] = (7,),
+ upsample_rates: Tuple[int] = (4, 4, 2, 2, 2, 2, 2),
+ upsample_kernel_sizes: Tuple[int] = (8, 8, 4, 4, 4, 4, 4),
+ resblock_kernel_sizes: Tuple[int] = (3, 7, 11, 13),
+ resblock_dilation_sizes: Tuple[Tuple[int]] = (
+ (1, 3, 5), (1, 3, 5), (1, 3, 5), (1, 3, 5)),
+ num_mels: int = 512,
+ upsample_initial_channel: int = 1024,
+ use_template: bool = False,
+ pre_conv_kernel_size: int = 13,
+ post_conv_kernel_size: int = 13,
+ sampling_rate: int = 44100,
+ n_fft: int = 2048,
+ win_length: int = 2048,
+ hop_length: int = 512,
+ f_min: int = 40,
+ f_max: int = 16000,
+ n_mels: int = 128,
+ ):
+ super().__init__()
+
+ self.backbone = ConvNeXtEncoder(
+ input_channels=input_channels,
+ depths=depths,
+ dims=dims,
+ drop_path_rate=drop_path_rate,
+ kernel_sizes=kernel_sizes,
+ )
+
+ self.head = HiFiGANGenerator(
+ hop_length=hop_length,
+ upsample_rates=upsample_rates,
+ upsample_kernel_sizes=upsample_kernel_sizes,
+ resblock_kernel_sizes=resblock_kernel_sizes,
+ resblock_dilation_sizes=resblock_dilation_sizes,
+ num_mels=num_mels,
+ upsample_initial_channel=upsample_initial_channel,
+ use_template=use_template,
+ pre_conv_kernel_size=pre_conv_kernel_size,
+ post_conv_kernel_size=post_conv_kernel_size,
+ )
+ self.sampling_rate = sampling_rate
+ self.mel_transform = LogMelSpectrogram(
+ sample_rate=sampling_rate,
+ n_fft=n_fft,
+ win_length=win_length,
+ hop_length=hop_length,
+ f_min=f_min,
+ f_max=f_max,
+ n_mels=n_mels,
+ )
+ self.eval()
+
+ @torch.no_grad()
+ def decode(self, mel):
+ y = self.backbone(mel)
+ y = self.head(y)
+ return y
+
+ @torch.no_grad()
+ def encode(self, x):
+ return self.mel_transform(x)
+
+ def forward(self, mel):
+ y = self.backbone(mel)
+ y = self.head(y)
+ return y
diff --git a/comfy/ldm/audio/autoencoder.py b/comfy/ldm/audio/autoencoder.py
index 9e7e7c876..78ed6ffa6 100644
--- a/comfy/ldm/audio/autoencoder.py
+++ b/comfy/ldm/audio/autoencoder.py
@@ -75,16 +75,10 @@ class SnakeBeta(nn.Module):
return x
def WNConv1d(*args, **kwargs):
- try:
- return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
- except:
- return torch.nn.utils.weight_norm(ops.Conv1d(*args, **kwargs)) #support pytorch 2.1 and older
+ return torch.nn.utils.parametrizations.weight_norm(ops.Conv1d(*args, **kwargs))
def WNConvTranspose1d(*args, **kwargs):
- try:
- return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
- except:
- return torch.nn.utils.weight_norm(ops.ConvTranspose1d(*args, **kwargs)) #support pytorch 2.1 and older
+ return torch.nn.utils.parametrizations.weight_norm(ops.ConvTranspose1d(*args, **kwargs))
def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
if activation == "elu":
diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py
index ca8867eaf..145e6e69a 100644
--- a/comfy/ldm/cascade/stage_a.py
+++ b/comfy/ldm/cascade/stage_a.py
@@ -19,6 +19,10 @@
import torch
from torch import nn
from torch.autograd import Function
+import comfy.ops
+
+ops = comfy.ops.disable_weight_init
+
class vector_quantize(Function):
@staticmethod
@@ -121,15 +125,15 @@ class ResBlock(nn.Module):
self.norm1 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.depthwise = nn.Sequential(
nn.ReplicationPad2d(1),
- nn.Conv2d(c, c, kernel_size=3, groups=c)
+ ops.Conv2d(c, c, kernel_size=3, groups=c)
)
# channelwise
self.norm2 = nn.LayerNorm(c, elementwise_affine=False, eps=1e-6)
self.channelwise = nn.Sequential(
- nn.Linear(c, c_hidden),
+ ops.Linear(c, c_hidden),
nn.GELU(),
- nn.Linear(c_hidden, c),
+ ops.Linear(c_hidden, c),
)
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
@@ -171,16 +175,16 @@ class StageA(nn.Module):
# Encoder blocks
self.in_block = nn.Sequential(
nn.PixelUnshuffle(2),
- nn.Conv2d(3 * 4, c_levels[0], kernel_size=1)
+ ops.Conv2d(3 * 4, c_levels[0], kernel_size=1)
)
down_blocks = []
for i in range(levels):
if i > 0:
- down_blocks.append(nn.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
+ down_blocks.append(ops.Conv2d(c_levels[i - 1], c_levels[i], kernel_size=4, stride=2, padding=1))
block = ResBlock(c_levels[i], c_levels[i] * 4)
down_blocks.append(block)
down_blocks.append(nn.Sequential(
- nn.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
+ ops.Conv2d(c_levels[-1], c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
))
self.down_blocks = nn.Sequential(*down_blocks)
@@ -191,7 +195,7 @@ class StageA(nn.Module):
# Decoder blocks
up_blocks = [nn.Sequential(
- nn.Conv2d(c_latent, c_levels[-1], kernel_size=1)
+ ops.Conv2d(c_latent, c_levels[-1], kernel_size=1)
)]
for i in range(levels):
for j in range(bottleneck_blocks if i == 0 else 1):
@@ -199,11 +203,11 @@ class StageA(nn.Module):
up_blocks.append(block)
if i < levels - 1:
up_blocks.append(
- nn.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
+ ops.ConvTranspose2d(c_levels[levels - 1 - i], c_levels[levels - 2 - i], kernel_size=4, stride=2,
padding=1))
self.up_blocks = nn.Sequential(*up_blocks)
self.out_block = nn.Sequential(
- nn.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
+ ops.Conv2d(c_levels[0], 3 * 4, kernel_size=1),
nn.PixelShuffle(2),
)
@@ -232,17 +236,17 @@ class Discriminator(nn.Module):
super().__init__()
d = max(depth - 3, 3)
layers = [
- nn.utils.spectral_norm(nn.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
+ nn.utils.spectral_norm(ops.Conv2d(c_in, c_hidden // (2 ** d), kernel_size=3, stride=2, padding=1)),
nn.LeakyReLU(0.2),
]
for i in range(depth - 1):
c_in = c_hidden // (2 ** max((d - i), 0))
c_out = c_hidden // (2 ** max((d - 1 - i), 0))
- layers.append(nn.utils.spectral_norm(nn.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
+ layers.append(nn.utils.spectral_norm(ops.Conv2d(c_in, c_out, kernel_size=3, stride=2, padding=1)))
layers.append(nn.InstanceNorm2d(c_out))
layers.append(nn.LeakyReLU(0.2))
self.encoder = nn.Sequential(*layers)
- self.shuffle = nn.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
+ self.shuffle = ops.Conv2d((c_hidden + c_cond) if c_cond > 0 else c_hidden, 1, kernel_size=1)
self.logits = nn.Sigmoid()
def forward(self, x, cond=None):
diff --git a/comfy/ldm/cascade/stage_c_coder.py b/comfy/ldm/cascade/stage_c_coder.py
index 0cb7c49fc..b467a70a8 100644
--- a/comfy/ldm/cascade/stage_c_coder.py
+++ b/comfy/ldm/cascade/stage_c_coder.py
@@ -19,6 +19,9 @@ import torch
import torchvision
from torch import nn
+import comfy.ops
+
+ops = comfy.ops.disable_weight_init
# EfficientNet
class EfficientNetEncoder(nn.Module):
@@ -26,7 +29,7 @@ class EfficientNetEncoder(nn.Module):
super().__init__()
self.backbone = torchvision.models.efficientnet_v2_s().features.eval()
self.mapper = nn.Sequential(
- nn.Conv2d(1280, c_latent, kernel_size=1, bias=False),
+ ops.Conv2d(1280, c_latent, kernel_size=1, bias=False),
nn.BatchNorm2d(c_latent, affine=False), # then normalize them to have mean 0 and std 1
)
self.mean = nn.Parameter(torch.tensor([0.485, 0.456, 0.406]))
@@ -34,7 +37,7 @@ class EfficientNetEncoder(nn.Module):
def forward(self, x):
x = x * 0.5 + 0.5
- x = (x - self.mean.view([3,1,1])) / self.std.view([3,1,1])
+ x = (x - self.mean.view([3,1,1]).to(device=x.device, dtype=x.dtype)) / self.std.view([3,1,1]).to(device=x.device, dtype=x.dtype)
o = self.mapper(self.backbone(x))
return o
@@ -44,39 +47,39 @@ class Previewer(nn.Module):
def __init__(self, c_in=16, c_hidden=512, c_out=3):
super().__init__()
self.blocks = nn.Sequential(
- nn.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
+ ops.Conv2d(c_in, c_hidden, kernel_size=1), # 16 channels to 512 channels
nn.GELU(),
nn.BatchNorm2d(c_hidden),
- nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
+ ops.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden),
- nn.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
+ ops.ConvTranspose2d(c_hidden, c_hidden // 2, kernel_size=2, stride=2), # 16 -> 32
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
- nn.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
+ ops.Conv2d(c_hidden // 2, c_hidden // 2, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 2),
- nn.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
+ ops.ConvTranspose2d(c_hidden // 2, c_hidden // 4, kernel_size=2, stride=2), # 32 -> 64
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
- nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
+ ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
- nn.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
+ ops.ConvTranspose2d(c_hidden // 4, c_hidden // 4, kernel_size=2, stride=2), # 64 -> 128
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
- nn.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
+ ops.Conv2d(c_hidden // 4, c_hidden // 4, kernel_size=3, padding=1),
nn.GELU(),
nn.BatchNorm2d(c_hidden // 4),
- nn.Conv2d(c_hidden // 4, c_out, kernel_size=1),
+ ops.Conv2d(c_hidden // 4, c_out, kernel_size=1),
)
def forward(self, x):
diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py
new file mode 100644
index 000000000..35da91ee2
--- /dev/null
+++ b/comfy/ldm/chroma/layers.py
@@ -0,0 +1,183 @@
+import torch
+from torch import Tensor, nn
+
+from comfy.ldm.flux.math import attention
+from comfy.ldm.flux.layers import (
+ MLPEmbedder,
+ RMSNorm,
+ QKNorm,
+ SelfAttention,
+ ModulationOut,
+)
+
+
+
+class ChromaModulationOut(ModulationOut):
+ @classmethod
+ def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
+ return cls(
+ shift=tensor[:, offset : offset + 1, :],
+ scale=tensor[:, offset + 1 : offset + 2, :],
+ gate=tensor[:, offset + 2 : offset + 3, :],
+ )
+
+
+
+
+class Approximator(nn.Module):
+ def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
+ self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
+ self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
+ self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)
+
+ @property
+ def device(self):
+ # Get the device of the module (assumes all parameters are on the same device)
+ return next(self.parameters()).device
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.in_proj(x)
+
+ for layer, norms in zip(self.layers, self.norms):
+ x = x + layer(norms(x))
+
+ x = self.out_proj(x)
+
+ return x
+
+
+class DoubleStreamBlock(nn.Module):
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
+ super().__init__()
+
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
+
+ self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.img_mlp = nn.Sequential(
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
+ nn.GELU(approximate="tanh"),
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
+ )
+
+ self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)
+
+ self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.txt_mlp = nn.Sequential(
+ operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
+ nn.GELU(approximate="tanh"),
+ operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
+ )
+ self.flipped_img_txt = flipped_img_txt
+
+ def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
+ (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
+
+ # prepare image for attention
+ img_modulated = self.img_norm1(img)
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
+ img_qkv = self.img_attn.qkv(img_modulated)
+ img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
+
+ # prepare txt for attention
+ txt_modulated = self.txt_norm1(txt)
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
+ txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
+
+ # run actual attention
+ attn = attention(torch.cat((txt_q, img_q), dim=2),
+ torch.cat((txt_k, img_k), dim=2),
+ torch.cat((txt_v, img_v), dim=2),
+ pe=pe, mask=attn_mask)
+
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
+
+ # calculate the img bloks
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
+
+ # calculate the txt bloks
+ txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
+ txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
+
+ if txt.dtype == torch.float16:
+ txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
+
+ return img, txt
+
+
+class SingleStreamBlock(nn.Module):
+ """
+ A DiT block with parallel linear layers as described in
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qk_scale: float = None,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ self.hidden_dim = hidden_size
+ self.num_heads = num_heads
+ head_dim = hidden_size // num_heads
+ self.scale = qk_scale or head_dim**-0.5
+
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ # qkv and mlp_in
+ self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
+ # proj and mlp_out
+ self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)
+
+ self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)
+
+ self.hidden_size = hidden_size
+ self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+
+ self.mlp_act = nn.GELU(approximate="tanh")
+
+ def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
+ mod = vec
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
+
+ q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k = self.norm(q, k, v)
+
+ # compute attention
+ attn = attention(q, k, v, pe=pe, mask=attn_mask)
+ # compute activation in mlp stream, cat again and run second linear layer
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
+ x += mod.gate * output
+ if x.dtype == torch.float16:
+ x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
+ return x
+
+
+class LastLayer(nn.Module):
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
+ self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)
+
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
+ shift, scale = vec
+ shift = shift.squeeze(1)
+ scale = scale.squeeze(1)
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
+ x = self.linear(x)
+ return x
diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py
new file mode 100644
index 000000000..c75023a31
--- /dev/null
+++ b/comfy/ldm/chroma/model.py
@@ -0,0 +1,271 @@
+#Original code can be found on: https://github.com/black-forest-labs/flux
+
+from dataclasses import dataclass
+
+import torch
+from torch import Tensor, nn
+from einops import rearrange, repeat
+import comfy.ldm.common_dit
+
+from comfy.ldm.flux.layers import (
+ EmbedND,
+ timestep_embedding,
+)
+
+from .layers import (
+ DoubleStreamBlock,
+ LastLayer,
+ SingleStreamBlock,
+ Approximator,
+ ChromaModulationOut,
+)
+
+
+@dataclass
+class ChromaParams:
+ in_channels: int
+ out_channels: int
+ context_in_dim: int
+ hidden_size: int
+ mlp_ratio: float
+ num_heads: int
+ depth: int
+ depth_single_blocks: int
+ axes_dim: list
+ theta: int
+ patch_size: int
+ qkv_bias: bool
+ in_dim: int
+ out_dim: int
+ hidden_dim: int
+ n_layers: int
+
+
+
+
+class Chroma(nn.Module):
+ """
+ Transformer model for flow matching on sequences.
+ """
+
+ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs):
+ super().__init__()
+ self.dtype = dtype
+ params = ChromaParams(**kwargs)
+ self.params = params
+ self.patch_size = params.patch_size
+ self.in_channels = params.in_channels
+ self.out_channels = params.out_channels
+ if params.hidden_size % params.num_heads != 0:
+ raise ValueError(
+ f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
+ )
+ pe_dim = params.hidden_size // params.num_heads
+ if sum(params.axes_dim) != pe_dim:
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
+ self.hidden_size = params.hidden_size
+ self.num_heads = params.num_heads
+ self.in_dim = params.in_dim
+ self.out_dim = params.out_dim
+ self.hidden_dim = params.hidden_dim
+ self.n_layers = params.n_layers
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
+ self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
+ self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device)
+ # set as nn identity for now, will overwrite it later.
+ self.distilled_guidance_layer = Approximator(
+ in_dim=self.in_dim,
+ hidden_dim=self.hidden_dim,
+ out_dim=self.out_dim,
+ n_layers=self.n_layers,
+ dtype=dtype, device=device, operations=operations
+ )
+
+
+ self.double_blocks = nn.ModuleList(
+ [
+ DoubleStreamBlock(
+ self.hidden_size,
+ self.num_heads,
+ mlp_ratio=params.mlp_ratio,
+ qkv_bias=params.qkv_bias,
+ dtype=dtype, device=device, operations=operations
+ )
+ for _ in range(params.depth)
+ ]
+ )
+
+ self.single_blocks = nn.ModuleList(
+ [
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations)
+ for _ in range(params.depth_single_blocks)
+ ]
+ )
+
+ if final_layer:
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations)
+
+ self.skip_mmdit = []
+ self.skip_dit = []
+ self.lite = False
+
+ def get_modulations(self, tensor: torch.Tensor, block_type: str, *, idx: int = 0):
+ # This function slices up the modulations tensor which has the following layout:
+ # single : num_single_blocks * 3 elements
+ # double_img : num_double_blocks * 6 elements
+ # double_txt : num_double_blocks * 6 elements
+ # final : 2 elements
+ if block_type == "final":
+ return (tensor[:, -2:-1, :], tensor[:, -1:, :])
+ single_block_count = self.params.depth_single_blocks
+ double_block_count = self.params.depth
+ offset = 3 * idx
+ if block_type == "single":
+ return ChromaModulationOut.from_offset(tensor, offset)
+ # Double block modulations are 6 elements so we double 3 * idx.
+ offset *= 2
+ if block_type in {"double_img", "double_txt"}:
+ # Advance past the single block modulations.
+ offset += 3 * single_block_count
+ if block_type == "double_txt":
+ # Advance past the double block img modulations.
+ offset += 6 * double_block_count
+ return (
+ ChromaModulationOut.from_offset(tensor, offset),
+ ChromaModulationOut.from_offset(tensor, offset + 3),
+ )
+ raise ValueError("Bad block_type")
+
+
+ def forward_orig(
+ self,
+ img: Tensor,
+ img_ids: Tensor,
+ txt: Tensor,
+ txt_ids: Tensor,
+ timesteps: Tensor,
+ guidance: Tensor = None,
+ control = None,
+ transformer_options={},
+ attn_mask: Tensor = None,
+ ) -> Tensor:
+ patches_replace = transformer_options.get("patches_replace", {})
+ if img.ndim != 3 or txt.ndim != 3:
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
+
+ # running on sequences img
+ img = self.img_in(img)
+
+ # distilled vector guidance
+ mod_index_length = 344
+ distill_timestep = timestep_embedding(timesteps.detach().clone(), 16).to(img.device, img.dtype)
+ # guidance = guidance *
+ distil_guidance = timestep_embedding(guidance.detach().clone(), 16).to(img.device, img.dtype)
+
+ # get all modulation index
+ modulation_index = timestep_embedding(torch.arange(mod_index_length, device=img.device), 32).to(img.device, img.dtype)
+ # we need to broadcast the modulation index here so each batch has all of the index
+ modulation_index = modulation_index.unsqueeze(0).repeat(img.shape[0], 1, 1).to(img.device, img.dtype)
+ # and we need to broadcast timestep and guidance along too
+ timestep_guidance = torch.cat([distill_timestep, distil_guidance], dim=1).unsqueeze(1).repeat(1, mod_index_length, 1).to(img.dtype).to(img.device, img.dtype)
+ # then and only then we could concatenate it together
+ input_vec = torch.cat([timestep_guidance, modulation_index], dim=-1).to(img.device, img.dtype)
+
+ mod_vectors = self.distilled_guidance_layer(input_vec)
+
+ txt = self.txt_in(txt)
+
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+ pe = self.pe_embedder(ids)
+
+ blocks_replace = patches_replace.get("dit", {})
+ for i, block in enumerate(self.double_blocks):
+ if i not in self.skip_mmdit:
+ double_mod = (
+ self.get_modulations(mod_vectors, "double_img", idx=i),
+ self.get_modulations(mod_vectors, "double_txt", idx=i),
+ )
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["img"], out["txt"] = block(img=args["img"],
+ txt=args["txt"],
+ vec=args["vec"],
+ pe=args["pe"],
+ attn_mask=args.get("attn_mask"))
+ return out
+
+ out = blocks_replace[("double_block", i)]({"img": img,
+ "txt": txt,
+ "vec": double_mod,
+ "pe": pe,
+ "attn_mask": attn_mask},
+ {"original_block": block_wrap})
+ txt = out["txt"]
+ img = out["img"]
+ else:
+ img, txt = block(img=img,
+ txt=txt,
+ vec=double_mod,
+ pe=pe,
+ attn_mask=attn_mask)
+
+ if control is not None: # Controlnet
+ control_i = control.get("input")
+ if i < len(control_i):
+ add = control_i[i]
+ if add is not None:
+ img += add
+
+ img = torch.cat((txt, img), 1)
+
+ for i, block in enumerate(self.single_blocks):
+ if i not in self.skip_dit:
+ single_mod = self.get_modulations(mod_vectors, "single", idx=i)
+ if ("single_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["img"] = block(args["img"],
+ vec=args["vec"],
+ pe=args["pe"],
+ attn_mask=args.get("attn_mask"))
+ return out
+
+ out = blocks_replace[("single_block", i)]({"img": img,
+ "vec": single_mod,
+ "pe": pe,
+ "attn_mask": attn_mask},
+ {"original_block": block_wrap})
+ img = out["img"]
+ else:
+ img = block(img, vec=single_mod, pe=pe, attn_mask=attn_mask)
+
+ if control is not None: # Controlnet
+ control_o = control.get("output")
+ if i < len(control_o):
+ add = control_o[i]
+ if add is not None:
+ img[:, txt.shape[1] :, ...] += add
+
+ img = img[:, txt.shape[1] :, ...]
+ final_mod = self.get_modulations(mod_vectors, "final")
+ img = self.final_layer(img, vec=final_mod) # (N, T, patch_size ** 2 * out_channels)
+ return img
+
+ def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
+ bs, c, h, w = x.shape
+ patch_size = 2
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
+
+ img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
+
+ h_len = ((h + (patch_size // 2)) // patch_size)
+ w_len = ((w + (patch_size // 2)) // patch_size)
+ img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
+ img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
+ img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
+
+ txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
+ out = self.forward_orig(img, img_ids, context, txt_ids, timestep, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
+ return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:,:,:h,:w]
diff --git a/comfy/ldm/common_dit.py b/comfy/ldm/common_dit.py
index e0f3057f7..f7f56b72c 100644
--- a/comfy/ldm/common_dit.py
+++ b/comfy/ldm/common_dit.py
@@ -1,5 +1,6 @@
import torch
-import comfy.ops
+import comfy.rmsnorm
+
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
@@ -11,20 +12,5 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
return torch.nn.functional.pad(img, pad, mode=padding_mode)
-try:
- rms_norm_torch = torch.nn.functional.rms_norm
-except:
- rms_norm_torch = None
-def rms_norm(x, weight=None, eps=1e-6):
- if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
- if weight is None:
- return rms_norm_torch(x, (x.shape[-1],), eps=eps)
- else:
- return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
- else:
- r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
- if weight is None:
- return r
- else:
- return r * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)
+rms_norm = comfy.rmsnorm.rms_norm
diff --git a/comfy/ldm/cosmos/blocks.py b/comfy/ldm/cosmos/blocks.py
index 3e9c6497a..a12f892d2 100644
--- a/comfy/ldm/cosmos/blocks.py
+++ b/comfy/ldm/cosmos/blocks.py
@@ -23,7 +23,6 @@ from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn
-from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
from comfy.ldm.modules.attention import optimized_attention
@@ -37,11 +36,11 @@ def apply_rotary_pos_emb(
return t_out
-def get_normalization(name: str, channels: int, weight_args={}):
+def get_normalization(name: str, channels: int, weight_args={}, operations=None):
if name == "I":
return nn.Identity()
elif name == "R":
- return RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
+ return operations.RMSNorm(channels, elementwise_affine=True, eps=1e-6, **weight_args)
else:
raise ValueError(f"Normalization {name} not found")
@@ -120,15 +119,15 @@ class Attention(nn.Module):
self.to_q = nn.Sequential(
operations.Linear(query_dim, inner_dim, bias=qkv_bias, **weight_args),
- get_normalization(qkv_norm[0], norm_dim),
+ get_normalization(qkv_norm[0], norm_dim, weight_args=weight_args, operations=operations),
)
self.to_k = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
- get_normalization(qkv_norm[1], norm_dim),
+ get_normalization(qkv_norm[1], norm_dim, weight_args=weight_args, operations=operations),
)
self.to_v = nn.Sequential(
operations.Linear(context_dim, inner_dim, bias=qkv_bias, **weight_args),
- get_normalization(qkv_norm[2], norm_dim),
+ get_normalization(qkv_norm[2], norm_dim, weight_args=weight_args, operations=operations),
)
self.to_out = nn.Sequential(
@@ -168,14 +167,18 @@ class Attention(nn.Module):
k = self.to_k[1](k)
v = self.to_v[1](v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
- q = apply_rotary_pos_emb(q, rope_emb)
- k = apply_rotary_pos_emb(k, rope_emb)
- return q, k, v
+ # apply_rotary_pos_emb inlined
+ q_shape = q.shape
+ q = q.reshape(*q.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
+ q = rope_emb[..., 0] * q[..., 0] + rope_emb[..., 1] * q[..., 1]
+ q = q.movedim(-1, -2).reshape(*q_shape).to(x.dtype)
- def cal_attn(self, q, k, v, mask=None):
- out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
- out = rearrange(out, " b n s c -> s b (n c)")
- return self.to_out(out)
+ # apply_rotary_pos_emb inlined
+ k_shape = k.shape
+ k = k.reshape(*k.shape[:-1], 2, -1).movedim(-2, -1).unsqueeze(-2)
+ k = rope_emb[..., 0] * k[..., 0] + rope_emb[..., 1] * k[..., 1]
+ k = k.movedim(-1, -2).reshape(*k_shape).to(x.dtype)
+ return q, k, v
def forward(
self,
@@ -191,7 +194,10 @@ class Attention(nn.Module):
context (Optional[Tensor]): The key tensor of shape [B, Mk, K] or use x as context [self attention] if None
"""
q, k, v = self.cal_qkv(x, context, mask, rope_emb=rope_emb, **kwargs)
- return self.cal_attn(q, k, v, mask)
+ out = optimized_attention(q, k, v, self.heads, skip_reshape=True, mask=mask, skip_output_reshape=True)
+ del q, k, v
+ out = rearrange(out, " b n s c -> s b (n c)")
+ return self.to_out(out)
class FeedForward(nn.Module):
@@ -788,10 +794,7 @@ class GeneralDITTransformerBlock(nn.Module):
crossattn_mask: Optional[torch.Tensor] = None,
rope_emb_L_1_1_D: Optional[torch.Tensor] = None,
adaln_lora_B_3D: Optional[torch.Tensor] = None,
- extra_per_block_pos_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
- if extra_per_block_pos_emb is not None:
- x = x + extra_per_block_pos_emb
for block in self.blocks:
x = block(
x,
diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py
index 6149e53ec..9a3ebed6a 100644
--- a/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py
+++ b/comfy/ldm/cosmos/cosmos_tokenizer/layers3d.py
@@ -30,6 +30,8 @@ import torch.nn as nn
import torch.nn.functional as F
import logging
+from comfy.ldm.modules.diffusionmodules.model import vae_attention
+
from .patching import (
Patcher,
Patcher3D,
@@ -400,6 +402,8 @@ class CausalAttnBlock(nn.Module):
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
+ self.optimized_attention = vae_attention()
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = x
h_ = self.norm(h_)
@@ -413,18 +417,7 @@ class CausalAttnBlock(nn.Module):
v, batch_size = time2batch(v)
b, c, h, w = q.shape
- q = q.reshape(b, c, h * w)
- q = q.permute(0, 2, 1)
- k = k.reshape(b, c, h * w)
- w_ = torch.bmm(q, k)
- w_ = w_ * (int(c) ** (-0.5))
- w_ = F.softmax(w_, dim=2)
-
- # attend to values
- v = v.reshape(b, c, h * w)
- w_ = w_.permute(0, 2, 1)
- h_ = torch.bmm(v, w_)
- h_ = h_.reshape(b, c, h, w)
+ h_ = self.optimized_attention(q, k, v)
h_ = batch2time(h_, batch_size)
h_ = self.proj_out(h_)
@@ -871,18 +864,16 @@ class EncoderFactorized(nn.Module):
x = self.patcher3d(x)
# downsampling
- hs = [self.conv_in(x)]
+ h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
- h = self.down[i_level].block[i_block](hs[-1])
+ h = self.down[i_level].block[i_block](h)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
- hs.append(h)
if i_level != self.num_resolutions - 1:
- hs.append(self.down[i_level].downsample(hs[-1]))
+ h = self.down[i_level].downsample(h)
# middle
- h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/patching.py b/comfy/ldm/cosmos/cosmos_tokenizer/patching.py
index 793f0da8a..87a53a1d9 100644
--- a/comfy/ldm/cosmos/cosmos_tokenizer/patching.py
+++ b/comfy/ldm/cosmos/cosmos_tokenizer/patching.py
@@ -281,54 +281,76 @@ class UnPatcher3D(UnPatcher):
hh = hh.to(dtype=dtype)
xlll, xllh, xlhl, xlhh, xhll, xhlh, xhhl, xhhh = torch.chunk(x, 8, dim=1)
+ del x
# Height height transposed convolutions.
xll = F.conv_transpose3d(
xlll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
+ del xlll
+
xll += F.conv_transpose3d(
xllh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
+ del xllh
xlh = F.conv_transpose3d(
xlhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
+ del xlhl
+
xlh += F.conv_transpose3d(
xlhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
+ del xlhh
xhl = F.conv_transpose3d(
xhll, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
+ del xhll
+
xhl += F.conv_transpose3d(
xhlh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
+ del xhlh
xhh = F.conv_transpose3d(
xhhl, hl.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
+ del xhhl
+
xhh += F.conv_transpose3d(
xhhh, hh.unsqueeze(2).unsqueeze(3), groups=g, stride=(1, 1, 2)
)
+ del xhhh
# Handles width transposed convolutions.
xl = F.conv_transpose3d(
xll, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
+ del xll
+
xl += F.conv_transpose3d(
xlh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
+ del xlh
+
xh = F.conv_transpose3d(
xhl, hl.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
+ del xhl
+
xh += F.conv_transpose3d(
xhh, hh.unsqueeze(2).unsqueeze(4), groups=g, stride=(1, 2, 1)
)
+ del xhh
# Handles time axis transposed convolutions.
x = F.conv_transpose3d(
xl, hl.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)
+ del xl
+
x += F.conv_transpose3d(
xh, hh.unsqueeze(3).unsqueeze(4), groups=g, stride=(2, 1, 1)
)
diff --git a/comfy/ldm/cosmos/cosmos_tokenizer/utils.py b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py
index 64dd5e288..3af8d0d05 100644
--- a/comfy/ldm/cosmos/cosmos_tokenizer/utils.py
+++ b/comfy/ldm/cosmos/cosmos_tokenizer/utils.py
@@ -17,7 +17,7 @@
from typing import Any
import torch
-from einops import pack, rearrange, unpack
+from einops import rearrange
import comfy.ops
@@ -98,14 +98,6 @@ def default(*args):
return None
-def pack_one(t, pattern):
- return pack([t], pattern)
-
-
-def unpack_one(t, ps, pattern):
- return unpack(t, ps, pattern)[0]
-
-
def round_ste(z: torch.Tensor) -> torch.Tensor:
"""Round with straight through gradients."""
zhat = z.round()
diff --git a/comfy/ldm/cosmos/model.py b/comfy/ldm/cosmos/model.py
index 05dd38469..4836e0b69 100644
--- a/comfy/ldm/cosmos/model.py
+++ b/comfy/ldm/cosmos/model.py
@@ -27,8 +27,6 @@ from torchvision import transforms
from enum import Enum
import logging
-from comfy.ldm.modules.diffusionmodules.mmdit import RMSNorm
-
from .blocks import (
FinalLayer,
GeneralDITTransformerBlock,
@@ -168,7 +166,7 @@ class GeneralDIT(nn.Module):
operations=operations,
)
- self.build_pos_embed(device=device)
+ self.build_pos_embed(device=device, dtype=dtype)
self.block_x_format = block_x_format
self.use_adaln_lora = use_adaln_lora
self.adaln_lora_dim = adaln_lora_dim
@@ -195,7 +193,7 @@ class GeneralDIT(nn.Module):
if self.affline_emb_norm:
logging.debug("Building affine embedding normalization layer")
- self.affline_norm = RMSNorm(model_channels, elementwise_affine=True, eps=1e-6)
+ self.affline_norm = operations.RMSNorm(model_channels, elementwise_affine=True, eps=1e-6, device=device, dtype=dtype)
else:
self.affline_norm = nn.Identity()
@@ -210,7 +208,7 @@ class GeneralDIT(nn.Module):
operations=operations,
)
- def build_pos_embed(self, device=None):
+ def build_pos_embed(self, device=None, dtype=None):
if self.pos_emb_cls == "rope3d":
cls_type = VideoRopePosition3DEmb
else:
@@ -242,6 +240,7 @@ class GeneralDIT(nn.Module):
kwargs["w_extrapolation_ratio"] = self.extra_w_extrapolation_ratio
kwargs["t_extrapolation_ratio"] = self.extra_t_extrapolation_ratio
kwargs["device"] = device
+ kwargs["dtype"] = dtype
self.extra_pos_embedder = LearnablePosEmbAxis(
**kwargs,
)
@@ -292,7 +291,7 @@ class GeneralDIT(nn.Module):
x_B_T_H_W_D = self.x_embedder(x_B_C_T_H_W)
if self.extra_per_block_abs_pos_emb:
- extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device)
+ extra_pos_emb = self.extra_pos_embedder(x_B_T_H_W_D, fps=fps, device=x_B_C_T_H_W.device, dtype=x_B_C_T_H_W.dtype)
else:
extra_pos_emb = None
@@ -476,6 +475,8 @@ class GeneralDIT(nn.Module):
inputs["original_shape"],
)
extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D = inputs["extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D"].to(x.dtype)
+ del inputs
+
if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
assert (
x.shape == extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D.shape
@@ -486,6 +487,8 @@ class GeneralDIT(nn.Module):
self.blocks["block0"].x_format == block.x_format
), f"First block has x_format {self.blocks[0].x_format}, got {block.x_format}"
+ if extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D is not None:
+ x += extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D
x = block(
x,
affline_emb_B_D,
@@ -493,7 +496,6 @@ class GeneralDIT(nn.Module):
crossattn_mask,
rope_emb_L_1_1_D=rope_emb_L_1_1_D,
adaln_lora_B_3D=adaln_lora_B_3D,
- extra_per_block_pos_emb=extra_pos_emb_B_T_H_W_D_or_T_H_W_B_D,
)
x_B_T_H_W_D = rearrange(x, "T H W B D -> B T H W D")
diff --git a/comfy/ldm/cosmos/position_embedding.py b/comfy/ldm/cosmos/position_embedding.py
index dda752cb8..4d6a58dba 100644
--- a/comfy/ldm/cosmos/position_embedding.py
+++ b/comfy/ldm/cosmos/position_embedding.py
@@ -41,12 +41,12 @@ def normalize(x: torch.Tensor, dim: Optional[List[int]] = None, eps: float = 0)
class VideoPositionEmb(nn.Module):
- def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
+ def forward(self, x_B_T_H_W_C: torch.Tensor, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
"""
It delegates the embedding generation to generate_embeddings function.
"""
B_T_H_W_C = x_B_T_H_W_C.shape
- embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device)
+ embeddings = self.generate_embeddings(B_T_H_W_C, fps=fps, device=device, dtype=dtype)
return embeddings
@@ -104,6 +104,7 @@ class VideoRopePosition3DEmb(VideoPositionEmb):
w_ntk_factor: Optional[float] = None,
t_ntk_factor: Optional[float] = None,
device=None,
+ dtype=None,
):
"""
Generate embeddings for the given input size.
@@ -173,6 +174,7 @@ class LearnablePosEmbAxis(VideoPositionEmb):
len_w: int,
len_t: int,
device=None,
+ dtype=None,
**kwargs,
):
"""
@@ -184,17 +186,16 @@ class LearnablePosEmbAxis(VideoPositionEmb):
self.interpolation = interpolation
assert self.interpolation in ["crop"], f"Unknown interpolation method {self.interpolation}"
- self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device))
- self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device))
- self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device))
+ self.pos_emb_h = nn.Parameter(torch.empty(len_h, model_channels, device=device, dtype=dtype))
+ self.pos_emb_w = nn.Parameter(torch.empty(len_w, model_channels, device=device, dtype=dtype))
+ self.pos_emb_t = nn.Parameter(torch.empty(len_t, model_channels, device=device, dtype=dtype))
-
- def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None) -> torch.Tensor:
+ def generate_embeddings(self, B_T_H_W_C: torch.Size, fps=Optional[torch.Tensor], device=None, dtype=None) -> torch.Tensor:
B, T, H, W, _ = B_T_H_W_C
if self.interpolation == "crop":
- emb_h_H = self.pos_emb_h[:H].to(device=device)
- emb_w_W = self.pos_emb_w[:W].to(device=device)
- emb_t_T = self.pos_emb_t[:T].to(device=device)
+ emb_h_H = self.pos_emb_h[:H].to(device=device, dtype=dtype)
+ emb_w_W = self.pos_emb_w[:W].to(device=device, dtype=dtype)
+ emb_t_T = self.pos_emb_t[:T].to(device=device, dtype=dtype)
emb = (
repeat(emb_t_T, "t d-> b t h w d", b=B, h=H, w=W)
+ repeat(emb_h_H, "h d-> b t h w d", b=B, t=T, w=W)
diff --git a/comfy/ldm/cosmos/vae.py b/comfy/ldm/cosmos/vae.py
index 94fcc54ce..d64f292de 100644
--- a/comfy/ldm/cosmos/vae.py
+++ b/comfy/ldm/cosmos/vae.py
@@ -18,6 +18,7 @@ import logging
import torch
from torch import nn
from enum import Enum
+import math
from .cosmos_tokenizer.layers3d import (
EncoderFactorized,
@@ -89,8 +90,8 @@ class CausalContinuousVideoTokenizer(nn.Module):
self.distribution = IdentityDistribution() # ContinuousFormulation[formulation_name].value()
num_parameters = sum(param.numel() for param in self.parameters())
- logging.info(f"model={self.name}, num_parameters={num_parameters:,}")
- logging.info(
+ logging.debug(f"model={self.name}, num_parameters={num_parameters:,}")
+ logging.debug(
f"z_channels={z_channels}, latent_channels={self.latent_channels}."
)
@@ -105,17 +106,23 @@ class CausalContinuousVideoTokenizer(nn.Module):
z, posteriors = self.distribution(moments)
latent_ch = z.shape[1]
latent_t = z.shape[2]
- dtype = z.dtype
- mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
- std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=dtype, device=z.device)
+ in_dtype = z.dtype
+ mean = self.latent_mean.view(latent_ch, -1)
+ std = self.latent_std.view(latent_ch, -1)
+
+ mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
+ std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
return ((z - mean) / std) * self.sigma_data
def decode(self, z):
in_dtype = z.dtype
latent_ch = z.shape[1]
latent_t = z.shape[2]
- mean = self.latent_mean.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
- std = self.latent_std.view(latent_ch, -1)[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
+ mean = self.latent_mean.view(latent_ch, -1)
+ std = self.latent_std.view(latent_ch, -1)
+
+ mean = mean.repeat(1, math.ceil(latent_t / mean.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
+ std = std.repeat(1, math.ceil(latent_t / std.shape[-1]))[:, : latent_t].reshape([1, latent_ch, -1, 1, 1]).to(dtype=in_dtype, device=z.device)
z = z / self.sigma_data
z = z * std + mean
diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py
index 8e055151f..76af967e6 100644
--- a/comfy/ldm/flux/layers.py
+++ b/comfy/ldm/flux/layers.py
@@ -105,7 +105,9 @@ class Modulation(nn.Module):
self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device)
def forward(self, vec: Tensor) -> tuple:
- out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
+ if vec.ndim == 2:
+ vec = vec[:, None, :]
+ out = self.lin(nn.functional.silu(vec)).chunk(self.multiplier, dim=-1)
return (
ModulationOut(*out[:3]),
@@ -113,6 +115,20 @@ class Modulation(nn.Module):
)
+def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
+ if modulation_dims is None:
+ if m_add is not None:
+ return tensor * m_mult + m_add
+ else:
+ return tensor * m_mult
+ else:
+ for d in modulation_dims:
+ tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
+ if m_add is not None:
+ tensor[:, d[0]:d[1]] += m_add[:, d[2]]
+ return tensor
+
+
class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()
@@ -143,20 +159,20 @@ class DoubleStreamBlock(nn.Module):
)
self.flipped_img_txt = flipped_img_txt
- def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None):
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None):
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
- img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
+ img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
- txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
+ txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt)
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
@@ -179,12 +195,12 @@ class DoubleStreamBlock(nn.Module):
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
# calculate the img bloks
- img = img + img_mod1.gate * self.img_attn.proj(img_attn)
- img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
+ img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
+ img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
# calculate the txt bloks
- txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
- txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
+ txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
+ txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt)
if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)
@@ -228,10 +244,9 @@ class SingleStreamBlock(nn.Module):
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations)
- def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None) -> Tensor:
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None) -> Tensor:
mod, _ = self.modulation(vec)
- x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
- qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
+ qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)
@@ -240,7 +255,7 @@ class SingleStreamBlock(nn.Module):
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
- x += mod.gate * output
+ x += apply_mod(output, mod.gate, None, modulation_dims)
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x
@@ -253,8 +268,11 @@ class LastLayer(nn.Module):
self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device))
- def forward(self, x: Tensor, vec: Tensor) -> Tensor:
- shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
- x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
+ def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor:
+ if vec.ndim == 2:
+ vec = vec[:, None, :]
+
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=-1)
+ x = apply_mod(self.norm_final(x), (1 + scale), shift, modulation_dims)
x = self.linear(x)
return x
diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py
index b6549585a..3e0978176 100644
--- a/comfy/ldm/flux/math.py
+++ b/comfy/ldm/flux/math.py
@@ -5,8 +5,16 @@ from torch import Tensor
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
- q, k = apply_rope(q, k, pe)
+ q_shape = q.shape
+ k_shape = k.shape
+
+ if pe is not None:
+ q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
+ k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
+ q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
+ k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
@@ -15,7 +23,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
- if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
+ if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
device = torch.device("cpu")
else:
device = pos.device
@@ -29,8 +37,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
- xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
- xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
+ xq_ = xq.to(dtype=freqs_cis.dtype).reshape(*xq.shape[:-1], -1, 1, 2)
+ xk_ = xk.to(dtype=freqs_cis.dtype).reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py
index dead87de8..ef4ba4106 100644
--- a/comfy/ldm/flux/model.py
+++ b/comfy/ldm/flux/model.py
@@ -109,15 +109,17 @@ class Flux(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256).to(img.dtype))
if self.params.guidance_embed:
- if guidance is None:
- raise ValueError("Didn't get guidance strength for guidance distilled model.")
- vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
+ if guidance is not None:
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
vec = vec + self.vector_in(y[:,:self.params.vec_in_dim])
txt = self.txt_in(txt)
- ids = torch.cat((txt_ids, img_ids), dim=1)
- pe = self.pe_embedder(ids)
+ if img_ids is not None:
+ ids = torch.cat((txt_ids, img_ids), dim=1)
+ pe = self.pe_embedder(ids)
+ else:
+ pe = None
blocks_replace = patches_replace.get("dit", {})
for i, block in enumerate(self.double_blocks):
@@ -186,7 +188,7 @@ class Flux(nn.Module):
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
return img
- def forward(self, x, timestep, context, y, guidance, control=None, transformer_options={}, **kwargs):
+ def forward(self, x, timestep, context, y, guidance=None, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = self.patch_size
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
diff --git a/comfy/ldm/genmo/joint_model/asymm_models_joint.py b/comfy/ldm/genmo/joint_model/asymm_models_joint.py
index 2c46c24bf..366a8b713 100644
--- a/comfy/ldm/genmo/joint_model/asymm_models_joint.py
+++ b/comfy/ldm/genmo/joint_model/asymm_models_joint.py
@@ -13,7 +13,6 @@ from comfy.ldm.modules.attention import optimized_attention
from .layers import (
FeedForward,
PatchEmbed,
- RMSNorm,
TimestepEmbedder,
)
@@ -90,10 +89,10 @@ class AsymmetricAttention(nn.Module):
# Query and key normalization for stability.
assert qk_norm
- self.q_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
- self.k_norm_x = RMSNorm(self.head_dim, device=device, dtype=dtype)
- self.q_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
- self.k_norm_y = RMSNorm(self.head_dim, device=device, dtype=dtype)
+ self.q_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
+ self.k_norm_x = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
+ self.q_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
+ self.k_norm_y = operations.RMSNorm(self.head_dim, eps=1e-5, device=device, dtype=dtype)
# Output layers. y features go back down from dim_x -> dim_y.
self.proj_x = operations.Linear(dim_x, dim_x, bias=out_bias, device=device, dtype=dtype)
diff --git a/comfy/ldm/genmo/joint_model/layers.py b/comfy/ldm/genmo/joint_model/layers.py
index 51d979559..e310bd717 100644
--- a/comfy/ldm/genmo/joint_model/layers.py
+++ b/comfy/ldm/genmo/joint_model/layers.py
@@ -151,14 +151,3 @@ class PatchEmbed(nn.Module):
x = self.norm(x)
return x
-
-
-class RMSNorm(torch.nn.Module):
- def __init__(self, hidden_size, eps=1e-5, device=None, dtype=None):
- super().__init__()
- self.eps = eps
- self.weight = torch.nn.Parameter(torch.empty(hidden_size, device=device, dtype=dtype))
- self.register_parameter("bias", None)
-
- def forward(self, x):
- return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
diff --git a/comfy/ldm/hidream/model.py b/comfy/ldm/hidream/model.py
new file mode 100644
index 000000000..0305747bf
--- /dev/null
+++ b/comfy/ldm/hidream/model.py
@@ -0,0 +1,802 @@
+from typing import Optional, Tuple, List
+
+import torch
+import torch.nn as nn
+import einops
+from einops import repeat
+
+from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
+import torch.nn.functional as F
+
+from comfy.ldm.flux.math import apply_rope, rope
+from comfy.ldm.flux.layers import LastLayer
+
+from comfy.ldm.modules.attention import optimized_attention
+import comfy.model_management
+import comfy.ldm.common_dit
+
+
+# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
+class EmbedND(nn.Module):
+ def __init__(self, theta: int, axes_dim: List[int]):
+ super().__init__()
+ self.theta = theta
+ self.axes_dim = axes_dim
+
+ def forward(self, ids: torch.Tensor) -> torch.Tensor:
+ n_axes = ids.shape[-1]
+ emb = torch.cat(
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
+ dim=-3,
+ )
+ return emb.unsqueeze(2)
+
+
+class PatchEmbed(nn.Module):
+ def __init__(
+ self,
+ patch_size=2,
+ in_channels=4,
+ out_channels=1024,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+ self.patch_size = patch_size
+ self.out_channels = out_channels
+ self.proj = operations.Linear(in_channels * patch_size * patch_size, out_channels, bias=True, dtype=dtype, device=device)
+
+ def forward(self, latent):
+ latent = self.proj(latent)
+ return latent
+
+
+class PooledEmbed(nn.Module):
+ def __init__(self, text_emb_dim, hidden_size, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
+
+ def forward(self, pooled_embed):
+ return self.pooled_embedder(pooled_embed)
+
+
+class TimestepEmbed(nn.Module):
+ def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0)
+ self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size, dtype=dtype, device=device, operations=operations)
+
+ def forward(self, timesteps, wdtype):
+ t_emb = self.time_proj(timesteps).to(dtype=wdtype)
+ t_emb = self.timestep_embedder(t_emb)
+ return t_emb
+
+
+def attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
+ return optimized_attention(query.view(query.shape[0], -1, query.shape[-1] * query.shape[-2]), key.view(key.shape[0], -1, key.shape[-1] * key.shape[-2]), value.view(value.shape[0], -1, value.shape[-1] * value.shape[-2]), query.shape[2])
+
+
+class HiDreamAttnProcessor_flashattn:
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
+
+ def __call__(
+ self,
+ attn,
+ image_tokens: torch.FloatTensor,
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
+ text_tokens: Optional[torch.FloatTensor] = None,
+ rope: torch.FloatTensor = None,
+ *args,
+ **kwargs,
+ ) -> torch.FloatTensor:
+ dtype = image_tokens.dtype
+ batch_size = image_tokens.shape[0]
+
+ query_i = attn.q_rms_norm(attn.to_q(image_tokens)).to(dtype=dtype)
+ key_i = attn.k_rms_norm(attn.to_k(image_tokens)).to(dtype=dtype)
+ value_i = attn.to_v(image_tokens)
+
+ inner_dim = key_i.shape[-1]
+ head_dim = inner_dim // attn.heads
+
+ query_i = query_i.view(batch_size, -1, attn.heads, head_dim)
+ key_i = key_i.view(batch_size, -1, attn.heads, head_dim)
+ value_i = value_i.view(batch_size, -1, attn.heads, head_dim)
+ if image_tokens_masks is not None:
+ key_i = key_i * image_tokens_masks.view(batch_size, -1, 1, 1)
+
+ if not attn.single:
+ query_t = attn.q_rms_norm_t(attn.to_q_t(text_tokens)).to(dtype=dtype)
+ key_t = attn.k_rms_norm_t(attn.to_k_t(text_tokens)).to(dtype=dtype)
+ value_t = attn.to_v_t(text_tokens)
+
+ query_t = query_t.view(batch_size, -1, attn.heads, head_dim)
+ key_t = key_t.view(batch_size, -1, attn.heads, head_dim)
+ value_t = value_t.view(batch_size, -1, attn.heads, head_dim)
+
+ num_image_tokens = query_i.shape[1]
+ num_text_tokens = query_t.shape[1]
+ query = torch.cat([query_i, query_t], dim=1)
+ key = torch.cat([key_i, key_t], dim=1)
+ value = torch.cat([value_i, value_t], dim=1)
+ else:
+ query = query_i
+ key = key_i
+ value = value_i
+
+ if query.shape[-1] == rope.shape[-3] * 2:
+ query, key = apply_rope(query, key, rope)
+ else:
+ query_1, query_2 = query.chunk(2, dim=-1)
+ key_1, key_2 = key.chunk(2, dim=-1)
+ query_1, key_1 = apply_rope(query_1, key_1, rope)
+ query = torch.cat([query_1, query_2], dim=-1)
+ key = torch.cat([key_1, key_2], dim=-1)
+
+ hidden_states = attention(query, key, value)
+
+ if not attn.single:
+ hidden_states_i, hidden_states_t = torch.split(hidden_states, [num_image_tokens, num_text_tokens], dim=1)
+ hidden_states_i = attn.to_out(hidden_states_i)
+ hidden_states_t = attn.to_out_t(hidden_states_t)
+ return hidden_states_i, hidden_states_t
+ else:
+ hidden_states = attn.to_out(hidden_states)
+ return hidden_states
+
+class HiDreamAttention(nn.Module):
+ def __init__(
+ self,
+ query_dim: int,
+ heads: int = 8,
+ dim_head: int = 64,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ scale_qk: bool = True,
+ eps: float = 1e-5,
+ processor = None,
+ out_dim: int = None,
+ single: bool = False,
+ dtype=None, device=None, operations=None
+ ):
+ # super(Attention, self).__init__()
+ super().__init__()
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
+ self.query_dim = query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+ self.out_dim = out_dim if out_dim is not None else query_dim
+
+ self.scale_qk = scale_qk
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
+
+ self.heads = out_dim // dim_head if out_dim is not None else heads
+ self.sliceable_head_dim = heads
+ self.single = single
+
+ linear_cls = operations.Linear
+ self.linear_cls = linear_cls
+ self.to_q = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
+ self.to_k = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
+ self.to_v = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
+ self.to_out = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
+ self.q_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
+ self.k_rms_norm = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
+
+ if not single:
+ self.to_q_t = linear_cls(query_dim, self.inner_dim, dtype=dtype, device=device)
+ self.to_k_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
+ self.to_v_t = linear_cls(self.inner_dim, self.inner_dim, dtype=dtype, device=device)
+ self.to_out_t = linear_cls(self.inner_dim, self.out_dim, dtype=dtype, device=device)
+ self.q_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
+ self.k_rms_norm_t = operations.RMSNorm(self.inner_dim, eps, dtype=dtype, device=device)
+
+ self.processor = processor
+
+ def forward(
+ self,
+ norm_image_tokens: torch.FloatTensor,
+ image_tokens_masks: torch.FloatTensor = None,
+ norm_text_tokens: torch.FloatTensor = None,
+ rope: torch.FloatTensor = None,
+ ) -> torch.Tensor:
+ return self.processor(
+ self,
+ image_tokens = norm_image_tokens,
+ image_tokens_masks = image_tokens_masks,
+ text_tokens = norm_text_tokens,
+ rope = rope,
+ )
+
+
+class FeedForwardSwiGLU(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int = 256,
+ ffn_dim_multiplier: Optional[float] = None,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+ hidden_dim = int(2 * hidden_dim / 3)
+ # custom dim factor multiplier
+ if ffn_dim_multiplier is not None:
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+ hidden_dim = multiple_of * (
+ (hidden_dim + multiple_of - 1) // multiple_of
+ )
+
+ self.w1 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
+ self.w2 = operations.Linear(hidden_dim, dim, bias=False, dtype=dtype, device=device)
+ self.w3 = operations.Linear(dim, hidden_dim, bias=False, dtype=dtype, device=device)
+
+ def forward(self, x):
+ return self.w2(torch.nn.functional.silu(self.w1(x)) * self.w3(x))
+
+
+# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
+class MoEGate(nn.Module):
+ def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.top_k = num_activated_experts
+ self.n_routed_experts = num_routed_experts
+
+ self.scoring_func = 'softmax'
+ self.alpha = aux_loss_alpha
+ self.seq_aux = False
+
+ # topk selection algorithm
+ self.norm_topk_prob = False
+ self.gating_dim = embed_dim
+ self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim), dtype=dtype, device=device))
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ pass
+ # import torch.nn.init as init
+ # init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+
+ def forward(self, hidden_states):
+ bsz, seq_len, h = hidden_states.shape
+
+ ### compute gating score
+ hidden_states = hidden_states.view(-1, h)
+ logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
+ if self.scoring_func == 'softmax':
+ scores = logits.softmax(dim=-1)
+ else:
+ raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
+
+ ### select top-k experts
+ topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
+
+ ### norm gate to sum 1
+ if self.top_k > 1 and self.norm_topk_prob:
+ denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
+ topk_weight = topk_weight / denominator
+
+ aux_loss = None
+ return topk_idx, topk_weight, aux_loss
+
+
+# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
+class MOEFeedForwardSwiGLU(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ num_routed_experts: int,
+ num_activated_experts: int,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+ self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2, dtype=dtype, device=device, operations=operations)
+ self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim, dtype=dtype, device=device, operations=operations) for i in range(num_routed_experts)])
+ self.gate = MoEGate(
+ embed_dim = dim,
+ num_routed_experts = num_routed_experts,
+ num_activated_experts = num_activated_experts,
+ dtype=dtype, device=device, operations=operations
+ )
+ self.num_activated_experts = num_activated_experts
+
+ def forward(self, x):
+ wtype = x.dtype
+ identity = x
+ orig_shape = x.shape
+ topk_idx, topk_weight, aux_loss = self.gate(x)
+ x = x.view(-1, x.shape[-1])
+ flat_topk_idx = topk_idx.view(-1)
+ if True: # self.training: # TODO: check which branch performs faster
+ x = x.repeat_interleave(self.num_activated_experts, dim=0)
+ y = torch.empty_like(x, dtype=wtype)
+ for i, expert in enumerate(self.experts):
+ y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
+ y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
+ y = y.view(*orig_shape).to(dtype=wtype)
+ #y = AddAuxiliaryLoss.apply(y, aux_loss)
+ else:
+ y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
+ y = y + self.shared_experts(identity)
+ return y
+
+ @torch.no_grad()
+ def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
+ expert_cache = torch.zeros_like(x)
+ idxs = flat_expert_indices.argsort()
+ tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
+ token_idxs = idxs // self.num_activated_experts
+ for i, end_idx in enumerate(tokens_per_expert):
+ start_idx = 0 if i == 0 else tokens_per_expert[i-1]
+ if start_idx == end_idx:
+ continue
+ expert = self.experts[i]
+ exp_token_idx = token_idxs[start_idx:end_idx]
+ expert_tokens = x[exp_token_idx]
+ expert_out = expert(expert_tokens)
+ expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
+
+ # for fp16 and other dtype
+ expert_cache = expert_cache.to(expert_out.dtype)
+ expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
+ return expert_cache
+
+
+class TextProjection(nn.Module):
+ def __init__(self, in_features, hidden_size, dtype=None, device=None, operations=None):
+ super().__init__()
+ self.linear = operations.Linear(in_features=in_features, out_features=hidden_size, bias=False, dtype=dtype, device=device)
+
+ def forward(self, caption):
+ hidden_states = self.linear(caption)
+ return hidden_states
+
+
+class BlockType:
+ TransformerBlock = 1
+ SingleTransformerBlock = 2
+
+
+class HiDreamImageSingleTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_routed_experts: int = 4,
+ num_activated_experts: int = 2,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(dim, 6 * dim, bias=True, dtype=dtype, device=device)
+ )
+
+ # 1. Attention
+ self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
+ self.attn1 = HiDreamAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ processor = HiDreamAttnProcessor_flashattn(),
+ single = True,
+ dtype=dtype, device=device, operations=operations
+ )
+
+ # 3. Feed-forward
+ self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
+ if num_routed_experts > 0:
+ self.ff_i = MOEFeedForwardSwiGLU(
+ dim = dim,
+ hidden_dim = 4 * dim,
+ num_routed_experts = num_routed_experts,
+ num_activated_experts = num_activated_experts,
+ dtype=dtype, device=device, operations=operations
+ )
+ else:
+ self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
+
+ def forward(
+ self,
+ image_tokens: torch.FloatTensor,
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
+ text_tokens: Optional[torch.FloatTensor] = None,
+ adaln_input: Optional[torch.FloatTensor] = None,
+ rope: torch.FloatTensor = None,
+
+ ) -> torch.FloatTensor:
+ wtype = image_tokens.dtype
+ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i = \
+ self.adaLN_modulation(adaln_input)[:,None].chunk(6, dim=-1)
+
+ # 1. MM-Attention
+ norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
+ norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
+ attn_output_i = self.attn1(
+ norm_image_tokens,
+ image_tokens_masks,
+ rope = rope,
+ )
+ image_tokens = gate_msa_i * attn_output_i + image_tokens
+
+ # 2. Feed-forward
+ norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
+ norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
+ ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens.to(dtype=wtype))
+ image_tokens = ff_output_i + image_tokens
+ return image_tokens
+
+
+class HiDreamImageTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_routed_experts: int = 4,
+ num_activated_experts: int = 2,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ operations.Linear(dim, 12 * dim, bias=True, dtype=dtype, device=device)
+ )
+ # nn.init.zeros_(self.adaLN_modulation[1].weight)
+ # nn.init.zeros_(self.adaLN_modulation[1].bias)
+
+ # 1. Attention
+ self.norm1_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
+ self.norm1_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
+ self.attn1 = HiDreamAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ processor = HiDreamAttnProcessor_flashattn(),
+ single = False,
+ dtype=dtype, device=device, operations=operations
+ )
+
+ # 3. Feed-forward
+ self.norm3_i = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False, dtype=dtype, device=device)
+ if num_routed_experts > 0:
+ self.ff_i = MOEFeedForwardSwiGLU(
+ dim = dim,
+ hidden_dim = 4 * dim,
+ num_routed_experts = num_routed_experts,
+ num_activated_experts = num_activated_experts,
+ dtype=dtype, device=device, operations=operations
+ )
+ else:
+ self.ff_i = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
+ self.norm3_t = operations.LayerNorm(dim, eps = 1e-06, elementwise_affine = False)
+ self.ff_t = FeedForwardSwiGLU(dim = dim, hidden_dim = 4 * dim, dtype=dtype, device=device, operations=operations)
+
+ def forward(
+ self,
+ image_tokens: torch.FloatTensor,
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
+ text_tokens: Optional[torch.FloatTensor] = None,
+ adaln_input: Optional[torch.FloatTensor] = None,
+ rope: torch.FloatTensor = None,
+ ) -> torch.FloatTensor:
+ wtype = image_tokens.dtype
+ shift_msa_i, scale_msa_i, gate_msa_i, shift_mlp_i, scale_mlp_i, gate_mlp_i, \
+ shift_msa_t, scale_msa_t, gate_msa_t, shift_mlp_t, scale_mlp_t, gate_mlp_t = \
+ self.adaLN_modulation(adaln_input)[:,None].chunk(12, dim=-1)
+
+ # 1. MM-Attention
+ norm_image_tokens = self.norm1_i(image_tokens).to(dtype=wtype)
+ norm_image_tokens = norm_image_tokens * (1 + scale_msa_i) + shift_msa_i
+ norm_text_tokens = self.norm1_t(text_tokens).to(dtype=wtype)
+ norm_text_tokens = norm_text_tokens * (1 + scale_msa_t) + shift_msa_t
+
+ attn_output_i, attn_output_t = self.attn1(
+ norm_image_tokens,
+ image_tokens_masks,
+ norm_text_tokens,
+ rope = rope,
+ )
+
+ image_tokens = gate_msa_i * attn_output_i + image_tokens
+ text_tokens = gate_msa_t * attn_output_t + text_tokens
+
+ # 2. Feed-forward
+ norm_image_tokens = self.norm3_i(image_tokens).to(dtype=wtype)
+ norm_image_tokens = norm_image_tokens * (1 + scale_mlp_i) + shift_mlp_i
+ norm_text_tokens = self.norm3_t(text_tokens).to(dtype=wtype)
+ norm_text_tokens = norm_text_tokens * (1 + scale_mlp_t) + shift_mlp_t
+
+ ff_output_i = gate_mlp_i * self.ff_i(norm_image_tokens)
+ ff_output_t = gate_mlp_t * self.ff_t(norm_text_tokens)
+ image_tokens = ff_output_i + image_tokens
+ text_tokens = ff_output_t + text_tokens
+ return image_tokens, text_tokens
+
+
+class HiDreamImageBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ num_routed_experts: int = 4,
+ num_activated_experts: int = 2,
+ block_type: BlockType = BlockType.TransformerBlock,
+ dtype=None, device=None, operations=None
+ ):
+ super().__init__()
+ block_classes = {
+ BlockType.TransformerBlock: HiDreamImageTransformerBlock,
+ BlockType.SingleTransformerBlock: HiDreamImageSingleTransformerBlock,
+ }
+ self.block = block_classes[block_type](
+ dim,
+ num_attention_heads,
+ attention_head_dim,
+ num_routed_experts,
+ num_activated_experts,
+ dtype=dtype, device=device, operations=operations
+ )
+
+ def forward(
+ self,
+ image_tokens: torch.FloatTensor,
+ image_tokens_masks: Optional[torch.FloatTensor] = None,
+ text_tokens: Optional[torch.FloatTensor] = None,
+ adaln_input: torch.FloatTensor = None,
+ rope: torch.FloatTensor = None,
+ ) -> torch.FloatTensor:
+ return self.block(
+ image_tokens,
+ image_tokens_masks,
+ text_tokens,
+ adaln_input,
+ rope,
+ )
+
+
+class HiDreamImageTransformer2DModel(nn.Module):
+ def __init__(
+ self,
+ patch_size: Optional[int] = None,
+ in_channels: int = 64,
+ out_channels: Optional[int] = None,
+ num_layers: int = 16,
+ num_single_layers: int = 32,
+ attention_head_dim: int = 128,
+ num_attention_heads: int = 20,
+ caption_channels: List[int] = None,
+ text_emb_dim: int = 2048,
+ num_routed_experts: int = 4,
+ num_activated_experts: int = 2,
+ axes_dims_rope: Tuple[int, int] = (32, 32),
+ max_resolution: Tuple[int, int] = (128, 128),
+ llama_layers: List[int] = None,
+ image_model=None,
+ dtype=None, device=None, operations=None
+ ):
+ self.patch_size = patch_size
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ self.num_layers = num_layers
+ self.num_single_layers = num_single_layers
+
+ self.gradient_checkpointing = False
+
+ super().__init__()
+ self.dtype = dtype
+ self.out_channels = out_channels or in_channels
+ self.inner_dim = self.num_attention_heads * self.attention_head_dim
+ self.llama_layers = llama_layers
+
+ self.t_embedder = TimestepEmbed(self.inner_dim, dtype=dtype, device=device, operations=operations)
+ self.p_embedder = PooledEmbed(text_emb_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
+ self.x_embedder = PatchEmbed(
+ patch_size = patch_size,
+ in_channels = in_channels,
+ out_channels = self.inner_dim,
+ dtype=dtype, device=device, operations=operations
+ )
+ self.pe_embedder = EmbedND(theta=10000, axes_dim=axes_dims_rope)
+
+ self.double_stream_blocks = nn.ModuleList(
+ [
+ HiDreamImageBlock(
+ dim = self.inner_dim,
+ num_attention_heads = self.num_attention_heads,
+ attention_head_dim = self.attention_head_dim,
+ num_routed_experts = num_routed_experts,
+ num_activated_experts = num_activated_experts,
+ block_type = BlockType.TransformerBlock,
+ dtype=dtype, device=device, operations=operations
+ )
+ for i in range(self.num_layers)
+ ]
+ )
+
+ self.single_stream_blocks = nn.ModuleList(
+ [
+ HiDreamImageBlock(
+ dim = self.inner_dim,
+ num_attention_heads = self.num_attention_heads,
+ attention_head_dim = self.attention_head_dim,
+ num_routed_experts = num_routed_experts,
+ num_activated_experts = num_activated_experts,
+ block_type = BlockType.SingleTransformerBlock,
+ dtype=dtype, device=device, operations=operations
+ )
+ for i in range(self.num_single_layers)
+ ]
+ )
+
+ self.final_layer = LastLayer(self.inner_dim, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations)
+
+ caption_channels = [caption_channels[1], ] * (num_layers + num_single_layers) + [caption_channels[0], ]
+ caption_projection = []
+ for caption_channel in caption_channels:
+ caption_projection.append(TextProjection(in_features=caption_channel, hidden_size=self.inner_dim, dtype=dtype, device=device, operations=operations))
+ self.caption_projection = nn.ModuleList(caption_projection)
+ self.max_seq = max_resolution[0] * max_resolution[1] // (patch_size * patch_size)
+
+ def expand_timesteps(self, timesteps, batch_size, device):
+ if not torch.is_tensor(timesteps):
+ is_mps = device.type == "mps"
+ if isinstance(timesteps, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(device)
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(batch_size)
+ return timesteps
+
+ def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]]) -> List[torch.Tensor]:
+ x_arr = []
+ for i, img_size in enumerate(img_sizes):
+ pH, pW = img_size
+ x_arr.append(
+ einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)',
+ p1=self.patch_size, p2=self.patch_size)
+ )
+ x = torch.cat(x_arr, dim=0)
+ return x
+
+ def patchify(self, x, max_seq, img_sizes=None):
+ pz2 = self.patch_size * self.patch_size
+ if isinstance(x, torch.Tensor):
+ B = x.shape[0]
+ device = x.device
+ dtype = x.dtype
+ else:
+ B = len(x)
+ device = x[0].device
+ dtype = x[0].dtype
+ x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
+
+ if img_sizes is not None:
+ for i, img_size in enumerate(img_sizes):
+ x_masks[i, 0:img_size[0] * img_size[1]] = 1
+ x = einops.rearrange(x, 'B C S p -> B S (p C)', p=pz2)
+ elif isinstance(x, torch.Tensor):
+ pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size
+ x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size)
+ img_sizes = [[pH, pW]] * B
+ x_masks = None
+ else:
+ raise NotImplementedError
+ return x, x_masks, img_sizes
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ t: torch.Tensor,
+ y: Optional[torch.Tensor] = None,
+ context: Optional[torch.Tensor] = None,
+ encoder_hidden_states_llama3=None,
+ image_cond=None,
+ control = None,
+ transformer_options = {},
+ ) -> torch.Tensor:
+ bs, c, h, w = x.shape
+ if image_cond is not None:
+ x = torch.cat([x, image_cond], dim=-1)
+ hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
+ timesteps = t
+ pooled_embeds = y
+ T5_encoder_hidden_states = context
+
+ img_sizes = None
+
+ # spatial forward
+ batch_size = hidden_states.shape[0]
+ hidden_states_type = hidden_states.dtype
+
+ # 0. time
+ timesteps = self.expand_timesteps(timesteps, batch_size, hidden_states.device)
+ timesteps = self.t_embedder(timesteps, hidden_states_type)
+ p_embedder = self.p_embedder(pooled_embeds)
+ adaln_input = timesteps + p_embedder
+
+ hidden_states, image_tokens_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
+ if image_tokens_masks is None:
+ pH, pW = img_sizes[0]
+ img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=batch_size)
+ hidden_states = self.x_embedder(hidden_states)
+
+ # T5_encoder_hidden_states = encoder_hidden_states[0]
+ encoder_hidden_states = encoder_hidden_states_llama3.movedim(1, 0)
+ encoder_hidden_states = [encoder_hidden_states[k] for k in self.llama_layers]
+
+ if self.caption_projection is not None:
+ new_encoder_hidden_states = []
+ for i, enc_hidden_state in enumerate(encoder_hidden_states):
+ enc_hidden_state = self.caption_projection[i](enc_hidden_state)
+ enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
+ new_encoder_hidden_states.append(enc_hidden_state)
+ encoder_hidden_states = new_encoder_hidden_states
+ T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
+ T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
+ encoder_hidden_states.append(T5_encoder_hidden_states)
+
+ txt_ids = torch.zeros(
+ batch_size,
+ encoder_hidden_states[-1].shape[1] + encoder_hidden_states[-2].shape[1] + encoder_hidden_states[0].shape[1],
+ 3,
+ device=img_ids.device, dtype=img_ids.dtype
+ )
+ ids = torch.cat((img_ids, txt_ids), dim=1)
+ rope = self.pe_embedder(ids)
+
+ # 2. Blocks
+ block_id = 0
+ initial_encoder_hidden_states = torch.cat([encoder_hidden_states[-1], encoder_hidden_states[-2]], dim=1)
+ initial_encoder_hidden_states_seq_len = initial_encoder_hidden_states.shape[1]
+ for bid, block in enumerate(self.double_stream_blocks):
+ cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
+ cur_encoder_hidden_states = torch.cat([initial_encoder_hidden_states, cur_llama31_encoder_hidden_states], dim=1)
+ hidden_states, initial_encoder_hidden_states = block(
+ image_tokens = hidden_states,
+ image_tokens_masks = image_tokens_masks,
+ text_tokens = cur_encoder_hidden_states,
+ adaln_input = adaln_input,
+ rope = rope,
+ )
+ initial_encoder_hidden_states = initial_encoder_hidden_states[:, :initial_encoder_hidden_states_seq_len]
+ block_id += 1
+
+ image_tokens_seq_len = hidden_states.shape[1]
+ hidden_states = torch.cat([hidden_states, initial_encoder_hidden_states], dim=1)
+ hidden_states_seq_len = hidden_states.shape[1]
+ if image_tokens_masks is not None:
+ encoder_attention_mask_ones = torch.ones(
+ (batch_size, initial_encoder_hidden_states.shape[1] + cur_llama31_encoder_hidden_states.shape[1]),
+ device=image_tokens_masks.device, dtype=image_tokens_masks.dtype
+ )
+ image_tokens_masks = torch.cat([image_tokens_masks, encoder_attention_mask_ones], dim=1)
+
+ for bid, block in enumerate(self.single_stream_blocks):
+ cur_llama31_encoder_hidden_states = encoder_hidden_states[block_id]
+ hidden_states = torch.cat([hidden_states, cur_llama31_encoder_hidden_states], dim=1)
+ hidden_states = block(
+ image_tokens=hidden_states,
+ image_tokens_masks=image_tokens_masks,
+ text_tokens=None,
+ adaln_input=adaln_input,
+ rope=rope,
+ )
+ hidden_states = hidden_states[:, :hidden_states_seq_len]
+ block_id += 1
+
+ hidden_states = hidden_states[:, :image_tokens_seq_len, ...]
+ output = self.final_layer(hidden_states, adaln_input)
+ output = self.unpatchify(output, img_sizes)
+ return -output[:, :, :h, :w]
diff --git a/comfy/ldm/hunyuan3d/model.py b/comfy/ldm/hunyuan3d/model.py
new file mode 100644
index 000000000..4e18358f0
--- /dev/null
+++ b/comfy/ldm/hunyuan3d/model.py
@@ -0,0 +1,135 @@
+import torch
+from torch import nn
+from comfy.ldm.flux.layers import (
+ DoubleStreamBlock,
+ LastLayer,
+ MLPEmbedder,
+ SingleStreamBlock,
+ timestep_embedding,
+)
+
+
+class Hunyuan3Dv2(nn.Module):
+ def __init__(
+ self,
+ in_channels=64,
+ context_in_dim=1536,
+ hidden_size=1024,
+ mlp_ratio=4.0,
+ num_heads=16,
+ depth=16,
+ depth_single_blocks=32,
+ qkv_bias=True,
+ guidance_embed=False,
+ image_model=None,
+ dtype=None,
+ device=None,
+ operations=None
+ ):
+ super().__init__()
+ self.dtype = dtype
+
+ if hidden_size % num_heads != 0:
+ raise ValueError(
+ f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
+ )
+
+ self.max_period = 1000 # While reimplementing the model I noticed that they messed up. This 1000 value was meant to be the time_factor but they set the max_period instead
+ self.latent_in = operations.Linear(in_channels, hidden_size, bias=True, dtype=dtype, device=device)
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations)
+ self.guidance_in = (
+ MLPEmbedder(in_dim=256, hidden_dim=hidden_size, dtype=dtype, device=device, operations=operations) if guidance_embed else None
+ )
+ self.cond_in = operations.Linear(context_in_dim, hidden_size, dtype=dtype, device=device)
+ self.double_blocks = nn.ModuleList(
+ [
+ DoubleStreamBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ dtype=dtype, device=device, operations=operations
+ )
+ for _ in range(depth)
+ ]
+ )
+ self.single_blocks = nn.ModuleList(
+ [
+ SingleStreamBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ dtype=dtype, device=device, operations=operations
+ )
+ for _ in range(depth_single_blocks)
+ ]
+ )
+ self.final_layer = LastLayer(hidden_size, 1, in_channels, dtype=dtype, device=device, operations=operations)
+
+ def forward(self, x, timestep, context, guidance=None, transformer_options={}, **kwargs):
+ x = x.movedim(-1, -2)
+ timestep = 1.0 - timestep
+ txt = context
+ img = self.latent_in(x)
+
+ vec = self.time_in(timestep_embedding(timestep, 256, self.max_period).to(dtype=img.dtype))
+ if self.guidance_in is not None:
+ if guidance is not None:
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256, self.max_period).to(img.dtype))
+
+ txt = self.cond_in(txt)
+ pe = None
+ attn_mask = None
+
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+ for i, block in enumerate(self.double_blocks):
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["img"], out["txt"] = block(img=args["img"],
+ txt=args["txt"],
+ vec=args["vec"],
+ pe=args["pe"],
+ attn_mask=args.get("attn_mask"))
+ return out
+
+ out = blocks_replace[("double_block", i)]({"img": img,
+ "txt": txt,
+ "vec": vec,
+ "pe": pe,
+ "attn_mask": attn_mask},
+ {"original_block": block_wrap})
+ txt = out["txt"]
+ img = out["img"]
+ else:
+ img, txt = block(img=img,
+ txt=txt,
+ vec=vec,
+ pe=pe,
+ attn_mask=attn_mask)
+
+ img = torch.cat((txt, img), 1)
+
+ for i, block in enumerate(self.single_blocks):
+ if ("single_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["img"] = block(args["img"],
+ vec=args["vec"],
+ pe=args["pe"],
+ attn_mask=args.get("attn_mask"))
+ return out
+
+ out = blocks_replace[("single_block", i)]({"img": img,
+ "vec": vec,
+ "pe": pe,
+ "attn_mask": attn_mask},
+ {"original_block": block_wrap})
+ img = out["img"]
+ else:
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
+
+ img = img[:, txt.shape[1]:, ...]
+ img = self.final_layer(img, vec)
+ return img.movedim(-2, -1) * (-1.0)
diff --git a/comfy/ldm/hunyuan3d/vae.py b/comfy/ldm/hunyuan3d/vae.py
new file mode 100644
index 000000000..5eb2c6548
--- /dev/null
+++ b/comfy/ldm/hunyuan3d/vae.py
@@ -0,0 +1,587 @@
+# Original: https://github.com/Tencent/Hunyuan3D-2/blob/main/hy3dgen/shapegen/models/autoencoders/model.py
+# Since the header on their VAE source file was a bit confusing we asked for permission to use this code from tencent under the GPL license used in ComfyUI.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+from typing import Union, Tuple, List, Callable, Optional
+
+import numpy as np
+from einops import repeat, rearrange
+from tqdm import tqdm
+import logging
+
+import comfy.ops
+ops = comfy.ops.disable_weight_init
+
+def generate_dense_grid_points(
+ bbox_min: np.ndarray,
+ bbox_max: np.ndarray,
+ octree_resolution: int,
+ indexing: str = "ij",
+):
+ length = bbox_max - bbox_min
+ num_cells = octree_resolution
+
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
+ xyz = np.stack((xs, ys, zs), axis=-1)
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
+
+ return xyz, grid_size, length
+
+
+class VanillaVolumeDecoder:
+ @torch.no_grad()
+ def __call__(
+ self,
+ latents: torch.FloatTensor,
+ geo_decoder: Callable,
+ bounds: Union[Tuple[float], List[float], float] = 1.01,
+ num_chunks: int = 10000,
+ octree_resolution: int = None,
+ enable_pbar: bool = True,
+ **kwargs,
+ ):
+ device = latents.device
+ dtype = latents.dtype
+ batch_size = latents.shape[0]
+
+ # 1. generate query points
+ if isinstance(bounds, float):
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
+
+ bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6])
+ xyz_samples, grid_size, length = generate_dense_grid_points(
+ bbox_min=bbox_min,
+ bbox_max=bbox_max,
+ octree_resolution=octree_resolution,
+ indexing="ij"
+ )
+ xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
+
+ # 2. latents to 3d volume
+ batch_logits = []
+ for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc="Volume Decoding",
+ disable=not enable_pbar):
+ chunk_queries = xyz_samples[start: start + num_chunks, :]
+ chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size)
+ logits = geo_decoder(queries=chunk_queries, latents=latents)
+ batch_logits.append(logits)
+
+ grid_logits = torch.cat(batch_logits, dim=1)
+ grid_logits = grid_logits.view((batch_size, *grid_size)).float()
+
+ return grid_logits
+
+
+class FourierEmbedder(nn.Module):
+ """The sin/cosine positional embedding. Given an input tensor `x` of shape [n_batch, ..., c_dim], it converts
+ each feature dimension of `x[..., i]` into:
+ [
+ sin(x[..., i]),
+ sin(f_1*x[..., i]),
+ sin(f_2*x[..., i]),
+ ...
+ sin(f_N * x[..., i]),
+ cos(x[..., i]),
+ cos(f_1*x[..., i]),
+ cos(f_2*x[..., i]),
+ ...
+ cos(f_N * x[..., i]),
+ x[..., i] # only present if include_input is True.
+ ], here f_i is the frequency.
+
+ Denote the space is [0 / num_freqs, 1 / num_freqs, 2 / num_freqs, 3 / num_freqs, ..., (num_freqs - 1) / num_freqs].
+ If logspace is True, then the frequency f_i is [2^(0 / num_freqs), ..., 2^(i / num_freqs), ...];
+ Otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)].
+
+ Args:
+ num_freqs (int): the number of frequencies, default is 6;
+ logspace (bool): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1)];
+ input_dim (int): the input dimension, default is 3;
+ include_input (bool): include the input tensor or not, default is True.
+
+ Attributes:
+ frequencies (torch.Tensor): If logspace is True, then the frequency f_i is [..., 2^(i / num_freqs), ...],
+ otherwise, the frequencies are linearly spaced between [1.0, 2^(num_freqs - 1);
+
+ out_dim (int): the embedding size, if include_input is True, it is input_dim * (num_freqs * 2 + 1),
+ otherwise, it is input_dim * num_freqs * 2.
+
+ """
+
+ def __init__(self,
+ num_freqs: int = 6,
+ logspace: bool = True,
+ input_dim: int = 3,
+ include_input: bool = True,
+ include_pi: bool = True) -> None:
+
+ """The initialization"""
+
+ super().__init__()
+
+ if logspace:
+ frequencies = 2.0 ** torch.arange(
+ num_freqs,
+ dtype=torch.float32
+ )
+ else:
+ frequencies = torch.linspace(
+ 1.0,
+ 2.0 ** (num_freqs - 1),
+ num_freqs,
+ dtype=torch.float32
+ )
+
+ if include_pi:
+ frequencies *= torch.pi
+
+ self.register_buffer("frequencies", frequencies, persistent=False)
+ self.include_input = include_input
+ self.num_freqs = num_freqs
+
+ self.out_dim = self.get_dims(input_dim)
+
+ def get_dims(self, input_dim):
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
+
+ return out_dim
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """ Forward process.
+
+ Args:
+ x: tensor of shape [..., dim]
+
+ Returns:
+ embedding: an embedding of `x` of shape [..., dim * (num_freqs * 2 + temp)]
+ where temp is 1 if include_input is True and 0 otherwise.
+ """
+
+ if self.num_freqs > 0:
+ embed = (x[..., None].contiguous() * self.frequencies.to(device=x.device, dtype=x.dtype)).view(*x.shape[:-1], -1)
+ if self.include_input:
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
+ else:
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
+ else:
+ return x
+
+
+class CrossAttentionProcessor:
+ def __call__(self, attn, q, k, v):
+ out = F.scaled_dot_product_attention(q, k, v)
+ return out
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+ """
+
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
+ 'survival rate' as the argument.
+
+ """
+ if self.drop_prob == 0. or not self.training:
+ return x
+ keep_prob = 1 - self.drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and self.scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+ def extra_repr(self):
+ return f'drop_prob={round(self.drop_prob, 3):0.3f}'
+
+
+class MLP(nn.Module):
+ def __init__(
+ self, *,
+ width: int,
+ expand_ratio: int = 4,
+ output_width: int = None,
+ drop_path_rate: float = 0.0
+ ):
+ super().__init__()
+ self.width = width
+ self.c_fc = ops.Linear(width, width * expand_ratio)
+ self.c_proj = ops.Linear(width * expand_ratio, output_width if output_width is not None else width)
+ self.gelu = nn.GELU()
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ def forward(self, x):
+ return self.drop_path(self.c_proj(self.gelu(self.c_fc(x))))
+
+
+class QKVMultiheadCrossAttention(nn.Module):
+ def __init__(
+ self,
+ *,
+ heads: int,
+ width=None,
+ qk_norm=False,
+ norm_layer=ops.LayerNorm
+ ):
+ super().__init__()
+ self.heads = heads
+ self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
+
+ self.attn_processor = CrossAttentionProcessor()
+
+ def forward(self, q, kv):
+ _, n_ctx, _ = q.shape
+ bs, n_data, width = kv.shape
+ attn_ch = width // self.heads // 2
+ q = q.view(bs, n_ctx, self.heads, -1)
+ kv = kv.view(bs, n_data, self.heads, -1)
+ k, v = torch.split(kv, attn_ch, dim=-1)
+
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+ q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
+ out = self.attn_processor(self, q, k, v)
+ out = out.transpose(1, 2).reshape(bs, n_ctx, -1)
+ return out
+
+
+class MultiheadCrossAttention(nn.Module):
+ def __init__(
+ self,
+ *,
+ width: int,
+ heads: int,
+ qkv_bias: bool = True,
+ data_width: Optional[int] = None,
+ norm_layer=ops.LayerNorm,
+ qk_norm: bool = False,
+ kv_cache: bool = False,
+ ):
+ super().__init__()
+ self.width = width
+ self.heads = heads
+ self.data_width = width if data_width is None else data_width
+ self.c_q = ops.Linear(width, width, bias=qkv_bias)
+ self.c_kv = ops.Linear(self.data_width, width * 2, bias=qkv_bias)
+ self.c_proj = ops.Linear(width, width)
+ self.attention = QKVMultiheadCrossAttention(
+ heads=heads,
+ width=width,
+ norm_layer=norm_layer,
+ qk_norm=qk_norm
+ )
+ self.kv_cache = kv_cache
+ self.data = None
+
+ def forward(self, x, data):
+ x = self.c_q(x)
+ if self.kv_cache:
+ if self.data is None:
+ self.data = self.c_kv(data)
+ logging.info('Save kv cache,this should be called only once for one mesh')
+ data = self.data
+ else:
+ data = self.c_kv(data)
+ x = self.attention(x, data)
+ x = self.c_proj(x)
+ return x
+
+
+class ResidualCrossAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ width: int,
+ heads: int,
+ mlp_expand_ratio: int = 4,
+ data_width: Optional[int] = None,
+ qkv_bias: bool = True,
+ norm_layer=ops.LayerNorm,
+ qk_norm: bool = False
+ ):
+ super().__init__()
+
+ if data_width is None:
+ data_width = width
+
+ self.attn = MultiheadCrossAttention(
+ width=width,
+ heads=heads,
+ data_width=data_width,
+ qkv_bias=qkv_bias,
+ norm_layer=norm_layer,
+ qk_norm=qk_norm
+ )
+ self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
+ self.ln_2 = norm_layer(data_width, elementwise_affine=True, eps=1e-6)
+ self.ln_3 = norm_layer(width, elementwise_affine=True, eps=1e-6)
+ self.mlp = MLP(width=width, expand_ratio=mlp_expand_ratio)
+
+ def forward(self, x: torch.Tensor, data: torch.Tensor):
+ x = x + self.attn(self.ln_1(x), self.ln_2(data))
+ x = x + self.mlp(self.ln_3(x))
+ return x
+
+
+class QKVMultiheadAttention(nn.Module):
+ def __init__(
+ self,
+ *,
+ heads: int,
+ width=None,
+ qk_norm=False,
+ norm_layer=ops.LayerNorm
+ ):
+ super().__init__()
+ self.heads = heads
+ self.q_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
+ self.k_norm = norm_layer(width // heads, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity()
+
+ def forward(self, qkv):
+ bs, n_ctx, width = qkv.shape
+ attn_ch = width // self.heads // 3
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
+
+ q = self.q_norm(q)
+ k = self.k_norm(k)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n h d -> b h n d', h=self.heads), (q, k, v))
+ out = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(bs, n_ctx, -1)
+ return out
+
+
+class MultiheadAttention(nn.Module):
+ def __init__(
+ self,
+ *,
+ width: int,
+ heads: int,
+ qkv_bias: bool,
+ norm_layer=ops.LayerNorm,
+ qk_norm: bool = False,
+ drop_path_rate: float = 0.0
+ ):
+ super().__init__()
+ self.width = width
+ self.heads = heads
+ self.c_qkv = ops.Linear(width, width * 3, bias=qkv_bias)
+ self.c_proj = ops.Linear(width, width)
+ self.attention = QKVMultiheadAttention(
+ heads=heads,
+ width=width,
+ norm_layer=norm_layer,
+ qk_norm=qk_norm
+ )
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
+
+ def forward(self, x):
+ x = self.c_qkv(x)
+ x = self.attention(x)
+ x = self.drop_path(self.c_proj(x))
+ return x
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(
+ self,
+ *,
+ width: int,
+ heads: int,
+ qkv_bias: bool = True,
+ norm_layer=ops.LayerNorm,
+ qk_norm: bool = False,
+ drop_path_rate: float = 0.0,
+ ):
+ super().__init__()
+ self.attn = MultiheadAttention(
+ width=width,
+ heads=heads,
+ qkv_bias=qkv_bias,
+ norm_layer=norm_layer,
+ qk_norm=qk_norm,
+ drop_path_rate=drop_path_rate
+ )
+ self.ln_1 = norm_layer(width, elementwise_affine=True, eps=1e-6)
+ self.mlp = MLP(width=width, drop_path_rate=drop_path_rate)
+ self.ln_2 = norm_layer(width, elementwise_affine=True, eps=1e-6)
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attn(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ *,
+ width: int,
+ layers: int,
+ heads: int,
+ qkv_bias: bool = True,
+ norm_layer=ops.LayerNorm,
+ qk_norm: bool = False,
+ drop_path_rate: float = 0.0
+ ):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.ModuleList(
+ [
+ ResidualAttentionBlock(
+ width=width,
+ heads=heads,
+ qkv_bias=qkv_bias,
+ norm_layer=norm_layer,
+ qk_norm=qk_norm,
+ drop_path_rate=drop_path_rate
+ )
+ for _ in range(layers)
+ ]
+ )
+
+ def forward(self, x: torch.Tensor):
+ for block in self.resblocks:
+ x = block(x)
+ return x
+
+
+class CrossAttentionDecoder(nn.Module):
+
+ def __init__(
+ self,
+ *,
+ out_channels: int,
+ fourier_embedder: FourierEmbedder,
+ width: int,
+ heads: int,
+ mlp_expand_ratio: int = 4,
+ downsample_ratio: int = 1,
+ enable_ln_post: bool = True,
+ qkv_bias: bool = True,
+ qk_norm: bool = False,
+ label_type: str = "binary"
+ ):
+ super().__init__()
+
+ self.enable_ln_post = enable_ln_post
+ self.fourier_embedder = fourier_embedder
+ self.downsample_ratio = downsample_ratio
+ self.query_proj = ops.Linear(self.fourier_embedder.out_dim, width)
+ if self.downsample_ratio != 1:
+ self.latents_proj = ops.Linear(width * downsample_ratio, width)
+ if self.enable_ln_post == False:
+ qk_norm = False
+ self.cross_attn_decoder = ResidualCrossAttentionBlock(
+ width=width,
+ mlp_expand_ratio=mlp_expand_ratio,
+ heads=heads,
+ qkv_bias=qkv_bias,
+ qk_norm=qk_norm
+ )
+
+ if self.enable_ln_post:
+ self.ln_post = ops.LayerNorm(width)
+ self.output_proj = ops.Linear(width, out_channels)
+ self.label_type = label_type
+ self.count = 0
+
+ def forward(self, queries=None, query_embeddings=None, latents=None):
+ if query_embeddings is None:
+ query_embeddings = self.query_proj(self.fourier_embedder(queries).to(latents.dtype))
+ self.count += query_embeddings.shape[1]
+ if self.downsample_ratio != 1:
+ latents = self.latents_proj(latents)
+ x = self.cross_attn_decoder(query_embeddings, latents)
+ if self.enable_ln_post:
+ x = self.ln_post(x)
+ occ = self.output_proj(x)
+ return occ
+
+
+class ShapeVAE(nn.Module):
+ def __init__(
+ self,
+ *,
+ embed_dim: int,
+ width: int,
+ heads: int,
+ num_decoder_layers: int,
+ geo_decoder_downsample_ratio: int = 1,
+ geo_decoder_mlp_expand_ratio: int = 4,
+ geo_decoder_ln_post: bool = True,
+ num_freqs: int = 8,
+ include_pi: bool = True,
+ qkv_bias: bool = True,
+ qk_norm: bool = False,
+ label_type: str = "binary",
+ drop_path_rate: float = 0.0,
+ scale_factor: float = 1.0,
+ ):
+ super().__init__()
+ self.geo_decoder_ln_post = geo_decoder_ln_post
+
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
+
+ self.post_kl = ops.Linear(embed_dim, width)
+
+ self.transformer = Transformer(
+ width=width,
+ layers=num_decoder_layers,
+ heads=heads,
+ qkv_bias=qkv_bias,
+ qk_norm=qk_norm,
+ drop_path_rate=drop_path_rate
+ )
+
+ self.geo_decoder = CrossAttentionDecoder(
+ fourier_embedder=self.fourier_embedder,
+ out_channels=1,
+ mlp_expand_ratio=geo_decoder_mlp_expand_ratio,
+ downsample_ratio=geo_decoder_downsample_ratio,
+ enable_ln_post=self.geo_decoder_ln_post,
+ width=width // geo_decoder_downsample_ratio,
+ heads=heads // geo_decoder_downsample_ratio,
+ qkv_bias=qkv_bias,
+ qk_norm=qk_norm,
+ label_type=label_type,
+ )
+
+ self.volume_decoder = VanillaVolumeDecoder()
+ self.scale_factor = scale_factor
+
+ def decode(self, latents, **kwargs):
+ latents = self.post_kl(latents.movedim(-2, -1))
+ latents = self.transformer(latents)
+
+ bounds = kwargs.get("bounds", 1.01)
+ num_chunks = kwargs.get("num_chunks", 8000)
+ octree_resolution = kwargs.get("octree_resolution", 256)
+ enable_pbar = kwargs.get("enable_pbar", True)
+
+ grid_logits = self.volume_decoder(latents, self.geo_decoder, bounds=bounds, num_chunks=num_chunks, octree_resolution=octree_resolution, enable_pbar=enable_pbar)
+ return grid_logits.movedim(-2, -1)
+
+ def encode(self, x):
+ return None
diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py
index d6d854089..fbd8d4196 100644
--- a/comfy/ldm/hunyuan_video/model.py
+++ b/comfy/ldm/hunyuan_video/model.py
@@ -227,6 +227,8 @@ class HunyuanVideo(nn.Module):
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
+ guiding_frame_index=None,
+ ref_latent=None,
control=None,
transformer_options={},
) -> Tensor:
@@ -237,12 +239,29 @@ class HunyuanVideo(nn.Module):
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))
- vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
+ if ref_latent is not None:
+ ref_latent_ids = self.img_ids(ref_latent)
+ ref_latent = self.img_in(ref_latent)
+ img = torch.cat([ref_latent, img], dim=-2)
+ ref_latent_ids[..., 0] = -1
+ ref_latent_ids[..., 2] += (initial_shape[-1] // self.patch_size[-1])
+ img_ids = torch.cat([ref_latent_ids, img_ids], dim=-2)
+
+ if guiding_frame_index is not None:
+ token_replace_vec = self.time_in(timestep_embedding(guiding_frame_index, 256, time_factor=1.0))
+ vec_ = self.vector_in(y[:, :self.params.vec_in_dim])
+ vec = torch.cat([(vec_ + token_replace_vec).unsqueeze(1), (vec_ + vec).unsqueeze(1)], dim=1)
+ frame_tokens = (initial_shape[-1] // self.patch_size[-1]) * (initial_shape[-2] // self.patch_size[-2])
+ modulation_dims = [(0, frame_tokens, 0), (frame_tokens, None, 1)]
+ modulation_dims_txt = [(0, None, 1)]
+ else:
+ vec = vec + self.vector_in(y[:, :self.params.vec_in_dim])
+ modulation_dims = None
+ modulation_dims_txt = None
if self.params.guidance_embed:
- if guidance is None:
- raise ValueError("Didn't get guidance strength for guidance distilled model.")
- vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
+ if guidance is not None:
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype))
if txt_mask is not None and not torch.is_floating_point(txt_mask):
txt_mask = (txt_mask - 1).to(img.dtype) * torch.finfo(img.dtype).max
@@ -265,14 +284,14 @@ class HunyuanVideo(nn.Module):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
+ out["img"], out["txt"] = block(img=args["img"], txt=args["txt"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims_img=args["modulation_dims_img"], modulation_dims_txt=args["modulation_dims_txt"])
return out
- out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
+ out = blocks_replace[("double_block", i)]({"img": img, "txt": txt, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims_img': modulation_dims, 'modulation_dims_txt': modulation_dims_txt}, {"original_block": block_wrap})
txt = out["txt"]
img = out["img"]
else:
- img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask)
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims_img=modulation_dims, modulation_dims_txt=modulation_dims_txt)
if control is not None: # Controlnet
control_i = control.get("input")
@@ -287,13 +306,13 @@ class HunyuanVideo(nn.Module):
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
- out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"])
+ out["img"] = block(args["img"], vec=args["vec"], pe=args["pe"], attn_mask=args["attention_mask"], modulation_dims=args["modulation_dims"])
return out
- out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask}, {"original_block": block_wrap})
+ out = blocks_replace[("single_block", i)]({"img": img, "vec": vec, "pe": pe, "attention_mask": attn_mask, 'modulation_dims': modulation_dims}, {"original_block": block_wrap})
img = out["img"]
else:
- img = block(img, vec=vec, pe=pe, attn_mask=attn_mask)
+ img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, modulation_dims=modulation_dims)
if control is not None: # Controlnet
control_o = control.get("output")
@@ -303,18 +322,20 @@ class HunyuanVideo(nn.Module):
img[:, : img_len] += add
img = img[:, : img_len]
+ if ref_latent is not None:
+ img = img[:, ref_latent.shape[1]:]
- img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
+ img = self.final_layer(img, vec, modulation_dims=modulation_dims) # (N, T, patch_size ** 2 * out_channels)
shape = initial_shape[-3:]
for i in range(len(shape)):
shape[i] = shape[i] // self.patch_size[i]
img = img.reshape([img.shape[0]] + shape + [self.out_channels] + self.patch_size)
img = img.permute(0, 4, 1, 5, 2, 6, 3, 7)
- img = img.reshape(initial_shape)
+ img = img.reshape(initial_shape[0], self.out_channels, initial_shape[2], initial_shape[3], initial_shape[4])
return img
- def forward(self, x, timestep, context, y, guidance, attention_mask=None, control=None, transformer_options={}, **kwargs):
+ def img_ids(self, x):
bs, c, t, h, w = x.shape
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
@@ -324,7 +345,11 @@ class HunyuanVideo(nn.Module):
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
- img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
+ return repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
+
+ def forward(self, x, timestep, context, y, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, control=None, transformer_options={}, **kwargs):
+ bs, c, t, h, w = x.shape
+ img_ids = self.img_ids(x)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
- out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, control, transformer_options)
+ out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, guidance, guiding_frame_index, ref_latent, control=control, transformer_options=transformer_options)
return out
diff --git a/comfy/ldm/hydit/models.py b/comfy/ldm/hydit/models.py
index 359f6a965..5ba2b76e0 100644
--- a/comfy/ldm/hydit/models.py
+++ b/comfy/ldm/hydit/models.py
@@ -3,7 +3,7 @@ import torch
import torch.nn as nn
import comfy.ops
-from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed, RMSNorm
+from comfy.ldm.modules.diffusionmodules.mmdit import Mlp, TimestepEmbedder, PatchEmbed
from comfy.ldm.modules.diffusionmodules.util import timestep_embedding
from torch.utils import checkpoint
@@ -51,7 +51,7 @@ class HunYuanDiTBlock(nn.Module):
if norm_type == "layer":
norm_layer = operations.LayerNorm
elif norm_type == "rms":
- norm_layer = RMSNorm
+ norm_layer = operations.RMSNorm
else:
raise ValueError(f"Unknown norm_type: {norm_type}")
diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py
index 2a02acd65..056e101a4 100644
--- a/comfy/ldm/lightricks/model.py
+++ b/comfy/ldm/lightricks/model.py
@@ -1,13 +1,12 @@
import torch
from torch import nn
import comfy.ldm.modules.attention
-from comfy.ldm.genmo.joint_model.layers import RMSNorm
import comfy.ldm.common_dit
from einops import rearrange
import math
from typing import Dict, Optional, Tuple
-from .symmetric_patchifier import SymmetricPatchifier
+from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
def get_timestep_embedding(
@@ -262,8 +261,8 @@ class CrossAttention(nn.Module):
self.heads = heads
self.dim_head = dim_head
- self.q_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
- self.k_norm = RMSNorm(inner_dim, dtype=dtype, device=device)
+ self.q_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
+ self.k_norm = operations.RMSNorm(inner_dim, dtype=dtype, device=device)
self.to_q = operations.Linear(query_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
@@ -377,12 +376,16 @@ class LTXVModel(torch.nn.Module):
positional_embedding_theta=10000.0,
positional_embedding_max_pos=[20, 2048, 2048],
+ causal_temporal_positioning=False,
+ vae_scale_factors=(8, 32, 32),
dtype=None, device=None, operations=None, **kwargs):
super().__init__()
self.generator = None
+ self.vae_scale_factors = vae_scale_factors
self.dtype = dtype
self.out_channels = in_channels
self.inner_dim = num_attention_heads * attention_head_dim
+ self.causal_temporal_positioning = causal_temporal_positioning
self.patchify_proj = operations.Linear(in_channels, self.inner_dim, bias=True, dtype=dtype, device=device)
@@ -416,42 +419,23 @@ class LTXVModel(torch.nn.Module):
self.patchifier = SymmetricPatchifier(1)
- def forward(self, x, timestep, context, attention_mask, frame_rate=25, guiding_latent=None, guiding_latent_noise_scale=0, transformer_options={}, **kwargs):
+ def forward(self, x, timestep, context, attention_mask, frame_rate=25, transformer_options={}, keyframe_idxs=None, **kwargs):
patches_replace = transformer_options.get("patches_replace", {})
- indices_grid = self.patchifier.get_grid(
- orig_num_frames=x.shape[2],
- orig_height=x.shape[3],
- orig_width=x.shape[4],
- batch_size=x.shape[0],
- scale_grid=((1 / frame_rate) * 8, 32, 32),
- device=x.device,
- )
-
- if guiding_latent is not None:
- ts = torch.ones([x.shape[0], 1, x.shape[2], x.shape[3], x.shape[4]], device=x.device, dtype=x.dtype)
- input_ts = timestep.view([timestep.shape[0]] + [1] * (x.ndim - 1))
- ts *= input_ts
- ts[:, :, 0] = guiding_latent_noise_scale * (input_ts[:, :, 0] ** 2)
- timestep = self.patchifier.patchify(ts)
- input_x = x.clone()
- x[:, :, 0] = guiding_latent[:, :, 0]
- if guiding_latent_noise_scale > 0:
- if self.generator is None:
- self.generator = torch.Generator(device=x.device).manual_seed(42)
- elif self.generator.device != x.device:
- self.generator = torch.Generator(device=x.device).set_state(self.generator.get_state())
-
- noise_shape = [guiding_latent.shape[0], guiding_latent.shape[1], 1, guiding_latent.shape[3], guiding_latent.shape[4]]
- scale = guiding_latent_noise_scale * (input_ts ** 2)
- guiding_noise = scale * torch.randn(size=noise_shape, device=x.device, generator=self.generator)
-
- x[:, :, 0] = guiding_noise[:, :, 0] + x[:, :, 0] * (1.0 - scale[:, :, 0])
-
-
orig_shape = list(x.shape)
- x = self.patchifier.patchify(x)
+ x, latent_coords = self.patchifier.patchify(x)
+ pixel_coords = latent_to_pixel_coords(
+ latent_coords=latent_coords,
+ scale_factors=self.vae_scale_factors,
+ causal_fix=self.causal_temporal_positioning,
+ )
+
+ if keyframe_idxs is not None:
+ pixel_coords[:, :, -keyframe_idxs.shape[2]:] = keyframe_idxs
+
+ fractional_coords = pixel_coords.to(torch.float32)
+ fractional_coords[:, 0] = fractional_coords[:, 0] * (1.0 / frame_rate)
x = self.patchify_proj(x)
timestep = timestep * 1000.0
@@ -459,7 +443,7 @@ class LTXVModel(torch.nn.Module):
if attention_mask is not None and not torch.is_floating_point(attention_mask):
attention_mask = (attention_mask - 1).to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])) * torch.finfo(x.dtype).max
- pe = precompute_freqs_cis(indices_grid, dim=self.inner_dim, out_dtype=x.dtype)
+ pe = precompute_freqs_cis(fractional_coords, dim=self.inner_dim, out_dtype=x.dtype)
batch_size = x.shape[0]
timestep, embedded_timestep = self.adaln_single(
@@ -519,8 +503,4 @@ class LTXVModel(torch.nn.Module):
out_channels=orig_shape[1] // math.prod(self.patchifier.patch_size),
)
- if guiding_latent is not None:
- x[:, :, 0] = (input_x[:, :, 0] - guiding_latent[:, :, 0]) / input_ts[:, :, 0]
-
- # print("res", x)
return x
diff --git a/comfy/ldm/lightricks/symmetric_patchifier.py b/comfy/ldm/lightricks/symmetric_patchifier.py
index c58dfb20b..4b9972b9f 100644
--- a/comfy/ldm/lightricks/symmetric_patchifier.py
+++ b/comfy/ldm/lightricks/symmetric_patchifier.py
@@ -6,16 +6,29 @@ from einops import rearrange
from torch import Tensor
-def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
- """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
- dims_to_append = target_dims - x.ndim
- if dims_to_append < 0:
- raise ValueError(
- f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
- )
- elif dims_to_append == 0:
- return x
- return x[(...,) + (None,) * dims_to_append]
+def latent_to_pixel_coords(
+ latent_coords: Tensor, scale_factors: Tuple[int, int, int], causal_fix: bool = False
+) -> Tensor:
+ """
+ Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
+ configuration.
+ Args:
+ latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
+ containing the latent corner coordinates of each token.
+ scale_factors (Tuple[int, int, int]): The scale factors of the VAE's latent space.
+ causal_fix (bool): Whether to take into account the different temporal scale
+ of the first frame. Default = False for backwards compatibility.
+ Returns:
+ Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
+ """
+ pixel_coords = (
+ latent_coords
+ * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
+ )
+ if causal_fix:
+ # Fix temporal scale for first frame to 1 due to causality
+ pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
+ return pixel_coords
class Patchifier(ABC):
@@ -44,29 +57,26 @@ class Patchifier(ABC):
def patch_size(self):
return self._patch_size
- def get_grid(
- self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device
+ def get_latent_coords(
+ self, latent_num_frames, latent_height, latent_width, batch_size, device
):
- f = orig_num_frames // self._patch_size[0]
- h = orig_height // self._patch_size[1]
- w = orig_width // self._patch_size[2]
- grid_h = torch.arange(h, dtype=torch.float32, device=device)
- grid_w = torch.arange(w, dtype=torch.float32, device=device)
- grid_f = torch.arange(f, dtype=torch.float32, device=device)
- grid = torch.meshgrid(grid_f, grid_h, grid_w, indexing='ij')
- grid = torch.stack(grid, dim=0)
- grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
-
- if scale_grid is not None:
- for i in range(3):
- if isinstance(scale_grid[i], Tensor):
- scale = append_dims(scale_grid[i], grid.ndim - 1)
- else:
- scale = scale_grid[i]
- grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i]
-
- grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size)
- return grid
+ """
+ Return a tensor of shape [batch_size, 3, num_patches] containing the
+ top-left corner latent coordinates of each latent patch.
+ The tensor is repeated for each batch element.
+ """
+ latent_sample_coords = torch.meshgrid(
+ torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
+ torch.arange(0, latent_height, self._patch_size[1], device=device),
+ torch.arange(0, latent_width, self._patch_size[2], device=device),
+ indexing="ij",
+ )
+ latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
+ latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
+ latent_coords = rearrange(
+ latent_coords, "b c f h w -> b c (f h w)", b=batch_size
+ )
+ return latent_coords
class SymmetricPatchifier(Patchifier):
@@ -74,6 +84,8 @@ class SymmetricPatchifier(Patchifier):
self,
latents: Tensor,
) -> Tuple[Tensor, Tensor]:
+ b, _, f, h, w = latents.shape
+ latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
latents = rearrange(
latents,
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
@@ -81,7 +93,7 @@ class SymmetricPatchifier(Patchifier):
p2=self._patch_size[1],
p3=self._patch_size[2],
)
- return latents
+ return latents, latent_coords
def unpatchify(
self,
diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py
index c572e7e86..70d612e86 100644
--- a/comfy/ldm/lightricks/vae/causal_conv3d.py
+++ b/comfy/ldm/lightricks/vae/causal_conv3d.py
@@ -15,6 +15,7 @@ class CausalConv3d(nn.Module):
stride: Union[int, Tuple[int]] = 1,
dilation: int = 1,
groups: int = 1,
+ spatial_padding_mode: str = "zeros",
**kwargs,
):
super().__init__()
@@ -38,7 +39,7 @@ class CausalConv3d(nn.Module):
stride=stride,
dilation=dilation,
padding=padding,
- padding_mode="zeros",
+ padding_mode=spatial_padding_mode,
groups=groups,
)
diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
index e0344deec..f91870d71 100644
--- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
+++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py
@@ -1,13 +1,15 @@
+from __future__ import annotations
import torch
from torch import nn
from functools import partial
import math
from einops import rearrange
-from typing import Optional, Tuple, Union
+from typing import List, Optional, Tuple, Union
from .conv_nd_factory import make_conv_nd, make_linear_nd
from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops
+
ops = comfy.ops.disable_weight_init
class Encoder(nn.Module):
@@ -32,7 +34,7 @@ class Encoder(nn.Module):
norm_layer (`str`, *optional*, defaults to `group_norm`):
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
latent_log_var (`str`, *optional*, defaults to `per_channel`):
- The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
+ The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
"""
def __init__(
@@ -40,12 +42,13 @@ class Encoder(nn.Module):
dims: Union[int, Tuple[int, int]] = 3,
in_channels: int = 3,
out_channels: int = 3,
- blocks=[("res_x", 1)],
+ blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128,
norm_num_groups: int = 32,
patch_size: Union[int, Tuple[int]] = 1,
norm_layer: str = "group_norm", # group_norm, pixel_norm
latent_log_var: str = "per_channel",
+ spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
@@ -65,6 +68,7 @@ class Encoder(nn.Module):
stride=1,
padding=1,
causal=True,
+ spatial_padding_mode=spatial_padding_mode,
)
self.down_blocks = nn.ModuleList([])
@@ -82,6 +86,7 @@ class Encoder(nn.Module):
resnet_eps=1e-6,
resnet_groups=norm_num_groups,
norm_layer=norm_layer,
+ spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
@@ -92,6 +97,7 @@ class Encoder(nn.Module):
eps=1e-6,
groups=norm_num_groups,
norm_layer=norm_layer,
+ spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = make_conv_nd(
@@ -101,6 +107,7 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 1, 1),
causal=True,
+ spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = make_conv_nd(
@@ -110,6 +117,7 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(1, 2, 2),
causal=True,
+ spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
block = make_conv_nd(
@@ -119,6 +127,7 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 2, 2),
causal=True,
+ spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all_x_y":
output_channel = block_params.get("multiplier", 2) * output_channel
@@ -129,6 +138,34 @@ class Encoder(nn.Module):
kernel_size=3,
stride=(2, 2, 2),
causal=True,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+ elif block_name == "compress_all_res":
+ output_channel = block_params.get("multiplier", 2) * output_channel
+ block = SpaceToDepthDownsample(
+ dims=dims,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ stride=(2, 2, 2),
+ spatial_padding_mode=spatial_padding_mode,
+ )
+ elif block_name == "compress_space_res":
+ output_channel = block_params.get("multiplier", 2) * output_channel
+ block = SpaceToDepthDownsample(
+ dims=dims,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ stride=(1, 2, 2),
+ spatial_padding_mode=spatial_padding_mode,
+ )
+ elif block_name == "compress_time_res":
+ output_channel = block_params.get("multiplier", 2) * output_channel
+ block = SpaceToDepthDownsample(
+ dims=dims,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ stride=(2, 1, 1),
+ spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown block: {block_name}")
@@ -152,10 +189,18 @@ class Encoder(nn.Module):
conv_out_channels *= 2
elif latent_log_var == "uniform":
conv_out_channels += 1
+ elif latent_log_var == "constant":
+ conv_out_channels += 1
elif latent_log_var != "none":
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
self.conv_out = make_conv_nd(
- dims, output_channel, conv_out_channels, 3, padding=1, causal=True
+ dims,
+ output_channel,
+ conv_out_channels,
+ 3,
+ padding=1,
+ causal=True,
+ spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
@@ -197,6 +242,15 @@ class Encoder(nn.Module):
sample = torch.cat([sample, repeated_last_channel], dim=1)
else:
raise ValueError(f"Invalid input shape: {sample.shape}")
+ elif self.latent_log_var == "constant":
+ sample = sample[:, :-1, ...]
+ approx_ln_0 = (
+ -30
+ ) # this is the minimal clamp value in DiagonalGaussianDistribution objects
+ sample = torch.cat(
+ [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
+ dim=1,
+ )
return sample
@@ -231,7 +285,7 @@ class Decoder(nn.Module):
dims,
in_channels: int = 3,
out_channels: int = 3,
- blocks=[("res_x", 1)],
+ blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
base_channels: int = 128,
layers_per_block: int = 2,
norm_num_groups: int = 32,
@@ -239,6 +293,7 @@ class Decoder(nn.Module):
norm_layer: str = "group_norm",
causal: bool = True,
timestep_conditioning: bool = False,
+ spatial_padding_mode: str = "zeros",
):
super().__init__()
self.patch_size = patch_size
@@ -264,6 +319,7 @@ class Decoder(nn.Module):
stride=1,
padding=1,
causal=True,
+ spatial_padding_mode=spatial_padding_mode,
)
self.up_blocks = nn.ModuleList([])
@@ -283,6 +339,7 @@ class Decoder(nn.Module):
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
+ spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "attn_res_x":
block = UNetMidBlock3D(
@@ -294,6 +351,7 @@ class Decoder(nn.Module):
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=timestep_conditioning,
attention_head_dim=block_params["attention_head_dim"],
+ spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "res_x_y":
output_channel = output_channel // block_params.get("multiplier", 2)
@@ -306,14 +364,21 @@ class Decoder(nn.Module):
norm_layer=norm_layer,
inject_noise=block_params.get("inject_noise", False),
timestep_conditioning=False,
+ spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
block = DepthToSpaceUpsample(
- dims=dims, in_channels=input_channel, stride=(2, 1, 1)
+ dims=dims,
+ in_channels=input_channel,
+ stride=(2, 1, 1),
+ spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
block = DepthToSpaceUpsample(
- dims=dims, in_channels=input_channel, stride=(1, 2, 2)
+ dims=dims,
+ in_channels=input_channel,
+ stride=(1, 2, 2),
+ spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
output_channel = output_channel // block_params.get("multiplier", 1)
@@ -323,6 +388,7 @@ class Decoder(nn.Module):
stride=(2, 2, 2),
residual=block_params.get("residual", False),
out_channels_reduction_factor=block_params.get("multiplier", 1),
+ spatial_padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unknown layer: {block_name}")
@@ -340,7 +406,13 @@ class Decoder(nn.Module):
self.conv_act = nn.SiLU()
self.conv_out = make_conv_nd(
- dims, output_channel, out_channels, 3, padding=1, causal=True
+ dims,
+ output_channel,
+ out_channels,
+ 3,
+ padding=1,
+ causal=True,
+ spatial_padding_mode=spatial_padding_mode,
)
self.gradient_checkpointing = False
@@ -433,6 +505,12 @@ class UNetMidBlock3D(nn.Module):
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
resnet_groups (`int`, *optional*, defaults to 32):
The number of groups to use in the group normalization layers of the resnet blocks.
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
+ inject_noise (`bool`, *optional*, defaults to `False`):
+ Whether to inject noise into the hidden states.
+ timestep_conditioning (`bool`, *optional*, defaults to `False`):
+ Whether to condition the hidden states on the timestep.
Returns:
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
@@ -451,6 +529,7 @@ class UNetMidBlock3D(nn.Module):
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
+ spatial_padding_mode: str = "zeros",
):
super().__init__()
resnet_groups = (
@@ -476,13 +555,17 @@ class UNetMidBlock3D(nn.Module):
norm_layer=norm_layer,
inject_noise=inject_noise,
timestep_conditioning=timestep_conditioning,
+ spatial_padding_mode=spatial_padding_mode,
)
for _ in range(num_layers)
]
)
def forward(
- self, hidden_states: torch.FloatTensor, causal: bool = True, timestep: Optional[torch.Tensor] = None
+ self,
+ hidden_states: torch.FloatTensor,
+ causal: bool = True,
+ timestep: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
timestep_embed = None
if self.timestep_conditioning:
@@ -507,9 +590,62 @@ class UNetMidBlock3D(nn.Module):
return hidden_states
+class SpaceToDepthDownsample(nn.Module):
+ def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
+ super().__init__()
+ self.stride = stride
+ self.group_size = in_channels * math.prod(stride) // out_channels
+ self.conv = make_conv_nd(
+ dims=dims,
+ in_channels=in_channels,
+ out_channels=out_channels // math.prod(stride),
+ kernel_size=3,
+ stride=1,
+ causal=True,
+ spatial_padding_mode=spatial_padding_mode,
+ )
+
+ def forward(self, x, causal: bool = True):
+ if self.stride[0] == 2:
+ x = torch.cat(
+ [x[:, :, :1, :, :], x], dim=2
+ ) # duplicate first frames for padding
+
+ # skip connection
+ x_in = rearrange(
+ x,
+ "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
+ p1=self.stride[0],
+ p2=self.stride[1],
+ p3=self.stride[2],
+ )
+ x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
+ x_in = x_in.mean(dim=2)
+
+ # conv
+ x = self.conv(x, causal=causal)
+ x = rearrange(
+ x,
+ "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
+ p1=self.stride[0],
+ p2=self.stride[1],
+ p3=self.stride[2],
+ )
+
+ x = x + x_in
+
+ return x
+
+
class DepthToSpaceUpsample(nn.Module):
def __init__(
- self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1
+ self,
+ dims,
+ in_channels,
+ stride,
+ residual=False,
+ out_channels_reduction_factor=1,
+ spatial_padding_mode="zeros",
):
super().__init__()
self.stride = stride
@@ -523,6 +659,7 @@ class DepthToSpaceUpsample(nn.Module):
kernel_size=3,
stride=1,
causal=True,
+ spatial_padding_mode=spatial_padding_mode,
)
self.residual = residual
self.out_channels_reduction_factor = out_channels_reduction_factor
@@ -558,7 +695,7 @@ class DepthToSpaceUpsample(nn.Module):
class LayerNorm(nn.Module):
def __init__(self, dim, eps, elementwise_affine=True) -> None:
super().__init__()
- self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
+ self.norm = ops.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(self, x):
x = rearrange(x, "b c d h w -> b d h w c")
@@ -591,6 +728,7 @@ class ResnetBlock3D(nn.Module):
norm_layer: str = "group_norm",
inject_noise: bool = False,
timestep_conditioning: bool = False,
+ spatial_padding_mode: str = "zeros",
):
super().__init__()
self.in_channels = in_channels
@@ -617,6 +755,7 @@ class ResnetBlock3D(nn.Module):
stride=1,
padding=1,
causal=True,
+ spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
@@ -641,6 +780,7 @@ class ResnetBlock3D(nn.Module):
stride=1,
padding=1,
causal=True,
+ spatial_padding_mode=spatial_padding_mode,
)
if inject_noise:
@@ -801,9 +941,44 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module):
- def __init__(self, version=0):
+ def __init__(self, version=0, config=None):
super().__init__()
+ if config is None:
+ config = self.guess_config(version)
+
+ self.timestep_conditioning = config.get("timestep_conditioning", False)
+ double_z = config.get("double_z", True)
+ latent_log_var = config.get(
+ "latent_log_var", "per_channel" if double_z else "none"
+ )
+
+ self.encoder = Encoder(
+ dims=config["dims"],
+ in_channels=config.get("in_channels", 3),
+ out_channels=config["latent_channels"],
+ blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
+ patch_size=config.get("patch_size", 1),
+ latent_log_var=latent_log_var,
+ norm_layer=config.get("norm_layer", "group_norm"),
+ spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
+ )
+
+ self.decoder = Decoder(
+ dims=config["dims"],
+ in_channels=config["latent_channels"],
+ out_channels=config.get("out_channels", 3),
+ blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
+ patch_size=config.get("patch_size", 1),
+ norm_layer=config.get("norm_layer", "group_norm"),
+ causal=config.get("causal_decoder", False),
+ timestep_conditioning=self.timestep_conditioning,
+ spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
+ )
+
+ self.per_channel_statistics = processor()
+
+ def guess_config(self, version):
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
@@ -830,7 +1005,7 @@ class VideoVAE(nn.Module):
"use_quant_conv": False,
"causal_decoder": False,
}
- else:
+ elif version == 1:
config = {
"_class_name": "CausalVideoAutoencoder",
"dims": 3,
@@ -866,37 +1041,47 @@ class VideoVAE(nn.Module):
"causal_decoder": False,
"timestep_conditioning": True,
}
-
- double_z = config.get("double_z", True)
- latent_log_var = config.get(
- "latent_log_var", "per_channel" if double_z else "none"
- )
-
- self.encoder = Encoder(
- dims=config["dims"],
- in_channels=config.get("in_channels", 3),
- out_channels=config["latent_channels"],
- blocks=config.get("encoder_blocks", config.get("encoder_blocks", config.get("blocks"))),
- patch_size=config.get("patch_size", 1),
- latent_log_var=latent_log_var,
- norm_layer=config.get("norm_layer", "group_norm"),
- )
-
- self.decoder = Decoder(
- dims=config["dims"],
- in_channels=config["latent_channels"],
- out_channels=config.get("out_channels", 3),
- blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
- patch_size=config.get("patch_size", 1),
- norm_layer=config.get("norm_layer", "group_norm"),
- causal=config.get("causal_decoder", False),
- timestep_conditioning=config.get("timestep_conditioning", False),
- )
-
- self.timestep_conditioning = config.get("timestep_conditioning", False)
- self.per_channel_statistics = processor()
+ else:
+ config = {
+ "_class_name": "CausalVideoAutoencoder",
+ "dims": 3,
+ "in_channels": 3,
+ "out_channels": 3,
+ "latent_channels": 128,
+ "encoder_blocks": [
+ ["res_x", {"num_layers": 4}],
+ ["compress_space_res", {"multiplier": 2}],
+ ["res_x", {"num_layers": 6}],
+ ["compress_time_res", {"multiplier": 2}],
+ ["res_x", {"num_layers": 6}],
+ ["compress_all_res", {"multiplier": 2}],
+ ["res_x", {"num_layers": 2}],
+ ["compress_all_res", {"multiplier": 2}],
+ ["res_x", {"num_layers": 2}]
+ ],
+ "decoder_blocks": [
+ ["res_x", {"num_layers": 5, "inject_noise": False}],
+ ["compress_all", {"residual": True, "multiplier": 2}],
+ ["res_x", {"num_layers": 5, "inject_noise": False}],
+ ["compress_all", {"residual": True, "multiplier": 2}],
+ ["res_x", {"num_layers": 5, "inject_noise": False}],
+ ["compress_all", {"residual": True, "multiplier": 2}],
+ ["res_x", {"num_layers": 5, "inject_noise": False}]
+ ],
+ "scaling_factor": 1.0,
+ "norm_layer": "pixel_norm",
+ "patch_size": 4,
+ "latent_log_var": "uniform",
+ "use_quant_conv": False,
+ "causal_decoder": False,
+ "timestep_conditioning": True
+ }
+ return config
def encode(self, x):
+ frames_count = x.shape[2]
+ if ((frames_count - 1) % 8) != 0:
+ raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
diff --git a/comfy/ldm/lightricks/vae/conv_nd_factory.py b/comfy/ldm/lightricks/vae/conv_nd_factory.py
index 52df4ee22..b4026b14f 100644
--- a/comfy/ldm/lightricks/vae/conv_nd_factory.py
+++ b/comfy/ldm/lightricks/vae/conv_nd_factory.py
@@ -17,7 +17,11 @@ def make_conv_nd(
groups=1,
bias=True,
causal=False,
+ spatial_padding_mode="zeros",
+ temporal_padding_mode="zeros",
):
+ if not (spatial_padding_mode == temporal_padding_mode or causal):
+ raise NotImplementedError("spatial and temporal padding modes must be equal")
if dims == 2:
return ops.Conv2d(
in_channels=in_channels,
@@ -28,6 +32,7 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
+ padding_mode=spatial_padding_mode,
)
elif dims == 3:
if causal:
@@ -40,6 +45,7 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
+ spatial_padding_mode=spatial_padding_mode,
)
return ops.Conv3d(
in_channels=in_channels,
@@ -50,6 +56,7 @@ def make_conv_nd(
dilation=dilation,
groups=groups,
bias=bias,
+ padding_mode=spatial_padding_mode,
)
elif dims == (2, 1):
return DualConv3d(
@@ -59,6 +66,7 @@ def make_conv_nd(
stride=stride,
padding=padding,
bias=bias,
+ padding_mode=spatial_padding_mode,
)
else:
raise ValueError(f"unsupported dimensions: {dims}")
diff --git a/comfy/ldm/lightricks/vae/dual_conv3d.py b/comfy/ldm/lightricks/vae/dual_conv3d.py
index 6bd54c0a6..dcf889296 100644
--- a/comfy/ldm/lightricks/vae/dual_conv3d.py
+++ b/comfy/ldm/lightricks/vae/dual_conv3d.py
@@ -18,11 +18,13 @@ class DualConv3d(nn.Module):
dilation: Union[int, Tuple[int, int, int]] = 1,
groups=1,
bias=True,
+ padding_mode="zeros",
):
super(DualConv3d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
+ self.padding_mode = padding_mode
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
@@ -108,6 +110,7 @@ class DualConv3d(nn.Module):
self.padding1,
self.dilation1,
self.groups,
+ padding_mode=self.padding_mode,
)
if skip_time_conv:
@@ -122,6 +125,7 @@ class DualConv3d(nn.Module):
self.padding2,
self.dilation2,
self.groups,
+ padding_mode=self.padding_mode,
)
return x
@@ -137,7 +141,16 @@ class DualConv3d(nn.Module):
stride1 = (self.stride1[1], self.stride1[2])
padding1 = (self.padding1[1], self.padding1[2])
dilation1 = (self.dilation1[1], self.dilation1[2])
- x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups)
+ x = F.conv2d(
+ x,
+ weight1,
+ self.bias1,
+ stride1,
+ padding1,
+ dilation1,
+ self.groups,
+ padding_mode=self.padding_mode,
+ )
_, _, h, w = x.shape
@@ -154,7 +167,16 @@ class DualConv3d(nn.Module):
stride2 = self.stride2[0]
padding2 = self.padding2[0]
dilation2 = self.dilation2[0]
- x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups)
+ x = F.conv1d(
+ x,
+ weight2,
+ self.bias2,
+ stride2,
+ padding2,
+ dilation2,
+ self.groups,
+ padding_mode=self.padding_mode,
+ )
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
return x
diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py
new file mode 100644
index 000000000..f8dc4d7db
--- /dev/null
+++ b/comfy/ldm/lumina/model.py
@@ -0,0 +1,622 @@
+# Code from: https://github.com/Alpha-VLLM/Lumina-Image-2.0/blob/main/models/model.py
+from __future__ import annotations
+
+from typing import List, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import comfy.ldm.common_dit
+
+from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder
+from comfy.ldm.modules.attention import optimized_attention_masked
+from comfy.ldm.flux.layers import EmbedND
+
+
+def modulate(x, scale):
+ return x * (1 + scale.unsqueeze(1))
+
+#############################################################################
+# Core NextDiT Model #
+#############################################################################
+
+
+class JointAttention(nn.Module):
+ """Multi-head attention module."""
+
+ def __init__(
+ self,
+ dim: int,
+ n_heads: int,
+ n_kv_heads: Optional[int],
+ qk_norm: bool,
+ operation_settings={},
+ ):
+ """
+ Initialize the Attention module.
+
+ Args:
+ dim (int): Number of input dimensions.
+ n_heads (int): Number of heads.
+ n_kv_heads (Optional[int]): Number of kv heads, if using GQA.
+
+ """
+ super().__init__()
+ self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
+ self.n_local_heads = n_heads
+ self.n_local_kv_heads = self.n_kv_heads
+ self.n_rep = self.n_local_heads // self.n_local_kv_heads
+ self.head_dim = dim // n_heads
+
+ self.qkv = operation_settings.get("operations").Linear(
+ dim,
+ (n_heads + self.n_kv_heads + self.n_kv_heads) * self.head_dim,
+ bias=False,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ )
+ self.out = operation_settings.get("operations").Linear(
+ n_heads * self.head_dim,
+ dim,
+ bias=False,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ )
+
+ if qk_norm:
+ self.q_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.k_norm = operation_settings.get("operations").RMSNorm(self.head_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ else:
+ self.q_norm = self.k_norm = nn.Identity()
+
+ @staticmethod
+ def apply_rotary_emb(
+ x_in: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ Apply rotary embeddings to input tensors using the given frequency
+ tensor.
+
+ This function applies rotary embeddings to the given query 'xq' and
+ key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The
+ input tensors are reshaped as complex numbers, and the frequency tensor
+ is reshaped for broadcasting compatibility. The resulting tensors
+ contain rotary embeddings and are returned as real tensors.
+
+ Args:
+ x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings.
+ freqs_cis (torch.Tensor): Precomputed frequency tensor for complex
+ exponentials.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor
+ and key tensor with rotary embeddings.
+ """
+
+ t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2)
+ t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
+ return t_out.reshape(*x_in.shape)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+
+ Args:
+ x:
+ x_mask:
+ freqs_cis:
+
+ Returns:
+
+ """
+ bsz, seqlen, _ = x.shape
+
+ xq, xk, xv = torch.split(
+ self.qkv(x),
+ [
+ self.n_local_heads * self.head_dim,
+ self.n_local_kv_heads * self.head_dim,
+ self.n_local_kv_heads * self.head_dim,
+ ],
+ dim=-1,
+ )
+ xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+ xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
+ xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
+
+ xq = self.q_norm(xq)
+ xk = self.k_norm(xk)
+
+ xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
+ xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
+
+ n_rep = self.n_local_heads // self.n_local_kv_heads
+ if n_rep >= 1:
+ xk = xk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+ xv = xv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
+ output = optimized_attention_masked(xq.movedim(1, 2), xk.movedim(1, 2), xv.movedim(1, 2), self.n_local_heads, x_mask, skip_reshape=True)
+
+ return self.out(output)
+
+
+class FeedForward(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ hidden_dim: int,
+ multiple_of: int,
+ ffn_dim_multiplier: Optional[float],
+ operation_settings={},
+ ):
+ """
+ Initialize the FeedForward module.
+
+ Args:
+ dim (int): Input dimension.
+ hidden_dim (int): Hidden dimension of the feedforward layer.
+ multiple_of (int): Value to ensure hidden dimension is a multiple
+ of this value.
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden
+ dimension. Defaults to None.
+
+ """
+ super().__init__()
+ # custom dim factor multiplier
+ if ffn_dim_multiplier is not None:
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+
+ self.w1 = operation_settings.get("operations").Linear(
+ dim,
+ hidden_dim,
+ bias=False,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ )
+ self.w2 = operation_settings.get("operations").Linear(
+ hidden_dim,
+ dim,
+ bias=False,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ )
+ self.w3 = operation_settings.get("operations").Linear(
+ dim,
+ hidden_dim,
+ bias=False,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ )
+
+ # @torch.compile
+ def _forward_silu_gating(self, x1, x3):
+ return F.silu(x1) * x3
+
+ def forward(self, x):
+ return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
+
+
+class JointTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ layer_id: int,
+ dim: int,
+ n_heads: int,
+ n_kv_heads: int,
+ multiple_of: int,
+ ffn_dim_multiplier: float,
+ norm_eps: float,
+ qk_norm: bool,
+ modulation=True,
+ operation_settings={},
+ ) -> None:
+ """
+ Initialize a TransformerBlock.
+
+ Args:
+ layer_id (int): Identifier for the layer.
+ dim (int): Embedding dimension of the input features.
+ n_heads (int): Number of attention heads.
+ n_kv_heads (Optional[int]): Number of attention heads in key and
+ value features (if using GQA), or set to None for the same as
+ query.
+ multiple_of (int):
+ ffn_dim_multiplier (float):
+ norm_eps (float):
+
+ """
+ super().__init__()
+ self.dim = dim
+ self.head_dim = dim // n_heads
+ self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings)
+ self.feed_forward = FeedForward(
+ dim=dim,
+ hidden_dim=4 * dim,
+ multiple_of=multiple_of,
+ ffn_dim_multiplier=ffn_dim_multiplier,
+ operation_settings=operation_settings,
+ )
+ self.layer_id = layer_id
+ self.attention_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.ffn_norm1 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ self.attention_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.ffn_norm2 = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ self.modulation = modulation
+ if modulation:
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ operation_settings.get("operations").Linear(
+ min(dim, 1024),
+ 4 * dim,
+ bias=True,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ ),
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_mask: torch.Tensor,
+ freqs_cis: torch.Tensor,
+ adaln_input: Optional[torch.Tensor]=None,
+ ):
+ """
+ Perform a forward pass through the TransformerBlock.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+ freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
+
+ Returns:
+ torch.Tensor: Output tensor after applying attention and
+ feedforward layers.
+
+ """
+ if self.modulation:
+ assert adaln_input is not None
+ scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
+
+ x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2(
+ self.attention(
+ modulate(self.attention_norm1(x), scale_msa),
+ x_mask,
+ freqs_cis,
+ )
+ )
+ x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
+ self.feed_forward(
+ modulate(self.ffn_norm1(x), scale_mlp),
+ )
+ )
+ else:
+ assert adaln_input is None
+ x = x + self.attention_norm2(
+ self.attention(
+ self.attention_norm1(x),
+ x_mask,
+ freqs_cis,
+ )
+ )
+ x = x + self.ffn_norm2(
+ self.feed_forward(
+ self.ffn_norm1(x),
+ )
+ )
+ return x
+
+
+class FinalLayer(nn.Module):
+ """
+ The final layer of NextDiT.
+ """
+
+ def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}):
+ super().__init__()
+ self.norm_final = operation_settings.get("operations").LayerNorm(
+ hidden_size,
+ elementwise_affine=False,
+ eps=1e-6,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ )
+ self.linear = operation_settings.get("operations").Linear(
+ hidden_size,
+ patch_size * patch_size * out_channels,
+ bias=True,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ )
+
+ self.adaLN_modulation = nn.Sequential(
+ nn.SiLU(),
+ operation_settings.get("operations").Linear(
+ min(hidden_size, 1024),
+ hidden_size,
+ bias=True,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ ),
+ )
+
+ def forward(self, x, c):
+ scale = self.adaLN_modulation(c)
+ x = modulate(self.norm_final(x), scale)
+ x = self.linear(x)
+ return x
+
+
+class NextDiT(nn.Module):
+ """
+ Diffusion model with a Transformer backbone.
+ """
+
+ def __init__(
+ self,
+ patch_size: int = 2,
+ in_channels: int = 4,
+ dim: int = 4096,
+ n_layers: int = 32,
+ n_refiner_layers: int = 2,
+ n_heads: int = 32,
+ n_kv_heads: Optional[int] = None,
+ multiple_of: int = 256,
+ ffn_dim_multiplier: Optional[float] = None,
+ norm_eps: float = 1e-5,
+ qk_norm: bool = False,
+ cap_feat_dim: int = 5120,
+ axes_dims: List[int] = (16, 56, 56),
+ axes_lens: List[int] = (1, 512, 512),
+ image_model=None,
+ device=None,
+ dtype=None,
+ operations=None,
+ ) -> None:
+ super().__init__()
+ self.dtype = dtype
+ operation_settings = {"operations": operations, "device": device, "dtype": dtype}
+ self.in_channels = in_channels
+ self.out_channels = in_channels
+ self.patch_size = patch_size
+
+ self.x_embedder = operation_settings.get("operations").Linear(
+ in_features=patch_size * patch_size * in_channels,
+ out_features=dim,
+ bias=True,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ )
+
+ self.noise_refiner = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ modulation=True,
+ operation_settings=operation_settings,
+ )
+ for layer_id in range(n_refiner_layers)
+ ]
+ )
+ self.context_refiner = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ modulation=False,
+ operation_settings=operation_settings,
+ )
+ for layer_id in range(n_refiner_layers)
+ ]
+ )
+
+ self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings)
+ self.cap_embedder = nn.Sequential(
+ operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
+ operation_settings.get("operations").Linear(
+ cap_feat_dim,
+ dim,
+ bias=True,
+ device=operation_settings.get("device"),
+ dtype=operation_settings.get("dtype"),
+ ),
+ )
+
+ self.layers = nn.ModuleList(
+ [
+ JointTransformerBlock(
+ layer_id,
+ dim,
+ n_heads,
+ n_kv_heads,
+ multiple_of,
+ ffn_dim_multiplier,
+ norm_eps,
+ qk_norm,
+ operation_settings=operation_settings,
+ )
+ for layer_id in range(n_layers)
+ ]
+ )
+ self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings)
+
+ assert (dim // n_heads) == sum(axes_dims)
+ self.axes_dims = axes_dims
+ self.axes_lens = axes_lens
+ self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims)
+ self.dim = dim
+ self.n_heads = n_heads
+
+ def unpatchify(
+ self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False
+ ) -> List[torch.Tensor]:
+ """
+ x: (N, T, patch_size**2 * C)
+ imgs: (N, H, W, C)
+ """
+ pH = pW = self.patch_size
+ imgs = []
+ for i in range(x.size(0)):
+ H, W = img_size[i]
+ begin = cap_size[i]
+ end = begin + (H // pH) * (W // pW)
+ imgs.append(
+ x[i][begin:end]
+ .view(H // pH, W // pW, pH, pW, self.out_channels)
+ .permute(4, 0, 2, 1, 3)
+ .flatten(3, 4)
+ .flatten(1, 2)
+ )
+
+ if return_tensor:
+ imgs = torch.stack(imgs, dim=0)
+ return imgs
+
+ def patchify_and_embed(
+ self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens
+ ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]:
+ bsz = len(x)
+ pH = pW = self.patch_size
+ device = x[0].device
+ dtype = x[0].dtype
+
+ if cap_mask is not None:
+ l_effective_cap_len = cap_mask.sum(dim=1).tolist()
+ else:
+ l_effective_cap_len = [num_tokens] * bsz
+
+ if cap_mask is not None and not torch.is_floating_point(cap_mask):
+ cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max
+
+ img_sizes = [(img.size(1), img.size(2)) for img in x]
+ l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes]
+
+ max_seq_len = max(
+ (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len))
+ )
+ max_cap_len = max(l_effective_cap_len)
+ max_img_len = max(l_effective_img_len)
+
+ position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
+
+ for i in range(bsz):
+ cap_len = l_effective_cap_len[i]
+ img_len = l_effective_img_len[i]
+ H, W = img_sizes[i]
+ H_tokens, W_tokens = H // pH, W // pW
+ assert H_tokens * W_tokens == img_len
+
+ position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
+ position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
+ row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
+ col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
+ position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
+ position_ids[i, cap_len:cap_len+img_len, 2] = col_ids
+
+ freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype)
+
+ # build freqs_cis for cap and image individually
+ cap_freqs_cis_shape = list(freqs_cis.shape)
+ # cap_freqs_cis_shape[1] = max_cap_len
+ cap_freqs_cis_shape[1] = cap_feats.shape[1]
+ cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
+
+ img_freqs_cis_shape = list(freqs_cis.shape)
+ img_freqs_cis_shape[1] = max_img_len
+ img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
+
+ for i in range(bsz):
+ cap_len = l_effective_cap_len[i]
+ img_len = l_effective_img_len[i]
+ cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
+ img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len]
+
+ # refine context
+ for layer in self.context_refiner:
+ cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis)
+
+ # refine image
+ flat_x = []
+ for i in range(bsz):
+ img = x[i]
+ C, H, W = img.size()
+ img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
+ flat_x.append(img)
+ x = flat_x
+ padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype)
+ padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device)
+ for i in range(bsz):
+ padded_img_embed[i, :l_effective_img_len[i]] = x[i]
+ padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max
+
+ padded_img_embed = self.x_embedder(padded_img_embed)
+ padded_img_mask = padded_img_mask.unsqueeze(1)
+ for layer in self.noise_refiner:
+ padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t)
+
+ if cap_mask is not None:
+ mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device)
+ mask[:, :max_cap_len] = cap_mask[:, :max_cap_len]
+ else:
+ mask = None
+
+ padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype)
+ for i in range(bsz):
+ cap_len = l_effective_cap_len[i]
+ img_len = l_effective_img_len[i]
+
+ padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len]
+ padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len]
+
+ return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis
+
+ # def forward(self, x, t, cap_feats, cap_mask):
+ def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
+ t = 1.0 - timesteps
+ cap_feats = context
+ cap_mask = attention_mask
+ bs, c, h, w = x.shape
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
+ """
+ Forward pass of NextDiT.
+ t: (N,) tensor of diffusion timesteps
+ y: (N,) tensor of text tokens/features
+ """
+
+ t = self.t_embedder(t, dtype=x.dtype) # (N, D)
+ adaln_input = t
+
+ cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
+
+ x_is_tensor = isinstance(x, torch.Tensor)
+ x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens)
+ freqs_cis = freqs_cis.to(x.device)
+
+ for layer in self.layers:
+ x = layer(x, mask, freqs_cis, adaln_input)
+
+ x = self.final_layer(x, adaln_input)
+ x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w]
+
+ return -x
+
diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py
index 44aec59a6..2cb77d85d 100644
--- a/comfy/ldm/modules/attention.py
+++ b/comfy/ldm/modules/attention.py
@@ -1,4 +1,6 @@
import math
+import sys
+
import torch
import torch.nn.functional as F
from torch import nn, einsum
@@ -16,7 +18,21 @@ if model_management.xformers_enabled():
import xformers.ops
if model_management.sage_attention_enabled():
- from sageattention import sageattn
+ try:
+ from sageattention import sageattn
+ except ModuleNotFoundError as e:
+ if e.name == "sageattention":
+ logging.error(f"\n\nTo use the `--use-sage-attention` feature, the `sageattention` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install sageattention")
+ else:
+ raise e
+ exit(-1)
+
+if model_management.flash_attention_enabled():
+ try:
+ from flash_attn import flash_attn_func
+ except ModuleNotFoundError:
+ logging.error(f"\n\nTo use the `--use-flash-attention` feature, the `flash-attn` package must be installed first.\ncommand:\n\t{sys.executable} -m pip install flash-attn")
+ exit(-1)
from comfy.cli_args import args
import comfy.ops
@@ -24,38 +40,24 @@ ops = comfy.ops.disable_weight_init
FORCE_UPCAST_ATTENTION_DTYPE = model_management.force_upcast_attention_dtype()
-def get_attn_precision(attn_precision):
+def get_attn_precision(attn_precision, current_dtype):
if args.dont_upcast_attention:
return None
- if FORCE_UPCAST_ATTENTION_DTYPE is not None:
- return FORCE_UPCAST_ATTENTION_DTYPE
+
+ if FORCE_UPCAST_ATTENTION_DTYPE is not None and current_dtype in FORCE_UPCAST_ATTENTION_DTYPE:
+ return FORCE_UPCAST_ATTENTION_DTYPE[current_dtype]
return attn_precision
def exists(val):
return val is not None
-def uniq(arr):
- return{el: True for el in arr}.keys()
-
-
def default(val, d):
if exists(val):
return val
return d
-def max_neg_value(t):
- return -torch.finfo(t.dtype).max
-
-
-def init_(tensor):
- dim = tensor.shape[-1]
- std = 1 / math.sqrt(dim)
- tensor.uniform_(-std, std)
- return tensor
-
-
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=ops):
@@ -90,7 +92,7 @@ def Normalize(in_channels, dtype=None, device=None):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
- attn_precision = get_attn_precision(attn_precision)
+ attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
b, _, _, dim_head = q.shape
@@ -159,7 +161,7 @@ def attention_basic(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
- attn_precision = get_attn_precision(attn_precision)
+ attn_precision = get_attn_precision(attn_precision, query.dtype)
if skip_reshape:
b, _, _, dim_head = query.shape
@@ -229,7 +231,7 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
return hidden_states
def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
- attn_precision = get_attn_precision(attn_precision)
+ attn_precision = get_attn_precision(attn_precision, q.dtype)
if skip_reshape:
b, _, _, dim_head = q.shape
@@ -472,7 +474,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
- tensor_layout="HND"
+ tensor_layout = "HND"
else:
b, _, dim_head = q.shape
dim_head //= heads
@@ -480,7 +482,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
lambda t: t.view(b, -1, heads, dim_head),
(q, k, v),
)
- tensor_layout="NHD"
+ tensor_layout = "NHD"
if mask is not None:
# add a batch dimension if there isn't already one
@@ -490,7 +492,17 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
if mask.ndim == 3:
mask = mask.unsqueeze(1)
- out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
+ try:
+ out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout)
+ except Exception as e:
+ logging.error("Error running sage attention: {}, using pytorch attention instead.".format(e))
+ if tensor_layout == "NHD":
+ q, k, v = map(
+ lambda t: t.transpose(1, 2),
+ (q, k, v),
+ )
+ return attention_pytorch(q, k, v, heads, mask=mask, skip_reshape=True, skip_output_reshape=skip_output_reshape)
+
if tensor_layout == "HND":
if not skip_output_reshape:
out = (
@@ -504,6 +516,63 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
return out
+try:
+ @torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
+ def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+ dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
+ return flash_attn_func(q, k, v, dropout_p=dropout_p, causal=causal)
+
+
+ @flash_attn_wrapper.register_fake
+ def flash_attn_fake(q, k, v, dropout_p=0.0, causal=False):
+ # Output shape is the same as q
+ return q.new_empty(q.shape)
+except AttributeError as error:
+ FLASH_ATTN_ERROR = error
+
+ def flash_attn_wrapper(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
+ dropout_p: float = 0.0, causal: bool = False) -> torch.Tensor:
+ assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
+
+
+def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False):
+ if skip_reshape:
+ b, _, _, dim_head = q.shape
+ else:
+ b, _, dim_head = q.shape
+ dim_head //= heads
+ q, k, v = map(
+ lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
+ (q, k, v),
+ )
+
+ if mask is not None:
+ # add a batch dimension if there isn't already one
+ if mask.ndim == 2:
+ mask = mask.unsqueeze(0)
+ # add a heads dimension if there isn't already one
+ if mask.ndim == 3:
+ mask = mask.unsqueeze(1)
+
+ try:
+ assert mask is None
+ out = flash_attn_wrapper(
+ q.transpose(1, 2),
+ k.transpose(1, 2),
+ v.transpose(1, 2),
+ dropout_p=0.0,
+ causal=False,
+ ).transpose(1, 2)
+ except Exception as e:
+ logging.warning(f"Flash Attention failed, using default SDPA: {e}")
+ out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False)
+ if not skip_output_reshape:
+ out = (
+ out.transpose(1, 2).reshape(b, -1, heads * dim_head)
+ )
+ return out
+
+
optimized_attention = attention_basic
if model_management.sage_attention_enabled():
@@ -512,6 +581,9 @@ if model_management.sage_attention_enabled():
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
+elif model_management.flash_attention_enabled():
+ logging.info("Using Flash Attention")
+ optimized_attention = attention_flash
elif model_management.pytorch_attention_enabled():
logging.info("Using pytorch attention")
optimized_attention = attention_pytorch
@@ -778,6 +850,7 @@ class SpatialTransformer(nn.Module):
if not isinstance(context, list):
context = [context] * len(self.transformer_blocks)
b, c, h, w = x.shape
+ transformer_options["activations_shape"] = list(x.shape)
x_in = x
x = self.norm(x)
if not self.use_linear:
@@ -893,6 +966,7 @@ class SpatialVideoTransformer(SpatialTransformer):
transformer_options={}
) -> torch.Tensor:
_, _, h, w = x.shape
+ transformer_options["activations_shape"] = list(x.shape)
x_in = x
spatial_context = None
if exists(context):
diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py
index e70f4431f..eaf3e73a4 100644
--- a/comfy/ldm/modules/diffusionmodules/mmdit.py
+++ b/comfy/ldm/modules/diffusionmodules/mmdit.py
@@ -321,7 +321,7 @@ class SelfAttention(nn.Module):
class RMSNorm(torch.nn.Module):
def __init__(
- self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None
+ self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6, device=None, dtype=None, **kwargs
):
"""
Initialize the RMSNorm normalization layer.
diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py
index ed1e88212..8162742cf 100644
--- a/comfy/ldm/modules/diffusionmodules/model.py
+++ b/comfy/ldm/modules/diffusionmodules/model.py
@@ -293,6 +293,17 @@ def pytorch_attention(q, k, v):
return out
+def vae_attention():
+ if model_management.xformers_enabled_vae():
+ logging.info("Using xformers attention in VAE")
+ return xformers_attention
+ elif model_management.pytorch_attention_enabled_vae():
+ logging.info("Using pytorch attention in VAE")
+ return pytorch_attention
+ else:
+ logging.info("Using split attention in VAE")
+ return normal_attention
+
class AttnBlock(nn.Module):
def __init__(self, in_channels, conv_op=ops.Conv2d):
super().__init__()
@@ -320,15 +331,7 @@ class AttnBlock(nn.Module):
stride=1,
padding=0)
- if model_management.xformers_enabled_vae():
- logging.info("Using xformers attention in VAE")
- self.optimized_attention = xformers_attention
- elif model_management.pytorch_attention_enabled():
- logging.info("Using pytorch attention in VAE")
- self.optimized_attention = pytorch_attention
- else:
- logging.info("Using split attention in VAE")
- self.optimized_attention = normal_attention
+ self.optimized_attention = vae_attention()
def forward(self, x):
h_ = x
@@ -699,9 +702,6 @@ class Decoder(nn.Module):
padding=1)
def forward(self, z, **kwargs):
- #assert z.shape[1:] == self.z_shape[1:]
- self.last_z_shape = z.shape
-
# timestep embedding
temb = None
diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py
new file mode 100644
index 000000000..1b51a4e4a
--- /dev/null
+++ b/comfy/ldm/wan/model.py
@@ -0,0 +1,786 @@
+# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+import math
+
+import torch
+import torch.nn as nn
+from einops import repeat
+
+from comfy.ldm.modules.attention import optimized_attention
+from comfy.ldm.flux.layers import EmbedND
+from comfy.ldm.flux.math import apply_rope
+import comfy.ldm.common_dit
+import comfy.model_management
+
+
+def sinusoidal_embedding_1d(dim, position):
+ # preprocess
+ assert dim % 2 == 0
+ half = dim // 2
+ position = position.type(torch.float32)
+
+ # calculation
+ sinusoid = torch.outer(
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
+ return x
+
+
+class WanSelfAttention(nn.Module):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ eps=1e-6, operation_settings={}):
+ assert dim % num_heads == 0
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.head_dim = dim // num_heads
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.eps = eps
+
+ # layers
+ self.q = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.k = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.v = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.o = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.norm_q = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
+ self.norm_k = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
+
+ def forward(self, x, freqs):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
+
+ # query, key, value function
+ def qkv_fn(x):
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
+ v = self.v(x).view(b, s, n * d)
+ return q, k, v
+
+ q, k, v = qkv_fn(x)
+ q, k = apply_rope(q, k, freqs)
+
+ x = optimized_attention(
+ q.view(b, s, n * d),
+ k.view(b, s, n * d),
+ v,
+ heads=self.num_heads,
+ )
+
+ x = self.o(x)
+ return x
+
+
+class WanT2VCrossAttention(WanSelfAttention):
+
+ def forward(self, x, context, **kwargs):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ context(Tensor): Shape [B, L2, C]
+ """
+ # compute query, key, value
+ q = self.norm_q(self.q(x))
+ k = self.norm_k(self.k(context))
+ v = self.v(context)
+
+ # compute attention
+ x = optimized_attention(q, k, v, heads=self.num_heads)
+
+ x = self.o(x)
+ return x
+
+
+class WanI2VCrossAttention(WanSelfAttention):
+
+ def __init__(self,
+ dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ eps=1e-6, operation_settings={}):
+ super().__init__(dim, num_heads, window_size, qk_norm, eps, operation_settings=operation_settings)
+
+ self.k_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.v_img = operation_settings.get("operations").Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ # self.alpha = nn.Parameter(torch.zeros((1, )))
+ self.norm_k_img = operation_settings.get("operations").RMSNorm(dim, eps=eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if qk_norm else nn.Identity()
+
+ def forward(self, x, context, context_img_len):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ context(Tensor): Shape [B, L2, C]
+ """
+ context_img = context[:, :context_img_len]
+ context = context[:, context_img_len:]
+
+ # compute query, key, value
+ q = self.norm_q(self.q(x))
+ k = self.norm_k(self.k(context))
+ v = self.v(context)
+ k_img = self.norm_k_img(self.k_img(context_img))
+ v_img = self.v_img(context_img)
+ img_x = optimized_attention(q, k_img, v_img, heads=self.num_heads)
+ # compute attention
+ x = optimized_attention(q, k, v, heads=self.num_heads)
+
+ # output
+ x = x + img_x
+ x = self.o(x)
+ return x
+
+
+WAN_CROSSATTENTION_CLASSES = {
+ 't2v_cross_attn': WanT2VCrossAttention,
+ 'i2v_cross_attn': WanI2VCrossAttention,
+}
+
+
+class WanAttentionBlock(nn.Module):
+
+ def __init__(self,
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6, operation_settings={}):
+ super().__init__()
+ self.dim = dim
+ self.ffn_dim = ffn_dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.cross_attn_norm = cross_attn_norm
+ self.eps = eps
+
+ # layers
+ self.norm1 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
+ eps, operation_settings=operation_settings)
+ self.norm3 = operation_settings.get("operations").LayerNorm(
+ dim, eps,
+ elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) if cross_attn_norm else nn.Identity()
+ self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
+ num_heads,
+ (-1, -1),
+ qk_norm,
+ eps, operation_settings=operation_settings)
+ self.norm2 = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.ffn = nn.Sequential(
+ operation_settings.get("operations").Linear(dim, ffn_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
+ operation_settings.get("operations").Linear(ffn_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
+
+ # modulation
+ self.modulation = nn.Parameter(torch.empty(1, 6, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
+
+ def forward(
+ self,
+ x,
+ e,
+ freqs,
+ context,
+ context_img_len=257,
+ ):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L, C]
+ e(Tensor): Shape [B, 6, C]
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
+ """
+ # assert e.dtype == torch.float32
+
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
+ # assert e[0].dtype == torch.float32
+
+ # self-attention
+ y = self.self_attn(
+ self.norm1(x) * (1 + e[1]) + e[0],
+ freqs)
+
+ x = x + y * e[2]
+
+ # cross-attention & ffn
+ x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
+ y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
+ x = x + y * e[5]
+ return x
+
+
+class VaceWanAttentionBlock(WanAttentionBlock):
+ def __init__(
+ self,
+ cross_attn_type,
+ dim,
+ ffn_dim,
+ num_heads,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=False,
+ eps=1e-6,
+ block_id=0,
+ operation_settings={}
+ ):
+ super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
+ self.block_id = block_id
+ if block_id == 0:
+ self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, c, x, **kwargs):
+ if self.block_id == 0:
+ c = self.before_proj(c) + x
+ c = super().forward(c, **kwargs)
+ c_skip = self.after_proj(c)
+ return c_skip, c
+
+
+class WanCamAdapter(nn.Module):
+ def __init__(self, in_dim, out_dim, kernel_size, stride, num_residual_blocks=1, operation_settings={}):
+ super(WanCamAdapter, self).__init__()
+
+ # Pixel Unshuffle: reduce spatial dimensions by a factor of 8
+ self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=8)
+
+ # Convolution: reduce spatial dimensions by a factor
+ # of 2 (without overlap)
+ self.conv = operation_settings.get("operations").Conv2d(in_dim * 64, out_dim, kernel_size=kernel_size, stride=stride, padding=0, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ # Residual blocks for feature extraction
+ self.residual_blocks = nn.Sequential(
+ *[WanCamResidualBlock(out_dim, operation_settings = operation_settings) for _ in range(num_residual_blocks)]
+ )
+
+ def forward(self, x):
+ # Reshape to merge the frame dimension into batch
+ bs, c, f, h, w = x.size()
+ x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
+
+ # Pixel Unshuffle operation
+ x_unshuffled = self.pixel_unshuffle(x)
+
+ # Convolution operation
+ x_conv = self.conv(x_unshuffled)
+
+ # Feature extraction with residual blocks
+ out = self.residual_blocks(x_conv)
+
+ # Reshape to restore original bf dimension
+ out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
+
+ # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
+ out = out.permute(0, 2, 1, 3, 4)
+
+ return out
+
+
+class WanCamResidualBlock(nn.Module):
+ def __init__(self, dim, operation_settings={}):
+ super(WanCamResidualBlock, self).__init__()
+ self.conv1 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.relu = nn.ReLU(inplace=True)
+ self.conv2 = operation_settings.get("operations").Conv2d(dim, dim, kernel_size=3, padding=1, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ def forward(self, x):
+ residual = x
+ out = self.relu(self.conv1(x))
+ out = self.conv2(out)
+ out += residual
+ return out
+
+
+class Head(nn.Module):
+
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6, operation_settings={}):
+ super().__init__()
+ self.dim = dim
+ self.out_dim = out_dim
+ self.patch_size = patch_size
+ self.eps = eps
+
+ # layers
+ out_dim = math.prod(patch_size) * out_dim
+ self.norm = operation_settings.get("operations").LayerNorm(dim, eps, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+ self.head = operation_settings.get("operations").Linear(dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"))
+
+ # modulation
+ self.modulation = nn.Parameter(torch.empty(1, 2, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
+
+ def forward(self, x, e):
+ r"""
+ Args:
+ x(Tensor): Shape [B, L1, C]
+ e(Tensor): Shape [B, C]
+ """
+ # assert e.dtype == torch.float32
+ e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
+ x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
+ return x
+
+
+class MLPProj(torch.nn.Module):
+
+ def __init__(self, in_dim, out_dim, flf_pos_embed_token_number=None, operation_settings={}):
+ super().__init__()
+
+ self.proj = torch.nn.Sequential(
+ operation_settings.get("operations").LayerNorm(in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear(in_dim, in_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
+ torch.nn.GELU(), operation_settings.get("operations").Linear(in_dim, out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
+ operation_settings.get("operations").LayerNorm(out_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
+
+ if flf_pos_embed_token_number is not None:
+ self.emb_pos = nn.Parameter(torch.empty((1, flf_pos_embed_token_number, in_dim), device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
+ else:
+ self.emb_pos = None
+
+ def forward(self, image_embeds):
+ if self.emb_pos is not None:
+ image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device)
+
+ clip_extra_context_tokens = self.proj(image_embeds)
+ return clip_extra_context_tokens
+
+
+class WanModel(torch.nn.Module):
+ r"""
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
+ """
+
+ def __init__(self,
+ model_type='t2v',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6,
+ flf_pos_embed_token_number=None,
+ image_model=None,
+ device=None,
+ dtype=None,
+ operations=None,
+ ):
+ r"""
+ Initialize the diffusion model backbone.
+
+ Args:
+ model_type (`str`, *optional*, defaults to 't2v'):
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
+ text_len (`int`, *optional*, defaults to 512):
+ Fixed length for text embeddings
+ in_dim (`int`, *optional*, defaults to 16):
+ Input video channels (C_in)
+ dim (`int`, *optional*, defaults to 2048):
+ Hidden dimension of the transformer
+ ffn_dim (`int`, *optional*, defaults to 8192):
+ Intermediate dimension in feed-forward network
+ freq_dim (`int`, *optional*, defaults to 256):
+ Dimension for sinusoidal time embeddings
+ text_dim (`int`, *optional*, defaults to 4096):
+ Input dimension for text embeddings
+ out_dim (`int`, *optional*, defaults to 16):
+ Output video channels (C_out)
+ num_heads (`int`, *optional*, defaults to 16):
+ Number of attention heads
+ num_layers (`int`, *optional*, defaults to 32):
+ Number of transformer blocks
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
+ Window size for local attention (-1 indicates global attention)
+ qk_norm (`bool`, *optional*, defaults to True):
+ Enable query/key normalization
+ cross_attn_norm (`bool`, *optional*, defaults to False):
+ Enable cross-attention normalization
+ eps (`float`, *optional*, defaults to 1e-6):
+ Epsilon value for normalization layers
+ """
+
+ super().__init__()
+ self.dtype = dtype
+ operation_settings = {"operations": operations, "device": device, "dtype": dtype}
+
+ assert model_type in ['t2v', 'i2v']
+ self.model_type = model_type
+
+ self.patch_size = patch_size
+ self.text_len = text_len
+ self.in_dim = in_dim
+ self.dim = dim
+ self.ffn_dim = ffn_dim
+ self.freq_dim = freq_dim
+ self.text_dim = text_dim
+ self.out_dim = out_dim
+ self.num_heads = num_heads
+ self.num_layers = num_layers
+ self.window_size = window_size
+ self.qk_norm = qk_norm
+ self.cross_attn_norm = cross_attn_norm
+ self.eps = eps
+
+ # embeddings
+ self.patch_embedding = operations.Conv3d(
+ in_dim, dim, kernel_size=patch_size, stride=patch_size, device=operation_settings.get("device"), dtype=torch.float32)
+ self.text_embedding = nn.Sequential(
+ operations.Linear(text_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.GELU(approximate='tanh'),
+ operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
+
+ self.time_embedding = nn.Sequential(
+ operations.Linear(freq_dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), nn.SiLU(), operations.Linear(dim, dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
+ self.time_projection = nn.Sequential(nn.SiLU(), operations.Linear(dim, dim * 6, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")))
+
+ # blocks
+ cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
+ self.blocks = nn.ModuleList([
+ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
+ window_size, qk_norm, cross_attn_norm, eps, operation_settings=operation_settings)
+ for _ in range(num_layers)
+ ])
+
+ # head
+ self.head = Head(dim, out_dim, patch_size, eps, operation_settings=operation_settings)
+
+ d = dim // num_heads
+ self.rope_embedder = EmbedND(dim=d, theta=10000.0, axes_dim=[d - 4 * (d // 6), 2 * (d // 6), 2 * (d // 6)])
+
+ if model_type == 'i2v':
+ self.img_emb = MLPProj(1280, dim, flf_pos_embed_token_number=flf_pos_embed_token_number, operation_settings=operation_settings)
+ else:
+ self.img_emb = None
+
+ def forward_orig(
+ self,
+ x,
+ t,
+ context,
+ clip_fea=None,
+ freqs=None,
+ transformer_options={},
+ **kwargs,
+ ):
+ r"""
+ Forward pass through the diffusion model
+
+ Args:
+ x (Tensor):
+ List of input video tensors with shape [B, C_in, F, H, W]
+ t (Tensor):
+ Diffusion timesteps tensor of shape [B]
+ context (List[Tensor]):
+ List of text embeddings each with shape [B, L, C]
+ seq_len (`int`):
+ Maximum sequence length for positional encoding
+ clip_fea (Tensor, *optional*):
+ CLIP image features for image-to-video mode
+ y (List[Tensor], *optional*):
+ Conditional video inputs for image-to-video mode, same shape as x
+
+ Returns:
+ List[Tensor]:
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
+ """
+ # embeddings
+ x = self.patch_embedding(x.float()).to(x.dtype)
+ grid_sizes = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2)
+
+ # time embeddings
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
+
+ # context
+ context = self.text_embedding(context)
+
+ context_img_len = None
+ if clip_fea is not None:
+ if self.img_emb is not None:
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
+ context = torch.concat([context_clip, context], dim=1)
+ context_img_len = clip_fea.shape[-2]
+
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+ for i, block in enumerate(self.blocks):
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
+ return out
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
+ x = out["img"]
+ else:
+ x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
+
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return x
+
+ def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
+ bs, c, t, h, w = x.shape
+ x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
+ patch_size = self.patch_size
+ t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
+ h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
+ w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
+ img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
+ img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
+ img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)
+ img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).reshape(1, 1, -1)
+ img_ids = repeat(img_ids, "t h w c -> b (t h w) c", b=bs)
+
+ freqs = self.rope_embedder(img_ids).movedim(1, 2)
+ return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
+
+ def unpatchify(self, x, grid_sizes):
+ r"""
+ Reconstruct video tensors from patch embeddings.
+
+ Args:
+ x (List[Tensor]):
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
+ grid_sizes (Tensor):
+ Original spatial-temporal grid dimensions before patching,
+ shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
+
+ Returns:
+ List[Tensor]:
+ Reconstructed video tensors with shape [L, C_out, F, H / 8, W / 8]
+ """
+
+ c = self.out_dim
+ u = x
+ b = u.shape[0]
+ u = u[:, :math.prod(grid_sizes)].view(b, *grid_sizes, *self.patch_size, c)
+ u = torch.einsum('bfhwpqrc->bcfphqwr', u)
+ u = u.reshape(b, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
+ return u
+
+
+class VaceWanModel(WanModel):
+ r"""
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
+ """
+
+ def __init__(self,
+ model_type='vace',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6,
+ flf_pos_embed_token_number=None,
+ image_model=None,
+ vace_layers=None,
+ vace_in_dim=None,
+ device=None,
+ dtype=None,
+ operations=None,
+ ):
+
+ super().__init__(model_type='t2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
+ operation_settings = {"operations": operations, "device": device, "dtype": dtype}
+
+ # Vace
+ if vace_layers is not None:
+ self.vace_layers = vace_layers
+ self.vace_in_dim = vace_in_dim
+ # vace blocks
+ self.vace_blocks = nn.ModuleList([
+ VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm, self.cross_attn_norm, self.eps, block_id=i, operation_settings=operation_settings)
+ for i in range(self.vace_layers)
+ ])
+
+ self.vace_layers_mapping = {i: n for n, i in enumerate(range(0, self.num_layers, self.num_layers // self.vace_layers))}
+ # vace patch embeddings
+ self.vace_patch_embedding = operations.Conv3d(
+ self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size, device=device, dtype=torch.float32
+ )
+
+ def forward_orig(
+ self,
+ x,
+ t,
+ context,
+ vace_context,
+ vace_strength,
+ clip_fea=None,
+ freqs=None,
+ transformer_options={},
+ **kwargs,
+ ):
+ # embeddings
+ x = self.patch_embedding(x.float()).to(x.dtype)
+ grid_sizes = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2)
+
+ # time embeddings
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
+
+ # context
+ context = self.text_embedding(context)
+
+ context_img_len = None
+ if clip_fea is not None:
+ if self.img_emb is not None:
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
+ context = torch.concat([context_clip, context], dim=1)
+ context_img_len = clip_fea.shape[-2]
+
+ orig_shape = list(vace_context.shape)
+ vace_context = vace_context.movedim(0, 1).reshape([-1] + orig_shape[2:])
+ c = self.vace_patch_embedding(vace_context.float()).to(vace_context.dtype)
+ c = c.flatten(2).transpose(1, 2)
+ c = list(c.split(orig_shape[0], dim=0))
+
+ # arguments
+ x_orig = x
+
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+ for i, block in enumerate(self.blocks):
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
+ return out
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
+ x = out["img"]
+ else:
+ x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
+
+ ii = self.vace_layers_mapping.get(i, None)
+ if ii is not None:
+ for iii in range(len(c)):
+ c_skip, c[iii] = self.vace_blocks[ii](c[iii], x=x_orig, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
+ x += c_skip * vace_strength[iii]
+ del c_skip
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return x
+
+class CameraWanModel(WanModel):
+ r"""
+ Wan diffusion backbone supporting both text-to-video and image-to-video.
+ """
+
+ def __init__(self,
+ model_type='camera',
+ patch_size=(1, 2, 2),
+ text_len=512,
+ in_dim=16,
+ dim=2048,
+ ffn_dim=8192,
+ freq_dim=256,
+ text_dim=4096,
+ out_dim=16,
+ num_heads=16,
+ num_layers=32,
+ window_size=(-1, -1),
+ qk_norm=True,
+ cross_attn_norm=True,
+ eps=1e-6,
+ flf_pos_embed_token_number=None,
+ image_model=None,
+ in_dim_control_adapter=24,
+ device=None,
+ dtype=None,
+ operations=None,
+ ):
+
+ super().__init__(model_type='i2v', patch_size=patch_size, text_len=text_len, in_dim=in_dim, dim=dim, ffn_dim=ffn_dim, freq_dim=freq_dim, text_dim=text_dim, out_dim=out_dim, num_heads=num_heads, num_layers=num_layers, window_size=window_size, qk_norm=qk_norm, cross_attn_norm=cross_attn_norm, eps=eps, flf_pos_embed_token_number=flf_pos_embed_token_number, image_model=image_model, device=device, dtype=dtype, operations=operations)
+ operation_settings = {"operations": operations, "device": device, "dtype": dtype}
+
+ self.control_adapter = WanCamAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], operation_settings=operation_settings)
+
+
+ def forward_orig(
+ self,
+ x,
+ t,
+ context,
+ clip_fea=None,
+ freqs=None,
+ camera_conditions = None,
+ transformer_options={},
+ **kwargs,
+ ):
+ # embeddings
+ x = self.patch_embedding(x.float()).to(x.dtype)
+ if self.control_adapter is not None and camera_conditions is not None:
+ x_camera = self.control_adapter(camera_conditions).to(x.dtype)
+ x = x + x_camera
+ grid_sizes = x.shape[2:]
+ x = x.flatten(2).transpose(1, 2)
+
+ # time embeddings
+ e = self.time_embedding(
+ sinusoidal_embedding_1d(self.freq_dim, t).to(dtype=x[0].dtype))
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
+
+ # context
+ context = self.text_embedding(context)
+
+ context_img_len = None
+ if clip_fea is not None:
+ if self.img_emb is not None:
+ context_clip = self.img_emb(clip_fea) # bs x 257 x dim
+ context = torch.concat([context_clip, context], dim=1)
+ context_img_len = clip_fea.shape[-2]
+
+ patches_replace = transformer_options.get("patches_replace", {})
+ blocks_replace = patches_replace.get("dit", {})
+ for i, block in enumerate(self.blocks):
+ if ("double_block", i) in blocks_replace:
+ def block_wrap(args):
+ out = {}
+ out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len)
+ return out
+ out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
+ x = out["img"]
+ else:
+ x = block(x, e=e0, freqs=freqs, context=context, context_img_len=context_img_len)
+
+ # head
+ x = self.head(x, e)
+
+ # unpatchify
+ x = self.unpatchify(x, grid_sizes)
+ return x
diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py
new file mode 100644
index 000000000..a8ebc5ec6
--- /dev/null
+++ b/comfy/ldm/wan/vae.py
@@ -0,0 +1,567 @@
+# original version: https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
+# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from comfy.ldm.modules.diffusionmodules.model import vae_attention
+
+import comfy.ops
+ops = comfy.ops.disable_weight_init
+
+CACHE_T = 2
+
+
+class CausalConv3d(ops.Conv3d):
+ """
+ Causal 3d convolusion.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._padding = (self.padding[2], self.padding[2], self.padding[1],
+ self.padding[1], 2 * self.padding[0], 0)
+ self.padding = (0, 0, 0)
+
+ def forward(self, x, cache_x=None):
+ padding = list(self._padding)
+ if cache_x is not None and self._padding[4] > 0:
+ cache_x = cache_x.to(x.device)
+ x = torch.cat([cache_x, x], dim=2)
+ padding[4] -= cache_x.shape[2]
+ x = F.pad(x, padding)
+
+ return super().forward(x)
+
+
+class RMS_norm(nn.Module):
+
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
+ super().__init__()
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
+
+ self.channel_first = channel_first
+ self.scale = dim**0.5
+ self.gamma = nn.Parameter(torch.ones(shape))
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else None
+
+ def forward(self, x):
+ return F.normalize(
+ x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma.to(x) + (self.bias.to(x) if self.bias is not None else 0)
+
+
+class Upsample(nn.Upsample):
+
+ def forward(self, x):
+ """
+ Fix bfloat16 support for nearest neighbor interpolation.
+ """
+ return super().forward(x.float()).type_as(x)
+
+
+class Resample(nn.Module):
+
+ def __init__(self, dim, mode):
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
+ 'downsample3d')
+ super().__init__()
+ self.dim = dim
+ self.mode = mode
+
+ # layers
+ if mode == 'upsample2d':
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
+ ops.Conv2d(dim, dim // 2, 3, padding=1))
+ elif mode == 'upsample3d':
+ self.resample = nn.Sequential(
+ Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
+ ops.Conv2d(dim, dim // 2, 3, padding=1))
+ self.time_conv = CausalConv3d(
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
+
+ elif mode == 'downsample2d':
+ self.resample = nn.Sequential(
+ nn.ZeroPad2d((0, 1, 0, 1)),
+ ops.Conv2d(dim, dim, 3, stride=(2, 2)))
+ elif mode == 'downsample3d':
+ self.resample = nn.Sequential(
+ nn.ZeroPad2d((0, 1, 0, 1)),
+ ops.Conv2d(dim, dim, 3, stride=(2, 2)))
+ self.time_conv = CausalConv3d(
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
+
+ else:
+ self.resample = nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ b, c, t, h, w = x.size()
+ if self.mode == 'upsample3d':
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = 'Rep'
+ feat_idx[0] += 1
+ else:
+
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[
+ idx] is not None and feat_cache[idx] != 'Rep':
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ if cache_x.shape[2] < 2 and feat_cache[
+ idx] is not None and feat_cache[idx] == 'Rep':
+ cache_x = torch.cat([
+ torch.zeros_like(cache_x).to(cache_x.device),
+ cache_x
+ ],
+ dim=2)
+ if feat_cache[idx] == 'Rep':
+ x = self.time_conv(x)
+ else:
+ x = self.time_conv(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+
+ x = x.reshape(b, 2, c, t, h, w)
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
+ 3)
+ x = x.reshape(b, c, t * 2, h, w)
+ t = x.shape[2]
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+ x = self.resample(x)
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
+
+ if self.mode == 'downsample3d':
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ if feat_cache[idx] is None:
+ feat_cache[idx] = x.clone()
+ feat_idx[0] += 1
+ else:
+
+ cache_x = x[:, :, -1:, :, :].clone()
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
+ # # cache last frame of last two chunk
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
+
+ x = self.time_conv(
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ return x
+
+ def init_weight(self, conv):
+ conv_weight = conv.weight
+ nn.init.zeros_(conv_weight)
+ c1, c2, t, h, w = conv_weight.size()
+ one_matrix = torch.eye(c1, c2)
+ init_matrix = one_matrix
+ nn.init.zeros_(conv_weight)
+ #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+ def init_weight2(self, conv):
+ conv_weight = conv.weight.data
+ nn.init.zeros_(conv_weight)
+ c1, c2, t, h, w = conv_weight.size()
+ init_matrix = torch.eye(c1 // 2, c2)
+ #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
+ conv.weight.data.copy_(conv_weight)
+ nn.init.zeros_(conv.bias.data)
+
+
+class ResidualBlock(nn.Module):
+
+ def __init__(self, in_dim, out_dim, dropout=0.0):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+
+ # layers
+ self.residual = nn.Sequential(
+ RMS_norm(in_dim, images=False), nn.SiLU(),
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
+ if in_dim != out_dim else nn.Identity()
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ h = self.shortcut(x)
+ for layer in self.residual:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x + h
+
+
+class AttentionBlock(nn.Module):
+ """
+ Causal self-attention with a single head.
+ """
+
+ def __init__(self, dim):
+ super().__init__()
+ self.dim = dim
+
+ # layers
+ self.norm = RMS_norm(dim)
+ self.to_qkv = ops.Conv2d(dim, dim * 3, 1)
+ self.proj = ops.Conv2d(dim, dim, 1)
+ self.optimized_attention = vae_attention()
+
+ def forward(self, x):
+ identity = x
+ b, c, t, h, w = x.size()
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
+ x = self.norm(x)
+ # compute query, key, value
+
+ q, k, v = self.to_qkv(x).chunk(3, dim=1)
+ x = self.optimized_attention(q, k, v)
+
+ # output
+ x = self.proj(x)
+ x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
+ return x + identity
+
+
+class Encoder3d(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+
+ # dimensions
+ dims = [dim * u for u in [1] + dim_mult]
+ scale = 1.0
+
+ # init block
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
+
+ # downsample blocks
+ downsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ for _ in range(num_res_blocks):
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ downsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # downsample block
+ if i != len(dim_mult) - 1:
+ mode = 'downsample3d' if temperal_downsample[
+ i] else 'downsample2d'
+ downsamples.append(Resample(out_dim, mode=mode))
+ scale /= 2.0
+ self.downsamples = nn.Sequential(*downsamples)
+
+ # middle blocks
+ self.middle = nn.Sequential(
+ ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
+ ResidualBlock(out_dim, out_dim, dropout))
+
+ # output blocks
+ self.head = nn.Sequential(
+ RMS_norm(out_dim, images=False), nn.SiLU(),
+ CausalConv3d(out_dim, z_dim, 3, padding=1))
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ ## downsamples
+ for layer in self.downsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## middle
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+class Decoder3d(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_upsample=[False, True, True],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_upsample = temperal_upsample
+
+ # dimensions
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
+ scale = 1.0 / 2**(len(dim_mult) - 2)
+
+ # init block
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
+
+ # middle blocks
+ self.middle = nn.Sequential(
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
+ ResidualBlock(dims[0], dims[0], dropout))
+
+ # upsample blocks
+ upsamples = []
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
+ # residual (+attention) blocks
+ if i == 1 or i == 2 or i == 3:
+ in_dim = in_dim // 2
+ for _ in range(num_res_blocks + 1):
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
+ if scale in attn_scales:
+ upsamples.append(AttentionBlock(out_dim))
+ in_dim = out_dim
+
+ # upsample block
+ if i != len(dim_mult) - 1:
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
+ upsamples.append(Resample(out_dim, mode=mode))
+ scale *= 2.0
+ self.upsamples = nn.Sequential(*upsamples)
+
+ # output blocks
+ self.head = nn.Sequential(
+ RMS_norm(out_dim, images=False), nn.SiLU(),
+ CausalConv3d(out_dim, 3, 3, padding=1))
+
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
+ ## conv1
+ if feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = self.conv1(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = self.conv1(x)
+
+ ## middle
+ for layer in self.middle:
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## upsamples
+ for layer in self.upsamples:
+ if feat_cache is not None:
+ x = layer(x, feat_cache, feat_idx)
+ else:
+ x = layer(x)
+
+ ## head
+ for layer in self.head:
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
+ idx = feat_idx[0]
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
+ # cache last frame of last two chunk
+ cache_x = torch.cat([
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
+ cache_x.device), cache_x
+ ],
+ dim=2)
+ x = layer(x, feat_cache[idx])
+ feat_cache[idx] = cache_x
+ feat_idx[0] += 1
+ else:
+ x = layer(x)
+ return x
+
+
+def count_conv3d(model):
+ count = 0
+ for m in model.modules():
+ if isinstance(m, CausalConv3d):
+ count += 1
+ return count
+
+
+class WanVAE(nn.Module):
+
+ def __init__(self,
+ dim=128,
+ z_dim=4,
+ dim_mult=[1, 2, 4, 4],
+ num_res_blocks=2,
+ attn_scales=[],
+ temperal_downsample=[True, True, False],
+ dropout=0.0):
+ super().__init__()
+ self.dim = dim
+ self.z_dim = z_dim
+ self.dim_mult = dim_mult
+ self.num_res_blocks = num_res_blocks
+ self.attn_scales = attn_scales
+ self.temperal_downsample = temperal_downsample
+ self.temperal_upsample = temperal_downsample[::-1]
+
+ # modules
+ self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
+ attn_scales, self.temperal_downsample, dropout)
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
+ self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
+ attn_scales, self.temperal_upsample, dropout)
+
+ def forward(self, x):
+ mu, log_var = self.encode(x)
+ z = self.reparameterize(mu, log_var)
+ x_recon = self.decode(z)
+ return x_recon, mu, log_var
+
+ def encode(self, x):
+ self.clear_cache()
+ ## cache
+ t = x.shape[2]
+ iter_ = 1 + (t - 1) // 4
+ ## 对encode输入的x,按时间拆分为1、4、4、4....
+ for i in range(iter_):
+ self._enc_conv_idx = [0]
+ if i == 0:
+ out = self.encoder(
+ x[:, :, :1, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx)
+ else:
+ out_ = self.encoder(
+ x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
+ feat_cache=self._enc_feat_map,
+ feat_idx=self._enc_conv_idx)
+ out = torch.cat([out, out_], 2)
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
+ self.clear_cache()
+ return mu
+
+ def decode(self, z):
+ self.clear_cache()
+ # z: [b,c,t,h,w]
+
+ iter_ = z.shape[2]
+ x = self.conv2(z)
+ for i in range(iter_):
+ self._conv_idx = [0]
+ if i == 0:
+ out = self.decoder(
+ x[:, :, i:i + 1, :, :],
+ feat_cache=self._feat_map,
+ feat_idx=self._conv_idx)
+ else:
+ out_ = self.decoder(
+ x[:, :, i:i + 1, :, :],
+ feat_cache=self._feat_map,
+ feat_idx=self._conv_idx)
+ out = torch.cat([out, out_], 2)
+ self.clear_cache()
+ return out
+
+ def reparameterize(self, mu, log_var):
+ std = torch.exp(0.5 * log_var)
+ eps = torch.randn_like(std)
+ return eps * std + mu
+
+ def sample(self, imgs, deterministic=False):
+ mu, log_var = self.encode(imgs)
+ if deterministic:
+ return mu
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
+ return mu + std * torch.randn_like(std)
+
+ def clear_cache(self):
+ self._conv_num = count_conv3d(self.decoder)
+ self._conv_idx = [0]
+ self._feat_map = [None] * self._conv_num
+ #cache encode
+ self._enc_conv_num = count_conv3d(self.encoder)
+ self._enc_conv_idx = [0]
+ self._enc_feat_map = [None] * self._enc_conv_num
diff --git a/comfy/lora.py b/comfy/lora.py
index ec3da6f4c..ef110c164 100644
--- a/comfy/lora.py
+++ b/comfy/lora.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import comfy.utils
import comfy.model_management
import comfy.model_base
+import comfy.weight_adapter as weight_adapter
import logging
import torch
@@ -49,139 +50,12 @@ def load_lora(lora, to_load, log_missing=True):
dora_scale = lora[dora_scale_name]
loaded_keys.add(dora_scale_name)
- reshape_name = "{}.reshape_weight".format(x)
- reshape = None
- if reshape_name in lora.keys():
- try:
- reshape = lora[reshape_name].tolist()
- loaded_keys.add(reshape_name)
- except:
- pass
-
- regular_lora = "{}.lora_up.weight".format(x)
- diffusers_lora = "{}_lora.up.weight".format(x)
- diffusers2_lora = "{}.lora_B.weight".format(x)
- diffusers3_lora = "{}.lora.up.weight".format(x)
- mochi_lora = "{}.lora_B".format(x)
- transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
- A_name = None
-
- if regular_lora in lora.keys():
- A_name = regular_lora
- B_name = "{}.lora_down.weight".format(x)
- mid_name = "{}.lora_mid.weight".format(x)
- elif diffusers_lora in lora.keys():
- A_name = diffusers_lora
- B_name = "{}_lora.down.weight".format(x)
- mid_name = None
- elif diffusers2_lora in lora.keys():
- A_name = diffusers2_lora
- B_name = "{}.lora_A.weight".format(x)
- mid_name = None
- elif diffusers3_lora in lora.keys():
- A_name = diffusers3_lora
- B_name = "{}.lora.down.weight".format(x)
- mid_name = None
- elif mochi_lora in lora.keys():
- A_name = mochi_lora
- B_name = "{}.lora_A".format(x)
- mid_name = None
- elif transformers_lora in lora.keys():
- A_name = transformers_lora
- B_name ="{}.lora_linear_layer.down.weight".format(x)
- mid_name = None
-
- if A_name is not None:
- mid = None
- if mid_name is not None and mid_name in lora.keys():
- mid = lora[mid_name]
- loaded_keys.add(mid_name)
- patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape))
- loaded_keys.add(A_name)
- loaded_keys.add(B_name)
-
-
- ######## loha
- hada_w1_a_name = "{}.hada_w1_a".format(x)
- hada_w1_b_name = "{}.hada_w1_b".format(x)
- hada_w2_a_name = "{}.hada_w2_a".format(x)
- hada_w2_b_name = "{}.hada_w2_b".format(x)
- hada_t1_name = "{}.hada_t1".format(x)
- hada_t2_name = "{}.hada_t2".format(x)
- if hada_w1_a_name in lora.keys():
- hada_t1 = None
- hada_t2 = None
- if hada_t1_name in lora.keys():
- hada_t1 = lora[hada_t1_name]
- hada_t2 = lora[hada_t2_name]
- loaded_keys.add(hada_t1_name)
- loaded_keys.add(hada_t2_name)
-
- patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale))
- loaded_keys.add(hada_w1_a_name)
- loaded_keys.add(hada_w1_b_name)
- loaded_keys.add(hada_w2_a_name)
- loaded_keys.add(hada_w2_b_name)
-
-
- ######## lokr
- lokr_w1_name = "{}.lokr_w1".format(x)
- lokr_w2_name = "{}.lokr_w2".format(x)
- lokr_w1_a_name = "{}.lokr_w1_a".format(x)
- lokr_w1_b_name = "{}.lokr_w1_b".format(x)
- lokr_t2_name = "{}.lokr_t2".format(x)
- lokr_w2_a_name = "{}.lokr_w2_a".format(x)
- lokr_w2_b_name = "{}.lokr_w2_b".format(x)
-
- lokr_w1 = None
- if lokr_w1_name in lora.keys():
- lokr_w1 = lora[lokr_w1_name]
- loaded_keys.add(lokr_w1_name)
-
- lokr_w2 = None
- if lokr_w2_name in lora.keys():
- lokr_w2 = lora[lokr_w2_name]
- loaded_keys.add(lokr_w2_name)
-
- lokr_w1_a = None
- if lokr_w1_a_name in lora.keys():
- lokr_w1_a = lora[lokr_w1_a_name]
- loaded_keys.add(lokr_w1_a_name)
-
- lokr_w1_b = None
- if lokr_w1_b_name in lora.keys():
- lokr_w1_b = lora[lokr_w1_b_name]
- loaded_keys.add(lokr_w1_b_name)
-
- lokr_w2_a = None
- if lokr_w2_a_name in lora.keys():
- lokr_w2_a = lora[lokr_w2_a_name]
- loaded_keys.add(lokr_w2_a_name)
-
- lokr_w2_b = None
- if lokr_w2_b_name in lora.keys():
- lokr_w2_b = lora[lokr_w2_b_name]
- loaded_keys.add(lokr_w2_b_name)
-
- lokr_t2 = None
- if lokr_t2_name in lora.keys():
- lokr_t2 = lora[lokr_t2_name]
- loaded_keys.add(lokr_t2_name)
-
- if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
- patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale))
-
- #glora
- a1_name = "{}.a1.weight".format(x)
- a2_name = "{}.a2.weight".format(x)
- b1_name = "{}.b1.weight".format(x)
- b2_name = "{}.b2.weight".format(x)
- if a1_name in lora:
- patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale))
- loaded_keys.add(a1_name)
- loaded_keys.add(a2_name)
- loaded_keys.add(b1_name)
- loaded_keys.add(b2_name)
+ for adapter_cls in weight_adapter.adapters:
+ adapter = adapter_cls.load(x, lora, alpha, dora_scale, loaded_keys)
+ if adapter is not None:
+ patch_dict[to_load[x]] = adapter
+ loaded_keys.update(adapter.loaded_keys)
+ continue
w_norm_name = "{}.w_norm".format(x)
b_norm_name = "{}.b_norm".format(x)
@@ -307,7 +181,6 @@ def model_lora_keys_unet(model, key_map={}):
if k.endswith(".weight"):
key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
key_map["lora_unet_{}".format(key_lora)] = k
- key_map["lora_prior_unet_{}".format(key_lora)] = k #cascade lora: TODO put lora key prefix in the model config
key_map["{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
else:
key_map["{}".format(k)] = k #generic lora format for not .weight without any weird key names
@@ -327,6 +200,13 @@ def model_lora_keys_unet(model, key_map={}):
diffusers_lora_key = diffusers_lora_key[:-2]
key_map[diffusers_lora_key] = unet_key
+ if isinstance(model, comfy.model_base.StableCascade_C):
+ for k in sdk:
+ if k.startswith("diffusion_model."):
+ if k.endswith(".weight"):
+ key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
+ key_map["lora_prior_unet_{}".format(key_lora)] = k
+
if isinstance(model, comfy.model_base.SD3): #Diffusers lora SD3
diffusers_keys = comfy.utils.mmdit_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.")
for k in diffusers_keys:
@@ -399,29 +279,22 @@ def model_lora_keys_unet(model, key_map={}):
key_map["transformer.{}".format(key_lora)] = k
key_map["diffusion_model.{}".format(key_lora)] = k # Old loras
+ if isinstance(model, comfy.model_base.HiDream):
+ for k in sdk:
+ if k.startswith("diffusion_model."):
+ if k.endswith(".weight"):
+ key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
+ key_map["lycoris_{}".format(key_lora)] = k #SimpleTuner lycoris format
+
+ if isinstance(model, comfy.model_base.ACEStep):
+ for k in sdk:
+ if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official ACE step lora format
+ key_lora = k[len("diffusion_model."):-len(".weight")]
+ key_map["{}".format(key_lora)] = k
+
return key_map
-def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
- dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
- lora_diff *= alpha
- weight_calc = weight + function(lora_diff).type(weight.dtype)
- weight_norm = (
- weight_calc.transpose(0, 1)
- .reshape(weight_calc.shape[1], -1)
- .norm(dim=1, keepdim=True)
- .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
- .transpose(0, 1)
- )
-
- weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
- if strength != 1.0:
- weight_calc -= weight
- weight += strength * (weight_calc)
- else:
- weight[:] = weight_calc
- return weight
-
def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
"""
Pad a tensor to a new shape with zeros.
@@ -476,6 +349,16 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
if isinstance(v, list):
v = (calculate_weight(v[1:], v[0][1](comfy.model_management.cast_to_device(v[0][0], weight.device, intermediate_dtype, copy=True), inplace=True), key, intermediate_dtype=intermediate_dtype), )
+ if isinstance(v, weight_adapter.WeightAdapterBase):
+ output = v.calculate_weight(weight, key, strength, strength_model, offset, function, intermediate_dtype, original_weights)
+ if output is None:
+ logging.warning("Calculate Weight Failed: {} {}".format(v.name, key))
+ else:
+ weight = output
+ if old_weight is not None:
+ weight = old_weight
+ continue
+
if len(v) == 1:
patch_type = "diff"
elif len(v) == 2:
@@ -502,157 +385,6 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
diff_weight = comfy.model_management.cast_to_device(target_weight, weight.device, intermediate_dtype) - \
comfy.model_management.cast_to_device(original_weights[key][0][0], weight.device, intermediate_dtype)
weight += function(strength * comfy.model_management.cast_to_device(diff_weight, weight.device, weight.dtype))
- elif patch_type == "lora": #lora/locon
- mat1 = comfy.model_management.cast_to_device(v[0], weight.device, intermediate_dtype)
- mat2 = comfy.model_management.cast_to_device(v[1], weight.device, intermediate_dtype)
- dora_scale = v[4]
- reshape = v[5]
-
- if reshape is not None:
- weight = pad_tensor_to_shape(weight, reshape)
-
- if v[2] is not None:
- alpha = v[2] / mat2.shape[0]
- else:
- alpha = 1.0
-
- if v[3] is not None:
- #locon mid weights, hopefully the math is fine because I didn't properly test it
- mat3 = comfy.model_management.cast_to_device(v[3], weight.device, intermediate_dtype)
- final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
- mat2 = torch.mm(mat2.transpose(0, 1).flatten(start_dim=1), mat3.transpose(0, 1).flatten(start_dim=1)).reshape(final_shape).transpose(0, 1)
- try:
- lora_diff = torch.mm(mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)).reshape(weight.shape)
- if dora_scale is not None:
- weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
- else:
- weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
- except Exception as e:
- logging.error("ERROR {} {} {}".format(patch_type, key, e))
- elif patch_type == "lokr":
- w1 = v[0]
- w2 = v[1]
- w1_a = v[3]
- w1_b = v[4]
- w2_a = v[5]
- w2_b = v[6]
- t2 = v[7]
- dora_scale = v[8]
- dim = None
-
- if w1 is None:
- dim = w1_b.shape[0]
- w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
- comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
- else:
- w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
-
- if w2 is None:
- dim = w2_b.shape[0]
- if t2 is None:
- w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
- comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
- else:
- w2 = torch.einsum('i j k l, j r, i p -> p r k l',
- comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
- comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
- comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
- else:
- w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
-
- if len(w2.shape) == 4:
- w1 = w1.unsqueeze(2).unsqueeze(2)
- if v[2] is not None and dim is not None:
- alpha = v[2] / dim
- else:
- alpha = 1.0
-
- try:
- lora_diff = torch.kron(w1, w2).reshape(weight.shape)
- if dora_scale is not None:
- weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
- else:
- weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
- except Exception as e:
- logging.error("ERROR {} {} {}".format(patch_type, key, e))
- elif patch_type == "loha":
- w1a = v[0]
- w1b = v[1]
- if v[2] is not None:
- alpha = v[2] / w1b.shape[0]
- else:
- alpha = 1.0
-
- w2a = v[3]
- w2b = v[4]
- dora_scale = v[7]
- if v[5] is not None: #cp decomposition
- t1 = v[5]
- t2 = v[6]
- m1 = torch.einsum('i j k l, j r, i p -> p r k l',
- comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
- comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
- comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
-
- m2 = torch.einsum('i j k l, j r, i p -> p r k l',
- comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
- comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
- comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
- else:
- m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
- comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
- m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
- comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
-
- try:
- lora_diff = (m1 * m2).reshape(weight.shape)
- if dora_scale is not None:
- weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
- else:
- weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
- except Exception as e:
- logging.error("ERROR {} {} {}".format(patch_type, key, e))
- elif patch_type == "glora":
- dora_scale = v[5]
-
- old_glora = False
- if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
- rank = v[0].shape[0]
- old_glora = True
-
- if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
- if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
- pass
- else:
- old_glora = False
- rank = v[1].shape[0]
-
- a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
- a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
- b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
- b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
-
- if v[4] is not None:
- alpha = v[4] / rank
- else:
- alpha = 1.0
-
- try:
- if old_glora:
- lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
- else:
- if weight.dim() > 2:
- lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
- else:
- lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
- lora_diff += torch.mm(b1, b2).reshape(weight.shape)
-
- if dora_scale is not None:
- weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
- else:
- weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
- except Exception as e:
- logging.error("ERROR {} {} {}".format(patch_type, key, e))
else:
logging.warning("patch type not recognized {} {}".format(patch_type, key))
diff --git a/comfy/lora_convert.py b/comfy/lora_convert.py
index 05032c690..3e00b63db 100644
--- a/comfy/lora_convert.py
+++ b/comfy/lora_convert.py
@@ -1,4 +1,5 @@
import torch
+import comfy.utils
def convert_lora_bfl_control(sd): #BFL loras for Flux
@@ -11,7 +12,13 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
return sd_out
+def convert_lora_wan_fun(sd): #Wan Fun loras
+ return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
+
+
def convert_lora(sd):
if "img_in.lora_A.weight" in sd and "single_blocks.0.norm.key_norm.scale" in sd:
return convert_lora_bfl_control(sd)
+ if "lora_unet__blocks_0_cross_attn_k.lora_down.weight" in sd:
+ return convert_lora_wan_fun(sd)
return sd
diff --git a/comfy/model_base.py b/comfy/model_base.py
index a67504cbb..fb4724690 100644
--- a/comfy/model_base.py
+++ b/comfy/model_base.py
@@ -34,6 +34,12 @@ import comfy.ldm.flux.model
import comfy.ldm.lightricks.model
import comfy.ldm.hunyuan_video.model
import comfy.ldm.cosmos.model
+import comfy.ldm.lumina.model
+import comfy.ldm.wan.model
+import comfy.ldm.hunyuan3d.model
+import comfy.ldm.hidream.model
+import comfy.ldm.chroma.model
+import comfy.ldm.ace.model
import comfy.model_management
import comfy.patcher_extension
@@ -56,6 +62,7 @@ class ModelType(Enum):
FLOW = 6
V_PREDICTION_CONTINUOUS = 7
FLUX = 8
+ IMG_TO_IMG = 9
from comfy.model_sampling import EPS, V_PREDICTION, EDM, ModelSamplingDiscrete, ModelSamplingContinuousEDM, StableCascadeSampling, ModelSamplingContinuousV
@@ -86,6 +93,8 @@ def model_sampling(model_config, model_type):
elif model_type == ModelType.FLUX:
c = comfy.model_sampling.CONST
s = comfy.model_sampling.ModelSamplingFlux
+ elif model_type == ModelType.IMG_TO_IMG:
+ c = comfy.model_sampling.IMG_TO_IMG
class ModelSampling(s, c):
pass
@@ -106,7 +115,7 @@ class BaseModel(torch.nn.Module):
if not unet_config.get("disable_unet_model_creation", False):
if model_config.custom_operations is None:
- fp8 = model_config.optimizations.get("fp8", model_config.scaled_fp8 is not None)
+ fp8 = model_config.optimizations.get("fp8", False)
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8)
else:
operations = model_config.custom_operations
@@ -137,6 +146,7 @@ class BaseModel(torch.nn.Module):
def _apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
xc = self.model_sampling.calculate_input(sigma, x)
+
if c_concat is not None:
xc = torch.cat([xc] + [c_concat], dim=1)
@@ -148,7 +158,9 @@ class BaseModel(torch.nn.Module):
xc = xc.to(dtype)
t = self.model_sampling.timestep(t).float()
- context = context.to(dtype)
+ if context is not None:
+ context = context.to(dtype)
+
extra_conds = {}
for o in kwargs:
extra = kwargs[o]
@@ -157,15 +169,16 @@ class BaseModel(torch.nn.Module):
extra = extra.to(dtype)
extra_conds[o] = extra
+ t = self.process_timestep(t, x=x, **extra_conds)
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
return self.model_sampling.calculate_denoised(sigma, model_output, x)
+ def process_timestep(self, timestep, **kwargs):
+ return timestep
+
def get_dtype(self):
return self.diffusion_model.dtype
- def is_adm(self):
- return self.adm_channels > 0
-
def encode_adm(self, **kwargs):
return None
@@ -184,14 +197,20 @@ class BaseModel(torch.nn.Module):
if concat_latent_image.shape[1:] != noise.shape[1:]:
concat_latent_image = utils.common_upscale(concat_latent_image, noise.shape[-1], noise.shape[-2], "bilinear", "center")
+ if noise.ndim == 5:
+ if concat_latent_image.shape[-3] < noise.shape[-3]:
+ concat_latent_image = torch.nn.functional.pad(concat_latent_image, (0, 0, 0, 0, 0, noise.shape[-3] - concat_latent_image.shape[-3]), "constant", 0)
+ else:
+ concat_latent_image = concat_latent_image[:, :, :noise.shape[-3]]
concat_latent_image = utils.resize_to_batch_size(concat_latent_image, noise.shape[0])
if denoise_mask is not None:
if len(denoise_mask.shape) == len(noise.shape):
- denoise_mask = denoise_mask[:,:1]
+ denoise_mask = denoise_mask[:, :1]
- denoise_mask = denoise_mask.reshape((-1, 1, denoise_mask.shape[-2], denoise_mask.shape[-1]))
+ num_dim = noise.ndim - 2
+ denoise_mask = denoise_mask.reshape((-1, 1) + tuple(denoise_mask.shape[-num_dim:]))
if denoise_mask.shape[-2:] != noise.shape[-2:]:
denoise_mask = utils.common_upscale(denoise_mask, noise.shape[-1], noise.shape[-2], "bilinear", "center")
denoise_mask = utils.resize_to_batch_size(denoise_mask.round(), noise.shape[0])
@@ -201,12 +220,21 @@ class BaseModel(torch.nn.Module):
if ck == "mask":
cond_concat.append(denoise_mask.to(device))
elif ck == "masked_image":
- cond_concat.append(concat_latent_image.to(device)) #NOTE: the latent_image should be masked by the mask in pixel space
+ cond_concat.append(concat_latent_image.to(device)) # NOTE: the latent_image should be masked by the mask in pixel space
+ elif ck == "mask_inverted":
+ cond_concat.append(1.0 - denoise_mask.to(device))
else:
if ck == "mask":
- cond_concat.append(torch.ones_like(noise)[:,:1])
+ cond_concat.append(torch.ones_like(noise)[:, :1])
elif ck == "masked_image":
cond_concat.append(self.blank_inpaint_image_like(noise))
+ elif ck == "mask_inverted":
+ cond_concat.append(torch.zeros_like(noise)[:, :1])
+ if ck == "concat_image":
+ if concat_latent_image is not None:
+ cond_concat.append(concat_latent_image.to(device))
+ else:
+ cond_concat.append(torch.zeros_like(noise))
data = torch.cat(cond_concat, dim=1)
return data
return None
@@ -294,6 +322,9 @@ class BaseModel(torch.nn.Module):
return blank_image
self.blank_inpaint_image_like = blank_inpaint_image_like
+ def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
+ return self.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1)), noise, latent_image)
+
def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
@@ -541,6 +572,10 @@ class SD_X4Upscaler(BaseModel):
out['c_concat'] = comfy.conds.CONDNoiseShape(image)
out['y'] = comfy.conds.CONDRegular(noise_level)
+
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
return out
class IP2P:
@@ -573,6 +608,19 @@ class SDXL_instructpix2pix(IP2P, SDXL):
else:
self.process_ip2p_image_in = lambda image: image #diffusers ip2p
+class Lotus(BaseModel):
+ def extra_conds(self, **kwargs):
+ out = {}
+ cross_attn = kwargs.get("cross_attn", None)
+ out['c_crossattn'] = comfy.conds.CONDCrossAttn(cross_attn)
+ device = kwargs["device"]
+ task_emb = torch.tensor([1, 0]).float().to(device)
+ task_emb = torch.cat([torch.sin(task_emb), torch.cos(task_emb)]).unsqueeze(0)
+ out['y'] = comfy.conds.CONDRegular(task_emb)
+ return out
+
+ def __init__(self, model_config, model_type=ModelType.IMG_TO_IMG, device=None):
+ super().__init__(model_config, model_type, device=device)
class StableCascade_C(BaseModel):
def __init__(self, model_config, model_type=ModelType.STABLE_CASCADE, device=None):
@@ -740,8 +788,8 @@ class PixArt(BaseModel):
return out
class Flux(BaseModel):
- def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
- super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
+ def __init__(self, model_config, model_type=ModelType.FLUX, device=None, unet_model=comfy.ldm.flux.model.Flux):
+ super().__init__(model_config, model_type, device=device, unet_model=unet_model)
def concat_cond(self, **kwargs):
try:
@@ -798,7 +846,10 @@ class Flux(BaseModel):
(h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size))
attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok))
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
- out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
+
+ guidance = kwargs.get("guidance", 3.5)
+ if guidance is not None:
+ out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
return out
class GenmoMochi(BaseModel):
@@ -829,17 +880,26 @@ class LTXV(BaseModel):
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
- guiding_latent = kwargs.get("guiding_latent", None)
- if guiding_latent is not None:
- out['guiding_latent'] = comfy.conds.CONDRegular(guiding_latent)
-
- guiding_latent_noise_scale = kwargs.get("guiding_latent_noise_scale", None)
- if guiding_latent_noise_scale is not None:
- out["guiding_latent_noise_scale"] = comfy.conds.CONDConstant(guiding_latent_noise_scale)
-
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
+
+ denoise_mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
+ if denoise_mask is not None:
+ out["denoise_mask"] = comfy.conds.CONDRegular(denoise_mask)
+
+ keyframe_idxs = kwargs.get("keyframe_idxs", None)
+ if keyframe_idxs is not None:
+ out['keyframe_idxs'] = comfy.conds.CONDRegular(keyframe_idxs)
+
return out
+ def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
+ if denoise_mask is None:
+ return timestep
+ return self.diffusion_model.patchifier.patchify(((denoise_mask) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1)))[:, :1])[0]
+
+ def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
+ return latent_image
+
class HunyuanVideo(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan_video.model.HunyuanVideo)
@@ -855,12 +915,46 @@ class HunyuanVideo(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
- out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 6.0)]))
+
+ guidance = kwargs.get("guidance", 6.0)
+ if guidance is not None:
+ out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
+
+ guiding_frame_index = kwargs.get("guiding_frame_index", None)
+ if guiding_frame_index is not None:
+ out['guiding_frame_index'] = comfy.conds.CONDRegular(torch.FloatTensor([guiding_frame_index]))
+
+ ref_latent = kwargs.get("ref_latent", None)
+ if ref_latent is not None:
+ out['ref_latent'] = comfy.conds.CONDRegular(self.process_latent_in(ref_latent))
+
return out
+ def scale_latent_inpaint(self, latent_image, **kwargs):
+ return latent_image
+
+class HunyuanVideoI2V(HunyuanVideo):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device)
+ self.concat_keys = ("concat_image", "mask_inverted")
+
+ def scale_latent_inpaint(self, latent_image, **kwargs):
+ return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
+
+class HunyuanVideoSkyreelsI2V(HunyuanVideo):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device)
+ self.concat_keys = ("concat_image",)
+
+ def scale_latent_inpaint(self, latent_image, **kwargs):
+ return super().scale_latent_inpaint(latent_image=latent_image, **kwargs)
+
class CosmosVideo(BaseModel):
- def __init__(self, model_config, model_type=ModelType.EDM, device=None):
+ def __init__(self, model_config, model_type=ModelType.EDM, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cosmos.model.GeneralDIT)
+ self.image_to_video = image_to_video
+ if self.image_to_video:
+ self.concat_keys = ("mask_inverted",)
def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
@@ -873,3 +967,197 @@ class CosmosVideo(BaseModel):
out['fps'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", None))
return out
+
+ def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
+ sigma = sigma.reshape([sigma.shape[0]] + [1] * (len(noise.shape) - 1))
+ sigma_noise_augmentation = 0 #TODO
+ if sigma_noise_augmentation != 0:
+ latent_image = latent_image + noise
+ latent_image = self.model_sampling.calculate_input(torch.tensor([sigma_noise_augmentation], device=latent_image.device, dtype=latent_image.dtype), latent_image)
+ return latent_image * ((sigma ** 2 + self.model_sampling.sigma_data ** 2) ** 0.5)
+
+class Lumina2(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiT)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ attention_mask = kwargs.get("attention_mask", None)
+ if attention_mask is not None:
+ if torch.numel(attention_mask) != attention_mask.sum():
+ out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
+ out['num_tokens'] = comfy.conds.CONDConstant(max(1, torch.sum(attention_mask).item()))
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+ return out
+
+class WAN21(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
+ self.image_to_video = image_to_video
+
+ def concat_cond(self, **kwargs):
+ noise = kwargs.get("noise", None)
+ extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
+ if extra_channels == 0:
+ return None
+
+ image = kwargs.get("concat_latent_image", None)
+ device = kwargs["device"]
+
+ if image is None:
+ shape_image = list(noise.shape)
+ shape_image[1] = extra_channels
+ image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device)
+ else:
+ image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+ for i in range(0, image.shape[1], 16):
+ image[:, i: i + 16] = self.process_latent_in(image[:, i: i + 16])
+ image = utils.resize_to_batch_size(image, noise.shape[0])
+
+ if not self.image_to_video or extra_channels == image.shape[1]:
+ return image
+
+ if image.shape[1] > (extra_channels - 4):
+ image = image[:, :(extra_channels - 4)]
+
+ mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
+ if mask is None:
+ mask = torch.zeros_like(noise)[:, :4]
+ else:
+ if mask.shape[1] != 4:
+ mask = torch.mean(mask, dim=1, keepdim=True)
+ mask = 1.0 - mask
+ mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
+ if mask.shape[-3] < noise.shape[-3]:
+ mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
+ if mask.shape[1] == 1:
+ mask = mask.repeat(1, 4, 1, 1, 1)
+ mask = utils.resize_to_batch_size(mask, noise.shape[0])
+
+ return torch.cat((mask, image), dim=1)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ clip_vision_output = kwargs.get("clip_vision_output", None)
+ if clip_vision_output is not None:
+ out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
+ return out
+
+
+class WAN21_Vace(WAN21):
+ def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
+ super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.VaceWanModel)
+ self.image_to_video = image_to_video
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ noise = kwargs.get("noise", None)
+ noise_shape = list(noise.shape)
+ vace_frames = kwargs.get("vace_frames", None)
+ if vace_frames is None:
+ noise_shape[1] = 32
+ vace_frames = [torch.zeros(noise_shape, device=noise.device, dtype=noise.dtype)]
+
+ mask = kwargs.get("vace_mask", None)
+ if mask is None:
+ noise_shape[1] = 64
+ mask = [torch.ones(noise_shape, device=noise.device, dtype=noise.dtype)] * len(vace_frames)
+
+ vace_frames_out = []
+ for j in range(len(vace_frames)):
+ vf = vace_frames[j].clone()
+ for i in range(0, vf.shape[1], 16):
+ vf[:, i:i + 16] = self.process_latent_in(vf[:, i:i + 16])
+ vf = torch.cat([vf, mask[j]], dim=1)
+ vace_frames_out.append(vf)
+
+ vace_frames = torch.stack(vace_frames_out, dim=1)
+ out['vace_context'] = comfy.conds.CONDRegular(vace_frames)
+
+ vace_strength = kwargs.get("vace_strength", [1.0] * len(vace_frames_out))
+ out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
+ return out
+
+class WAN21_Camera(WAN21):
+ def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
+ super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
+ self.image_to_video = image_to_video
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ camera_conditions = kwargs.get("camera_conditions", None)
+ if camera_conditions is not None:
+ out['camera_conditions'] = comfy.conds.CONDRegular(camera_conditions)
+ return out
+
+class Hunyuan3Dv2(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hunyuan3d.model.Hunyuan3Dv2)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ guidance = kwargs.get("guidance", 5.0)
+ if guidance is not None:
+ out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
+ return out
+
+class HiDream(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.hidream.model.HiDreamImageTransformer2DModel)
+
+ def encode_adm(self, **kwargs):
+ return kwargs["pooled_output"]
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+ conditioning_llama3 = kwargs.get("conditioning_llama3", None)
+ if conditioning_llama3 is not None:
+ out['encoder_hidden_states_llama3'] = comfy.conds.CONDRegular(conditioning_llama3)
+ image_cond = kwargs.get("concat_latent_image", None)
+ if image_cond is not None:
+ out['image_cond'] = comfy.conds.CONDNoiseShape(self.process_latent_in(image_cond))
+ return out
+
+class Chroma(Flux):
+ def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.chroma.model.Chroma)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+
+ guidance = kwargs.get("guidance", 0)
+ if guidance is not None:
+ out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([guidance]))
+ return out
+
+class ACEStep(BaseModel):
+ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
+ super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ace.model.ACEStepTransformer2DModel)
+
+ def extra_conds(self, **kwargs):
+ out = super().extra_conds(**kwargs)
+ noise = kwargs.get("noise", None)
+
+ cross_attn = kwargs.get("cross_attn", None)
+ if cross_attn is not None:
+ out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
+
+ conditioning_lyrics = kwargs.get("conditioning_lyrics", None)
+ if cross_attn is not None:
+ out['lyric_token_idx'] = comfy.conds.CONDRegular(conditioning_lyrics)
+ out['speaker_embeds'] = comfy.conds.CONDRegular(torch.zeros(noise.shape[0], 512, device=noise.device, dtype=noise.dtype))
+ out['lyrics_strength'] = comfy.conds.CONDConstant(kwargs.get("lyrics_strength", 1.0))
+ return out
diff --git a/comfy/model_detection.py b/comfy/model_detection.py
index 20cd6bb86..74f539598 100644
--- a/comfy/model_detection.py
+++ b/comfy/model_detection.py
@@ -1,3 +1,4 @@
+import json
import comfy.supported_models
import comfy.supported_models_base
import comfy.utils
@@ -33,7 +34,7 @@ def calculate_transformer_depth(prefix, state_dict_keys, state_dict):
return last_transformer_depth, context_dim, use_linear_in_transformer, time_stack, time_stack_cross
return None
-def detect_unet_config(state_dict, key_prefix):
+def detect_unet_config(state_dict, key_prefix, metadata=None):
state_dict_keys = list(state_dict.keys())
if '{}joint_blocks.0.context_block.attn.qkv.weight'.format(key_prefix) in state_dict_keys: #mmdit model
@@ -136,7 +137,7 @@ def detect_unet_config(state_dict, key_prefix):
if '{}txt_in.individual_token_refiner.blocks.0.norm1.weight'.format(key_prefix) in state_dict_keys: #Hunyuan Video
dit_config = {}
dit_config["image_model"] = "hunyuan_video"
- dit_config["in_channels"] = 16
+ dit_config["in_channels"] = state_dict['{}img_in.proj.weight'.format(key_prefix)].shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = [1, 2, 2]
dit_config["out_channels"] = 16
dit_config["vec_in_dim"] = 768
@@ -153,7 +154,7 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["guidance_embed"] = len(guidance_keys) > 0
return dit_config
- if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys: #Flux
+ if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and '{}img_in.weight'.format(key_prefix) in state_dict_keys: #Flux
dit_config = {}
dit_config["image_model"] = "flux"
dit_config["in_channels"] = 16
@@ -163,7 +164,9 @@ def detect_unet_config(state_dict, key_prefix):
if in_key in state_dict_keys:
dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size)
dit_config["out_channels"] = 16
- dit_config["vec_in_dim"] = 768
+ vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix)
+ if vec_in_key in state_dict_keys:
+ dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1]
dit_config["context_in_dim"] = 4096
dit_config["hidden_size"] = 3072
dit_config["mlp_ratio"] = 4.0
@@ -173,7 +176,16 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["axes_dim"] = [16, 56, 56]
dit_config["theta"] = 10000
dit_config["qkv_bias"] = True
- dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
+ if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
+ dit_config["image_model"] = "chroma"
+ dit_config["in_channels"] = 64
+ dit_config["out_channels"] = 64
+ dit_config["in_dim"] = 64
+ dit_config["out_dim"] = 3072
+ dit_config["hidden_dim"] = 5120
+ dit_config["n_layers"] = 5
+ else:
+ dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
return dit_config
if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: #Genmo mochi preview
@@ -210,6 +222,37 @@ def detect_unet_config(state_dict, key_prefix):
if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv
dit_config = {}
dit_config["image_model"] = "ltxv"
+ dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
+ shape = state_dict['{}transformer_blocks.0.attn2.to_k.weight'.format(key_prefix)].shape
+ dit_config["attention_head_dim"] = shape[0] // 32
+ dit_config["cross_attention_dim"] = shape[1]
+ if metadata is not None and "config" in metadata:
+ dit_config.update(json.loads(metadata["config"]).get("transformer", {}))
+ return dit_config
+
+ if '{}genre_embedder.weight'.format(key_prefix) in state_dict_keys: #ACE-Step model
+ dit_config = {}
+ dit_config["audio_model"] = "ace"
+ dit_config["attention_head_dim"] = 128
+ dit_config["in_channels"] = 8
+ dit_config["inner_dim"] = 2560
+ dit_config["max_height"] = 16
+ dit_config["max_position"] = 32768
+ dit_config["max_width"] = 32768
+ dit_config["mlp_ratio"] = 2.5
+ dit_config["num_attention_heads"] = 20
+ dit_config["num_layers"] = 24
+ dit_config["out_channels"] = 8
+ dit_config["patch_size"] = [16, 1]
+ dit_config["rope_theta"] = 1000000.0
+ dit_config["speaker_embedding_dim"] = 512
+ dit_config["text_embedding_dim"] = 768
+
+ dit_config["ssl_encoder_depths"] = [8, 8]
+ dit_config["ssl_latent_dims"] = [1024, 768]
+ dit_config["ssl_names"] = ["mert", "m-hubert"]
+ dit_config["lyric_encoder_vocab_size"] = 6693
+ dit_config["lyric_hidden_size"] = 1024
return dit_config
if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt
@@ -239,19 +282,20 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["micro_condition"] = False
return dit_config
- if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys:
+ if '{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix) in state_dict_keys: # Cosmos
dit_config = {}
dit_config["image_model"] = "cosmos"
dit_config["max_img_h"] = 240
dit_config["max_img_w"] = 240
dit_config["max_frames"] = 128
- dit_config["in_channels"] = 16
+ concat_padding_mask = True
+ dit_config["in_channels"] = (state_dict['{}x_embedder.proj.1.weight'.format(key_prefix)].shape[1] // 4) - int(concat_padding_mask)
dit_config["out_channels"] = 16
dit_config["patch_spatial"] = 2
dit_config["patch_temporal"] = 1
dit_config["model_channels"] = state_dict['{}blocks.block0.blocks.0.block.attn.to_q.0.weight'.format(key_prefix)].shape[0]
dit_config["block_config"] = "FA-CA-MLP"
- dit_config["concat_padding_mask"] = True
+ dit_config["concat_padding_mask"] = concat_padding_mask
dit_config["pos_emb_cls"] = "rope3d"
dit_config["pos_emb_learnable"] = False
dit_config["pos_emb_interpolation"] = "crop"
@@ -283,6 +327,86 @@ def detect_unet_config(state_dict, key_prefix):
dit_config["extra_per_block_abs_pos_emb_type"] = "learnable"
return dit_config
+ if '{}cap_embedder.1.weight'.format(key_prefix) in state_dict_keys: # Lumina 2
+ dit_config = {}
+ dit_config["image_model"] = "lumina2"
+ dit_config["patch_size"] = 2
+ dit_config["in_channels"] = 16
+ dit_config["dim"] = 2304
+ dit_config["cap_feat_dim"] = 2304
+ dit_config["n_layers"] = 26
+ dit_config["n_heads"] = 24
+ dit_config["n_kv_heads"] = 8
+ dit_config["qk_norm"] = True
+ dit_config["axes_dims"] = [32, 32, 32]
+ dit_config["axes_lens"] = [300, 512, 512]
+ return dit_config
+
+ if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
+ dit_config = {}
+ dit_config["image_model"] = "wan2.1"
+ dim = state_dict['{}head.modulation'.format(key_prefix)].shape[-1]
+ dit_config["dim"] = dim
+ dit_config["num_heads"] = dim // 128
+ dit_config["ffn_dim"] = state_dict['{}blocks.0.ffn.0.weight'.format(key_prefix)].shape[0]
+ dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')
+ dit_config["patch_size"] = (1, 2, 2)
+ dit_config["freq_dim"] = 256
+ dit_config["window_size"] = (-1, -1)
+ dit_config["qk_norm"] = True
+ dit_config["cross_attn_norm"] = True
+ dit_config["eps"] = 1e-6
+ dit_config["in_dim"] = state_dict['{}patch_embedding.weight'.format(key_prefix)].shape[1]
+ if '{}vace_patch_embedding.weight'.format(key_prefix) in state_dict_keys:
+ dit_config["model_type"] = "vace"
+ dit_config["vace_in_dim"] = state_dict['{}vace_patch_embedding.weight'.format(key_prefix)].shape[1]
+ dit_config["vace_layers"] = count_blocks(state_dict_keys, '{}vace_blocks.'.format(key_prefix) + '{}.')
+ elif '{}control_adapter.conv.weight'.format(key_prefix) in state_dict_keys:
+ dit_config["model_type"] = "camera"
+ else:
+ if '{}img_emb.proj.0.bias'.format(key_prefix) in state_dict_keys:
+ dit_config["model_type"] = "i2v"
+ else:
+ dit_config["model_type"] = "t2v"
+ flf_weight = state_dict.get('{}img_emb.emb_pos'.format(key_prefix))
+ if flf_weight is not None:
+ dit_config["flf_pos_embed_token_number"] = flf_weight.shape[1]
+ return dit_config
+
+ if '{}latent_in.weight'.format(key_prefix) in state_dict_keys: # Hunyuan 3D
+ in_shape = state_dict['{}latent_in.weight'.format(key_prefix)].shape
+ dit_config = {}
+ dit_config["image_model"] = "hunyuan3d2"
+ dit_config["in_channels"] = in_shape[1]
+ dit_config["context_in_dim"] = state_dict['{}cond_in.weight'.format(key_prefix)].shape[1]
+ dit_config["hidden_size"] = in_shape[0]
+ dit_config["mlp_ratio"] = 4.0
+ dit_config["num_heads"] = 16
+ dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
+ dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
+ dit_config["qkv_bias"] = True
+ dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
+ return dit_config
+
+ if '{}caption_projection.0.linear.weight'.format(key_prefix) in state_dict_keys: # HiDream
+ dit_config = {}
+ dit_config["image_model"] = "hidream"
+ dit_config["attention_head_dim"] = 128
+ dit_config["axes_dims_rope"] = [64, 32, 32]
+ dit_config["caption_channels"] = [4096, 4096]
+ dit_config["max_resolution"] = [128, 128]
+ dit_config["in_channels"] = 16
+ dit_config["llama_layers"] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31]
+ dit_config["num_attention_heads"] = 20
+ dit_config["num_routed_experts"] = 4
+ dit_config["num_activated_experts"] = 2
+ dit_config["num_layers"] = 16
+ dit_config["num_single_layers"] = 32
+ dit_config["out_channels"] = 16
+ dit_config["patch_size"] = 2
+ dit_config["text_emb_dim"] = 2048
+ return dit_config
+
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
return None
@@ -417,8 +541,8 @@ def model_config_from_unet_config(unet_config, state_dict=None):
logging.error("no match {}".format(unet_config))
return None
-def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False):
- unet_config = detect_unet_config(state_dict, unet_key_prefix)
+def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=False, metadata=None):
+ unet_config = detect_unet_config(state_dict, unet_key_prefix, metadata=metadata)
if unet_config is None:
return None
model_config = model_config_from_unet_config(unet_config, state_dict)
@@ -431,6 +555,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
model_config.scaled_fp8 = scaled_fp8_weight.dtype
if model_config.scaled_fp8 == torch.float32:
model_config.scaled_fp8 = torch.float8_e4m3fn
+ if scaled_fp8_weight.nelement() == 2:
+ model_config.optimizations["fp8"] = False
+ else:
+ model_config.optimizations["fp8"] = True
return model_config
@@ -492,6 +620,9 @@ def convert_config(unet_config):
def unet_config_from_diffusers_unet(state_dict, dtype=None):
+ if "conv_in.weight" not in state_dict:
+ return None
+
match = {}
transformer_depth = []
@@ -623,8 +754,13 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
'use_temporal_attention': False, 'use_temporal_resblock': False}
+ LotusD = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': 4,
+ 'dtype': dtype, 'in_channels': 4, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
+ 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': True, 'context_dim': 1024, 'num_heads': 8,
+ 'transformer_depth_output': [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
+ 'use_temporal_attention': False, 'use_temporal_resblock': False}
- supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
+ supported_models = [LotusD, SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B, SD09_XS, SD_XS, SDXL_diffusers_ip2p, SD15_diffusers_inpaint]
for unet_config in supported_models:
matches = True
diff --git a/comfy/model_management.py b/comfy/model_management.py
index f6dfc18b0..f5b37e68e 100644
--- a/comfy/model_management.py
+++ b/comfy/model_management.py
@@ -19,7 +19,7 @@
import psutil
import logging
from enum import Enum
-from comfy.cli_args import args
+from comfy.cli_args import args, PerformanceFeature
import torch
import sys
import platform
@@ -46,11 +46,39 @@ cpu_state = CPUState.GPU
total_vram = 0
+def get_supported_float8_types():
+ float8_types = []
+ try:
+ float8_types.append(torch.float8_e4m3fn)
+ except:
+ pass
+ try:
+ float8_types.append(torch.float8_e4m3fnuz)
+ except:
+ pass
+ try:
+ float8_types.append(torch.float8_e5m2)
+ except:
+ pass
+ try:
+ float8_types.append(torch.float8_e5m2fnuz)
+ except:
+ pass
+ try:
+ float8_types.append(torch.float8_e8m0fnu)
+ except:
+ pass
+ return float8_types
+
+FLOAT8_TYPES = get_supported_float8_types()
+
xpu_available = False
torch_version = ""
try:
torch_version = torch.version.__version__
- xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()
+ temp = torch_version.split(".")
+ torch_version_numeric = (int(temp[0]), int(temp[1]))
+ xpu_available = (torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] <= 4)) and torch.xpu.is_available()
except:
pass
@@ -93,6 +121,13 @@ try:
except:
npu_available = False
+try:
+ import torch_mlu # noqa: F401
+ _ = torch.mlu.device_count()
+ mlu_available = torch.mlu.is_available()
+except:
+ mlu_available = False
+
if args.cpu:
cpu_state = CPUState.CPU
@@ -110,6 +145,12 @@ def is_ascend_npu():
return True
return False
+def is_mlu():
+ global mlu_available
+ if mlu_available:
+ return True
+ return False
+
def get_torch_device():
global directml_enabled
global cpu_state
@@ -125,6 +166,8 @@ def get_torch_device():
return torch.device("xpu", torch.xpu.current_device())
elif is_ascend_npu():
return torch.device("npu", torch.npu.current_device())
+ elif is_mlu():
+ return torch.device("mlu", torch.mlu.current_device())
else:
return torch.device(torch.cuda.current_device())
@@ -151,6 +194,12 @@ def get_total_memory(dev=None, torch_total_too=False):
_, mem_total_npu = torch.npu.mem_get_info(dev)
mem_total_torch = mem_reserved
mem_total = mem_total_npu
+ elif is_mlu():
+ stats = torch.mlu.memory_stats(dev)
+ mem_reserved = stats['reserved_bytes.all.current']
+ _, mem_total_mlu = torch.mlu.mem_get_info(dev)
+ mem_total_torch = mem_reserved
+ mem_total = mem_total_mlu
else:
stats = torch.cuda.memory_stats(dev)
mem_reserved = stats['reserved_bytes.all.current']
@@ -163,12 +212,21 @@ def get_total_memory(dev=None, torch_total_too=False):
else:
return mem_total
+def mac_version():
+ try:
+ return tuple(int(n) for n in platform.mac_ver()[0].split("."))
+ except:
+ return None
+
total_vram = get_total_memory(get_torch_device()) / (1024 * 1024)
total_ram = psutil.virtual_memory().total / (1024 * 1024)
logging.info("Total VRAM {:0.0f} MB, total RAM {:0.0f} MB".format(total_vram, total_ram))
try:
logging.info("pytorch version: {}".format(torch_version))
+ mac_ver = mac_version()
+ if mac_ver is not None:
+ logging.info("Mac Version {}".format(mac_ver))
except:
pass
@@ -218,7 +276,7 @@ def is_amd():
MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia():
- MIN_WEIGHT_MEMORY_RATIO = 0.2
+ MIN_WEIGHT_MEMORY_RATIO = 0.0
ENABLE_PYTORCH_ATTENTION = False
if args.use_pytorch_cross_attention:
@@ -227,22 +285,45 @@ if args.use_pytorch_cross_attention:
try:
if is_nvidia():
- if int(torch_version[0]) >= 2:
+ if torch_version_numeric[0] >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
- if is_intel_xpu() or is_ascend_npu():
+ if is_intel_xpu() or is_ascend_npu() or is_mlu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True
except:
pass
+
+try:
+ if is_amd():
+ arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
+ logging.info("AMD arch: {}".format(arch))
+ if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
+ if torch_version_numeric[0] >= 2 and torch_version_numeric[1] >= 7: # works on 2.6 but doesn't actually seem to improve much
+ if any((a in arch) for a in ["gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches
+ ENABLE_PYTORCH_ATTENTION = True
+except:
+ pass
+
+
if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
+
+PRIORITIZE_FP16 = False # TODO: remove and replace with something that shows exactly which dtype is faster than the other
try:
- if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
+ if is_nvidia() and PerformanceFeature.Fp16Accumulation in args.fast:
+ torch.backends.cuda.matmul.allow_fp16_accumulation = True
+ PRIORITIZE_FP16 = True # TODO: limit to cards where it actually boosts performance
+ logging.info("Enabled fp16 accumulation.")
+except:
+ pass
+
+try:
+ if torch_version_numeric[0] == 2 and torch_version_numeric[1] >= 5:
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
except:
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
@@ -256,15 +337,10 @@ elif args.highvram or args.gpu_only:
vram_state = VRAMState.HIGH_VRAM
FORCE_FP32 = False
-FORCE_FP16 = False
if args.force_fp32:
logging.info("Forcing FP32, if this improves things please report it.")
FORCE_FP32 = True
-if args.force_fp16:
- logging.info("Forcing FP16.")
- FORCE_FP16 = True
-
if lowvram_available:
if set_vram_to in (VRAMState.LOW_VRAM, VRAMState.NO_VRAM):
vram_state = set_vram_to
@@ -297,6 +373,8 @@ def get_torch_device_name(device):
return "{} {}".format(device, torch.xpu.get_device_name(device))
elif is_ascend_npu():
return "{} {}".format(device, torch.npu.get_device_name(device))
+ elif is_mlu():
+ return "{} {}".format(device, torch.mlu.get_device_name(device))
else:
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
@@ -535,14 +613,11 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
vram_set_state = vram_state
lowvram_model_memory = 0
if lowvram_available and (vram_set_state == VRAMState.LOW_VRAM or vram_set_state == VRAMState.NORMAL_VRAM) and not force_full_load:
- model_size = loaded_model.model_memory_required(torch_dev)
loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory
- lowvram_model_memory = max(64 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
+ lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory)
- if model_size <= lowvram_model_memory: #only switch to lowvram if really necessary
- lowvram_model_memory = 0
if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 0.1
@@ -620,7 +695,7 @@ def unet_inital_load_device(parameters, dtype):
return torch_dev
cpu_dev = torch.device("cpu")
- if DISABLE_SMART_MEMORY:
+ if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
return cpu_dev
model_size = dtype_size(dtype) * parameters
@@ -635,7 +710,7 @@ def unet_inital_load_device(parameters, dtype):
def maximum_vram_for_weights(device=None):
return (get_total_memory(device) * 0.88 - minimum_inference_memory())
-def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32]):
+def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, torch.bfloat16, torch.float32], weight_dtype=None):
if model_params < 0:
model_params = 1000000000000000000000
if args.fp32_unet:
@@ -650,15 +725,12 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
return torch.float8_e4m3fn
if args.fp8_e5m2_unet:
return torch.float8_e5m2
+ if args.fp8_e8m0fnu_unet:
+ return torch.float8_e8m0fnu
fp8_dtype = None
- try:
- for dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
- if dtype in supported_dtypes:
- fp8_dtype = dtype
- break
- except:
- pass
+ if weight_dtype in FLOAT8_TYPES:
+ fp8_dtype = weight_dtype
if fp8_dtype is not None:
if supports_fp8_compute(device): #if fp8 compute is supported the casting is most likely not expensive
@@ -668,6 +740,10 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
if model_params * 2 > free_model_memory:
return fp8_dtype
+ if PRIORITIZE_FP16 or weight_dtype == torch.float16:
+ if torch.float16 in supported_dtypes and should_use_fp16(device=device, model_params=model_params):
+ return torch.float16
+
for dt in supported_dtypes:
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
if torch.float16 in supported_dtypes:
@@ -700,6 +776,9 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
return None
fp16_supported = should_use_fp16(inference_device, prioritize_performance=True)
+ if PRIORITIZE_FP16 and fp16_supported and torch.float16 in supported_dtypes:
+ return torch.float16
+
for dt in supported_dtypes:
if dt == torch.float16 and fp16_supported:
return torch.float16
@@ -746,6 +825,8 @@ def text_encoder_dtype(device=None):
return torch.float8_e5m2
elif args.fp16_text_enc:
return torch.float16
+ elif args.bf16_text_enc:
+ return torch.bfloat16
elif args.fp32_text_enc:
return torch.float32
@@ -858,15 +939,61 @@ def force_channels_last():
#TODO
return False
-def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False):
+
+STREAMS = {}
+NUM_STREAMS = 1
+if args.async_offload:
+ NUM_STREAMS = 2
+ logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
+
+stream_counters = {}
+def get_offload_stream(device):
+ stream_counter = stream_counters.get(device, 0)
+ if NUM_STREAMS <= 1:
+ return None
+
+ if device in STREAMS:
+ ss = STREAMS[device]
+ s = ss[stream_counter]
+ stream_counter = (stream_counter + 1) % len(ss)
+ if is_device_cuda(device):
+ ss[stream_counter].wait_stream(torch.cuda.current_stream())
+ stream_counters[device] = stream_counter
+ return s
+ elif is_device_cuda(device):
+ ss = []
+ for k in range(NUM_STREAMS):
+ ss.append(torch.cuda.Stream(device=device, priority=0))
+ STREAMS[device] = ss
+ s = ss[stream_counter]
+ stream_counter = (stream_counter + 1) % len(ss)
+ stream_counters[device] = stream_counter
+ return s
+ return None
+
+def sync_stream(device, stream):
+ if stream is None:
+ return
+ if is_device_cuda(device):
+ torch.cuda.current_stream().wait_stream(stream)
+
+def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None):
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
return weight
+ if stream is not None:
+ with stream:
+ return weight.to(dtype=dtype, copy=copy)
return weight.to(dtype=dtype, copy=copy)
- r = torch.empty_like(weight, dtype=dtype, device=device)
- r.copy_(weight, non_blocking=non_blocking)
+ if stream is not None:
+ with stream:
+ r = torch.empty_like(weight, dtype=dtype, device=device)
+ r.copy_(weight, non_blocking=non_blocking)
+ else:
+ r = torch.empty_like(weight, dtype=dtype, device=device)
+ r.copy_(weight, non_blocking=non_blocking)
return r
def cast_to_device(tensor, device, dtype, copy=False):
@@ -876,6 +1003,9 @@ def cast_to_device(tensor, device, dtype, copy=False):
def sage_attention_enabled():
return args.use_sage_attention
+def flash_attention_enabled():
+ return args.use_flash_attention
+
def xformers_enabled():
global directml_enabled
global cpu_state
@@ -885,6 +1015,8 @@ def xformers_enabled():
return False
if is_ascend_npu():
return False
+ if is_mlu():
+ return False
if directml_enabled:
return False
return XFORMERS_IS_AVAILABLE
@@ -901,6 +1033,11 @@ def pytorch_attention_enabled():
global ENABLE_PYTORCH_ATTENTION
return ENABLE_PYTORCH_ATTENTION
+def pytorch_attention_enabled_vae():
+ if is_amd():
+ return False # enabling pytorch attention on AMD currently causes crash when doing high res
+ return pytorch_attention_enabled()
+
def pytorch_attention_flash_attention():
global ENABLE_PYTORCH_ATTENTION
if ENABLE_PYTORCH_ATTENTION:
@@ -911,23 +1048,21 @@ def pytorch_attention_flash_attention():
return True
if is_ascend_npu():
return True
+ if is_mlu():
+ return True
+ if is_amd():
+ return True #if you have pytorch attention enabled on AMD it probably supports at least mem efficient attention
return False
-def mac_version():
- try:
- return tuple(int(n) for n in platform.mac_ver()[0].split("."))
- except:
- return None
-
def force_upcast_attention_dtype():
upcast = args.force_upcast_attention
macos_version = mac_version()
- if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
+ if macos_version is not None and ((14, 5) <= macos_version < (16,)): # black image bug on recent versions of macOS
upcast = True
if upcast:
- return torch.float32
+ return {torch.float16: torch.float32}
else:
return None
@@ -957,6 +1092,13 @@ def get_free_memory(dev=None, torch_free_too=False):
mem_free_npu, _ = torch.npu.mem_get_info(dev)
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_npu + mem_free_torch
+ elif is_mlu():
+ stats = torch.mlu.memory_stats(dev)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_mlu, _ = torch.mlu.mem_get_info(dev)
+ mem_free_torch = mem_reserved - mem_active
+ mem_free_total = mem_free_mlu + mem_free_torch
else:
stats = torch.cuda.memory_stats(dev)
mem_active = stats['active_bytes.all.current']
@@ -993,21 +1135,26 @@ def is_device_mps(device):
def is_device_cuda(device):
return is_device_type(device, 'cuda')
-def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
+def is_directml_enabled():
global directml_enabled
+ if directml_enabled:
+ return True
+ return False
+
+def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False):
if device is not None:
if is_device_cpu(device):
return False
- if FORCE_FP16:
+ if args.force_fp16:
return True
if FORCE_FP32:
return False
- if directml_enabled:
- return False
+ if is_directml_enabled():
+ return True
if (device is not None and is_device_mps(device)) or mps_mode():
return True
@@ -1021,6 +1168,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if is_ascend_npu():
return True
+ if is_mlu():
+ return True
+
if torch.version.hip:
return True
@@ -1078,13 +1228,28 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_intel_xpu():
return True
+ if is_ascend_npu():
+ return True
+
+ if is_amd():
+ arch = torch.cuda.get_device_properties(device).gcnArchName
+ if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
+ if manual_cast:
+ return True
+ return False
+
props = torch.cuda.get_device_properties(device)
+
+ if is_mlu():
+ if props.major > 3:
+ return True
+
if props.major >= 8:
return True
bf16_works = torch.cuda.is_bf16_supported()
- if bf16_works or manual_cast:
+ if bf16_works and manual_cast:
free_model_memory = maximum_vram_for_weights(device)
if (not prioritize_performance) or model_params * 4 > free_model_memory:
return True
@@ -1092,6 +1257,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
return False
def supports_fp8_compute(device=None):
+ if args.supports_fp8_compute:
+ return True
+
if not is_nvidia():
return False
@@ -1103,11 +1271,11 @@ def supports_fp8_compute(device=None):
if props.minor < 9:
return False
- if int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) < 3):
+ if torch_version_numeric[0] < 2 or (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 3):
return False
if WINDOWS:
- if (int(torch_version[0]) == 2 and int(torch_version[2]) < 4):
+ if (torch_version_numeric[0] == 2 and torch_version_numeric[1] < 4):
return False
return True
@@ -1120,6 +1288,8 @@ def soft_empty_cache(force=False):
torch.xpu.empty_cache()
elif is_ascend_npu():
torch.npu.empty_cache()
+ elif is_mlu():
+ torch.mlu.empty_cache()
elif torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py
index 0501f7b38..b7cb12dfc 100644
--- a/comfy/model_patcher.py
+++ b/comfy/model_patcher.py
@@ -96,8 +96,28 @@ def wipe_lowvram_weight(m):
if hasattr(m, "prev_comfy_cast_weights"):
m.comfy_cast_weights = m.prev_comfy_cast_weights
del m.prev_comfy_cast_weights
- m.weight_function = None
- m.bias_function = None
+
+ if hasattr(m, "weight_function"):
+ m.weight_function = []
+
+ if hasattr(m, "bias_function"):
+ m.bias_function = []
+
+def move_weight_functions(m, device):
+ if device is None:
+ return 0
+
+ memory = 0
+ if hasattr(m, "weight_function"):
+ for f in m.weight_function:
+ if hasattr(f, "move_to"):
+ memory += f.move_to(device=device)
+
+ if hasattr(m, "bias_function"):
+ for f in m.bias_function:
+ if hasattr(f, "move_to"):
+ memory += f.move_to(device=device)
+ return memory
class LowVramPatch:
def __init__(self, key, patches):
@@ -192,11 +212,13 @@ class ModelPatcher:
self.backup = {}
self.object_patches = {}
self.object_patches_backup = {}
+ self.weight_wrapper_patches = {}
self.model_options = {"transformer_options":{}}
self.model_size()
self.load_device = load_device
self.offload_device = offload_device
self.weight_inplace_update = weight_inplace_update
+ self.force_cast_weights = False
self.patches_uuid = uuid.uuid4()
self.parent = None
@@ -250,11 +272,14 @@ class ModelPatcher:
n.patches_uuid = self.patches_uuid
n.object_patches = self.object_patches.copy()
+ n.weight_wrapper_patches = self.weight_wrapper_patches.copy()
n.model_options = copy.deepcopy(self.model_options)
n.backup = self.backup
n.object_patches_backup = self.object_patches_backup
n.parent = self
+ n.force_cast_weights = self.force_cast_weights
+
# attachments
n.attachments = {}
for k in self.attachments:
@@ -402,6 +427,16 @@ class ModelPatcher:
def add_object_patch(self, name, obj):
self.object_patches[name] = obj
+ def set_model_compute_dtype(self, dtype):
+ self.add_object_patch("manual_cast_dtype", dtype)
+ if dtype is not None:
+ self.force_cast_weights = True
+ self.patches_uuid = uuid.uuid4() #TODO: optimize by preventing a full model reload for this
+
+ def add_weight_wrapper(self, name, function):
+ self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function]
+ self.patches_uuid = uuid.uuid4()
+
def get_model_object(self, name: str) -> torch.nn.Module:
"""Retrieves a nested attribute from an object using dot notation considering
object patches.
@@ -566,6 +601,9 @@ class ModelPatcher:
lowvram_weight = False
+ weight_key = "{}.weight".format(n)
+ bias_key = "{}.bias".format(n)
+
if not full_load and hasattr(m, "comfy_cast_weights"):
if mem_counter + module_mem >= lowvram_model_memory:
lowvram_weight = True
@@ -573,34 +611,46 @@ class ModelPatcher:
if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed
continue
- weight_key = "{}.weight".format(n)
- bias_key = "{}.bias".format(n)
-
+ cast_weight = self.force_cast_weights
if lowvram_weight:
+ if hasattr(m, "comfy_cast_weights"):
+ m.weight_function = []
+ m.bias_function = []
+
if weight_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(weight_key)
else:
- m.weight_function = LowVramPatch(weight_key, self.patches)
+ m.weight_function = [LowVramPatch(weight_key, self.patches)]
patch_counter += 1
if bias_key in self.patches:
if force_patch_weights:
self.patch_weight_to_device(bias_key)
else:
- m.bias_function = LowVramPatch(bias_key, self.patches)
+ m.bias_function = [LowVramPatch(bias_key, self.patches)]
patch_counter += 1
- m.prev_comfy_cast_weights = m.comfy_cast_weights
- m.comfy_cast_weights = True
+ cast_weight = True
else:
if hasattr(m, "comfy_cast_weights"):
- if m.comfy_cast_weights:
- wipe_lowvram_weight(m)
+ wipe_lowvram_weight(m)
if full_load or mem_counter + module_mem < lowvram_model_memory:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
+ if cast_weight and hasattr(m, "comfy_cast_weights"):
+ m.prev_comfy_cast_weights = m.comfy_cast_weights
+ m.comfy_cast_weights = True
+
+ if weight_key in self.weight_wrapper_patches:
+ m.weight_function.extend(self.weight_wrapper_patches[weight_key])
+
+ if bias_key in self.weight_wrapper_patches:
+ m.bias_function.extend(self.weight_wrapper_patches[bias_key])
+
+ mem_counter += move_weight_functions(m, device_to)
+
load_completely.sort(reverse=True)
for x in load_completely:
n = x[1]
@@ -662,6 +712,7 @@ class ModelPatcher:
self.unpatch_hooks()
if self.model.model_lowvram:
for m in self.model.modules():
+ move_weight_functions(m, device_to)
wipe_lowvram_weight(m)
self.model.model_lowvram = False
@@ -696,6 +747,7 @@ class ModelPatcher:
def partially_unload(self, device_to, memory_to_free=0):
with self.use_ejected():
+ hooks_unpatched = False
memory_freed = 0
patch_counter = 0
unload_list = self._load_list()
@@ -719,6 +771,10 @@ class ModelPatcher:
move_weight = False
break
+ if not hooks_unpatched:
+ self.unpatch_hooks()
+ hooks_unpatched = True
+
if bk.inplace_update:
comfy.utils.copy_to_param(self.model, key, bk.weight)
else:
@@ -728,15 +784,19 @@ class ModelPatcher:
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if move_weight:
+ cast_weight = self.force_cast_weights
m.to(device_to)
+ module_mem += move_weight_functions(m, device_to)
if lowvram_possible:
if weight_key in self.patches:
- m.weight_function = LowVramPatch(weight_key, self.patches)
+ m.weight_function.append(LowVramPatch(weight_key, self.patches))
patch_counter += 1
if bias_key in self.patches:
- m.bias_function = LowVramPatch(bias_key, self.patches)
+ m.bias_function.append(LowVramPatch(bias_key, self.patches))
patch_counter += 1
+ cast_weight = True
+ if cast_weight:
m.prev_comfy_cast_weights = m.comfy_cast_weights
m.comfy_cast_weights = True
m.comfy_patched_weights = False
@@ -1034,7 +1094,6 @@ class ModelPatcher:
def patch_hooks(self, hooks: comfy.hooks.HookGroup):
with self.use_ejected():
- self.unpatch_hooks()
if hooks is not None:
model_sd_keys = list(self.model_state_dict().keys())
memory_counter = None
@@ -1045,12 +1104,16 @@ class ModelPatcher:
# if have cached weights for hooks, use it
cached_weights = self.cached_hook_patches.get(hooks, None)
if cached_weights is not None:
+ model_sd_keys_set = set(model_sd_keys)
for key in cached_weights:
if key not in model_sd_keys:
logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}")
continue
self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter)
+ model_sd_keys_set.remove(key)
+ self.unpatch_hooks(model_sd_keys_set)
else:
+ self.unpatch_hooks()
relevant_patches = self.get_combined_hook_patches(hooks=hooks)
original_weights = None
if len(relevant_patches) > 0:
@@ -1061,6 +1124,8 @@ class ModelPatcher:
continue
self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights,
memory_counter=memory_counter)
+ else:
+ self.unpatch_hooks()
self.current_hooks = hooks
def patch_cached_hook_weights(self, cached_weights: dict, key: str, memory_counter: MemoryCounter):
@@ -1117,17 +1182,23 @@ class ModelPatcher:
del out_weight
del weight
- def unpatch_hooks(self) -> None:
+ def unpatch_hooks(self, whitelist_keys_set: set[str]=None) -> None:
with self.use_ejected():
if len(self.hook_backup) == 0:
self.current_hooks = None
return
keys = list(self.hook_backup.keys())
- for k in keys:
- comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
+ if whitelist_keys_set:
+ for k in keys:
+ if k in whitelist_keys_set:
+ comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
+ self.hook_backup.pop(k)
+ else:
+ for k in keys:
+ comfy.utils.copy_to_param(self.model, k, self.hook_backup[k][0].to(device=self.hook_backup[k][1]))
- self.hook_backup.clear()
- self.current_hooks = None
+ self.hook_backup.clear()
+ self.current_hooks = None
def clean_hooks(self):
self.unpatch_hooks()
diff --git a/comfy/model_sampling.py b/comfy/model_sampling.py
index 4370516b9..7e7291476 100644
--- a/comfy/model_sampling.py
+++ b/comfy/model_sampling.py
@@ -31,6 +31,7 @@ class EPS:
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
+ sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
@@ -61,11 +62,22 @@ class CONST:
return model_input - model_output * sigma
def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
+ sigma = sigma.view(sigma.shape[:1] + (1,) * (noise.ndim - 1))
return sigma * noise + (1.0 - sigma) * latent_image
def inverse_noise_scaling(self, sigma, latent):
+ sigma = sigma.view(sigma.shape[:1] + (1,) * (latent.ndim - 1))
return latent / (1.0 - sigma)
+class X0(EPS):
+ def calculate_denoised(self, sigma, model_output, model_input):
+ return model_output
+
+class IMG_TO_IMG(X0):
+ def calculate_input(self, sigma, noise):
+ return noise
+
+
class ModelSamplingDiscrete(torch.nn.Module):
def __init__(self, model_config=None, zsnr=None):
super().__init__()
@@ -99,13 +111,14 @@ class ModelSamplingDiscrete(torch.nn.Module):
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
+ self.zsnr = zsnr
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
- if zsnr:
+ if self.zsnr:
sigmas = rescale_zero_terminal_snr_sigmas(sigmas)
self.set_sigmas(sigmas)
diff --git a/comfy/ops.py b/comfy/ops.py
index 06be6b48b..431c8f89d 100644
--- a/comfy/ops.py
+++ b/comfy/ops.py
@@ -17,9 +17,12 @@
"""
import torch
+import logging
import comfy.model_management
-from comfy.cli_args import args
+from comfy.cli_args import args, PerformanceFeature
import comfy.float
+import comfy.rmsnorm
+import contextlib
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
@@ -35,24 +38,37 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None):
if device is None:
device = input.device
+ offload_stream = comfy.model_management.get_offload_stream(device)
+ if offload_stream is not None:
+ wf_context = offload_stream
+ else:
+ wf_context = contextlib.nullcontext()
+
bias = None
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if s.bias is not None:
- has_function = s.bias_function is not None
- bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function)
- if has_function:
- bias = s.bias_function(bias)
+ has_function = len(s.bias_function) > 0
+ bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
- has_function = s.weight_function is not None
- weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function)
+ if has_function:
+ with wf_context:
+ for f in s.bias_function:
+ bias = f(bias)
+
+ has_function = len(s.weight_function) > 0
+ weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream)
if has_function:
- weight = s.weight_function(weight)
+ with wf_context:
+ for f in s.weight_function:
+ weight = f(weight)
+
+ comfy.model_management.sync_stream(device, offload_stream)
return weight, bias
class CastWeightBiasOp:
comfy_cast_weights = False
- weight_function = None
- bias_function = None
+ weight_function = []
+ bias_function = []
class disable_weight_init:
class Linear(torch.nn.Linear, CastWeightBiasOp):
@@ -64,7 +80,7 @@ class disable_weight_init:
return torch.nn.functional.linear(input, weight, bias)
def forward(self, *args, **kwargs):
- if self.comfy_cast_weights:
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -78,7 +94,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
- if self.comfy_cast_weights:
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -92,7 +108,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
- if self.comfy_cast_weights:
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -106,7 +122,7 @@ class disable_weight_init:
return self._conv_forward(input, weight, bias)
def forward(self, *args, **kwargs):
- if self.comfy_cast_weights:
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -120,12 +136,11 @@ class disable_weight_init:
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)
def forward(self, *args, **kwargs):
- if self.comfy_cast_weights:
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
-
class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp):
def reset_parameters(self):
return None
@@ -139,7 +154,26 @@ class disable_weight_init:
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)
def forward(self, *args, **kwargs):
- if self.comfy_cast_weights:
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
+ return self.forward_comfy_cast_weights(*args, **kwargs)
+ else:
+ return super().forward(*args, **kwargs)
+
+ class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
+ def reset_parameters(self):
+ self.bias = None
+ return None
+
+ def forward_comfy_cast_weights(self, input):
+ if self.weight is not None:
+ weight, bias = cast_bias_weight(self, input)
+ else:
+ weight = None
+ return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
+ # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
+
+ def forward(self, *args, **kwargs):
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -160,7 +194,7 @@ class disable_weight_init:
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
- if self.comfy_cast_weights:
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -181,7 +215,7 @@ class disable_weight_init:
output_padding, self.groups, self.dilation)
def forward(self, *args, **kwargs):
- if self.comfy_cast_weights:
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
@@ -199,7 +233,7 @@ class disable_weight_init:
return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype)
def forward(self, *args, **kwargs):
- if self.comfy_cast_weights:
+ if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
if "out_dtype" in kwargs:
@@ -241,6 +275,9 @@ class manual_cast(disable_weight_init):
class ConvTranspose1d(disable_weight_init.ConvTranspose1d):
comfy_cast_weights = True
+ class RMSNorm(disable_weight_init.RMSNorm):
+ comfy_cast_weights = True
+
class Embedding(disable_weight_init.Embedding):
comfy_cast_weights = True
@@ -271,10 +308,10 @@ def fp8_linear(self, input):
if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
input = torch.clamp(input, min=-448, max=448, out=input)
- input = input.reshape(-1, input_shape[2]).to(dtype)
+ input = input.reshape(-1, input_shape[2]).to(dtype).contiguous()
else:
scale_input = scale_input.to(input.device)
- input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
+ input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous()
if bias is not None:
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
@@ -307,6 +344,7 @@ class fp8_ops(manual_cast):
return torch.nn.functional.linear(input, weight, bias)
def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None):
+ logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input))
class scaled_fp8_op(manual_cast):
class Linear(manual_cast.Linear):
def __init__(self, *args, **kwargs):
@@ -354,14 +392,46 @@ def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None
return scaled_fp8_op
+CUBLAS_IS_AVAILABLE = False
+try:
+ from cublas_ops import CublasLinear
+ CUBLAS_IS_AVAILABLE = True
+except ImportError:
+ pass
+
+if CUBLAS_IS_AVAILABLE:
+ class cublas_ops(disable_weight_init):
+ class Linear(CublasLinear, disable_weight_init.Linear):
+ def reset_parameters(self):
+ return None
+
+ def forward_comfy_cast_weights(self, input):
+ return super().forward(input)
+
+ def forward(self, *args, **kwargs):
+ return super().forward(*args, **kwargs)
+
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device)
if scaled_fp8 is not None:
- return scaled_fp8_ops(fp8_matrix_mult=fp8_compute, scale_input=True, override_dtype=scaled_fp8)
+ return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8)
- if fp8_compute and (fp8_optimizations or args.fast) and not disable_fast_fp8:
+ if (
+ fp8_compute and
+ (fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and
+ not disable_fast_fp8
+ ):
return fp8_ops
+ if (
+ PerformanceFeature.CublasOps in args.fast and
+ CUBLAS_IS_AVAILABLE and
+ weight_dtype == torch.float16 and
+ (compute_dtype == torch.float16 or compute_dtype is None)
+ ):
+ logging.info("Using cublas ops")
+ return cublas_ops
+
if compute_dtype is None or weight_dtype == compute_dtype:
return disable_weight_init
diff --git a/comfy/patcher_extension.py b/comfy/patcher_extension.py
index 859758244..965958f4c 100644
--- a/comfy/patcher_extension.py
+++ b/comfy/patcher_extension.py
@@ -48,6 +48,7 @@ def get_all_callbacks(call_type: str, transformer_options: dict, is_model_option
class WrappersMP:
OUTER_SAMPLE = "outer_sample"
+ PREPARE_SAMPLING = "prepare_sampling"
SAMPLER_SAMPLE = "sampler_sample"
CALC_COND_BATCH = "calc_cond_batch"
APPLY_MODEL = "apply_model"
diff --git a/comfy/rmsnorm.py b/comfy/rmsnorm.py
new file mode 100644
index 000000000..66ae8321d
--- /dev/null
+++ b/comfy/rmsnorm.py
@@ -0,0 +1,55 @@
+import torch
+import comfy.model_management
+import numbers
+
+RMSNorm = None
+
+try:
+ rms_norm_torch = torch.nn.functional.rms_norm
+ RMSNorm = torch.nn.RMSNorm
+except:
+ rms_norm_torch = None
+
+
+def rms_norm(x, weight=None, eps=1e-6):
+ if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
+ if weight is None:
+ return rms_norm_torch(x, (x.shape[-1],), eps=eps)
+ else:
+ return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
+ else:
+ r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
+ if weight is None:
+ return r
+ else:
+ return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
+
+
+if RMSNorm is None:
+ class RMSNorm(torch.nn.Module):
+ def __init__(
+ self,
+ normalized_shape,
+ eps=1e-6,
+ elementwise_affine=True,
+ device=None,
+ dtype=None,
+ ):
+ factory_kwargs = {"device": device, "dtype": dtype}
+ super().__init__()
+ if isinstance(normalized_shape, numbers.Integral):
+ # mypy error: incompatible types in assignment
+ normalized_shape = (normalized_shape,) # type: ignore[assignment]
+ self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
+ self.eps = eps
+ self.elementwise_affine = elementwise_affine
+ if self.elementwise_affine:
+ self.weight = torch.nn.Parameter(
+ torch.empty(self.normalized_shape, **factory_kwargs)
+ )
+ else:
+ self.register_parameter("weight", None)
+ self.bias = None
+
+ def forward(self, x):
+ return rms_norm(x, self.weight, self.eps)
diff --git a/comfy/sampler_helpers.py b/comfy/sampler_helpers.py
index b70e5e636..96a3040a1 100644
--- a/comfy/sampler_helpers.py
+++ b/comfy/sampler_helpers.py
@@ -58,7 +58,6 @@ def convert_cond(cond):
temp = c[1].copy()
model_conds = temp.get("model_conds", {})
if c[0] is not None:
- model_conds["c_crossattn"] = comfy.conds.CONDCrossAttn(c[0]) #TODO: remove
temp["cross_attn"] = c[0]
temp["model_conds"] = model_conds
temp["uuid"] = uuid.uuid4()
@@ -107,6 +106,13 @@ def cleanup_additional_models(models):
def prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
+ executor = comfy.patcher_extension.WrapperExecutor.new_executor(
+ _prepare_sampling,
+ comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.PREPARE_SAMPLING, model_options, is_model_options=True)
+ )
+ return executor.execute(model, noise_shape, conds, model_options=model_options)
+
+def _prepare_sampling(model: ModelPatcher, noise_shape, conds, model_options=None):
real_model: BaseModel = None
models, inference_memory = get_additional_models(conds, model.model_dtype())
models += get_additional_models_from_model_options(model_options)
diff --git a/comfy/samplers.py b/comfy/samplers.py
index fa176c6de..67ae09a25 100644
--- a/comfy/samplers.py
+++ b/comfy/samplers.py
@@ -12,7 +12,6 @@ import collections
from comfy import model_management
import math
import logging
-import comfy.samplers
import comfy.sampler_helpers
import comfy.model_patcher
import comfy.patcher_extension
@@ -20,6 +19,12 @@ import comfy.hooks
import scipy.stats
import numpy
+
+def add_area_dims(area, num_dims):
+ while (len(area) // 2) < num_dims:
+ area = [2147483648] + area[:len(area) // 2] + [0] + area[len(area) // 2:]
+ return area
+
def get_area_and_mult(conds, x_in, timestep_in):
dims = tuple(x_in.shape[2:])
area = None
@@ -35,6 +40,10 @@ def get_area_and_mult(conds, x_in, timestep_in):
return None
if 'area' in conds:
area = list(conds['area'])
+ area = add_area_dims(area, len(dims))
+ if (len(area) // 2) > len(dims):
+ area = area[:len(dims)] + area[len(area) // 2:(len(area) // 2) + len(dims)]
+
if 'strength' in conds:
strength = conds['strength']
@@ -51,7 +60,7 @@ def get_area_and_mult(conds, x_in, timestep_in):
if "mask_strength" in conds:
mask_strength = conds["mask_strength"]
mask = conds['mask']
- assert(mask.shape[1:] == x_in.shape[2:])
+ assert (mask.shape[1:] == x_in.shape[2:])
mask = mask[:input_x.shape[0]]
if area is not None:
@@ -65,16 +74,17 @@ def get_area_and_mult(conds, x_in, timestep_in):
mult = mask * strength
if 'mask' not in conds and area is not None:
- rr = 8
+ fuzz = 8
for i in range(len(dims)):
+ rr = min(fuzz, mult.shape[2 + i] // 4)
if area[len(dims) + i] != 0:
for t in range(rr):
m = mult.narrow(i + 2, t, 1)
- m *= ((1.0/rr) * (t + 1))
+ m *= ((1.0 / rr) * (t + 1))
if (area[i] + area[len(dims) + i]) < x_in.shape[i + 2]:
for t in range(rr):
m = mult.narrow(i + 2, area[i] - 1 - t, 1)
- m *= ((1.0/rr) * (t + 1))
+ m *= ((1.0 / rr) * (t + 1))
conditioning = {}
model_conds = conds["model_conds"]
@@ -178,7 +188,7 @@ def finalize_default_conds(model: 'BaseModel', hooked_to_run: dict[comfy.hooks.H
cond = default_conds[i]
for x in cond:
# do get_area_and_mult to get all the expected values
- p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
+ p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
# replace p's mult with calculated mult
@@ -215,7 +225,7 @@ def _calc_cond_batch(model: 'BaseModel', conds: list[list[dict]], x_in: torch.Te
default_c.append(x)
has_default_conds = True
continue
- p = comfy.samplers.get_area_and_mult(x, x_in, timestep)
+ p = get_area_and_mult(x, x_in, timestep)
if p is None:
continue
if p.hooks is not None:
@@ -376,7 +386,7 @@ class KSamplerX0Inpaint:
if "denoise_mask_function" in model_options:
denoise_mask = model_options["denoise_mask_function"](sigma, denoise_mask, extra_options={"model": self.inner_model, "sigmas": self.sigmas})
latent_mask = 1. - denoise_mask
- x = x * denoise_mask + self.inner_model.inner_model.model_sampling.noise_scaling(sigma.reshape([sigma.shape[0]] + [1] * (len(self.noise.shape) - 1)), self.noise, self.latent_image) * latent_mask
+ x = x * denoise_mask + self.inner_model.inner_model.scale_latent_inpaint(x=x, sigma=sigma, noise=self.noise, latent_image=self.latent_image) * latent_mask
out = self.inner_model(x, sigma, model_options=model_options, seed=seed)
if denoise_mask is not None:
out = out * denoise_mask + self.latent_image * latent_mask
@@ -549,25 +559,37 @@ def resolve_areas_and_cond_masks(conditions, h, w, device):
logging.warning("WARNING: The comfy.samplers.resolve_areas_and_cond_masks function is deprecated please use the resolve_areas_and_cond_masks_multidim one instead.")
return resolve_areas_and_cond_masks_multidim(conditions, [h, w], device)
-def create_cond_with_same_area_if_none(conds, c): #TODO: handle dim != 2
+def create_cond_with_same_area_if_none(conds, c):
if 'area' not in c:
return
+ def area_inside(a, area_cmp):
+ a = add_area_dims(a, len(area_cmp) // 2)
+ area_cmp = add_area_dims(area_cmp, len(a) // 2)
+
+ a_l = len(a) // 2
+ area_cmp_l = len(area_cmp) // 2
+ for i in range(min(a_l, area_cmp_l)):
+ if a[a_l + i] < area_cmp[area_cmp_l + i]:
+ return False
+ for i in range(min(a_l, area_cmp_l)):
+ if (a[i] + a[a_l + i]) > (area_cmp[i] + area_cmp[area_cmp_l + i]):
+ return False
+ return True
+
c_area = c['area']
smallest = None
for x in conds:
if 'area' in x:
a = x['area']
- if c_area[2] >= a[2] and c_area[3] >= a[3]:
- if a[0] + a[2] >= c_area[0] + c_area[2]:
- if a[1] + a[3] >= c_area[1] + c_area[3]:
- if smallest is None:
- smallest = x
- elif 'area' not in smallest:
- smallest = x
- else:
- if smallest['area'][0] * smallest['area'][1] > a[0] * a[1]:
- smallest = x
+ if area_inside(c_area, a):
+ if smallest is None:
+ smallest = x
+ elif 'area' not in smallest:
+ smallest = x
+ else:
+ if math.prod(smallest['area'][:len(smallest['area']) // 2]) > math.prod(a[:len(a) // 2]):
+ smallest = x
else:
if smallest is None:
smallest = x
@@ -687,7 +709,8 @@ class Sampler:
KSAMPLER_NAMES = ["euler", "euler_cfg_pp", "euler_ancestral", "euler_ancestral_cfg_pp", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_2s_ancestral_cfg_pp", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_cfg_pp", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm",
- "ipndm", "ipndm_v", "deis", "res_multistep"]
+ "ipndm", "ipndm_v", "deis", "res_multistep", "res_multistep_cfg_pp", "res_multistep_ancestral", "res_multistep_ancestral_cfg_pp",
+ "gradient_estimation", "gradient_estimation_cfg_pp", "er_sde", "seeds_2", "seeds_3"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
diff --git a/comfy/sd.py b/comfy/sd.py
index 7db1c2d60..e98a3aa87 100644
--- a/comfy/sd.py
+++ b/comfy/sd.py
@@ -1,4 +1,5 @@
from __future__ import annotations
+import json
import torch
from enum import Enum
import logging
@@ -12,6 +13,9 @@ from .ldm.audio.autoencoder import AudioOobleckVAE
import comfy.ldm.genmo.vae.model
import comfy.ldm.lightricks.vae.causal_video_autoencoder
import comfy.ldm.cosmos.vae
+import comfy.ldm.wan.vae
+import comfy.ldm.hunyuan3d.vae
+import comfy.ldm.ace.vae.music_dcae_pipeline
import yaml
import math
@@ -36,6 +40,10 @@ import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
+import comfy.text_encoders.lumina2
+import comfy.text_encoders.wan
+import comfy.text_encoders.hidream
+import comfy.text_encoders.ace
import comfy.model_patcher
import comfy.lora
@@ -114,6 +122,7 @@ class CLIP:
self.layer_idx = None
self.use_clip_schedule = False
logging.info("CLIP/text encoder model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
+ self.tokenizer_options = {}
def clone(self):
n = CLIP(no_init=True)
@@ -121,6 +130,7 @@ class CLIP:
n.cond_stage_model = self.cond_stage_model
n.tokenizer = self.tokenizer
n.layer_idx = self.layer_idx
+ n.tokenizer_options = self.tokenizer_options.copy()
n.use_clip_schedule = self.use_clip_schedule
n.apply_hooks_to_conds = self.apply_hooks_to_conds
return n
@@ -128,11 +138,19 @@ class CLIP:
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
return self.patcher.add_patches(patches, strength_patch, strength_model)
+ def set_tokenizer_option(self, option_name, value):
+ self.tokenizer_options[option_name] = value
+
def clip_layer(self, layer_idx):
self.layer_idx = layer_idx
- def tokenize(self, text, return_word_ids=False):
- return self.tokenizer.tokenize_with_weights(text, return_word_ids)
+ def tokenize(self, text, return_word_ids=False, **kwargs):
+ tokenizer_options = kwargs.get("tokenizer_options", {})
+ if len(self.tokenizer_options) > 0:
+ tokenizer_options = {**self.tokenizer_options, **tokenizer_options}
+ if len(tokenizer_options) > 0:
+ kwargs["tokenizer_options"] = tokenizer_options
+ return self.tokenizer.tokenize_with_weights(text, return_word_ids, **kwargs)
def add_hooks_to_dict(self, pooled_dict: dict[str]):
if self.apply_hooks_to_conds:
@@ -246,7 +264,7 @@ class CLIP:
return self.patcher.get_key_patches()
class VAE:
- def __init__(self, sd=None, device=None, config=None, dtype=None):
+ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd)
@@ -260,9 +278,11 @@ class VAE:
self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
self.working_dtypes = [torch.bfloat16, torch.float32]
+ self.disable_offload = False
self.downscale_index_formula = None
self.upscale_index_formula = None
+ self.extra_1d_channel = None
if config is None:
if "decoder.mid.block_1.mix_factor" in sd:
@@ -332,6 +352,7 @@ class VAE:
self.process_output = lambda audio: audio
self.process_input = lambda audio: audio
self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+ self.disable_offload = True
elif "blocks.2.blocks.3.stack.5.weight" in sd or "decoder.blocks.2.blocks.3.stack.5.weight" in sd or "layers.4.layers.1.attn_block.attn.qkv.weight" in sd or "encoder.layers.4.layers.1.attn_block.attn.qkv.weight" in sd: #genmo mochi vae
if "blocks.2.blocks.3.stack.5.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"": "decoder."})
@@ -354,7 +375,12 @@ class VAE:
version = 0
elif tensor_conv1.shape[0] == 1024:
version = 1
- self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version)
+ if "encoder.down_blocks.1.conv.conv.bias" in sd:
+ version = 2
+ vae_config = None
+ if metadata is not None and "config" in metadata:
+ vae_config = json.loads(metadata["config"]).get("vae", None)
+ self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
self.latent_channels = 128
self.latent_dim = 3
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
@@ -388,9 +414,46 @@ class VAE:
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
#TODO: these values are a bit off because this is not a standard VAE
- self.memory_used_decode = lambda shape, dtype: (220 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
- self.memory_used_encode = lambda shape, dtype: (500 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
+ self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
+ elif "decoder.middle.0.residual.0.gamma" in sd:
+ self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
+ self.upscale_index_formula = (4, 8, 8)
+ self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
+ self.downscale_index_formula = (4, 8, 8)
+ self.latent_dim = 3
+ self.latent_channels = 16
+ ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0}
+ self.first_stage_model = comfy.ldm.wan.vae.WanVAE(**ddconfig)
+ self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+ self.memory_used_encode = lambda shape, dtype: 6000 * shape[3] * shape[4] * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: 7000 * shape[3] * shape[4] * (8 * 8) * model_management.dtype_size(dtype)
+ elif "geo_decoder.cross_attn_decoder.ln_1.bias" in sd:
+ self.latent_dim = 1
+ ln_post = "geo_decoder.ln_post.weight" in sd
+ inner_size = sd["geo_decoder.output_proj.weight"].shape[1]
+ downsample_ratio = sd["post_kl.weight"].shape[0] // inner_size
+ mlp_expand = sd["geo_decoder.cross_attn_decoder.mlp.c_fc.weight"].shape[0] // inner_size
+ self.memory_used_encode = lambda shape, dtype: (1000 * shape[2]) * model_management.dtype_size(dtype) # TODO
+ self.memory_used_decode = lambda shape, dtype: (1024 * 1024 * 1024 * 2.0) * model_management.dtype_size(dtype) # TODO
+ ddconfig = {"embed_dim": 64, "num_freqs": 8, "include_pi": False, "heads": 16, "width": 1024, "num_decoder_layers": 16, "qkv_bias": False, "qk_norm": True, "geo_decoder_mlp_expand_ratio": mlp_expand, "geo_decoder_downsample_ratio": downsample_ratio, "geo_decoder_ln_post": ln_post}
+ self.first_stage_model = comfy.ldm.hunyuan3d.vae.ShapeVAE(**ddconfig)
+ self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+ elif "vocoder.backbone.channel_layers.0.0.bias" in sd: #Ace Step Audio
+ self.first_stage_model = comfy.ldm.ace.vae.music_dcae_pipeline.MusicDCAE(source_sample_rate=44100)
+ self.memory_used_encode = lambda shape, dtype: (shape[2] * 330) * model_management.dtype_size(dtype)
+ self.memory_used_decode = lambda shape, dtype: (shape[2] * shape[3] * 87000) * model_management.dtype_size(dtype)
+ self.latent_channels = 8
+ self.output_channels = 2
+ self.upscale_ratio = 4096
+ self.downscale_ratio = 4096
+ self.latent_dim = 2
+ self.process_output = lambda audio: audio
+ self.process_input = lambda audio: audio
+ self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+ self.disable_offload = True
+ self.extra_1d_channel = 16
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
self.first_stage_model = None
@@ -419,6 +482,10 @@ class VAE:
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
+ def throw_exception_if_invalid(self):
+ if self.first_stage_model is None:
+ raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.")
+
def vae_encode_crop_pixels(self, pixels):
downscale_ratio = self.spacial_compression_encode()
@@ -445,7 +512,13 @@ class VAE:
return output
def decode_tiled_1d(self, samples, tile_x=128, overlap=32):
- decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
+ if samples.ndim == 3:
+ decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
+ else:
+ og_shape = samples.shape
+ samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
+ decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
+
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
@@ -465,33 +538,49 @@ class VAE:
samples /= 3.0
return samples
- def encode_tiled_1d(self, samples, tile_x=128 * 2048, overlap=32 * 2048):
- encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
- return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=(1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device)
+ def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
+ if self.latent_dim == 1:
+ encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
+ out_channels = self.latent_channels
+ upscale_amount = 1 / self.downscale_ratio
+ else:
+ extra_channel_size = self.extra_1d_channel
+ out_channels = self.latent_channels * extra_channel_size
+ tile_x = tile_x // extra_channel_size
+ overlap = overlap // extra_channel_size
+ upscale_amount = 1 / self.downscale_ratio
+ encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
+
+ out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
+ if self.latent_dim == 1:
+ return out
+ else:
+ return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
- def decode(self, samples_in):
+ def decode(self, samples_in, vae_options={}):
+ self.throw_exception_if_invalid()
pixel_samples = None
try:
memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype)
- model_management.load_models_gpu([self.patcher], memory_required=memory_used)
+ model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number)
for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
- out = self.process_output(self.first_stage_model.decode(samples).to(self.output_device).float())
+ out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
dims = samples_in.ndim - 2
- if dims == 1:
+ if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in)
elif dims == 2:
pixel_samples = self.decode_tiled_(samples_in)
@@ -504,8 +593,9 @@ class VAE:
return pixel_samples
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
+ self.throw_exception_if_invalid()
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) #TODO: calculate mem required for tile
- model_management.load_models_gpu([self.patcher], memory_required=memory_used)
+ model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
dims = samples.ndim - 2
args = {}
if tile_x is not None:
@@ -532,13 +622,14 @@ class VAE:
return output.movedim(1, -1)
def encode(self, pixel_samples):
+ self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
pixel_samples = pixel_samples.movedim(-1, 1)
- if self.latent_dim == 3:
+ if self.latent_dim == 3 and pixel_samples.ndim < 5:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
try:
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype)
- model_management.load_models_gpu([self.patcher], memory_required=memory_used)
+ model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
free_memory = model_management.get_free_memory(self.device)
batch_number = int(free_memory / max(1, memory_used))
batch_number = max(1, batch_number)
@@ -556,7 +647,7 @@ class VAE:
tile = 256
overlap = tile // 4
samples = self.encode_tiled_3d(pixel_samples, tile_x=tile, tile_y=tile, overlap=(1, overlap, overlap))
- elif self.latent_dim == 1:
+ elif self.latent_dim == 1 or self.extra_1d_channel is not None:
samples = self.encode_tiled_1d(pixel_samples)
else:
samples = self.encode_tiled_(pixel_samples)
@@ -564,6 +655,7 @@ class VAE:
return samples
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
+ self.throw_exception_if_invalid()
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
dims = self.latent_dim
pixel_samples = pixel_samples.movedim(-1, 1)
@@ -571,7 +663,7 @@ class VAE:
pixel_samples = pixel_samples.movedim(1, 0).unsqueeze(0)
memory_used = self.memory_used_encode(pixel_samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
- model_management.load_models_gpu([self.patcher], memory_required=memory_used)
+ model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload)
args = {}
if tile_x is not None:
@@ -657,6 +749,11 @@ class CLIPType(Enum):
HUNYUAN_VIDEO = 9
PIXART = 10
COSMOS = 11
+ LUMINA2 = 12
+ WAN = 13
+ HIDREAM = 14
+ CHROMA = 15
+ ACE = 16
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
@@ -675,6 +772,7 @@ class TEModel(Enum):
T5_BASE = 6
LLAMA3_8 = 7
T5_XXL_OLD = 8
+ GEMMA_2_2B = 9
def detect_te_model(sd):
if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
@@ -693,6 +791,8 @@ def detect_te_model(sd):
return TEModel.T5_XXL_OLD
if "encoder.block.0.layer.0.SelfAttention.k.weight" in sd:
return TEModel.T5_BASE
+ if 'model.layers.0.post_feedforward_layernorm.weight' in sd:
+ return TEModel.GEMMA_2_2B
if "model.layers.0.post_attention_layernorm.weight" in sd:
return TEModel.LLAMA3_8
return None
@@ -730,6 +830,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
if "text_projection" in clip_data[i]:
clip_data[i]["text_projection.weight"] = clip_data[i]["text_projection"].transpose(0, 1) #old models saved with the CLIPSave node
+ tokenizer_data = {}
clip_target = EmptyClass()
clip_target.params = {}
if len(clip_data) == 1:
@@ -741,6 +842,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
+ elif clip_type == CLIPType.HIDREAM:
+ clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
+ clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sdxl_clip.SDXLRefinerClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
@@ -754,9 +858,17 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer
- elif clip_type == CLIPType.PIXART:
+ elif clip_type == CLIPType.PIXART or clip_type == CLIPType.CHROMA:
clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer
+ elif clip_type == CLIPType.WAN:
+ clip_target.clip = comfy.text_encoders.wan.te(**t5xxl_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.wan.WanT5Tokenizer
+ tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
+ elif clip_type == CLIPType.HIDREAM:
+ clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data),
+ clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None)
+ clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else: #CLIPType.MOCHI
clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer
@@ -767,12 +879,29 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.aura_t5.AuraT5Model
clip_target.tokenizer = comfy.text_encoders.aura_t5.AuraT5Tokenizer
elif te_model == TEModel.T5_BASE:
- clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
- clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
+ if clip_type == CLIPType.ACE or "spiece_model" in clip_data[0]:
+ clip_target.clip = comfy.text_encoders.ace.AceT5Model
+ clip_target.tokenizer = comfy.text_encoders.ace.AceT5Tokenizer
+ tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
+ else:
+ clip_target.clip = comfy.text_encoders.sa_t5.SAT5Model
+ clip_target.tokenizer = comfy.text_encoders.sa_t5.SAT5Tokenizer
+ elif te_model == TEModel.GEMMA_2_2B:
+ clip_target.clip = comfy.text_encoders.lumina2.te(**llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.lumina2.LuminaTokenizer
+ tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
+ elif te_model == TEModel.LLAMA3_8:
+ clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data),
+ clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None)
+ clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
+ # clip_l
if clip_type == CLIPType.SD3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False)
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
+ elif clip_type == CLIPType.HIDREAM:
+ clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None)
+ clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sd1_clip.SD1ClipModel
clip_target.tokenizer = sd1_clip.SD1Tokenizer
@@ -790,15 +919,35 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
elif clip_type == CLIPType.HUNYUAN_VIDEO:
clip_target.clip = comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer
+ elif clip_type == CLIPType.HIDREAM:
+ # Detect
+ hidream_dualclip_classes = []
+ for hidream_te in clip_data:
+ te_model = detect_te_model(hidream_te)
+ hidream_dualclip_classes.append(te_model)
+
+ clip_l = TEModel.CLIP_L in hidream_dualclip_classes
+ clip_g = TEModel.CLIP_G in hidream_dualclip_classes
+ t5 = TEModel.T5_XXL in hidream_dualclip_classes
+ llama = TEModel.LLAMA3_8 in hidream_dualclip_classes
+
+ # Initialize t5xxl_detect and llama_detect kwargs if needed
+ t5_kwargs = t5xxl_detect(clip_data) if t5 else {}
+ llama_kwargs = llama_detect(clip_data) if llama else {}
+
+ clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, **t5_kwargs, **llama_kwargs)
+ clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
else:
clip_target.clip = sdxl_clip.SDXLClipModel
clip_target.tokenizer = sdxl_clip.SDXLTokenizer
elif len(clip_data) == 3:
clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(**t5xxl_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer
+ elif len(clip_data) == 4:
+ clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), **llama_detect(clip_data))
+ clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer
parameters = 0
- tokenizer_data = {}
for c in clip_data:
parameters += comfy.utils.calculate_parameters(c)
tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options)
@@ -845,13 +994,13 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
return (model, clip, vae)
def load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
- sd = comfy.utils.load_torch_file(ckpt_path)
- out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options)
+ sd, metadata = comfy.utils.load_torch_file(ckpt_path, return_metadata=True)
+ out = load_state_dict_guess_config(sd, output_vae, output_clip, output_clipvision, embedding_directory, output_model, model_options, te_model_options=te_model_options, metadata=metadata)
if out is None:
raise RuntimeError("ERROR: Could not detect model type of: {}".format(ckpt_path))
return out
-def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}):
+def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_clipvision=False, embedding_directory=None, output_model=True, model_options={}, te_model_options={}, metadata=None):
clip = None
clipvision = None
vae = None
@@ -863,19 +1012,24 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix)
load_device = model_management.get_torch_device()
- model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix)
+ model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata)
if model_config is None:
- return None
+ logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.")
+ diffusion_model = load_diffusion_model_state_dict(sd, model_options={})
+ if diffusion_model is None:
+ return None
+ return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used'
+
unet_weight_dtype = list(model_config.supported_inference_dtypes)
- if weight_dtype is not None and model_config.scaled_fp8 is None:
- unet_weight_dtype.append(weight_dtype)
+ if model_config.scaled_fp8 is not None:
+ weight_dtype = None
model_config.custom_operations = model_options.get("custom_operations", None)
unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None))
if unet_dtype is None:
- unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
+ unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes)
model_config.set_inference_dtype(unet_dtype, manual_cast_dtype)
@@ -892,7 +1046,7 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c
if output_vae:
vae_sd = comfy.utils.state_dict_prefix_replace(sd, {k: "" for k in model_config.vae_key_prefix}, filter_keys=True)
vae_sd = model_config.process_vae_state_dict(vae_sd)
- vae = VAE(sd=vae_sd)
+ vae = VAE(sd=vae_sd, metadata=metadata)
if output_clip:
clip_target = model_config.clip_target(state_dict=sd)
@@ -966,11 +1120,11 @@ def load_diffusion_model_state_dict(sd, model_options={}): #load unet in diffuse
offload_device = model_management.unet_offload_device()
unet_weight_dtype = list(model_config.supported_inference_dtypes)
- if weight_dtype is not None and model_config.scaled_fp8 is None:
- unet_weight_dtype.append(weight_dtype)
+ if model_config.scaled_fp8 is not None:
+ weight_dtype = None
if dtype is None:
- unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype)
+ unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype)
else:
unet_dtype = dtype
diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py
index 95d41c30f..ac61babe9 100644
--- a/comfy/sd1_clip.py
+++ b/comfy/sd1_clip.py
@@ -82,7 +82,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
LAYERS = [
"last",
"pooled",
- "hidden"
+ "hidden",
+ "all"
]
def __init__(self, device="cpu", max_length=77,
freeze=True, layer="last", layer_idx=None, textmodel_json_config=None, dtype=None, model_class=comfy.clip_model.CLIPTextModel,
@@ -93,6 +94,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
if textmodel_json_config is None:
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_clip_config.json")
+ if "model_name" not in model_options:
+ model_options = {**model_options, "model_name": "clip_l"}
if isinstance(textmodel_json_config, dict):
config = textmodel_json_config
@@ -100,6 +103,10 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
with open(textmodel_json_config) as f:
config = json.load(f)
+ te_model_options = model_options.get("{}_model_config".format(model_options.get("model_name", "")), {})
+ for k, v in te_model_options.items():
+ config[k] = v
+
operations = model_options.get("custom_operations", None)
scaled_fp8 = None
@@ -147,7 +154,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
def set_clip_options(self, options):
layer_idx = options.get("layer", self.layer_idx)
self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled)
- if layer_idx is None or abs(layer_idx) > self.num_layers:
+ if self.layer == "all":
+ pass
+ elif layer_idx is None or abs(layer_idx) > self.num_layers:
self.layer = "last"
else:
self.layer = "hidden"
@@ -158,71 +167,98 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
self.layer_idx = self.options_default[1]
self.return_projected_pooled = self.options_default[2]
- def set_up_textual_embeddings(self, tokens, current_embeds):
- out_tokens = []
- next_new_token = token_dict_size = current_embeds.weight.shape[0]
- embedding_weights = []
+ def process_tokens(self, tokens, device):
+ end_token = self.special_tokens.get("end", None)
+ if end_token is None:
+ cmp_token = self.special_tokens.get("pad", -1)
+ else:
+ cmp_token = end_token
+
+ embeds_out = []
+ attention_masks = []
+ num_tokens = []
for x in tokens:
+ attention_mask = []
tokens_temp = []
+ other_embeds = []
+ eos = False
+ index = 0
for y in x:
if isinstance(y, numbers.Integral):
- tokens_temp += [int(y)]
- else:
- if y.shape[0] == current_embeds.weight.shape[1]:
- embedding_weights += [y]
- tokens_temp += [next_new_token]
- next_new_token += 1
+ if eos:
+ attention_mask.append(0)
else:
- logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(y.shape[0], current_embeds.weight.shape[1]))
- while len(tokens_temp) < len(x):
- tokens_temp += [self.special_tokens["pad"]]
- out_tokens += [tokens_temp]
+ attention_mask.append(1)
+ token = int(y)
+ tokens_temp += [token]
+ if not eos and token == cmp_token:
+ if end_token is None:
+ attention_mask[-1] = 0
+ eos = True
+ else:
+ other_embeds.append((index, y))
+ index += 1
- n = token_dict_size
- if len(embedding_weights) > 0:
- new_embedding = self.operations.Embedding(next_new_token + 1, current_embeds.weight.shape[1], device=current_embeds.weight.device, dtype=current_embeds.weight.dtype)
- new_embedding.weight[:token_dict_size] = current_embeds.weight
- for x in embedding_weights:
- new_embedding.weight[n] = x
- n += 1
- self.transformer.set_input_embeddings(new_embedding)
+ tokens_embed = torch.tensor([tokens_temp], device=device, dtype=torch.long)
+ tokens_embed = self.transformer.get_input_embeddings()(tokens_embed, out_dtype=torch.float32)
+ index = 0
+ pad_extra = 0
+ for o in other_embeds:
+ emb = o[1]
+ if torch.is_tensor(emb):
+ emb = {"type": "embedding", "data": emb}
- processed_tokens = []
- for x in out_tokens:
- processed_tokens += [list(map(lambda a: n if a == -1 else a, x))] #The EOS token should always be the largest one
+ emb_type = emb.get("type", None)
+ if emb_type == "embedding":
+ emb = emb.get("data", None)
+ else:
+ if hasattr(self.transformer, "preprocess_embed"):
+ emb = self.transformer.preprocess_embed(emb, device=device)
+ else:
+ emb = None
- return processed_tokens
+ if emb is None:
+ index += -1
+ continue
+
+ ind = index + o[0]
+ emb = emb.view(1, -1, emb.shape[-1]).to(device=device, dtype=torch.float32)
+ emb_shape = emb.shape[1]
+ if emb.shape[-1] == tokens_embed.shape[-1]:
+ tokens_embed = torch.cat([tokens_embed[:, :ind], emb, tokens_embed[:, ind:]], dim=1)
+ attention_mask = attention_mask[:ind] + [1] * emb_shape + attention_mask[ind:]
+ index += emb_shape - 1
+ else:
+ index += -1
+ pad_extra += emb_shape
+ logging.warning("WARNING: shape mismatch when trying to apply embedding, embedding will be ignored {} != {}".format(emb.shape[-1], tokens_embed.shape[-1]))
+
+ if pad_extra > 0:
+ padd_embed = self.transformer.get_input_embeddings()(torch.tensor([[self.special_tokens["pad"]] * pad_extra], device=device, dtype=torch.long), out_dtype=torch.float32)
+ tokens_embed = torch.cat([tokens_embed, padd_embed], dim=1)
+ attention_mask = attention_mask + [0] * pad_extra
+
+ embeds_out.append(tokens_embed)
+ attention_masks.append(attention_mask)
+ num_tokens.append(sum(attention_mask))
+
+ return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens
def forward(self, tokens):
- backup_embeds = self.transformer.get_input_embeddings()
- device = backup_embeds.weight.device
- tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
- tokens = torch.LongTensor(tokens).to(device)
-
- attention_mask = None
- if self.enable_attention_masks or self.zero_out_masked or self.return_attention_masks:
- attention_mask = torch.zeros_like(tokens)
- end_token = self.special_tokens.get("end", None)
- if end_token is None:
- cmp_token = self.special_tokens.get("pad", -1)
- else:
- cmp_token = end_token
-
- for x in range(attention_mask.shape[0]):
- for y in range(attention_mask.shape[1]):
- attention_mask[x, y] = 1
- if tokens[x, y] == cmp_token:
- if end_token is None:
- attention_mask[x, y] = 0
- break
+ device = self.transformer.get_input_embeddings().weight.device
+ embeds, attention_mask, num_tokens = self.process_tokens(tokens, device)
attention_mask_model = None
if self.enable_attention_masks:
attention_mask_model = attention_mask
- outputs = self.transformer(tokens, attention_mask_model, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
- self.transformer.set_input_embeddings(backup_embeds)
+ if self.layer == "all":
+ intermediate_output = "all"
+ else:
+ intermediate_output = self.layer_idx
+
+ outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32)
if self.layer == "last":
z = outputs[0].float()
@@ -388,13 +424,10 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
import safetensors.torch
embed = safetensors.torch.load_file(embed_path, device="cpu")
else:
- if 'weights_only' in torch.load.__code__.co_varnames:
- try:
- embed = torch.load(embed_path, weights_only=True, map_location="cpu")
- except:
- embed_out = safe_load_embed_zip(embed_path)
- else:
- embed = torch.load(embed_path, map_location="cpu")
+ try:
+ embed = torch.load(embed_path, weights_only=True, map_location="cpu")
+ except:
+ embed_out = safe_load_embed_zip(embed_path)
except Exception:
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
return None
@@ -424,13 +457,14 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
return embed_out
class SDTokenizer:
- def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, tokenizer_data={}):
+ def __init__(self, tokenizer_path=None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data={}, tokenizer_args={}):
if tokenizer_path is None:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
- self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path)
- self.max_length = max_length
+ self.tokenizer = tokenizer_class.from_pretrained(tokenizer_path, **tokenizer_args)
+ self.max_length = tokenizer_data.get("{}_max_length".format(embedding_key), max_length)
self.min_length = min_length
self.end_token = None
+ self.min_padding = min_padding
empty = self.tokenizer('')["input_ids"]
self.tokenizer_adds_end_token = has_end_token
@@ -485,13 +519,15 @@ class SDTokenizer:
return (embed, leftover)
- def tokenize_with_weights(self, text:str, return_word_ids=False):
+ def tokenize_with_weights(self, text:str, return_word_ids=False, tokenizer_options={}, **kwargs):
'''
Takes a prompt and converts it to a list of (token, weight, word id) elements.
Tokens can both be integer tokens and pre computed CLIP tensors.
Word id values are unique per word and embedding, where the id 0 is reserved for non word tokens.
Returned list has the dimensions NxM where M is the input size of CLIP
'''
+ min_length = tokenizer_options.get("{}_min_length".format(self.embedding_key), self.min_length)
+ min_padding = tokenizer_options.get("{}_min_padding".format(self.embedding_key), self.min_padding)
text = escape_important(text)
parsed_weights = token_weights(text, 1.0)
@@ -570,10 +606,12 @@ class SDTokenizer:
#fill last batch
if self.end_token is not None:
batch.append((self.end_token, 1.0, 0))
- if self.pad_to_max_length:
+ if min_padding is not None:
+ batch.extend([(self.pad_token, 1.0, 0)] * min_padding)
+ if self.pad_to_max_length and len(batch) < self.max_length:
batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch)))
- if self.min_length is not None and len(batch) < self.min_length:
- batch.extend([(self.pad_token, 1.0, 0)] * (self.min_length - len(batch)))
+ if min_length is not None and len(batch) < min_length:
+ batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch)))
if not return_word_ids:
batched_tokens = [[(t, w) for t, w,_ in x] for x in batched_tokens]
@@ -588,22 +626,27 @@ class SDTokenizer:
return {}
class SD1Tokenizer:
- def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer):
- self.clip_name = clip_name
- self.clip = "clip_{}".format(self.clip_name)
+ def __init__(self, embedding_directory=None, tokenizer_data={}, clip_name="l", tokenizer=SDTokenizer, name=None):
+ if name is not None:
+ self.clip_name = name
+ self.clip = "{}".format(self.clip_name)
+ else:
+ self.clip_name = clip_name
+ self.clip = "clip_{}".format(self.clip_name)
+
tokenizer = tokenizer_data.get("{}_tokenizer_class".format(self.clip), tokenizer)
setattr(self, self.clip, tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data))
- def tokenize_with_weights(self, text:str, return_word_ids=False):
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
- out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids)
+ out[self.clip_name] = getattr(self, self.clip).tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
return getattr(self, self.clip).untokenize(token_weight_pair)
def state_dict(self):
- return {}
+ return getattr(self, self.clip).state_dict()
class SD1CheckpointClipModel(SDClipModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
@@ -621,6 +664,7 @@ class SD1ClipModel(torch.nn.Module):
self.clip = "clip_{}".format(self.clip_name)
clip_model = model_options.get("{}_class".format(self.clip), clip_model)
+ model_options = {**model_options, "model_name": self.clip}
setattr(self, self.clip, clip_model(device=device, dtype=dtype, model_options=model_options, **kwargs))
self.dtypes = set()
diff --git a/comfy/sdxl_clip.py b/comfy/sdxl_clip.py
index 4d0a4e8e7..c8cef14e4 100644
--- a/comfy/sdxl_clip.py
+++ b/comfy/sdxl_clip.py
@@ -9,6 +9,7 @@ class SDXLClipG(sd1_clip.SDClipModel):
layer_idx=-2
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
+ model_options = {**model_options, "model_name": "clip_g"}
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False, return_projected_pooled=True, model_options=model_options)
@@ -17,19 +18,18 @@ class SDXLClipG(sd1_clip.SDClipModel):
class SDXLClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
- super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
class SDXLTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
- clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
- self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
- self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory)
+ self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+ self.clip_g = SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
- def tokenize_with_weights(self, text:str, return_word_ids=False):
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
- out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
- out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
+ out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
+ out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
@@ -41,8 +41,7 @@ class SDXLTokenizer:
class SDXLClipModel(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}):
super().__init__()
- clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
- self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
+ self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, model_options=model_options)
self.clip_g = SDXLClipG(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set([dtype])
@@ -75,7 +74,7 @@ class SDXLRefinerClipModel(sd1_clip.SD1ClipModel):
class StableCascadeClipGTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
- super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g')
+ super().__init__(tokenizer_path, pad_with_end=True, embedding_directory=embedding_directory, embedding_size=1280, embedding_key='clip_g', tokenizer_data=tokenizer_data)
class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
@@ -84,6 +83,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "clip_config_bigg.json")
+ model_options = {**model_options, "model_name": "clip_g"}
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True, return_projected_pooled=True, model_options=model_options)
diff --git a/comfy/supported_models.py b/comfy/supported_models.py
index 31de1ae9e..efe2e6b8f 100644
--- a/comfy/supported_models.py
+++ b/comfy/supported_models.py
@@ -15,6 +15,9 @@ import comfy.text_encoders.genmo
import comfy.text_encoders.lt
import comfy.text_encoders.hunyuan_video
import comfy.text_encoders.cosmos
+import comfy.text_encoders.lumina2
+import comfy.text_encoders.wan
+import comfy.text_encoders.ace
from . import supported_models_base
from . import latent_formats
@@ -504,6 +507,22 @@ class SDXL_instructpix2pix(SDXL):
def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXL_instructpix2pix(self, model_type=self.model_type(state_dict, prefix), device=device)
+class LotusD(SD20):
+ unet_config = {
+ "model_channels": 320,
+ "use_linear_in_transformer": True,
+ "use_temporal_attention": False,
+ "adm_in_channels": 4,
+ "in_channels": 4,
+ }
+
+ unet_extra_config = {
+ "num_classes": 'sequential'
+ }
+
+ def get_model(self, state_dict, prefix="", device=None):
+ return model_base.Lotus(self, device=device)
+
class SD3(supported_models_base.BASE):
unet_config = {
"in_channels": 16,
@@ -760,13 +779,17 @@ class LTXV(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.LTXV
- memory_usage_factor = 2.7
+ memory_usage_factor = 5.5 # TODO: img2vid is about 2x vs txt2vid
supported_inference_dtypes = [torch.bfloat16, torch.float32]
vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]
+ def __init__(self, unet_config):
+ super().__init__(unet_config)
+ self.memory_usage_factor = (unet_config.get("cross_attention_dim", 2048) / 2048) * 5.5
+
def get_model(self, state_dict, prefix="", device=None):
out = model_base.LTXV(self, device=device)
return out
@@ -788,7 +811,7 @@ class HunyuanVideo(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo
- memory_usage_factor = 2.0 #TODO
+ memory_usage_factor = 1.8 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float32]
@@ -824,9 +847,30 @@ class HunyuanVideo(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect))
-class Cosmos(supported_models_base.BASE):
+class HunyuanVideoI2V(HunyuanVideo):
+ unet_config = {
+ "image_model": "hunyuan_video",
+ "in_channels": 33,
+ }
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.HunyuanVideoI2V(self, device=device)
+ return out
+
+class HunyuanVideoSkyreelsI2V(HunyuanVideo):
+ unet_config = {
+ "image_model": "hunyuan_video",
+ "in_channels": 32,
+ }
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.HunyuanVideoSkyreelsI2V(self, device=device)
+ return out
+
+class CosmosT2V(supported_models_base.BASE):
unet_config = {
"image_model": "cosmos",
+ "in_channels": 16,
}
sampling_settings = {
@@ -838,7 +882,7 @@ class Cosmos(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Cosmos1CV8x8x8
- memory_usage_factor = 2.4 #TODO
+ memory_usage_factor = 1.6 #TODO
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO
@@ -854,7 +898,247 @@ class Cosmos(supported_models_base.BASE):
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.cosmos.CosmosT5Tokenizer, comfy.text_encoders.cosmos.te(**t5_detect))
+class CosmosI2V(CosmosT2V):
+ unet_config = {
+ "image_model": "cosmos",
+ "in_channels": 17,
+ }
-models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo, Cosmos]
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.CosmosVideo(self, image_to_video=True, device=device)
+ return out
+
+class Lumina2(supported_models_base.BASE):
+ unet_config = {
+ "image_model": "lumina2",
+ }
+
+ sampling_settings = {
+ "multiplier": 1.0,
+ "shift": 6.0,
+ }
+
+ memory_usage_factor = 1.2
+
+ unet_extra_config = {}
+ latent_format = latent_formats.Flux
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Lumina2(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}gemma2_2b.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.lumina2.LuminaTokenizer, comfy.text_encoders.lumina2.te(**hunyuan_detect))
+
+class WAN21_T2V(supported_models_base.BASE):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "t2v",
+ }
+
+ sampling_settings = {
+ "shift": 8.0,
+ }
+
+ unet_extra_config = {}
+ latent_format = latent_formats.Wan21
+
+ memory_usage_factor = 1.0
+
+ supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def __init__(self, unet_config):
+ super().__init__(unet_config)
+ self.memory_usage_factor = self.unet_config.get("dim", 2000) / 2000
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.WAN21(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}umt5xxl.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.wan.WanT5Tokenizer, comfy.text_encoders.wan.te(**t5_detect))
+
+class WAN21_I2V(WAN21_T2V):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "i2v",
+ "in_dim": 36,
+ }
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.WAN21(self, image_to_video=True, device=device)
+ return out
+
+class WAN21_FunControl2V(WAN21_T2V):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "i2v",
+ "in_dim": 48,
+ }
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.WAN21(self, image_to_video=False, device=device)
+ return out
+
+class WAN21_Camera(WAN21_T2V):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "camera",
+ "in_dim": 32,
+ }
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.WAN21_Camera(self, image_to_video=False, device=device)
+ return out
+class WAN21_Vace(WAN21_T2V):
+ unet_config = {
+ "image_model": "wan2.1",
+ "model_type": "vace",
+ }
+
+ def __init__(self, unet_config):
+ super().__init__(unet_config)
+ self.memory_usage_factor = 1.2 * self.memory_usage_factor
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.WAN21_Vace(self, image_to_video=False, device=device)
+ return out
+
+class Hunyuan3Dv2(supported_models_base.BASE):
+ unet_config = {
+ "image_model": "hunyuan3d2",
+ }
+
+ unet_extra_config = {}
+
+ sampling_settings = {
+ "multiplier": 1.0,
+ "shift": 1.0,
+ }
+
+ memory_usage_factor = 3.5
+
+ clip_vision_prefix = "conditioner.main_image_encoder.model."
+ vae_key_prefix = ["vae."]
+
+ latent_format = latent_formats.Hunyuan3Dv2
+
+ def process_unet_state_dict_for_saving(self, state_dict):
+ replace_prefix = {"": "model."}
+ return utils.state_dict_prefix_replace(state_dict, replace_prefix)
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Hunyuan3Dv2(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ return None
+
+class Hunyuan3Dv2mini(Hunyuan3Dv2):
+ unet_config = {
+ "image_model": "hunyuan3d2",
+ "depth": 8,
+ }
+
+ latent_format = latent_formats.Hunyuan3Dv2mini
+
+class HiDream(supported_models_base.BASE):
+ unet_config = {
+ "image_model": "hidream",
+ }
+
+ sampling_settings = {
+ "shift": 3.0,
+ }
+
+ sampling_settings = {
+ }
+
+ # memory_usage_factor = 1.2 # TODO
+
+ unet_extra_config = {}
+ latent_format = latent_formats.Flux
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.HiDream(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ return None # TODO
+
+class Chroma(supported_models_base.BASE):
+ unet_config = {
+ "image_model": "chroma",
+ }
+
+ unet_extra_config = {
+ }
+
+ sampling_settings = {
+ "multiplier": 1.0,
+ }
+
+ latent_format = comfy.latent_formats.Flux
+
+ memory_usage_factor = 3.2
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
+
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.Chroma(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ pref = self.text_encoder_key_prefix[0]
+ t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
+ return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.pixart_te(**t5_detect))
+
+class ACEStep(supported_models_base.BASE):
+ unet_config = {
+ "audio_model": "ace",
+ }
+
+ unet_extra_config = {
+ }
+
+ sampling_settings = {
+ "shift": 3.0,
+ }
+
+ latent_format = comfy.latent_formats.ACEAudio
+
+ memory_usage_factor = 0.5
+
+ supported_inference_dtypes = [torch.bfloat16, torch.float32]
+
+ vae_key_prefix = ["vae."]
+ text_encoder_key_prefix = ["text_encoders."]
+
+ def get_model(self, state_dict, prefix="", device=None):
+ out = model_base.ACEStep(self, device=device)
+ return out
+
+ def clip_target(self, state_dict={}):
+ return supported_models_base.ClipTarget(comfy.text_encoders.ace.AceT5Tokenizer, comfy.text_encoders.ace.AceT5Model)
+
+models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, Lumina2, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, Hunyuan3Dv2mini, Hunyuan3Dv2, HiDream, Chroma, ACEStep]
models += [SVD_img2vid]
diff --git a/comfy/text_encoders/ace.py b/comfy/text_encoders/ace.py
new file mode 100644
index 000000000..d650bb10d
--- /dev/null
+++ b/comfy/text_encoders/ace.py
@@ -0,0 +1,153 @@
+from comfy import sd1_clip
+from .spiece_tokenizer import SPieceTokenizer
+import comfy.text_encoders.t5
+import os
+import re
+import torch
+import logging
+
+from tokenizers import Tokenizer
+from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji
+
+SUPPORT_LANGUAGES = {
+ "en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
+ "pt": 286, "pl": 294, "tr": 295, "ru": 267, "cs": 293,
+ "nl": 297, "ar": 5022, "zh": 5023, "ja": 5412, "hu": 5753,
+ "ko": 6152, "hi": 6680
+}
+
+structure_pattern = re.compile(r"\[.*?\]")
+
+DEFAULT_VOCAB_FILE = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
+
+
+class VoiceBpeTokenizer:
+ def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
+ self.tokenizer = None
+ if vocab_file is not None:
+ self.tokenizer = Tokenizer.from_file(vocab_file)
+
+ def preprocess_text(self, txt, lang):
+ txt = multilingual_cleaners(txt, lang)
+ return txt
+
+ def encode(self, txt, lang='en'):
+ # lang = lang.split("-")[0] # remove the region
+ # self.check_input_length(txt, lang)
+ txt = self.preprocess_text(txt, lang)
+ lang = "zh-cn" if lang == "zh" else lang
+ txt = f"[{lang}]{txt}"
+ txt = txt.replace(" ", "[SPACE]")
+ return self.tokenizer.encode(txt).ids
+
+ def get_lang(self, line):
+ if line.startswith("[") and line[3:4] == ']':
+ lang = line[1:3].lower()
+ if lang in SUPPORT_LANGUAGES:
+ return lang, line[4:]
+ return "en", line
+
+ def __call__(self, string):
+ lines = string.split("\n")
+ lyric_token_idx = [261]
+ for line in lines:
+ line = line.strip()
+ if not line:
+ lyric_token_idx += [2]
+ continue
+
+ lang, line = self.get_lang(line)
+
+ if lang not in SUPPORT_LANGUAGES:
+ lang = "en"
+ if "zh" in lang:
+ lang = "zh"
+ if "spa" in lang:
+ lang = "es"
+
+ try:
+ line_out = japanese_to_romaji(line)
+ if line_out != line:
+ lang = "ja"
+ line = line_out
+ except:
+ pass
+
+ try:
+ if structure_pattern.match(line):
+ token_idx = self.encode(line, "en")
+ else:
+ token_idx = self.encode(line, lang)
+ lyric_token_idx = lyric_token_idx + token_idx + [2]
+ except Exception as e:
+ logging.warning("tokenize error {} for line {} major_language {}".format(e, line, lang))
+ return {"input_ids": lyric_token_idx}
+
+ @staticmethod
+ def from_pretrained(path, **kwargs):
+ return VoiceBpeTokenizer(path, **kwargs)
+
+ def get_vocab(self):
+ return {}
+
+
+class UMT5BaseModel(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
+ textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_base.json")
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
+
+class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ tokenizer = tokenizer_data.get("spiece_model", None)
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)
+
+ def state_dict(self):
+ return {"spiece_model": self.tokenizer.serialize_model()}
+
+class LyricsTokenizer(sd1_clip.SDTokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ tokenizer = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)
+
+class AceT5Tokenizer:
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+ self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
+ out = {}
+ out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
+ out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
+ return out
+
+ def untokenize(self, token_weight_pair):
+ return self.umt5base.untokenize(token_weight_pair)
+
+ def state_dict(self):
+ return self.umt5base.state_dict()
+
+class AceT5Model(torch.nn.Module):
+ def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
+ super().__init__()
+ self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
+ self.dtypes = set()
+ if dtype is not None:
+ self.dtypes.add(dtype)
+
+ def set_clip_options(self, options):
+ self.umt5base.set_clip_options(options)
+
+ def reset_clip_options(self):
+ self.umt5base.reset_clip_options()
+
+ def encode_token_weights(self, token_weight_pairs):
+ token_weight_pairs_umt5base = token_weight_pairs["umt5base"]
+ token_weight_pairs_lyrics = token_weight_pairs["lyrics"]
+
+ t5_out, t5_pooled = self.umt5base.encode_token_weights(token_weight_pairs_umt5base)
+
+ lyrics_embeds = torch.tensor(list(map(lambda a: a[0], token_weight_pairs_lyrics[0]))).unsqueeze(0)
+ return t5_out, None, {"conditioning_lyrics": lyrics_embeds}
+
+ def load_sd(self, sd):
+ return self.umt5base.load_sd(sd)
diff --git a/comfy/text_encoders/ace_lyrics_tokenizer/vocab.json b/comfy/text_encoders/ace_lyrics_tokenizer/vocab.json
new file mode 100644
index 000000000..519ed340c
--- /dev/null
+++ b/comfy/text_encoders/ace_lyrics_tokenizer/vocab.json
@@ -0,0 +1,15535 @@
+{
+ "version": "1.0",
+ "truncation": null,
+ "padding": null,
+ "added_tokens": [
+ {
+ "id": 0,
+ "special": true,
+ "content": "[STOP]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 1,
+ "special": true,
+ "content": "[UNK]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 2,
+ "special": true,
+ "content": "[SPACE]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 259,
+ "special": true,
+ "content": "[en]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 260,
+ "special": true,
+ "content": "[de]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 261,
+ "special": true,
+ "content": "[START]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 262,
+ "special": true,
+ "content": "[fr]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 284,
+ "special": true,
+ "content": "[es]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 285,
+ "special": true,
+ "content": "[it]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 286,
+ "special": true,
+ "content": "[pt]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 294,
+ "special": true,
+ "content": "[pl]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 295,
+ "special": true,
+ "content": "[tr]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 267,
+ "special": true,
+ "content": "[ru]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 293,
+ "special": true,
+ "content": "[cs]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 297,
+ "special": true,
+ "content": "[nl]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 5022,
+ "special": true,
+ "content": "[ar]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 5023,
+ "special": true,
+ "content": "[zh-cn]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 5412,
+ "special": true,
+ "content": "[ja]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 5753,
+ "special": true,
+ "content": "[hu]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6152,
+ "special": true,
+ "content": "[ko]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6680,
+ "special": true,
+ "content": "[hi]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6681,
+ "special": true,
+ "content": "[start]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6682,
+ "special": true,
+ "content": "[intro]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6683,
+ "special": true,
+ "content": "[verse]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6684,
+ "special": true,
+ "content": "[chorus]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6685,
+ "special": true,
+ "content": "[bridge]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6686,
+ "special": true,
+ "content": "[outro]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6687,
+ "special": true,
+ "content": "[end]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6688,
+ "special": true,
+ "content": "[inst]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6689,
+ "special": true,
+ "content": "[solo]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6690,
+ "special": true,
+ "content": "[hook]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6691,
+ "special": true,
+ "content": "[pre-chorus]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 6692,
+ "special": true,
+ "content": "[break]",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ }
+ ],
+ "normalizer": null,
+ "pre_tokenizer": {
+ "type": "Whitespace"
+ },
+ "post_processor": null,
+ "decoder": null,
+ "model": {
+ "type": "BPE",
+ "dropout": null,
+ "unk_token": "[UNK]",
+ "continuing_subword_prefix": null,
+ "end_of_word_suffix": null,
+ "fuse_unk": false,
+ "vocab": {
+ "[STOP]": 0,
+ "[UNK]": 1,
+ "[SPACE]": 2,
+ "!": 3,
+ "'": 4,
+ "(": 5,
+ ")": 6,
+ ",": 7,
+ "-": 8,
+ ".": 9,
+ "/": 10,
+ ":": 11,
+ ";": 12,
+ "?": 13,
+ "a": 14,
+ "b": 15,
+ "c": 16,
+ "d": 17,
+ "e": 18,
+ "f": 19,
+ "g": 20,
+ "h": 21,
+ "i": 22,
+ "j": 23,
+ "k": 24,
+ "l": 25,
+ "m": 26,
+ "n": 27,
+ "o": 28,
+ "p": 29,
+ "q": 30,
+ "r": 31,
+ "s": 32,
+ "t": 33,
+ "u": 34,
+ "v": 35,
+ "w": 36,
+ "x": 37,
+ "y": 38,
+ "z": 39,
+ "th": 40,
+ "in": 41,
+ "the": 42,
+ "an": 43,
+ "er": 44,
+ "ou": 45,
+ "re": 46,
+ "on": 47,
+ "at": 48,
+ "ed": 49,
+ "en": 50,
+ "to": 51,
+ "ing": 52,
+ "and": 53,
+ "is": 54,
+ "as": 55,
+ "al": 56,
+ "or": 57,
+ "of": 58,
+ "ar": 59,
+ "it": 60,
+ "es": 61,
+ "he": 62,
+ "st": 63,
+ "le": 64,
+ "om": 65,
+ "se": 66,
+ "be": 67,
+ "ad": 68,
+ "ow": 69,
+ "ly": 70,
+ "ch": 71,
+ "wh": 72,
+ "that": 73,
+ "you": 74,
+ "li": 75,
+ "ve": 76,
+ "ac": 77,
+ "ti": 78,
+ "ld": 79,
+ "me": 80,
+ "was": 81,
+ "gh": 82,
+ "id": 83,
+ "ll": 84,
+ "wi": 85,
+ "ent": 86,
+ "for": 87,
+ "ay": 88,
+ "ro": 89,
+ "ver": 90,
+ "ic": 91,
+ "her": 92,
+ "ke": 93,
+ "his": 94,
+ "no": 95,
+ "ut": 96,
+ "un": 97,
+ "ir": 98,
+ "lo": 99,
+ "we": 100,
+ "ri": 101,
+ "ha": 102,
+ "with": 103,
+ "ght": 104,
+ "out": 105,
+ "im": 106,
+ "ion": 107,
+ "all": 108,
+ "ab": 109,
+ "one": 110,
+ "ne": 111,
+ "ge": 112,
+ "ould": 113,
+ "ter": 114,
+ "mo": 115,
+ "had": 116,
+ "ce": 117,
+ "she": 118,
+ "go": 119,
+ "sh": 120,
+ "ur": 121,
+ "am": 122,
+ "so": 123,
+ "pe": 124,
+ "my": 125,
+ "de": 126,
+ "are": 127,
+ "but": 128,
+ "ome": 129,
+ "fr": 130,
+ "ther": 131,
+ "fe": 132,
+ "su": 133,
+ "do": 134,
+ "con": 135,
+ "te": 136,
+ "ain": 137,
+ "ere": 138,
+ "po": 139,
+ "if": 140,
+ "they": 141,
+ "us": 142,
+ "ag": 143,
+ "tr": 144,
+ "now": 145,
+ "oun": 146,
+ "this": 147,
+ "have": 148,
+ "not": 149,
+ "sa": 150,
+ "il": 151,
+ "up": 152,
+ "thing": 153,
+ "from": 154,
+ "ap": 155,
+ "him": 156,
+ "ack": 157,
+ "ation": 158,
+ "ant": 159,
+ "our": 160,
+ "op": 161,
+ "like": 162,
+ "ust": 163,
+ "ess": 164,
+ "bo": 165,
+ "ok": 166,
+ "ul": 167,
+ "ind": 168,
+ "ex": 169,
+ "com": 170,
+ "some": 171,
+ "there": 172,
+ "ers": 173,
+ "co": 174,
+ "res": 175,
+ "man": 176,
+ "ard": 177,
+ "pl": 178,
+ "wor": 179,
+ "way": 180,
+ "tion": 181,
+ "fo": 182,
+ "ca": 183,
+ "were": 184,
+ "by": 185,
+ "ate": 186,
+ "pro": 187,
+ "ted": 188,
+ "ound": 189,
+ "own": 190,
+ "would": 191,
+ "ts": 192,
+ "what": 193,
+ "qu": 194,
+ "ally": 195,
+ "ight": 196,
+ "ck": 197,
+ "gr": 198,
+ "when": 199,
+ "ven": 200,
+ "can": 201,
+ "ough": 202,
+ "ine": 203,
+ "end": 204,
+ "per": 205,
+ "ous": 206,
+ "od": 207,
+ "ide": 208,
+ "know": 209,
+ "ty": 210,
+ "very": 211,
+ "si": 212,
+ "ak": 213,
+ "who": 214,
+ "about": 215,
+ "ill": 216,
+ "them": 217,
+ "est": 218,
+ "red": 219,
+ "ye": 220,
+ "could": 221,
+ "ong": 222,
+ "your": 223,
+ "their": 224,
+ "em": 225,
+ "just": 226,
+ "other": 227,
+ "into": 228,
+ "any": 229,
+ "whi": 230,
+ "um": 231,
+ "tw": 232,
+ "ast": 233,
+ "der": 234,
+ "did": 235,
+ "ie": 236,
+ "been": 237,
+ "ace": 238,
+ "ink": 239,
+ "ity": 240,
+ "back": 241,
+ "ting": 242,
+ "br": 243,
+ "more": 244,
+ "ake": 245,
+ "pp": 246,
+ "then": 247,
+ "sp": 248,
+ "el": 249,
+ "use": 250,
+ "bl": 251,
+ "said": 252,
+ "over": 253,
+ "get": 254,
+ "ß": 255,
+ "ä": 256,
+ "ö": 257,
+ "ü": 258,
+ "[en]": 259,
+ "[de]": 260,
+ "[START]": 261,
+ "[fr]": 262,
+ "œ": 263,
+ "ï": 264,
+ "ê": 265,
+ "â": 266,
+ "[ru]": 267,
+ "ÿ": 268,
+ "è": 269,
+ "à": 270,
+ "ë": 271,
+ "ù": 272,
+ "î": 273,
+ "ç": 274,
+ "æ": 275,
+ "ô": 276,
+ "û": 277,
+ "á": 278,
+ "é": 279,
+ "í": 280,
+ "ó": 281,
+ "ú": 282,
+ "ñ": 283,
+ "[es]": 284,
+ "[it]": 285,
+ "[pt]": 286,
+ "ń": 287,
+ "ś": 288,
+ "ę": 289,
+ "ą": 290,
+ "ż": 291,
+ "ć": 292,
+ "[cs]": 293,
+ "[pl]": 294,
+ "[tr]": 295,
+ "ã": 296,
+ "[nl]": 297,
+ "ş": 298,
+ "ğ": 299,
+ "ı": 300,
+ "ò": 301,
+ "ì": 302,
+ "¿": 303,
+ "…": 304,
+ "i̇": 305,
+ "õ": 306,
+ "\"": 307,
+ "´": 308,
+ "ø": 309,
+ "č": 310,
+ "ō": 311,
+ "š": 312,
+ "ž": 313,
+ "̇": 314,
+ "ei": 315,
+ "ich": 316,
+ "ein": 317,
+ "au": 318,
+ "sch": 319,
+ "und": 320,
+ "die": 321,
+ "da": 322,
+ "den": 323,
+ "gen": 324,
+ "zu": 325,
+ "hr": 326,
+ "ten": 327,
+ "mi": 328,
+ "sie": 329,
+ "das": 330,
+ "eine": 331,
+ "icht": 332,
+ "ber": 333,
+ "ach": 334,
+ "auf": 335,
+ "lich": 336,
+ "nicht": 337,
+ "mm": 338,
+ "ben": 339,
+ "war": 340,
+ "mit": 341,
+ "sich": 342,
+ "ig": 343,
+ "aus": 344,
+ "ist": 345,
+ "wie": 346,
+ "och": 347,
+ "ung": 348,
+ "ann": 349,
+ "ür": 350,
+ "hn": 351,
+ "ihr": 352,
+ "sen": 353,
+ "tz": 354,
+ "dem": 355,
+ "eit": 356,
+ "hat": 357,
+ "wir": 358,
+ "von": 359,
+ "wei": 360,
+ "ier": 361,
+ "ra": 362,
+ "einen": 363,
+ "vor": 364,
+ "als": 365,
+ "wo": 366,
+ "rei": 367,
+ "ste": 368,
+ "lie": 369,
+ "auch": 370,
+ "du": 371,
+ "des": 372,
+ "ko": 373,
+ "über": 374,
+ "bei": 375,
+ "hen": 376,
+ "hm": 377,
+ "lei": 378,
+ "aber": 379,
+ "wen": 380,
+ "hl": 381,
+ "ger": 382,
+ "nach": 383,
+ "ft": 384,
+ "imm": 385,
+ "je": 386,
+ "schen": 387,
+ "wer": 388,
+ "ser": 389,
+ "än": 390,
+ "sein": 391,
+ "ol": 392,
+ "cht": 393,
+ "für": 394,
+ "kl": 395,
+ "ff": 396,
+ "einem": 397,
+ "nen": 398,
+ "ja": 399,
+ "noch": 400,
+ "hatte": 401,
+ "pf": 402,
+ "hin": 403,
+ "di": 404,
+ "chen": 405,
+ "rü": 406,
+ "iel": 407,
+ "sel": 408,
+ "dass": 409,
+ "ihn": 410,
+ "mir": 411,
+ "schl": 412,
+ "ön": 413,
+ "gan": 414,
+ "gt": 415,
+ "einer": 416,
+ "sten": 417,
+ "mich": 418,
+ "wenn": 419,
+ "ell": 420,
+ "gte": 421,
+ "mal": 422,
+ "gel": 423,
+ "ken": 424,
+ "nur": 425,
+ "mmen": 426,
+ "fü": 427,
+ "ern": 428,
+ "ör": 429,
+ "unter": 430,
+ "ander": 431,
+ "dur": 432,
+ "uch": 433,
+ "ta": 434,
+ "men": 435,
+ "mach": 436,
+ "doch": 437,
+ "durch": 438,
+ "os": 439,
+ "gl": 440,
+ "hal": 441,
+ "ihre": 442,
+ "wä": 443,
+ "immer": 444,
+ "ihm": 445,
+ "kann": 446,
+ "ort": 447,
+ "dann": 448,
+ "lan": 449,
+ "tzt": 450,
+ "oder": 451,
+ "hren": 452,
+ "et": 453,
+ "kön": 454,
+ "ick": 455,
+ "fa": 456,
+ "wieder": 457,
+ "daß": 458,
+ "mein": 459,
+ "fen": 460,
+ "ganz": 461,
+ "diese": 462,
+ "ster": 463,
+ "dar": 464,
+ "wa": 465,
+ "ges": 466,
+ "na": 467,
+ "fl": 468,
+ "igen": 469,
+ "sche": 470,
+ "ungen": 471,
+ "mehr": 472,
+ "ßen": 473,
+ "ot": 474,
+ "kon": 475,
+ "gew": 476,
+ "haben": 477,
+ "geh": 478,
+ "ät": 479,
+ "sind": 480,
+ "dr": 481,
+ "wel": 482,
+ "uns": 483,
+ "vo": 484,
+ "ma": 485,
+ "ute": 486,
+ "schon": 487,
+ "bes": 488,
+ "gesch": 489,
+ "bt": 490,
+ "che": 491,
+ "son": 492,
+ "ob": 493,
+ "la": 494,
+ "rück": 495,
+ "seine": 496,
+ "kr": 497,
+ "fre": 498,
+ "eil": 499,
+ "zum": 500,
+ "hier": 501,
+ "kt": 502,
+ "ige": 503,
+ "spr": 504,
+ "leben": 505,
+ "bst": 506,
+ "zeit": 507,
+ "gro": 508,
+ "denn": 509,
+ "ho": 510,
+ "scha": 511,
+ "bar": 512,
+ "alle": 513,
+ "gegen": 514,
+ "wür": 515,
+ "mü": 516,
+ "ze": 517,
+ "werden": 518,
+ "jetzt": 519,
+ "kommen": 520,
+ "nie": 521,
+ "sei": 522,
+ "heit": 523,
+ "soll": 524,
+ "glei": 525,
+ "meine": 526,
+ "woll": 527,
+ "ner": 528,
+ "habe": 529,
+ "wur": 530,
+ "lichen": 531,
+ "assen": 532,
+ "nte": 533,
+ "sehen": 534,
+ "wird": 535,
+ "bis": 536,
+ "gar": 537,
+ "ien": 538,
+ "mus": 539,
+ "uß": 540,
+ "är": 541,
+ "stell": 542,
+ "keit": 543,
+ "zwei": 544,
+ "selbst": 545,
+ "sta": 546,
+ "pa": 547,
+ "sagte": 548,
+ "tet": 549,
+ "kam": 550,
+ "ssen": 551,
+ "viel": 552,
+ "ug": 553,
+ "zen": 554,
+ "hei": 555,
+ "mann": 556,
+ "will": 557,
+ "geb": 558,
+ "waren": 559,
+ "ück": 560,
+ "äch": 561,
+ "mer": 562,
+ "ru": 563,
+ "hau": 564,
+ "eigen": 565,
+ "ang": 566,
+ "weg": 567,
+ "blick": 568,
+ "fra": 569,
+ "alles": 570,
+ "ka": 571,
+ "augen": 572,
+ "fin": 573,
+ "liche": 574,
+ "unser": 575,
+ "dern": 576,
+ "herr": 577,
+ "nun": 578,
+ "vie": 579,
+ "chte": 580,
+ "wohl": 581,
+ "fall": 582,
+ "ht": 583,
+ "ün": 584,
+ "etwas": 585,
+ "stand": 586,
+ "äu": 587,
+ "mö": 588,
+ "tel": 589,
+ "rie": 590,
+ "dich": 591,
+ "dies": 592,
+ "hand": 593,
+ "bin": 594,
+ "ffen": 595,
+ "nichts": 596,
+ "dan": 597,
+ "hne": 598,
+ "ihnen": 599,
+ "esen": 600,
+ "dieser": 601,
+ "frau": 602,
+ "art": 603,
+ "dir": 604,
+ "isch": 605,
+ "erst": 606,
+ "gleich": 607,
+ "komm": 608,
+ "hör": 609,
+ "ße": 610,
+ "dig": 611,
+ "sehr": 612,
+ "zei": 613,
+ "sam": 614,
+ "aum": 615,
+ "hät": 616,
+ "ingen": 617,
+ "gut": 618,
+ "mut": 619,
+ "cken": 620,
+ "konnte": 621,
+ "stimm": 622,
+ "zur": 623,
+ "itz": 624,
+ "weil": 625,
+ "würde": 626,
+ "fä": 627,
+ "können": 628,
+ "keine": 629,
+ "fer": 630,
+ "ischen": 631,
+ "voll": 632,
+ "eines": 633,
+ "setz": 634,
+ "zie": 635,
+ "del": 636,
+ "tete": 637,
+ "seiner": 638,
+ "ieren": 639,
+ "gest": 640,
+ "zurück": 641,
+ "wurde": 642,
+ "schn": 643,
+ "pr": 644,
+ "ließ": 645,
+ "tra": 646,
+ "mä": 647,
+ "gend": 648,
+ "fol": 649,
+ "ik": 650,
+ "schla": 651,
+ "schaft": 652,
+ "ater": 653,
+ "weiß": 654,
+ "seinen": 655,
+ "lassen": 656,
+ "lu": 657,
+ "unden": 658,
+ "teil": 659,
+ "neu": 660,
+ "iert": 661,
+ "menschen": 662,
+ "hmen": 663,
+ "str": 664,
+ "gi": 665,
+ "sah": 666,
+ "ihren": 667,
+ "eln": 668,
+ "weiter": 669,
+ "gehen": 670,
+ "iger": 671,
+ "macht": 672,
+ "tag": 673,
+ "also": 674,
+ "halten": 675,
+ "nis": 676,
+ "acht": 677,
+ "geben": 678,
+ "og": 679,
+ "nat": 680,
+ "mar": 681,
+ "det": 682,
+ "ohne": 683,
+ "haus": 684,
+ "tro": 685,
+ "ange": 686,
+ "lau": 687,
+ "spiel": 688,
+ "tre": 689,
+ "schr": 690,
+ "inn": 691,
+ "los": 692,
+ "machen": 693,
+ "hätte": 694,
+ "beg": 695,
+ "wirk": 696,
+ "alt": 697,
+ "glich": 698,
+ "tes": 699,
+ "richt": 700,
+ "freund": 701,
+ "ihrer": 702,
+ "fel": 703,
+ "bel": 704,
+ "sol": 705,
+ "einmal": 706,
+ "eben": 707,
+ "hol": 708,
+ "hän": 709,
+ "tern": 710,
+ "hö": 711,
+ "schw": 712,
+ "recht": 713,
+ "wahr": 714,
+ "seinem": 715,
+ "stehen": 716,
+ "hlen": 717,
+ "ins": 718,
+ "ging": 719,
+ "wollte": 720,
+ "wissen": 721,
+ "ungs": 722,
+ "ald": 723,
+ "ass": 724,
+ "jahr": 725,
+ "mor": 726,
+ "welt": 727,
+ "under": 728,
+ "zusa": 729,
+ "kopf": 730,
+ "lang": 731,
+ "hinter": 732,
+ "atz": 733,
+ "stra": 734,
+ "angen": 735,
+ "ank": 736,
+ "ade": 737,
+ "glau": 738,
+ "fach": 739,
+ "hatten": 740,
+ "fort": 741,
+ "eicht": 742,
+ "iff": 743,
+ "ler": 744,
+ "mei": 745,
+ "diesem": 746,
+ "kein": 747,
+ "frei": 748,
+ "führ": 749,
+ "vom": 750,
+ "β": 751,
+ "ai": 752,
+ "ait": 753,
+ "que": 754,
+ "les": 755,
+ "av": 756,
+ "ais": 757,
+ "oi": 758,
+ "eu": 759,
+ "lle": 760,
+ "par": 761,
+ "ans": 762,
+ "ment": 763,
+ "ét": 764,
+ "une": 765,
+ "pas": 766,
+ "qui": 767,
+ "elle": 768,
+ "dé": 769,
+ "pour": 770,
+ "dans": 771,
+ "ré": 772,
+ "tou": 773,
+ "vous": 774,
+ "vi": 775,
+ "ouv": 776,
+ "mon": 777,
+ "sur": 778,
+ "ci": 779,
+ "plu": 780,
+ "ère": 781,
+ "mais": 782,
+ "ois": 783,
+ "plus": 784,
+ "ée": 785,
+ "aient": 786,
+ "mp": 787,
+ "lui": 788,
+ "ave": 789,
+ "était": 790,
+ "ses": 791,
+ "tout": 792,
+ "oir": 793,
+ "avait": 794,
+ "és": 795,
+ "mes": 796,
+ "nous": 797,
+ "eux": 798,
+ "bi": 799,
+ "ons": 800,
+ "pu": 801,
+ "ces": 802,
+ "tu": 803,
+ "leur": 804,
+ "don": 805,
+ "eur": 806,
+ "ette": 807,
+ "aire": 808,
+ "avec": 809,
+ "dit": 810,
+ "té": 811,
+ "ille": 812,
+ "comme": 813,
+ "cr": 814,
+ "ux": 815,
+ "ès": 816,
+ "aux": 817,
+ "jour": 818,
+ "ils": 819,
+ "bien": 820,
+ "cou": 821,
+ "quel": 822,
+ "peu": 823,
+ "cette": 824,
+ "cu": 825,
+ "mê": 826,
+ "fait": 827,
+ "gu": 828,
+ "être": 829,
+ "ité": 830,
+ "ens": 831,
+ "ni": 832,
+ "lé": 833,
+ "dis": 834,
+ "ble": 835,
+ "né": 836,
+ "puis": 837,
+ "même": 838,
+ "ques": 839,
+ "fi": 840,
+ "age": 841,
+ "moi": 842,
+ "ence": 843,
+ "ont": 844,
+ "main": 845,
+ "ors": 846,
+ "aut": 847,
+ "ance": 848,
+ "mé": 849,
+ "sans": 850,
+ "sé": 851,
+ "lon": 852,
+ "hom": 853,
+ "car": 854,
+ "able": 855,
+ "cher": 856,
+ "deux": 857,
+ "enf": 858,
+ "où": 859,
+ "ph": 860,
+ "ure": 861,
+ "temp": 862,
+ "pos": 863,
+ "rent": 864,
+ "pé": 865,
+ "faire": 866,
+ "pi": 867,
+ "tres": 868,
+ "ça": 869,
+ "endre": 870,
+ "bon": 871,
+ "sou": 872,
+ "int": 873,
+ "pré": 874,
+ "sent": 875,
+ "tant": 876,
+ "cer": 877,
+ "là": 878,
+ "lais": 879,
+ "près": 880,
+ "bre": 881,
+ "cour": 882,
+ "pet": 883,
+ "comp": 884,
+ "lait": 885,
+ "trouv": 886,
+ "entre": 887,
+ "sont": 888,
+ "dev": 889,
+ "nu": 890,
+ "temps": 891,
+ "dou": 892,
+ "rait": 893,
+ "bou": 894,
+ "quand": 895,
+ "jours": 896,
+ "avoir": 897,
+ "été": 898,
+ "ale": 899,
+ "pre": 900,
+ "fois": 901,
+ "orte": 902,
+ "vé": 903,
+ "non": 904,
+ "tous": 905,
+ "jus": 906,
+ "coup": 907,
+ "homme": 908,
+ "ête": 909,
+ "aussi": 910,
+ "urs": 911,
+ "seu": 912,
+ "ord": 913,
+ "min": 914,
+ "gé": 915,
+ "core": 916,
+ "va": 917,
+ "vre": 918,
+ "encore": 919,
+ "sem": 920,
+ "ite": 921,
+ "autre": 922,
+ "pris": 923,
+ "peut": 924,
+ "ue": 925,
+ "ante": 926,
+ "gn": 927,
+ "rép": 928,
+ "hu": 929,
+ "sion": 930,
+ "votre": 931,
+ "dire": 932,
+ "ez": 933,
+ "fem": 934,
+ "leurs": 935,
+ "met": 936,
+ "cri": 937,
+ "mis": 938,
+ "tour": 939,
+ "rai": 940,
+ "jam": 941,
+ "regar": 942,
+ "rien": 943,
+ "vers": 944,
+ "suis": 945,
+ "pouv": 946,
+ "vis": 947,
+ "grand": 948,
+ "ants": 949,
+ "cor": 950,
+ "rer": 951,
+ "cé": 952,
+ "tent": 953,
+ "pres": 954,
+ "vou": 955,
+ "alors": 956,
+ "sieur": 957,
+ "aine": 958,
+ "quoi": 959,
+ "fon": 960,
+ "endant": 961,
+ "arri": 962,
+ "eure": 963,
+ "après": 964,
+ "donc": 965,
+ "itu": 966,
+ "lè": 967,
+ "sait": 968,
+ "toi": 969,
+ "cha": 970,
+ "ail": 971,
+ "asse": 972,
+ "imp": 973,
+ "voy": 974,
+ "conn": 975,
+ "pla": 976,
+ "petit": 977,
+ "avant": 978,
+ "nom": 979,
+ "tin": 980,
+ "dont": 981,
+ "sous": 982,
+ "emp": 983,
+ "person": 984,
+ "elles": 985,
+ "beau": 986,
+ "parti": 987,
+ "cho": 988,
+ "prit": 989,
+ "toujours": 990,
+ "rais": 991,
+ "jamais": 992,
+ "trav": 993,
+ "tions": 994,
+ "très": 995,
+ "voi": 996,
+ "ren": 997,
+ "yeux": 998,
+ "voir": 999,
+ "premi": 1000,
+ "gne": 1001,
+ "heure": 1002,
+ "rou": 1003,
+ "eff": 1004,
+ "notre": 1005,
+ "ments": 1006,
+ "ton": 1007,
+ "fais": 1008,
+ "cela": 1009,
+ "répon": 1010,
+ "cons": 1011,
+ "air": 1012,
+ "ôt": 1013,
+ "pendant": 1014,
+ "ici": 1015,
+ "toute": 1016,
+ "jet": 1017,
+ "port": 1018,
+ "étaient": 1019,
+ "pen": 1020,
+ "hé": 1021,
+ "autres": 1022,
+ "père": 1023,
+ "oc": 1024,
+ "quelques": 1025,
+ "ique": 1026,
+ "lis": 1027,
+ "femme": 1028,
+ "jou": 1029,
+ "teur": 1030,
+ "monde": 1031,
+ "nes": 1032,
+ "dre": 1033,
+ "aff": 1034,
+ "rap": 1035,
+ "part": 1036,
+ "lement": 1037,
+ "cla": 1038,
+ "fut": 1039,
+ "quelque": 1040,
+ "prendre": 1041,
+ "rê": 1042,
+ "aille": 1043,
+ "sais": 1044,
+ "ches": 1045,
+ "let": 1046,
+ "char": 1047,
+ "ères": 1048,
+ "ents": 1049,
+ "moins": 1050,
+ "eau": 1051,
+ "aî": 1052,
+ "jeu": 1053,
+ "heur": 1054,
+ "ées": 1055,
+ "tri": 1056,
+ "point": 1057,
+ "mom": 1058,
+ "vent": 1059,
+ "nouv": 1060,
+ "gran": 1061,
+ "trois": 1062,
+ "sant": 1063,
+ "toutes": 1064,
+ "contre": 1065,
+ "èrent": 1066,
+ "chez": 1067,
+ "avez": 1068,
+ "ût": 1069,
+ "att": 1070,
+ "pau": 1071,
+ "porte": 1072,
+ "ouver": 1073,
+ "lit": 1074,
+ "prés": 1075,
+ "chose": 1076,
+ "vit": 1077,
+ "monsieur": 1078,
+ "hab": 1079,
+ "tête": 1080,
+ "ju": 1081,
+ "tement": 1082,
+ "ction": 1083,
+ "vrai": 1084,
+ "lar": 1085,
+ "cet": 1086,
+ "regard": 1087,
+ "lant": 1088,
+ "som": 1089,
+ "moment": 1090,
+ "illes": 1091,
+ "ple": 1092,
+ "ps": 1093,
+ "mère": 1094,
+ "cl": 1095,
+ "sour": 1096,
+ "ys": 1097,
+ "trop": 1098,
+ "enne": 1099,
+ "jusqu": 1100,
+ "avaient": 1101,
+ "avais": 1102,
+ "jeune": 1103,
+ "depuis": 1104,
+ "personne": 1105,
+ "fit": 1106,
+ "cert": 1107,
+ "jo": 1108,
+ "oui": 1109,
+ "rest": 1110,
+ "semb": 1111,
+ "cap": 1112,
+ "mat": 1113,
+ "mu": 1114,
+ "long": 1115,
+ "fran": 1116,
+ "faut": 1117,
+ "iti": 1118,
+ "bli": 1119,
+ "chev": 1120,
+ "pri": 1121,
+ "ente": 1122,
+ "ainsi": 1123,
+ "cham": 1124,
+ "lors": 1125,
+ "cas": 1126,
+ "ili": 1127,
+ "bé": 1128,
+ "nos": 1129,
+ "sui": 1130,
+ "rit": 1131,
+ "cro": 1132,
+ "gue": 1133,
+ "ía": 1134,
+ "por": 1135,
+ "las": 1136,
+ "ón": 1137,
+ "una": 1138,
+ "aba": 1139,
+ "dos": 1140,
+ "era": 1141,
+ "mb": 1142,
+ "para": 1143,
+ "ás": 1144,
+ "mos": 1145,
+ "ando": 1146,
+ "como": 1147,
+ "más": 1148,
+ "ción": 1149,
+ "tan": 1150,
+ "dad": 1151,
+ "ado": 1152,
+ "fu": 1153,
+ "cia": 1154,
+ "mente": 1155,
+ "sus": 1156,
+ "tar": 1157,
+ "za": 1158,
+ "ba": 1159,
+ "pero": 1160,
+ "sin": 1161,
+ "lla": 1162,
+ "án": 1163,
+ "ia": 1164,
+ "ran": 1165,
+ "ga": 1166,
+ "yo": 1167,
+ "tos": 1168,
+ "cos": 1169,
+ "ya": 1170,
+ "ones": 1171,
+ "había": 1172,
+ "hi": 1173,
+ "esta": 1174,
+ "mas": 1175,
+ "tor": 1176,
+ "aban": 1177,
+ "dor": 1178,
+ "ían": 1179,
+ "tas": 1180,
+ "én": 1181,
+ "endo": 1182,
+ "aque": 1183,
+ "ero": 1184,
+ "io": 1185,
+ "qué": 1186,
+ "cab": 1187,
+ "tal": 1188,
+ "señ": 1189,
+ "ora": 1190,
+ "todo": 1191,
+ "sal": 1192,
+ "cuando": 1193,
+ "gun": 1194,
+ "bu": 1195,
+ "ras": 1196,
+ "esto": 1197,
+ "pare": 1198,
+ "él": 1199,
+ "tras": 1200,
+ "jos": 1201,
+ "mien": 1202,
+ "pue": 1203,
+ "cre": 1204,
+ "pon": 1205,
+ "día": 1206,
+ "tros": 1207,
+ "sab": 1208,
+ "sobre": 1209,
+ "ese": 1210,
+ "mbre": 1211,
+ "eron": 1212,
+ "añ": 1213,
+ "ido": 1214,
+ "porque": 1215,
+ "ella": 1216,
+ "cen": 1217,
+ "muy": 1218,
+ "cal": 1219,
+ "este": 1220,
+ "has": 1221,
+ "có": 1222,
+ "gra": 1223,
+ "ros": 1224,
+ "aquel": 1225,
+ "dijo": 1226,
+ "cía": 1227,
+ "zo": 1228,
+ "ciones": 1229,
+ "mbi": 1230,
+ "elo": 1231,
+ "tó": 1232,
+ "ina": 1233,
+ "todos": 1234,
+ "tien": 1235,
+ "estaba": 1236,
+ "deci": 1237,
+ "cio": 1238,
+ "ño": 1239,
+ "lor": 1240,
+ "nues": 1241,
+ "medi": 1242,
+ "len": 1243,
+ "vida": 1244,
+ "ali": 1245,
+ "pues": 1246,
+ "ales": 1247,
+ "vol": 1248,
+ "mí": 1249,
+ "rar": 1250,
+ "cion": 1251,
+ "hasta": 1252,
+ "señor": 1253,
+ "cono": 1254,
+ "ah": 1255,
+ "dios": 1256,
+ "esa": 1257,
+ "ún": 1258,
+ "var": 1259,
+ "san": 1260,
+ "gui": 1261,
+ "otros": 1262,
+ "tado": 1263,
+ "buen": 1264,
+ "ña": 1265,
+ "tiemp": 1266,
+ "hacer": 1267,
+ "jer": 1268,
+ "vu": 1269,
+ "ana": 1270,
+ "así": 1271,
+ "antes": 1272,
+ "vez": 1273,
+ "miento": 1274,
+ "jar": 1275,
+ "lab": 1276,
+ "casa": 1277,
+ "eso": 1278,
+ "ego": 1279,
+ "dió": 1280,
+ "está": 1281,
+ "encia": 1282,
+ "eli": 1283,
+ "ías": 1284,
+ "tiempo": 1285,
+ "zar": 1286,
+ "van": 1287,
+ "mun": 1288,
+ "erta": 1289,
+ "tambi": 1290,
+ "sí": 1291,
+ "aun": 1292,
+ "mismo": 1293,
+ "entes": 1294,
+ "mano": 1295,
+ "ele": 1296,
+ "nada": 1297,
+ "segu": 1298,
+ "mej": 1299,
+ "erra": 1300,
+ "tir": 1301,
+ "uno": 1302,
+ "donde": 1303,
+ "toda": 1304,
+ "desde": 1305,
+ "también": 1306,
+ "cuer": 1307,
+ "hombre": 1308,
+ "otro": 1309,
+ "lib": 1310,
+ "trar": 1311,
+ "cual": 1312,
+ "hay": 1313,
+ "cada": 1314,
+ "taba": 1315,
+ "mento": 1316,
+ "tenía": 1317,
+ "quer": 1318,
+ "eran": 1319,
+ "siemp": 1320,
+ "siempre": 1321,
+ "erto": 1322,
+ "quí": 1323,
+ "gos": 1324,
+ "pués": 1325,
+ "ellos": 1326,
+ "después": 1327,
+ "nue": 1328,
+ "llo": 1329,
+ "inter": 1330,
+ "cómo": 1331,
+ "ahora": 1332,
+ "uste": 1333,
+ "traba": 1334,
+ "lado": 1335,
+ "ino": 1336,
+ "poco": 1337,
+ "erte": 1338,
+ "mujer": 1339,
+ "quier": 1340,
+ "algun": 1341,
+ "fue": 1342,
+ "ojos": 1343,
+ "enton": 1344,
+ "vos": 1345,
+ "esper": 1346,
+ "much": 1347,
+ "otra": 1348,
+ "az": 1349,
+ "eza": 1350,
+ "aquí": 1351,
+ "cias": 1352,
+ "gua": 1353,
+ "mucho": 1354,
+ "decir": 1355,
+ "esti": 1356,
+ "idad": 1357,
+ "algo": 1358,
+ "ocu": 1359,
+ "entonces": 1360,
+ "dido": 1361,
+ "entos": 1362,
+ "gri": 1363,
+ "dado": 1364,
+ "ios": 1365,
+ "dose": 1366,
+ "usted": 1367,
+ "quien": 1368,
+ "ami": 1369,
+ "unto": 1370,
+ "mejor": 1371,
+ "bas": 1372,
+ "solo": 1373,
+ "pregun": 1374,
+ "tur": 1375,
+ "alg": 1376,
+ "todas": 1377,
+ "parte": 1378,
+ "emb": 1379,
+ "cto": 1380,
+ "mundo": 1381,
+ "tiene": 1382,
+ "tante": 1383,
+ "palab": 1384,
+ "tran": 1385,
+ "aquella": 1386,
+ "cios": 1387,
+ "aunque": 1388,
+ "cuen": 1389,
+ "tener": 1390,
+ "fun": 1391,
+ "respon": 1392,
+ "allí": 1393,
+ "xi": 1394,
+ "han": 1395,
+ "pens": 1396,
+ "contra": 1397,
+ "tura": 1398,
+ "val": 1399,
+ "dio": 1400,
+ "tanto": 1401,
+ "camin": 1402,
+ "mó": 1403,
+ "esp": 1404,
+ "ada": 1405,
+ "ío": 1406,
+ "hacia": 1407,
+ "dej": 1408,
+ "estar": 1409,
+ "ión": 1410,
+ "gas": 1411,
+ "vas": 1412,
+ "noche": 1413,
+ "ér": 1414,
+ "años": 1415,
+ "padre": 1416,
+ "gus": 1417,
+ "ár": 1418,
+ "sino": 1419,
+ "manos": 1420,
+ "cido": 1421,
+ "estu": 1422,
+ "hubi": 1423,
+ "vir": 1424,
+ "bri": 1425,
+ "raz": 1426,
+ "chi": 1427,
+ "puede": 1428,
+ "menos": 1429,
+ "habi": 1430,
+ "homb": 1431,
+ "neces": 1432,
+ "may": 1433,
+ "eros": 1434,
+ "ría": 1435,
+ "hecho": 1436,
+ "escu": 1437,
+ "lti": 1438,
+ "ándo": 1439,
+ "bus": 1440,
+ "cosas": 1441,
+ "tú": 1442,
+ "espa": 1443,
+ "reci": 1444,
+ "ctor": 1445,
+ "prim": 1446,
+ "dia": 1447,
+ "dese": 1448,
+ "mientras": 1449,
+ "hor": 1450,
+ "fuer": 1451,
+ "ida": 1452,
+ "posi": 1453,
+ "lante": 1454,
+ "ano": 1455,
+ "estas": 1456,
+ "pli": 1457,
+ "luego": 1458,
+ "sión": 1459,
+ "cin": 1460,
+ "tierra": 1461,
+ "guar": 1462,
+ "cado": 1463,
+ "encon": 1464,
+ "pren": 1465,
+ "mayor": 1466,
+ "fal": 1467,
+ "ð": 1468,
+ "ħ": 1469,
+ "ň": 1470,
+ "ə": 1471,
+ "θ": 1472,
+ "’": 1473,
+ "“": 1474,
+ "”": 1475,
+ "zi": 1476,
+ "gli": 1477,
+ "tto": 1478,
+ "ono": 1479,
+ "nel": 1480,
+ "tti": 1481,
+ "della": 1482,
+ "zione": 1483,
+ "tta": 1484,
+ "tà": 1485,
+ "uo": 1486,
+ "come": 1487,
+ "alla": 1488,
+ "oni": 1489,
+ "ggi": 1490,
+ "ssi": 1491,
+ "più": 1492,
+ "ini": 1493,
+ "bb": 1494,
+ "sto": 1495,
+ "sono": 1496,
+ "eri": 1497,
+ "sse": 1498,
+ "sc": 1499,
+ "sul": 1500,
+ "vano": 1501,
+ "sti": 1502,
+ "suo": 1503,
+ "cchi": 1504,
+ "zza": 1505,
+ "anche": 1506,
+ "tte": 1507,
+ "sci": 1508,
+ "col": 1509,
+ "sso": 1510,
+ "ssa": 1511,
+ "dei": 1512,
+ "aveva": 1513,
+ "zz": 1514,
+ "amo": 1515,
+ "gno": 1516,
+ "sua": 1517,
+ "ria": 1518,
+ "sì": 1519,
+ "ché": 1520,
+ "dal": 1521,
+ "ona": 1522,
+ "spe": 1523,
+ "gni": 1524,
+ "tt": 1525,
+ "delle": 1526,
+ "questo": 1527,
+ "nella": 1528,
+ "dere": 1529,
+ "anno": 1530,
+ "dell": 1531,
+ "uni": 1532,
+ "bbe": 1533,
+ "anti": 1534,
+ "ene": 1535,
+ "gio": 1536,
+ "uto": 1537,
+ "qual": 1538,
+ "glia": 1539,
+ "quando": 1540,
+ "tutto": 1541,
+ "glio": 1542,
+ "zioni": 1543,
+ "cam": 1544,
+ "esso": 1545,
+ "ss": 1546,
+ "mol": 1547,
+ "loro": 1548,
+ "perché": 1549,
+ "cosa": 1550,
+ "due": 1551,
+ "poi": 1552,
+ "sco": 1553,
+ "cco": 1554,
+ "gna": 1555,
+ "tem": 1556,
+ "prima": 1557,
+ "così": 1558,
+ "essere": 1559,
+ "ani": 1560,
+ "bra": 1561,
+ "rio": 1562,
+ "anco": 1563,
+ "cui": 1564,
+ "spi": 1565,
+ "via": 1566,
+ "gior": 1567,
+ "bile": 1568,
+ "ggio": 1569,
+ "mai": 1570,
+ "tare": 1571,
+ "indi": 1572,
+ "rebbe": 1573,
+ "senza": 1574,
+ "zio": 1575,
+ "tutti": 1576,
+ "stato": 1577,
+ "zia": 1578,
+ "dalla": 1579,
+ "mia": 1580,
+ "vita": 1581,
+ "quella": 1582,
+ "qua": 1583,
+ "dove": 1584,
+ "allo": 1585,
+ "sempre": 1586,
+ "zzo": 1587,
+ "sia": 1588,
+ "dopo": 1589,
+ "porta": 1590,
+ "ccia": 1591,
+ "erano": 1592,
+ "anni": 1593,
+ "chia": 1594,
+ "enza": 1595,
+ "propri": 1596,
+ "anda": 1597,
+ "cca": 1598,
+ "occhi": 1599,
+ "questa": 1600,
+ "ffi": 1601,
+ "ron": 1602,
+ "mio": 1603,
+ "ris": 1604,
+ "ogni": 1605,
+ "rin": 1606,
+ "far": 1607,
+ "menti": 1608,
+ "ancora": 1609,
+ "fatto": 1610,
+ "mani": 1611,
+ "senti": 1612,
+ "pra": 1613,
+ "tempo": 1614,
+ "essi": 1615,
+ "bbi": 1616,
+ "lare": 1617,
+ "pers": 1618,
+ "sor": 1619,
+ "anza": 1620,
+ "pie": 1621,
+ "verso": 1622,
+ "altro": 1623,
+ "tato": 1624,
+ "cato": 1625,
+ "ato": 1626,
+ "volta": 1627,
+ "cc": 1628,
+ "fare": 1629,
+ "ciò": 1630,
+ "bili": 1631,
+ "nuo": 1632,
+ "quello": 1633,
+ "colo": 1634,
+ "ppo": 1635,
+ "trova": 1636,
+ "ore": 1637,
+ "rono": 1638,
+ "molto": 1639,
+ "almente": 1640,
+ "sca": 1641,
+ "vole": 1642,
+ "tali": 1643,
+ "sulla": 1644,
+ "sce": 1645,
+ "meno": 1646,
+ "anto": 1647,
+ "pun": 1648,
+ "stu": 1649,
+ "capi": 1650,
+ "giu": 1651,
+ "mini": 1652,
+ "pia": 1653,
+ "lavo": 1654,
+ "vero": 1655,
+ "rsi": 1656,
+ "altri": 1657,
+ "scia": 1658,
+ "suoi": 1659,
+ "glie": 1660,
+ "sotto": 1661,
+ "bene": 1662,
+ "scri": 1663,
+ "tale": 1664,
+ "degli": 1665,
+ "alc": 1666,
+ "uomo": 1667,
+ "pel": 1668,
+ "pote": 1669,
+ "essa": 1670,
+ "scu": 1671,
+ "signo": 1672,
+ "stro": 1673,
+ "uti": 1674,
+ "sione": 1675,
+ "gre": 1676,
+ "fini": 1677,
+ "lun": 1678,
+ "esi": 1679,
+ "passa": 1680,
+ "rà": 1681,
+ "mentre": 1682,
+ "hanno": 1683,
+ "usci": 1684,
+ "gia": 1685,
+ "già": 1686,
+ "mina": 1687,
+ "tica": 1688,
+ "giorno": 1689,
+ "esse": 1690,
+ "modo": 1691,
+ "spa": 1692,
+ "proprio": 1693,
+ "ori": 1694,
+ "contro": 1695,
+ "stru": 1696,
+ "diven": 1697,
+ "disse": 1698,
+ "rato": 1699,
+ "noi": 1700,
+ "vere": 1701,
+ "può": 1702,
+ "dice": 1703,
+ "cci": 1704,
+ "secon": 1705,
+ "ccio": 1706,
+ "qualche": 1707,
+ "tutta": 1708,
+ "gg": 1709,
+ "mondo": 1710,
+ "forma": 1711,
+ "mma": 1712,
+ "pensa": 1713,
+ "deva": 1714,
+ "fosse": 1715,
+ "sopra": 1716,
+ "tamente": 1717,
+ "ness": 1718,
+ "quanto": 1719,
+ "raga": 1720,
+ "unque": 1721,
+ "care": 1722,
+ "stre": 1723,
+ "grande": 1724,
+ "picco": 1725,
+ "guarda": 1726,
+ "nell": 1727,
+ "possi": 1728,
+ "presen": 1729,
+ "rò": 1730,
+ "paro": 1731,
+ "tua": 1732,
+ "vin": 1733,
+ "ane": 1734,
+ "stesso": 1735,
+ "dav": 1736,
+ "nei": 1737,
+ "nelle": 1738,
+ "ghi": 1739,
+ "pio": 1740,
+ "lato": 1741,
+ "sid": 1742,
+ "fine": 1743,
+ "fuo": 1744,
+ "quasi": 1745,
+ "ulti": 1746,
+ "ito": 1747,
+ "sue": 1748,
+ "fil": 1749,
+ "allora": 1750,
+ "veni": 1751,
+ "tano": 1752,
+ "ello": 1753,
+ "ão": 1754,
+ "não": 1755,
+ "uma": 1756,
+ "ela": 1757,
+ "lh": 1758,
+ "ção": 1759,
+ "cê": 1760,
+ "inha": 1761,
+ "você": 1762,
+ "ec": 1763,
+ "dade": 1764,
+ "ao": 1765,
+ "ram": 1766,
+ "vel": 1767,
+ "ém": 1768,
+ "pode": 1769,
+ "estava": 1770,
+ "isso": 1771,
+ "mui": 1772,
+ "faz": 1773,
+ "ões": 1774,
+ "pes": 1775,
+ "ix": 1776,
+ "sim": 1777,
+ "olh": 1778,
+ "isa": 1779,
+ "ên": 1780,
+ "tinha": 1781,
+ "meu": 1782,
+ "são": 1783,
+ "minha": 1784,
+ "muito": 1785,
+ "foi": 1786,
+ "bem": 1787,
+ "diz": 1788,
+ "parec": 1789,
+ "ço": 1790,
+ "pesso": 1791,
+ "pois": 1792,
+ "mesmo": 1793,
+ "ções": 1794,
+ "seus": 1795,
+ "até": 1796,
+ "ência": 1797,
+ "lhe": 1798,
+ "tiv": 1799,
+ "mã": 1800,
+ "só": 1801,
+ "tão": 1802,
+ "tudo": 1803,
+ "então": 1804,
+ "inda": 1805,
+ "bal": 1806,
+ "indo": 1807,
+ "ndo": 1808,
+ "já": 1809,
+ "vam": 1810,
+ "eito": 1811,
+ "depois": 1812,
+ "mel": 1813,
+ "lha": 1814,
+ "ainda": 1815,
+ "fazer": 1816,
+ "pou": 1817,
+ "pergun": 1818,
+ "deix": 1819,
+ "tamb": 1820,
+ "ala": 1821,
+ "pelo": 1822,
+ "também": 1823,
+ "fica": 1824,
+ "prec": 1825,
+ "eles": 1826,
+ "havia": 1827,
+ "lá": 1828,
+ "nas": 1829,
+ "gem": 1830,
+ "mem": 1831,
+ "ós": 1832,
+ "deu": 1833,
+ "eiro": 1834,
+ "..": 1835,
+ "assim": 1836,
+ "ior": 1837,
+ "har": 1838,
+ "aqui": 1839,
+ "cul": 1840,
+ "sar": 1841,
+ "outra": 1842,
+ "olhos": 1843,
+ "ima": 1844,
+ "mim": 1845,
+ "ago": 1846,
+ "pessoas": 1847,
+ "eram": 1848,
+ "eira": 1849,
+ "pela": 1850,
+ "coisa": 1851,
+ "mão": 1852,
+ "conh": 1853,
+ "agora": 1854,
+ "iam": 1855,
+ "há": 1856,
+ "suas": 1857,
+ "guém": 1858,
+ "cabe": 1859,
+ "nem": 1860,
+ "ível": 1861,
+ "consegu": 1862,
+ "trabal": 1863,
+ "lev": 1864,
+ "lem": 1865,
+ "vai": 1866,
+ "tei": 1867,
+ "pró": 1868,
+ "quem": 1869,
+ "onde": 1870,
+ "cabeça": 1871,
+ "nunca": 1872,
+ "mentos": 1873,
+ "hum": 1874,
+ "dele": 1875,
+ "verdade": 1876,
+ "tá": 1877,
+ "hos": 1878,
+ "algum": 1879,
+ "dizer": 1880,
+ "penas": 1881,
+ "nós": 1882,
+ "enquanto": 1883,
+ "outro": 1884,
+ "lho": 1885,
+ "melhor": 1886,
+ "primei": 1887,
+ "iu": 1888,
+ "apenas": 1889,
+ "estou": 1890,
+ "conte": 1891,
+ "homem": 1892,
+ "dois": 1893,
+ "ças": 1894,
+ "pouco": 1895,
+ "senhor": 1896,
+ "tando": 1897,
+ "espera": 1898,
+ "pai": 1899,
+ "rios": 1900,
+ "baix": 1901,
+ "ase": 1902,
+ "isas": 1903,
+ "hora": 1904,
+ "ficar": 1905,
+ "seja": 1906,
+ "ân": 1907,
+ "clar": 1908,
+ "inc": 1909,
+ "fos": 1910,
+ "ouvi": 1911,
+ "vem": 1912,
+ "tava": 1913,
+ "ário": 1914,
+ "sos": 1915,
+ "inho": 1916,
+ "rando": 1917,
+ "ês": 1918,
+ "coisas": 1919,
+ "aconte": 1920,
+ "lher": 1921,
+ "anos": 1922,
+ "talvez": 1923,
+ "estão": 1924,
+ "liv": 1925,
+ "outros": 1926,
+ "qualquer": 1927,
+ "gou": 1928,
+ "lí": 1929,
+ "tivesse": 1930,
+ "rado": 1931,
+ "precisa": 1932,
+ "mãe": 1933,
+ "dela": 1934,
+ "entra": 1935,
+ "maior": 1936,
+ "noite": 1937,
+ "tiva": 1938,
+ "pala": 1939,
+ "ração": 1940,
+ "deus": 1941,
+ "sas": 1942,
+ "inte": 1943,
+ "fei": 1944,
+ "palav": 1945,
+ "trás": 1946,
+ "cidade": 1947,
+ "lugar": 1948,
+ "vezes": 1949,
+ "encontra": 1950,
+ "tru": 1951,
+ "eci": 1952,
+ "ın": 1953,
+ "bir": 1954,
+ "yor": 1955,
+ "ek": 1956,
+ "dı": 1957,
+ "ey": 1958,
+ "tı": 1959,
+ "mı": 1960,
+ "iz": 1961,
+ "ır": 1962,
+ "gö": 1963,
+ "sı": 1964,
+ "bil": 1965,
+ "lı": 1966,
+ "üz": 1967,
+ "iç": 1968,
+ "iy": 1969,
+ "ım": 1970,
+ "uz": 1971,
+ "cak": 1972,
+ "iş": 1973,
+ "ını": 1974,
+ "iyor": 1975,
+ "baş": 1976,
+ "dü": 1977,
+ "değ": 1978,
+ "kar": 1979,
+ "ev": 1980,
+ "öy": 1981,
+ "bun": 1982,
+ "yap": 1983,
+ "sun": 1984,
+ "gör": 1985,
+ "yı": 1986,
+ "ki": 1987,
+ "ara": 1988,
+ "alı": 1989,
+ "onu": 1990,
+ "çı": 1991,
+ "şey": 1992,
+ "sın": 1993,
+ "kı": 1994,
+ "kad": 1995,
+ "ağ": 1996,
+ "değil": 1997,
+ "ük": 1998,
+ "çok": 1999,
+ "şı": 2000,
+ "ül": 2001,
+ "için": 2002,
+ "eye": 2003,
+ "oldu": 2004,
+ "mış": 2005,
+ "kal": 2006,
+ "mek": 2007,
+ "öyle": 2008,
+ "yordu": 2009,
+ "yüz": 2010,
+ "miş": 2011,
+ "mak": 2012,
+ "ola": 2013,
+ "yan": 2014,
+ "cek": 2015,
+ "yorum": 2016,
+ "bak": 2017,
+ "üm": 2018,
+ "ları": 2019,
+ "oğ": 2020,
+ "kadar": 2021,
+ "arı": 2022,
+ "ında": 2023,
+ "gün": 2024,
+ "yok": 2025,
+ "yer": 2026,
+ "dım": 2027,
+ "daha": 2028,
+ "ına": 2029,
+ "dim": 2030,
+ "bilir": 2031,
+ "iki": 2032,
+ "siz": 2033,
+ "diğ": 2034,
+ "bü": 2035,
+ "düş": 2036,
+ "üç": 2037,
+ "unu": 2038,
+ "aman": 2039,
+ "fak": 2040,
+ "ede": 2041,
+ "sonra": 2042,
+ "hiç": 2043,
+ "aki": 2044,
+ "ğı": 2045,
+ "bul": 2046,
+ "maz": 2047,
+ "anla": 2048,
+ "bura": 2049,
+ "geç": 2050,
+ "maya": 2051,
+ "konu": 2052,
+ "din": 2053,
+ "tek": 2054,
+ "zaman": 2055,
+ "eler": 2056,
+ "öz": 2057,
+ "dır": 2058,
+ "gibi": 2059,
+ "şa": 2060,
+ "leri": 2061,
+ "kim": 2062,
+ "ku": 2063,
+ "fakat": 2064,
+ "yar": 2065,
+ "göz": 2066,
+ "cı": 2067,
+ "yorsun": 2068,
+ "bek": 2069,
+ "inde": 2070,
+ "pek": 2071,
+ "bunu": 2072,
+ "lik": 2073,
+ "iler": 2074,
+ "edi": 2075,
+ "öl": 2076,
+ "sür": 2077,
+ "sır": 2078,
+ "çık": 2079,
+ "sıl": 2080,
+ "alar": 2081,
+ "kes": 2082,
+ "yak": 2083,
+ "çek": 2084,
+ "yıl": 2085,
+ "ecek": 2086,
+ "ız": 2087,
+ "git": 2088,
+ "kap": 2089,
+ "ama": 2090,
+ "ıl": 2091,
+ "ların": 2092,
+ "biz": 2093,
+ "tır": 2094,
+ "oy": 2095,
+ "ancak": 2096,
+ "doğ": 2097,
+ "bana": 2098,
+ "şim": 2099,
+ "başla": 2100,
+ "lü": 2101,
+ "madı": 2102,
+ "beni": 2103,
+ "yük": 2104,
+ "lık": 2105,
+ "beş": 2106,
+ "nasıl": 2107,
+ "tık": 2108,
+ "tür": 2109,
+ "daki": 2110,
+ "ceğ": 2111,
+ "zı": 2112,
+ "iyi": 2113,
+ "dok": 2114,
+ "benim": 2115,
+ "cağ": 2116,
+ "yen": 2117,
+ "şu": 2118,
+ "mez": 2119,
+ "düşün": 2120,
+ "kendi": 2121,
+ "şimdi": 2122,
+ "yol": 2123,
+ "yu": 2124,
+ "iste": 2125,
+ "sek": 2126,
+ "mam": 2127,
+ "söyle": 2128,
+ "dik": 2129,
+ "kur": 2130,
+ "olduğ": 2131,
+ "sını": 2132,
+ "biliyor": 2133,
+ "kan": 2134,
+ "yal": 2135,
+ "meye": 2136,
+ "muş": 2137,
+ "kaç": 2138,
+ "iye": 2139,
+ "tü": 2140,
+ "ef": 2141,
+ "tım": 2142,
+ "evet": 2143,
+ "yet": 2144,
+ "burada": 2145,
+ "tim": 2146,
+ "biraz": 2147,
+ "kor": 2148,
+ "doğru": 2149,
+ "inin": 2150,
+ "kız": 2151,
+ "diye": 2152,
+ "dör": 2153,
+ "etti": 2154,
+ "onun": 2155,
+ "isti": 2156,
+ "ği": 2157,
+ "sana": 2158,
+ "üş": 2159,
+ "arka": 2160,
+ "hayır": 2161,
+ "karşı": 2162,
+ "ile": 2163,
+ "hak": 2164,
+ "ıyor": 2165,
+ "neden": 2166,
+ "sev": 2167,
+ "sız": 2168,
+ "çocu": 2169,
+ "çalı": 2170,
+ "olur": 2171,
+ "bır": 2172,
+ "gir": 2173,
+ "ise": 2174,
+ "ih": 2175,
+ "kır": 2176,
+ "dön": 2177,
+ "böyle": 2178,
+ "seni": 2179,
+ "!\"": 2180,
+ "dört": 2181,
+ "söy": 2182,
+ "oş": 2183,
+ "musun": 2184,
+ "laş": 2185,
+ "ip": 2186,
+ "kay": 2187,
+ "hem": 2188,
+ "büyük": 2189,
+ "aç": 2190,
+ "bırak": 2191,
+ "misin": 2192,
+ "söz": 2193,
+ "değiş": 2194,
+ "ünü": 2195,
+ "gül": 2196,
+ "kö": 2197,
+ "karı": 2198,
+ "tamam": 2199,
+ "olu": 2200,
+ "yeni": 2201,
+ "lam": 2202,
+ "mıştı": 2203,
+ "yaş": 2204,
+ "iniz": 2205,
+ "kadın": 2206,
+ "bunun": 2207,
+ "mey": 2208,
+ "altı": 2209,
+ "yi": 2210,
+ "inden": 2211,
+ "senin": 2212,
+ "yat": 2213,
+ "top": 2214,
+ "isi": 2215,
+ "dün": 2216,
+ "hiçbir": 2217,
+ "yon": 2218,
+ "dın": 2219,
+ "tün": 2220,
+ "başka": 2221,
+ "hep": 2222,
+ "irmi": 2223,
+ "devam": 2224,
+ "olacak": 2225,
+ "artık": 2226,
+ "durum": 2227,
+ "imiz": 2228,
+ "üzel": 2229,
+ "lerini": 2230,
+ "sağ": 2231,
+ "gerek": 2232,
+ "yirmi": 2233,
+ "şek": 2234,
+ "bağ": 2235,
+ "lara": 2236,
+ "yür": 2237,
+ "ması": 2238,
+ "katı": 2239,
+ "dedi": 2240,
+ "gü": 2241,
+ "sorun": 2242,
+ "üne": 2243,
+ "mız": 2244,
+ "yapı": 2245,
+ "mil": 2246,
+ "ğını": 2247,
+ "tara": 2248,
+ "vardı": 2249,
+ "konuş": 2250,
+ "arak": 2251,
+ "larak": 2252,
+ "çocuk": 2253,
+ "bütün": 2254,
+ "ley": 2255,
+ "dür": 2256,
+ "güzel": 2257,
+ "ayı": 2258,
+ "yapa": 2259,
+ "nı": 2260,
+ "ayr": 2261,
+ "öne": 2262,
+ "yordum": 2263,
+ "ban": 2264,
+ "i̇ş": 2265,
+ "dum": 2266,
+ "yorlar": 2267,
+ "larını": 2268,
+ "çıkar": 2269,
+ "zan": 2270,
+ "seç": 2271,
+ "liyor": 2272,
+ "tak": 2273,
+ "şık": 2274,
+ "tekrar": 2275,
+ "aş": 2276,
+ "eş": 2277,
+ "mişti": 2278,
+ "kin": 2279,
+ "imi": 2280,
+ "eğ": 2281,
+ "gidi": 2282,
+ "leş": 2283,
+ "başladı": 2284,
+ "gide": 2285,
+ "otur": 2286,
+ "dde": 2287,
+ "ından": 2288,
+ "üzer": 2289,
+ "ının": 2290,
+ "nız": 2291,
+ "uy": 2292,
+ "yedi": 2293,
+ "kat": 2294,
+ "olarak": 2295,
+ "ladı": 2296,
+ "yalnız": 2297,
+ "bah": 2298,
+ "iyet": 2299,
+ "sak": 2300,
+ "açık": 2301,
+ "sında": 2302,
+ "...": 2303,
+ "insan": 2304,
+ "aynı": 2305,
+ "eder": 2306,
+ "istan": 2307,
+ "uzun": 2308,
+ "geri": 2309,
+ "erek": 2310,
+ "olan": 2311,
+ "gerçek": 2312,
+ "alan": 2313,
+ "dış": 2314,
+ "alık": 2315,
+ "fark": 2316,
+ "üst": 2317,
+ "sade": 2318,
+ "kiş": 2319,
+ "ldı": 2320,
+ "zor": 2321,
+ "etir": 2322,
+ "herkes": 2323,
+ "ömer": 2324,
+ "unda": 2325,
+ "haf": 2326,
+ "buna": 2327,
+ "ydı": 2328,
+ "peki": 2329,
+ "adam": 2330,
+ "haz": 2331,
+ "sına": 2332,
+ "kapı": 2333,
+ "görüş": 2334,
+ "sadece": 2335,
+ "aldı": 2336,
+ "geldi": 2337,
+ "rz": 2338,
+ "sz": 2339,
+ "cz": 2340,
+ "ię": 2341,
+ "dz": 2342,
+ "ał": 2343,
+ "się": 2344,
+ "rze": 2345,
+ "że": 2346,
+ "wy": 2347,
+ "rzy": 2348,
+ "ła": 2349,
+ "ło": 2350,
+ "ny": 2351,
+ "dzie": 2352,
+ "dzi": 2353,
+ "czy": 2354,
+ "cie": 2355,
+ "prze": 2356,
+ "dy": 2357,
+ "kie": 2358,
+ "ry": 2359,
+ "ją": 2360,
+ "ów": 2361,
+ "przy": 2362,
+ "mie": 2363,
+ "szy": 2364,
+ "cze": 2365,
+ "bie": 2366,
+ "cy": 2367,
+ "nia": 2368,
+ "ści": 2369,
+ "sze": 2370,
+ "jest": 2371,
+ "ży": 2372,
+ "ną": 2373,
+ "któ": 2374,
+ "ała": 2375,
+ "mnie": 2376,
+ "ły": 2377,
+ "cza": 2378,
+ "jak": 2379,
+ "roz": 2380,
+ "ró": 2381,
+ "zna": 2382,
+ "łu": 2383,
+ "ść": 2384,
+ "wia": 2385,
+ "wszy": 2386,
+ "spo": 2387,
+ "gdy": 2388,
+ "wał": 2389,
+ "wię": 2390,
+ "łem": 2391,
+ "ję": 2392,
+ "sk": 2393,
+ "rę": 2394,
+ "dob": 2395,
+ "już": 2396,
+ "bę": 2397,
+ "ałem": 2398,
+ "sza": 2399,
+ "pod": 2400,
+ "dla": 2401,
+ "pan": 2402,
+ "nę": 2403,
+ "może": 2404,
+ "śli": 2405,
+ "ało": 2406,
+ "lko": 2407,
+ "nych": 2408,
+ "powie": 2409,
+ "cię": 2410,
+ "tylko": 2411,
+ "naj": 2412,
+ "tego": 2413,
+ "ski": 2414,
+ "nego": 2415,
+ "wszyst": 2416,
+ "szcze": 2417,
+ "jed": 2418,
+ "jej": 2419,
+ "two": 2420,
+ "ąd": 2421,
+ "śmy": 2422,
+ "czę": 2423,
+ "wać": 2424,
+ "jego": 2425,
+ "ża": 2426,
+ "sy": 2427,
+ "praw": 2428,
+ "tym": 2429,
+ "który": 2430,
+ "ały": 2431,
+ "trze": 2432,
+ "niej": 2433,
+ "nym": 2434,
+ "gło": 2435,
+ "jąc": 2436,
+ "mówi": 2437,
+ "ska": 2438,
+ "nej": 2439,
+ "słu": 2440,
+ "wła": 2441,
+ "będzie": 2442,
+ "dę": 2443,
+ "pó": 2444,
+ "bez": 2445,
+ "nic": 2446,
+ "pła": 2447,
+ "ście": 2448,
+ "są": 2449,
+ "trzy": 2450,
+ "kiem": 2451,
+ "był": 2452,
+ "mog": 2453,
+ "robi": 2454,
+ "tam": 2455,
+ "mię": 2456,
+ "zy": 2457,
+ "pew": 2458,
+ "myś": 2459,
+ "przed": 2460,
+ "sko": 2461,
+ "które": 2462,
+ "lę": 2463,
+ "wsze": 2464,
+ "ąc": 2465,
+ "było": 2466,
+ "sobie": 2467,
+ "py": 2468,
+ "cią": 2469,
+ "jeszcze": 2470,
+ "tę": 2471,
+ "czas": 2472,
+ "szę": 2473,
+ "gł": 2474,
+ "kę": 2475,
+ "czu": 2476,
+ "przez": 2477,
+ "sło": 2478,
+ "wz": 2479,
+ "kto": 2480,
+ "ków": 2481,
+ "czo": 2482,
+ "liśmy": 2483,
+ "więc": 2484,
+ "rą": 2485,
+ "wó": 2486,
+ "rza": 2487,
+ "ności": 2488,
+ "wet": 2489,
+ "nął": 2490,
+ "śmie": 2491,
+ "nawet": 2492,
+ "musi": 2493,
+ "swo": 2494,
+ "tej": 2495,
+ "wą": 2496,
+ "wu": 2497,
+ "wią": 2498,
+ "niu": 2499,
+ "czą": 2500,
+ "dzo": 2501,
+ "skie": 2502,
+ "jeśli": 2503,
+ "czego": 2504,
+ "chy": 2505,
+ "dł": 2506,
+ "tych": 2507,
+ "bym": 2508,
+ "żo": 2509,
+ "eś": 2510,
+ "sią": 2511,
+ "kiedy": 2512,
+ "wró": 2513,
+ "dze": 2514,
+ "dro": 2515,
+ "rów": 2516,
+ "pani": 2517,
+ "kul": 2518,
+ "nad": 2519,
+ "chwi": 2520,
+ "nim": 2521,
+ "być": 2522,
+ "chodzi": 2523,
+ "nio": 2524,
+ "dobrze": 2525,
+ "teraz": 2526,
+ "wokul": 2527,
+ "coś": 2528,
+ "kł": 2529,
+ "pier": 2530,
+ "gdzie": 2531,
+ "dzy": 2532,
+ "pię": 2533,
+ "dź": 2534,
+ "ką": 2535,
+ "gó": 2536,
+ "zda": 2537,
+ "chce": 2538,
+ "stę": 2539,
+ "świa": 2540,
+ "wszystko": 2541,
+ "peł": 2542,
+ "wiem": 2543,
+ "wiel": 2544,
+ "każ": 2545,
+ "rzu": 2546,
+ "sły": 2547,
+ "jedna": 2548,
+ "myśl": 2549,
+ "mój": 2550,
+ "jestem": 2551,
+ "óż": 2552,
+ "miej": 2553,
+ "moż": 2554,
+ "kła": 2555,
+ "resz": 2556,
+ "dłu": 2557,
+ "stwo": 2558,
+ "nię": 2559,
+ "masz": 2560,
+ "żeby": 2561,
+ "niem": 2562,
+ "jakie": 2563,
+ "sty": 2564,
+ "nią": 2565,
+ "wej": 2566,
+ "oj": 2567,
+ "sła": 2568,
+ "ność": 2569,
+ "zło": 2570,
+ "szczę": 2571,
+ "lej": 2572,
+ "wego": 2573,
+ "cał": 2574,
+ "dział": 2575,
+ "kich": 2576,
+ "dza": 2577,
+ "dzię": 2578,
+ "oczy": 2579,
+ "zosta": 2580,
+ "czło": 2581,
+ "nam": 2582,
+ "kil": 2583,
+ "szu": 2584,
+ "wę": 2585,
+ "miał": 2586,
+ "strze": 2587,
+ "cej": 2588,
+ "ej": 2589,
+ "znaj": 2590,
+ "dać": 2591,
+ "miejs": 2592,
+ "kró": 2593,
+ "kry": 2594,
+ "bardzo": 2595,
+ "śnie": 2596,
+ "lą": 2597,
+ "gie": 2598,
+ "ciebie": 2599,
+ "dni": 2600,
+ "potrze": 2601,
+ "wokulski": 2602,
+ "uwa": 2603,
+ "umie": 2604,
+ "jednak": 2605,
+ "kra": 2606,
+ "wróci": 2607,
+ "człowie": 2608,
+ "czyć": 2609,
+ "była": 2610,
+ "żeli": 2611,
+ "mę": 2612,
+ "cę": 2613,
+ "zrobi": 2614,
+ "mogę": 2615,
+ "prowa": 2616,
+ "rem": 2617,
+ "niech": 2618,
+ "cznie": 2619,
+ "kro": 2620,
+ "tą": 2621,
+ "chci": 2622,
+ "bro": 2623,
+ "dzieć": 2624,
+ "szą": 2625,
+ "pad": 2626,
+ "trz": 2627,
+ "jem": 2628,
+ "tów": 2629,
+ "dru": 2630,
+ "taj": 2631,
+ "rzekł": 2632,
+ "niego": 2633,
+ "takie": 2634,
+ "wała": 2635,
+ "towa": 2636,
+ "kapła": 2637,
+ "widzi": 2638,
+ "podob": 2639,
+ "dzę": 2640,
+ "tał": 2641,
+ "stęp": 2642,
+ "bą": 2643,
+ "poko": 2644,
+ "wem": 2645,
+ "gę": 2646,
+ "aby": 2647,
+ "albo": 2648,
+ "spra": 2649,
+ "zno": 2650,
+ "smo": 2651,
+ "jesz": 2652,
+ "księ": 2653,
+ "jesteś": 2654,
+ "poz": 2655,
+ "nigdy": 2656,
+ "ksią": 2657,
+ "cóż": 2658,
+ "ws": 2659,
+ "pow": 2660,
+ "tka": 2661,
+ "świe": 2662,
+ "szka": 2663,
+ "samo": 2664,
+ "sł": 2665,
+ "rzę": 2666,
+ "nale": 2667,
+ "chcesz": 2668,
+ "nik": 2669,
+ "pę": 2670,
+ "chyba": 2671,
+ "ciąg": 2672,
+ "jący": 2673,
+ "woj": 2674,
+ "nasze": 2675,
+ "mniej": 2676,
+ "więcej": 2677,
+ "zwy": 2678,
+ "osta": 2679,
+ "waż": 2680,
+ "śmier": 2681,
+ "wier": 2682,
+ "dzą": 2683,
+ "zaś": 2684,
+ "gdyby": 2685,
+ "jaki": 2686,
+ "wol": 2687,
+ "win": 2688,
+ "dą": 2689,
+ "ścia": 2690,
+ "rozma": 2691,
+ "wal": 2692,
+ "panie": 2693,
+ "star": 2694,
+ "kaz": 2695,
+ "jeżeli": 2696,
+ "wra": 2697,
+ "koń": 2698,
+ "siebie": 2699,
+ "znowu": 2700,
+ "czem": 2701,
+ "stwa": 2702,
+ "isto": 2703,
+ "pół": 2704,
+ "dał": 2705,
+ "kobie": 2706,
+ "ałam": 2707,
+ "wych": 2708,
+ "cesa": 2709,
+ "nich": 2710,
+ "zawsze": 2711,
+ "dzić": 2712,
+ "też": 2713,
+ "lepie": 2714,
+ "proszę": 2715,
+ "kre": 2716,
+ "twa": 2717,
+ "łą": 2718,
+ "chu": 2719,
+ "cą": 2720,
+ "prz": 2721,
+ "łe": 2722,
+ "szedł": 2723,
+ "odpowie": 2724,
+ "myśli": 2725,
+ "świą": 2726,
+ "ź": 2727,
+ "ł": 2728,
+ "&": 2729,
+ "=": 2730,
+ "ă": 2731,
+ "đ": 2732,
+ "ţ": 2733,
+ "–": 2734,
+ "‘": 2735,
+ "ij": 2736,
+ "aa": 2737,
+ "een": 2738,
+ "het": 2739,
+ "aar": 2740,
+ "oor": 2741,
+ "ijn": 2742,
+ "dat": 2743,
+ "oe": 2744,
+ "ijk": 2745,
+ "aan": 2746,
+ "voor": 2747,
+ "iet": 2748,
+ "zijn": 2749,
+ "niet": 2750,
+ "oo": 2751,
+ "moet": 2752,
+ "heb": 2753,
+ "uit": 2754,
+ "wij": 2755,
+ "aat": 2756,
+ "lijk": 2757,
+ "sl": 2758,
+ "daar": 2759,
+ "deze": 2760,
+ "worden": 2761,
+ "moeten": 2762,
+ "onder": 2763,
+ "hebben": 2764,
+ "ook": 2765,
+ "ct": 2766,
+ "nog": 2767,
+ "aal": 2768,
+ "eer": 2769,
+ "bij": 2770,
+ "mijn": 2771,
+ "kom": 2772,
+ "atie": 2773,
+ "eft": 2774,
+ "kel": 2775,
+ "rij": 2776,
+ "heid": 2777,
+ "af": 2778,
+ "stel": 2779,
+ "maar": 2780,
+ "wee": 2781,
+ "heeft": 2782,
+ "waar": 2783,
+ "eren": 2784,
+ "wat": 2785,
+ "wil": 2786,
+ "aag": 2787,
+ "bet": 2788,
+ "hij": 2789,
+ "kun": 2790,
+ "uw": 2791,
+ "dt": 2792,
+ "door": 2793,
+ "tij": 2794,
+ "ond": 2795,
+ "geen": 2796,
+ "gev": 2797,
+ "veel": 2798,
+ "naar": 2799,
+ "aten": 2800,
+ "kunnen": 2801,
+ "echt": 2802,
+ "goe": 2803,
+ "twee": 2804,
+ "delijk": 2805,
+ "uur": 2806,
+ "toe": 2807,
+ "meer": 2808,
+ "onze": 2809,
+ "tijd": 2810,
+ "hoe": 2811,
+ "tot": 2812,
+ "zou": 2813,
+ "aak": 2814,
+ "amen": 2815,
+ "woor": 2816,
+ "wordt": 2817,
+ "gelijk": 2818,
+ "gaan": 2819,
+ "ker": 2820,
+ "eld": 2821,
+ "hou": 2822,
+ "zel": 2823,
+ "tegen": 2824,
+ "komen": 2825,
+ "werk": 2826,
+ "goed": 2827,
+ "zal": 2828,
+ "zij": 2829,
+ "slag": 2830,
+ "zien": 2831,
+ "echter": 2832,
+ "itie": 2833,
+ "tie": 2834,
+ "elijk": 2835,
+ "ische": 2836,
+ "belan": 2837,
+ "haar": 2838,
+ "vr": 2839,
+ "grijk": 2840,
+ "doen": 2841,
+ "land": 2842,
+ "belangrijk": 2843,
+ "open": 2844,
+ "ctie": 2845,
+ "zelf": 2846,
+ "mij": 2847,
+ "iteit": 2848,
+ "stem": 2849,
+ "mee": 2850,
+ "aren": 2851,
+ "dien": 2852,
+ "gaat": 2853,
+ "prob": 2854,
+ "moe": 2855,
+ "ullen": 2856,
+ "zich": 2857,
+ "daarom": 2858,
+ "orm": 2859,
+ "staat": 2860,
+ "zit": 2861,
+ "dui": 2862,
+ "dus": 2863,
+ "ds": 2864,
+ "verslag": 2865,
+ "kelijk": 2866,
+ "proble": 2867,
+ "schap": 2868,
+ "gd": 2869,
+ "hun": 2870,
+ "erd": 2871,
+ "zet": 2872,
+ "staan": 2873,
+ "maal": 2874,
+ "inder": 2875,
+ "eid": 2876,
+ "kken": 2877,
+ "ged": 2878,
+ "zullen": 2879,
+ "mensen": 2880,
+ "jaar": 2881,
+ "regel": 2882,
+ "ieder": 2883,
+ "volgen": 2884,
+ "geven": 2885,
+ "even": 2886,
+ "blij": 2887,
+ "ië": 2888,
+ "uwe": 2889,
+ "maken": 2890,
+ "oek": 2891,
+ "nieuwe": 2892,
+ "baar": 2893,
+ "andere": 2894,
+ "ruik": 2895,
+ "agen": 2896,
+ "ouw": 2897,
+ "willen": 2898,
+ "aakt": 2899,
+ "hoo": 2900,
+ "anden": 2901,
+ "lig": 2902,
+ "samen": 2903,
+ "zeer": 2904,
+ "duidelijk": 2905,
+ "antwoor": 2906,
+ "heel": 2907,
+ "punt": 2908,
+ "houden": 2909,
+ "vraag": 2910,
+ "gele": 2911,
+ "eens": 2912,
+ "besch": 2913,
+ "omen": 2914,
+ "erg": 2915,
+ "doel": 2916,
+ "dag": 2917,
+ "uren": 2918,
+ "ings": 2919,
+ "oren": 2920,
+ "delen": 2921,
+ "steun": 2922,
+ "innen": 2923,
+ "pol": 2924,
+ "oon": 2925,
+ "sn": 2926,
+ "zonder": 2927,
+ "nodig": 2928,
+ "alleen": 2929,
+ "mid": 2930,
+ "ragen": 2931,
+ "iets": 2932,
+ "versch": 2933,
+ "gebruik": 2934,
+ "rouw": 2935,
+ "stellen": 2936,
+ "menten": 2937,
+ "eerste": 2938,
+ "laat": 2939,
+ "groot": 2940,
+ "ood": 2941,
+ "toch": 2942,
+ "laten": 2943,
+ "aard": 2944,
+ "sle": 2945,
+ "deel": 2946,
+ "plaat": 2947,
+ "ree": 2948,
+ "betre": 2949,
+ "lid": 2950,
+ "uiten": 2951,
+ "racht": 2952,
+ "beleid": 2953,
+ "stie": 2954,
+ "staten": 2955,
+ "ggen": 2956,
+ "reken": 2957,
+ "alen": 2958,
+ "ming": 2959,
+ "mogelijk": 2960,
+ "grote": 2961,
+ "altijd": 2962,
+ "enkel": 2963,
+ "wik": 2964,
+ "politie": 2965,
+ "elk": 2966,
+ "handel": 2967,
+ "kwe": 2968,
+ "maat": 2969,
+ "elen": 2970,
+ "vrij": 2971,
+ "jes": 2972,
+ "aam": 2973,
+ "huis": 2974,
+ "weer": 2975,
+ "lidstaten": 2976,
+ "king": 2977,
+ "kle": 2978,
+ "bed": 2979,
+ "geval": 2980,
+ "wikkel": 2981,
+ "kwestie": 2982,
+ "stee": 2983,
+ "hel": 2984,
+ "komst": 2985,
+ "iden": 2986,
+ "eerd": 2987,
+ "tweede": 2988,
+ "probleem": 2989,
+ "ussen": 2990,
+ "snel": 2991,
+ "tig": 2992,
+ "ult": 2993,
+ "nemen": 2994,
+ "commis": 2995,
+ "verschil": 2996,
+ "zoek": 2997,
+ "krij": 2998,
+ "graag": 2999,
+ "denk": 3000,
+ "landen": 3001,
+ "reden": 3002,
+ "besl": 3003,
+ "oeg": 3004,
+ "beter": 3005,
+ "heden": 3006,
+ "mag": 3007,
+ "boven": 3008,
+ "cont": 3009,
+ "fd": 3010,
+ "hele": 3011,
+ "vier": 3012,
+ "gez": 3013,
+ "kw": 3014,
+ "aas": 3015,
+ "ontwikkel": 3016,
+ "drie": 3017,
+ "vaak": 3018,
+ "plaats": 3019,
+ "gang": 3020,
+ "ijf": 3021,
+ "natuur": 3022,
+ "tussen": 3023,
+ "bat": 3024,
+ "komt": 3025,
+ "wacht": 3026,
+ "aad": 3027,
+ "achter": 3028,
+ "gebie": 3029,
+ "verk": 3030,
+ "ligt": 3031,
+ "nieuw": 3032,
+ "vand": 3033,
+ "ý": 3034,
+ "ď": 3035,
+ "ě": 3036,
+ "ř": 3037,
+ "ť": 3038,
+ "ů": 3039,
+ "„": 3040,
+ "ní": 3041,
+ "ně": 3042,
+ "ře": 3043,
+ "ná": 3044,
+ "vě": 3045,
+ "vá": 3046,
+ "rá": 3047,
+ "vy": 3048,
+ "mě": 3049,
+ "ři": 3050,
+ "ří": 3051,
+ "že": 3052,
+ "jí": 3053,
+ "vý": 3054,
+ "ji": 3055,
+ "dě": 3056,
+ "če": 3057,
+ "tě": 3058,
+ "ky": 3059,
+ "še": 3060,
+ "ké": 3061,
+ "ší": 3062,
+ "pře": 3063,
+ "ví": 3064,
+ "ný": 3065,
+ "ži": 3066,
+ "má": 3067,
+ "cí": 3068,
+ "zá": 3069,
+ "ské": 3070,
+ "dá": 3071,
+ "byl": 3072,
+ "tí": 3073,
+ "pří": 3074,
+ "při": 3075,
+ "či": 3076,
+ "vní": 3077,
+ "ča": 3078,
+ "dí": 3079,
+ "dní": 3080,
+ "ká": 3081,
+ "nou": 3082,
+ "vět": 3083,
+ "pě": 3084,
+ "kou": 3085,
+ "ých": 3086,
+ "bě": 3087,
+ "prá": 3088,
+ "jako": 3089,
+ "ží": 3090,
+ "zí": 3091,
+ "jsou": 3092,
+ "jsem": 3093,
+ "lní": 3094,
+ "cké": 3095,
+ "vat": 3096,
+ "před": 3097,
+ "hla": 3098,
+ "stá": 3099,
+ "čí": 3100,
+ "ši": 3101,
+ "kla": 3102,
+ "ště": 3103,
+ "lou": 3104,
+ "mů": 3105,
+ "chá": 3106,
+ "pů": 3107,
+ "také": 3108,
+ "dů": 3109,
+ "nost": 3110,
+ "tře": 3111,
+ "sku": 3112,
+ "vše": 3113,
+ "tní": 3114,
+ "byla": 3115,
+ "ční": 3116,
+ "jeho": 3117,
+ "bý": 3118,
+ "vání": 3119,
+ "ných": 3120,
+ "tři": 3121,
+ "vz": 3122,
+ "stře": 3123,
+ "dva": 3124,
+ "hle": 3125,
+ "čá": 3126,
+ "nosti": 3127,
+ "vš": 3128,
+ "hra": 3129,
+ "jen": 3130,
+ "slo": 3131,
+ "však": 3132,
+ "kdy": 3133,
+ "bylo": 3134,
+ "bude": 3135,
+ "jší": 3136,
+ "vých": 3137,
+ "ním": 3138,
+ "sm": 3139,
+ "koli": 3140,
+ "rů": 3141,
+ "může": 3142,
+ "není": 3143,
+ "hod": 3144,
+ "bí": 3145,
+ "tý": 3146,
+ "stě": 3147,
+ "uje": 3148,
+ "sá": 3149,
+ "pět": 3150,
+ "krá": 3151,
+ "tom": 3152,
+ "ství": 3153,
+ "vně": 3154,
+ "sed": 3155,
+ "své": 3156,
+ "pí": 3157,
+ "musí": 3158,
+ "už": 3159,
+ "tím": 3160,
+ "jící": 3161,
+ "jedno": 3162,
+ "čas": 3163,
+ "čty": 3164,
+ "ský": 3165,
+ "evro": 3166,
+ "toho": 3167,
+ "hy": 3168,
+ "kter": 3169,
+ "rní": 3170,
+ "stí": 3171,
+ "svě": 3172,
+ "pak": 3173,
+ "všech": 3174,
+ "ků": 3175,
+ "ng": 3176,
+ "ád": 3177,
+ "chází": 3178,
+ "být": 3179,
+ "první": 3180,
+ "mno": 3181,
+ "ského": 3182,
+ "pá": 3183,
+ "nebo": 3184,
+ "kem": 3185,
+ "sla": 3186,
+ "ného": 3187,
+ "zde": 3188,
+ "další": 3189,
+ "řa": 3190,
+ "čtyři": 3191,
+ "hrá": 3192,
+ "druh": 3193,
+ "lně": 3194,
+ "vla": 3195,
+ "ských": 3196,
+ "ško": 3197,
+ "půso": 3198,
+ "proto": 3199,
+ "vů": 3200,
+ "ská": 3201,
+ "šest": 3202,
+ "dně": 3203,
+ "ještě": 3204,
+ "mezi": 3205,
+ "několi": 3206,
+ "již": 3207,
+ "čně": 3208,
+ "slu": 3209,
+ "zná": 3210,
+ "sedm": 3211,
+ "vlá": 3212,
+ "osm": 3213,
+ "byly": 3214,
+ "vám": 3215,
+ "cký": 3216,
+ "tech": 3217,
+ "ději": 3218,
+ "velmi": 3219,
+ "leži": 3220,
+ "vala": 3221,
+ "lý": 3222,
+ "tvo": 3223,
+ "spole": 3224,
+ "stup": 3225,
+ "mož": 3226,
+ "evrop": 3227,
+ "stal": 3228,
+ "jde": 3229,
+ "rodi": 3230,
+ "její": 3231,
+ "poli": 3232,
+ "devět": 3233,
+ "sme": 3234,
+ "až": 3235,
+ "této": 3236,
+ "tento": 3237,
+ "kaž": 3238,
+ "nula": 3239,
+ "bych": 3240,
+ "moc": 3241,
+ "stou": 3242,
+ "kdo": 3243,
+ "zd": 3244,
+ "praco": 3245,
+ "tomu": 3246,
+ "ným": 3247,
+ "živo": 3248,
+ "zem": 3249,
+ "násle": 3250,
+ "sky": 3251,
+ "jich": 3252,
+ "měl": 3253,
+ "děla": 3254,
+ "jsme": 3255,
+ "nice": 3256,
+ "stej": 3257,
+ "stní": 3258,
+ "náro": 3259,
+ "nit": 3260,
+ "později": 3261,
+ "tako": 3262,
+ "nce": 3263,
+ "čer": 3264,
+ "ším": 3265,
+ "něco": 3266,
+ "vál": 3267,
+ "řej": 3268,
+ "krát": 3269,
+ "ální": 3270,
+ "asi": 3271,
+ "které": 3272,
+ "stav": 3273,
+ "mají": 3274,
+ "mys": 3275,
+ "době": 3276,
+ "sně": 3277,
+ "zku": 3278,
+ "tů": 3279,
+ "chod": 3280,
+ "spě": 3281,
+ "jejich": 3282,
+ "součas": 3283,
+ "vali": 3284,
+ "kte": 3285,
+ "prů": 3286,
+ "zení": 3287,
+ "pat": 3288,
+ "potře": 3289,
+ "dnes": 3290,
+ "zemí": 3291,
+ "znam": 3292,
+ "mám": 3293,
+ "tedy": 3294,
+ "hlavní": 3295,
+ "použí": 3296,
+ "bní": 3297,
+ "vede": 3298,
+ "lep": 3299,
+ "jek": 3300,
+ "prav": 3301,
+ "politi": 3302,
+ "dne": 3303,
+ "čení": 3304,
+ "než": 3305,
+ "děl": 3306,
+ "čo": 3307,
+ "cích": 3308,
+ "sté": 3309,
+ "dlou": 3310,
+ "několik": 3311,
+ "vyu": 3312,
+ "ckých": 3313,
+ "nové": 3314,
+ "čin": 3315,
+ "dělá": 3316,
+ "ký": 3317,
+ "obla": 3318,
+ "podle": 3319,
+ "důleži": 3320,
+ "poku": 3321,
+ "kone": 3322,
+ "dý": 3323,
+ "dvě": 3324,
+ "žád": 3325,
+ "nout": 3326,
+ "tku": 3327,
+ "tvr": 3328,
+ "ckého": 3329,
+ "rov": 3330,
+ "tele": 3331,
+ "psa": 3332,
+ "svět": 3333,
+ "tivní": 3334,
+ "dosta": 3335,
+ "šel": 3336,
+ "druhé": 3337,
+ "skou": 3338,
+ "žo": 3339,
+ "jedná": 3340,
+ "význam": 3341,
+ "problé": 3342,
+ "publi": 3343,
+ "ván": 3344,
+ "odpo": 3345,
+ "podpo": 3346,
+ "dle": 3347,
+ "jaké": 3348,
+ "šení": 3349,
+ "vím": 3350,
+ "během": 3351,
+ "nachází": 3352,
+ "slou": 3353,
+ "pouze": 3354,
+ "otá": 3355,
+ "plo": 3356,
+ "tové": 3357,
+ "větši": 3358,
+ "komi": 3359,
+ "vají": 3360,
+ "tyto": 3361,
+ "zápa": 3362,
+ "změ": 3363,
+ "moh": 3364,
+ "více": 3365,
+ "společ": 3366,
+ "auto": 3367,
+ "proti": 3368,
+ "dět": 3369,
+ "cháze": 3370,
+ "žel": 3371,
+ "«": 3372,
+ "»": 3373,
+ "а": 3374,
+ "б": 3375,
+ "в": 3376,
+ "г": 3377,
+ "д": 3378,
+ "е": 3379,
+ "ж": 3380,
+ "з": 3381,
+ "и": 3382,
+ "й": 3383,
+ "к": 3384,
+ "л": 3385,
+ "м": 3386,
+ "н": 3387,
+ "о": 3388,
+ "п": 3389,
+ "р": 3390,
+ "с": 3391,
+ "т": 3392,
+ "у": 3393,
+ "ф": 3394,
+ "х": 3395,
+ "ц": 3396,
+ "ч": 3397,
+ "ш": 3398,
+ "щ": 3399,
+ "ъ": 3400,
+ "ы": 3401,
+ "ь": 3402,
+ "э": 3403,
+ "ю": 3404,
+ "я": 3405,
+ "ё": 3406,
+ "‑": 3407,
+ "−": 3408,
+ "ст": 3409,
+ "ен": 3410,
+ "но": 3411,
+ "на": 3412,
+ "пр": 3413,
+ "то": 3414,
+ "по": 3415,
+ "ра": 3416,
+ "го": 3417,
+ "ко": 3418,
+ "не": 3419,
+ "во": 3420,
+ "ва": 3421,
+ "ет": 3422,
+ "ер": 3423,
+ "ни": 3424,
+ "ел": 3425,
+ "ит": 3426,
+ "ны": 3427,
+ "за": 3428,
+ "ро": 3429,
+ "ени": 3430,
+ "ка": 3431,
+ "ли": 3432,
+ "ем": 3433,
+ "да": 3434,
+ "об": 3435,
+ "ла": 3436,
+ "до": 3437,
+ "ся": 3438,
+ "ть": 3439,
+ "от": 3440,
+ "ло": 3441,
+ "ль": 3442,
+ "ед": 3443,
+ "со": 3444,
+ "ми": 3445,
+ "ре": 3446,
+ "мо": 3447,
+ "ци": 3448,
+ "про": 3449,
+ "та": 3450,
+ "это": 3451,
+ "ки": 3452,
+ "ру": 3453,
+ "при": 3454,
+ "ти": 3455,
+ "се": 3456,
+ "ста": 3457,
+ "вы": 3458,
+ "мы": 3459,
+ "ви": 3460,
+ "бы": 3461,
+ "ма": 3462,
+ "ес": 3463,
+ "ля": 3464,
+ "сти": 3465,
+ "ле": 3466,
+ "что": 3467,
+ "ме": 3468,
+ "ри": 3469,
+ "ча": 3470,
+ "од": 3471,
+ "ей": 3472,
+ "ель": 3473,
+ "ения": 3474,
+ "га": 3475,
+ "ну": 3476,
+ "си": 3477,
+ "па": 3478,
+ "раз": 3479,
+ "бо": 3480,
+ "сто": 3481,
+ "су": 3482,
+ "са": 3483,
+ "ду": 3484,
+ "его": 3485,
+ "ест": 3486,
+ "ин": 3487,
+ "ить": 3488,
+ "из": 3489,
+ "же": 3490,
+ "му": 3491,
+ "пер": 3492,
+ "под": 3493,
+ "ение": 3494,
+ "сь": 3495,
+ "ку": 3496,
+ "пред": 3497,
+ "ного": 3498,
+ "ных": 3499,
+ "вер": 3500,
+ "те": 3501,
+ "ной": 3502,
+ "ции": 3503,
+ "де": 3504,
+ "ры": 3505,
+ "дел": 3506,
+ "лю": 3507,
+ "ве": 3508,
+ "он": 3509,
+ "мен": 3510,
+ "ги": 3511,
+ "ня": 3512,
+ "бу": 3513,
+ "пра": 3514,
+ "все": 3515,
+ "ется": 3516,
+ "сть": 3517,
+ "жа": 3518,
+ "дол": 3519,
+ "жи": 3520,
+ "бе": 3521,
+ "кон": 3522,
+ "сл": 3523,
+ "ши": 3524,
+ "ди": 3525,
+ "ств": 3526,
+ "ско": 3527,
+ "ные": 3528,
+ "чи": 3529,
+ "ют": 3530,
+ "дер": 3531,
+ "стра": 3532,
+ "ты": 3533,
+ "ход": 3534,
+ "щи": 3535,
+ "зо": 3536,
+ "зна": 3537,
+ "ности": 3538,
+ "чес": 3539,
+ "вля": 3540,
+ "вать": 3541,
+ "ор": 3542,
+ "пол": 3543,
+ "вет": 3544,
+ "так": 3545,
+ "ша": 3546,
+ "ту": 3547,
+ "сво": 3548,
+ "пре": 3549,
+ "она": 3550,
+ "итель": 3551,
+ "ный": 3552,
+ "сло": 3553,
+ "как": 3554,
+ "вл": 3555,
+ "ность": 3556,
+ "хо": 3557,
+ "мож": 3558,
+ "пе": 3559,
+ "для": 3560,
+ "ния": 3561,
+ "ное": 3562,
+ "рас": 3563,
+ "долж": 3564,
+ "дар": 3565,
+ "тель": 3566,
+ "ска": 3567,
+ "пу": 3568,
+ "ство": 3569,
+ "кото": 3570,
+ "раб": 3571,
+ "ее": 3572,
+ "род": 3573,
+ "эти": 3574,
+ "соб": 3575,
+ "ору": 3576,
+ "жен": 3577,
+ "ным": 3578,
+ "ити": 3579,
+ "ние": 3580,
+ "ком": 3581,
+ "дет": 3582,
+ "сту": 3583,
+ "гу": 3584,
+ "пи": 3585,
+ "меж": 3586,
+ "ению": 3587,
+ "тер": 3588,
+ "работ": 3589,
+ "воз": 3590,
+ "ция": 3591,
+ "кой": 3592,
+ "щест": 3593,
+ "гра": 3594,
+ "зи": 3595,
+ "ря": 3596,
+ "между": 3597,
+ "ства": 3598,
+ "вс": 3599,
+ "ело": 3600,
+ "ше": 3601,
+ "мер": 3602,
+ "ба": 3603,
+ "зы": 3604,
+ "лу": 3605,
+ "аль": 3606,
+ "дей": 3607,
+ "гла": 3608,
+ "народ": 3609,
+ "кти": 3610,
+ "предста": 3611,
+ "лся": 3612,
+ "явля": 3613,
+ "ски": 3614,
+ "нов": 3615,
+ "един": 3616,
+ "ров": 3617,
+ "ис": 3618,
+ "нима": 3619,
+ "рем": 3620,
+ "ходи": 3621,
+ "также": 3622,
+ "дру": 3623,
+ "ать": 3624,
+ "след": 3625,
+ "гово": 3626,
+ "ная": 3627,
+ "ющи": 3628,
+ "ень": 3629,
+ "которы": 3630,
+ "хот": 3631,
+ "ву": 3632,
+ "их": 3633,
+ "ему": 3634,
+ "чит": 3635,
+ "важ": 3636,
+ "орга": 3637,
+ "чески": 3638,
+ "ще": 3639,
+ "ке": 3640,
+ "ха": 3641,
+ "пос": 3642,
+ "том": 3643,
+ "боль": 3644,
+ "мне": 3645,
+ "пас": 3646,
+ "объ": 3647,
+ "прав": 3648,
+ "конф": 3649,
+ "слу": 3650,
+ "поддер": 3651,
+ "стви": 3652,
+ "наш": 3653,
+ "лько": 3654,
+ "стоя": 3655,
+ "ную": 3656,
+ "лем": 3657,
+ "енных": 3658,
+ "кра": 3659,
+ "ды": 3660,
+ "международ": 3661,
+ "гда": 3662,
+ "необ": 3663,
+ "госу": 3664,
+ "ству": 3665,
+ "ении": 3666,
+ "государ": 3667,
+ "кто": 3668,
+ "им": 3669,
+ "чест": 3670,
+ "рет": 3671,
+ "вопро": 3672,
+ "лен": 3673,
+ "ели": 3674,
+ "рова": 3675,
+ "ций": 3676,
+ "нам": 3677,
+ "этой": 3678,
+ "жения": 3679,
+ "необходи": 3680,
+ "меня": 3681,
+ "было": 3682,
+ "сили": 3683,
+ "фи": 3684,
+ "вя": 3685,
+ "шь": 3686,
+ "этого": 3687,
+ "они": 3688,
+ "органи": 3689,
+ "безо": 3690,
+ "проб": 3691,
+ "име": 3692,
+ "реш": 3693,
+ "би": 3694,
+ "безопас": 3695,
+ "ются": 3696,
+ "оста": 3697,
+ "енно": 3698,
+ "год": 3699,
+ "ела": 3700,
+ "представ": 3701,
+ "ться": 3702,
+ "слово": 3703,
+ "организа": 3704,
+ "должны": 3705,
+ "этом": 3706,
+ "бла": 3707,
+ "че": 3708,
+ "чу": 3709,
+ "благо": 3710,
+ "этому": 3711,
+ "врем": 3712,
+ "спе": 3713,
+ "ном": 3714,
+ "ений": 3715,
+ "спо": 3716,
+ "нас": 3717,
+ "нет": 3718,
+ "зу": 3719,
+ "вед": 3720,
+ "еще": 3721,
+ "сказа": 3722,
+ "сей": 3723,
+ "ерен": 3724,
+ "дан": 3725,
+ "сам": 3726,
+ "еля": 3727,
+ "ран": 3728,
+ "зыва": 3729,
+ "является": 3730,
+ "будет": 3731,
+ "ктив": 3732,
+ "тре": 3733,
+ "деле": 3734,
+ "мот": 3735,
+ "конферен": 3736,
+ "лась": 3737,
+ "час": 3738,
+ "сторо": 3739,
+ "кого": 3740,
+ "ез": 3741,
+ "ней": 3742,
+ "ос": 3743,
+ "лись": 3744,
+ "разору": 3745,
+ "пере": 3746,
+ "сси": 3747,
+ "ными": 3748,
+ "проц": 3749,
+ "голо": 3750,
+ "чело": 3751,
+ "боле": 3752,
+ "челове": 3753,
+ "сер": 3754,
+ "пл": 3755,
+ "чет": 3756,
+ "стран": 3757,
+ "пя": 3758,
+ "был": 3759,
+ "кла": 3760,
+ "тов": 3761,
+ "жд": 3762,
+ "дела": 3763,
+ "ера": 3764,
+ "уже": 3765,
+ "совет": 3766,
+ "ген": 3767,
+ "безопасности": 3768,
+ "ца": 3769,
+ "седа": 3770,
+ "поз": 3771,
+ "ответ": 3772,
+ "проблем": 3773,
+ "нако": 3774,
+ "тем": 3775,
+ "доста": 3776,
+ "пы": 3777,
+ "ща": 3778,
+ "вой": 3779,
+ "сущест": 3780,
+ "необходимо": 3781,
+ "быть": 3782,
+ "может": 3783,
+ "дем": 3784,
+ "чтобы": 3785,
+ "ек": 3786,
+ "чер": 3787,
+ "усили": 3788,
+ "рес": 3789,
+ "руд": 3790,
+ "единенных": 3791,
+ "доб": 3792,
+ "дости": 3793,
+ "ствен": 3794,
+ "ядер": 3795,
+ "годня": 3796,
+ "каза": 3797,
+ "сегодня": 3798,
+ "сейчас": 3799,
+ "только": 3800,
+ "вод": 3801,
+ "есь": 3802,
+ "много": 3803,
+ "буду": 3804,
+ "ев": 3805,
+ "есть": 3806,
+ "три": 3807,
+ "общест": 3808,
+ "явл": 3809,
+ "высту": 3810,
+ "ред": 3811,
+ "счит": 3812,
+ "сит": 3813,
+ "делега": 3814,
+ "лож": 3815,
+ "этот": 3816,
+ "фор": 3817,
+ "клю": 3818,
+ "возмож": 3819,
+ "вания": 3820,
+ "бли": 3821,
+ "или": 3822,
+ "вз": 3823,
+ "наций": 3824,
+ "ского": 3825,
+ "приня": 3826,
+ "пла": 3827,
+ "оч": 3828,
+ "иться": 3829,
+ "сте": 3830,
+ "наши": 3831,
+ "которые": 3832,
+ "ар": 3833,
+ "имеет": 3834,
+ "сот": 3835,
+ "знач": 3836,
+ "перь": 3837,
+ "следу": 3838,
+ "ены": 3839,
+ "таки": 3840,
+ "объединенных": 3841,
+ "стро": 3842,
+ "теперь": 3843,
+ "бле": 3844,
+ "благодар": 3845,
+ "разв": 3846,
+ "ан": 3847,
+ "жива": 3848,
+ "очень": 3849,
+ "ят": 3850,
+ "без": 3851,
+ "обес": 3852,
+ "гро": 3853,
+ "лось": 3854,
+ "сы": 3855,
+ "организации": 3856,
+ "член": 3857,
+ "того": 3858,
+ "ональ": 3859,
+ "жда": 3860,
+ "всех": 3861,
+ "свя": 3862,
+ "более": 3863,
+ "сов": 3864,
+ "когда": 3865,
+ "вот": 3866,
+ "кре": 3867,
+ "кры": 3868,
+ "поэтому": 3869,
+ "воль": 3870,
+ "ой": 3871,
+ "генера": 3872,
+ "чем": 3873,
+ "лы": 3874,
+ "полити": 3875,
+ "вен": 3876,
+ "конференции": 3877,
+ "процес": 3878,
+ "бя": 3879,
+ "ите": 3880,
+ "отно": 3881,
+ "развити": 3882,
+ "аф": 3883,
+ "ющ": 3884,
+ "вно": 3885,
+ "мир": 3886,
+ "нии": 3887,
+ "кая": 3888,
+ "ас": 3889,
+ "ительно": 3890,
+ "вто": 3891,
+ "ением": 3892,
+ "генераль": 3893,
+ "прот": 3894,
+ "всем": 3895,
+ "самбле": 3896,
+ "ассамбле": 3897,
+ "ом": 3898,
+ "зд": 3899,
+ "смот": 3900,
+ "реги": 3901,
+ "чего": 3902,
+ "однако": 3903,
+ "усилия": 3904,
+ "действи": 3905,
+ "чно": 3906,
+ "уча": 3907,
+ "образ": 3908,
+ "вос": 3909,
+ "эта": 3910,
+ "перего": 3911,
+ "говор": 3912,
+ "вам": 3913,
+ "моло": 3914,
+ "время": 3915,
+ "дь": 3916,
+ "хотел": 3917,
+ "гру": 3918,
+ "заявл": 3919,
+ "предоста": 3920,
+ "поль": 3921,
+ "нее": 3922,
+ "резо": 3923,
+ "перегово": 3924,
+ "резолю": 3925,
+ "крет": 3926,
+ "поддерж": 3927,
+ "обеспе": 3928,
+ "него": 3929,
+ "представит": 3930,
+ "наде": 3931,
+ "кри": 3932,
+ "чь": 3933,
+ "проек": 3934,
+ "лет": 3935,
+ "други": 3936,
+ "_": 3937,
+ "،": 3938,
+ "؛": 3939,
+ "؟": 3940,
+ "ء": 3941,
+ "آ": 3942,
+ "أ": 3943,
+ "ؤ": 3944,
+ "إ": 3945,
+ "ئ": 3946,
+ "ا": 3947,
+ "ب": 3948,
+ "ة": 3949,
+ "ت": 3950,
+ "ث": 3951,
+ "ج": 3952,
+ "ح": 3953,
+ "خ": 3954,
+ "د": 3955,
+ "ذ": 3956,
+ "ر": 3957,
+ "ز": 3958,
+ "س": 3959,
+ "ش": 3960,
+ "ص": 3961,
+ "ض": 3962,
+ "ط": 3963,
+ "ظ": 3964,
+ "ع": 3965,
+ "غ": 3966,
+ "ـ": 3967,
+ "ف": 3968,
+ "ق": 3969,
+ "ك": 3970,
+ "ل": 3971,
+ "م": 3972,
+ "ن": 3973,
+ "ه": 3974,
+ "و": 3975,
+ "ى": 3976,
+ "ي": 3977,
+ "ً": 3978,
+ "ٌ": 3979,
+ "ٍ": 3980,
+ "َ": 3981,
+ "ُ": 3982,
+ "ِ": 3983,
+ "ّ": 3984,
+ "ْ": 3985,
+ "ٰ": 3986,
+ "چ": 3987,
+ "ڨ": 3988,
+ "ک": 3989,
+ "ھ": 3990,
+ "ی": 3991,
+ "ۖ": 3992,
+ "ۗ": 3993,
+ "ۘ": 3994,
+ "ۚ": 3995,
+ "ۛ": 3996,
+ "—": 3997,
+ "☭": 3998,
+ "ﺃ": 3999,
+ "ﻻ": 4000,
+ "ال": 4001,
+ "َا": 4002,
+ "وَ": 4003,
+ "َّ": 4004,
+ "ِي": 4005,
+ "أَ": 4006,
+ "لَ": 4007,
+ "نَ": 4008,
+ "الْ": 4009,
+ "هُ": 4010,
+ "ُو": 4011,
+ "ما": 4012,
+ "نْ": 4013,
+ "من": 4014,
+ "عَ": 4015,
+ "نا": 4016,
+ "لا": 4017,
+ "مَ": 4018,
+ "تَ": 4019,
+ "فَ": 4020,
+ "أن": 4021,
+ "لي": 4022,
+ "مِ": 4023,
+ "ان": 4024,
+ "في": 4025,
+ "رَ": 4026,
+ "يَ": 4027,
+ "هِ": 4028,
+ "مْ": 4029,
+ "قَ": 4030,
+ "بِ": 4031,
+ "لى": 4032,
+ "ين": 4033,
+ "إِ": 4034,
+ "لِ": 4035,
+ "وا": 4036,
+ "كَ": 4037,
+ "ها": 4038,
+ "ًا": 4039,
+ "مُ": 4040,
+ "ون": 4041,
+ "الم": 4042,
+ "بَ": 4043,
+ "يا": 4044,
+ "ذا": 4045,
+ "سا": 4046,
+ "الل": 4047,
+ "مي": 4048,
+ "يْ": 4049,
+ "را": 4050,
+ "ري": 4051,
+ "لك": 4052,
+ "مَا": 4053,
+ "نَّ": 4054,
+ "لم": 4055,
+ "إن": 4056,
+ "ست": 4057,
+ "وم": 4058,
+ "َّا": 4059,
+ "لَا": 4060,
+ "هم": 4061,
+ "ِّ": 4062,
+ "كُ": 4063,
+ "كان": 4064,
+ "سَ": 4065,
+ "با": 4066,
+ "دي": 4067,
+ "حَ": 4068,
+ "عْ": 4069,
+ "بي": 4070,
+ "الأ": 4071,
+ "ول": 4072,
+ "فِي": 4073,
+ "رِ": 4074,
+ "دا": 4075,
+ "مِنْ": 4076,
+ "ُونَ": 4077,
+ "وْ": 4078,
+ "هَا": 4079,
+ "ُّ": 4080,
+ "الس": 4081,
+ "الَ": 4082,
+ "ني": 4083,
+ "لْ": 4084,
+ "تُ": 4085,
+ "هل": 4086,
+ "رة": 4087,
+ "دَ": 4088,
+ "سْ": 4089,
+ "تِ": 4090,
+ "نَا": 4091,
+ "رْ": 4092,
+ "اللَّ": 4093,
+ "سامي": 4094,
+ "كن": 4095,
+ "كل": 4096,
+ "هَ": 4097,
+ "عَلَ": 4098,
+ "على": 4099,
+ "مع": 4100,
+ "إلى": 4101,
+ "قد": 4102,
+ "الر": 4103,
+ "ُوا": 4104,
+ "ير": 4105,
+ "عن": 4106,
+ "يُ": 4107,
+ "نِ": 4108,
+ "بْ": 4109,
+ "الح": 4110,
+ "هُمْ": 4111,
+ "قا": 4112,
+ "ذه": 4113,
+ "الت": 4114,
+ "ِينَ": 4115,
+ "جَ": 4116,
+ "هذا": 4117,
+ "عد": 4118,
+ "الع": 4119,
+ "دْ": 4120,
+ "قَالَ": 4121,
+ "رُ": 4122,
+ "يم": 4123,
+ "ية": 4124,
+ "نُ": 4125,
+ "خَ": 4126,
+ "رب": 4127,
+ "الك": 4128,
+ "وَا": 4129,
+ "أنا": 4130,
+ "ةِ": 4131,
+ "الن": 4132,
+ "حد": 4133,
+ "عِ": 4134,
+ "تا": 4135,
+ "هو": 4136,
+ "فا": 4137,
+ "عا": 4138,
+ "الش": 4139,
+ "لُ": 4140,
+ "يت": 4141,
+ "ذَا": 4142,
+ "يع": 4143,
+ "الذ": 4144,
+ "حْ": 4145,
+ "الص": 4146,
+ "إِنَّ": 4147,
+ "جا": 4148,
+ "علي": 4149,
+ "كَا": 4150,
+ "بُ": 4151,
+ "تع": 4152,
+ "وق": 4153,
+ "مل": 4154,
+ "لَّ": 4155,
+ "يد": 4156,
+ "أخ": 4157,
+ "رف": 4158,
+ "تي": 4159,
+ "الِ": 4160,
+ "ّا": 4161,
+ "ذلك": 4162,
+ "أَنْ": 4163,
+ "سِ": 4164,
+ "توم": 4165,
+ "مر": 4166,
+ "مَنْ": 4167,
+ "بل": 4168,
+ "الق": 4169,
+ "الله": 4170,
+ "ِيَ": 4171,
+ "كم": 4172,
+ "ذَ": 4173,
+ "عل": 4174,
+ "حب": 4175,
+ "سي": 4176,
+ "عُ": 4177,
+ "الج": 4178,
+ "الد": 4179,
+ "شَ": 4180,
+ "تك": 4181,
+ "فْ": 4182,
+ "صَ": 4183,
+ "لل": 4184,
+ "دِ": 4185,
+ "بر": 4186,
+ "فِ": 4187,
+ "ته": 4188,
+ "أع": 4189,
+ "تْ": 4190,
+ "قْ": 4191,
+ "الْأَ": 4192,
+ "ئِ": 4193,
+ "عَنْ": 4194,
+ "ور": 4195,
+ "حا": 4196,
+ "الَّ": 4197,
+ "مت": 4198,
+ "فر": 4199,
+ "دُ": 4200,
+ "هنا": 4201,
+ "وَأَ": 4202,
+ "تب": 4203,
+ "ةُ": 4204,
+ "أي": 4205,
+ "سب": 4206,
+ "ريد": 4207,
+ "وج": 4208,
+ "كُمْ": 4209,
+ "حِ": 4210,
+ "كْ": 4211,
+ "در": 4212,
+ "َاء": 4213,
+ "هذه": 4214,
+ "الط": 4215,
+ "الْمُ": 4216,
+ "دة": 4217,
+ "قل": 4218,
+ "غَ": 4219,
+ "يوم": 4220,
+ "الَّذ": 4221,
+ "كر": 4222,
+ "تر": 4223,
+ "كِ": 4224,
+ "كي": 4225,
+ "عَلَى": 4226,
+ "رَب": 4227,
+ "عة": 4228,
+ "قُ": 4229,
+ "جْ": 4230,
+ "فض": 4231,
+ "لة": 4232,
+ "هْ": 4233,
+ "رَا": 4234,
+ "وَلَ": 4235,
+ "الْمَ": 4236,
+ "أَنَّ": 4237,
+ "يَا": 4238,
+ "أُ": 4239,
+ "شي": 4240,
+ "اللَّهُ": 4241,
+ "لَى": 4242,
+ "قِ": 4243,
+ "أت": 4244,
+ "عَلَيْ": 4245,
+ "اللَّهِ": 4246,
+ "الب": 4247,
+ "ضَ": 4248,
+ "ةً": 4249,
+ "قي": 4250,
+ "ار": 4251,
+ "بد": 4252,
+ "خْ": 4253,
+ "سْتَ": 4254,
+ "طَ": 4255,
+ "قَدْ": 4256,
+ "ذهب": 4257,
+ "أم": 4258,
+ "ماذا": 4259,
+ "وَإِ": 4260,
+ "ةٌ": 4261,
+ "ونَ": 4262,
+ "ليلى": 4263,
+ "ولا": 4264,
+ "حُ": 4265,
+ "هي": 4266,
+ "صل": 4267,
+ "الخ": 4268,
+ "ود": 4269,
+ "ليس": 4270,
+ "لدي": 4271,
+ "قال": 4272,
+ "كَانَ": 4273,
+ "مَّ": 4274,
+ "حي": 4275,
+ "تم": 4276,
+ "لن": 4277,
+ "وَلَا": 4278,
+ "بع": 4279,
+ "يمكن": 4280,
+ "سُ": 4281,
+ "ةَ": 4282,
+ "حت": 4283,
+ "رًا": 4284,
+ "كا": 4285,
+ "شا": 4286,
+ "هِمْ": 4287,
+ "لَهُ": 4288,
+ "زَ": 4289,
+ "داً": 4290,
+ "مس": 4291,
+ "كث": 4292,
+ "الْعَ": 4293,
+ "جِ": 4294,
+ "صْ": 4295,
+ "فَا": 4296,
+ "له": 4297,
+ "وي": 4298,
+ "عَا": 4299,
+ "هُوَ": 4300,
+ "بِي": 4301,
+ "بَا": 4302,
+ "أس": 4303,
+ "ثَ": 4304,
+ "لِي": 4305,
+ "رض": 4306,
+ "الرَّ": 4307,
+ "لِكَ": 4308,
+ "تَّ": 4309,
+ "فُ": 4310,
+ "قة": 4311,
+ "فعل": 4312,
+ "مِن": 4313,
+ "الآ": 4314,
+ "ثُ": 4315,
+ "سم": 4316,
+ "مَّا": 4317,
+ "بِهِ": 4318,
+ "تق": 4319,
+ "خر": 4320,
+ "لقد": 4321,
+ "خل": 4322,
+ "شر": 4323,
+ "أنت": 4324,
+ "لَّا": 4325,
+ "سن": 4326,
+ "السَّ": 4327,
+ "الذي": 4328,
+ "سَا": 4329,
+ "وما": 4330,
+ "زل": 4331,
+ "وب": 4332,
+ "أْ": 4333,
+ "إذا": 4334,
+ "رِي": 4335,
+ "حة": 4336,
+ "نِي": 4337,
+ "الْحَ": 4338,
+ "وَقَالَ": 4339,
+ "به": 4340,
+ "ةٍ": 4341,
+ "سأ": 4342,
+ "رٌ": 4343,
+ "بال": 4344,
+ "مة": 4345,
+ "شْ": 4346,
+ "وت": 4347,
+ "عند": 4348,
+ "فس": 4349,
+ "بَعْ": 4350,
+ "هر": 4351,
+ "قط": 4352,
+ "أح": 4353,
+ "إنه": 4354,
+ "وع": 4355,
+ "فت": 4356,
+ "غا": 4357,
+ "هناك": 4358,
+ "بت": 4359,
+ "مِنَ": 4360,
+ "سر": 4361,
+ "ذَلِكَ": 4362,
+ "رس": 4363,
+ "حدث": 4364,
+ "غْ": 4365,
+ "ِّي": 4366,
+ "الإ": 4367,
+ "وَيَ": 4368,
+ "جل": 4369,
+ "است": 4370,
+ "قِي": 4371,
+ "عب": 4372,
+ "وس": 4373,
+ "يش": 4374,
+ "الَّذِينَ": 4375,
+ "تاب": 4376,
+ "دِي": 4377,
+ "جب": 4378,
+ "كون": 4379,
+ "بن": 4380,
+ "الث": 4381,
+ "لَيْ": 4382,
+ "بعد": 4383,
+ "وَالْ": 4384,
+ "فَأَ": 4385,
+ "عم": 4386,
+ "هُم": 4387,
+ "تن": 4388,
+ "ذْ": 4389,
+ "أص": 4390,
+ "أين": 4391,
+ "رَبِّ": 4392,
+ "الذين": 4393,
+ "إِن": 4394,
+ "بين": 4395,
+ "جُ": 4396,
+ "عَلَيْهِ": 4397,
+ "حَا": 4398,
+ "لو": 4399,
+ "ستط": 4400,
+ "ظر": 4401,
+ "لَمْ": 4402,
+ "ءِ": 4403,
+ "كُل": 4404,
+ "طل": 4405,
+ "تَا": 4406,
+ "ضُ": 4407,
+ "كنت": 4408,
+ "لًا": 4409,
+ "مٌ": 4410,
+ "قبل": 4411,
+ "ــ": 4412,
+ "ذِ": 4413,
+ "قَوْ": 4414,
+ "صِ": 4415,
+ "مًا": 4416,
+ "كانت": 4417,
+ "صا": 4418,
+ "يق": 4419,
+ "الف": 4420,
+ "النا": 4421,
+ "مٍ": 4422,
+ "إِنْ": 4423,
+ "النَّ": 4424,
+ "جد": 4425,
+ "وَمَا": 4426,
+ "تت": 4427,
+ "بح": 4428,
+ "مكان": 4429,
+ "كيف": 4430,
+ "ّة": 4431,
+ "الا": 4432,
+ "جَا": 4433,
+ "أو": 4434,
+ "ساعد": 4435,
+ "ضِ": 4436,
+ "إلا": 4437,
+ "راً": 4438,
+ "قَا": 4439,
+ "رأ": 4440,
+ "عت": 4441,
+ "أحد": 4442,
+ "هد": 4443,
+ "ضا": 4444,
+ "طر": 4445,
+ "أق": 4446,
+ "ماء": 4447,
+ "دَّ": 4448,
+ "البا": 4449,
+ "مُو": 4450,
+ "أَوْ": 4451,
+ "طا": 4452,
+ "قُو": 4453,
+ "خِ": 4454,
+ "تل": 4455,
+ "ستطيع": 4456,
+ "دَا": 4457,
+ "النَّا": 4458,
+ "إلَى": 4459,
+ "وَتَ": 4460,
+ "هَذَا": 4461,
+ "بة": 4462,
+ "عليك": 4463,
+ "جر": 4464,
+ "المن": 4465,
+ "زا": 4466,
+ "رٍ": 4467,
+ "دع": 4468,
+ "ًّا": 4469,
+ "سة": 4470,
+ "ثُمَّ": 4471,
+ "شيء": 4472,
+ "الغ": 4473,
+ "تح": 4474,
+ "رُونَ": 4475,
+ "اليوم": 4476,
+ "مِي": 4477,
+ "نُوا": 4478,
+ "أر": 4479,
+ "تُمْ": 4480,
+ "عر": 4481,
+ "يف": 4482,
+ "أب": 4483,
+ "دًا": 4484,
+ "صَا": 4485,
+ "التَّ": 4486,
+ "أريد": 4487,
+ "الز": 4488,
+ "يَوْ": 4489,
+ "إلي": 4490,
+ "جي": 4491,
+ "يَعْ": 4492,
+ "فضل": 4493,
+ "الإن": 4494,
+ "أنه": 4495,
+ "1": 4496,
+ "2": 4497,
+ "3": 4498,
+ "4": 4499,
+ "5": 4500,
+ "·": 4501,
+ "×": 4502,
+ "̃": 4503,
+ "̌": 4504,
+ "ε": 4505,
+ "λ": 4506,
+ "μ": 4507,
+ "•": 4508,
+ "‧": 4509,
+ "─": 4510,
+ "□": 4511,
+ "、": 4512,
+ "。": 4513,
+ "〈": 4514,
+ "〉": 4515,
+ "《": 4516,
+ "》": 4517,
+ "「": 4518,
+ "」": 4519,
+ "『": 4520,
+ "』": 4521,
+ "ア": 4522,
+ "オ": 4523,
+ "カ": 4524,
+ "チ": 4525,
+ "ド": 4526,
+ "ベ": 4527,
+ "ャ": 4528,
+ "ヤ": 4529,
+ "ン": 4530,
+ "・": 4531,
+ "ー": 4532,
+ "ㄟ": 4533,
+ "!": 4534,
+ "(": 4535,
+ ")": 4536,
+ ",": 4537,
+ "-": 4538,
+ "/": 4539,
+ ":": 4540,
+ ";": 4541,
+ "?": 4542,
+ "p": 4543,
+ "i4": 4544,
+ "zh": 4545,
+ "i2": 4546,
+ "ng1": 4547,
+ "u4": 4548,
+ "i1": 4549,
+ "ng2": 4550,
+ "u3": 4551,
+ "de5": 4552,
+ "e4": 4553,
+ "i3": 4554,
+ "ng4": 4555,
+ "an4": 4556,
+ "shi4": 4557,
+ "an2": 4558,
+ "u2": 4559,
+ "u1": 4560,
+ "ng3": 4561,
+ "a1": 4562,
+ "an1": 4563,
+ "e2": 4564,
+ "a4": 4565,
+ "ei4": 4566,
+ "ong1": 4567,
+ "ai4": 4568,
+ "ao4": 4569,
+ "ang1": 4570,
+ "an3": 4571,
+ "wei4": 4572,
+ "uo2": 4573,
+ "n1": 4574,
+ "en2": 4575,
+ "ao3": 4576,
+ "e1": 4577,
+ "qi": 4578,
+ "eng2": 4579,
+ "zho": 4580,
+ "ang3": 4581,
+ "ang4": 4582,
+ "ang2": 4583,
+ "uo4": 4584,
+ "ge4": 4585,
+ "yi1": 4586,
+ "guo2": 4587,
+ "a3": 4588,
+ "he2": 4589,
+ "e3": 4590,
+ "yi2": 4591,
+ "di4": 4592,
+ "zhong1": 4593,
+ "bu4": 4594,
+ "ai2": 4595,
+ "n2": 4596,
+ "zai4": 4597,
+ "shi2": 4598,
+ "eng1": 4599,
+ "ren2": 4600,
+ "ong2": 4601,
+ "xian4": 4602,
+ "xu": 4603,
+ "n4": 4604,
+ "li4": 4605,
+ "en4": 4606,
+ "yu2": 4607,
+ "ei2": 4608,
+ "yi2ge4": 4609,
+ "ou4": 4610,
+ "ei3": 4611,
+ "ui4": 4612,
+ "a2": 4613,
+ "you3": 4614,
+ "ao1": 4615,
+ "da4": 4616,
+ "cheng2": 4617,
+ "en1": 4618,
+ "eng4": 4619,
+ "yi4": 4620,
+ "si1": 4621,
+ "zhi4": 4622,
+ "jia1": 4623,
+ "yuan2": 4624,
+ "ta1": 4625,
+ "de5yi2ge4": 4626,
+ "ke1": 4627,
+ "shu3": 4628,
+ "xi1": 4629,
+ "ji2": 4630,
+ "ao2": 4631,
+ "ou3": 4632,
+ "ong4": 4633,
+ "xia4": 4634,
+ "ai1": 4635,
+ "gong1": 4636,
+ "zhi1": 4637,
+ "en3": 4638,
+ "wei2": 4639,
+ "xue2": 4640,
+ "qu1": 4641,
+ "zhou1": 4642,
+ "er3": 4643,
+ "ming2": 4644,
+ "zhong3": 4645,
+ "li3": 4646,
+ "wu4": 4647,
+ "yi3": 4648,
+ "uo1": 4649,
+ "e5": 4650,
+ "ji4": 4651,
+ "xing2": 4652,
+ "jian4": 4653,
+ "hua4": 4654,
+ "yu3": 4655,
+ "uo3": 4656,
+ "ji1": 4657,
+ "ai3": 4658,
+ "zuo4": 4659,
+ "hou4": 4660,
+ "hui4": 4661,
+ "ei1": 4662,
+ "nian2": 4663,
+ "qi2": 4664,
+ "dao4": 4665,
+ "sheng1": 4666,
+ "de2": 4667,
+ "dai4": 4668,
+ "uan2": 4669,
+ "zhe4": 4670,
+ "zheng4": 4671,
+ "ben3": 4672,
+ "shang4": 4673,
+ "zhu3": 4674,
+ "bei4": 4675,
+ "ye4": 4676,
+ "chu1": 4677,
+ "zhan4": 4678,
+ "le5": 4679,
+ "lai2": 4680,
+ "shi3": 4681,
+ "nan2": 4682,
+ "ren4": 4683,
+ "you2": 4684,
+ "ke4": 4685,
+ "ba1": 4686,
+ "fu4": 4687,
+ "dui4": 4688,
+ "ya4": 4689,
+ "mei3": 4690,
+ "zi4": 4691,
+ "xin1": 4692,
+ "jing1": 4693,
+ "zhu": 4694,
+ "n3": 4695,
+ "yong4": 4696,
+ "mu4": 4697,
+ "jiao4": 4698,
+ "ye3": 4699,
+ "jin4": 4700,
+ "bian4": 4701,
+ "lu4": 4702,
+ "qi1": 4703,
+ "she4": 4704,
+ "xiang1": 4705,
+ "ong3": 4706,
+ "shu4": 4707,
+ "dong4": 4708,
+ "suo3": 4709,
+ "guan1": 4710,
+ "san1": 4711,
+ "te4": 4712,
+ "duo1": 4713,
+ "fu2": 4714,
+ "min2": 4715,
+ "la1": 4716,
+ "zhi2": 4717,
+ "zhen4": 4718,
+ "ou1": 4719,
+ "wu3": 4720,
+ "ma3": 4721,
+ "i5": 4722,
+ "zi5": 4723,
+ "ju4": 4724,
+ "er4": 4725,
+ "yao4": 4726,
+ "xia4de5yi2ge4": 4727,
+ "si4": 4728,
+ "tu2": 4729,
+ "shan1": 4730,
+ "zui4": 4731,
+ "yin1": 4732,
+ "er2": 4733,
+ "tong2": 4734,
+ "dong1": 4735,
+ "yu4": 4736,
+ "yan2": 4737,
+ "qian2": 4738,
+ "shu3xia4de5yi2ge4": 4739,
+ "jun1": 4740,
+ "ke3": 4741,
+ "wen2": 4742,
+ "fa3": 4743,
+ "luo2": 4744,
+ "zhu4": 4745,
+ "xi4": 4746,
+ "kou3": 4747,
+ "bei3": 4748,
+ "jian1": 4749,
+ "fa1": 4750,
+ "dian4": 4751,
+ "jiang1": 4752,
+ "wei4yu2": 4753,
+ "xiang4": 4754,
+ "zhi3": 4755,
+ "eng3": 4756,
+ "fang1": 4757,
+ "lan2": 4758,
+ "shu": 4759,
+ "ri4": 4760,
+ "lian2": 4761,
+ "shou3": 4762,
+ "qiu2": 4763,
+ "jin1": 4764,
+ "huo4": 4765,
+ "shu3xia4de5yi2ge4zhong3": 4766,
+ "fen1": 4767,
+ "nei4": 4768,
+ "gai1": 4769,
+ "mei3guo2": 4770,
+ "un2": 4771,
+ "ge2": 4772,
+ "bao3": 4773,
+ "qing1": 4774,
+ "gao1": 4775,
+ "tai2": 4776,
+ "xiao3": 4777,
+ "jie2": 4778,
+ "tian1": 4779,
+ "chang2": 4780,
+ "quan2": 4781,
+ "lie4": 4782,
+ "hai3": 4783,
+ "fei1": 4784,
+ "ti3": 4785,
+ "jue2": 4786,
+ "ou2": 4787,
+ "ci3": 4788,
+ "zu2": 4789,
+ "ni2": 4790,
+ "biao3": 4791,
+ "zhong1guo2": 4792,
+ "du4": 4793,
+ "yue4": 4794,
+ "xing4": 4795,
+ "sheng4": 4796,
+ "che1": 4797,
+ "dan1": 4798,
+ "jie1": 4799,
+ "lin2": 4800,
+ "ping2": 4801,
+ "fu3": 4802,
+ "gu3": 4803,
+ "jie4": 4804,
+ "v3": 4805,
+ "sheng3": 4806,
+ "na4": 4807,
+ "yuan4": 4808,
+ "zhang3": 4809,
+ "guan3": 4810,
+ "dao3": 4811,
+ "zu3": 4812,
+ "ding4": 4813,
+ "dian3": 4814,
+ "ceng2": 4815,
+ "ren2kou3": 4816,
+ "tai4": 4817,
+ "tong1": 4818,
+ "guo4": 4819,
+ "neng2": 4820,
+ "chang3": 4821,
+ "hua2": 4822,
+ "liu2": 4823,
+ "ying1": 4824,
+ "xiao4": 4825,
+ "ci4": 4826,
+ "bian4hua4": 4827,
+ "liang3": 4828,
+ "gong4": 4829,
+ "zhong4": 4830,
+ "de5yi1": 4831,
+ "se4": 4832,
+ "kai1": 4833,
+ "wang2": 4834,
+ "jiu4": 4835,
+ "shi1": 4836,
+ "shou4": 4837,
+ "mei2": 4838,
+ "feng1": 4839,
+ "ze2": 4840,
+ "tu2shi4": 4841,
+ "ti2": 4842,
+ "qi4": 4843,
+ "jiu3": 4844,
+ "shen1": 4845,
+ "zhe3": 4846,
+ "ren2kou3bian4hua4": 4847,
+ "ren2kou3bian4hua4tu2shi4": 4848,
+ "di4qu1": 4849,
+ "yang2": 4850,
+ "men5": 4851,
+ "long2": 4852,
+ "bing4": 4853,
+ "chan3": 4854,
+ "zhu1": 4855,
+ "wei3": 4856,
+ "wai4": 4857,
+ "xing1": 4858,
+ "bo1": 4859,
+ "bi3": 4860,
+ "tang2": 4861,
+ "hua1": 4862,
+ "bo2": 4863,
+ "shui3": 4864,
+ "shu1": 4865,
+ "dou1": 4866,
+ "sai4": 4867,
+ "chao2": 4868,
+ "bi4": 4869,
+ "ling2": 4870,
+ "lei4": 4871,
+ "da4xue2": 4872,
+ "fen4": 4873,
+ "shu3de5": 4874,
+ "mu3": 4875,
+ "jiao1": 4876,
+ "dang1": 4877,
+ "cheng1": 4878,
+ "tong3": 4879,
+ "nv3": 4880,
+ "qi3": 4881,
+ "yan3": 4882,
+ "mian4": 4883,
+ "luo4": 4884,
+ "jing4": 4885,
+ "ge1": 4886,
+ "ru4": 4887,
+ "dan4": 4888,
+ "ri4ben3": 4889,
+ "pu3": 4890,
+ "yun4": 4891,
+ "huang2": 4892,
+ "wo3": 4893,
+ "lv": 4894,
+ "hai2": 4895,
+ "shi4yi1": 4896,
+ "xie1": 4897,
+ "ying3": 4898,
+ "wu2": 4899,
+ "shen2": 4900,
+ "wang3": 4901,
+ "guang3": 4902,
+ "liu4": 4903,
+ "su4": 4904,
+ "shi4zhen4": 4905,
+ "can1": 4906,
+ "cao3": 4907,
+ "xia2": 4908,
+ "ka3": 4909,
+ "da2": 4910,
+ "hu4": 4911,
+ "ban4": 4912,
+ "dang3": 4913,
+ "hu2": 4914,
+ "zong3": 4915,
+ "deng3": 4916,
+ "de5yi2ge4shi4zhen4": 4917,
+ "chuan2": 4918,
+ "mo4": 4919,
+ "zhang1": 4920,
+ "ban1": 4921,
+ "mo2": 4922,
+ "cha2": 4923,
+ "ce4": 4924,
+ "zhu3yao4": 4925,
+ "tou2": 4926,
+ "ju2": 4927,
+ "shi4wei4yu2": 4928,
+ "sa4": 4929,
+ "un1": 4930,
+ "ke3yi3": 4931,
+ "du1": 4932,
+ "han4": 4933,
+ "liang4": 4934,
+ "sha1": 4935,
+ "jia3": 4936,
+ "zi1": 4937,
+ "lv4": 4938,
+ "fu1": 4939,
+ "xian1": 4940,
+ "xu4": 4941,
+ "guang1": 4942,
+ "meng2": 4943,
+ "bao4": 4944,
+ "you4": 4945,
+ "rong2": 4946,
+ "zhi1yi1": 4947,
+ "wei1": 4948,
+ "mao2": 4949,
+ "guo2jia1": 4950,
+ "cong2": 4951,
+ "gou4": 4952,
+ "tie3": 4953,
+ "zhen1": 4954,
+ "du2": 4955,
+ "bian1": 4956,
+ "ci2": 4957,
+ "qu3": 4958,
+ "fan4": 4959,
+ "xiang3": 4960,
+ "men2": 4961,
+ "ju1": 4962,
+ "hong2": 4963,
+ "zi3": 4964,
+ "ta1men5": 4965,
+ "ji3": 4966,
+ "zong1": 4967,
+ "zhou1de5yi2ge4shi4zhen4": 4968,
+ "tuan2": 4969,
+ "jing3": 4970,
+ "gong1si1": 4971,
+ "xie4": 4972,
+ "li2": 4973,
+ "li4shi3": 4974,
+ "bao1": 4975,
+ "gang3": 4976,
+ "gui1": 4977,
+ "zheng1": 4978,
+ "zhi2wu4": 4979,
+ "ta1de5": 4980,
+ "pin3": 4981,
+ "zhuan1": 4982,
+ "chong2": 4983,
+ "shi3yong4": 4984,
+ "wa3": 4985,
+ "shuo1": 4986,
+ "chuan1": 4987,
+ "lei2": 4988,
+ "wan1": 4989,
+ "huo2": 4990,
+ "su1": 4991,
+ "zao3": 4992,
+ "gai3": 4993,
+ "qu4": 4994,
+ "gu4": 4995,
+ "xi2": 4996,
+ "hang2": 4997,
+ "ying4": 4998,
+ "cun1": 4999,
+ "gen1": 5000,
+ "ying2": 5001,
+ "ting2": 5002,
+ "cheng2shi4": 5003,
+ "jiang3": 5004,
+ "ling3": 5005,
+ "lun2": 5006,
+ "bu4fen4": 5007,
+ "deng1": 5008,
+ "xuan3": 5009,
+ "dong4wu4": 5010,
+ "de2guo2": 5011,
+ "xian3": 5012,
+ "fan3": 5013,
+ "zhe5": 5014,
+ "han2": 5015,
+ "hao4": 5016,
+ "mi4": 5017,
+ "ran2": 5018,
+ "qin1": 5019,
+ "tiao2": 5020,
+ "zhan3": 5021,
+ "[ar]": 5022,
+ "[zh-cn]": 5023,
+ "¡": 5024,
+ "é": 5025,
+ "shi": 5026,
+ "tsu": 5027,
+ "teki": 5028,
+ "nai": 5029,
+ "aru": 5030,
+ "uu": 5031,
+ "kai": 5032,
+ "shite": 5033,
+ "mono": 5034,
+ "koto": 5035,
+ "kara": 5036,
+ "shita": 5037,
+ "suru": 5038,
+ "masu": 5039,
+ "tai": 5040,
+ "ware": 5041,
+ "shin": 5042,
+ "oku": 5043,
+ "yuu": 5044,
+ "iru": 5045,
+ "jiko": 5046,
+ "desu": 5047,
+ "rare": 5048,
+ "shou": 5049,
+ "sha": 5050,
+ "sekai": 5051,
+ "kyou": 5052,
+ "mashita": 5053,
+ "nara": 5054,
+ "kei": 5055,
+ "ita": 5056,
+ "ari": 5057,
+ "itsu": 5058,
+ "kono": 5059,
+ "naka": 5060,
+ "chou": 5061,
+ "sore": 5062,
+ "naru": 5063,
+ "gaku": 5064,
+ "reba": 5065,
+ "hito": 5066,
+ "sai": 5067,
+ "nan": 5068,
+ "dai": 5069,
+ "tsuku": 5070,
+ "shiki": 5071,
+ "sare": 5072,
+ "naku": 5073,
+ "jun": 5074,
+ "kaku": 5075,
+ "zai": 5076,
+ "wata": 5077,
+ "shuu": 5078,
+ "ii": 5079,
+ "kare": 5080,
+ "shii": 5081,
+ "made": 5082,
+ "sho": 5083,
+ "kereba": 5084,
+ "shika": 5085,
+ "ichi": 5086,
+ "deki": 5087,
+ "nin": 5088,
+ "wareware": 5089,
+ "nakereba": 5090,
+ "oite": 5091,
+ "yaku": 5092,
+ "mujun": 5093,
+ "yoku": 5094,
+ "butsu": 5095,
+ "omo": 5096,
+ "gae": 5097,
+ "naranai": 5098,
+ "tachi": 5099,
+ "chuu": 5100,
+ "kangae": 5101,
+ "toki": 5102,
+ "koro": 5103,
+ "mujunteki": 5104,
+ "naga": 5105,
+ "jin": 5106,
+ "shima": 5107,
+ "iku": 5108,
+ "imasu": 5109,
+ "hon": 5110,
+ "kae": 5111,
+ "kore": 5112,
+ "kita": 5113,
+ "datta": 5114,
+ "jitsu": 5115,
+ "mae": 5116,
+ "toku": 5117,
+ "douitsu": 5118,
+ "ritsu": 5119,
+ "kyuu": 5120,
+ "hyou": 5121,
+ "rareta": 5122,
+ "keisei": 5123,
+ "kkan": 5124,
+ "rareru": 5125,
+ "mou": 5126,
+ "doko": 5127,
+ "ryou": 5128,
+ "dake": 5129,
+ "nakatta": 5130,
+ "soko": 5131,
+ "tabe": 5132,
+ "hana": 5133,
+ "fuku": 5134,
+ "yasu": 5135,
+ "wataku": 5136,
+ "yama": 5137,
+ "kyo": 5138,
+ "genzai": 5139,
+ "boku": 5140,
+ "ata": 5141,
+ "kawa": 5142,
+ "masen": 5143,
+ "juu": 5144,
+ "natte": 5145,
+ "watakushi": 5146,
+ "yotte": 5147,
+ "hai": 5148,
+ "jishin": 5149,
+ "rete": 5150,
+ "oka": 5151,
+ "kagaku": 5152,
+ "natta": 5153,
+ "karu": 5154,
+ "nari": 5155,
+ "mata": 5156,
+ "kuru": 5157,
+ "gai": 5158,
+ "kari": 5159,
+ "shakai": 5160,
+ "koui": 5161,
+ "yori": 5162,
+ "setsu": 5163,
+ "reru": 5164,
+ "tokoro": 5165,
+ "jutsu": 5166,
+ "saku": 5167,
+ "ttai": 5168,
+ "ningen": 5169,
+ "tame": 5170,
+ "kankyou": 5171,
+ "ooku": 5172,
+ "watashi": 5173,
+ "tsukuru": 5174,
+ "sugi": 5175,
+ "jibun": 5176,
+ "shitsu": 5177,
+ "keru": 5178,
+ "kishi": 5179,
+ "shikashi": 5180,
+ "moto": 5181,
+ "mari": 5182,
+ "itte": 5183,
+ "deshita": 5184,
+ "nde": 5185,
+ "arimasu": 5186,
+ "koe": 5187,
+ "zettai": 5188,
+ "kkanteki": 5189,
+ "rekishi": 5190,
+ "dekiru": 5191,
+ "tsuka": 5192,
+ "itta": 5193,
+ "kobutsu": 5194,
+ "miru": 5195,
+ "shoku": 5196,
+ "shimasu": 5197,
+ "gijutsu": 5198,
+ "gyou": 5199,
+ "joushiki": 5200,
+ "atta": 5201,
+ "hodo": 5202,
+ "koko": 5203,
+ "tsukurareta": 5204,
+ "zoku": 5205,
+ "hitei": 5206,
+ "koku": 5207,
+ "rekishiteki": 5208,
+ "kete": 5209,
+ "kako": 5210,
+ "nagara": 5211,
+ "kakaru": 5212,
+ "shutai": 5213,
+ "haji": 5214,
+ "taku": 5215,
+ "douitsuteki": 5216,
+ "mete": 5217,
+ "tsuu": 5218,
+ "sarete": 5219,
+ "genjitsu": 5220,
+ "bai": 5221,
+ "nawa": 5222,
+ "jikan": 5223,
+ "waru": 5224,
+ "rt": 5225,
+ "atsu": 5226,
+ "soku": 5227,
+ "kouiteki": 5228,
+ "kata": 5229,
+ "tetsu": 5230,
+ "gawa": 5231,
+ "kedo": 5232,
+ "reta": 5233,
+ "sayou": 5234,
+ "tteru": 5235,
+ "tori": 5236,
+ "kimi": 5237,
+ "mura": 5238,
+ "sareru": 5239,
+ "machi": 5240,
+ "kya": 5241,
+ "osa": 5242,
+ "konna": 5243,
+ "aku": 5244,
+ "sareta": 5245,
+ "ipp": 5246,
+ "shiku": 5247,
+ "uchi": 5248,
+ "hitotsu": 5249,
+ "hatara": 5250,
+ "tachiba": 5251,
+ "shiro": 5252,
+ "katachi": 5253,
+ "tomo": 5254,
+ "ete": 5255,
+ "meru": 5256,
+ "nichi": 5257,
+ "dare": 5258,
+ "katta": 5259,
+ "eru": 5260,
+ "suki": 5261,
+ "ooki": 5262,
+ "maru": 5263,
+ "moku": 5264,
+ "oko": 5265,
+ "kangaerareru": 5266,
+ "oto": 5267,
+ "tanni": 5268,
+ "tada": 5269,
+ "taiteki": 5270,
+ "motte": 5271,
+ "kinou": 5272,
+ "shinai": 5273,
+ "kki": 5274,
+ "tari": 5275,
+ "ranai": 5276,
+ "kkou": 5277,
+ "mirai": 5278,
+ "ppon": 5279,
+ "goto": 5280,
+ "hitsu": 5281,
+ "teru": 5282,
+ "mochi": 5283,
+ "katsu": 5284,
+ "nyuu": 5285,
+ "zuka": 5286,
+ "tsuite": 5287,
+ "nomi": 5288,
+ "sugu": 5289,
+ "kuda": 5290,
+ "tetsugaku": 5291,
+ "ika": 5292,
+ "ronri": 5293,
+ "oki": 5294,
+ "nippon": 5295,
+ "shimashita": 5296,
+ "chishiki": 5297,
+ "chokkanteki": 5298,
+ "suko": 5299,
+ "kuu": 5300,
+ "arou": 5301,
+ "katte": 5302,
+ "kuri": 5303,
+ "inai": 5304,
+ "hyougen": 5305,
+ "ishiki": 5306,
+ "doku": 5307,
+ "atte": 5308,
+ "atara": 5309,
+ "wari": 5310,
+ "kao": 5311,
+ "seisan": 5312,
+ "hanashi": 5313,
+ "kake": 5314,
+ "naji": 5315,
+ "sunawa": 5316,
+ "sunawachi": 5317,
+ "ugo": 5318,
+ "suu": 5319,
+ "bara": 5320,
+ "hiro": 5321,
+ "iwa": 5322,
+ "betsu": 5323,
+ "yoi": 5324,
+ "seru": 5325,
+ "shiteru": 5326,
+ "rarete": 5327,
+ "toshi": 5328,
+ "seki": 5329,
+ "tairitsu": 5330,
+ "wakara": 5331,
+ "tokyo": 5332,
+ "kka": 5333,
+ "kyoku": 5334,
+ "iro": 5335,
+ "mite": 5336,
+ "saki": 5337,
+ "kanji": 5338,
+ "mita": 5339,
+ "sube": 5340,
+ "ryoku": 5341,
+ "matta": 5342,
+ "kudasai": 5343,
+ "omoi": 5344,
+ "wareru": 5345,
+ "hitsuyou": 5346,
+ "kashi": 5347,
+ "renai": 5348,
+ "kankei": 5349,
+ "gatte": 5350,
+ "ochi": 5351,
+ "motsu": 5352,
+ "sonzai": 5353,
+ "taishite": 5354,
+ "ame": 5355,
+ "seimei": 5356,
+ "kano": 5357,
+ "giri": 5358,
+ "kangaeru": 5359,
+ "yue": 5360,
+ "asa": 5361,
+ "onaji": 5362,
+ "yoru": 5363,
+ "niku": 5364,
+ "osaka": 5365,
+ "sukoshi": 5366,
+ "tama": 5367,
+ "kanojo": 5368,
+ "kite": 5369,
+ "mondai": 5370,
+ "amari": 5371,
+ "eki": 5372,
+ "kojin": 5373,
+ "haya": 5374,
+ "dete": 5375,
+ "atarashii": 5376,
+ "awa": 5377,
+ "gakkou": 5378,
+ "tsuzu": 5379,
+ "shukan": 5380,
+ "imashita": 5381,
+ "atae": 5382,
+ "darou": 5383,
+ "hataraku": 5384,
+ "gata": 5385,
+ "dachi": 5386,
+ "matsu": 5387,
+ "arimasen": 5388,
+ "seibutsu": 5389,
+ "mitsu": 5390,
+ "heya": 5391,
+ "yasui": 5392,
+ "deni": 5393,
+ "noko": 5394,
+ "haha": 5395,
+ "domo": 5396,
+ "kami": 5397,
+ "sudeni": 5398,
+ "nao": 5399,
+ "raku": 5400,
+ "ike": 5401,
+ "meta": 5402,
+ "kodomo": 5403,
+ "soshite": 5404,
+ "game": 5405,
+ "bakari": 5406,
+ "tote": 5407,
+ "hatsu": 5408,
+ "mise": 5409,
+ "mokuteki": 5410,
+ "dakara": 5411,
+ "[ja]": 5412,
+ "ő": 5413,
+ "ű": 5414,
+ "そ": 5415,
+ "な": 5416,
+ "ん": 5417,
+ "포": 5418,
+ "�": 5419,
+ "gy": 5420,
+ "eg": 5421,
+ "cs": 5422,
+ "ál": 5423,
+ "egy": 5424,
+ "át": 5425,
+ "ott": 5426,
+ "ett": 5427,
+ "meg": 5428,
+ "hogy": 5429,
+ "ég": 5430,
+ "ól": 5431,
+ "nek": 5432,
+ "volt": 5433,
+ "ág": 5434,
+ "nk": 5435,
+ "ék": 5436,
+ "ít": 5437,
+ "ák": 5438,
+ "ud": 5439,
+ "szer": 5440,
+ "mind": 5441,
+ "oz": 5442,
+ "ép": 5443,
+ "ért": 5444,
+ "mond": 5445,
+ "szt": 5446,
+ "nak": 5447,
+ "ől": 5448,
+ "csak": 5449,
+ "oly": 5450,
+ "áll": 5451,
+ "ány": 5452,
+ "mint": 5453,
+ "már": 5454,
+ "ött": 5455,
+ "nagy": 5456,
+ "ész": 5457,
+ "azt": 5458,
+ "elő": 5459,
+ "tud": 5460,
+ "ény": 5461,
+ "áz": 5462,
+ "még": 5463,
+ "köz": 5464,
+ "ely": 5465,
+ "ség": 5466,
+ "hoz": 5467,
+ "uk": 5468,
+ "kez": 5469,
+ "ám": 5470,
+ "aj": 5471,
+ "unk": 5472,
+ "vagy": 5473,
+ "szem": 5474,
+ "ember": 5475,
+ "fog": 5476,
+ "mert": 5477,
+ "ös": 5478,
+ "ság": 5479,
+ "leg": 5480,
+ "ünk": 5481,
+ "hát": 5482,
+ "ony": 5483,
+ "ezt": 5484,
+ "minden": 5485,
+ "ült": 5486,
+ "jó": 5487,
+ "kis": 5488,
+ "áj": 5489,
+ "úgy": 5490,
+ "most": 5491,
+ "ír": 5492,
+ "itt": 5493,
+ "elt": 5494,
+ "mondta": 5495,
+ "kell": 5496,
+ "ált": 5497,
+ "érd": 5498,
+ "tö": 5499,
+ "vár": 5500,
+ "lát": 5501,
+ "ők": 5502,
+ "vet": 5503,
+ "után": 5504,
+ "két": 5505,
+ "nap": 5506,
+ "ív": 5507,
+ "ály": 5508,
+ "vég": 5509,
+ "ök": 5510,
+ "dul": 5511,
+ "néz": 5512,
+ "ában": 5513,
+ "kül": 5514,
+ "akkor": 5515,
+ "szél": 5516,
+ "új": 5517,
+ "olyan": 5518,
+ "ked": 5519,
+ "hely": 5520,
+ "tör": 5521,
+ "ból": 5522,
+ "elm": 5523,
+ "ára": 5524,
+ "ló": 5525,
+ "volna": 5526,
+ "lehet": 5527,
+ "ebb": 5528,
+ "sok": 5529,
+ "olt": 5530,
+ "eket": 5531,
+ "bor": 5532,
+ "fej": 5533,
+ "gond": 5534,
+ "akar": 5535,
+ "fél": 5536,
+ "úl": 5537,
+ "otta": 5538,
+ "valami": 5539,
+ "jel": 5540,
+ "éd": 5541,
+ "arc": 5542,
+ "hall": 5543,
+ "föl": 5544,
+ "ába": 5545,
+ "olg": 5546,
+ "kir": 5547,
+ "old": 5548,
+ "kérd": 5549,
+ "jár": 5550,
+ "úr": 5551,
+ "zs": 5552,
+ "élet": 5553,
+ "ját": 5554,
+ "ov": 5555,
+ "éz": 5556,
+ "vil": 5557,
+ "őr": 5558,
+ "ög": 5559,
+ "lesz": 5560,
+ "koz": 5561,
+ "ább": 5562,
+ "király": 5563,
+ "eng": 5564,
+ "igaz": 5565,
+ "haj": 5566,
+ "kod": 5567,
+ "ról": 5568,
+ "több": 5569,
+ "szó": 5570,
+ "ében": 5571,
+ "öt": 5572,
+ "nyi": 5573,
+ "szól": 5574,
+ "gondol": 5575,
+ "egész": 5576,
+ "így": 5577,
+ "ős": 5578,
+ "obb": 5579,
+ "osan": 5580,
+ "ből": 5581,
+ "abb": 5582,
+ "őt": 5583,
+ "nál": 5584,
+ "kép": 5585,
+ "aztán": 5586,
+ "tart": 5587,
+ "beszél": 5588,
+ "előtt": 5589,
+ "aszt": 5590,
+ "maj": 5591,
+ "kör": 5592,
+ "hang": 5593,
+ "íz": 5594,
+ "incs": 5595,
+ "év": 5596,
+ "ód": 5597,
+ "ók": 5598,
+ "hozz": 5599,
+ "okat": 5600,
+ "nagyon": 5601,
+ "ház": 5602,
+ "ped": 5603,
+ "ezte": 5604,
+ "etlen": 5605,
+ "neki": 5606,
+ "majd": 5607,
+ "szony": 5608,
+ "ának": 5609,
+ "felé": 5610,
+ "egyszer": 5611,
+ "adt": 5612,
+ "gyer": 5613,
+ "amikor": 5614,
+ "foly": 5615,
+ "szak": 5616,
+ "őd": 5617,
+ "hú": 5618,
+ "ász": 5619,
+ "amely": 5620,
+ "ére": 5621,
+ "ilyen": 5622,
+ "oda": 5623,
+ "ják": 5624,
+ "tár": 5625,
+ "ával": 5626,
+ "lak": 5627,
+ "gyan": 5628,
+ "ély": 5629,
+ "út": 5630,
+ "kezd": 5631,
+ "mell": 5632,
+ "mikor": 5633,
+ "hez": 5634,
+ "való": 5635,
+ "szeret": 5636,
+ "rend": 5637,
+ "vissza": 5638,
+ "fő": 5639,
+ "asszony": 5640,
+ "ről": 5641,
+ "pedig": 5642,
+ "szép": 5643,
+ "ták": 5644,
+ "öv": 5645,
+ "világ": 5646,
+ "maga": 5647,
+ "szik": 5648,
+ "éj": 5649,
+ "ént": 5650,
+ "jött": 5651,
+ "szí": 5652,
+ "gat": 5653,
+ "ettem": 5654,
+ "hány": 5655,
+ "ást": 5656,
+ "ahol": 5657,
+ "őket": 5658,
+ "hár": 5659,
+ "nő": 5660,
+ "csi": 5661,
+ "talál": 5662,
+ "elte": 5663,
+ "látt": 5664,
+ "tört": 5665,
+ "hagy": 5666,
+ "esz": 5667,
+ "nél": 5668,
+ "kut": 5669,
+ "lány": 5670,
+ "amit": 5671,
+ "ső": 5672,
+ "ellen": 5673,
+ "magát": 5674,
+ "ugyan": 5675,
+ "külön": 5676,
+ "asz": 5677,
+ "mindig": 5678,
+ "lép": 5679,
+ "talán": 5680,
+ "szor": 5681,
+ "illan": 5682,
+ "nincs": 5683,
+ "vagyok": 5684,
+ "telen": 5685,
+ "ismer": 5686,
+ "isten": 5687,
+ "ított": 5688,
+ "jobb": 5689,
+ "ves": 5690,
+ "dult": 5691,
+ "juk": 5692,
+ "szen": 5693,
+ "öm": 5694,
+ "lett": 5695,
+ "egyik": 5696,
+ "bár": 5697,
+ "szi": 5698,
+ "szív": 5699,
+ "azon": 5700,
+ "eszt": 5701,
+ "föld": 5702,
+ "kuty": 5703,
+ "pillan": 5704,
+ "fér": 5705,
+ "től": 5706,
+ "tű": 5707,
+ "ébe": 5708,
+ "tött": 5709,
+ "barát": 5710,
+ "íg": 5711,
+ "ahogy": 5712,
+ "eh": 5713,
+ "ep": 5714,
+ "jelent": 5715,
+ "tat": 5716,
+ "szeg": 5717,
+ "mintha": 5718,
+ "egyen": 5719,
+ "szab": 5720,
+ "bizony": 5721,
+ "jon": 5722,
+ "öreg": 5723,
+ "dolg": 5724,
+ "csap": 5725,
+ "tiszt": 5726,
+ "állt": 5727,
+ "ancs": 5728,
+ "idő": 5729,
+ "ügy": 5730,
+ "miért": 5731,
+ "ót": 5732,
+ "csin": 5733,
+ "ének": 5734,
+ "vér": 5735,
+ "jól": 5736,
+ "alatt": 5737,
+ "mely": 5738,
+ "semmi": 5739,
+ "nyug": 5740,
+ "vág": 5741,
+ "követ": 5742,
+ "össze": 5743,
+ "mad": 5744,
+ "acs": 5745,
+ "fiú": 5746,
+ "másik": 5747,
+ "jön": 5748,
+ "szám": 5749,
+ "rész": 5750,
+ "kér": 5751,
+ "ével": 5752,
+ "[hu]": 5753,
+ "%": 5754,
+ "0": 5755,
+ "6": 5756,
+ "7": 5757,
+ "8": 5758,
+ "9": 5759,
+ "A": 5760,
+ "B": 5761,
+ "C": 5762,
+ "D": 5763,
+ "E": 5764,
+ "F": 5765,
+ "G": 5766,
+ "H": 5767,
+ "I": 5768,
+ "J": 5769,
+ "K": 5770,
+ "L": 5771,
+ "M": 5772,
+ "N": 5773,
+ "O": 5774,
+ "P": 5775,
+ "Q": 5776,
+ "R": 5777,
+ "S": 5778,
+ "T": 5779,
+ "U": 5780,
+ "V": 5781,
+ "W": 5782,
+ "X": 5783,
+ "Y": 5784,
+ "Z": 5785,
+ "Ł": 5786,
+ "α": 5787,
+ "ς": 5788,
+ "♥": 5789,
+ "か": 5790,
+ "ズ": 5791,
+ "因": 5792,
+ "国": 5793,
+ "怎": 5794,
+ "抱": 5795,
+ "推": 5796,
+ "有": 5797,
+ "樣": 5798,
+ "為": 5799,
+ "群": 5800,
+ "麼": 5801,
+ "eo": 5802,
+ "eul": 5803,
+ "eun": 5804,
+ "eon": 5805,
+ "ae": 5806,
+ "yeon": 5807,
+ "yeo": 5808,
+ "ui": 5809,
+ "hae": 5810,
+ "geo": 5811,
+ "neun": 5812,
+ "ssda": 5813,
+ "seo": 5814,
+ "eong": 5815,
+ "kk": 5816,
+ "jeo": 5817,
+ "deul": 5818,
+ "eum": 5819,
+ "yeong": 5820,
+ "geos": 5821,
+ "hag": 5822,
+ "aneun": 5823,
+ "iss": 5824,
+ "dae": 5825,
+ "eob": 5826,
+ "eol": 5827,
+ "geu": 5828,
+ "jeong": 5829,
+ "sae": 5830,
+ "doe": 5831,
+ "geul": 5832,
+ "eulo": 5833,
+ "bn": 5834,
+ "sang": 5835,
+ "bnida": 5836,
+ "haneun": 5837,
+ "jeog": 5838,
+ "saeng": 5839,
+ "ineun": 5840,
+ "anh": 5841,
+ "salam": 5842,
+ "eom": 5843,
+ "nae": 5844,
+ "gwa": 5845,
+ "yeol": 5846,
+ "eseo": 5847,
+ "myeon": 5848,
+ "ttae": 5849,
+ "hw": 5850,
+ "eobs": 5851,
+ "jang": 5852,
+ "gw": 5853,
+ "ileul": 5854,
+ "yeog": 5855,
+ "jeon": 5856,
+ "sig": 5857,
+ "jag": 5858,
+ "hago": 5859,
+ "deun": 5860,
+ "seong": 5861,
+ "gag": 5862,
+ "ham": 5863,
+ "dang": 5864,
+ "leul": 5865,
+ "sil": 5866,
+ "dong": 5867,
+ "handa": 5868,
+ "eossda": 5869,
+ "aeg": 5870,
+ "seon": 5871,
+ "haessda": 5872,
+ "issda": 5873,
+ "ege": 5874,
+ "mul": 5875,
+ "jung": 5876,
+ "jig": 5877,
+ "issneun": 5878,
+ "geun": 5879,
+ "seubnida": 5880,
+ "won": 5881,
+ "daneun": 5882,
+ "eoh": 5883,
+ "deo": 5884,
+ "gam": 5885,
+ "jal": 5886,
+ "haeng": 5887,
+ "yang": 5888,
+ "bang": 5889,
+ "jae": 5890,
+ "saenggag": 5891,
+ "hage": 5892,
+ "sog": 5893,
+ "eoss": 5894,
+ "jasin": 5895,
+ "jil": 5896,
+ "eog": 5897,
+ "gyeong": 5898,
+ "gong": 5899,
+ "deon": 5900,
+ "haess": 5901,
+ "eung": 5902,
+ "joh": 5903,
+ "nal": 5904,
+ "myeong": 5905,
+ "eona": 5906,
+ "igo": 5907,
+ "gyeol": 5908,
+ "yag": 5909,
+ "gwan": 5910,
+ "uli": 5911,
+ "yong": 5912,
+ "lyeo": 5913,
+ "jog": 5914,
+ "eohge": 5915,
+ "bog": 5916,
+ "tong": 5917,
+ "manh": 5918,
+ "jeol": 5919,
+ "geol": 5920,
+ "aga": 5921,
+ "naneun": 5922,
+ "uneun": 5923,
+ "cheol": 5924,
+ "dol": 5925,
+ "bad": 5926,
+ "hamyeon": 5927,
+ "yeossda": 5928,
+ "ibnida": 5929,
+ "gye": 5930,
+ "eos": 5931,
+ "hwal": 5932,
+ "salamdeul": 5933,
+ "jiman": 5934,
+ "dangsin": 5935,
+ "jib": 5936,
+ "ttaemun": 5937,
+ "ib": 5938,
+ "eneun": 5939,
+ "eug": 5940,
+ "jeom": 5941,
+ "geuleon": 5942,
+ "hwa": 5943,
+ "assda": 5944,
+ "beob": 5945,
+ "bae": 5946,
+ "yeoss": 5947,
+ "chin": 5948,
+ "chaeg": 5949,
+ "geon": 5950,
+ "naega": 5951,
+ "iga": 5952,
+ "sigan": 5953,
+ "gil": 5954,
+ "hyeon": 5955,
+ "lyeog": 5956,
+ "gug": 5957,
+ "pyeon": 5958,
+ "wae": 5959,
+ "jul": 5960,
+ "seul": 5961,
+ "deung": 5962,
+ "hajiman": 5963,
+ "eumyeon": 5964,
+ "pil": 5965,
+ "nyeon": 5966,
+ "tae": 5967,
+ "pyo": 5968,
+ "jineun": 5969,
+ "beon": 5970,
+ "hada": 5971,
+ "seol": 5972,
+ "sip": 5973,
+ "daleun": 5974,
+ "salm": 5975,
+ "gyo": 5976,
+ "cheon": 5977,
+ "hagi": 5978,
+ "cheoleom": 5979,
+ "gal": 5980,
+ "ila": 5981,
+ "kkaji": 5982,
+ "anhneun": 5983,
+ "habnida": 5984,
+ "tteon": 5985,
+ "haeseo": 5986,
+ "doenda": 5987,
+ "ttal": 5988,
+ "ilo": 5989,
+ "seub": 5990,
+ "byeon": 5991,
+ "myeo": 5992,
+ "beol": 5993,
+ "jeung": 5994,
+ "chim": 5995,
+ "hwang": 5996,
+ "euneun": 5997,
+ "jong": 5998,
+ "boda": 5999,
+ "nol": 6000,
+ "neom": 6001,
+ "buteo": 6002,
+ "jigeum": 6003,
+ "eobsda": 6004,
+ "daelo": 6005,
+ "yul": 6006,
+ "pyeong": 6007,
+ "seoneun": 6008,
+ "salang": 6009,
+ "seut": 6010,
+ "heom": 6011,
+ "hyang": 6012,
+ "gwang": 6013,
+ "eobsneun": 6014,
+ "hwag": 6015,
+ "gess": 6016,
+ "jagi": 6017,
+ "ileon": 6018,
+ "wihae": 6019,
+ "daehan": 6020,
+ "gaji": 6021,
+ "meog": 6022,
+ "jyeo": 6023,
+ "chaj": 6024,
+ "byeong": 6025,
+ "eod": 6026,
+ "gyeo": 6027,
+ "eoji": 6028,
+ "gul": 6029,
+ "modeun": 6030,
+ "insaeng": 6031,
+ "geulae": 6032,
+ "sasil": 6033,
+ "sib": 6034,
+ "chal": 6035,
+ "ilago": 6036,
+ "geum": 6037,
+ "doeneun": 6038,
+ "bol": 6039,
+ "gajang": 6040,
+ "geuligo": 6041,
+ "hyeong": 6042,
+ "haengbog": 6043,
+ "chul": 6044,
+ "chae": 6045,
+ "mang": 6046,
+ "dam": 6047,
+ "choe": 6048,
+ "sijag": 6049,
+ "cheong": 6050,
+ "ilaneun": 6051,
+ "ulineun": 6052,
+ "aen": 6053,
+ "kke": 6054,
+ "munje": 6055,
+ "teu": 6056,
+ "geuneun": 6057,
+ "bge": 6058,
+ "cheo": 6059,
+ "baeg": 6060,
+ "jug": 6061,
+ "sangdae": 6062,
+ "geugeos": 6063,
+ "dog": 6064,
+ "eus": 6065,
+ "jab": 6066,
+ "hyeo": 6067,
+ "tteohge": 6068,
+ "chil": 6069,
+ "swi": 6070,
+ "jileul": 6071,
+ "chang": 6072,
+ "ganeun": 6073,
+ "iji": 6074,
+ "dago": 6075,
+ "yohan": 6076,
+ "teug": 6077,
+ "ppun": 6078,
+ "aleul": 6079,
+ "haengdong": 6080,
+ "sesang": 6081,
+ "edo": 6082,
+ "mandeul": 6083,
+ "amyeon": 6084,
+ "kkae": 6085,
+ "bag": 6086,
+ "ideul": 6087,
+ "pum": 6088,
+ "meol": 6089,
+ "neul": 6090,
+ "hamkke": 6091,
+ "chung": 6092,
+ "dab": 6093,
+ "yug": 6094,
+ "sag": 6095,
+ "gwangye": 6096,
+ "ileohge": 6097,
+ "balo": 6098,
+ "neunde": 6099,
+ "hamyeo": 6100,
+ "geuleoh": 6101,
+ "anila": 6102,
+ "bangbeob": 6103,
+ "dasi": 6104,
+ "byeol": 6105,
+ "gyeon": 6106,
+ "gamjeong": 6107,
+ "oneul": 6108,
+ "janeun": 6109,
+ "yeom": 6110,
+ "lago": 6111,
+ "igi": 6112,
+ "hwan": 6113,
+ "teul": 6114,
+ "eoseo": 6115,
+ "sik": 6116,
+ "jaga": 6117,
+ "geuleom": 6118,
+ "geuleona": 6119,
+ "jeongdo": 6120,
+ "gyeog": 6121,
+ "geuleohge": 6122,
+ "geudeul": 6123,
+ "eut": 6124,
+ "imyeon": 6125,
+ "jjae": 6126,
+ "keun": 6127,
+ "isang": 6128,
+ "malhaessda": 6129,
+ "euge": 6130,
+ "nop": 6131,
+ "ingan": 6132,
+ "bomyeon": 6133,
+ "taeg": 6134,
+ "dwi": 6135,
+ "saneun": 6136,
+ "wan": 6137,
+ "anhgo": 6138,
+ "nugu": 6139,
+ "sung": 6140,
+ "damyeon": 6141,
+ "adeul": 6142,
+ "peul": 6143,
+ "ttala": 6144,
+ "geosdo": 6145,
+ "aji": 6146,
+ "meon": 6147,
+ "eumyeo": 6148,
+ "dolog": 6149,
+ "neung": 6150,
+ "modu": 6151,
+ "[ko]": 6152,
+ "\u0014": 6153,
+ "\u0016": 6154,
+ "$": 6155,
+ "*": 6156,
+ "|": 6157,
+ "°": 6158,
+ "º": 6159,
+ "ँ": 6160,
+ "ं": 6161,
+ "ः": 6162,
+ "अ": 6163,
+ "आ": 6164,
+ "इ": 6165,
+ "ई": 6166,
+ "उ": 6167,
+ "ऊ": 6168,
+ "ऋ": 6169,
+ "ऎ": 6170,
+ "ए": 6171,
+ "ऐ": 6172,
+ "ऑ": 6173,
+ "ऒ": 6174,
+ "ओ": 6175,
+ "औ": 6176,
+ "क": 6177,
+ "ख": 6178,
+ "ग": 6179,
+ "घ": 6180,
+ "ङ": 6181,
+ "च": 6182,
+ "छ": 6183,
+ "ज": 6184,
+ "झ": 6185,
+ "ञ": 6186,
+ "ट": 6187,
+ "ठ": 6188,
+ "ड": 6189,
+ "ढ": 6190,
+ "ण": 6191,
+ "त": 6192,
+ "थ": 6193,
+ "द": 6194,
+ "ध": 6195,
+ "न": 6196,
+ "ऩ": 6197,
+ "प": 6198,
+ "फ": 6199,
+ "ब": 6200,
+ "भ": 6201,
+ "म": 6202,
+ "य": 6203,
+ "र": 6204,
+ "ऱ": 6205,
+ "ल": 6206,
+ "ळ": 6207,
+ "व": 6208,
+ "श": 6209,
+ "ष": 6210,
+ "स": 6211,
+ "ह": 6212,
+ "़": 6213,
+ "ा": 6214,
+ "ि": 6215,
+ "ी": 6216,
+ "ु": 6217,
+ "ू": 6218,
+ "ृ": 6219,
+ "ॄ": 6220,
+ "ॅ": 6221,
+ "ॆ": 6222,
+ "े": 6223,
+ "ै": 6224,
+ "ॉ": 6225,
+ "ॊ": 6226,
+ "ो": 6227,
+ "ौ": 6228,
+ "्": 6229,
+ "ॐ": 6230,
+ "ॖ": 6231,
+ "क़": 6232,
+ "ख़": 6233,
+ "ग़": 6234,
+ "ज़": 6235,
+ "ड़": 6236,
+ "ढ़": 6237,
+ "फ़": 6238,
+ "य़": 6239,
+ "ॠ": 6240,
+ "।": 6241,
+ "॥": 6242,
+ "०": 6243,
+ "१": 6244,
+ "२": 6245,
+ "३": 6246,
+ "४": 6247,
+ "५": 6248,
+ "६": 6249,
+ "७": 6250,
+ "८": 6251,
+ "९": 6252,
+ "॰": 6253,
+ "ॲ": 6254,
+ "": 6255,
+ "": 6256,
+ "": 6257,
+ "": 6258,
+ "₹": 6259,
+ "के": 6260,
+ "है": 6261,
+ "ें": 6262,
+ "्र": 6263,
+ "ार": 6264,
+ "ने": 6265,
+ "या": 6266,
+ "में": 6267,
+ "से": 6268,
+ "की": 6269,
+ "का": 6270,
+ "ों": 6271,
+ "ता": 6272,
+ "कर": 6273,
+ "स्": 6274,
+ "कि": 6275,
+ "को": 6276,
+ "र्": 6277,
+ "ना": 6278,
+ "क्": 6279,
+ "ही": 6280,
+ "और": 6281,
+ "पर": 6282,
+ "ते": 6283,
+ "हो": 6284,
+ "प्र": 6285,
+ "ान": 6286,
+ "्य": 6287,
+ "ला": 6288,
+ "वा": 6289,
+ "ले": 6290,
+ "सा": 6291,
+ "हैं": 6292,
+ "लि": 6293,
+ "जा": 6294,
+ "हा": 6295,
+ "भी": 6296,
+ "वि": 6297,
+ "इस": 6298,
+ "ती": 6299,
+ "न्": 6300,
+ "रा": 6301,
+ "मा": 6302,
+ "दे": 6303,
+ "दि": 6304,
+ "बा": 6305,
+ "ति": 6306,
+ "था": 6307,
+ "नि": 6308,
+ "कार": 6309,
+ "एक": 6310,
+ "हीं": 6311,
+ "हु": 6312,
+ "ंग": 6313,
+ "ैं": 6314,
+ "नी": 6315,
+ "सी": 6316,
+ "अप": 6317,
+ "त्": 6318,
+ "नहीं": 6319,
+ "री": 6320,
+ "मे": 6321,
+ "मु": 6322,
+ "ित": 6323,
+ "तो": 6324,
+ "पा": 6325,
+ "ली": 6326,
+ "लिए": 6327,
+ "गा": 6328,
+ "ल्": 6329,
+ "रह": 6330,
+ "रे": 6331,
+ "क्ष": 6332,
+ "मैं": 6333,
+ "सम": 6334,
+ "उस": 6335,
+ "जि": 6336,
+ "त्र": 6337,
+ "मि": 6338,
+ "चा": 6339,
+ "ोग": 6340,
+ "सं": 6341,
+ "द्": 6342,
+ "सि": 6343,
+ "आप": 6344,
+ "तु": 6345,
+ "दा": 6346,
+ "कु": 6347,
+ "यों": 6348,
+ "वे": 6349,
+ "जी": 6350,
+ "्या": 6351,
+ "उन": 6352,
+ "िक": 6353,
+ "ये": 6354,
+ "भा": 6355,
+ "्ट": 6356,
+ "हम": 6357,
+ "स्ट": 6358,
+ "शा": 6359,
+ "ड़": 6360,
+ "ंद": 6361,
+ "खा": 6362,
+ "म्": 6363,
+ "श्": 6364,
+ "यह": 6365,
+ "सक": 6366,
+ "पू": 6367,
+ "किया": 6368,
+ "अपने": 6369,
+ "रू": 6370,
+ "सु": 6371,
+ "मी": 6372,
+ "हि": 6373,
+ "जो": 6374,
+ "थे": 6375,
+ "रि": 6376,
+ "दी": 6377,
+ "थी": 6378,
+ "गी": 6379,
+ "लोग": 6380,
+ "गया": 6381,
+ "तर": 6382,
+ "न्ह": 6383,
+ "च्": 6384,
+ "वार": 6385,
+ "बी": 6386,
+ "प्": 6387,
+ "दो": 6388,
+ "टी": 6389,
+ "शि": 6390,
+ "करने": 6391,
+ "गे": 6392,
+ "ैसे": 6393,
+ "इन": 6394,
+ "ंड": 6395,
+ "साथ": 6396,
+ "पु": 6397,
+ "बे": 6398,
+ "बार": 6399,
+ "वी": 6400,
+ "अन": 6401,
+ "हर": 6402,
+ "उन्ह": 6403,
+ "होता": 6404,
+ "जब": 6405,
+ "कुछ": 6406,
+ "मान": 6407,
+ "क्र": 6408,
+ "बि": 6409,
+ "पह": 6410,
+ "फि": 6411,
+ "सर": 6412,
+ "ारी": 6413,
+ "रो": 6414,
+ "दू": 6415,
+ "कहा": 6416,
+ "तक": 6417,
+ "शन": 6418,
+ "ब्": 6419,
+ "स्थ": 6420,
+ "वह": 6421,
+ "बाद": 6422,
+ "ओं": 6423,
+ "गु": 6424,
+ "ज्": 6425,
+ "्रे": 6426,
+ "गर": 6427,
+ "रहे": 6428,
+ "वर्": 6429,
+ "हू": 6430,
+ "ार्": 6431,
+ "पी": 6432,
+ "बहु": 6433,
+ "मुझ": 6434,
+ "्रा": 6435,
+ "दिया": 6436,
+ "सब": 6437,
+ "करते": 6438,
+ "अपनी": 6439,
+ "बहुत": 6440,
+ "कह": 6441,
+ "टे": 6442,
+ "हुए": 6443,
+ "किसी": 6444,
+ "रहा": 6445,
+ "ष्ट": 6446,
+ "ज़": 6447,
+ "बना": 6448,
+ "सो": 6449,
+ "डि": 6450,
+ "कोई": 6451,
+ "व्य": 6452,
+ "बात": 6453,
+ "रु": 6454,
+ "वो": 6455,
+ "मुझे": 6456,
+ "द्ध": 6457,
+ "चार": 6458,
+ "मेरे": 6459,
+ "वर": 6460,
+ "्री": 6461,
+ "जाता": 6462,
+ "नों": 6463,
+ "प्रा": 6464,
+ "देख": 6465,
+ "टा": 6466,
+ "क्या": 6467,
+ "अध": 6468,
+ "लग": 6469,
+ "लो": 6470,
+ "पि": 6471,
+ "यु": 6472,
+ "चे": 6473,
+ "जिस": 6474,
+ "ंत": 6475,
+ "ानी": 6476,
+ "पै": 6477,
+ "जन": 6478,
+ "ारे": 6479,
+ "ची": 6480,
+ "मिल": 6481,
+ "दु": 6482,
+ "देश": 6483,
+ "च्छ": 6484,
+ "ष्": 6485,
+ "सू": 6486,
+ "खे": 6487,
+ "चु": 6488,
+ "िया": 6489,
+ "लगा": 6490,
+ "बु": 6491,
+ "उनके": 6492,
+ "ज्ञ": 6493,
+ "क्षा": 6494,
+ "तरह": 6495,
+ "्यादा": 6496,
+ "वाले": 6497,
+ "पूर्": 6498,
+ "मैंने": 6499,
+ "काम": 6500,
+ "रूप": 6501,
+ "होती": 6502,
+ "उप": 6503,
+ "जान": 6504,
+ "प्रकार": 6505,
+ "भार": 6506,
+ "मन": 6507,
+ "हुआ": 6508,
+ "टर": 6509,
+ "हूँ": 6510,
+ "परि": 6511,
+ "पास": 6512,
+ "अनु": 6513,
+ "राज": 6514,
+ "लोगों": 6515,
+ "अब": 6516,
+ "समझ": 6517,
+ "डी": 6518,
+ "मौ": 6519,
+ "शु": 6520,
+ "चि": 6521,
+ "पे": 6522,
+ "कृ": 6523,
+ "सकते": 6524,
+ "मह": 6525,
+ "योग": 6526,
+ "दर्": 6527,
+ "उसे": 6528,
+ "ंध": 6529,
+ "डा": 6530,
+ "जाए": 6531,
+ "बो": 6532,
+ "ूल": 6533,
+ "मो": 6534,
+ "ोंने": 6535,
+ "ंस": 6536,
+ "तुम": 6537,
+ "पहले": 6538,
+ "बता": 6539,
+ "तथा": 6540,
+ "यो": 6541,
+ "गई": 6542,
+ "उत्": 6543,
+ "सकता": 6544,
+ "कम": 6545,
+ "ज्यादा": 6546,
+ "रख": 6547,
+ "समय": 6548,
+ "ारा": 6549,
+ "अगर": 6550,
+ "स्त": 6551,
+ "चल": 6552,
+ "फिर": 6553,
+ "वारा": 6554,
+ "करना": 6555,
+ "शी": 6556,
+ "गए": 6557,
+ "बन": 6558,
+ "ौर": 6559,
+ "होने": 6560,
+ "चाह": 6561,
+ "खु": 6562,
+ "हाँ": 6563,
+ "उन्हें": 6564,
+ "उन्होंने": 6565,
+ "छो": 6566,
+ "म्ह": 6567,
+ "प्रति": 6568,
+ "निक": 6569,
+ "वन": 6570,
+ "्यू": 6571,
+ "रही": 6572,
+ "तुम्ह": 6573,
+ "जैसे": 6574,
+ "ियों": 6575,
+ "क्यों": 6576,
+ "लों": 6577,
+ "फ़": 6578,
+ "ंत्र": 6579,
+ "होते": 6580,
+ "क्ति": 6581,
+ "त्य": 6582,
+ "कर्": 6583,
+ "कई": 6584,
+ "वं": 6585,
+ "किन": 6586,
+ "पो": 6587,
+ "कारण": 6588,
+ "ड़ी": 6589,
+ "भि": 6590,
+ "इसके": 6591,
+ "बर": 6592,
+ "उसके": 6593,
+ "द्वारा": 6594,
+ "शे": 6595,
+ "कॉ": 6596,
+ "दिन": 6597,
+ "न्न": 6598,
+ "ड़ा": 6599,
+ "स्व": 6600,
+ "निर्": 6601,
+ "मुख": 6602,
+ "लिया": 6603,
+ "टि": 6604,
+ "ज्ञान": 6605,
+ "क्त": 6606,
+ "द्र": 6607,
+ "ग्": 6608,
+ "क्स": 6609,
+ "मै": 6610,
+ "गो": 6611,
+ "जे": 6612,
+ "ट्र": 6613,
+ "मार": 6614,
+ "त्व": 6615,
+ "धार": 6616,
+ "भाव": 6617,
+ "करता": 6618,
+ "खि": 6619,
+ "कं": 6620,
+ "चाहि": 6621,
+ "यर": 6622,
+ "प्त": 6623,
+ "कों": 6624,
+ "ंच": 6625,
+ "जु": 6626,
+ "मत": 6627,
+ "अच्छ": 6628,
+ "हुई": 6629,
+ "कभी": 6630,
+ "लेकिन": 6631,
+ "भू": 6632,
+ "अपना": 6633,
+ "दूस": 6634,
+ "चाहिए": 6635,
+ "यू": 6636,
+ "घर": 6637,
+ "सबसे": 6638,
+ "मेरी": 6639,
+ "नाम": 6640,
+ "ढ़": 6641,
+ "ंट": 6642,
+ "ेंगे": 6643,
+ "बै": 6644,
+ "फा": 6645,
+ "एवं": 6646,
+ "यी": 6647,
+ "ग्र": 6648,
+ "क्षे": 6649,
+ "आज": 6650,
+ "आपको": 6651,
+ "भाग": 6652,
+ "ठा": 6653,
+ "कै": 6654,
+ "भारत": 6655,
+ "उनकी": 6656,
+ "पहु": 6657,
+ "सभी": 6658,
+ "धा": 6659,
+ "णा": 6660,
+ "सान": 6661,
+ "होगा": 6662,
+ "तब": 6663,
+ "संग": 6664,
+ "पर्": 6665,
+ "अव": 6666,
+ "तना": 6667,
+ "गि": 6668,
+ "यन": 6669,
+ "स्था": 6670,
+ "चित": 6671,
+ "ट्": 6672,
+ "छा": 6673,
+ "जाने": 6674,
+ "क्षेत्र": 6675,
+ "वाली": 6676,
+ "पूर्ण": 6677,
+ "समा": 6678,
+ "कारी": 6679,
+ "[hi]": 6680
+ },
+ "merges": [
+ "t h",
+ "i n",
+ "th e",
+ "a n",
+ "e r",
+ "o u",
+ "r e",
+ "o n",
+ "a t",
+ "e d",
+ "e n",
+ "t o",
+ "in g",
+ "an d",
+ "i s",
+ "a s",
+ "a l",
+ "o r",
+ "o f",
+ "a r",
+ "i t",
+ "e s",
+ "h e",
+ "s t",
+ "l e",
+ "o m",
+ "s e",
+ "b e",
+ "a d",
+ "o w",
+ "l y",
+ "c h",
+ "w h",
+ "th at",
+ "y ou",
+ "l i",
+ "v e",
+ "a c",
+ "t i",
+ "l d",
+ "m e",
+ "w as",
+ "g h",
+ "i d",
+ "l l",
+ "w i",
+ "en t",
+ "f or",
+ "a y",
+ "r o",
+ "v er",
+ "i c",
+ "h er",
+ "k e",
+ "h is",
+ "n o",
+ "u t",
+ "u n",
+ "i r",
+ "l o",
+ "w e",
+ "r i",
+ "h a",
+ "wi th",
+ "gh t",
+ "ou t",
+ "i m",
+ "i on",
+ "al l",
+ "a b",
+ "on e",
+ "n e",
+ "g e",
+ "ou ld",
+ "t er",
+ "m o",
+ "h ad",
+ "c e",
+ "s he",
+ "g o",
+ "s h",
+ "u r",
+ "a m",
+ "s o",
+ "p e",
+ "m y",
+ "d e",
+ "a re",
+ "b ut",
+ "om e",
+ "f r",
+ "the r",
+ "f e",
+ "s u",
+ "d o",
+ "c on",
+ "t e",
+ "a in",
+ "er e",
+ "p o",
+ "i f",
+ "the y",
+ "u s",
+ "a g",
+ "t r",
+ "n ow",
+ "ou n",
+ "th is",
+ "ha ve",
+ "no t",
+ "s a",
+ "i l",
+ "u p",
+ "th ing",
+ "fr om",
+ "a p",
+ "h im",
+ "ac k",
+ "at ion",
+ "an t",
+ "ou r",
+ "o p",
+ "li ke",
+ "u st",
+ "es s",
+ "b o",
+ "o k",
+ "u l",
+ "in d",
+ "e x",
+ "c om",
+ "s ome",
+ "the re",
+ "er s",
+ "c o",
+ "re s",
+ "m an",
+ "ar d",
+ "p l",
+ "w or",
+ "w ay",
+ "ti on",
+ "f o",
+ "c a",
+ "w ere",
+ "b y",
+ "at e",
+ "p ro",
+ "t ed",
+ "oun d",
+ "ow n",
+ "w ould",
+ "t s",
+ "wh at",
+ "q u",
+ "al ly",
+ "i ght",
+ "c k",
+ "g r",
+ "wh en",
+ "v en",
+ "c an",
+ "ou gh",
+ "in e",
+ "en d",
+ "p er",
+ "ou s",
+ "o d",
+ "id e",
+ "k now",
+ "t y",
+ "ver y",
+ "s i",
+ "a k",
+ "wh o",
+ "ab out",
+ "i ll",
+ "the m",
+ "es t",
+ "re d",
+ "y e",
+ "c ould",
+ "on g",
+ "you r",
+ "the ir",
+ "e m",
+ "j ust",
+ "o ther",
+ "in to",
+ "an y",
+ "wh i",
+ "u m",
+ "t w",
+ "as t",
+ "d er",
+ "d id",
+ "i e",
+ "be en",
+ "ac e",
+ "in k",
+ "it y",
+ "b ack",
+ "t ing",
+ "b r",
+ "mo re",
+ "a ke",
+ "p p",
+ "the n",
+ "s p",
+ "e l",
+ "u se",
+ "b l",
+ "sa id",
+ "o ver",
+ "ge t",
+ "e n",
+ "e r",
+ "c h",
+ "e i",
+ "i e",
+ "u n",
+ "i ch",
+ "ei n",
+ "s t",
+ "a n",
+ "t e",
+ "g e",
+ "a u",
+ "i n",
+ "s ch",
+ "d er",
+ "un d",
+ "d ie",
+ "d a",
+ "e s",
+ "a l",
+ "d en",
+ "a r",
+ "g en",
+ "z u",
+ "d e",
+ "h r",
+ "o n",
+ "t en",
+ "e l",
+ "o r",
+ "m i",
+ "s ie",
+ "da s",
+ "a t",
+ "b e",
+ "ein e",
+ "ich t",
+ "b er",
+ "l e",
+ "a ch",
+ "v er",
+ "s e",
+ "au f",
+ "w i",
+ "s o",
+ "t er",
+ "l ich",
+ "c k",
+ "u r",
+ "n icht",
+ "m m",
+ "b en",
+ "a s",
+ "w ar",
+ "r e",
+ "mi t",
+ "s ich",
+ "i g",
+ "l l",
+ "au s",
+ "i st",
+ "w ie",
+ "o ch",
+ "un g",
+ "an n",
+ "ü r",
+ "h n",
+ "i hr",
+ "s a",
+ "s en",
+ "t z",
+ "de m",
+ "ei t",
+ "u m",
+ "h at",
+ "wi r",
+ "v on",
+ "h a",
+ "s p",
+ "w ei",
+ "i er",
+ "r o",
+ "h er",
+ "r a",
+ "ein en",
+ "n e",
+ "v or",
+ "al s",
+ "an d",
+ "al l",
+ "w as",
+ "w o",
+ "r ei",
+ "st e",
+ "l ie",
+ "au ch",
+ "d u",
+ "d es",
+ "k o",
+ "ü ber",
+ "a m",
+ "b ei",
+ "h en",
+ "h m",
+ "l ei",
+ "a ber",
+ "w en",
+ "h l",
+ "g er",
+ "i m",
+ "u t",
+ "n ach",
+ "h e",
+ "i s",
+ "b r",
+ "f t",
+ "en t",
+ "i mm",
+ "j e",
+ "sch en",
+ "w er",
+ "s er",
+ "a b",
+ "ä n",
+ "m e",
+ "s ein",
+ "i t",
+ "o l",
+ "ch t",
+ "f ür",
+ "k l",
+ "f f",
+ "eine m",
+ "n en",
+ "w e",
+ "j a",
+ "u s",
+ "n och",
+ "hat te",
+ "t r",
+ "p f",
+ "h in",
+ "d i",
+ "ch en",
+ "b l",
+ "m an",
+ "r ü",
+ "ie l",
+ "s el",
+ "das s",
+ "i hn",
+ "mi r",
+ "sch l",
+ "ö n",
+ "g an",
+ "g t",
+ "ein er",
+ "st en",
+ "m ich",
+ "wen n",
+ "el l",
+ "g te",
+ "in d",
+ "m al",
+ "ge l",
+ "k en",
+ "n ur",
+ "mm en",
+ "f ü",
+ "er n",
+ "ö r",
+ "un ter",
+ "f r",
+ "an der",
+ "g r",
+ "i l",
+ "d ur",
+ "u ch",
+ "f e",
+ "t a",
+ "m en",
+ "m ach",
+ "d och",
+ "t i",
+ "dur ch",
+ "o s",
+ "g l",
+ "h al",
+ "ihr e",
+ "w ä",
+ "imm er",
+ "i hm",
+ "k ann",
+ "or t",
+ "d ann",
+ "l an",
+ "tz t",
+ "o der",
+ "hr en",
+ "e t",
+ "k ön",
+ "i ck",
+ "f a",
+ "in g",
+ "i r",
+ "wie der",
+ "da ß",
+ "m ein",
+ "f en",
+ "gan z",
+ "die se",
+ "st er",
+ "da r",
+ "w a",
+ "ge s",
+ "n a",
+ "f l",
+ "i gen",
+ "sch e",
+ "un gen",
+ "me hr",
+ "ß en",
+ "o t",
+ "k on",
+ "ge w",
+ "ha ben",
+ "ge h",
+ "ä t",
+ "s ind",
+ "d r",
+ "w el",
+ "un s",
+ "v o",
+ "m a",
+ "u te",
+ "sch on",
+ "b es",
+ "ge sch",
+ "b t",
+ "ch e",
+ "s on",
+ "o b",
+ "l a",
+ "p p",
+ "rü ck",
+ "s eine",
+ "k r",
+ "f re",
+ "ei l",
+ "zu m",
+ "u l",
+ "h ier",
+ "k t",
+ "i ge",
+ "sp r",
+ "k e",
+ "le ben",
+ "b st",
+ "z eit",
+ "i on",
+ "g ro",
+ "den n",
+ "h o",
+ "sch a",
+ "b ar",
+ "al le",
+ "ge gen",
+ "w ür",
+ "m ü",
+ "z e",
+ "wer den",
+ "je tzt",
+ "ko mmen",
+ "n ie",
+ "s ei",
+ "h eit",
+ "so ll",
+ "g lei",
+ "m eine",
+ "wo ll",
+ "n er",
+ "ha be",
+ "w ur",
+ "lich en",
+ "p er",
+ "as sen",
+ "n te",
+ "se hen",
+ "wir d",
+ "b is",
+ "g ar",
+ "i en",
+ "m us",
+ "u ß",
+ "ä r",
+ "st ell",
+ "k eit",
+ "z wei",
+ "sel bst",
+ "st a",
+ "p a",
+ "sa gte",
+ "te t",
+ "k am",
+ "s sen",
+ "v iel",
+ "u g",
+ "z en",
+ "h ei",
+ "m ann",
+ "wi ll",
+ "ge b",
+ "war en",
+ "ü ck",
+ "ä ch",
+ "m er",
+ "r u",
+ "w or",
+ "h au",
+ "ei gen",
+ "an g",
+ "we g",
+ "bl ick",
+ "f ra",
+ "all es",
+ "k a",
+ "au gen",
+ "f in",
+ "lich e",
+ "t o",
+ "un ser",
+ "der n",
+ "her r",
+ "n un",
+ "v ie",
+ "ch te",
+ "wo hl",
+ "f all",
+ "h t",
+ "ü n",
+ "et was",
+ "st and",
+ "en d",
+ "ä u",
+ "e m",
+ "m ö",
+ "te l",
+ "r ie",
+ "d ich",
+ "die s",
+ "h and",
+ "b in",
+ "ff en",
+ "nicht s",
+ "d an",
+ "p l",
+ "hn e",
+ "ihn en",
+ "es en",
+ "die ser",
+ "fr au",
+ "an t",
+ "ar t",
+ "di r",
+ "i sch",
+ "er st",
+ "glei ch",
+ "ko mm",
+ "h ör",
+ "ß e",
+ "d ig",
+ "se hr",
+ "z ei",
+ "sa m",
+ "au m",
+ "h ät",
+ "in gen",
+ "g ut",
+ "b o",
+ "m ut",
+ "ck en",
+ "kon nte",
+ "st imm",
+ "p ro",
+ "zu r",
+ "i tz",
+ "wei l",
+ "wür de",
+ "f ä",
+ "kön nen",
+ "k eine",
+ "f er",
+ "i schen",
+ "vo ll",
+ "ein es",
+ "se tz",
+ "z ie",
+ "de l",
+ "te te",
+ "sein er",
+ "ier en",
+ "ge st",
+ "zu rück",
+ "wur de",
+ "sch n",
+ "p r",
+ "lie ß",
+ "t ra",
+ "m ä",
+ "gen d",
+ "f ol",
+ "i k",
+ "schl a",
+ "scha ft",
+ "at er",
+ "wei ß",
+ "s einen",
+ "l assen",
+ "l u",
+ "und en",
+ "t eil",
+ "ne u",
+ "ier t",
+ "men schen",
+ "hm en",
+ "st r",
+ "g i",
+ "sa h",
+ "ihr en",
+ "el n",
+ "wei ter",
+ "ge hen",
+ "ig er",
+ "mach t",
+ "ta g",
+ "al so",
+ "hal ten",
+ "n is",
+ "ach t",
+ "ge ben",
+ "f or",
+ "o g",
+ "n at",
+ "m ar",
+ "de t",
+ "o hne",
+ "h aus",
+ "t ro",
+ "an ge",
+ "l au",
+ "sp iel",
+ "t re",
+ "sch r",
+ "in n",
+ "s u",
+ "l os",
+ "mach en",
+ "hät te",
+ "be g",
+ "wir k",
+ "al t",
+ "g lich",
+ "te s",
+ "r icht",
+ "fre und",
+ "m o",
+ "ihr er",
+ "f el",
+ "b el",
+ "so l",
+ "ein mal",
+ "e ben",
+ "h ol",
+ "h än",
+ "q u",
+ "ter n",
+ "h ö",
+ "sch w",
+ "re cht",
+ "wa hr",
+ "s einem",
+ "ste hen",
+ "hl en",
+ "in s",
+ "g ing",
+ "woll te",
+ "wi ssen",
+ "ung s",
+ "al d",
+ "as s",
+ "ja hr",
+ "m or",
+ "wel t",
+ "un der",
+ "zu sa",
+ "at ion",
+ "ko pf",
+ "lan g",
+ "hin ter",
+ "at z",
+ "st ra",
+ "an gen",
+ "an k",
+ "a de",
+ "gl au",
+ "f ach",
+ "hat ten",
+ "l o",
+ "f ort",
+ "ei cht",
+ "i ff",
+ "l er",
+ "m ei",
+ "diese m",
+ "k ein",
+ "f rei",
+ "fü hr",
+ "vo m",
+ "e s",
+ "e n",
+ "a i",
+ "o u",
+ "o n",
+ "l e",
+ "d e",
+ "r e",
+ "q u",
+ "a n",
+ "e r",
+ "en t",
+ "e t",
+ "l a",
+ "n e",
+ "i l",
+ "a r",
+ "i s",
+ "ai t",
+ "t e",
+ "a u",
+ "i n",
+ "qu e",
+ "i t",
+ "u r",
+ "s e",
+ "l es",
+ "c h",
+ "c e",
+ "m e",
+ "o r",
+ "ou r",
+ "a s",
+ "p r",
+ "a v",
+ "o m",
+ "ai s",
+ "u n",
+ "an t",
+ "ou s",
+ "t r",
+ "t i",
+ "l u",
+ "o i",
+ "e u",
+ "l le",
+ "s i",
+ "p ar",
+ "d es",
+ "an s",
+ "m ent",
+ "é t",
+ "es t",
+ "j e",
+ "u ne",
+ "a l",
+ "p as",
+ "t re",
+ "qu i",
+ "d u",
+ "r i",
+ "c on",
+ "s on",
+ "c om",
+ "e lle",
+ "d é",
+ "p our",
+ "d ans",
+ "l i",
+ "s a",
+ "r é",
+ "t ou",
+ "v ous",
+ "d i",
+ "v i",
+ "a g",
+ "a m",
+ "a t",
+ "ou v",
+ "a p",
+ "ti on",
+ "m on",
+ "s ur",
+ "c i",
+ "o s",
+ "p lu",
+ "s u",
+ "en d",
+ "a b",
+ "è re",
+ "ai n",
+ "m ais",
+ "o is",
+ "r es",
+ "plu s",
+ "é e",
+ "ai ent",
+ "m p",
+ "ch e",
+ "lu i",
+ "av e",
+ "ét ait",
+ "m a",
+ "s es",
+ "tou t",
+ "i r",
+ "v o",
+ "a c",
+ "s er",
+ "an d",
+ "f f",
+ "oi r",
+ "g r",
+ "av ait",
+ "é s",
+ "m es",
+ "n ous",
+ "eu x",
+ "b i",
+ "t er",
+ "c o",
+ "on s",
+ "p u",
+ "c es",
+ "g e",
+ "t u",
+ "le ur",
+ "pr o",
+ "d on",
+ "e ur",
+ "et te",
+ "ai re",
+ "ave c",
+ "d it",
+ "t é",
+ "i e",
+ "u s",
+ "il le",
+ "p er",
+ "com me",
+ "c r",
+ "or t",
+ "m i",
+ "e x",
+ "u x",
+ "v er",
+ "m o",
+ "è s",
+ "v e",
+ "au x",
+ "r a",
+ "j our",
+ "il s",
+ "bi en",
+ "c ou",
+ "p e",
+ "que l",
+ "p eu",
+ "c ette",
+ "t es",
+ "p o",
+ "in s",
+ "c u",
+ "m ê",
+ "s o",
+ "f ait",
+ "g u",
+ "m ar",
+ "ê tre",
+ "l o",
+ "it é",
+ "f r",
+ "a tion",
+ "en s",
+ "b r",
+ "n i",
+ "l é",
+ "d is",
+ "b le",
+ "m an",
+ "n é",
+ "pu is",
+ "mê me",
+ "qu es",
+ "f i",
+ "e l",
+ "ag e",
+ "g ar",
+ "m oi",
+ "en ce",
+ "on t",
+ "m ain",
+ "or s",
+ "au t",
+ "an ce",
+ "v en",
+ "m é",
+ "s ans",
+ "e m",
+ "s é",
+ "l on",
+ "h om",
+ "r o",
+ "u t",
+ "c ar",
+ "ab le",
+ "i m",
+ "de r",
+ "ch er",
+ "n o",
+ "vi e",
+ "au s",
+ "b e",
+ "de ux",
+ "en f",
+ "o ù",
+ "t en",
+ "p h",
+ "u re",
+ "te mp",
+ "p os",
+ "r ent",
+ "p é",
+ "f aire",
+ "p i",
+ "tr es",
+ "ç a",
+ "an g",
+ "end re",
+ "f or",
+ "p a",
+ "b on",
+ "s ou",
+ "in t",
+ "pr é",
+ "s ent",
+ "t ant",
+ "n er",
+ "c er",
+ "l à",
+ "l ais",
+ "pr ès",
+ "b re",
+ "c our",
+ "p et",
+ "i on",
+ "i ne",
+ "com p",
+ "l ait",
+ "tr ouv",
+ "t a",
+ "ent re",
+ "son t",
+ "de v",
+ "n u",
+ "temp s",
+ "d ou",
+ "r ait",
+ "b ou",
+ "qu and",
+ "jour s",
+ "l an",
+ "er s",
+ "av oir",
+ "ét é",
+ "a le",
+ "p re",
+ "f ois",
+ "or te",
+ "v é",
+ "m er",
+ "n on",
+ "t ous",
+ "j us",
+ "cou p",
+ "t s",
+ "hom me",
+ "ê te",
+ "a d",
+ "aus si",
+ "ur s",
+ "se u",
+ "or d",
+ "o b",
+ "m in",
+ "g é",
+ "co re",
+ "v a",
+ "v re",
+ "en core",
+ "se m",
+ "i te",
+ "au tre",
+ "pr is",
+ "peu t",
+ "u e",
+ "an te",
+ "m al",
+ "g n",
+ "ré p",
+ "h u",
+ "si on",
+ "vo tre",
+ "di re",
+ "e z",
+ "f em",
+ "leur s",
+ "m et",
+ "f in",
+ "c ri",
+ "m is",
+ "t our",
+ "r ai",
+ "j am",
+ "re gar",
+ "ri en",
+ "ver s",
+ "su is",
+ "p ouv",
+ "o p",
+ "v is",
+ "gr and",
+ "ant s",
+ "c or",
+ "re r",
+ "ar d",
+ "c é",
+ "t ent",
+ "pr es",
+ "v ou",
+ "f a",
+ "al ors",
+ "si eur",
+ "ai ne",
+ "le r",
+ "qu oi",
+ "f on",
+ "end ant",
+ "ar ri",
+ "eu re",
+ "a près",
+ "don c",
+ "it u",
+ "l è",
+ "s ait",
+ "t oi",
+ "ch a",
+ "ai l",
+ "as se",
+ "i mp",
+ "vo y",
+ "con n",
+ "p la",
+ "pet it",
+ "av ant",
+ "n om",
+ "t in",
+ "don t",
+ "d a",
+ "s ous",
+ "e mp",
+ "per son",
+ "el les",
+ "be au",
+ "par ti",
+ "ch o",
+ "pr it",
+ "tou jours",
+ "m en",
+ "r ais",
+ "jam ais",
+ "tr av",
+ "tion s",
+ "tr ès",
+ "v oi",
+ "r en",
+ "y eux",
+ "f er",
+ "v oir",
+ "pre mi",
+ "c a",
+ "g ne",
+ "h eure",
+ "r ou",
+ "e ff",
+ "no tre",
+ "ment s",
+ "t on",
+ "f ais",
+ "ce la",
+ "i er",
+ "rép on",
+ "con s",
+ "ai r",
+ "ô t",
+ "p endant",
+ "i ci",
+ "tou te",
+ "j et",
+ "p ort",
+ "ét aient",
+ "p en",
+ "h é",
+ "au tres",
+ "p ère",
+ "o c",
+ "quel ques",
+ "i que",
+ "l is",
+ "fem me",
+ "j ou",
+ "te ur",
+ "mon de",
+ "u se",
+ "n es",
+ "d re",
+ "a ff",
+ "r ap",
+ "par t",
+ "le ment",
+ "c la",
+ "f ut",
+ "quel que",
+ "pr endre",
+ "r ê",
+ "ai lle",
+ "s ais",
+ "ch es",
+ "le t",
+ "ch ar",
+ "è res",
+ "ent s",
+ "b er",
+ "g er",
+ "mo ins",
+ "e au",
+ "a î",
+ "j eu",
+ "h eur",
+ "é es",
+ "tr i",
+ "po int",
+ "m om",
+ "v ent",
+ "n ouv",
+ "gr an",
+ "tr ois",
+ "s ant",
+ "tout es",
+ "con tre",
+ "è rent",
+ "che z",
+ "ave z",
+ "û t",
+ "a lle",
+ "at t",
+ "p au",
+ "p orte",
+ "ouv er",
+ "b ar",
+ "l it",
+ "f ort",
+ "o t",
+ "as s",
+ "pr és",
+ "cho se",
+ "v it",
+ "mon sieur",
+ "h ab",
+ "t ête",
+ "j u",
+ "te ment",
+ "c tion",
+ "v rai",
+ "la r",
+ "c et",
+ "regar d",
+ "l ant",
+ "de m",
+ "s om",
+ "mom ent",
+ "il les",
+ "p le",
+ "p s",
+ "b es",
+ "m ère",
+ "c l",
+ "s our",
+ "y s",
+ "tr op",
+ "en ne",
+ "jus qu",
+ "av aient",
+ "av ais",
+ "jeu ne",
+ "de puis",
+ "person ne",
+ "f it",
+ "cer t",
+ "j o",
+ "g es",
+ "ou i",
+ "r est",
+ "sem b",
+ "c ap",
+ "m at",
+ "m u",
+ "lon g",
+ "fr an",
+ "f aut",
+ "it i",
+ "b li",
+ "che v",
+ "pr i",
+ "ent e",
+ "ain si",
+ "ch am",
+ "l ors",
+ "c as",
+ "d o",
+ "il i",
+ "b é",
+ "n os",
+ "an ge",
+ "su i",
+ "r it",
+ "cr o",
+ "gu e",
+ "d e",
+ "e n",
+ "e s",
+ "o s",
+ "l a",
+ "e r",
+ "q u",
+ "a r",
+ "a n",
+ "o n",
+ "qu e",
+ "a s",
+ "o r",
+ "e l",
+ "d o",
+ "a l",
+ "c i",
+ "u n",
+ "r e",
+ "a b",
+ "i n",
+ "t e",
+ "t o",
+ "s e",
+ "d i",
+ "t r",
+ "d a",
+ "c on",
+ "t a",
+ "s u",
+ "m i",
+ "c o",
+ "t i",
+ "l e",
+ "l os",
+ "n o",
+ "l o",
+ "í a",
+ "c u",
+ "c a",
+ "s i",
+ "v i",
+ "m e",
+ "p or",
+ "m o",
+ "p ar",
+ "r a",
+ "r i",
+ "la s",
+ "c h",
+ "r o",
+ "m a",
+ "p er",
+ "ó n",
+ "m en",
+ "de s",
+ "un a",
+ "m p",
+ "s o",
+ "ab a",
+ "p u",
+ "d os",
+ "t u",
+ "g u",
+ "er a",
+ "de l",
+ "h a",
+ "m u",
+ "l i",
+ "en t",
+ "m b",
+ "h ab",
+ "es t",
+ "g o",
+ "p a",
+ "r es",
+ "par a",
+ "p o",
+ "á s",
+ "m os",
+ "tr a",
+ "t en",
+ "an do",
+ "p i",
+ "qu i",
+ "b i",
+ "m an",
+ "co mo",
+ "v e",
+ "m ás",
+ "j o",
+ "ci ón",
+ "i s",
+ "t an",
+ "v o",
+ "da d",
+ "c e",
+ "a do",
+ "v er",
+ "f u",
+ "ci a",
+ "c er",
+ "p e",
+ "c as",
+ "c ar",
+ "men te",
+ "n i",
+ "su s",
+ "t ar",
+ "n a",
+ "f i",
+ "t er",
+ "z a",
+ "p ro",
+ "tr o",
+ "s a",
+ "l u",
+ "b a",
+ "per o",
+ "s er",
+ "c es",
+ "d as",
+ "d u",
+ "s in",
+ "e mp",
+ "m ar",
+ "l la",
+ "e x",
+ "á n",
+ "c or",
+ "i a",
+ "v a",
+ "r an",
+ "ch o",
+ "g a",
+ "y o",
+ "t os",
+ "c os",
+ "mi s",
+ "l es",
+ "t es",
+ "v en",
+ "h o",
+ "y a",
+ "en te",
+ "on es",
+ "hab ía",
+ "n u",
+ "u s",
+ "p as",
+ "h i",
+ "n os",
+ "es ta",
+ "la n",
+ "m as",
+ "t or",
+ "l le",
+ "h e",
+ "s on",
+ "b re",
+ "p re",
+ "ab an",
+ "d or",
+ "í an",
+ "i r",
+ "t as",
+ "é n",
+ "r u",
+ "en do",
+ "a que",
+ "er o",
+ "i o",
+ "qu é",
+ "m in",
+ "c ab",
+ "j a",
+ "de r",
+ "t al",
+ "é s",
+ "se ñ",
+ "or a",
+ "to do",
+ "la r",
+ "d on",
+ "g ar",
+ "s al",
+ "p r",
+ "cu ando",
+ "j e",
+ "h u",
+ "g un",
+ "b u",
+ "g i",
+ "d ar",
+ "n e",
+ "r as",
+ "de n",
+ "es to",
+ "par e",
+ "p en",
+ "é l",
+ "tr as",
+ "c an",
+ "b o",
+ "j os",
+ "mi en",
+ "pu e",
+ "c re",
+ "co mp",
+ "p on",
+ "d ía",
+ "tr os",
+ "s ab",
+ "so bre",
+ "es e",
+ "mb re",
+ "er on",
+ "a ñ",
+ "m or",
+ "f or",
+ "i do",
+ "por que",
+ "el la",
+ "p ri",
+ "g ran",
+ "f a",
+ "c en",
+ "di s",
+ "c ri",
+ "mu y",
+ "ch a",
+ "c al",
+ "es te",
+ "h as",
+ "c ó",
+ "g ra",
+ "r os",
+ "p os",
+ "o b",
+ "al l",
+ "aque l",
+ "j u",
+ "p res",
+ "m er",
+ "di jo",
+ "c ía",
+ "ent re",
+ "z o",
+ "ci ones",
+ "bi en",
+ "mb i",
+ "el o",
+ "t ó",
+ "in a",
+ "to dos",
+ "g en",
+ "ti en",
+ "est aba",
+ "de ci",
+ "ci o",
+ "h er",
+ "ñ o",
+ "l or",
+ "nu es",
+ "me di",
+ "l en",
+ "vi da",
+ "f e",
+ "al i",
+ "m on",
+ "c la",
+ "d re",
+ "pu es",
+ "al es",
+ "vo l",
+ "m í",
+ "r ar",
+ "b le",
+ "ci on",
+ "has ta",
+ "señ or",
+ "con o",
+ "a h",
+ "di os",
+ "s en",
+ "es a",
+ "ú n",
+ "v ar",
+ "s an",
+ "gu i",
+ "a c",
+ "o tros",
+ "ta do",
+ "bu en",
+ "ñ a",
+ "ti emp",
+ "ha cer",
+ "j er",
+ "f er",
+ "v u",
+ "f in",
+ "an a",
+ "as í",
+ "an tes",
+ "t in",
+ "ve z",
+ "mien to",
+ "j ar",
+ "la b",
+ "ch e",
+ "cas a",
+ "d r",
+ "es o",
+ "e go",
+ "di ó",
+ "an te",
+ "est á",
+ "m al",
+ "en cia",
+ "el i",
+ "í as",
+ "tiemp o",
+ "z ar",
+ "v an",
+ "m un",
+ "er ta",
+ "ta mbi",
+ "s í",
+ "b ar",
+ "a un",
+ "al e",
+ "mis mo",
+ "ent es",
+ "vi s",
+ "man o",
+ "el e",
+ "na da",
+ "se gu",
+ "me j",
+ "er ra",
+ "ab le",
+ "b e",
+ "ti r",
+ "un o",
+ "don de",
+ "to da",
+ "des de",
+ "r en",
+ "tambi én",
+ "cu er",
+ "per son",
+ "ho mbre",
+ "o tro",
+ "li b",
+ "tr ar",
+ "cu al",
+ "ha y",
+ "a u",
+ "ca da",
+ "t aba",
+ "i mp",
+ "men to",
+ "ten ía",
+ "qu er",
+ "er an",
+ "si emp",
+ "siemp re",
+ "er to",
+ "qu í",
+ "g os",
+ "pu és",
+ "el los",
+ "des pués",
+ "nu e",
+ "g an",
+ "l lo",
+ "in ter",
+ "có mo",
+ "tr i",
+ "ah ora",
+ "us te",
+ "tr aba",
+ "la do",
+ "in o",
+ "po co",
+ "er te",
+ "mu jer",
+ "i m",
+ "qui er",
+ "al gun",
+ "fu e",
+ "o jos",
+ "ent on",
+ "v os",
+ "es per",
+ "mu ch",
+ "o tra",
+ "a z",
+ "a d",
+ "in g",
+ "e za",
+ "a quí",
+ "ci as",
+ "gu a",
+ "mu cho",
+ "deci r",
+ "es ti",
+ "i dad",
+ "al go",
+ "e z",
+ "o cu",
+ "enton ces",
+ "di do",
+ "ent os",
+ "g ri",
+ "da do",
+ "i os",
+ "so l",
+ "dos e",
+ "uste d",
+ "qui en",
+ "a mi",
+ "un to",
+ "f r",
+ "mi r",
+ "mej or",
+ "b as",
+ "so lo",
+ "pre gun",
+ "tu r",
+ "al g",
+ "p la",
+ "to das",
+ "par te",
+ "e mb",
+ "c to",
+ "mun do",
+ "tien e",
+ "tan te",
+ "pa lab",
+ "tr an",
+ "aque lla",
+ "ci os",
+ "aun que",
+ "a y",
+ "cu en",
+ "ten er",
+ "f un",
+ "res pon",
+ "all í",
+ "x i",
+ "h an",
+ "pen s",
+ "con tra",
+ "tu ra",
+ "v al",
+ "di o",
+ "tr es",
+ "t re",
+ "tan to",
+ "ca min",
+ "m ó",
+ "es p",
+ "a da",
+ "í o",
+ "in s",
+ "ha cia",
+ "de j",
+ "est ar",
+ "i ón",
+ "g as",
+ "b er",
+ "v as",
+ "no che",
+ "é r",
+ "añ os",
+ "pa dre",
+ "gu s",
+ "á r",
+ "sin o",
+ "man os",
+ "ci do",
+ "es tu",
+ "a de",
+ "hu bi",
+ "vi r",
+ "b ri",
+ "ra z",
+ "ch i",
+ "pue de",
+ "men os",
+ "hab i",
+ "ho mb",
+ "ne ces",
+ "ma y",
+ "er os",
+ "r ía",
+ "he cho",
+ "es cu",
+ "l ti",
+ "án do",
+ "b us",
+ "cos as",
+ "t ú",
+ "es pa",
+ "re ci",
+ "c tor",
+ "pri m",
+ "di a",
+ "de se",
+ "mien tras",
+ "h or",
+ "fu er",
+ "i da",
+ "pos i",
+ "lan te",
+ "t on",
+ "an o",
+ "est as",
+ "p li",
+ "ch ar",
+ "lu ego",
+ "si ón",
+ "ci n",
+ "ti erra",
+ "m es",
+ "gu ar",
+ "ca do",
+ "en con",
+ "pr en",
+ "may or",
+ "f al",
+ "e r",
+ "o n",
+ "a n",
+ "t o",
+ "d i",
+ "r e",
+ "l a",
+ "i n",
+ "e n",
+ "a l",
+ "t a",
+ "c h",
+ "e l",
+ "r i",
+ "c o",
+ "t i",
+ "t e",
+ "s i",
+ "r a",
+ "u n",
+ "l e",
+ "l i",
+ "ch e",
+ "r o",
+ "c i",
+ "c a",
+ "s e",
+ "q u",
+ "m a",
+ "p o",
+ "s o",
+ "i l",
+ "d o",
+ "e s",
+ "v a",
+ "p er",
+ "l o",
+ "c on",
+ "d el",
+ "p a",
+ "m o",
+ "s a",
+ "p i",
+ "d a",
+ "m i",
+ "g i",
+ "s u",
+ "d e",
+ "v i",
+ "z i",
+ "m e",
+ "g li",
+ "n o",
+ "m en",
+ "v o",
+ "t u",
+ "n on",
+ "v e",
+ "t to",
+ "s t",
+ "on e",
+ "an o",
+ "ch i",
+ "er a",
+ "er e",
+ "f a",
+ "c e",
+ "z a",
+ "un a",
+ "b i",
+ "p re",
+ "s ta",
+ "o r",
+ "a r",
+ "f i",
+ "on o",
+ "t ra",
+ "n a",
+ "n el",
+ "n e",
+ "p ro",
+ "t ro",
+ "al e",
+ "v er",
+ "n i",
+ "c u",
+ "t ti",
+ "men te",
+ "del la",
+ "t er",
+ "zi one",
+ "g u",
+ "p e",
+ "t ta",
+ "an do",
+ "t à",
+ "al i",
+ "u o",
+ "qu el",
+ "co m",
+ "s en",
+ "co me",
+ "b a",
+ "al la",
+ "p ri",
+ "d u",
+ "qu es",
+ "l u",
+ "on i",
+ "g gi",
+ "pa r",
+ "s si",
+ "v en",
+ "in a",
+ "g a",
+ "pi ù",
+ "ci a",
+ "i m",
+ "co r",
+ "m an",
+ "in o",
+ "in i",
+ "t en",
+ "r an",
+ "b b",
+ "g o",
+ "s to",
+ "t re",
+ "a ve",
+ "a v",
+ "s ono",
+ "er i",
+ "a c",
+ "s se",
+ "er o",
+ "h a",
+ "s c",
+ "su l",
+ "f or",
+ "v ano",
+ "po r",
+ "s ti",
+ "su o",
+ "c chi",
+ "t an",
+ "z za",
+ "an che",
+ "p u",
+ "i o",
+ "t te",
+ "vo l",
+ "es s",
+ "s ci",
+ "co l",
+ "r u",
+ "p en",
+ "f u",
+ "al l",
+ "s so",
+ "s te",
+ "se m",
+ "s sa",
+ "d en",
+ "a d",
+ "t ri",
+ "de i",
+ "in e",
+ "ave va",
+ "men to",
+ "z z",
+ "a mo",
+ "g no",
+ "f o",
+ "un o",
+ "su a",
+ "g en",
+ "ri a",
+ "g e",
+ "st ra",
+ "s ì",
+ "c er",
+ "ch é",
+ "b u",
+ "a p",
+ "c en",
+ "d al",
+ "on a",
+ "s pe",
+ "g ni",
+ "b o",
+ "t t",
+ "del le",
+ "ques to",
+ "nel la",
+ "f f",
+ "d ere",
+ "an no",
+ "del l",
+ "un i",
+ "bb e",
+ "an ti",
+ "g ra",
+ "s p",
+ "en e",
+ "gi o",
+ "u to",
+ "qu al",
+ "gli a",
+ "qu ando",
+ "tu tto",
+ "c an",
+ "gli o",
+ "zi oni",
+ "ca m",
+ "h o",
+ "es so",
+ "s s",
+ "mo l",
+ "a t",
+ "lo ro",
+ "per ché",
+ "co sa",
+ "du e",
+ "po i",
+ "ca r",
+ "s co",
+ "ci o",
+ "to r",
+ "c co",
+ "c re",
+ "a m",
+ "g na",
+ "te m",
+ "pri ma",
+ "lu i",
+ "co sì",
+ "qu e",
+ "gu ar",
+ "ess ere",
+ "an i",
+ "con o",
+ "b ra",
+ "al le",
+ "m on",
+ "ri o",
+ "an co",
+ "cu i",
+ "s pi",
+ "vi a",
+ "g ran",
+ "gi or",
+ "a i",
+ "bi le",
+ "u l",
+ "ggi o",
+ "f e",
+ "an te",
+ "ma i",
+ "ta re",
+ "in ter",
+ "in di",
+ "re bbe",
+ "sen za",
+ "so lo",
+ "zi o",
+ "e d",
+ "en te",
+ "tu tti",
+ "sta to",
+ "zi a",
+ "d alla",
+ "tu ra",
+ "mi a",
+ "vi ta",
+ "quel la",
+ "qu a",
+ "ma r",
+ "do ve",
+ "g h",
+ "al lo",
+ "sem pre",
+ "zz o",
+ "si a",
+ "mo r",
+ "do po",
+ "por ta",
+ "d re",
+ "c cia",
+ "er ano",
+ "an ni",
+ "di o",
+ "chi a",
+ "en za",
+ "pro pri",
+ "qu i",
+ "m u",
+ "m b",
+ "an da",
+ "c ca",
+ "o cchi",
+ "ques ta",
+ "f fi",
+ "le i",
+ "par te",
+ "d on",
+ "r on",
+ "mi o",
+ "tan to",
+ "ri s",
+ "o gni",
+ "di s",
+ "r in",
+ "fa r",
+ "men ti",
+ "t el",
+ "anco ra",
+ "f ra",
+ "fa tto",
+ "man i",
+ "sen ti",
+ "p ra",
+ "tem po",
+ "es si",
+ "b bi",
+ "f in",
+ "a re",
+ "la re",
+ "per s",
+ "f on",
+ "b el",
+ "so r",
+ "d er",
+ "pre n",
+ "an za",
+ "di re",
+ "pi e",
+ "o ra",
+ "ver so",
+ "se gu",
+ "al tro",
+ "ta to",
+ "ca to",
+ "a to",
+ "vol ta",
+ "c c",
+ "fa re",
+ "pa re",
+ "ci ò",
+ "li b",
+ "bi li",
+ "n uo",
+ "s er",
+ "quel lo",
+ "co lo",
+ "p po",
+ "ca sa",
+ "tro va",
+ "o re",
+ "f er",
+ "r ono",
+ "d es",
+ "mol to",
+ "al mente",
+ "s ca",
+ "vo le",
+ "t ali",
+ "sul la",
+ "s ce",
+ "men o",
+ "an to",
+ "p un",
+ "s tu",
+ "ca pi",
+ "so l",
+ "gi u",
+ "m ini",
+ "m ano",
+ "z e",
+ "pi a",
+ "par ti",
+ "s al",
+ "la vo",
+ "ver o",
+ "r si",
+ "al tri",
+ "es ti",
+ "s cia",
+ "suo i",
+ "gli e",
+ "so tto",
+ "b ene",
+ "sc ri",
+ "t ale",
+ "de gli",
+ "n u",
+ "al c",
+ "uo mo",
+ "p el",
+ "f re",
+ "po te",
+ "es sa",
+ "s cu",
+ "si gno",
+ "el e",
+ "st ro",
+ "u ti",
+ "di a",
+ "si one",
+ "g re",
+ "f ini",
+ "ar ri",
+ "l un",
+ "c ri",
+ "e si",
+ "pa ssa",
+ "r à",
+ "men tre",
+ "an d",
+ "h anno",
+ "el o",
+ "u sci",
+ "gi a",
+ "gi à",
+ "di e",
+ "m ina",
+ "b e",
+ "ti ca",
+ "gior no",
+ "t in",
+ "es se",
+ "mo do",
+ "c al",
+ "s pa",
+ "propri o",
+ "l en",
+ "o ri",
+ "con tro",
+ "st ru",
+ "di ven",
+ "di sse",
+ "ra to",
+ "no i",
+ "v ere",
+ "pu ò",
+ "di ce",
+ "s an",
+ "es a",
+ "c ci",
+ "se con",
+ "re n",
+ "c cio",
+ "qual che",
+ "tu tta",
+ "g g",
+ "mon do",
+ "for ma",
+ "p li",
+ "m ma",
+ "pen sa",
+ "de va",
+ "tu r",
+ "fo sse",
+ "so pra",
+ "ta mente",
+ "n ess",
+ "qu anto",
+ "ra ga",
+ "un que",
+ "ca re",
+ "st re",
+ "gran de",
+ "pi cco",
+ "guar da",
+ "b en",
+ "nel l",
+ "a ff",
+ "po ssi",
+ "pre sen",
+ "r ò",
+ "pa ro",
+ "tu a",
+ "v in",
+ "an e",
+ "a s",
+ "ste sso",
+ "da v",
+ "ne i",
+ "nel le",
+ "gh i",
+ "pi o",
+ "ta r",
+ "an a",
+ "la to",
+ "si d",
+ "f ine",
+ "f uo",
+ "m er",
+ "z o",
+ "qua si",
+ "ul ti",
+ "i to",
+ "su e",
+ "si e",
+ "f il",
+ "allo ra",
+ "m in",
+ "ven i",
+ "t ano",
+ "el lo",
+ "d e",
+ "r a",
+ "e s",
+ "d o",
+ "e n",
+ "q u",
+ "c o",
+ "a s",
+ "o s",
+ "e r",
+ "a r",
+ "s e",
+ "qu e",
+ "a n",
+ "i n",
+ "i s",
+ "t o",
+ "ã o",
+ "t e",
+ "d a",
+ "m a",
+ "e l",
+ "t a",
+ "o r",
+ "i a",
+ "r e",
+ "e m",
+ "a l",
+ "co m",
+ "p a",
+ "o u",
+ "c a",
+ "u m",
+ "r o",
+ "v a",
+ "t i",
+ "s o",
+ "m en",
+ "n ão",
+ "h a",
+ "co n",
+ "m e",
+ "r i",
+ "pa ra",
+ "p o",
+ "d i",
+ "s a",
+ "v o",
+ "u ma",
+ "c i",
+ "n a",
+ "p or",
+ "n o",
+ "g u",
+ "s u",
+ "h o",
+ "an do",
+ "t ra",
+ "e i",
+ "v i",
+ "e u",
+ "i m",
+ "do s",
+ "el e",
+ "r es",
+ "m o",
+ "en t",
+ "f i",
+ "l a",
+ "e ra",
+ "l e",
+ "de s",
+ "el a",
+ "men te",
+ "l h",
+ "p er",
+ "l i",
+ "ç ão",
+ "m as",
+ "t er",
+ "m u",
+ "es t",
+ "v e",
+ "g o",
+ "l o",
+ "u s",
+ "ma is",
+ "v er",
+ "c ê",
+ "in ha",
+ "vo cê",
+ "f a",
+ "t u",
+ "c u",
+ "p ar",
+ "com o",
+ "p ro",
+ "s i",
+ "m os",
+ "e c",
+ "p re",
+ "d as",
+ "ç a",
+ "es ta",
+ "s er",
+ "u n",
+ "da de",
+ "d is",
+ "f o",
+ "e x",
+ "c h",
+ "i r",
+ "ra n",
+ "t ar",
+ "en te",
+ "g a",
+ "t r",
+ "p e",
+ "t os",
+ "b o",
+ "c ia",
+ "p en",
+ "c ar",
+ "s en",
+ "su a",
+ "se m",
+ "c as",
+ "f or",
+ "to u",
+ "n os",
+ "te m",
+ "r ia",
+ "m es",
+ "se u",
+ "co r",
+ "o n",
+ "a o",
+ "p os",
+ "ra m",
+ "v el",
+ "é m",
+ "t en",
+ "po de",
+ "t es",
+ "esta va",
+ "c e",
+ "b a",
+ "qu ando",
+ "m i",
+ "qu er",
+ "men to",
+ "se gu",
+ "t as",
+ "is so",
+ "mu i",
+ "g ar",
+ "t ro",
+ "d u",
+ "fa z",
+ "õ es",
+ "p es",
+ "an to",
+ "l u",
+ "p i",
+ "i x",
+ "ve z",
+ "s im",
+ "j a",
+ "p r",
+ "m in",
+ "b e",
+ "ra s",
+ "m an",
+ "p res",
+ "est á",
+ "c er",
+ "b re",
+ "p as",
+ "d ia",
+ "m b",
+ "dis se",
+ "n i",
+ "r os",
+ "es se",
+ "v ia",
+ "o lh",
+ "is a",
+ "an te",
+ "ê n",
+ "z a",
+ "qu i",
+ "b i",
+ "t inha",
+ "me u",
+ "s ão",
+ "m inha",
+ "a c",
+ "ri o",
+ "m ar",
+ "a t",
+ "p el",
+ "mui to",
+ "ta l",
+ "to r",
+ "fo i",
+ "h or",
+ "j o",
+ "b em",
+ "g i",
+ "f al",
+ "vo l",
+ "po n",
+ "di z",
+ "l ar",
+ "gu n",
+ "m or",
+ "r u",
+ "par ec",
+ "ç o",
+ "do r",
+ "pes so",
+ "n e",
+ "f er",
+ "b er",
+ "p u",
+ "po is",
+ "in a",
+ "es p",
+ "d ar",
+ "en do",
+ "de n",
+ "so bre",
+ "co s",
+ "p ri",
+ "al i",
+ "mes mo",
+ "ç ões",
+ "g ra",
+ "se us",
+ "me i",
+ "b ra",
+ "vi da",
+ "an tes",
+ "b ri",
+ "at é",
+ "ên cia",
+ "lh e",
+ "ti v",
+ "m ã",
+ "al g",
+ "qu anto",
+ "s ó",
+ "g os",
+ "de r",
+ "t ão",
+ "tu do",
+ "ent ão",
+ "r ou",
+ "es s",
+ "in da",
+ "b al",
+ "in do",
+ "ci o",
+ "n do",
+ "j á",
+ "va m",
+ "re i",
+ "l es",
+ "ei to",
+ "v is",
+ "tem po",
+ "de pois",
+ "c ha",
+ "m el",
+ "ch e",
+ "l ha",
+ "a inda",
+ "faz er",
+ "con tra",
+ "p ou",
+ "per gun",
+ "de ix",
+ "ta mb",
+ "ra r",
+ "al a",
+ "v en",
+ "t in",
+ "pel o",
+ "tamb ém",
+ "fi ca",
+ "pre c",
+ "el es",
+ "tra n",
+ "ha via",
+ "l á",
+ "to dos",
+ "j u",
+ "qu al",
+ "c an",
+ "ta do",
+ "cas a",
+ "es sa",
+ "n as",
+ "g em",
+ "m em",
+ "se i",
+ "na da",
+ "sen ti",
+ "c ri",
+ "ó s",
+ "de u",
+ "ei ro",
+ ". .",
+ "f un",
+ "as sim",
+ "s ou",
+ "ent re",
+ "com e",
+ "i or",
+ "h ar",
+ "f e",
+ "por que",
+ "s or",
+ "f in",
+ "ta mente",
+ "a qui",
+ "cu l",
+ "t ó",
+ "for ma",
+ "s ar",
+ "ou tra",
+ "olh os",
+ "i ma",
+ "m im",
+ "a go",
+ "in s",
+ "co u",
+ "g ran",
+ "v al",
+ "pesso as",
+ "era m",
+ "ei ra",
+ "a que",
+ "com p",
+ "de i",
+ "p ela",
+ "co isa",
+ "m ão",
+ "con h",
+ "ca da",
+ "ago ra",
+ "ia m",
+ "h á",
+ "con s",
+ "su as",
+ "gu ém",
+ "o b",
+ "l an",
+ "es ti",
+ "á s",
+ "la do",
+ "in ter",
+ "ca be",
+ "por ta",
+ "n em",
+ "í vel",
+ "r is",
+ "j e",
+ "n un",
+ "sem pre",
+ "con segu",
+ "h as",
+ "tra bal",
+ "f u",
+ "le v",
+ "l em",
+ "l as",
+ "va i",
+ "tr os",
+ "t ante",
+ "te i",
+ "pr ó",
+ "que m",
+ "tu ra",
+ "on de",
+ "cabe ça",
+ "nun ca",
+ "men tos",
+ "h um",
+ "de le",
+ "ver dade",
+ "t á",
+ "h os",
+ "el i",
+ "ent es",
+ "m er",
+ "alg um",
+ "diz er",
+ "s in",
+ "pen as",
+ "n ós",
+ "en quanto",
+ "ou tro",
+ "l ho",
+ "es te",
+ "mel hor",
+ "est ar",
+ "g an",
+ "b ar",
+ "pri mei",
+ "a u",
+ "i u",
+ "pen sa",
+ "a penas",
+ "p ra",
+ "es tou",
+ "con te",
+ "res pon",
+ "ho mem",
+ "do is",
+ "a do",
+ "c al",
+ "a b",
+ "l os",
+ "ç as",
+ "pou co",
+ "sen hor",
+ "t ando",
+ "esp era",
+ "pa i",
+ "ri os",
+ "no i",
+ "i da",
+ "ba ix",
+ "as e",
+ "is as",
+ "f r",
+ "ho ra",
+ "mu ndo",
+ "pas sa",
+ "fi car",
+ "to do",
+ "se ja",
+ "al mente",
+ "â n",
+ "c lar",
+ "a d",
+ "in c",
+ "f os",
+ "lo n",
+ "g ri",
+ "ou vi",
+ "v em",
+ "g e",
+ "ta va",
+ "á rio",
+ "mo n",
+ "s os",
+ "in ho",
+ "ma l",
+ "t an",
+ "t re",
+ "gran de",
+ "ran do",
+ "b u",
+ "v ou",
+ "ê s",
+ "co isas",
+ "a conte",
+ "lh er",
+ "g en",
+ "ci on",
+ "an os",
+ "i do",
+ "tal vez",
+ "est ão",
+ "li v",
+ "sa b",
+ "su r",
+ "ou tros",
+ "c re",
+ "qual quer",
+ "g ou",
+ "t ri",
+ "l í",
+ "tiv esse",
+ "ra do",
+ "prec isa",
+ "mã e",
+ "su s",
+ "t anto",
+ "de la",
+ "men os",
+ "s al",
+ "en tra",
+ "p é",
+ "ma ior",
+ "noi te",
+ "ti va",
+ "p ala",
+ "so n",
+ "ra ção",
+ "de us",
+ "s as",
+ "un i",
+ "l or",
+ "u l",
+ "in te",
+ "f ei",
+ "an o",
+ "par ti",
+ "pala v",
+ "tr ás",
+ "par te",
+ "b el",
+ "ci dade",
+ "lu gar",
+ "v os",
+ "vez es",
+ "do u",
+ "en contra",
+ "tr u",
+ "e ci",
+ "a r",
+ "e r",
+ "a n",
+ "e n",
+ "i n",
+ "i r",
+ "o r",
+ "d e",
+ "a k",
+ "ı n",
+ "a l",
+ "d i",
+ "d a",
+ "b u",
+ "b ir",
+ "y or",
+ "i l",
+ "e k",
+ "y a",
+ "m a",
+ "l a",
+ "e l",
+ "u n",
+ "k a",
+ "l ar",
+ "i m",
+ "d ı",
+ "e t",
+ "o n",
+ "d u",
+ "o l",
+ "e y",
+ "t ı",
+ "m i",
+ "h a",
+ "b a",
+ "l er",
+ "ü n",
+ "m ı",
+ "i z",
+ "l e",
+ "ı r",
+ "m e",
+ "i s",
+ "n e",
+ "o k",
+ "t a",
+ "s a",
+ "u m",
+ "r a",
+ "g ö",
+ "i k",
+ "s ı",
+ "d en",
+ "e s",
+ "b il",
+ "t i",
+ "l ı",
+ "ü z",
+ "i ç",
+ "ü r",
+ "g i",
+ "u r",
+ "t e",
+ "b en",
+ "d an",
+ "i y",
+ "ı m",
+ "u z",
+ "v e",
+ "c ak",
+ "a y",
+ "c e",
+ "i ş",
+ "ın ı",
+ "i yor",
+ "ba ş",
+ "d ü",
+ "a t",
+ "a m",
+ "g el",
+ "de ğ",
+ "k ar",
+ "i ̇",
+ "m u",
+ "e v",
+ "ö y",
+ "bu n",
+ "v ar",
+ "ya p",
+ "s en",
+ "an a",
+ "s un",
+ "in i",
+ "gö r",
+ "y ı",
+ "k i",
+ "l i",
+ "ar a",
+ "al ı",
+ "on u",
+ "ç ı",
+ "ş ey",
+ "s ın",
+ "k ı",
+ "ka d",
+ "s e",
+ "t an",
+ "a ğ",
+ "değ il",
+ "s in",
+ "ü k",
+ "a z",
+ "ç ok",
+ "s on",
+ "ş ı",
+ "b i",
+ "ü l",
+ "t u",
+ "v er",
+ "iç in",
+ "g e",
+ "k en",
+ "ey e",
+ "ol du",
+ "mı ş",
+ "y e",
+ "k al",
+ "m ek",
+ "l an",
+ "öy le",
+ "yor du",
+ "er i",
+ "y üz",
+ "mi ş",
+ "b e",
+ "m ak",
+ "o la",
+ "in e",
+ "y an",
+ "h er",
+ "c ek",
+ "yor um",
+ "b ak",
+ "ü m",
+ "ö n",
+ "lar ı",
+ "o ğ",
+ "d er",
+ "kad ar",
+ "h al",
+ "ar ı",
+ "s t",
+ "s an",
+ "ın da",
+ "du r",
+ "g ün",
+ "v a",
+ "y ok",
+ "y er",
+ "dı m",
+ "k o",
+ "da ha",
+ "l u",
+ "ın a",
+ "di m",
+ "e m",
+ "bil ir",
+ "ik i",
+ "s iz",
+ "s i",
+ "n a",
+ "di ğ",
+ "s u",
+ "b ü",
+ "ha y",
+ "s or",
+ "dü ş",
+ "ü ç",
+ "un u",
+ "ö r",
+ "d ir",
+ "m ü",
+ "c a",
+ "am an",
+ "f ak",
+ "a da",
+ "e de",
+ "son ra",
+ "h iç",
+ "ak i",
+ "ğ ı",
+ "bu l",
+ "r u",
+ "ma z",
+ "an la",
+ "bu ra",
+ "ge ç",
+ "ma ya",
+ "l en",
+ "k onu",
+ "c i",
+ "c u",
+ "d in",
+ "t ek",
+ "z aman",
+ "el er",
+ "ö z",
+ "dı r",
+ "gi bi",
+ "o t",
+ "ş a",
+ "g er",
+ "ler i",
+ "k im",
+ "k u",
+ "fak at",
+ "y ar",
+ "gö z",
+ "c ı",
+ "yor sun",
+ "b ek",
+ "in de",
+ "r o",
+ "p ek",
+ "bun u",
+ "l ik",
+ "m an",
+ "il er",
+ "e di",
+ "ö l",
+ "s ür",
+ "b in",
+ "s ır",
+ "çı k",
+ "sı l",
+ "al ar",
+ "k es",
+ "y ak",
+ "ç ek",
+ "yı l",
+ "e cek",
+ "ı z",
+ "gi t",
+ "ka p",
+ "a ma",
+ "ı l",
+ "lar ın",
+ "b iz",
+ "tı r",
+ "o y",
+ "an cak",
+ "d oğ",
+ "ç a",
+ "b ana",
+ "ş im",
+ "baş la",
+ "l ü",
+ "ma dı",
+ "ben i",
+ "t ir",
+ "y ük",
+ "lı k",
+ "be ş",
+ "b el",
+ "b er",
+ "m er",
+ "na sıl",
+ "tı k",
+ "k e",
+ "t ür",
+ "a v",
+ ". .",
+ "d aki",
+ "p ar",
+ "t er",
+ "ce ğ",
+ "t en",
+ "z ı",
+ "iy i",
+ "d ok",
+ "ben im",
+ "c ağ",
+ "n er",
+ "y en",
+ "ş u",
+ "me z",
+ "düş ün",
+ "ken di",
+ "şim di",
+ "y ol",
+ "y u",
+ "de v",
+ "is te",
+ "s ek",
+ "ma m",
+ "s öyle",
+ "di k",
+ "t o",
+ "k ur",
+ "oldu ğ",
+ "s ını",
+ "t ar",
+ "bil iyor",
+ "k an",
+ "y al",
+ "m eye",
+ "mu ş",
+ "f a",
+ "ka ç",
+ "bil e",
+ "iy e",
+ "t ü",
+ "e f",
+ "tı m",
+ "ev et",
+ "ç o",
+ "y et",
+ "g en",
+ "bura da",
+ "t im",
+ "bir az",
+ "es i",
+ "k or",
+ "doğ ru",
+ "in in",
+ "kı z",
+ "di ye",
+ "d ör",
+ "et ti",
+ "on un",
+ "is ti",
+ "ğ i",
+ "h e",
+ "s ana",
+ "ü ş",
+ "ar ka",
+ "hay ır",
+ "kar şı",
+ "h ar",
+ "il e",
+ "h ak",
+ "ı yor",
+ "ne den",
+ "s ev",
+ "sı z",
+ "ço cu",
+ "me m",
+ "ç alı",
+ "ol ur",
+ "b ır",
+ "g ir",
+ "is e",
+ "i h",
+ "c an",
+ "k ır",
+ "d ön",
+ "b öyle",
+ "sen i",
+ "! \"",
+ "al t",
+ "dör t",
+ "s öy",
+ "o ş",
+ "mu sun",
+ "la ş",
+ "h an",
+ "i p",
+ "ka y",
+ "h em",
+ "bü yük",
+ "a ç",
+ "bır ak",
+ "mi sin",
+ "s öz",
+ "u l",
+ "değ iş",
+ "ün ü",
+ "g ül",
+ "k ö",
+ "kar ı",
+ "ta mam",
+ "ol u",
+ "r ar",
+ "yen i",
+ "la m",
+ "mış tı",
+ "ya ş",
+ "al a",
+ "in iz",
+ "kad ın",
+ "bun un",
+ "m ey",
+ "al tı",
+ "y i",
+ "s o",
+ "in den",
+ "sen in",
+ "ya t",
+ "to p",
+ "s er",
+ "is i",
+ "d ün",
+ "s es",
+ "hiç bir",
+ "y on",
+ "d ın",
+ "t ün",
+ "baş ka",
+ "a s",
+ "he p",
+ "i t",
+ "ir mi",
+ "dev am",
+ "ola cak",
+ "ar tık",
+ "r e",
+ "dur um",
+ "im iz",
+ "üz el",
+ "ler ini",
+ "sa ğ",
+ "p ro",
+ "ger ek",
+ "y irmi",
+ "ş ek",
+ "ba ğ",
+ "me di",
+ "lar a",
+ "a h",
+ "t ur",
+ "y ür",
+ "ma sı",
+ "ka tı",
+ "de di",
+ "g ü",
+ "sor un",
+ "el i",
+ "ün e",
+ "mı z",
+ "yap ı",
+ "m il",
+ "ğ ını",
+ "t ara",
+ "m en",
+ "ha t",
+ "var dı",
+ "m et",
+ "konu ş",
+ "ar ak",
+ "lar ak",
+ "çocu k",
+ "bü tün",
+ "l ey",
+ "d ür",
+ "g üzel",
+ "ay ı",
+ "yap a",
+ "n ı",
+ "ay r",
+ "ö ne",
+ "yordu m",
+ "b an",
+ "i̇ ş",
+ "du m",
+ "un a",
+ "on a",
+ "yor lar",
+ "lar ını",
+ "çı kar",
+ "z an",
+ "se ç",
+ "l iyor",
+ "t ak",
+ "şı k",
+ "tek rar",
+ "a ş",
+ "e ş",
+ "miş ti",
+ "f ar",
+ "k in",
+ "im i",
+ "i f",
+ "e ğ",
+ "gi di",
+ "le ş",
+ "başla dı",
+ "gi de",
+ "ot ur",
+ "d de",
+ "ın dan",
+ "üz er",
+ "ın ın",
+ "n ız",
+ "u y",
+ "ye di",
+ "ka t",
+ "o larak",
+ "la dı",
+ "yal nız",
+ "ba h",
+ "iy et",
+ "m al",
+ "s ak",
+ "a çık",
+ "sın da",
+ ".. .",
+ "in san",
+ "ay nı",
+ "e der",
+ "is tan",
+ "uz un",
+ "sa h",
+ "d o",
+ "g eri",
+ "er ek",
+ "ol an",
+ "ger çek",
+ "f en",
+ "al an",
+ "dı ş",
+ "alı k",
+ "far k",
+ "ü st",
+ "sa de",
+ "r i",
+ "k iş",
+ "l dı",
+ "z or",
+ "et ir",
+ "her kes",
+ "s al",
+ "ö mer",
+ "s el",
+ "un da",
+ "ha f",
+ "bun a",
+ "y dı",
+ "pek i",
+ "ada m",
+ "ha z",
+ "sın a",
+ "kap ı",
+ "gör üş",
+ "sade ce",
+ "al dı",
+ "gel di",
+ "i e",
+ "n ie",
+ "n a",
+ "r z",
+ "s z",
+ "c z",
+ "p o",
+ "s t",
+ "c h",
+ "i ę",
+ "d z",
+ "n i",
+ "a ł",
+ "r a",
+ "j e",
+ "r o",
+ "d o",
+ "s ię",
+ "z a",
+ "g o",
+ "e m",
+ "w i",
+ "c i",
+ "rz e",
+ "k o",
+ "l e",
+ "l i",
+ "w a",
+ "t o",
+ "k a",
+ "m i",
+ "ż e",
+ "t a",
+ "w ie",
+ "b y",
+ "m o",
+ "w y",
+ "rz y",
+ "ł a",
+ "j a",
+ "n o",
+ "ł o",
+ "w o",
+ "p a",
+ "m a",
+ "t e",
+ "t y",
+ "n y",
+ "k i",
+ "d a",
+ "n e",
+ "dz ie",
+ "dz i",
+ "cz y",
+ "c ie",
+ "m y",
+ "p rze",
+ "d y",
+ "o d",
+ "l a",
+ "k ie",
+ "r y",
+ "st a",
+ "j ą",
+ "ó w",
+ "c e",
+ "p rzy",
+ "c o",
+ "k u",
+ "m ie",
+ "sz y",
+ "cz e",
+ "r e",
+ "b a",
+ "s i",
+ "b ie",
+ "m u",
+ "w e",
+ "c y",
+ "ni a",
+ "ś ci",
+ "sz e",
+ "je st",
+ "k t",
+ "s a",
+ "b o",
+ "t u",
+ "ż y",
+ "n ą",
+ "b i",
+ "r u",
+ "a le",
+ "kt ó",
+ "p ra",
+ "ał a",
+ "m nie",
+ "p ie",
+ "ł y",
+ "cz a",
+ "ja k",
+ "ro z",
+ "r ó",
+ "l u",
+ "z na",
+ "g a",
+ "ra z",
+ "ł u",
+ "ta k",
+ "j u",
+ "p i",
+ "ś ć",
+ "s o",
+ "wi a",
+ "m ó",
+ "ch o",
+ "w szy",
+ "p e",
+ "s po",
+ "c a",
+ "g dy",
+ "w ał",
+ "w ię",
+ "d e",
+ "b e",
+ "p ro",
+ "ł em",
+ "j ę",
+ "s k",
+ "z e",
+ "l o",
+ "g i",
+ "r ę",
+ "do b",
+ "d u",
+ "ju ż",
+ "st o",
+ "b ę",
+ "ał em",
+ "sz a",
+ "m e",
+ "po d",
+ "d la",
+ "pa n",
+ "n ę",
+ "z o",
+ "mo że",
+ "ś li",
+ "s ie",
+ "ał o",
+ "t em",
+ "l ko",
+ "ny ch",
+ "po wie",
+ "c ię",
+ "s u",
+ "ty lko",
+ "i n",
+ "b u",
+ "na j",
+ "ch a",
+ "te go",
+ "p u",
+ "s ki",
+ "ne go",
+ "wszy st",
+ "sz cze",
+ "je d",
+ "je j",
+ "t wo",
+ "ą d",
+ "ś my",
+ "cz ę",
+ "wa ć",
+ "je go",
+ "ż a",
+ "i m",
+ "s y",
+ "pra w",
+ "ty m",
+ "któ ry",
+ "ał y",
+ "t rze",
+ "nie j",
+ "s e",
+ "ny m",
+ "i ch",
+ "o b",
+ ". .",
+ "g ło",
+ "ją c",
+ "mó wi",
+ "s ka",
+ "o n",
+ "ne j",
+ "s łu",
+ "w ła",
+ "bę dzie",
+ "d ę",
+ "p ó",
+ "be z",
+ "ni c",
+ "p ła",
+ "ś cie",
+ "mi a",
+ "s ą",
+ "t rzy",
+ "kie m",
+ "by ł",
+ "mo g",
+ "ro bi",
+ "ta m",
+ "c u",
+ "te n",
+ "m ię",
+ "z y",
+ "pe w",
+ "ci a",
+ "my ś",
+ "prze d",
+ "s ko",
+ "n u",
+ "któ re",
+ "a l",
+ "l ę",
+ "w sze",
+ "ą c",
+ "by ło",
+ "so bie",
+ "p y",
+ "ci ą",
+ "ba r",
+ "je szcze",
+ "h a",
+ "t ę",
+ "b ra",
+ "cza s",
+ "sz ę",
+ "g ł",
+ "k ę",
+ "ma r",
+ "cz u",
+ "prze z",
+ "f i",
+ "s ło",
+ "w z",
+ "k to",
+ "k ów",
+ "cz o",
+ "li śmy",
+ "st ra",
+ "wię c",
+ "r ą",
+ "ma m",
+ "w ó",
+ "rz a",
+ "g ro",
+ "no ści",
+ "f a",
+ "we t",
+ "ną ł",
+ "ś mie",
+ "na wet",
+ "mu si",
+ "s wo",
+ "te j",
+ "w ą",
+ "w u",
+ "wi ą",
+ "ni u",
+ "cz ą",
+ "b li",
+ "dz o",
+ "s kie",
+ "n em",
+ "je śli",
+ "cze go",
+ "ch y",
+ "d ł",
+ "ty ch",
+ "by m",
+ "ż o",
+ "e ś",
+ "si ą",
+ "kie dy",
+ "na s",
+ "w ró",
+ "dz e",
+ "d ro",
+ "t ra",
+ "r ów",
+ "pa ni",
+ "z ie",
+ "ku l",
+ "na d",
+ "ch wi",
+ "ni m",
+ "t ro",
+ "by ć",
+ "cho dzi",
+ "ni o",
+ "dob rze",
+ "te raz",
+ "wo kul",
+ "co ś",
+ "k ł",
+ "pie r",
+ "h e",
+ "g dzie",
+ "dz y",
+ "p ię",
+ "d ź",
+ "k ą",
+ "g ó",
+ "z da",
+ "ch ce",
+ "st ę",
+ "o r",
+ "ś wia",
+ "wszyst ko",
+ "st ro",
+ "pe ł",
+ "wie m",
+ "wie l",
+ "ka ż",
+ "ki m",
+ "rz u",
+ "s ły",
+ "jed na",
+ "z u",
+ "myś l",
+ "mó j",
+ "g u",
+ "wa r",
+ "jest em",
+ "ó ż",
+ "mie j",
+ "mo ż",
+ "k ła",
+ "re sz",
+ "d łu",
+ "st wo",
+ "n ię",
+ "ma sz",
+ "że by",
+ "nie m",
+ "ja kie",
+ "st y",
+ "ni ą",
+ "we j",
+ "o j",
+ "g ra",
+ "s ła",
+ "no ść",
+ "z ło",
+ "sz czę",
+ ".. .",
+ "r i",
+ "le j",
+ "we go",
+ "c ał",
+ "dzi ał",
+ "ki ch",
+ "dz a",
+ "dz ię",
+ "o czy",
+ "zo sta",
+ "cz ło",
+ "na m",
+ "ki l",
+ "o na",
+ "sz u",
+ "w ę",
+ "pa r",
+ "mi ał",
+ "st rze",
+ "ce j",
+ "e j",
+ "zna j",
+ "da ć",
+ "miej s",
+ "k ró",
+ "k ry",
+ "bar dzo",
+ "si a",
+ "z i",
+ "ś nie",
+ "l ą",
+ "g ie",
+ "cie bie",
+ "d ni",
+ "st u",
+ "po trze",
+ "wokul ski",
+ "u wa",
+ "u mie",
+ "jedna k",
+ "k ra",
+ "wró ci",
+ "czło wie",
+ "czy ć",
+ "by ła",
+ "że li",
+ "m ę",
+ "c ę",
+ "z robi",
+ "mog ę",
+ "pro wa",
+ "r em",
+ "nie ch",
+ "cz nie",
+ "k ro",
+ "t ą",
+ "ch ci",
+ "b ro",
+ "dzie ć",
+ "sz ą",
+ "pa d",
+ "t rz",
+ "t ru",
+ "je m",
+ "a ni",
+ "t ów",
+ "a r",
+ "d ru",
+ "ta j",
+ "rze kł",
+ "sa m",
+ "st e",
+ "nie go",
+ "ta kie",
+ "w ała",
+ "to wa",
+ "ka pła",
+ "wi dzi",
+ "po dob",
+ "dz ę",
+ "t ał",
+ "stę p",
+ "b ą",
+ "po ko",
+ "w em",
+ "g ę",
+ "a by",
+ "g e",
+ "al bo",
+ "s pra",
+ "z no",
+ "de n",
+ "s mo",
+ "je sz",
+ "k się",
+ "jest eś",
+ "po z",
+ "ni gdy",
+ "k sią",
+ "c óż",
+ "w s",
+ "po w",
+ "t ka",
+ "ś wie",
+ "sz ka",
+ "sa mo",
+ "s ł",
+ "rz ę",
+ "na le",
+ "chce sz",
+ "ni k",
+ "p ę",
+ "chy ba",
+ "cią g",
+ "ją cy",
+ "wo j",
+ "na sze",
+ "mnie j",
+ "wię cej",
+ "z wy",
+ "o sta",
+ "f e",
+ "wa ż",
+ "h o",
+ "se r",
+ "śmie r",
+ "wie r",
+ "dz ą",
+ "za ś",
+ "gdy by",
+ "ja ki",
+ "wo l",
+ "wi n",
+ "d ą",
+ "ści a",
+ "roz ma",
+ "wa l",
+ "pa nie",
+ "sta r",
+ "ka z",
+ "je żeli",
+ "d em",
+ "w ra",
+ "ko ń",
+ "sie bie",
+ "zno wu",
+ "p ró",
+ "cz em",
+ "st wa",
+ "i sto",
+ "pó ł",
+ "d ał",
+ "ko bie",
+ "ała m",
+ "wy ch",
+ "ce sa",
+ "ni ch",
+ "za wsze",
+ "dzi ć",
+ "te ż",
+ "le pie",
+ "pro szę",
+ "k re",
+ "t wa",
+ "o t",
+ "ł ą",
+ "ch u",
+ "c ą",
+ "p rz",
+ "ł e",
+ "sze dł",
+ "od powie",
+ "my śli",
+ "ś wią",
+ "e n",
+ "e r",
+ "d e",
+ "a n",
+ "e t",
+ "i j",
+ "i n",
+ "e l",
+ "a a",
+ "s t",
+ "o r",
+ "g e",
+ "i s",
+ "a t",
+ "i e",
+ "c h",
+ "o n",
+ "e en",
+ "h et",
+ "i t",
+ "v er",
+ "aa r",
+ "a l",
+ "o or",
+ "g en",
+ "v an",
+ "o p",
+ "d en",
+ "h e",
+ "o m",
+ "t e",
+ "w e",
+ "i k",
+ "r e",
+ "z e",
+ "ij n",
+ "d at",
+ "b e",
+ "d er",
+ "in g",
+ "o e",
+ "ij k",
+ "a an",
+ "ch t",
+ "v oor",
+ "l e",
+ "i et",
+ "r o",
+ "m o",
+ "k en",
+ "z ijn",
+ "m en",
+ "i g",
+ "j e",
+ "n iet",
+ "a r",
+ "o o",
+ "i d",
+ "u n",
+ "i l",
+ "s ch",
+ "mo et",
+ "st e",
+ "u r",
+ "o l",
+ "he b",
+ "u it",
+ "g el",
+ "w ij",
+ "a s",
+ "m e",
+ "t en",
+ "w or",
+ "o u",
+ "v en",
+ "l en",
+ "aa t",
+ "d it",
+ "m et",
+ "r a",
+ "b en",
+ "s p",
+ "o ver",
+ "d ie",
+ "n o",
+ "w er",
+ "l ijk",
+ "f t",
+ "s l",
+ "an d",
+ "v e",
+ "t er",
+ "i er",
+ "i en",
+ "t o",
+ "d aar",
+ "g r",
+ "b el",
+ "de ze",
+ "d u",
+ "a g",
+ "k an",
+ "wor den",
+ "in gen",
+ "moet en",
+ "n en",
+ "on der",
+ "heb ben",
+ "r u",
+ "oo k",
+ "s en",
+ "c t",
+ "k t",
+ "no g",
+ "aa l",
+ "w as",
+ "u l",
+ "e er",
+ "b ij",
+ "m ijn",
+ "p ro",
+ "v ol",
+ "d o",
+ "k om",
+ "at ie",
+ "e ft",
+ "k el",
+ "al s",
+ "r ij",
+ "he id",
+ "a f",
+ "st el",
+ "m aar",
+ "a p",
+ "we e",
+ "a d",
+ "he eft",
+ "w aar",
+ "i cht",
+ "d an",
+ "er en",
+ "n e",
+ "w el",
+ "w at",
+ "w il",
+ "a cht",
+ "aa g",
+ "ge b",
+ "c on",
+ "z o",
+ "k e",
+ "b et",
+ "h ij",
+ "d ig",
+ "k un",
+ "u w",
+ "d t",
+ "d oor",
+ "t ij",
+ "a m",
+ "an g",
+ "on d",
+ "er s",
+ "is ch",
+ "ge en",
+ "i ge",
+ "ge v",
+ "ve el",
+ "n u",
+ "m a",
+ "on s",
+ "o f",
+ "b l",
+ "n aar",
+ "g ro",
+ "p l",
+ "an der",
+ "at en",
+ "kun nen",
+ "e cht",
+ "h ier",
+ "g oe",
+ "an t",
+ "u s",
+ "t wee",
+ "on t",
+ "de lijk",
+ "el e",
+ "u ur",
+ "al le",
+ "t oe",
+ "me er",
+ "i st",
+ "n a",
+ "n ie",
+ "on ze",
+ "l o",
+ "i m",
+ "p en",
+ "h ad",
+ "tij d",
+ "h oe",
+ "to t",
+ "z ou",
+ "a k",
+ "aa k",
+ "a men",
+ "d r",
+ "w oor",
+ "s e",
+ "wor dt",
+ "o t",
+ "gel ijk",
+ "g aan",
+ "i c",
+ "g er",
+ "k er",
+ "el d",
+ "e m",
+ "h ou",
+ "de l",
+ "z en",
+ "z el",
+ "te gen",
+ "b o",
+ "kom en",
+ "c om",
+ "i gen",
+ "e it",
+ "wer k",
+ "goe d",
+ "z al",
+ "z ij",
+ "sl ag",
+ "e s",
+ "z ien",
+ "a st",
+ "echt er",
+ "it ie",
+ "t ie",
+ "el ijk",
+ "m is",
+ "isch e",
+ "bel an",
+ "h aar",
+ "i ch",
+ "b er",
+ "h an",
+ "v r",
+ "al e",
+ "c i",
+ "gr ijk",
+ "in d",
+ "do en",
+ "l and",
+ "belan grijk",
+ "p un",
+ "op en",
+ "ct ie",
+ "zel f",
+ "m ij",
+ "it eit",
+ "ste m",
+ "me e",
+ "ar en",
+ "al l",
+ "b r",
+ "re cht",
+ "d ien",
+ "h u",
+ "g aat",
+ "pro b",
+ "m oe",
+ "p er",
+ "a u",
+ "ul len",
+ "z ich",
+ "daar om",
+ "or m",
+ "k l",
+ "v o",
+ "en t",
+ "st aat",
+ "z it",
+ "du i",
+ "n at",
+ "du s",
+ "d s",
+ "ver slag",
+ "kel ijk",
+ "prob le",
+ "w et",
+ "ge m",
+ "c r",
+ "i on",
+ "p r",
+ "sch ap",
+ "g d",
+ "h un",
+ "z a",
+ "er d",
+ "z et",
+ "st aan",
+ "st r",
+ "m aal",
+ "in der",
+ "e id",
+ "st en",
+ "p ar",
+ "k ken",
+ "ge d",
+ "z ullen",
+ "re s",
+ "men sen",
+ "j aar",
+ "re gel",
+ "ie der",
+ "vol gen",
+ "ge ven",
+ "e ven",
+ "l u",
+ "bl ij",
+ "i ë",
+ "k o",
+ "u we",
+ "m an",
+ "ma ken",
+ "l ie",
+ "g a",
+ "oe k",
+ "nie uwe",
+ "b aar",
+ "h o",
+ "h er",
+ "in ter",
+ "ander e",
+ "ru ik",
+ "s u",
+ "a gen",
+ "or t",
+ "m er",
+ "ou w",
+ "st er",
+ "wil len",
+ "aa kt",
+ "h oo",
+ "an den",
+ "f f",
+ "l ig",
+ "t re",
+ "s amen",
+ "ze er",
+ "dui delijk",
+ "ant woor",
+ "he el",
+ "men t",
+ "pun t",
+ "hou den",
+ "we g",
+ "vr aag",
+ "gel e",
+ "een s",
+ "be sch",
+ "om en",
+ "er g",
+ "do el",
+ "d ag",
+ "sp e",
+ "ur en",
+ "ing s",
+ "or en",
+ "l ang",
+ "de len",
+ "m ar",
+ "ste un",
+ "in nen",
+ "p ol",
+ "o on",
+ "i de",
+ "s n",
+ "s ie",
+ "r icht",
+ "z onder",
+ "no dig",
+ "all een",
+ "m id",
+ "ra gen",
+ "iet s",
+ "ver sch",
+ "geb ruik",
+ "st u",
+ "ro uw",
+ "stel len",
+ "be g",
+ "men ten",
+ "v in",
+ "eer ste",
+ "l aat",
+ "gro ot",
+ "oo d",
+ "to ch",
+ "l aten",
+ "aar d",
+ "s le",
+ "de el",
+ "st and",
+ "pl aat",
+ "re e",
+ "bet re",
+ "d i",
+ "l id",
+ "uit en",
+ "ra cht",
+ "bel eid",
+ "g et",
+ "ar t",
+ "st ie",
+ "st aten",
+ "g gen",
+ "re ken",
+ "e in",
+ "al en",
+ "m ing",
+ "mo gelijk",
+ "gro te",
+ "al tijd",
+ "z or",
+ "en kel",
+ "w ik",
+ "pol itie",
+ "e igen",
+ "el k",
+ "han del",
+ "g t",
+ "k we",
+ "m aat",
+ "el en",
+ "i p",
+ "v rij",
+ "s om",
+ "je s",
+ "aa m",
+ "hu is",
+ "v al",
+ "we er",
+ "lid staten",
+ "k ing",
+ "k le",
+ "be d",
+ "gev al",
+ "stel l",
+ "a i",
+ "wik kel",
+ "kwe stie",
+ "t al",
+ "ste e",
+ "a b",
+ "h el",
+ "kom st",
+ "p as",
+ "s s",
+ "it u",
+ "i den",
+ "eer d",
+ "m in",
+ "c e",
+ "p o",
+ "twee de",
+ "proble em",
+ "w aren",
+ "us sen",
+ "sn el",
+ "t ig",
+ "ge w",
+ "j u",
+ "ul t",
+ "ne men",
+ "com mis",
+ "versch il",
+ "k on",
+ "z oek",
+ "k rij",
+ "gr aag",
+ "den k",
+ "l anden",
+ "re den",
+ "be sl",
+ "oe g",
+ "bet er",
+ "he den",
+ "m ag",
+ "p e",
+ "bo ven",
+ "a c",
+ "con t",
+ "f d",
+ "h ele",
+ "k r",
+ "v ier",
+ "w in",
+ "ge z",
+ "k w",
+ "m il",
+ "v or",
+ "he m",
+ "ra m",
+ "aa s",
+ "ont wikkel",
+ "dr ie",
+ "v aak",
+ "plaat s",
+ "l a",
+ "g ang",
+ "ij f",
+ "f in",
+ "nat uur",
+ "t ussen",
+ "u g",
+ "in e",
+ "d a",
+ "b at",
+ "kom t",
+ "w acht",
+ "aa d",
+ "u t",
+ "é n",
+ "acht er",
+ "geb ie",
+ "ver k",
+ "lig t",
+ "c es",
+ "nie uw",
+ "van d",
+ "s t",
+ "n í",
+ "j e",
+ "p o",
+ "c h",
+ "r o",
+ "n a",
+ "s e",
+ "t o",
+ "n e",
+ "l e",
+ "k o",
+ "l a",
+ "d o",
+ "r a",
+ "n o",
+ "t e",
+ "h o",
+ "n ě",
+ "v a",
+ "l i",
+ "l o",
+ "ř e",
+ "c e",
+ "d e",
+ "v e",
+ "b y",
+ "n i",
+ "s k",
+ "t a",
+ "n á",
+ "z a",
+ "p ro",
+ "v o",
+ "v ě",
+ "m e",
+ "v á",
+ "s o",
+ "k a",
+ "r á",
+ "v y",
+ "z e",
+ "m i",
+ "p a",
+ "t i",
+ "st a",
+ "m ě",
+ "n é",
+ "ř i",
+ "ř í",
+ "m o",
+ "ž e",
+ "m a",
+ "j í",
+ "v ý",
+ "j i",
+ "d ě",
+ "r e",
+ "d a",
+ "k u",
+ "j a",
+ "c i",
+ "r u",
+ "č e",
+ "o b",
+ "t ě",
+ "m u",
+ "k y",
+ "d i",
+ "š e",
+ "k é",
+ "š í",
+ "t u",
+ "v i",
+ "p ře",
+ "v í",
+ "s i",
+ "n ý",
+ "o d",
+ "so u",
+ "v é",
+ "n y",
+ "r i",
+ "d y",
+ "b u",
+ "b o",
+ "t y",
+ "l á",
+ "l u",
+ "n u",
+ "ž i",
+ "m á",
+ "st i",
+ "c í",
+ "z á",
+ "p ra",
+ "sk é",
+ "m í",
+ "c o",
+ "d u",
+ "d á",
+ "by l",
+ "st o",
+ "s a",
+ "t í",
+ "je d",
+ "p ří",
+ "p ři",
+ "t é",
+ "s í",
+ "č i",
+ "v ní",
+ "č a",
+ "d í",
+ "z i",
+ "st u",
+ "p e",
+ "b a",
+ "d ní",
+ "ro z",
+ "va l",
+ "l í",
+ "s po",
+ "k á",
+ "b e",
+ "p i",
+ "no u",
+ "ta k",
+ "st e",
+ "r y",
+ "l é",
+ "vě t",
+ "se m",
+ "p ě",
+ "ko n",
+ "ne j",
+ "l y",
+ "ko u",
+ "ý ch",
+ "b ě",
+ "p r",
+ "f i",
+ "p rá",
+ "a le",
+ "ja ko",
+ "po d",
+ "ž í",
+ "z í",
+ "j sou",
+ "j sem",
+ "ch o",
+ "l ní",
+ "c ké",
+ "t á",
+ "m y",
+ "a k",
+ "h u",
+ "va t",
+ "pře d",
+ "h la",
+ "k e",
+ "st á",
+ "č í",
+ "š i",
+ "s le",
+ "k la",
+ "š tě",
+ "lo u",
+ "m ů",
+ "z na",
+ "ch á",
+ "o r",
+ "p ů",
+ "h a",
+ "b i",
+ "ta ké",
+ "d ů",
+ "no st",
+ "t ře",
+ "te r",
+ "p u",
+ "i n",
+ "v r",
+ "ve l",
+ "sk u",
+ "v še",
+ "t ní",
+ "do b",
+ "by la",
+ "č ní",
+ "ja k",
+ "v u",
+ "je ho",
+ "b ý",
+ "vá ní",
+ "ný ch",
+ "po u",
+ "te n",
+ "t ři",
+ "v z",
+ "st ře",
+ "d va",
+ "h le",
+ "č á",
+ "no sti",
+ "c k",
+ "v š",
+ "vo u",
+ "s u",
+ "h e",
+ "h ra",
+ "je n",
+ "s y",
+ "da l",
+ "po z",
+ "s lo",
+ "te l",
+ "d ru",
+ "de n",
+ "vš ak",
+ "g i",
+ "k dy",
+ "by lo",
+ "bu de",
+ "st ra",
+ "j ší",
+ "m é",
+ "me n",
+ "vý ch",
+ "ní m",
+ "s m",
+ "ko li",
+ "r ů",
+ "t ra",
+ "mů že",
+ "ne ní",
+ "ho d",
+ "b í",
+ "do u",
+ "sk a",
+ "t ý",
+ "st ě",
+ "u je",
+ "s á",
+ "pě t",
+ "ne s",
+ "k rá",
+ "to m",
+ "st ví",
+ "v ně",
+ "se d",
+ "s vé",
+ "p í",
+ "z o",
+ "mu sí",
+ "u ž",
+ "tí m",
+ "jí cí",
+ "jed no",
+ "t r",
+ "ča s",
+ "e v",
+ "č ty",
+ "sk ý",
+ "ni c",
+ "ev ro",
+ "to ho",
+ "h y",
+ "k ter",
+ "r ní",
+ "st í",
+ "s vě",
+ "pa k",
+ "vše ch",
+ "k ů",
+ "n g",
+ "á d",
+ "chá zí",
+ "a ni",
+ "a r",
+ "jed na",
+ "bý t",
+ "t ro",
+ "k ra",
+ "pr vní",
+ "m no",
+ "ské ho",
+ "p á",
+ "p la",
+ "le m",
+ "ne bo",
+ "ke m",
+ "st ro",
+ "s la",
+ "né ho",
+ "z de",
+ "dal ší",
+ "ř a",
+ "čty ři",
+ "h rá",
+ "dru h",
+ "l ně",
+ "v la",
+ "sk ých",
+ "š ko",
+ "pů so",
+ "pro to",
+ "v ů",
+ "sk á",
+ "ve n",
+ "še st",
+ "d ně",
+ "je ště",
+ "me zi",
+ "te k",
+ "s ko",
+ "ch a",
+ "ně koli",
+ "be z",
+ "g ra",
+ "ji ž",
+ "č ně",
+ "j á",
+ "s lu",
+ "z ná",
+ "ve r",
+ "sed m",
+ "k ro",
+ "ta m",
+ "a no",
+ "v lá",
+ "o sm",
+ "byl y",
+ "vá m",
+ "ck ý",
+ "te ch",
+ "dě ji",
+ "vel mi",
+ "le ži",
+ "va la",
+ "l ý",
+ "t vo",
+ "spo le",
+ "ch u",
+ "stu p",
+ "mo ž",
+ "evro p",
+ "g e",
+ "sta l",
+ "j de",
+ "ch y",
+ "ro di",
+ "je jí",
+ "po li",
+ "de vět",
+ "s me",
+ "a ž",
+ "té to",
+ "re m",
+ "d é",
+ "f or",
+ "u ni",
+ "f o",
+ "ten to",
+ "a u",
+ "ka ž",
+ "nu la",
+ "na d",
+ "by ch",
+ "mo c",
+ "sto u",
+ "e x",
+ "le n",
+ "k do",
+ "z d",
+ "pra co",
+ "to mu",
+ "ný m",
+ "ži vo",
+ "ze m",
+ "f e",
+ "f u",
+ "ná sle",
+ "j o",
+ "sk y",
+ "ji ch",
+ "h á",
+ "mě l",
+ "dě la",
+ "j sme",
+ "p re",
+ "ni ce",
+ "ste j",
+ "ne m",
+ "st ní",
+ "he m",
+ "ná ro",
+ "z u",
+ "b li",
+ "ni t",
+ "pa r",
+ "a l",
+ "poz ději",
+ "ta ko",
+ "n ce",
+ "če r",
+ "ší m",
+ "ně co",
+ "vá l",
+ "ře j",
+ "krá t",
+ "á lní",
+ "u r",
+ ". .",
+ "a si",
+ "kter é",
+ "sta v",
+ "ma jí",
+ "my s",
+ "do bě",
+ "s ně",
+ "ce n",
+ "z y",
+ "z ku",
+ "t ů",
+ "ch od",
+ "s pě",
+ "je jich",
+ "sou čas",
+ "d r",
+ "va li",
+ "ri e",
+ "k te",
+ "pr ů",
+ "ze ní",
+ "pa t",
+ "a n",
+ "po tře",
+ "de m",
+ "d nes",
+ "ze mí",
+ "sa mo",
+ "zna m",
+ "b ra",
+ "má m",
+ "te dy",
+ "g o",
+ "hla vní",
+ "pou ží",
+ "b ní",
+ "ve de",
+ "le p",
+ "je k",
+ "pra v",
+ "poli ti",
+ "d ne",
+ "je m",
+ "le t",
+ "če ní",
+ "pro b",
+ "ne ž",
+ "dě l",
+ "fi l",
+ "č o",
+ "cí ch",
+ "st é",
+ "d lou",
+ "h i",
+ "a by",
+ "to u",
+ "několi k",
+ "d la",
+ "vy u",
+ "vi t",
+ "ho u",
+ "ck ých",
+ "no vé",
+ "či n",
+ "st y",
+ "dě lá",
+ "k ý",
+ "ob la",
+ "pod le",
+ "ra n",
+ "dů leži",
+ "ta to",
+ "po ku",
+ "ko ne",
+ "d ý",
+ "d vě",
+ "ž ád",
+ "nou t",
+ "t ku",
+ "t vr",
+ "cké ho",
+ "ro v",
+ "r é",
+ "te le",
+ "p sa",
+ "s vět",
+ "ti vní",
+ "do sta",
+ "te m",
+ "še l",
+ "druh é",
+ "s kou",
+ "ž o",
+ "jed ná",
+ "vý znam",
+ "prob lé",
+ "pu bli",
+ "vá n",
+ "od po",
+ "pod po",
+ "d le",
+ "ja ké",
+ "še ní",
+ "ví m",
+ "bě hem",
+ "na chází",
+ "s lou",
+ "pou ze",
+ "o tá",
+ "p lo",
+ "to vé",
+ "vět ši",
+ "ko mi",
+ "va jí",
+ "ty to",
+ "zá pa",
+ "z mě",
+ "mo h",
+ "ví ce",
+ "spole č",
+ "au to",
+ "pro ti",
+ "st ru",
+ "dě t",
+ "chá ze",
+ "že l",
+ "с т",
+ "е н",
+ "н о",
+ "н а",
+ "п р",
+ "т о",
+ "п о",
+ "р а",
+ "г о",
+ "к о",
+ "н е",
+ "в о",
+ "в а",
+ "е т",
+ "е р",
+ "н и",
+ "е л",
+ "и т",
+ "н ы",
+ "з а",
+ "р о",
+ "ен и",
+ "к а",
+ "л и",
+ "е м",
+ "д а",
+ "о б",
+ "л а",
+ "д о",
+ "с я",
+ "т ь",
+ "о т",
+ "л о",
+ "л ь",
+ "е д",
+ "с о",
+ "м и",
+ "р е",
+ "м о",
+ "ц и",
+ "пр о",
+ "т а",
+ "э то",
+ "к и",
+ "р у",
+ "пр и",
+ "т и",
+ "с е",
+ "ст а",
+ "в ы",
+ "м ы",
+ "в и",
+ "б ы",
+ "м а",
+ "е с",
+ "л я",
+ "ст и",
+ "л е",
+ "ч то",
+ "м е",
+ "р и",
+ "ч а",
+ "о д",
+ "е й",
+ "ел ь",
+ "ени я",
+ "г а",
+ "н у",
+ "с и",
+ "п а",
+ "ра з",
+ "б о",
+ "ст о",
+ "с у",
+ "с а",
+ "д у",
+ "е го",
+ "е ст",
+ "и н",
+ "ит ь",
+ "и з",
+ "ж е",
+ "м у",
+ "п ер",
+ "по д",
+ "ени е",
+ "с ь",
+ "к у",
+ "пр ед",
+ "но го",
+ "ны х",
+ "в ер",
+ "т е",
+ "но й",
+ "ци и",
+ "д е",
+ "р ы",
+ "д ел",
+ "л ю",
+ "в е",
+ "о н",
+ "м ен",
+ "г и",
+ "н я",
+ "б у",
+ "пр а",
+ "в се",
+ "ет ся",
+ "ст ь",
+ "ж а",
+ "до л",
+ "ж и",
+ "б е",
+ "ко н",
+ "с л",
+ "ш и",
+ "д и",
+ "ст в",
+ "с ко",
+ "ны е",
+ "ч и",
+ "ю т",
+ "д ер",
+ "ст ра",
+ "т ы",
+ "х од",
+ "щ и",
+ "з о",
+ "з на",
+ "но сти",
+ "ч ес",
+ "в ля",
+ "ва ть",
+ "о р",
+ "по л",
+ "в ет",
+ "та к",
+ "ш а",
+ "т у",
+ "с во",
+ "пр е",
+ "о на",
+ "ит ель",
+ "ны й",
+ "с ло",
+ "ка к",
+ "в л",
+ "но сть",
+ "х о",
+ "мо ж",
+ "п е",
+ "д ля",
+ "ни я",
+ "но е",
+ "ра с",
+ "дол ж",
+ "да р",
+ "т ель",
+ "с ка",
+ "п у",
+ "ст во",
+ "ко то",
+ "ра б",
+ "е е",
+ "ро д",
+ "э ти",
+ "с об",
+ "о ру",
+ "ж ен",
+ "ны м",
+ "ит и",
+ "ни е",
+ "ко м",
+ "д ет",
+ "ст у",
+ "г у",
+ "п и",
+ "ме ж",
+ "ени ю",
+ "т ер",
+ "раб от",
+ "во з",
+ "ци я",
+ "ко й",
+ "щ ест",
+ "г ра",
+ "з и",
+ "р я",
+ "меж ду",
+ "ст ва",
+ "в с",
+ "ел о",
+ "ш е",
+ "м ер",
+ "б а",
+ "з ы",
+ "л у",
+ "а ль",
+ "д ей",
+ "г ла",
+ "на род",
+ "к ти",
+ "пред ста",
+ "л ся",
+ "я вля",
+ "с ки",
+ "но в",
+ "ед ин",
+ "ро в",
+ "и с",
+ "ни ма",
+ "р ем",
+ "ход и",
+ "так же",
+ "д ру",
+ "а ть",
+ "сл ед",
+ "го во",
+ "на я",
+ "ю щи",
+ "ен ь",
+ "кото ры",
+ "х от",
+ "в у",
+ "и х",
+ "ем у",
+ "ч ит",
+ "ва ж",
+ "ор га",
+ "чес ки",
+ "щ е",
+ "к е",
+ "х а",
+ "по с",
+ "то м",
+ "бо ль",
+ "м не",
+ "па с",
+ "об ъ",
+ "пра в",
+ "кон ф",
+ "сл у",
+ "под дер",
+ "ст ви",
+ "на ш",
+ "ль ко",
+ "сто я",
+ "ну ю",
+ "л ем",
+ "ен ных",
+ "к ра",
+ "д ы",
+ "между народ",
+ "г да",
+ "не об",
+ "го су",
+ "ств у",
+ "ени и",
+ "госу дар",
+ "к то",
+ "и м",
+ "ч ест",
+ "р ет",
+ "во про",
+ "л ен",
+ "ел и",
+ "ро ва",
+ "ци й",
+ "на м",
+ "это й",
+ "ж ения",
+ "необ ходи",
+ "мен я",
+ "бы ло",
+ "си ли",
+ "ф и",
+ "в я",
+ "ш ь",
+ "это го",
+ "о ни",
+ "орга ни",
+ "бе зо",
+ "пр об",
+ "и ме",
+ "ре ш",
+ "б и",
+ "безо пас",
+ "ют ся",
+ "о ста",
+ "ен но",
+ "го д",
+ "ел а",
+ "предста в",
+ "ть ся",
+ "сло во",
+ "органи за",
+ "долж ны",
+ "это м",
+ "б ла",
+ "ч е",
+ "ч у",
+ "бла го",
+ "это му",
+ "в рем",
+ "с пе",
+ "но м",
+ "ени й",
+ "с по",
+ "на с",
+ "не т",
+ "з у",
+ "в ед",
+ "е ще",
+ "ска за",
+ "се й",
+ "ер ен",
+ "да н",
+ "са м",
+ "ел я",
+ "ра н",
+ "зы ва",
+ "явля ется",
+ "бу дет",
+ "кти в",
+ "т ре",
+ "дел е",
+ "м от",
+ "конф ерен",
+ "ла сь",
+ "ча с",
+ "сто ро",
+ "ко го",
+ "е з",
+ "не й",
+ "о с",
+ "ли сь",
+ "раз ору",
+ "пер е",
+ "с си",
+ "ны ми",
+ "про ц",
+ "го ло",
+ "ч ело",
+ "бо ле",
+ "чело ве",
+ "с ер",
+ "п л",
+ "ч ет",
+ "стра н",
+ "п я",
+ "бы л",
+ "к ла",
+ "то в",
+ "ж д",
+ "дел а",
+ "е ра",
+ "у же",
+ "со вет",
+ "г ен",
+ "безопас ности",
+ "ц а",
+ "се да",
+ "по з",
+ "от вет",
+ "проб лем",
+ "на ко",
+ "т ем",
+ "до ста",
+ "п ы",
+ "щ а",
+ "во й",
+ "су щест",
+ "необходи мо",
+ "бы ть",
+ "мож ет",
+ "д ем",
+ "что бы",
+ "е к",
+ "ч ер",
+ "у сили",
+ "ре с",
+ "ру д",
+ "един енных",
+ "д об",
+ "до сти",
+ "ств ен",
+ "я дер",
+ "год ня",
+ "ка за",
+ "се годня",
+ "сей час",
+ "то лько",
+ "во д",
+ "ес ь",
+ "м ного",
+ "бу ду",
+ "е в",
+ "ест ь",
+ "т ри",
+ "об щест",
+ ". .",
+ "я вл",
+ "вы сту",
+ "р ед",
+ "с чит",
+ "с ит",
+ "деле га",
+ "ло ж",
+ "это т",
+ "ф ор",
+ "к лю",
+ "воз мож",
+ "ва ния",
+ "б ли",
+ "и ли",
+ "в з",
+ "на ций",
+ "ско го",
+ "при ня",
+ "п ла",
+ "о ч",
+ "ить ся",
+ "ст е",
+ "на ши",
+ "которы е",
+ "а р",
+ "име ет",
+ "с от",
+ "зна ч",
+ "пер ь",
+ "след у",
+ "ен ы",
+ "та ки",
+ "объ единенных",
+ "ст ро",
+ "те перь",
+ "б ле",
+ "благо дар",
+ "раз в",
+ "а н",
+ "жи ва",
+ "оч ень",
+ "я т",
+ "бе з",
+ "об ес",
+ "г ро",
+ "ло сь",
+ "с ы",
+ "организа ции",
+ "ч лен",
+ "то го",
+ "она ль",
+ "ж да",
+ "все х",
+ "с вя",
+ "боле е",
+ "со в",
+ "ко гда",
+ "во т",
+ "к ре",
+ "к ры",
+ "по этому",
+ "во ль",
+ "о й",
+ "ген ера",
+ "ч ем",
+ "л ы",
+ "пол ити",
+ "в ен",
+ "конферен ции",
+ "проц ес",
+ "б я",
+ "ит е",
+ "от но",
+ "разв ити",
+ "а ф",
+ "ю щ",
+ "в но",
+ "ми р",
+ "ни и",
+ "ка я",
+ "а с",
+ "итель но",
+ "в то",
+ "ени ем",
+ "генера ль",
+ "пр от",
+ "вс ем",
+ "сам бле",
+ "ас самбле",
+ "о м",
+ "з д",
+ "с мот",
+ "ре ги",
+ "ч его",
+ "од нако",
+ "усили я",
+ "дей стви",
+ "ч но",
+ "у ча",
+ "об раз",
+ "во с",
+ "э та",
+ "пер его",
+ "гово р",
+ "ва м",
+ "мо ло",
+ "врем я",
+ "д ь",
+ "хот ел",
+ "г ру",
+ "за явл",
+ "пре доста",
+ "по ль",
+ "не е",
+ "ре зо",
+ "перего во",
+ "резо лю",
+ "к рет",
+ "поддер ж",
+ "обес пе",
+ "не го",
+ "представ ит",
+ "на де",
+ "к ри",
+ "ч ь",
+ "про ек",
+ "л ет",
+ "дру ги",
+ "ا ل",
+ "َ ا",
+ "و َ",
+ "ّ َ",
+ "ِ ي",
+ "أ َ",
+ "ل َ",
+ "ن َ",
+ "ال ْ",
+ "ه ُ",
+ "ُ و",
+ "م ا",
+ "ن ْ",
+ "م ن",
+ "ع َ",
+ "ن ا",
+ "ل ا",
+ "م َ",
+ "ت َ",
+ "ف َ",
+ "أ ن",
+ "ل ي",
+ "م ِ",
+ "ا ن",
+ "ف ي",
+ "ر َ",
+ "ي َ",
+ "ه ِ",
+ "م ْ",
+ "ق َ",
+ "ب ِ",
+ "ل ى",
+ "ي ن",
+ "إ ِ",
+ "ل ِ",
+ "و ا",
+ "ك َ",
+ "ه ا",
+ "ً ا",
+ "م ُ",
+ "و ن",
+ "ال م",
+ "ب َ",
+ "ي ا",
+ "ذ ا",
+ "س ا",
+ "ال ل",
+ "م ي",
+ "ي ْ",
+ "ر ا",
+ "ر ي",
+ "ل ك",
+ "م َا",
+ "ن َّ",
+ "ل م",
+ "إ ن",
+ "س ت",
+ "و م",
+ "ّ َا",
+ "ل َا",
+ "ه م",
+ "ّ ِ",
+ "ك ُ",
+ "ك ان",
+ "س َ",
+ "ب ا",
+ "د ي",
+ "ح َ",
+ "ع ْ",
+ "ب ي",
+ "ال أ",
+ "و ل",
+ "ف ِي",
+ "ر ِ",
+ "د ا",
+ "مِ نْ",
+ "ُو نَ",
+ "و ْ",
+ "ه َا",
+ "ّ ُ",
+ "ال س",
+ "ال َ",
+ "ن ي",
+ "ل ْ",
+ "ت ُ",
+ "ه ل",
+ "ر ة",
+ "د َ",
+ "س ْ",
+ "ت ِ",
+ "ن َا",
+ "ر ْ",
+ "الل َّ",
+ "سا مي",
+ "ك ن",
+ "ك ل",
+ "ه َ",
+ "عَ لَ",
+ "ع لى",
+ "م ع",
+ "إ لى",
+ "ق د",
+ "ال ر",
+ "ُو ا",
+ "ي ر",
+ "ع ن",
+ "ي ُ",
+ "ن ِ",
+ "ب ْ",
+ "ال ح",
+ "هُ مْ",
+ "ق ا",
+ "ذ ه",
+ "ال ت",
+ "ِي نَ",
+ "ج َ",
+ "ه ذا",
+ "ع د",
+ "ال ع",
+ "د ْ",
+ "قَ الَ",
+ "ر ُ",
+ "ي م",
+ "ي ة",
+ "ن ُ",
+ "خ َ",
+ "ر ب",
+ "ال ك",
+ "و َا",
+ "أ نا",
+ "ة ِ",
+ "ال ن",
+ "ح د",
+ "ع ِ",
+ "ت ا",
+ "ه و",
+ "ف ا",
+ "ع ا",
+ "ال ش",
+ "ل ُ",
+ "ي ت",
+ "ذ َا",
+ "ي ع",
+ "ال ذ",
+ "ح ْ",
+ "ال ص",
+ "إِ نَّ",
+ "ج ا",
+ "ع لي",
+ "ك َا",
+ "ب ُ",
+ "ت ع",
+ "و ق",
+ "م ل",
+ "ل َّ",
+ "ي د",
+ "أ خ",
+ "ر ف",
+ "ت ي",
+ "ال ِ",
+ "ّ ا",
+ "ذ لك",
+ "أَ نْ",
+ "س ِ",
+ "ت وم",
+ "م ر",
+ "مَ نْ",
+ "ب ل",
+ "ال ق",
+ "الل ه",
+ "ِي َ",
+ "ك م",
+ "ذ َ",
+ "ع ل",
+ "ح ب",
+ "س ي",
+ "ع ُ",
+ "ال ج",
+ "ال د",
+ "ش َ",
+ "ت ك",
+ "ف ْ",
+ "ص َ",
+ "ل ل",
+ "د ِ",
+ "ب ر",
+ "ف ِ",
+ "ت ه",
+ "أ ع",
+ "ت ْ",
+ "ق ْ",
+ "الْ أَ",
+ "ئ ِ",
+ "عَ نْ",
+ "و ر",
+ "ح ا",
+ "ال َّ",
+ "م ت",
+ "ف ر",
+ "د ُ",
+ "ه نا",
+ "وَ أَ",
+ "ت ب",
+ "ة ُ",
+ "أ ي",
+ "س ب",
+ "ري د",
+ "و ج",
+ "كُ مْ",
+ "ح ِ",
+ "ك ْ",
+ "د ر",
+ "َا ء",
+ "ه ذه",
+ "ال ط",
+ "الْ مُ",
+ "د ة",
+ "ق ل",
+ "غ َ",
+ "ي وم",
+ "الَّ ذ",
+ "ك ر",
+ "ت ر",
+ "ك ِ",
+ "ك ي",
+ "عَلَ ى",
+ "رَ ب",
+ "ع ة",
+ "ق ُ",
+ "ج ْ",
+ "ف ض",
+ "ل ة",
+ "ه ْ",
+ "ر َا",
+ "وَ لَ",
+ "الْ مَ",
+ "أَ نَّ",
+ "ي َا",
+ "أ ُ",
+ "ش ي",
+ "اللَّ هُ",
+ "لَ ى",
+ "ق ِ",
+ "أ ت",
+ "عَلَ يْ",
+ "اللَّ هِ",
+ "ال ب",
+ "ض َ",
+ "ة ً",
+ "ق ي",
+ "ا ر",
+ "ب د",
+ "خ ْ",
+ "سْ تَ",
+ "ط َ",
+ "قَ دْ",
+ "ذه ب",
+ "أ م",
+ "ما ذا",
+ "وَ إِ",
+ "ة ٌ",
+ "و نَ",
+ "لي لى",
+ "و لا",
+ "ح ُ",
+ "ه ي",
+ "ص ل",
+ "ال خ",
+ "و د",
+ "لي س",
+ "ل دي",
+ "ق ال",
+ "كَا نَ",
+ "م َّ",
+ "ح ي",
+ "ت م",
+ "ل ن",
+ "وَ لَا",
+ "ب ع",
+ "يم كن",
+ "س ُ",
+ "ة َ",
+ "ح ت",
+ "ر ًا",
+ "ك ا",
+ "ش ا",
+ "هِ مْ",
+ "لَ هُ",
+ "ز َ",
+ "دا ً",
+ "م س",
+ "ك ث",
+ "الْ عَ",
+ "ج ِ",
+ "ص ْ",
+ "ف َا",
+ "ل ه",
+ "و ي",
+ "ع َا",
+ "هُ وَ",
+ "ب ِي",
+ "ب َا",
+ "أ س",
+ "ث َ",
+ "ل ِي",
+ "ر ض",
+ "الر َّ",
+ "لِ كَ",
+ "ت َّ",
+ "ف ُ",
+ "ق ة",
+ "ف عل",
+ "مِ ن",
+ "ال آ",
+ "ث ُ",
+ "س م",
+ "م َّا",
+ "بِ هِ",
+ "ت ق",
+ "خ ر",
+ "ل قد",
+ "خ ل",
+ "ش ر",
+ "أن ت",
+ "ل َّا",
+ "س ن",
+ "الس َّ",
+ "الذ ي",
+ "س َا",
+ "و ما",
+ "ز ل",
+ "و ب",
+ "أ ْ",
+ "إ ذا",
+ "ر ِي",
+ "ح ة",
+ "ن ِي",
+ "الْ حَ",
+ "وَ قَالَ",
+ "ب ه",
+ "ة ٍ",
+ "س أ",
+ "ر ٌ",
+ "ب ال",
+ "م ة",
+ "ش ْ",
+ "و ت",
+ "عن د",
+ "ف س",
+ "بَ عْ",
+ "ه ر",
+ "ق ط",
+ "أ ح",
+ "إن ه",
+ "و ع",
+ "ف ت",
+ "غ ا",
+ "هنا ك",
+ "ب ت",
+ "مِ نَ",
+ "س ر",
+ "ذَ لِكَ",
+ "ر س",
+ "حد ث",
+ "غ ْ",
+ "ّ ِي",
+ "ال إ",
+ "وَ يَ",
+ "ج ل",
+ "ا ست",
+ "ق ِي",
+ "ع ب",
+ "و س",
+ "ي ش",
+ "الَّذ ِينَ",
+ "تا ب",
+ "د ِي",
+ "ج ب",
+ "ك ون",
+ "ب ن",
+ "ال ث",
+ "لَ يْ",
+ "ب عد",
+ "وَ الْ",
+ "فَ أَ",
+ "ع م",
+ "هُ م",
+ "ت ن",
+ "ذ ْ",
+ "أ ص",
+ "أ ين",
+ "رَب ِّ",
+ "الذ ين",
+ "إِ ن",
+ "ب ين",
+ "ج ُ",
+ "عَلَيْ هِ",
+ "ح َا",
+ "ل و",
+ "ست ط",
+ "ظ ر",
+ "لَ مْ",
+ "ء ِ",
+ "كُ ل",
+ "ط ل",
+ "ت َا",
+ "ض ُ",
+ "كن ت",
+ "ل ًا",
+ "م ٌ",
+ "ق بل",
+ "ـ ـ",
+ "ذ ِ",
+ "قَ وْ",
+ "ص ِ",
+ "م ًا",
+ "كان ت",
+ "ص ا",
+ "ي ق",
+ "ال ف",
+ "ال نا",
+ "م ٍ",
+ "إِ نْ",
+ "ال نَّ",
+ "ج د",
+ "وَ مَا",
+ "ت ت",
+ "ب ح",
+ "م كان",
+ "كي ف",
+ "ّ ة",
+ "ال ا",
+ "ج َا",
+ "أ و",
+ "سا عد",
+ "ض ِ",
+ "إ لا",
+ "را ً",
+ "ق َا",
+ "ر أ",
+ "ع ت",
+ "أ حد",
+ "ه د",
+ "ض ا",
+ "ط ر",
+ "أ ق",
+ "ما ء",
+ "د َّ",
+ "ال با",
+ "م ُو",
+ "أَ وْ",
+ "ط ا",
+ "ق ُو",
+ "خ ِ",
+ "ت ل",
+ "ستط يع",
+ "د َا",
+ "الن َّا",
+ "إ لَى",
+ "وَ تَ",
+ "هَ ذَا",
+ "ب ة",
+ "علي ك",
+ "ج ر",
+ "ال من",
+ "ز ا",
+ "ر ٍ",
+ "د ع",
+ "ّ ًا",
+ "س ة",
+ "ثُ مَّ",
+ "شي ء",
+ "ال غ",
+ "ت ح",
+ "ر ُونَ",
+ "ال يوم",
+ "م ِي",
+ "ن ُوا",
+ "أ ر",
+ "تُ مْ",
+ "ع ر",
+ "ي ف",
+ "أ ب",
+ "د ًا",
+ "ص َا",
+ "الت َّ",
+ "أ ريد",
+ "ال ز",
+ "يَ وْ",
+ "إ لي",
+ "ج ي",
+ "يَ عْ",
+ "فض ل",
+ "ال إن",
+ "أن ه",
+ "n g",
+ "i 4",
+ "a n",
+ "s h",
+ "z h",
+ "i 2",
+ "ng 1",
+ "u 4",
+ "i 1",
+ "ng 2",
+ "d e",
+ "j i",
+ "a o",
+ "x i",
+ "u 3",
+ "de 5",
+ "e 4",
+ "i 3",
+ "ng 4",
+ "an 4",
+ "e n",
+ "u o",
+ "sh i4",
+ "an 2",
+ "u 2",
+ "c h",
+ "u 1",
+ "ng 3",
+ "a 1",
+ "an 1",
+ "e 2",
+ "a 4",
+ "e i4",
+ "o ng1",
+ "a i4",
+ "ao 4",
+ "h u",
+ "a ng1",
+ "l i",
+ "y o",
+ "an 3",
+ "w ei4",
+ "uo 2",
+ "n 1",
+ "en 2",
+ "ao 3",
+ "e 1",
+ "y u",
+ "q i",
+ "e ng2",
+ "zh o",
+ "a ng3",
+ "a ng4",
+ "a ng2",
+ "uo 4",
+ "m i",
+ "g e4",
+ "y i1",
+ "g uo2",
+ "e r",
+ "b i",
+ "a 3",
+ "h e2",
+ "e 3",
+ "y i2",
+ "d i4",
+ "zh ong1",
+ "b u4",
+ "g u",
+ "a i2",
+ "n 2",
+ "z ai4",
+ "sh i2",
+ "e ng1",
+ "r en2",
+ "o ng2",
+ "xi an4",
+ "y i",
+ "x u",
+ "n 4",
+ "l i4",
+ "en 4",
+ "y u2",
+ "e i2",
+ "yi2 ge4",
+ "o u4",
+ "e i3",
+ "d i",
+ "u i4",
+ "a 2",
+ "yo u3",
+ "ao 1",
+ "d a4",
+ "ch eng2",
+ "en 1",
+ "e ng4",
+ "y i4",
+ "s i1",
+ "zh i4",
+ "ji a1",
+ "yu an2",
+ "n i",
+ "t a1",
+ "de5 yi2ge4",
+ "k e1",
+ "sh u3",
+ "x i1",
+ "j i2",
+ "ao 2",
+ "t i",
+ "o u3",
+ "o ng4",
+ "xi a4",
+ "a i1",
+ "g ong1",
+ "zh i1",
+ "en 3",
+ "w ei2",
+ "j u",
+ "xu e2",
+ "q u1",
+ "zho u1",
+ "er 3",
+ "mi ng2",
+ "zho ng3",
+ "l i3",
+ "w u4",
+ "y i3",
+ "uo 1",
+ "e 5",
+ "j i4",
+ "xi ng2",
+ "ji an4",
+ "hu a4",
+ "y u3",
+ "uo 3",
+ "j i1",
+ "a i3",
+ "z uo4",
+ "h ou4",
+ "hu i4",
+ "e i1",
+ "ni an2",
+ "q i2",
+ "p i",
+ "d ao4",
+ "sh eng1",
+ "de 2",
+ "d ai4",
+ "u an2",
+ "zh e4",
+ "zh eng4",
+ "b en3",
+ "sh ang4",
+ "zh u3",
+ "b ei4",
+ "y e4",
+ "ch u1",
+ "zh an4",
+ "l e5",
+ "l ai2",
+ "sh i3",
+ "n an2",
+ "r en4",
+ "yo u2",
+ "k e4",
+ "b a1",
+ "f u4",
+ "d ui4",
+ "y a4",
+ "m ei3",
+ "z i4",
+ "xi n1",
+ "ji ng1",
+ "zh u",
+ "n 3",
+ "yo ng4",
+ "m u4",
+ "ji ao4",
+ "y e3",
+ "ji n4",
+ "bi an4",
+ "l u4",
+ "q i1",
+ "sh e4",
+ "xi ang1",
+ "o ng3",
+ "sh u4",
+ "d ong4",
+ "s uo3",
+ "gu an1",
+ "s an1",
+ "b o",
+ "t e4",
+ "d uo1",
+ "f u2",
+ "mi n2",
+ "l a1",
+ "zh i2",
+ "zh en4",
+ "o u1",
+ "w u3",
+ "m a3",
+ "i 5",
+ "z i5",
+ "j u4",
+ "er 4",
+ "y ao4",
+ "xia4 de5yi2ge4",
+ "s i4",
+ "t u2",
+ "sh an1",
+ "z ui4",
+ "ch u",
+ "yi n1",
+ "er 2",
+ "t ong2",
+ "d ong1",
+ "y u4",
+ "y an2",
+ "qi an2",
+ "shu3 xia4de5yi2ge4",
+ "ju n1",
+ "k e3",
+ "w en2",
+ "f a3",
+ "l uo2",
+ "zh u4",
+ "x i4",
+ "k ou3",
+ "b ei3",
+ "ji an1",
+ "f a1",
+ "di an4",
+ "ji ang1",
+ "wei4 yu2",
+ "xi ang4",
+ "zh i3",
+ "e ng3",
+ "f ang1",
+ "l an2",
+ "sh u",
+ "r i4",
+ "li an2",
+ "sh ou3",
+ "m o",
+ "qi u2",
+ "ji n1",
+ "h uo4",
+ "shu3xia4de5yi2ge4 zhong3",
+ "f en1",
+ "n ei4",
+ "g ai1",
+ "mei3 guo2",
+ "u n2",
+ "g e2",
+ "b ao3",
+ "qi ng1",
+ "g ao1",
+ "t ai2",
+ "d u",
+ "xi ao3",
+ "ji e2",
+ "ti an1",
+ "ch ang2",
+ "q uan2",
+ "li e4",
+ "h ai3",
+ "f ei1",
+ "t i3",
+ "ju e2",
+ "o u2",
+ "c i3",
+ "z u2",
+ "n i2",
+ "bi ao3",
+ "zhong1 guo2",
+ "d u4",
+ "yu e4",
+ "xi ng4",
+ "sh eng4",
+ "ch e1",
+ "d an1",
+ "ji e1",
+ "li n2",
+ "pi ng2",
+ "f u3",
+ "g u3",
+ "ji e4",
+ "w o",
+ "v 3",
+ "sh eng3",
+ "n a4",
+ "yu an4",
+ "zh ang3",
+ "gu an3",
+ "d ao3",
+ "z u3",
+ "di ng4",
+ "di an3",
+ "c eng2",
+ "ren2 kou3",
+ "t ai4",
+ "t ong1",
+ "g uo4",
+ "n eng2",
+ "ch ang3",
+ "hu a2",
+ "li u2",
+ "yi ng1",
+ "xi ao4",
+ "c i4",
+ "bian4 hua4",
+ "li ang3",
+ "g ong4",
+ "zho ng4",
+ "de5 yi1",
+ "s e4",
+ "k ai1",
+ "w ang2",
+ "ji u4",
+ "sh i1",
+ "sh ou4",
+ "m ei2",
+ "k u",
+ "s u",
+ "f eng1",
+ "z e2",
+ "tu2 shi4",
+ "t i2",
+ "q i4",
+ "ji u3",
+ "sh en1",
+ "zh e3",
+ "ren2kou3 bian4hua4",
+ "ren2kou3bian4hua4 tu2shi4",
+ "di4 qu1",
+ "y ang2",
+ "m en",
+ "men 5",
+ "l ong2",
+ "bi ng4",
+ "ch an3",
+ "zh u1",
+ "w ei3",
+ "w ai4",
+ "xi ng1",
+ "bo 1",
+ "b i3",
+ "t ang2",
+ "hu a1",
+ "bo 2",
+ "shu i3",
+ "sh u1",
+ "d ou1",
+ "s ai4",
+ "ch ao2",
+ "b i4",
+ "li ng2",
+ "l ei4",
+ "da4 xue2",
+ "f en4",
+ "shu3 de5",
+ "m u3",
+ "ji ao1",
+ "d ang1",
+ "ch eng1",
+ "t ong3",
+ "n v3",
+ "q i3",
+ "y an3",
+ "mi an4",
+ "l uo4",
+ "ji ng4",
+ "g e1",
+ "r u4",
+ "d an4",
+ "ri4 ben3",
+ "p u3",
+ "yu n4",
+ "hu ang2",
+ "wo 3",
+ "l v",
+ "h ai2",
+ "shi4 yi1",
+ "xi e1",
+ "yi ng3",
+ "w u2",
+ "sh en2",
+ "w ang3",
+ "gu ang3",
+ "li u4",
+ "s u4",
+ "shi4 zhen4",
+ "c an1",
+ "c ao3",
+ "xi a2",
+ "k a3",
+ "d a2",
+ "h u4",
+ "b an4",
+ "d ang3",
+ "h u2",
+ "z ong3",
+ "de ng3",
+ "de5yi2ge4 shi4zhen4",
+ "ch uan2",
+ "mo 4",
+ "zh ang1",
+ "b an1",
+ "mo 2",
+ "ch a2",
+ "c e4",
+ "zhu3 yao4",
+ "t ou2",
+ "j u2",
+ "shi4 wei4yu2",
+ "s a4",
+ "u n1",
+ "ke3 yi3",
+ "d u1",
+ "h an4",
+ "li ang4",
+ "sh a1",
+ "ji a3",
+ "z i1",
+ "lv 4",
+ "f u1",
+ "xi an1",
+ "x u4",
+ "gu ang1",
+ "m eng2",
+ "b ao4",
+ "yo u4",
+ "r ong2",
+ "zhi1 yi1",
+ "w ei1",
+ "m ao2",
+ "guo2 jia1",
+ "c ong2",
+ "g ou4",
+ "ti e3",
+ "zh en1",
+ "d u2",
+ "bi an1",
+ "c i2",
+ "q u3",
+ "f an4",
+ "xi ang3",
+ "m en2",
+ "j u1",
+ "h ong2",
+ "z i3",
+ "ta1 men5",
+ "ji 3",
+ "z ong1",
+ "zhou1 de5yi2ge4shi4zhen4",
+ "t uan2",
+ "ji ng3",
+ "gong1 si1",
+ "xi e4",
+ "l i2",
+ "li4 shi3",
+ "b ao1",
+ "g ang3",
+ "gu i1",
+ "zh eng1",
+ "zhi2 wu4",
+ "ta1 de5",
+ "pi n3",
+ "zhu an1",
+ "ch ong2",
+ "shi3 yong4",
+ "w a3",
+ "sh uo1",
+ "chu an1",
+ "l ei2",
+ "w an1",
+ "h uo2",
+ "q u",
+ "s u1",
+ "z ao3",
+ "g ai3",
+ "q u4",
+ "g u4",
+ "l u",
+ "x i2",
+ "h ang2",
+ "yi ng4",
+ "c un1",
+ "g en1",
+ "yi ng2",
+ "ti ng2",
+ "cheng2 shi4",
+ "ji ang3",
+ "li ng3",
+ "l un2",
+ "bu4 fen4",
+ "de ng1",
+ "xu an3",
+ "dong4 wu4",
+ "de2 guo2",
+ "xi an3",
+ "f an3",
+ "zh e5",
+ "h an2",
+ "h ao4",
+ "m i4",
+ "r an2",
+ "qi n1",
+ "ti ao2",
+ "zh an3",
+ "h i",
+ "k a",
+ "n o",
+ "t e",
+ "s u",
+ "s hi",
+ "t a",
+ "t o",
+ "n a",
+ "w a",
+ "o u",
+ "r u",
+ "n i",
+ "k u",
+ "k i",
+ "g a",
+ "d e",
+ "k o",
+ "m a",
+ "r e",
+ "r a",
+ "m o",
+ "t su",
+ "w o",
+ "e n",
+ "r i",
+ "s a",
+ "d a",
+ "s e",
+ "j i",
+ "h a",
+ "c hi",
+ "k e",
+ "te ki",
+ "m i",
+ "y ou",
+ "s h",
+ "s o",
+ "y o",
+ "y a",
+ "na i",
+ "t te",
+ "a ru",
+ "b a",
+ "u u",
+ "t ta",
+ "ka i",
+ "ka n",
+ "shi te",
+ "m e",
+ "d o",
+ "mo no",
+ "se i",
+ "r o",
+ "ko to",
+ "ka ra",
+ "shi ta",
+ "b u",
+ "m u",
+ "c h",
+ "su ru",
+ "k ou",
+ "g o",
+ "ma su",
+ "ta i",
+ "f u",
+ "k en",
+ "i u",
+ "g en",
+ "wa re",
+ "shi n",
+ "z u",
+ "a i",
+ "o n",
+ "o ku",
+ "g i",
+ "d ou",
+ "n e",
+ "y uu",
+ "i ru",
+ "i te",
+ "ji ko",
+ "de su",
+ "j u",
+ "ra re",
+ "sh u",
+ "b e",
+ "sh ou",
+ "s ha",
+ "se kai",
+ "s ou",
+ "k you",
+ "ma shita",
+ "s en",
+ "na ra",
+ "sa n",
+ "ke i",
+ "i ta",
+ "a ri",
+ "i tsu",
+ "ko no",
+ "j ou",
+ "na ka",
+ "ch ou",
+ "so re",
+ "g u",
+ "na ru",
+ "ga ku",
+ "re ba",
+ "g e",
+ "h o",
+ "i n",
+ "hi to",
+ "sa i",
+ "na n",
+ "da i",
+ "tsu ku",
+ "shi ki",
+ "sa re",
+ "na ku",
+ "p p",
+ "bu n",
+ "ju n",
+ "so no",
+ "ka ku",
+ "z ai",
+ "b i",
+ "to u",
+ "wa ta",
+ "sh uu",
+ "i i",
+ "te i",
+ "ka re",
+ "y u",
+ "shi i",
+ "ma de",
+ "sh o",
+ "a n",
+ "ke reba",
+ "shi ka",
+ "i chi",
+ "ha n",
+ "de ki",
+ "ni n",
+ "ware ware",
+ "na kereba",
+ "o ite",
+ "h ou",
+ "ya ku",
+ "ra i",
+ "mu jun",
+ "l e",
+ "yo ku",
+ "bu tsu",
+ "o o",
+ "ko n",
+ "o mo",
+ "ga e",
+ "nara nai",
+ "ta chi",
+ "z en",
+ "ch uu",
+ "kan gae",
+ "ta ra",
+ "to ki",
+ "ko ro",
+ "mujun teki",
+ "z e",
+ "na ga",
+ "ji n",
+ "shi ma",
+ "te n",
+ "i ki",
+ "i ku",
+ "no u",
+ "i masu",
+ "r ou",
+ "h on",
+ "ka e",
+ "t to",
+ "ko re",
+ "ta n",
+ "ki ta",
+ "i s",
+ "da tta",
+ "ji tsu",
+ "ma e",
+ "i e",
+ "me i",
+ "da n",
+ "h e",
+ "to ku",
+ "dou itsu",
+ "ri tsu",
+ "k yuu",
+ "h you",
+ "rare ta",
+ "kei sei",
+ "k kan",
+ "rare ru",
+ "m ou",
+ "do ko",
+ "r you",
+ "da ke",
+ "naka tta",
+ "so ko",
+ "ta be",
+ "e r",
+ "ha na",
+ "c o",
+ "fu ku",
+ "p a",
+ "so n",
+ "ya su",
+ "ch o",
+ "wata ku",
+ "ya ma",
+ "z a",
+ "k yo",
+ "gen zai",
+ "b oku",
+ "a ta",
+ "j a",
+ "ka wa",
+ "ma sen",
+ "j uu",
+ "ro n",
+ "b o",
+ "na tte",
+ "wataku shi",
+ "yo tte",
+ "ma i",
+ "g ou",
+ "ha i",
+ "mo n",
+ "ba n",
+ "ji shin",
+ "c a",
+ "re te",
+ "n en",
+ "o ka",
+ "ka gaku",
+ "na tta",
+ "p o",
+ "ka ru",
+ "na ri",
+ "m en",
+ "ma ta",
+ "e i",
+ "ku ru",
+ "ga i",
+ "ka ri",
+ "sha kai",
+ "kou i",
+ "yo ri",
+ "se tsu",
+ "j o",
+ "re ru",
+ "to koro",
+ "ju tsu",
+ "i on",
+ "sa ku",
+ "tta i",
+ "c ha",
+ "nin gen",
+ "n u",
+ "c e",
+ "ta me",
+ "kan kyou",
+ "de n",
+ "o oku",
+ "i ma",
+ "wata shi",
+ "tsuku ru",
+ "su gi",
+ "b en",
+ "ji bun",
+ "shi tsu",
+ "ke ru",
+ "ki n",
+ "ki shi",
+ "shika shi",
+ "mo to",
+ "ma ri",
+ "i tte",
+ "de shita",
+ "n de",
+ "ari masu",
+ "te r",
+ "z ou",
+ "ko e",
+ "ze ttai",
+ "kkan teki",
+ "h en",
+ "re kishi",
+ "deki ru",
+ "tsu ka",
+ "l a",
+ "i tta",
+ "o i",
+ "ko butsu",
+ "mi ru",
+ "sh oku",
+ "shi masu",
+ "gi jutsu",
+ "g you",
+ "jou shiki",
+ "a tta",
+ "ho do",
+ "ko ko",
+ "tsuku rareta",
+ "z oku",
+ "hi tei",
+ "ko ku",
+ "rekishi teki",
+ "ke te",
+ "o ri",
+ "i mi",
+ "ka ko",
+ "naga ra",
+ "ka karu",
+ "shu tai",
+ "ha ji",
+ "ma n",
+ "ta ku",
+ "ra n",
+ "douitsu teki",
+ "z o",
+ "me te",
+ "re i",
+ "tsu u",
+ "sare te",
+ "gen jitsu",
+ "p e",
+ "s t",
+ "ba i",
+ "na wa",
+ "ji kan",
+ "wa ru",
+ "r t",
+ "a tsu",
+ "so ku",
+ "koui teki",
+ "a ra",
+ "u ma",
+ "a no",
+ "i de",
+ "ka ta",
+ "te tsu",
+ "ga wa",
+ "ke do",
+ "re ta",
+ "mi n",
+ "sa you",
+ "tte ru",
+ "to ri",
+ "p u",
+ "ki mi",
+ "b ou",
+ "mu ra",
+ "sare ru",
+ "ma chi",
+ "k ya",
+ "o sa",
+ "kon na",
+ "a ku",
+ "a l",
+ "sare ta",
+ "i pp",
+ "shi ku",
+ "u chi",
+ "hito tsu",
+ "ha tara",
+ "tachi ba",
+ "shi ro",
+ "ka tachi",
+ "to mo",
+ "e te",
+ "me ru",
+ "ni chi",
+ "da re",
+ "ka tta",
+ "e ru",
+ "su ki",
+ "a ge",
+ "oo ki",
+ "ma ru",
+ "mo ku",
+ "o ko",
+ "kangae rareru",
+ "o to",
+ "tan ni",
+ "ta da",
+ "tai teki",
+ "mo tte",
+ "ki nou",
+ "shi nai",
+ "k ki",
+ "u e",
+ "ta ri",
+ "l i",
+ "ra nai",
+ "k kou",
+ "mi rai",
+ "pp on",
+ "go to",
+ "hi n",
+ "hi tsu",
+ "te ru",
+ "mo chi",
+ "ka tsu",
+ "re n",
+ "n yuu",
+ "su i",
+ "zu ka",
+ "tsu ite",
+ "no mi",
+ "su gu",
+ "ku da",
+ "tetsu gaku",
+ "i ka",
+ "ron ri",
+ "o ki",
+ "ni ppon",
+ "p er",
+ "shi mashita",
+ "chi shiki",
+ "cho kkanteki",
+ "su ko",
+ "t ion",
+ "ku u",
+ "a na",
+ "a rou",
+ "ka tte",
+ "ku ri",
+ "i nai",
+ "hyou gen",
+ "i shiki",
+ "do ku",
+ "a tte",
+ "a tara",
+ "to n",
+ "wa ri",
+ "ka o",
+ "sei san",
+ "hana shi",
+ "s i",
+ "ka ke",
+ "na ji",
+ "su nawa",
+ "sunawa chi",
+ "u go",
+ "su u",
+ "ba ra",
+ "le v",
+ "hi ro",
+ "i wa",
+ "be tsu",
+ "yo i",
+ "se ru",
+ "shite ru",
+ "rare te",
+ "to shi",
+ "se ki",
+ "tai ritsu",
+ "wa kara",
+ "to kyo",
+ "k ka",
+ "k yoku",
+ "u n",
+ "i ro",
+ "mi te",
+ "sa ki",
+ "kan ji",
+ "mi ta",
+ "su be",
+ "r yoku",
+ "ma tta",
+ "kuda sai",
+ "omo i",
+ "ta no",
+ "ware ru",
+ "co m",
+ "hitsu you",
+ "ka shi",
+ "re nai",
+ "kan kei",
+ "a to",
+ "ga tte",
+ "o chi",
+ "mo tsu",
+ "in g",
+ "son zai",
+ "l l",
+ "o re",
+ "tai shite",
+ "a me",
+ "sei mei",
+ "ka no",
+ "gi ri",
+ "kangae ru",
+ "yu e",
+ "a sa",
+ "o naji",
+ "yo ru",
+ "ni ku",
+ "osa ka",
+ "suko shi",
+ "c k",
+ "ta ma",
+ "kano jo",
+ "ki te",
+ "mon dai",
+ "a mari",
+ "e ki",
+ "ko jin",
+ "ha ya",
+ "i t",
+ "de te",
+ "atara shii",
+ "a wa",
+ "ga kkou",
+ "tsu zu",
+ "shu kan",
+ "i mashita",
+ "mi na",
+ "ata e",
+ "da rou",
+ "hatara ku",
+ "ga ta",
+ "da chi",
+ "ma tsu",
+ "ari masen",
+ "sei butsu",
+ "mi tsu",
+ "he ya",
+ "yasu i",
+ "d i",
+ "de ni",
+ "no ko",
+ "ha ha",
+ "do mo",
+ "ka mi",
+ "su deni",
+ "na o",
+ "ra ku",
+ "i ke",
+ "a ki",
+ "me ta",
+ "l o",
+ "ko domo",
+ "so shite",
+ "ga me",
+ "ba kari",
+ "to te",
+ "ha tsu",
+ "mi se",
+ "moku teki",
+ "da kara",
+ "s z",
+ "e l",
+ "g y",
+ "e n",
+ "t t",
+ "e m",
+ "a n",
+ "a k",
+ "e r",
+ "a z",
+ "a l",
+ "e t",
+ "o l",
+ "e g",
+ "e k",
+ "m i",
+ "o n",
+ "é s",
+ "c s",
+ "a t",
+ "á r",
+ "h o",
+ "e z",
+ "á l",
+ "i s",
+ "á n",
+ "o r",
+ "a r",
+ "e gy",
+ "e s",
+ "é r",
+ "á t",
+ "o tt",
+ "e tt",
+ "m eg",
+ "t a",
+ "o k",
+ "o s",
+ "ho gy",
+ "n em",
+ "é g",
+ "n y",
+ "k i",
+ "é l",
+ "h a",
+ "á s",
+ "ü l",
+ "i n",
+ "mi n",
+ "n a",
+ "e d",
+ "o m",
+ "i k",
+ "k ö",
+ "m a",
+ "n i",
+ "v a",
+ "v ol",
+ "é t",
+ "b b",
+ "f el",
+ "i g",
+ "l e",
+ "r a",
+ "é n",
+ "t e",
+ "d e",
+ "a d",
+ "ó l",
+ "b e",
+ "on d",
+ "j a",
+ "r e",
+ "u l",
+ "b en",
+ "n ek",
+ "u t",
+ "vol t",
+ "b an",
+ "ö r",
+ "o g",
+ "a p",
+ "o d",
+ "á g",
+ "n k",
+ "é k",
+ "v al",
+ "k or",
+ "a m",
+ "i l",
+ "í t",
+ "á k",
+ "b a",
+ "u d",
+ "sz er",
+ "min d",
+ "o z",
+ "é p",
+ "el l",
+ "ér t",
+ "m ond",
+ "i t",
+ "sz t",
+ "n ak",
+ "a mi",
+ "n e",
+ "ő l",
+ "cs ak",
+ "n é",
+ "ma g",
+ "ol y",
+ "m er",
+ "ál l",
+ "án y",
+ "ö n",
+ "ö l",
+ "min t",
+ "m ár",
+ "ö tt",
+ "na gy",
+ "é sz",
+ "az t",
+ "el ő",
+ "t ud",
+ "o t",
+ "é ny",
+ "á z",
+ "m ég",
+ "kö z",
+ "el y",
+ "s ég",
+ "en t",
+ "s em",
+ "ta m",
+ "h et",
+ "h al",
+ "f i",
+ "a s",
+ "v an",
+ "ho z",
+ "v e",
+ "u k",
+ "k ez",
+ "á m",
+ "v el",
+ "b er",
+ "a j",
+ "u nk",
+ "i z",
+ "va gy",
+ "m os",
+ "sz em",
+ "em ber",
+ "f og",
+ "mer t",
+ "ü k",
+ "l en",
+ "ö s",
+ "e j",
+ "t al",
+ "h at",
+ "t ak",
+ "h i",
+ "m ás",
+ "s ág",
+ "ett e",
+ "l eg",
+ "ü nk",
+ "h át",
+ "sz a",
+ "on y",
+ "ez t",
+ "mind en",
+ "en d",
+ "ül t",
+ "h an",
+ "j ó",
+ "k is",
+ "á j",
+ "in t",
+ "ú gy",
+ "i d",
+ "mos t",
+ "ar t",
+ "í r",
+ "k er",
+ "i tt",
+ "a tt",
+ "el t",
+ "mond ta",
+ "k ell",
+ "l á",
+ "ak i",
+ "ál t",
+ "ér d",
+ "t ö",
+ "l an",
+ "v ár",
+ "h ol",
+ "t el",
+ "l át",
+ "ő k",
+ "v et",
+ "s e",
+ "ut án",
+ "k ét",
+ "na p",
+ "í v",
+ "ál y",
+ "v ég",
+ "ö k",
+ "i r",
+ "d ul",
+ "v is",
+ "né z",
+ "t er",
+ "á ban",
+ "k ül",
+ "ak kor",
+ "k ap",
+ "sz él",
+ "y en",
+ "ú j",
+ "i m",
+ "oly an",
+ "es en",
+ "k ed",
+ "h ely",
+ "t ör",
+ "b ól",
+ "el m",
+ "r á",
+ "ár a",
+ "r ó",
+ "l ó",
+ "vol na",
+ "t an",
+ "le het",
+ "e bb",
+ "t en",
+ "t ek",
+ "s ok",
+ "k al",
+ "f or",
+ "u g",
+ "ol t",
+ "k a",
+ "ek et",
+ "b or",
+ "f ej",
+ "g ond",
+ "a g",
+ "ak ar",
+ "f él",
+ "ú l",
+ "b el",
+ "ott a",
+ "mi t",
+ "val ami",
+ "j el",
+ "é d",
+ "ar c",
+ "u r",
+ "hal l",
+ "t i",
+ "f öl",
+ "á ba",
+ "ol g",
+ "ki r",
+ "ol d",
+ "m ar",
+ "k érd",
+ "j ár",
+ "ú r",
+ "sz e",
+ "z s",
+ "él et",
+ "j át",
+ "o v",
+ "u s",
+ "é z",
+ "v il",
+ "v er",
+ "ő r",
+ "á d",
+ "ö g",
+ "le sz",
+ "on t",
+ "b iz",
+ "k oz",
+ "á bb",
+ "kir ály",
+ "es t",
+ "a b",
+ "en g",
+ "ig az",
+ "b ar",
+ "ha j",
+ "d i",
+ "o b",
+ "k od",
+ "r ól",
+ "v ez",
+ "tö bb",
+ "sz ó",
+ "é ben",
+ "ö t",
+ "ny i",
+ "t á",
+ "sz ól",
+ "gond ol",
+ "eg ész",
+ "í gy",
+ "ő s",
+ "o bb",
+ "os an",
+ "b ől",
+ "a bb",
+ "c i",
+ "ő t",
+ "n ál",
+ "k ép",
+ "azt án",
+ "v i",
+ "t art",
+ "be szél",
+ "m en",
+ "elő tt",
+ "a szt",
+ "ma j",
+ "kö r",
+ "han g",
+ "í z",
+ "in cs",
+ "a i",
+ "é v",
+ "ó d",
+ "ó k",
+ "hoz z",
+ "t em",
+ "ok at",
+ "an y",
+ "nagy on",
+ "h áz",
+ "p er",
+ "p ed",
+ "ez te",
+ "et len",
+ "nek i",
+ "maj d",
+ "sz ony",
+ "án ak",
+ "fel é",
+ "egy szer",
+ "j e",
+ "ad t",
+ "gy er",
+ "ami kor",
+ "f oly",
+ "sz ak",
+ "ő d",
+ "h ú",
+ "á sz",
+ "am ely",
+ "h ar",
+ "ér e",
+ "il yen",
+ "od a",
+ "j ák",
+ "t ár",
+ "á val",
+ "l ak",
+ "t ó",
+ "m ent",
+ "gy an",
+ "él y",
+ "ú t",
+ "v ar",
+ "kez d",
+ "m ell",
+ "mi kor",
+ "h ez",
+ "val ó",
+ "k o",
+ "m es",
+ "szer et",
+ "r end",
+ "l et",
+ "vis sza",
+ "ig en",
+ "f ő",
+ "va s",
+ "as szony",
+ "r ől",
+ "ped ig",
+ "p i",
+ "sz ép",
+ "t ák",
+ "ö v",
+ "an i",
+ "vil ág",
+ "p en",
+ "mag a",
+ "t et",
+ "sz ik",
+ "é j",
+ "én t",
+ "j ött",
+ "s an",
+ "sz í",
+ "i de",
+ "g at",
+ "ett em",
+ "ul t",
+ "h ány",
+ "ás t",
+ "a hol",
+ "ők et",
+ "h ár",
+ "k el",
+ "n ő",
+ "cs i",
+ "tal ál",
+ "el te",
+ "lá tt",
+ "tör t",
+ "ha gy",
+ "e sz",
+ "s en",
+ "n él",
+ "p ar",
+ "v ál",
+ "k ut",
+ "l ány",
+ "ami t",
+ "s ő",
+ "ell en",
+ "mag át",
+ "in k",
+ "u gyan",
+ "kül ön",
+ "a sz",
+ "mind ig",
+ "l ép",
+ "tal án",
+ "u n",
+ "sz or",
+ "k e",
+ "il lan",
+ "n incs",
+ "z et",
+ "vagy ok",
+ "tel en",
+ "is mer",
+ "s or",
+ "is ten",
+ "ít ott",
+ "j obb",
+ "v es",
+ "dul t",
+ "j uk",
+ "sz en",
+ "r o",
+ "ö m",
+ "l ett",
+ "k ar",
+ "egy ik",
+ "b ár",
+ "sz i",
+ "sz ív",
+ "az on",
+ "e szt",
+ "föl d",
+ "kut y",
+ "p illan",
+ "f ér",
+ "k om",
+ "t ől",
+ "t ű",
+ "é be",
+ "t ött",
+ "bar át",
+ "í g",
+ "a hogy",
+ "e h",
+ "e p",
+ "s o",
+ "v en",
+ "jel ent",
+ "t at",
+ "sz eg",
+ "mint ha",
+ "f al",
+ "egy en",
+ "mi l",
+ "sza b",
+ "r i",
+ "é m",
+ "biz ony",
+ "j on",
+ "ör eg",
+ "d olg",
+ "cs ap",
+ "ti szt",
+ "áll t",
+ "an cs",
+ "id ő",
+ "k at",
+ "ü gy",
+ "mi ért",
+ "ó t",
+ "ü r",
+ "cs in",
+ "h az",
+ "b et",
+ "én ek",
+ "v ér",
+ "j ól",
+ "al att",
+ "m ely",
+ "l o",
+ "sem mi",
+ "ny ug",
+ "v ág",
+ "kö vet",
+ "ös sze",
+ "ma d",
+ "l i",
+ "a cs",
+ "fi ú",
+ "kö n",
+ "más ik",
+ "j ön",
+ "sz ám",
+ "g er",
+ "s ó",
+ "r ész",
+ "k ér",
+ "z el",
+ "é vel",
+ "e o",
+ "e u",
+ "a n",
+ "eu l",
+ "eu n",
+ "eo n",
+ "a e",
+ "d a",
+ "a l",
+ "s s",
+ "i n",
+ "i l",
+ "a g",
+ "an g",
+ "y eon",
+ "y eo",
+ "d o",
+ "c h",
+ "n g",
+ "j i",
+ "h an",
+ "g a",
+ "g o",
+ "u i",
+ "h ae",
+ "a m",
+ "u l",
+ "u n",
+ "g eo",
+ "s i",
+ "n eun",
+ "ss da",
+ "s eo",
+ "eon g",
+ "y o",
+ "i da",
+ "t t",
+ "k k",
+ "j eo",
+ "d eul",
+ "w a",
+ "eu m",
+ "g e",
+ "o n",
+ "o g",
+ "s al",
+ "m an",
+ "yeon g",
+ "geo s",
+ "h ag",
+ "an eun",
+ "j a",
+ "g i",
+ "s u",
+ "i ss",
+ "o l",
+ "d ae",
+ "eo b",
+ "h a",
+ "j u",
+ "eo l",
+ "g eu",
+ "j eong",
+ "s ae",
+ "do e",
+ "g eul",
+ "s eu",
+ "s in",
+ "eul o",
+ "b n",
+ "s ang",
+ "bn ida",
+ "h al",
+ "b o",
+ "han eun",
+ "m al",
+ "i m",
+ "m o",
+ "b u",
+ "jeo g",
+ "sae ng",
+ "in eun",
+ "an h",
+ "m a",
+ "sal am",
+ "j o",
+ "s a",
+ "eo m",
+ "n ae",
+ "w i",
+ "l o",
+ "g wa",
+ "yeo l",
+ "n a",
+ "e seo",
+ "y e",
+ "m yeon",
+ "tt ae",
+ "h w",
+ "j e",
+ "eob s",
+ "j ang",
+ "g u",
+ "g w",
+ "il eul",
+ "yeo g",
+ "j eon",
+ "si g",
+ "j ag",
+ "j in",
+ "y u",
+ "o e",
+ "s e",
+ "hag o",
+ "d eun",
+ "y a",
+ "m un",
+ "s eong",
+ "g ag",
+ "h am",
+ "d ang",
+ "b a",
+ "l eul",
+ "s il",
+ "do ng",
+ "kk a",
+ "b al",
+ "da l",
+ "han da",
+ "eo ssda",
+ "ae g",
+ "l i",
+ "ha ji",
+ "s eon",
+ "o ng",
+ "hae ssda",
+ "d e",
+ "i ssda",
+ "e ge",
+ "b un",
+ "m ul",
+ "ju ng",
+ "ji g",
+ "m u",
+ "iss neun",
+ "b i",
+ "g eun",
+ "seu bnida",
+ "w on",
+ "p p",
+ "d aneun",
+ "eo h",
+ "d eo",
+ "ga m",
+ "j al",
+ "hae ng",
+ "ag o",
+ "y ang",
+ "b ul",
+ "b ang",
+ "u m",
+ "s o",
+ "h i",
+ "j ae",
+ "si m",
+ "saeng gag",
+ "hag e",
+ "s og",
+ "eo ss",
+ "d an",
+ "ja sin",
+ "j il",
+ "eo g",
+ "g yeong",
+ "doe n",
+ "go ng",
+ "m i",
+ "ch i",
+ "d eu",
+ "d eon",
+ "hae ss",
+ "d u",
+ "n am",
+ "eun g",
+ "jo h",
+ "n al",
+ "m yeong",
+ "w o",
+ "eon a",
+ "i go",
+ "g yeol",
+ "y ag",
+ "gw an",
+ "ul i",
+ "yo ng",
+ "n o",
+ "l yeo",
+ "j og",
+ "eoh ge",
+ "ga t",
+ "b og",
+ "mo s",
+ "t ong",
+ "ch a",
+ "man h",
+ "jeo l",
+ "geo l",
+ "h oe",
+ "ag a",
+ "n aneun",
+ "g an",
+ "un eun",
+ "ch eol",
+ "ch e",
+ "do l",
+ "b on",
+ "b an",
+ "ba d",
+ "ch u",
+ "ham yeon",
+ "yeo ssda",
+ "i bnida",
+ "g ye",
+ "eo s",
+ "hw al",
+ "salam deul",
+ "ji man",
+ "dang sin",
+ "ji b",
+ "ttae mun",
+ "m ae",
+ "i b",
+ "e neun",
+ "eu g",
+ "jeo m",
+ "geul eon",
+ "h wa",
+ "a ssda",
+ "b eob",
+ "bu t",
+ "b ae",
+ "yeo ss",
+ "ch in",
+ "ch aeg",
+ "g eon",
+ "g ae",
+ "nae ga",
+ "i ga",
+ "m og",
+ "sig an",
+ "g il",
+ "h yeon",
+ "l yeog",
+ "gu g",
+ "p yeon",
+ "s an",
+ "w ae",
+ "j ul",
+ "s eul",
+ "deun g",
+ "haji man",
+ "eum yeon",
+ "p il",
+ "m ol",
+ "n eu",
+ "a ss",
+ "n yeon",
+ "t ae",
+ "h u",
+ "p yo",
+ "s ul",
+ "g ang",
+ "j ineun",
+ "b eon",
+ "ha da",
+ "seo l",
+ "si p",
+ "dal eun",
+ "a p",
+ "sal m",
+ "g yo",
+ "ch eon",
+ "hag i",
+ "in a",
+ "cheol eom",
+ "g al",
+ "il a",
+ "kka ji",
+ "anh neun",
+ "ha bnida",
+ "tt eon",
+ "n u",
+ "hae seo",
+ "doen da",
+ "s ol",
+ "tt al",
+ "l a",
+ "il o",
+ "seu b",
+ "b yeon",
+ "m yeo",
+ "b eol",
+ "s on",
+ "n un",
+ "j un",
+ "j am",
+ "j eung",
+ "tt o",
+ "e n",
+ "mo m",
+ "h o",
+ "ch im",
+ "hw ang",
+ "eun eun",
+ "jo ng",
+ "bo da",
+ "n ol",
+ "n eom",
+ "but eo",
+ "jig eum",
+ "eobs da",
+ "dae lo",
+ "i g",
+ "y ul",
+ "p yeong",
+ "seon eun",
+ "sal ang",
+ "seu t",
+ "h im",
+ "n an",
+ "h eom",
+ "h yang",
+ "p i",
+ "gw ang",
+ "eobs neun",
+ "hw ag",
+ "ge ss",
+ "jag i",
+ "il eon",
+ "wi hae",
+ "dae han",
+ "ga ji",
+ "m eog",
+ "j yeo",
+ "cha j",
+ "b yeong",
+ "eo d",
+ "g yeo",
+ "do n",
+ "eo ji",
+ "g ul",
+ "mo deun",
+ "j on",
+ "in saeng",
+ "geul ae",
+ "h ang",
+ "sa sil",
+ "si b",
+ "ch al",
+ "il ago",
+ "doe l",
+ "g eum",
+ "doe neun",
+ "b ol",
+ "ga jang",
+ "geul igo",
+ "e l",
+ "h yeong",
+ "haeng bog",
+ "ch ul",
+ "h on",
+ "ch ae",
+ "s am",
+ "m ang",
+ "in da",
+ "da m",
+ "w ol",
+ "ch oe",
+ "d ul",
+ "si jag",
+ "ch eong",
+ "il aneun",
+ "ul ineun",
+ "ae n",
+ "kk e",
+ "mun je",
+ "a do",
+ "t eu",
+ "g un",
+ "geun eun",
+ "b ge",
+ "ch eo",
+ "b aeg",
+ "ju g",
+ "t a",
+ "sang dae",
+ "geu geos",
+ "do g",
+ "eu s",
+ "deu s",
+ "ja b",
+ "h yeo",
+ "tt eohge",
+ "u g",
+ "ma j",
+ "ch il",
+ "s wi",
+ "j ileul",
+ "ch ang",
+ "g aneun",
+ "m ag",
+ "i ji",
+ "da go",
+ "m in",
+ "yo han",
+ "t eug",
+ "pp un",
+ "al eul",
+ "haeng dong",
+ "p o",
+ "m il",
+ "ch am",
+ "se sang",
+ "e do",
+ "p an",
+ "man deul",
+ "am yeon",
+ "a b",
+ "kk ae",
+ "b ag",
+ "i deul",
+ "p um",
+ "m eol",
+ "s un",
+ "n eul",
+ "ham kke",
+ "chu ng",
+ "da b",
+ "yu g",
+ "s ag",
+ "gwang ye",
+ "il eohge",
+ "bal o",
+ "neun de",
+ "ham yeo",
+ "go s",
+ "geul eoh",
+ "an ila",
+ "bang beob",
+ "da si",
+ "b yeol",
+ "g yeon",
+ "gam jeong",
+ "on eul",
+ "j aneun",
+ "yeo m",
+ "l ago",
+ "i gi",
+ "hw an",
+ "t eul",
+ "eo seo",
+ "si k",
+ "ch o",
+ "jag a",
+ "geul eom",
+ "geul eona",
+ "jeong do",
+ "g yeog",
+ "geul eohge",
+ "geu deul",
+ "eu t",
+ "im yeon",
+ "j jae",
+ "k eun",
+ "i sang",
+ "mal haessda",
+ "eu ge",
+ "no p",
+ "in gan",
+ "bo myeon",
+ "t aeg",
+ "seu s",
+ "d wi",
+ "s aneun",
+ "w an",
+ "anh go",
+ "t an",
+ "nu gu",
+ "su ng",
+ "da myeon",
+ "a deul",
+ "p eul",
+ "ttal a",
+ "d i",
+ "geos do",
+ "a ji",
+ "m eon",
+ "eum yeo",
+ "dol og",
+ "neun g",
+ "mo du",
+ "क े",
+ "ह ै",
+ "े ं",
+ "् र",
+ "ा र",
+ "न े",
+ "य ा",
+ "म ें",
+ "स े",
+ "क ी",
+ "क ा",
+ "ो ं",
+ "त ा",
+ "क र",
+ "स ्",
+ "क ि",
+ "क ो",
+ "र ्",
+ "न ा",
+ "क ्",
+ "ह ी",
+ "औ र",
+ "प र",
+ "त े",
+ "ह ो",
+ "प ्र",
+ "ा न",
+ "् य",
+ "ल ा",
+ "व ा",
+ "ल े",
+ "स ा",
+ "है ं",
+ "ल ि",
+ "ज ा",
+ "ह ा",
+ "भ ी",
+ "व ि",
+ "इ स",
+ "त ी",
+ "न ्",
+ "र ा",
+ "म ा",
+ "द े",
+ "द ि",
+ "ब ा",
+ "त ि",
+ "थ ा",
+ "न ि",
+ "क ार",
+ "ए क",
+ "ही ं",
+ "ह ु",
+ "ं ग",
+ "ै ं",
+ "न ी",
+ "स ी",
+ "अ प",
+ "त ्",
+ "न हीं",
+ "र ी",
+ "म े",
+ "म ु",
+ "ि त",
+ "त ो",
+ "प ा",
+ "ल ी",
+ "लि ए",
+ "ग ा",
+ "ल ्",
+ "र ह",
+ "र े",
+ "क् ष",
+ "म ैं",
+ "स म",
+ "उ स",
+ "ज ि",
+ "त ्र",
+ "म ि",
+ "च ा",
+ "ो ग",
+ "स ं",
+ "द ्",
+ "स ि",
+ "आ प",
+ "त ु",
+ "द ा",
+ "क ु",
+ "य ों",
+ "व े",
+ "ज ी",
+ "् या",
+ "उ न",
+ "ि क",
+ "य े",
+ "भ ा",
+ "् ट",
+ "ह म",
+ "स् ट",
+ "श ा",
+ "ड ़",
+ "ं द",
+ "ख ा",
+ "म ्",
+ "श ्",
+ "य ह",
+ "स क",
+ "प ू",
+ "कि या",
+ "अप ने",
+ "र ू",
+ "स ु",
+ "म ी",
+ "ह ि",
+ "ज ो",
+ "थ े",
+ "र ि",
+ "द ी",
+ "थ ी",
+ "ग ी",
+ "ल ोग",
+ "ग या",
+ "त र",
+ "न् ह",
+ "च ्",
+ "व ार",
+ "ब ी",
+ "प ्",
+ "द ो",
+ "ट ी",
+ "श ि",
+ "कर ने",
+ "ग े",
+ "ै से",
+ "इ न",
+ "ं ड",
+ "सा थ",
+ "प ु",
+ "ब े",
+ "ब ार",
+ "व ी",
+ "अ न",
+ "ह र",
+ "उ न्ह",
+ "हो ता",
+ "ज ब",
+ "कु छ",
+ "म ान",
+ "क ्र",
+ "ब ि",
+ "प ह",
+ "फ ि",
+ "स र",
+ "ार ी",
+ "र ो",
+ "द ू",
+ "क हा",
+ "त क",
+ "श न",
+ "ब ्",
+ "स् थ",
+ "व ह",
+ "बा द",
+ "ओ ं",
+ "ग ु",
+ "ज ्",
+ "्र े",
+ "ग र",
+ "रह े",
+ "व र्",
+ "ह ू",
+ "ार ्",
+ "प ी",
+ "ब हु",
+ "मु झ",
+ "्र ा",
+ "दि या",
+ "स ब",
+ "कर ते",
+ "अप नी",
+ "बहु त",
+ "क ह",
+ "ट े",
+ "हु ए",
+ "कि सी",
+ "र हा",
+ "ष ्ट",
+ "ज ़",
+ "ब ना",
+ "स ो",
+ "ड ि",
+ "को ई",
+ "व ्य",
+ "बा त",
+ "र ु",
+ "व ो",
+ "मुझ े",
+ "द् ध",
+ "च ार",
+ "मे रे",
+ "व र",
+ "्र ी",
+ "जा ता",
+ "न ों",
+ "प्र ा",
+ "दे ख",
+ "ट ा",
+ "क् या",
+ "अ ध",
+ "ल ग",
+ "ल ो",
+ "प ि",
+ "य ु",
+ "च े",
+ "जि स",
+ "ं त",
+ "ान ी",
+ "प ै",
+ "ज न",
+ "ार े",
+ "च ी",
+ "मि ल",
+ "द ु",
+ "दे श",
+ "च् छ",
+ "ष ्",
+ "स ू",
+ "ख े",
+ "च ु",
+ "ि या",
+ "ल गा",
+ "ब ु",
+ "उन के",
+ "ज् ञ",
+ "क्ष ा",
+ "त रह",
+ "्या दा",
+ "वा ले",
+ "पू र्",
+ "मैं ने",
+ "का म",
+ "रू प",
+ "हो ती",
+ "उ प",
+ "ज ान",
+ "प्र कार",
+ "भ ार",
+ "म न",
+ "हु आ",
+ "ट र",
+ "हू ँ",
+ "पर ि",
+ "पा स",
+ "अन ु",
+ "रा ज",
+ "लोग ों",
+ "अ ब",
+ "सम झ",
+ "ड ी",
+ "म ौ",
+ "श ु",
+ "च ि",
+ "प े",
+ "क ृ",
+ "सक ते",
+ "म ह",
+ "य ोग",
+ "द र्",
+ "उ से",
+ "ं ध",
+ "ड ा",
+ "जा ए",
+ "ब ो",
+ "ू ल",
+ "म ो",
+ "ों ने",
+ "ं स",
+ "तु म",
+ "पह ले",
+ "ब ता",
+ "त था",
+ "य ो",
+ "ग ई",
+ "उ त्",
+ "सक ता",
+ "क म",
+ "ज ्यादा",
+ "र ख",
+ "सम य",
+ "ार ा",
+ "अ गर",
+ "स् त",
+ "च ल",
+ "फि र",
+ "वार ा",
+ "कर ना",
+ "श ी",
+ "ग ए",
+ "ब न",
+ "ौ र",
+ "हो ने",
+ "चा ह",
+ "ख ु",
+ "हा ँ",
+ "उन्ह ें",
+ "उन्ह ोंने",
+ "छ ो",
+ "म् ह",
+ "प्र ति",
+ "नि क",
+ "व न",
+ "्य ू",
+ "र ही",
+ "तु म्ह",
+ "ज ैसे",
+ "ि यों",
+ "क् यों",
+ "ल ों",
+ "फ ़",
+ "ं त्र",
+ "हो ते",
+ "क् ति",
+ "त ्य",
+ "कर ्",
+ "क ई",
+ "व ं",
+ "कि न",
+ "प ो",
+ "कार ण",
+ "ड़ ी",
+ "भ ि",
+ "इस के",
+ "ब र",
+ "उस के",
+ "द् वारा",
+ "श े",
+ "क ॉ",
+ "दि न",
+ "न् न",
+ "ड़ ा",
+ "स् व",
+ "नि र्",
+ "मु ख",
+ "लि या",
+ "ट ि",
+ "ज्ञ ान",
+ "क् त",
+ "द ्र",
+ "ग ्",
+ "क् स",
+ "म ै",
+ "ग ो",
+ "ज े",
+ "ट ्र",
+ "म ार",
+ "त् व",
+ "ध ार",
+ "भा व",
+ "कर ता",
+ "ख ि",
+ "क ं",
+ "चा हि",
+ "य र",
+ "प् त",
+ "क ों",
+ "ं च",
+ "ज ु",
+ "म त",
+ "अ च्छ",
+ "हु ई",
+ "क भी",
+ "ले किन",
+ "भ ू",
+ "अप ना",
+ "दू स",
+ "चाहि ए",
+ "य ू",
+ "घ र",
+ "सब से",
+ "मे री",
+ "ना म",
+ "ढ ़",
+ "ं ट",
+ "ें गे",
+ "ब ै",
+ "फ ा",
+ "ए वं",
+ "य ी",
+ "ग ्र",
+ "क्ष े",
+ "आ ज",
+ "आप को",
+ "भा ग",
+ "ठ ा",
+ "क ै",
+ "भार त",
+ "उन की",
+ "प हु",
+ "स भी",
+ "ध ा",
+ "ण ा",
+ "स ान",
+ "हो गा",
+ "त ब",
+ "स ंग",
+ "प र्",
+ "अ व",
+ "त ना",
+ "ग ि",
+ "य न",
+ "स् था",
+ "च ित",
+ "ट ्",
+ "छ ा",
+ "जा ने",
+ "क्षे त्र",
+ "वा ली",
+ "पूर् ण",
+ "स मा",
+ "कार ी"
+ ]
+ }
+}
\ No newline at end of file
diff --git a/comfy/text_encoders/ace_text_cleaners.py b/comfy/text_encoders/ace_text_cleaners.py
new file mode 100644
index 000000000..cd31d8d8c
--- /dev/null
+++ b/comfy/text_encoders/ace_text_cleaners.py
@@ -0,0 +1,395 @@
+# basic text cleaners for the ACE step model
+# I didn't copy the ones from the reference code because I didn't want to deal with the dependencies
+# TODO: more languages than english?
+
+import re
+
+def japanese_to_romaji(japanese_text):
+ """
+ Convert Japanese hiragana and katakana to romaji (Latin alphabet representation).
+
+ Args:
+ japanese_text (str): Text containing hiragana and/or katakana characters
+
+ Returns:
+ str: The romaji (Latin alphabet) equivalent
+ """
+ # Dictionary mapping kana characters to their romaji equivalents
+ kana_map = {
+ # Katakana characters
+ 'ア': 'a', 'イ': 'i', 'ウ': 'u', 'エ': 'e', 'オ': 'o',
+ 'カ': 'ka', 'キ': 'ki', 'ク': 'ku', 'ケ': 'ke', 'コ': 'ko',
+ 'サ': 'sa', 'シ': 'shi', 'ス': 'su', 'セ': 'se', 'ソ': 'so',
+ 'タ': 'ta', 'チ': 'chi', 'ツ': 'tsu', 'テ': 'te', 'ト': 'to',
+ 'ナ': 'na', 'ニ': 'ni', 'ヌ': 'nu', 'ネ': 'ne', 'ノ': 'no',
+ 'ハ': 'ha', 'ヒ': 'hi', 'フ': 'fu', 'ヘ': 'he', 'ホ': 'ho',
+ 'マ': 'ma', 'ミ': 'mi', 'ム': 'mu', 'メ': 'me', 'モ': 'mo',
+ 'ヤ': 'ya', 'ユ': 'yu', 'ヨ': 'yo',
+ 'ラ': 'ra', 'リ': 'ri', 'ル': 'ru', 'レ': 're', 'ロ': 'ro',
+ 'ワ': 'wa', 'ヲ': 'wo', 'ン': 'n',
+
+ # Katakana voiced consonants
+ 'ガ': 'ga', 'ギ': 'gi', 'グ': 'gu', 'ゲ': 'ge', 'ゴ': 'go',
+ 'ザ': 'za', 'ジ': 'ji', 'ズ': 'zu', 'ゼ': 'ze', 'ゾ': 'zo',
+ 'ダ': 'da', 'ヂ': 'ji', 'ヅ': 'zu', 'デ': 'de', 'ド': 'do',
+ 'バ': 'ba', 'ビ': 'bi', 'ブ': 'bu', 'ベ': 'be', 'ボ': 'bo',
+ 'パ': 'pa', 'ピ': 'pi', 'プ': 'pu', 'ペ': 'pe', 'ポ': 'po',
+
+ # Katakana combinations
+ 'キャ': 'kya', 'キュ': 'kyu', 'キョ': 'kyo',
+ 'シャ': 'sha', 'シュ': 'shu', 'ショ': 'sho',
+ 'チャ': 'cha', 'チュ': 'chu', 'チョ': 'cho',
+ 'ニャ': 'nya', 'ニュ': 'nyu', 'ニョ': 'nyo',
+ 'ヒャ': 'hya', 'ヒュ': 'hyu', 'ヒョ': 'hyo',
+ 'ミャ': 'mya', 'ミュ': 'myu', 'ミョ': 'myo',
+ 'リャ': 'rya', 'リュ': 'ryu', 'リョ': 'ryo',
+ 'ギャ': 'gya', 'ギュ': 'gyu', 'ギョ': 'gyo',
+ 'ジャ': 'ja', 'ジュ': 'ju', 'ジョ': 'jo',
+ 'ビャ': 'bya', 'ビュ': 'byu', 'ビョ': 'byo',
+ 'ピャ': 'pya', 'ピュ': 'pyu', 'ピョ': 'pyo',
+
+ # Katakana small characters and special cases
+ 'ッ': '', # Small tsu (doubles the following consonant)
+ 'ャ': 'ya', 'ュ': 'yu', 'ョ': 'yo',
+
+ # Katakana extras
+ 'ヴ': 'vu', 'ファ': 'fa', 'フィ': 'fi', 'フェ': 'fe', 'フォ': 'fo',
+ 'ウィ': 'wi', 'ウェ': 'we', 'ウォ': 'wo',
+
+ # Hiragana characters
+ 'あ': 'a', 'い': 'i', 'う': 'u', 'え': 'e', 'お': 'o',
+ 'か': 'ka', 'き': 'ki', 'く': 'ku', 'け': 'ke', 'こ': 'ko',
+ 'さ': 'sa', 'し': 'shi', 'す': 'su', 'せ': 'se', 'そ': 'so',
+ 'た': 'ta', 'ち': 'chi', 'つ': 'tsu', 'て': 'te', 'と': 'to',
+ 'な': 'na', 'に': 'ni', 'ぬ': 'nu', 'ね': 'ne', 'の': 'no',
+ 'は': 'ha', 'ひ': 'hi', 'ふ': 'fu', 'へ': 'he', 'ほ': 'ho',
+ 'ま': 'ma', 'み': 'mi', 'む': 'mu', 'め': 'me', 'も': 'mo',
+ 'や': 'ya', 'ゆ': 'yu', 'よ': 'yo',
+ 'ら': 'ra', 'り': 'ri', 'る': 'ru', 'れ': 're', 'ろ': 'ro',
+ 'わ': 'wa', 'を': 'wo', 'ん': 'n',
+
+ # Hiragana voiced consonants
+ 'が': 'ga', 'ぎ': 'gi', 'ぐ': 'gu', 'げ': 'ge', 'ご': 'go',
+ 'ざ': 'za', 'じ': 'ji', 'ず': 'zu', 'ぜ': 'ze', 'ぞ': 'zo',
+ 'だ': 'da', 'ぢ': 'ji', 'づ': 'zu', 'で': 'de', 'ど': 'do',
+ 'ば': 'ba', 'び': 'bi', 'ぶ': 'bu', 'べ': 'be', 'ぼ': 'bo',
+ 'ぱ': 'pa', 'ぴ': 'pi', 'ぷ': 'pu', 'ぺ': 'pe', 'ぽ': 'po',
+
+ # Hiragana combinations
+ 'きゃ': 'kya', 'きゅ': 'kyu', 'きょ': 'kyo',
+ 'しゃ': 'sha', 'しゅ': 'shu', 'しょ': 'sho',
+ 'ちゃ': 'cha', 'ちゅ': 'chu', 'ちょ': 'cho',
+ 'にゃ': 'nya', 'にゅ': 'nyu', 'にょ': 'nyo',
+ 'ひゃ': 'hya', 'ひゅ': 'hyu', 'ひょ': 'hyo',
+ 'みゃ': 'mya', 'みゅ': 'myu', 'みょ': 'myo',
+ 'りゃ': 'rya', 'りゅ': 'ryu', 'りょ': 'ryo',
+ 'ぎゃ': 'gya', 'ぎゅ': 'gyu', 'ぎょ': 'gyo',
+ 'じゃ': 'ja', 'じゅ': 'ju', 'じょ': 'jo',
+ 'びゃ': 'bya', 'びゅ': 'byu', 'びょ': 'byo',
+ 'ぴゃ': 'pya', 'ぴゅ': 'pyu', 'ぴょ': 'pyo',
+
+ # Hiragana small characters and special cases
+ 'っ': '', # Small tsu (doubles the following consonant)
+ 'ゃ': 'ya', 'ゅ': 'yu', 'ょ': 'yo',
+
+ # Common punctuation and spaces
+ ' ': ' ', # Japanese space
+ '、': ', ', '。': '. ',
+ }
+
+ result = []
+ i = 0
+
+ while i < len(japanese_text):
+ # Check for small tsu (doubling the following consonant)
+ if i < len(japanese_text) - 1 and (japanese_text[i] == 'っ' or japanese_text[i] == 'ッ'):
+ if i < len(japanese_text) - 1 and japanese_text[i+1] in kana_map:
+ next_romaji = kana_map[japanese_text[i+1]]
+ if next_romaji and next_romaji[0] not in 'aiueon':
+ result.append(next_romaji[0]) # Double the consonant
+ i += 1
+ continue
+
+ # Check for combinations with small ya, yu, yo
+ if i < len(japanese_text) - 1 and japanese_text[i+1] in ('ゃ', 'ゅ', 'ょ', 'ャ', 'ュ', 'ョ'):
+ combo = japanese_text[i:i+2]
+ if combo in kana_map:
+ result.append(kana_map[combo])
+ i += 2
+ continue
+
+ # Regular character
+ if japanese_text[i] in kana_map:
+ result.append(kana_map[japanese_text[i]])
+ else:
+ # If it's not in our map, keep it as is (might be kanji, romaji, etc.)
+ result.append(japanese_text[i])
+
+ i += 1
+
+ return ''.join(result)
+
+def number_to_text(num, ordinal=False):
+ """
+ Convert a number (int or float) to its text representation.
+
+ Args:
+ num: The number to convert
+
+ Returns:
+ str: Text representation of the number
+ """
+
+ if not isinstance(num, (int, float)):
+ return "Input must be a number"
+
+ # Handle special case of zero
+ if num == 0:
+ return "zero"
+
+ # Handle negative numbers
+ negative = num < 0
+ num = abs(num)
+
+ # Handle floats
+ if isinstance(num, float):
+ # Split into integer and decimal parts
+ int_part = int(num)
+
+ # Convert both parts
+ int_text = _int_to_text(int_part)
+
+ # Handle decimal part (convert to string and remove '0.')
+ decimal_str = str(num).split('.')[1]
+ decimal_text = " point " + " ".join(_digit_to_text(int(digit)) for digit in decimal_str)
+
+ result = int_text + decimal_text
+ else:
+ # Handle integers
+ result = _int_to_text(num)
+
+ # Add 'negative' prefix for negative numbers
+ if negative:
+ result = "negative " + result
+
+ return result
+
+
+def _int_to_text(num):
+ """Helper function to convert an integer to text"""
+
+ ones = ["", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine",
+ "ten", "eleven", "twelve", "thirteen", "fourteen", "fifteen", "sixteen",
+ "seventeen", "eighteen", "nineteen"]
+
+ tens = ["", "", "twenty", "thirty", "forty", "fifty", "sixty", "seventy", "eighty", "ninety"]
+
+ if num < 20:
+ return ones[num]
+
+ if num < 100:
+ return tens[num // 10] + (" " + ones[num % 10] if num % 10 != 0 else "")
+
+ if num < 1000:
+ return ones[num // 100] + " hundred" + (" " + _int_to_text(num % 100) if num % 100 != 0 else "")
+
+ if num < 1000000:
+ return _int_to_text(num // 1000) + " thousand" + (" " + _int_to_text(num % 1000) if num % 1000 != 0 else "")
+
+ if num < 1000000000:
+ return _int_to_text(num // 1000000) + " million" + (" " + _int_to_text(num % 1000000) if num % 1000000 != 0 else "")
+
+ return _int_to_text(num // 1000000000) + " billion" + (" " + _int_to_text(num % 1000000000) if num % 1000000000 != 0 else "")
+
+
+def _digit_to_text(digit):
+ """Convert a single digit to text"""
+ digits = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
+ return digits[digit]
+
+
+_whitespace_re = re.compile(r"\s+")
+
+
+# List of (regular expression, replacement) pairs for abbreviations:
+_abbreviations = {
+ "en": [
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
+ for x in [
+ ("mrs", "misess"),
+ ("mr", "mister"),
+ ("dr", "doctor"),
+ ("st", "saint"),
+ ("co", "company"),
+ ("jr", "junior"),
+ ("maj", "major"),
+ ("gen", "general"),
+ ("drs", "doctors"),
+ ("rev", "reverend"),
+ ("lt", "lieutenant"),
+ ("hon", "honorable"),
+ ("sgt", "sergeant"),
+ ("capt", "captain"),
+ ("esq", "esquire"),
+ ("ltd", "limited"),
+ ("col", "colonel"),
+ ("ft", "fort"),
+ ]
+ ],
+}
+
+
+def expand_abbreviations_multilingual(text, lang="en"):
+ for regex, replacement in _abbreviations[lang]:
+ text = re.sub(regex, replacement, text)
+ return text
+
+
+_symbols_multilingual = {
+ "en": [
+ (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1])
+ for x in [
+ ("&", " and "),
+ ("@", " at "),
+ ("%", " percent "),
+ ("#", " hash "),
+ ("$", " dollar "),
+ ("£", " pound "),
+ ("°", " degree "),
+ ]
+ ],
+}
+
+
+def expand_symbols_multilingual(text, lang="en"):
+ for regex, replacement in _symbols_multilingual[lang]:
+ text = re.sub(regex, replacement, text)
+ text = text.replace(" ", " ") # Ensure there are no double spaces
+ return text.strip()
+
+
+_ordinal_re = {
+ "en": re.compile(r"([0-9]+)(st|nd|rd|th)"),
+}
+_number_re = re.compile(r"[0-9]+")
+_currency_re = {
+ "USD": re.compile(r"((\$[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+\$))"),
+ "GBP": re.compile(r"((£[0-9\.\,]*[0-9]+)|([0-9\.\,]*[0-9]+£))"),
+ "EUR": re.compile(r"(([0-9\.\,]*[0-9]+€)|((€[0-9\.\,]*[0-9]+)))"),
+}
+
+_comma_number_re = re.compile(r"\b\d{1,3}(,\d{3})*(\.\d+)?\b")
+_dot_number_re = re.compile(r"\b\d{1,3}(.\d{3})*(\,\d+)?\b")
+_decimal_number_re = re.compile(r"([0-9]+[.,][0-9]+)")
+
+
+def _remove_commas(m):
+ text = m.group(0)
+ if "," in text:
+ text = text.replace(",", "")
+ return text
+
+
+def _remove_dots(m):
+ text = m.group(0)
+ if "." in text:
+ text = text.replace(".", "")
+ return text
+
+
+def _expand_decimal_point(m, lang="en"):
+ amount = m.group(1).replace(",", ".")
+ return number_to_text(float(amount))
+
+
+def _expand_currency(m, lang="en", currency="USD"):
+ amount = float((re.sub(r"[^\d.]", "", m.group(0).replace(",", "."))))
+ full_amount = number_to_text(amount)
+
+ and_equivalents = {
+ "en": ", ",
+ "es": " con ",
+ "fr": " et ",
+ "de": " und ",
+ "pt": " e ",
+ "it": " e ",
+ "pl": ", ",
+ "cs": ", ",
+ "ru": ", ",
+ "nl": ", ",
+ "ar": ", ",
+ "tr": ", ",
+ "hu": ", ",
+ "ko": ", ",
+ }
+
+ if amount.is_integer():
+ last_and = full_amount.rfind(and_equivalents[lang])
+ if last_and != -1:
+ full_amount = full_amount[:last_and]
+
+ return full_amount
+
+
+def _expand_ordinal(m, lang="en"):
+ return number_to_text(int(m.group(1)), ordinal=True)
+
+
+def _expand_number(m, lang="en"):
+ return number_to_text(int(m.group(0)))
+
+
+def expand_numbers_multilingual(text, lang="en"):
+ if lang in ["en", "ru"]:
+ text = re.sub(_comma_number_re, _remove_commas, text)
+ else:
+ text = re.sub(_dot_number_re, _remove_dots, text)
+ try:
+ text = re.sub(_currency_re["GBP"], lambda m: _expand_currency(m, lang, "GBP"), text)
+ text = re.sub(_currency_re["USD"], lambda m: _expand_currency(m, lang, "USD"), text)
+ text = re.sub(_currency_re["EUR"], lambda m: _expand_currency(m, lang, "EUR"), text)
+ except:
+ pass
+
+ text = re.sub(_decimal_number_re, lambda m: _expand_decimal_point(m, lang), text)
+ text = re.sub(_ordinal_re[lang], lambda m: _expand_ordinal(m, lang), text)
+ text = re.sub(_number_re, lambda m: _expand_number(m, lang), text)
+ return text
+
+
+def lowercase(text):
+ return text.lower()
+
+
+def collapse_whitespace(text):
+ return re.sub(_whitespace_re, " ", text)
+
+
+def multilingual_cleaners(text, lang):
+ text = text.replace('"', "")
+ if lang == "tr":
+ text = text.replace("İ", "i")
+ text = text.replace("Ö", "ö")
+ text = text.replace("Ü", "ü")
+ text = lowercase(text)
+ try:
+ text = expand_numbers_multilingual(text, lang)
+ except:
+ pass
+ try:
+ text = expand_abbreviations_multilingual(text, lang)
+ except:
+ pass
+ try:
+ text = expand_symbols_multilingual(text, lang=lang)
+ except:
+ pass
+ text = collapse_whitespace(text)
+ return text
+
+
+def basic_cleaners(text):
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
+ text = lowercase(text)
+ text = collapse_whitespace(text)
+ return text
diff --git a/comfy/text_encoders/aura_t5.py b/comfy/text_encoders/aura_t5.py
index e9ad45a7f..cf4252eea 100644
--- a/comfy/text_encoders/aura_t5.py
+++ b/comfy/text_encoders/aura_t5.py
@@ -11,7 +11,7 @@ class PT5XlModel(sd1_clip.SDClipModel):
class PT5XlTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_pile_tokenizer"), "tokenizer.model")
- super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1)
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='pile_t5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, pad_token=1, tokenizer_data=tokenizer_data)
class AuraT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
diff --git a/comfy/text_encoders/bert.py b/comfy/text_encoders/bert.py
index fc9bac1d2..551b03162 100644
--- a/comfy/text_encoders/bert.py
+++ b/comfy/text_encoders/bert.py
@@ -93,8 +93,11 @@ class BertEmbeddings(torch.nn.Module):
self.LayerNorm = operations.LayerNorm(embed_dim, eps=layer_norm_eps, dtype=dtype, device=device)
- def forward(self, input_tokens, token_type_ids=None, dtype=None):
- x = self.word_embeddings(input_tokens, out_dtype=dtype)
+ def forward(self, input_tokens, embeds=None, token_type_ids=None, dtype=None):
+ if embeds is not None:
+ x = embeds
+ else:
+ x = self.word_embeddings(input_tokens, out_dtype=dtype)
x += comfy.ops.cast_to_input(self.position_embeddings.weight[:x.shape[1]], x)
if token_type_ids is not None:
x += self.token_type_embeddings(token_type_ids, out_dtype=x.dtype)
@@ -113,12 +116,12 @@ class BertModel_(torch.nn.Module):
self.embeddings = BertEmbeddings(config_dict["vocab_size"], config_dict["max_position_embeddings"], config_dict["type_vocab_size"], config_dict["pad_token_id"], embed_dim, layer_norm_eps, dtype, device, operations)
self.encoder = BertEncoder(config_dict["num_hidden_layers"], embed_dim, config_dict["intermediate_size"], config_dict["num_attention_heads"], layer_norm_eps, dtype, device, operations)
- def forward(self, input_tokens, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
- x = self.embeddings(input_tokens, dtype=dtype)
+ def forward(self, input_tokens, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
+ x = self.embeddings(input_tokens, embeds=embeds, dtype=dtype)
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
- mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
+ mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
x, i = self.encoder(x, mask, intermediate_output)
return x, i
diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py
index 5441c8952..a1adb5242 100644
--- a/comfy/text_encoders/cosmos.py
+++ b/comfy/text_encoders/cosmos.py
@@ -22,7 +22,7 @@ class CosmosT5XXL(sd1_clip.SD1ClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
- super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512)
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=1024, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, tokenizer_data=tokenizer_data)
class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer):
diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py
index b945b1aaa..d61ef6668 100644
--- a/comfy/text_encoders/flux.py
+++ b/comfy/text_encoders/flux.py
@@ -9,19 +9,18 @@ import os
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
- super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
class FluxTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
- clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
- self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
- self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
+ self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+ self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
- def tokenize_with_weights(self, text:str, return_word_ids=False):
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
- out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
- out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
+ out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
+ out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
@@ -35,8 +34,7 @@ class FluxClipModel(torch.nn.Module):
def __init__(self, dtype_t5=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
- clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
- self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
+ self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.t5xxl = comfy.text_encoders.sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options)
self.dtypes = set([dtype, dtype_t5])
diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py
index 45987a480..9dcf190a2 100644
--- a/comfy/text_encoders/genmo.py
+++ b/comfy/text_encoders/genmo.py
@@ -18,7 +18,7 @@ class MochiT5XXL(sd1_clip.SD1ClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
- super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
class MochiT5Tokenizer(sd1_clip.SD1Tokenizer):
diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py
new file mode 100644
index 000000000..dbcf52784
--- /dev/null
+++ b/comfy/text_encoders/hidream.py
@@ -0,0 +1,155 @@
+from . import hunyuan_video
+from . import sd3_clip
+from comfy import sd1_clip
+from comfy import sdxl_clip
+import comfy.model_management
+import torch
+import logging
+
+
+class HiDreamTokenizer:
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+ self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+ self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data)
+ self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
+
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
+ out = {}
+ out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
+ out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
+ t5xxl = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
+ out["t5xxl"] = [t5xxl[0]] # Use only first 128 tokens
+ out["llama"] = self.llama.tokenize_with_weights(text, return_word_ids, **kwargs)
+ return out
+
+ def untokenize(self, token_weight_pair):
+ return self.clip_g.untokenize(token_weight_pair)
+
+ def state_dict(self):
+ return {}
+
+
+class HiDreamTEModel(torch.nn.Module):
+ def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}):
+ super().__init__()
+ self.dtypes = set()
+ if clip_l:
+ self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options)
+ self.dtypes.add(dtype)
+ else:
+ self.clip_l = None
+
+ if clip_g:
+ self.clip_g = sdxl_clip.SDXLClipG(device=device, dtype=dtype, model_options=model_options)
+ self.dtypes.add(dtype)
+ else:
+ self.clip_g = None
+
+ if t5:
+ dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
+ self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True)
+ self.dtypes.add(dtype_t5)
+ else:
+ self.t5xxl = None
+
+ if llama:
+ dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
+ if "vocab_size" not in model_options:
+ model_options["vocab_size"] = 128256
+ self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009})
+ self.dtypes.add(dtype_llama)
+ else:
+ self.llama = None
+
+ logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
+
+ def set_clip_options(self, options):
+ if self.clip_l is not None:
+ self.clip_l.set_clip_options(options)
+ if self.clip_g is not None:
+ self.clip_g.set_clip_options(options)
+ if self.t5xxl is not None:
+ self.t5xxl.set_clip_options(options)
+ if self.llama is not None:
+ self.llama.set_clip_options(options)
+
+ def reset_clip_options(self):
+ if self.clip_l is not None:
+ self.clip_l.reset_clip_options()
+ if self.clip_g is not None:
+ self.clip_g.reset_clip_options()
+ if self.t5xxl is not None:
+ self.t5xxl.reset_clip_options()
+ if self.llama is not None:
+ self.llama.reset_clip_options()
+
+ def encode_token_weights(self, token_weight_pairs):
+ token_weight_pairs_l = token_weight_pairs["l"]
+ token_weight_pairs_g = token_weight_pairs["g"]
+ token_weight_pairs_t5 = token_weight_pairs["t5xxl"]
+ token_weight_pairs_llama = token_weight_pairs["llama"]
+ lg_out = None
+ pooled = None
+ extra = {}
+
+ if len(token_weight_pairs_g) > 0 or len(token_weight_pairs_l) > 0:
+ if self.clip_l is not None:
+ lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
+ else:
+ l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
+
+ if self.clip_g is not None:
+ g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
+ else:
+ g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
+
+ pooled = torch.cat((l_pooled, g_pooled), dim=-1)
+
+ if self.t5xxl is not None:
+ t5_output = self.t5xxl.encode_token_weights(token_weight_pairs_t5)
+ t5_out, t5_pooled = t5_output[:2]
+ else:
+ t5_out = None
+
+ if self.llama is not None:
+ ll_output = self.llama.encode_token_weights(token_weight_pairs_llama)
+ ll_out, ll_pooled = ll_output[:2]
+ ll_out = ll_out[:, 1:]
+ else:
+ ll_out = None
+
+ if t5_out is None:
+ t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
+
+ if ll_out is None:
+ ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
+
+ if pooled is None:
+ pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
+
+ extra["conditioning_llama3"] = ll_out
+ return t5_out, pooled, extra
+
+ def load_sd(self, sd):
+ if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
+ return self.clip_g.load_sd(sd)
+ elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
+ return self.clip_l.load_sd(sd)
+ elif "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
+ return self.t5xxl.load_sd(sd)
+ else:
+ return self.llama.load_sd(sd)
+
+
+def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
+ class HiDreamTEModel_(HiDreamTEModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
+ model_options = model_options.copy()
+ model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
+ if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options:
+ model_options = model_options.copy()
+ model_options["llama_scaled_fp8"] = llama_scaled_fp8
+ super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
+ return HiDreamTEModel_
diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py
index 7149d6878..b02148b33 100644
--- a/comfy/text_encoders/hunyuan_video.py
+++ b/comfy/text_encoders/hunyuan_video.py
@@ -4,6 +4,7 @@ import comfy.text_encoders.llama
from transformers import LlamaTokenizerFast
import torch
import os
+import numbers
def llama_detect(state_dict, prefix=""):
@@ -20,33 +21,49 @@ def llama_detect(state_dict, prefix=""):
class LLAMA3Tokenizer(sd1_clip.SDTokenizer):
- def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256):
+ def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=256, pad_token=128258):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "llama_tokenizer")
- super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=128258, end_token=128009, min_length=min_length)
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='llama', tokenizer_class=LlamaTokenizerFast, has_start_token=True, has_end_token=False, pad_to_max_length=False, max_length=99999999, pad_token=pad_token, min_length=min_length, tokenizer_data=tokenizer_data)
class LLAMAModel(sd1_clip.SDClipModel):
- def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}):
+ def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}):
llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None)
if llama_scaled_fp8 is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
- super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 128000, "pad": 128258}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+ textmodel_json_config = {}
+ vocab_size = model_options.get("vocab_size", None)
+ if vocab_size is not None:
+ textmodel_json_config["vocab_size"] = vocab_size
+
+ model_options = {**model_options, "model_name": "llama"}
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens=special_tokens, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Llama2, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
class HunyuanVideoTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
- clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
- self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
- self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n""" # 95 tokens
- self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1)
+ self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+ self.llama_template = """<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: 1. The main content and theme of the video.2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.4. background environment, light, style and atmosphere.5. camera angles, movements, and transitions used in the video:<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>""" # 95 tokens
+ self.llama = LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=1, tokenizer_data=tokenizer_data)
- def tokenize_with_weights(self, text:str, return_word_ids=False):
+ def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, image_embeds=None, image_interleave=1, **kwargs):
out = {}
- out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
+ out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
- llama_text = "{}{}".format(self.llama_template, text)
- out["llama"] = self.llama.tokenize_with_weights(llama_text, return_word_ids)
+ if llama_template is None:
+ llama_text = self.llama_template.format(text)
+ else:
+ llama_text = llama_template.format(text)
+ llama_text_tokens = self.llama.tokenize_with_weights(llama_text, return_word_ids, **kwargs)
+ embed_count = 0
+ for r in llama_text_tokens:
+ for i in range(len(r)):
+ if r[i][0] == 128257:
+ if image_embeds is not None and embed_count < image_embeds.shape[0]:
+ r[i] = ({"type": "embedding", "data": image_embeds[embed_count], "original_type": "image", "image_interleave": image_interleave},) + r[i][1:]
+ embed_count += 1
+ out["llama"] = llama_text_tokens
return out
def untokenize(self, token_weight_pair):
@@ -60,8 +77,7 @@ class HunyuanVideoClipModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
super().__init__()
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
- clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
- self.clip_l = clip_l_class(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
+ self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options)
self.llama = LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options)
self.dtypes = set([dtype, dtype_llama])
@@ -80,20 +96,51 @@ class HunyuanVideoClipModel(torch.nn.Module):
llama_out, llama_pooled, llama_extra_out = self.llama.encode_token_weights(token_weight_pairs_llama)
template_end = 0
- for i, v in enumerate(token_weight_pairs_llama[0]):
- if v[0] == 128007: # <|end_header_id|>
- template_end = i
+ extra_template_end = 0
+ extra_sizes = 0
+ user_end = 9999999999999
+ images = []
+
+ tok_pairs = token_weight_pairs_llama[0]
+ for i, v in enumerate(tok_pairs):
+ elem = v[0]
+ if not torch.is_tensor(elem):
+ if isinstance(elem, numbers.Integral):
+ if elem == 128006:
+ if tok_pairs[i + 1][0] == 882:
+ if tok_pairs[i + 2][0] == 128007:
+ template_end = i + 2
+ user_end = -1
+ if elem == 128009 and user_end == -1:
+ user_end = i + 1
+ else:
+ if elem.get("original_type") == "image":
+ elem_size = elem.get("data").shape[0]
+ if template_end > 0:
+ if user_end == -1:
+ extra_template_end += elem_size - 1
+ else:
+ image_start = i + extra_sizes
+ image_end = i + elem_size + extra_sizes
+ images.append((image_start, image_end, elem.get("image_interleave", 1)))
+ extra_sizes += elem_size - 1
if llama_out.shape[1] > (template_end + 2):
- if token_weight_pairs_llama[0][template_end + 1][0] == 271:
+ if tok_pairs[template_end + 1][0] == 271:
template_end += 2
- llama_out = llama_out[:, template_end:]
- llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end:]
+ llama_output = llama_out[:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
+ llama_extra_out["attention_mask"] = llama_extra_out["attention_mask"][:, template_end + extra_sizes:user_end + extra_sizes + extra_template_end]
if llama_extra_out["attention_mask"].sum() == torch.numel(llama_extra_out["attention_mask"]):
llama_extra_out.pop("attention_mask") # attention mask is useless if no masked elements
+ if len(images) > 0:
+ out = []
+ for i in images:
+ out.append(llama_out[:, i[0]: i[1]: i[2]])
+ llama_output = torch.cat(out + [llama_output], dim=1)
+
l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
- return llama_out, l_pooled, llama_extra_out
+ return llama_output, l_pooled, llama_extra_out
def load_sd(self, sd):
if "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
diff --git a/comfy/text_encoders/hydit.py b/comfy/text_encoders/hydit.py
index 7cb790f45..ac6994529 100644
--- a/comfy/text_encoders/hydit.py
+++ b/comfy/text_encoders/hydit.py
@@ -9,24 +9,26 @@ import torch
class HyditBertModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip.json")
+ model_options = {**model_options, "model_name": "hydit_clip"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "hydit_clip_tokenizer")
- super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77)
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1024, embedding_key='chinese_roberta', tokenizer_class=BertTokenizer, pad_to_max_length=False, max_length=512, min_length=77, tokenizer_data=tokenizer_data)
class MT5XLModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_config_xl.json")
+ model_options = {**model_options, "model_name": "mt5xl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class MT5XLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
#tokenizer_path = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "mt5_tokenizer"), "spiece.model")
tokenizer = tokenizer_data.get("spiece_model", None)
- super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256)
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=2048, embedding_key='mt5xl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=256, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
@@ -35,12 +37,12 @@ class HyditTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
mt5_tokenizer_data = tokenizer_data.get("mt5xl.spiece_model", None)
self.hydit_clip = HyditBertTokenizer(embedding_directory=embedding_directory)
- self.mt5xl = MT5XLTokenizer(tokenizer_data={"spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
+ self.mt5xl = MT5XLTokenizer(tokenizer_data={**tokenizer_data, "spiece_model": mt5_tokenizer_data}, embedding_directory=embedding_directory)
- def tokenize_with_weights(self, text:str, return_word_ids=False):
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
- out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids)
- out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids)
+ out["hydit_clip"] = self.hydit_clip.tokenize_with_weights(text, return_word_ids, **kwargs)
+ out["mt5xl"] = self.mt5xl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py
index ad4b4623e..34eb870e3 100644
--- a/comfy/text_encoders/llama.py
+++ b/comfy/text_encoders/llama.py
@@ -1,6 +1,5 @@
import torch
import torch.nn as nn
-import torch.nn.functional as F
from dataclasses import dataclass
from typing import Optional, Any
@@ -21,15 +20,41 @@ class Llama2Config:
max_position_embeddings: int = 8192
rms_norm_eps: float = 1e-5
rope_theta: float = 500000.0
+ transformer_type: str = "llama"
+ head_dim = 128
+ rms_norm_add = False
+ mlp_activation = "silu"
+
+@dataclass
+class Gemma2_2B_Config:
+ vocab_size: int = 256000
+ hidden_size: int = 2304
+ intermediate_size: int = 9216
+ num_hidden_layers: int = 26
+ num_attention_heads: int = 8
+ num_key_value_heads: int = 4
+ max_position_embeddings: int = 8192
+ rms_norm_eps: float = 1e-6
+ rope_theta: float = 10000.0
+ transformer_type: str = "gemma2"
+ head_dim = 256
+ rms_norm_add = True
+ mlp_activation = "gelu_pytorch_tanh"
class RMSNorm(nn.Module):
- def __init__(self, dim: int, eps: float = 1e-5, device=None, dtype=None):
+ def __init__(self, dim: int, eps: float = 1e-5, add=False, device=None, dtype=None):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
+ self.add = add
def forward(self, x: torch.Tensor):
- return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
+ w = self.weight
+ if self.add:
+ w = w + 1.0
+
+ return comfy.ldm.common_dit.rms_norm(x, w, self.eps)
+
def rotate_half(x):
@@ -68,13 +93,15 @@ class Attention(nn.Module):
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.hidden_size = config.hidden_size
- self.head_dim = self.hidden_size // self.num_heads
+
+ self.head_dim = config.head_dim
+ self.inner_size = self.num_heads * self.head_dim
ops = ops or nn
- self.q_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
+ self.q_proj = ops.Linear(config.hidden_size, self.inner_size, bias=False, device=device, dtype=dtype)
self.k_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
self.v_proj = ops.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False, device=device, dtype=dtype)
- self.o_proj = ops.Linear(config.hidden_size, config.hidden_size, bias=False, device=device, dtype=dtype)
+ self.o_proj = ops.Linear(self.inner_size, config.hidden_size, bias=False, device=device, dtype=dtype)
def forward(
self,
@@ -84,7 +111,6 @@ class Attention(nn.Module):
optimized_attention=None,
):
batch_size, seq_length, _ = hidden_states.shape
-
xq = self.q_proj(hidden_states)
xk = self.k_proj(hidden_states)
xv = self.v_proj(hidden_states)
@@ -108,9 +134,13 @@ class MLP(nn.Module):
self.gate_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.up_proj = ops.Linear(config.hidden_size, config.intermediate_size, bias=False, device=device, dtype=dtype)
self.down_proj = ops.Linear(config.intermediate_size, config.hidden_size, bias=False, device=device, dtype=dtype)
+ if config.mlp_activation == "silu":
+ self.activation = torch.nn.functional.silu
+ elif config.mlp_activation == "gelu_pytorch_tanh":
+ self.activation = lambda a: torch.nn.functional.gelu(a, approximate="tanh")
def forward(self, x):
- return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
+ return self.down_proj(self.activation(self.gate_proj(x)) * self.up_proj(x))
class TransformerBlock(nn.Module):
def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
@@ -146,6 +176,45 @@ class TransformerBlock(nn.Module):
return x
+class TransformerBlockGemma2(nn.Module):
+ def __init__(self, config: Llama2Config, device=None, dtype=None, ops: Any = None):
+ super().__init__()
+ self.self_attn = Attention(config, device=device, dtype=dtype, ops=ops)
+ self.mlp = MLP(config, device=device, dtype=dtype, ops=ops)
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
+ self.pre_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
+ self.post_feedforward_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ freqs_cis: Optional[torch.Tensor] = None,
+ optimized_attention=None,
+ ):
+ # Self Attention
+ residual = x
+ x = self.input_layernorm(x)
+ x = self.self_attn(
+ hidden_states=x,
+ attention_mask=attention_mask,
+ freqs_cis=freqs_cis,
+ optimized_attention=optimized_attention,
+ )
+
+ x = self.post_attention_layernorm(x)
+ x = residual + x
+
+ # MLP
+ residual = x
+ x = self.pre_feedforward_layernorm(x)
+ x = self.mlp(x)
+ x = self.post_feedforward_layernorm(x)
+ x = residual + x
+
+ return x
+
class Llama2_(nn.Module):
def __init__(self, config, device=None, dtype=None, ops=None):
super().__init__()
@@ -158,17 +227,30 @@ class Llama2_(nn.Module):
device=device,
dtype=dtype
)
+ if self.config.transformer_type == "gemma2":
+ transformer = TransformerBlockGemma2
+ self.normalize_in = True
+ else:
+ transformer = TransformerBlock
+ self.normalize_in = False
+
self.layers = nn.ModuleList([
- TransformerBlock(config, device=device, dtype=dtype, ops=ops)
+ transformer(config, device=device, dtype=dtype, ops=ops)
for _ in range(config.num_hidden_layers)
])
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, device=device, dtype=dtype)
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype)
# self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype)
- def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
- x = self.embed_tokens(x, out_dtype=dtype)
+ def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None):
+ if embeds is not None:
+ x = embeds
+ else:
+ x = self.embed_tokens(x, out_dtype=dtype)
- freqs_cis = precompute_freqs_cis(self.config.hidden_size // self.config.num_attention_heads,
+ if self.normalize_in:
+ x *= self.config.hidden_size ** 0.5
+
+ freqs_cis = precompute_freqs_cis(self.config.head_dim,
x.shape[1],
self.config.rope_theta,
device=x.device)
@@ -186,11 +268,17 @@ class Llama2_(nn.Module):
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)
intermediate = None
+ all_intermediate = None
if intermediate_output is not None:
- if intermediate_output < 0:
+ if intermediate_output == "all":
+ all_intermediate = []
+ intermediate_output = None
+ elif intermediate_output < 0:
intermediate_output = len(self.layers) + intermediate_output
for i, layer in enumerate(self.layers):
+ if all_intermediate is not None:
+ all_intermediate.append(x.unsqueeze(1).clone())
x = layer(
x=x,
attention_mask=mask,
@@ -201,21 +289,18 @@ class Llama2_(nn.Module):
intermediate = x.clone()
x = self.norm(x)
+ if all_intermediate is not None:
+ all_intermediate.append(x.unsqueeze(1).clone())
+
+ if all_intermediate is not None:
+ intermediate = torch.cat(all_intermediate, dim=1)
+
if intermediate is not None and final_layer_norm_intermediate:
intermediate = self.norm(intermediate)
return x, intermediate
-
-class Llama2(torch.nn.Module):
- def __init__(self, config_dict, dtype, device, operations):
- super().__init__()
- config = Llama2Config(**config_dict)
- self.num_layers = config.num_hidden_layers
-
- self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
- self.dtype = dtype
-
+class BaseLlama:
def get_input_embeddings(self):
return self.model.embed_tokens
@@ -224,3 +309,23 @@ class Llama2(torch.nn.Module):
def forward(self, input_ids, *args, **kwargs):
return self.model(input_ids, *args, **kwargs)
+
+
+class Llama2(BaseLlama, torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ config = Llama2Config(**config_dict)
+ self.num_layers = config.num_hidden_layers
+
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
+
+
+class Gemma2_2B(BaseLlama, torch.nn.Module):
+ def __init__(self, config_dict, dtype, device, operations):
+ super().__init__()
+ config = Gemma2_2B_Config(**config_dict)
+ self.num_layers = config.num_hidden_layers
+
+ self.model = Llama2_(config, device=device, dtype=dtype, ops=operations)
+ self.dtype = dtype
diff --git a/comfy/text_encoders/long_clipl.py b/comfy/text_encoders/long_clipl.py
index b81912cb3..8d4c7619d 100644
--- a/comfy/text_encoders/long_clipl.py
+++ b/comfy/text_encoders/long_clipl.py
@@ -1,30 +1,27 @@
-from comfy import sd1_clip
-import os
-class LongClipTokenizer_(sd1_clip.SDTokenizer):
- def __init__(self, embedding_directory=None, tokenizer_data={}):
- super().__init__(max_length=248, embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
-
-class LongClipModel_(sd1_clip.SDClipModel):
- def __init__(self, *args, **kwargs):
- textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "long_clipl.json")
- super().__init__(*args, textmodel_json_config=textmodel_json_config, **kwargs)
-
-class LongClipTokenizer(sd1_clip.SD1Tokenizer):
- def __init__(self, embedding_directory=None, tokenizer_data={}):
- super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, tokenizer=LongClipTokenizer_)
-
-class LongClipModel(sd1_clip.SD1ClipModel):
- def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
- super().__init__(device=device, dtype=dtype, model_options=model_options, clip_model=LongClipModel_, **kwargs)
def model_options_long_clip(sd, tokenizer_data, model_options):
w = sd.get("clip_l.text_model.embeddings.position_embedding.weight", None)
+ if w is None:
+ w = sd.get("clip_g.text_model.embeddings.position_embedding.weight", None)
+ else:
+ model_name = "clip_g"
+
if w is None:
w = sd.get("text_model.embeddings.position_embedding.weight", None)
- if w is not None and w.shape[0] == 248:
+ if w is not None:
+ if "text_model.encoder.layers.30.mlp.fc1.weight" in sd:
+ model_name = "clip_g"
+ elif "text_model.encoder.layers.1.mlp.fc1.weight" in sd:
+ model_name = "clip_l"
+ else:
+ model_name = "clip_l"
+
+ if w is not None:
tokenizer_data = tokenizer_data.copy()
model_options = model_options.copy()
- tokenizer_data["clip_l_tokenizer_class"] = LongClipTokenizer_
- model_options["clip_l_class"] = LongClipModel_
+ model_config = model_options.get("model_config", {})
+ model_config["max_position_embeddings"] = w.shape[0]
+ model_options["{}_model_config".format(model_name)] = model_config
+ tokenizer_data["{}_max_length".format(model_name)] = w.shape[0]
return tokenizer_data, model_options
diff --git a/comfy/text_encoders/lt.py b/comfy/text_encoders/lt.py
index 5c2ce583f..48ea67e67 100644
--- a/comfy/text_encoders/lt.py
+++ b/comfy/text_encoders/lt.py
@@ -6,7 +6,7 @@ import comfy.text_encoders.genmo
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
- super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128) #pad to 128?
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data) #pad to 128?
class LTXVT5Tokenizer(sd1_clip.SD1Tokenizer):
diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py
new file mode 100644
index 000000000..674461b75
--- /dev/null
+++ b/comfy/text_encoders/lumina2.py
@@ -0,0 +1,39 @@
+from comfy import sd1_clip
+from .spiece_tokenizer import SPieceTokenizer
+import comfy.text_encoders.llama
+
+
+class Gemma2BTokenizer(sd1_clip.SDTokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ tokenizer = tokenizer_data.get("spiece_model", None)
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
+
+ def state_dict(self):
+ return {"spiece_model": self.tokenizer.serialize_model()}
+
+
+class LuminaTokenizer(sd1_clip.SD1Tokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer)
+
+
+class Gemma2_2BModel(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}):
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
+
+
+class LuminaModel(sd1_clip.SD1ClipModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ super().__init__(device=device, dtype=dtype, name="gemma2_2b", clip_model=Gemma2_2BModel, model_options=model_options)
+
+
+def te(dtype_llama=None, llama_scaled_fp8=None):
+ class LuminaTEModel_(LuminaModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ model_options = model_options.copy()
+ model_options["scaled_fp8"] = llama_scaled_fp8
+ if dtype_llama is not None:
+ dtype = dtype_llama
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
+ return LuminaTEModel_
diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py
index d56d57f1b..b8de6bc4e 100644
--- a/comfy/text_encoders/pixart_t5.py
+++ b/comfy/text_encoders/pixart_t5.py
@@ -24,7 +24,7 @@ class PixArtT5XXL(sd1_clip.SD1ClipModel):
class T5XXLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
- super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) # no padding
class PixArtTokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
diff --git a/comfy/text_encoders/sa_t5.py b/comfy/text_encoders/sa_t5.py
index 7778ce47a..2803926ac 100644
--- a/comfy/text_encoders/sa_t5.py
+++ b/comfy/text_encoders/sa_t5.py
@@ -11,7 +11,7 @@ class T5BaseModel(sd1_clip.SDClipModel):
class T5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
- super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128)
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_size=768, embedding_key='t5base', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=128, tokenizer_data=tokenizer_data)
class SAT5Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
diff --git a/comfy/text_encoders/sd2_clip.py b/comfy/text_encoders/sd2_clip.py
index 31fc89869..700a23bf0 100644
--- a/comfy/text_encoders/sd2_clip.py
+++ b/comfy/text_encoders/sd2_clip.py
@@ -12,7 +12,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
class SD2ClipHTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_path=None, embedding_directory=None, tokenizer_data={}):
- super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024)
+ super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=1024, embedding_key='clip_h', tokenizer_data=tokenizer_data)
class SD2Tokenizer(sd1_clip.SD1Tokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py
index 00d7e31ad..ff5d412db 100644
--- a/comfy/text_encoders/sd3_clip.py
+++ b/comfy/text_encoders/sd3_clip.py
@@ -15,6 +15,7 @@ class T5XXLModel(sd1_clip.SDClipModel):
model_options = model_options.copy()
model_options["scaled_fp8"] = t5xxl_scaled_fp8
+ model_options = {**model_options, "model_name": "t5xxl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)
@@ -31,23 +32,22 @@ def t5_xxl_detect(state_dict, prefix=""):
return out
class T5XXLTokenizer(sd1_clip.SDTokenizer):
- def __init__(self, embedding_directory=None, tokenizer_data={}):
+ def __init__(self, embedding_directory=None, tokenizer_data={}, min_length=77, max_length=99999999):
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer")
- super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=77)
+ super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=max_length, min_length=min_length, tokenizer_data=tokenizer_data)
class SD3Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
- clip_l_tokenizer_class = tokenizer_data.get("clip_l_tokenizer_class", sd1_clip.SDTokenizer)
- self.clip_l = clip_l_tokenizer_class(embedding_directory=embedding_directory)
- self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory)
- self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory)
+ self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+ self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
+ self.t5xxl = T5XXLTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
- def tokenize_with_weights(self, text:str, return_word_ids=False):
+ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
out = {}
- out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids)
- out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids)
- out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids)
+ out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
+ out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
+ out["t5xxl"] = self.t5xxl.tokenize_with_weights(text, return_word_ids, **kwargs)
return out
def untokenize(self, token_weight_pair):
@@ -61,8 +61,7 @@ class SD3ClipModel(torch.nn.Module):
super().__init__()
self.dtypes = set()
if clip_l:
- clip_l_class = model_options.get("clip_l_class", sd1_clip.SDClipModel)
- self.clip_l = clip_l_class(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
+ self.clip_l = sd1_clip.SDClipModel(layer="hidden", layer_idx=-2, device=device, dtype=dtype, layer_norm_hidden_state=False, return_projected_pooled=False, model_options=model_options)
self.dtypes.add(dtype)
else:
self.clip_l = None
diff --git a/comfy/text_encoders/spiece_tokenizer.py b/comfy/text_encoders/spiece_tokenizer.py
index cbaa99ba5..caccb3ca2 100644
--- a/comfy/text_encoders/spiece_tokenizer.py
+++ b/comfy/text_encoders/spiece_tokenizer.py
@@ -1,21 +1,24 @@
import torch
+import os
class SPieceTokenizer:
- add_eos = True
-
@staticmethod
- def from_pretrained(path):
- return SPieceTokenizer(path)
+ def from_pretrained(path, **kwargs):
+ return SPieceTokenizer(path, **kwargs)
- def __init__(self, tokenizer_path):
+ def __init__(self, tokenizer_path, add_bos=False, add_eos=True):
+ self.add_bos = add_bos
+ self.add_eos = add_eos
import sentencepiece
if torch.is_tensor(tokenizer_path):
tokenizer_path = tokenizer_path.numpy().tobytes()
if isinstance(tokenizer_path, bytes):
- self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_eos=self.add_eos)
+ self.tokenizer = sentencepiece.SentencePieceProcessor(model_proto=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
else:
- self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_eos=self.add_eos)
+ if not os.path.isfile(tokenizer_path):
+ raise ValueError("invalid tokenizer")
+ self.tokenizer = sentencepiece.SentencePieceProcessor(model_file=tokenizer_path, add_bos=self.add_bos, add_eos=self.add_eos)
def get_vocab(self):
out = {}
diff --git a/comfy/text_encoders/t5.py b/comfy/text_encoders/t5.py
index 7405528e2..49f0ba4fe 100644
--- a/comfy/text_encoders/t5.py
+++ b/comfy/text_encoders/t5.py
@@ -203,7 +203,7 @@ class T5Stack(torch.nn.Module):
mask = None
if attention_mask is not None:
mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1])
- mask = mask.masked_fill(mask.to(torch.bool), float("-inf"))
+ mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max)
intermediate = None
optimized_attention = optimized_attention_for_device(x.device, mask=attention_mask is not None, small_input=True)
@@ -239,8 +239,11 @@ class T5(torch.nn.Module):
def set_input_embeddings(self, embeddings):
self.shared = embeddings
- def forward(self, input_ids, *args, **kwargs):
- x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
+ def forward(self, input_ids, attention_mask, embeds=None, num_tokens=None, **kwargs):
+ if input_ids is None:
+ x = embeds
+ else:
+ x = self.shared(input_ids, out_dtype=kwargs.get("dtype", torch.float32))
if self.dtype not in [torch.float32, torch.float16, torch.bfloat16]:
x = torch.nan_to_num(x) #Fix for fp8 T5 base
- return self.encoder(x, *args, **kwargs)
+ return self.encoder(x, attention_mask=attention_mask, **kwargs)
diff --git a/comfy/text_encoders/umt5_config_base.json b/comfy/text_encoders/umt5_config_base.json
new file mode 100644
index 000000000..6b3618f07
--- /dev/null
+++ b/comfy/text_encoders/umt5_config_base.json
@@ -0,0 +1,22 @@
+{
+ "d_ff": 2048,
+ "d_kv": 64,
+ "d_model": 768,
+ "decoder_start_token_id": 0,
+ "dropout_rate": 0.1,
+ "eos_token_id": 1,
+ "dense_act_fn": "gelu_pytorch_tanh",
+ "initializer_factor": 1.0,
+ "is_encoder_decoder": true,
+ "is_gated_act": true,
+ "layer_norm_epsilon": 1e-06,
+ "model_type": "umt5",
+ "num_decoder_layers": 12,
+ "num_heads": 12,
+ "num_layers": 12,
+ "output_past": true,
+ "pad_token_id": 0,
+ "relative_attention_num_buckets": 32,
+ "tie_word_embeddings": false,
+ "vocab_size": 256384
+}
diff --git a/comfy/text_encoders/umt5_config_xxl.json b/comfy/text_encoders/umt5_config_xxl.json
new file mode 100644
index 000000000..dfcb4b54b
--- /dev/null
+++ b/comfy/text_encoders/umt5_config_xxl.json
@@ -0,0 +1,22 @@
+{
+ "d_ff": 10240,
+ "d_kv": 64,
+ "d_model": 4096,
+ "decoder_start_token_id": 0,
+ "dropout_rate": 0.1,
+ "eos_token_id": 1,
+ "dense_act_fn": "gelu_pytorch_tanh",
+ "initializer_factor": 1.0,
+ "is_encoder_decoder": true,
+ "is_gated_act": true,
+ "layer_norm_epsilon": 1e-06,
+ "model_type": "umt5",
+ "num_decoder_layers": 24,
+ "num_heads": 64,
+ "num_layers": 24,
+ "output_past": true,
+ "pad_token_id": 0,
+ "relative_attention_num_buckets": 32,
+ "tie_word_embeddings": false,
+ "vocab_size": 256384
+}
diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py
new file mode 100644
index 000000000..d50fa4b28
--- /dev/null
+++ b/comfy/text_encoders/wan.py
@@ -0,0 +1,37 @@
+from comfy import sd1_clip
+from .spiece_tokenizer import SPieceTokenizer
+import comfy.text_encoders.t5
+import os
+
+class UMT5XXlModel(sd1_clip.SDClipModel):
+ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
+ textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_xxl.json")
+ super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
+
+class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ tokenizer = tokenizer_data.get("spiece_model", None)
+ super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data)
+
+ def state_dict(self):
+ return {"spiece_model": self.tokenizer.serialize_model()}
+
+
+class WanT5Tokenizer(sd1_clip.SD1Tokenizer):
+ def __init__(self, embedding_directory=None, tokenizer_data={}):
+ super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="umt5xxl", tokenizer=UMT5XXlTokenizer)
+
+class WanT5Model(sd1_clip.SD1ClipModel):
+ def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
+ super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs)
+
+def te(dtype_t5=None, t5xxl_scaled_fp8=None):
+ class WanTEModel(WanT5Model):
+ def __init__(self, device="cpu", dtype=None, model_options={}):
+ if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options:
+ model_options = model_options.copy()
+ model_options["scaled_fp8"] = t5xxl_scaled_fp8
+ if dtype_t5 is not None:
+ dtype = dtype_t5
+ super().__init__(device=device, dtype=dtype, model_options=model_options)
+ return WanTEModel
diff --git a/comfy/utils.py b/comfy/utils.py
index b486b2deb..1f8d71292 100644
--- a/comfy/utils.py
+++ b/comfy/utils.py
@@ -28,23 +28,56 @@ import logging
import itertools
from torch.nn.functional import interpolate
from einops import rearrange
+from comfy.cli_args import args
-def load_torch_file(ckpt, safe_load=False, device=None):
+MMAP_TORCH_FILES = args.mmap_torch_files
+
+ALWAYS_SAFE_LOAD = False
+if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in pytorch 2.4, the unsafe path should be removed once earlier versions are deprecated
+ class ModelCheckpoint:
+ pass
+ ModelCheckpoint.__module__ = "pytorch_lightning.callbacks.model_checkpoint"
+
+ from numpy.core.multiarray import scalar
+ from numpy import dtype
+ from numpy.dtypes import Float64DType
+ from _codecs import encode
+
+ torch.serialization.add_safe_globals([ModelCheckpoint, scalar, dtype, Float64DType, encode])
+ ALWAYS_SAFE_LOAD = True
+ logging.info("Checkpoint files will always be loaded safely.")
+else:
+ logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
+
+def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
if device is None:
device = torch.device("cpu")
+ metadata = None
if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"):
- sd = safetensors.torch.load_file(ckpt, device=device.type)
+ try:
+ with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f:
+ sd = {}
+ for k in f.keys():
+ sd[k] = f.get_tensor(k)
+ if return_metadata:
+ metadata = f.metadata()
+ except Exception as e:
+ if len(e.args) > 0:
+ message = e.args[0]
+ if "HeaderTooLarge" in message:
+ raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt))
+ if "MetadataIncompleteBuffer" in message:
+ raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt))
+ raise e
else:
- if safe_load:
- if not 'weights_only' in torch.load.__code__.co_varnames:
- logging.warning("Warning torch.load doesn't support weights_only on this pytorch version, loading unsafely.")
- safe_load = False
- if safe_load:
- pl_sd = torch.load(ckpt, map_location=device, weights_only=True)
+ torch_args = {}
+ if MMAP_TORCH_FILES:
+ torch_args["mmap"] = True
+
+ if safe_load or ALWAYS_SAFE_LOAD:
+ pl_sd = torch.load(ckpt, map_location=device, weights_only=True, **torch_args)
else:
pl_sd = torch.load(ckpt, map_location=device, pickle_module=comfy.checkpoint_pickle)
- if "global_step" in pl_sd:
- logging.debug(f"Global Step: {pl_sd['global_step']}")
if "state_dict" in pl_sd:
sd = pl_sd["state_dict"]
else:
@@ -55,7 +88,7 @@ def load_torch_file(ckpt, safe_load=False, device=None):
sd = pl_sd
else:
sd = pl_sd
- return sd
+ return (sd, metadata) if return_metadata else sd
def save_torch_file(sd, ckpt, metadata=None):
if metadata is not None:
diff --git a/comfy/weight_adapter/__init__.py b/comfy/weight_adapter/__init__.py
new file mode 100644
index 000000000..d2a1d0151
--- /dev/null
+++ b/comfy/weight_adapter/__init__.py
@@ -0,0 +1,17 @@
+from .base import WeightAdapterBase
+from .lora import LoRAAdapter
+from .loha import LoHaAdapter
+from .lokr import LoKrAdapter
+from .glora import GLoRAAdapter
+from .oft import OFTAdapter
+from .boft import BOFTAdapter
+
+
+adapters: list[type[WeightAdapterBase]] = [
+ LoRAAdapter,
+ LoHaAdapter,
+ LoKrAdapter,
+ GLoRAAdapter,
+ OFTAdapter,
+ BOFTAdapter,
+]
diff --git a/comfy/weight_adapter/base.py b/comfy/weight_adapter/base.py
new file mode 100644
index 000000000..29873519d
--- /dev/null
+++ b/comfy/weight_adapter/base.py
@@ -0,0 +1,104 @@
+from typing import Optional
+
+import torch
+import torch.nn as nn
+
+import comfy.model_management
+
+
+class WeightAdapterBase:
+ name: str
+ loaded_keys: set[str]
+ weights: list[torch.Tensor]
+
+ @classmethod
+ def load(cls, x: str, lora: dict[str, torch.Tensor]) -> Optional["WeightAdapterBase"]:
+ raise NotImplementedError
+
+ def to_train(self) -> "WeightAdapterTrainBase":
+ raise NotImplementedError
+
+ def calculate_weight(
+ self,
+ weight,
+ key,
+ strength,
+ strength_model,
+ offset,
+ function,
+ intermediate_dtype=torch.float32,
+ original_weight=None,
+ ):
+ raise NotImplementedError
+
+
+class WeightAdapterTrainBase(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ # [TODO] Collaborate with LoRA training PR #7032
+
+
+def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
+ dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
+ lora_diff *= alpha
+ weight_calc = weight + function(lora_diff).type(weight.dtype)
+
+ wd_on_output_axis = dora_scale.shape[0] == weight_calc.shape[0]
+ if wd_on_output_axis:
+ weight_norm = (
+ weight.reshape(weight.shape[0], -1)
+ .norm(dim=1, keepdim=True)
+ .reshape(weight.shape[0], *[1] * (weight.dim() - 1))
+ )
+ else:
+ weight_norm = (
+ weight_calc.transpose(0, 1)
+ .reshape(weight_calc.shape[1], -1)
+ .norm(dim=1, keepdim=True)
+ .reshape(weight_calc.shape[1], *[1] * (weight_calc.dim() - 1))
+ .transpose(0, 1)
+ )
+ weight_norm = weight_norm + torch.finfo(weight.dtype).eps
+
+ weight_calc *= (dora_scale / weight_norm).type(weight.dtype)
+ if strength != 1.0:
+ weight_calc -= weight
+ weight += strength * (weight_calc)
+ else:
+ weight[:] = weight_calc
+ return weight
+
+
+def pad_tensor_to_shape(tensor: torch.Tensor, new_shape: list[int]) -> torch.Tensor:
+ """
+ Pad a tensor to a new shape with zeros.
+
+ Args:
+ tensor (torch.Tensor): The original tensor to be padded.
+ new_shape (List[int]): The desired shape of the padded tensor.
+
+ Returns:
+ torch.Tensor: A new tensor padded with zeros to the specified shape.
+
+ Note:
+ If the new shape is smaller than the original tensor in any dimension,
+ the original tensor will be truncated in that dimension.
+ """
+ if any([new_shape[i] < tensor.shape[i] for i in range(len(new_shape))]):
+ raise ValueError("The new shape must be larger than the original tensor in all dimensions")
+
+ if len(new_shape) != len(tensor.shape):
+ raise ValueError("The new shape must have the same number of dimensions as the original tensor")
+
+ # Create a new tensor filled with zeros
+ padded_tensor = torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device)
+
+ # Create slicing tuples for both tensors
+ orig_slices = tuple(slice(0, dim) for dim in tensor.shape)
+ new_slices = tuple(slice(0, dim) for dim in tensor.shape)
+
+ # Copy the original tensor into the new tensor
+ padded_tensor[new_slices] = tensor[orig_slices]
+
+ return padded_tensor
diff --git a/comfy/weight_adapter/boft.py b/comfy/weight_adapter/boft.py
new file mode 100644
index 000000000..b2a2f1bd4
--- /dev/null
+++ b/comfy/weight_adapter/boft.py
@@ -0,0 +1,115 @@
+import logging
+from typing import Optional
+
+import torch
+import comfy.model_management
+from .base import WeightAdapterBase, weight_decompose
+
+
+class BOFTAdapter(WeightAdapterBase):
+ name = "boft"
+
+ def __init__(self, loaded_keys, weights):
+ self.loaded_keys = loaded_keys
+ self.weights = weights
+
+ @classmethod
+ def load(
+ cls,
+ x: str,
+ lora: dict[str, torch.Tensor],
+ alpha: float,
+ dora_scale: torch.Tensor,
+ loaded_keys: set[str] = None,
+ ) -> Optional["BOFTAdapter"]:
+ if loaded_keys is None:
+ loaded_keys = set()
+ blocks_name = "{}.oft_blocks".format(x)
+ rescale_name = "{}.rescale".format(x)
+
+ blocks = None
+ if blocks_name in lora.keys():
+ blocks = lora[blocks_name]
+ if blocks.ndim == 4:
+ loaded_keys.add(blocks_name)
+ else:
+ blocks = None
+ if blocks is None:
+ return None
+
+ rescale = None
+ if rescale_name in lora.keys():
+ rescale = lora[rescale_name]
+ loaded_keys.add(rescale_name)
+
+ weights = (blocks, rescale, alpha, dora_scale)
+ return cls(loaded_keys, weights)
+
+ def calculate_weight(
+ self,
+ weight,
+ key,
+ strength,
+ strength_model,
+ offset,
+ function,
+ intermediate_dtype=torch.float32,
+ original_weight=None,
+ ):
+ v = self.weights
+ blocks = v[0]
+ rescale = v[1]
+ alpha = v[2]
+ dora_scale = v[3]
+
+ blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
+ if rescale is not None:
+ rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
+
+ boft_m, block_num, boft_b, *_ = blocks.shape
+
+ try:
+ # Get r
+ I = torch.eye(boft_b, device=blocks.device, dtype=blocks.dtype)
+ # for Q = -Q^T
+ q = blocks - blocks.transpose(-1, -2)
+ normed_q = q
+ if alpha > 0: # alpha in boft/bboft is for constraint
+ q_norm = torch.norm(q) + 1e-8
+ if q_norm > alpha:
+ normed_q = q * alpha / q_norm
+ # use float() to prevent unsupported type in .inverse()
+ r = (I + normed_q) @ (I - normed_q).float().inverse()
+ r = r.to(weight)
+ inp = org = weight
+
+ r_b = boft_b//2
+ for i in range(boft_m):
+ bi = r[i]
+ g = 2
+ k = 2**i * r_b
+ if strength != 1:
+ bi = bi * strength + (1-strength) * I
+ inp = (
+ inp.unflatten(0, (-1, g, k))
+ .transpose(1, 2)
+ .flatten(0, 2)
+ .unflatten(0, (-1, boft_b))
+ )
+ inp = torch.einsum("b i j, b j ...-> b i ...", bi, inp)
+ inp = (
+ inp.flatten(0, 1).unflatten(0, (-1, k, g)).transpose(1, 2).flatten(0, 2)
+ )
+
+ if rescale is not None:
+ inp = inp * rescale
+
+ lora_diff = inp - org
+ lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
+ if dora_scale is not None:
+ weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
+ else:
+ weight += function((strength * lora_diff).type(weight.dtype))
+ except Exception as e:
+ logging.error("ERROR {} {} {}".format(self.name, key, e))
+ return weight
diff --git a/comfy/weight_adapter/glora.py b/comfy/weight_adapter/glora.py
new file mode 100644
index 000000000..939abbba5
--- /dev/null
+++ b/comfy/weight_adapter/glora.py
@@ -0,0 +1,93 @@
+import logging
+from typing import Optional
+
+import torch
+import comfy.model_management
+from .base import WeightAdapterBase, weight_decompose
+
+
+class GLoRAAdapter(WeightAdapterBase):
+ name = "glora"
+
+ def __init__(self, loaded_keys, weights):
+ self.loaded_keys = loaded_keys
+ self.weights = weights
+
+ @classmethod
+ def load(
+ cls,
+ x: str,
+ lora: dict[str, torch.Tensor],
+ alpha: float,
+ dora_scale: torch.Tensor,
+ loaded_keys: set[str] = None,
+ ) -> Optional["GLoRAAdapter"]:
+ if loaded_keys is None:
+ loaded_keys = set()
+ a1_name = "{}.a1.weight".format(x)
+ a2_name = "{}.a2.weight".format(x)
+ b1_name = "{}.b1.weight".format(x)
+ b2_name = "{}.b2.weight".format(x)
+ if a1_name in lora:
+ weights = (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha, dora_scale)
+ loaded_keys.add(a1_name)
+ loaded_keys.add(a2_name)
+ loaded_keys.add(b1_name)
+ loaded_keys.add(b2_name)
+ return cls(loaded_keys, weights)
+ else:
+ return None
+
+ def calculate_weight(
+ self,
+ weight,
+ key,
+ strength,
+ strength_model,
+ offset,
+ function,
+ intermediate_dtype=torch.float32,
+ original_weight=None,
+ ):
+ v = self.weights
+ dora_scale = v[5]
+
+ old_glora = False
+ if v[3].shape[1] == v[2].shape[0] == v[0].shape[0] == v[1].shape[1]:
+ rank = v[0].shape[0]
+ old_glora = True
+
+ if v[3].shape[0] == v[2].shape[1] == v[0].shape[1] == v[1].shape[0]:
+ if old_glora and v[1].shape[0] == weight.shape[0] and weight.shape[0] == weight.shape[1]:
+ pass
+ else:
+ old_glora = False
+ rank = v[1].shape[0]
+
+ a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
+ a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
+ b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
+ b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
+
+ if v[4] is not None:
+ alpha = v[4] / rank
+ else:
+ alpha = 1.0
+
+ try:
+ if old_glora:
+ lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape) #old lycoris glora
+ else:
+ if weight.dim() > 2:
+ lora_diff = torch.einsum("o i ..., i j -> o j ...", torch.einsum("o i ..., i j -> o j ...", weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
+ else:
+ lora_diff = torch.mm(torch.mm(weight.to(dtype=intermediate_dtype), a1), a2).reshape(weight.shape)
+ lora_diff += torch.mm(b1, b2).reshape(weight.shape)
+
+ if dora_scale is not None:
+ weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
+ else:
+ weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
+ except Exception as e:
+ logging.error("ERROR {} {} {}".format(self.name, key, e))
+ return weight
diff --git a/comfy/weight_adapter/loha.py b/comfy/weight_adapter/loha.py
new file mode 100644
index 000000000..ce79abad5
--- /dev/null
+++ b/comfy/weight_adapter/loha.py
@@ -0,0 +1,100 @@
+import logging
+from typing import Optional
+
+import torch
+import comfy.model_management
+from .base import WeightAdapterBase, weight_decompose
+
+
+class LoHaAdapter(WeightAdapterBase):
+ name = "loha"
+
+ def __init__(self, loaded_keys, weights):
+ self.loaded_keys = loaded_keys
+ self.weights = weights
+
+ @classmethod
+ def load(
+ cls,
+ x: str,
+ lora: dict[str, torch.Tensor],
+ alpha: float,
+ dora_scale: torch.Tensor,
+ loaded_keys: set[str] = None,
+ ) -> Optional["LoHaAdapter"]:
+ if loaded_keys is None:
+ loaded_keys = set()
+
+ hada_w1_a_name = "{}.hada_w1_a".format(x)
+ hada_w1_b_name = "{}.hada_w1_b".format(x)
+ hada_w2_a_name = "{}.hada_w2_a".format(x)
+ hada_w2_b_name = "{}.hada_w2_b".format(x)
+ hada_t1_name = "{}.hada_t1".format(x)
+ hada_t2_name = "{}.hada_t2".format(x)
+ if hada_w1_a_name in lora.keys():
+ hada_t1 = None
+ hada_t2 = None
+ if hada_t1_name in lora.keys():
+ hada_t1 = lora[hada_t1_name]
+ hada_t2 = lora[hada_t2_name]
+ loaded_keys.add(hada_t1_name)
+ loaded_keys.add(hada_t2_name)
+
+ weights = (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2, dora_scale)
+ loaded_keys.add(hada_w1_a_name)
+ loaded_keys.add(hada_w1_b_name)
+ loaded_keys.add(hada_w2_a_name)
+ loaded_keys.add(hada_w2_b_name)
+ return cls(loaded_keys, weights)
+ else:
+ return None
+
+ def calculate_weight(
+ self,
+ weight,
+ key,
+ strength,
+ strength_model,
+ offset,
+ function,
+ intermediate_dtype=torch.float32,
+ original_weight=None,
+ ):
+ v = self.weights
+ w1a = v[0]
+ w1b = v[1]
+ if v[2] is not None:
+ alpha = v[2] / w1b.shape[0]
+ else:
+ alpha = 1.0
+
+ w2a = v[3]
+ w2b = v[4]
+ dora_scale = v[7]
+ if v[5] is not None: #cp decomposition
+ t1 = v[5]
+ t2 = v[6]
+ m1 = torch.einsum('i j k l, j r, i p -> p r k l',
+ comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
+ comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
+ comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
+
+ m2 = torch.einsum('i j k l, j r, i p -> p r k l',
+ comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
+ comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
+ comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
+ else:
+ m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
+ comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
+ m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
+ comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
+
+ try:
+ lora_diff = (m1 * m2).reshape(weight.shape)
+ if dora_scale is not None:
+ weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
+ else:
+ weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
+ except Exception as e:
+ logging.error("ERROR {} {} {}".format(self.name, key, e))
+ return weight
diff --git a/comfy/weight_adapter/lokr.py b/comfy/weight_adapter/lokr.py
new file mode 100644
index 000000000..51233db2d
--- /dev/null
+++ b/comfy/weight_adapter/lokr.py
@@ -0,0 +1,133 @@
+import logging
+from typing import Optional
+
+import torch
+import comfy.model_management
+from .base import WeightAdapterBase, weight_decompose
+
+
+class LoKrAdapter(WeightAdapterBase):
+ name = "lokr"
+
+ def __init__(self, loaded_keys, weights):
+ self.loaded_keys = loaded_keys
+ self.weights = weights
+
+ @classmethod
+ def load(
+ cls,
+ x: str,
+ lora: dict[str, torch.Tensor],
+ alpha: float,
+ dora_scale: torch.Tensor,
+ loaded_keys: set[str] = None,
+ ) -> Optional["LoKrAdapter"]:
+ if loaded_keys is None:
+ loaded_keys = set()
+ lokr_w1_name = "{}.lokr_w1".format(x)
+ lokr_w2_name = "{}.lokr_w2".format(x)
+ lokr_w1_a_name = "{}.lokr_w1_a".format(x)
+ lokr_w1_b_name = "{}.lokr_w1_b".format(x)
+ lokr_t2_name = "{}.lokr_t2".format(x)
+ lokr_w2_a_name = "{}.lokr_w2_a".format(x)
+ lokr_w2_b_name = "{}.lokr_w2_b".format(x)
+
+ lokr_w1 = None
+ if lokr_w1_name in lora.keys():
+ lokr_w1 = lora[lokr_w1_name]
+ loaded_keys.add(lokr_w1_name)
+
+ lokr_w2 = None
+ if lokr_w2_name in lora.keys():
+ lokr_w2 = lora[lokr_w2_name]
+ loaded_keys.add(lokr_w2_name)
+
+ lokr_w1_a = None
+ if lokr_w1_a_name in lora.keys():
+ lokr_w1_a = lora[lokr_w1_a_name]
+ loaded_keys.add(lokr_w1_a_name)
+
+ lokr_w1_b = None
+ if lokr_w1_b_name in lora.keys():
+ lokr_w1_b = lora[lokr_w1_b_name]
+ loaded_keys.add(lokr_w1_b_name)
+
+ lokr_w2_a = None
+ if lokr_w2_a_name in lora.keys():
+ lokr_w2_a = lora[lokr_w2_a_name]
+ loaded_keys.add(lokr_w2_a_name)
+
+ lokr_w2_b = None
+ if lokr_w2_b_name in lora.keys():
+ lokr_w2_b = lora[lokr_w2_b_name]
+ loaded_keys.add(lokr_w2_b_name)
+
+ lokr_t2 = None
+ if lokr_t2_name in lora.keys():
+ lokr_t2 = lora[lokr_t2_name]
+ loaded_keys.add(lokr_t2_name)
+
+ if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
+ weights = (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2, dora_scale)
+ return cls(loaded_keys, weights)
+ else:
+ return None
+
+ def calculate_weight(
+ self,
+ weight,
+ key,
+ strength,
+ strength_model,
+ offset,
+ function,
+ intermediate_dtype=torch.float32,
+ original_weight=None,
+ ):
+ v = self.weights
+ w1 = v[0]
+ w2 = v[1]
+ w1_a = v[3]
+ w1_b = v[4]
+ w2_a = v[5]
+ w2_b = v[6]
+ t2 = v[7]
+ dora_scale = v[8]
+ dim = None
+
+ if w1 is None:
+ dim = w1_b.shape[0]
+ w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
+ comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
+ else:
+ w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
+
+ if w2 is None:
+ dim = w2_b.shape[0]
+ if t2 is None:
+ w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
+ comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
+ else:
+ w2 = torch.einsum('i j k l, j r, i p -> p r k l',
+ comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
+ comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
+ comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
+ else:
+ w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
+
+ if len(w2.shape) == 4:
+ w1 = w1.unsqueeze(2).unsqueeze(2)
+ if v[2] is not None and dim is not None:
+ alpha = v[2] / dim
+ else:
+ alpha = 1.0
+
+ try:
+ lora_diff = torch.kron(w1, w2).reshape(weight.shape)
+ if dora_scale is not None:
+ weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
+ else:
+ weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
+ except Exception as e:
+ logging.error("ERROR {} {} {}".format(self.name, key, e))
+ return weight
diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py
new file mode 100644
index 000000000..b2e623924
--- /dev/null
+++ b/comfy/weight_adapter/lora.py
@@ -0,0 +1,142 @@
+import logging
+from typing import Optional
+
+import torch
+import comfy.model_management
+from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
+
+
+class LoRAAdapter(WeightAdapterBase):
+ name = "lora"
+
+ def __init__(self, loaded_keys, weights):
+ self.loaded_keys = loaded_keys
+ self.weights = weights
+
+ @classmethod
+ def load(
+ cls,
+ x: str,
+ lora: dict[str, torch.Tensor],
+ alpha: float,
+ dora_scale: torch.Tensor,
+ loaded_keys: set[str] = None,
+ ) -> Optional["LoRAAdapter"]:
+ if loaded_keys is None:
+ loaded_keys = set()
+
+ reshape_name = "{}.reshape_weight".format(x)
+ regular_lora = "{}.lora_up.weight".format(x)
+ diffusers_lora = "{}_lora.up.weight".format(x)
+ diffusers2_lora = "{}.lora_B.weight".format(x)
+ diffusers3_lora = "{}.lora.up.weight".format(x)
+ mochi_lora = "{}.lora_B".format(x)
+ transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
+ A_name = None
+
+ if regular_lora in lora.keys():
+ A_name = regular_lora
+ B_name = "{}.lora_down.weight".format(x)
+ mid_name = "{}.lora_mid.weight".format(x)
+ elif diffusers_lora in lora.keys():
+ A_name = diffusers_lora
+ B_name = "{}_lora.down.weight".format(x)
+ mid_name = None
+ elif diffusers2_lora in lora.keys():
+ A_name = diffusers2_lora
+ B_name = "{}.lora_A.weight".format(x)
+ mid_name = None
+ elif diffusers3_lora in lora.keys():
+ A_name = diffusers3_lora
+ B_name = "{}.lora.down.weight".format(x)
+ mid_name = None
+ elif mochi_lora in lora.keys():
+ A_name = mochi_lora
+ B_name = "{}.lora_A".format(x)
+ mid_name = None
+ elif transformers_lora in lora.keys():
+ A_name = transformers_lora
+ B_name = "{}.lora_linear_layer.down.weight".format(x)
+ mid_name = None
+
+ if A_name is not None:
+ mid = None
+ if mid_name is not None and mid_name in lora.keys():
+ mid = lora[mid_name]
+ loaded_keys.add(mid_name)
+ reshape = None
+ if reshape_name in lora.keys():
+ try:
+ reshape = lora[reshape_name].tolist()
+ loaded_keys.add(reshape_name)
+ except:
+ pass
+ weights = (lora[A_name], lora[B_name], alpha, mid, dora_scale, reshape)
+ loaded_keys.add(A_name)
+ loaded_keys.add(B_name)
+ return cls(loaded_keys, weights)
+ else:
+ return None
+
+ def calculate_weight(
+ self,
+ weight,
+ key,
+ strength,
+ strength_model,
+ offset,
+ function,
+ intermediate_dtype=torch.float32,
+ original_weight=None,
+ ):
+ v = self.weights
+ mat1 = comfy.model_management.cast_to_device(
+ v[0], weight.device, intermediate_dtype
+ )
+ mat2 = comfy.model_management.cast_to_device(
+ v[1], weight.device, intermediate_dtype
+ )
+ dora_scale = v[4]
+ reshape = v[5]
+
+ if reshape is not None:
+ weight = pad_tensor_to_shape(weight, reshape)
+
+ if v[2] is not None:
+ alpha = v[2] / mat2.shape[0]
+ else:
+ alpha = 1.0
+
+ if v[3] is not None:
+ # locon mid weights, hopefully the math is fine because I didn't properly test it
+ mat3 = comfy.model_management.cast_to_device(
+ v[3], weight.device, intermediate_dtype
+ )
+ final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
+ mat2 = (
+ torch.mm(
+ mat2.transpose(0, 1).flatten(start_dim=1),
+ mat3.transpose(0, 1).flatten(start_dim=1),
+ )
+ .reshape(final_shape)
+ .transpose(0, 1)
+ )
+ try:
+ lora_diff = torch.mm(
+ mat1.flatten(start_dim=1), mat2.flatten(start_dim=1)
+ ).reshape(weight.shape)
+ if dora_scale is not None:
+ weight = weight_decompose(
+ dora_scale,
+ weight,
+ lora_diff,
+ alpha,
+ strength,
+ intermediate_dtype,
+ function,
+ )
+ else:
+ weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
+ except Exception as e:
+ logging.error("ERROR {} {} {}".format(self.name, key, e))
+ return weight
diff --git a/comfy/weight_adapter/oft.py b/comfy/weight_adapter/oft.py
new file mode 100644
index 000000000..25009eca3
--- /dev/null
+++ b/comfy/weight_adapter/oft.py
@@ -0,0 +1,96 @@
+import logging
+from typing import Optional
+
+import torch
+import comfy.model_management
+from .base import WeightAdapterBase, weight_decompose
+
+
+class OFTAdapter(WeightAdapterBase):
+ name = "oft"
+
+ def __init__(self, loaded_keys, weights):
+ self.loaded_keys = loaded_keys
+ self.weights = weights
+
+ @classmethod
+ def load(
+ cls,
+ x: str,
+ lora: dict[str, torch.Tensor],
+ alpha: float,
+ dora_scale: torch.Tensor,
+ loaded_keys: set[str] = None,
+ ) -> Optional["OFTAdapter"]:
+ if loaded_keys is None:
+ loaded_keys = set()
+ blocks_name = "{}.oft_blocks".format(x)
+ rescale_name = "{}.rescale".format(x)
+
+ blocks = None
+ if blocks_name in lora.keys():
+ blocks = lora[blocks_name]
+ if blocks.ndim == 3:
+ loaded_keys.add(blocks_name)
+ else:
+ blocks = None
+ if blocks is None:
+ return None
+
+ rescale = None
+ if rescale_name in lora.keys():
+ rescale = lora[rescale_name]
+ loaded_keys.add(rescale_name)
+
+ weights = (blocks, rescale, alpha, dora_scale)
+ return cls(loaded_keys, weights)
+
+ def calculate_weight(
+ self,
+ weight,
+ key,
+ strength,
+ strength_model,
+ offset,
+ function,
+ intermediate_dtype=torch.float32,
+ original_weight=None,
+ ):
+ v = self.weights
+ blocks = v[0]
+ rescale = v[1]
+ alpha = v[2]
+ dora_scale = v[3]
+
+ blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
+ if rescale is not None:
+ rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
+
+ block_num, block_size, *_ = blocks.shape
+
+ try:
+ # Get r
+ I = torch.eye(block_size, device=blocks.device, dtype=blocks.dtype)
+ # for Q = -Q^T
+ q = blocks - blocks.transpose(1, 2)
+ normed_q = q
+ if alpha > 0: # alpha in oft/boft is for constraint
+ q_norm = torch.norm(q) + 1e-8
+ if q_norm > alpha:
+ normed_q = q * alpha / q_norm
+ # use float() to prevent unsupported type in .inverse()
+ r = (I + normed_q) @ (I - normed_q).float().inverse()
+ r = r.to(weight)
+ _, *shape = weight.shape
+ lora_diff = torch.einsum(
+ "k n m, k n ... -> k m ...",
+ (r * strength) - strength * I,
+ weight.view(block_num, block_size, *shape),
+ ).view(-1, *shape)
+ if dora_scale is not None:
+ weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
+ else:
+ weight += function((strength * lora_diff).type(weight.dtype))
+ except Exception as e:
+ logging.error("ERROR {} {} {}".format(self.name, key, e))
+ return weight
diff --git a/comfy_api/input/__init__.py b/comfy_api/input/__init__.py
new file mode 100644
index 000000000..66667946f
--- /dev/null
+++ b/comfy_api/input/__init__.py
@@ -0,0 +1,8 @@
+from .basic_types import ImageInput, AudioInput
+from .video_types import VideoInput
+
+__all__ = [
+ "ImageInput",
+ "AudioInput",
+ "VideoInput",
+]
diff --git a/comfy_api/input/basic_types.py b/comfy_api/input/basic_types.py
new file mode 100644
index 000000000..033fb7e27
--- /dev/null
+++ b/comfy_api/input/basic_types.py
@@ -0,0 +1,20 @@
+import torch
+from typing import TypedDict
+
+ImageInput = torch.Tensor
+"""
+An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
+"""
+
+class AudioInput(TypedDict):
+ """
+ TypedDict representing audio input.
+ """
+
+ waveform: torch.Tensor
+ """
+ Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
+ """
+
+ sample_rate: int
+
diff --git a/comfy_api/input/video_types.py b/comfy_api/input/video_types.py
new file mode 100644
index 000000000..dc22d34ff
--- /dev/null
+++ b/comfy_api/input/video_types.py
@@ -0,0 +1,55 @@
+from __future__ import annotations
+from abc import ABC, abstractmethod
+from typing import Optional
+from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
+
+class VideoInput(ABC):
+ """
+ Abstract base class for video input types.
+ """
+
+ @abstractmethod
+ def get_components(self) -> VideoComponents:
+ """
+ Abstract method to get the video components (images, audio, and frame rate).
+
+ Returns:
+ VideoComponents containing images, audio, and frame rate
+ """
+ pass
+
+ @abstractmethod
+ def save_to(
+ self,
+ path: str,
+ format: VideoContainer = VideoContainer.AUTO,
+ codec: VideoCodec = VideoCodec.AUTO,
+ metadata: Optional[dict] = None
+ ):
+ """
+ Abstract method to save the video input to a file.
+ """
+ pass
+
+ # Provide a default implementation, but subclasses can provide optimized versions
+ # if possible.
+ def get_dimensions(self) -> tuple[int, int]:
+ """
+ Returns the dimensions of the video input.
+
+ Returns:
+ Tuple of (width, height)
+ """
+ components = self.get_components()
+ return components.images.shape[2], components.images.shape[1]
+
+ def get_duration(self) -> float:
+ """
+ Returns the duration of the video in seconds.
+
+ Returns:
+ Duration in seconds
+ """
+ components = self.get_components()
+ frame_count = components.images.shape[0]
+ return float(frame_count / components.frame_rate)
diff --git a/comfy_api/input_impl/__init__.py b/comfy_api/input_impl/__init__.py
new file mode 100644
index 000000000..02901b8b9
--- /dev/null
+++ b/comfy_api/input_impl/__init__.py
@@ -0,0 +1,7 @@
+from .video_types import VideoFromFile, VideoFromComponents
+
+__all__ = [
+ # Implementations
+ "VideoFromFile",
+ "VideoFromComponents",
+]
diff --git a/comfy_api/input_impl/video_types.py b/comfy_api/input_impl/video_types.py
new file mode 100644
index 000000000..197f6558c
--- /dev/null
+++ b/comfy_api/input_impl/video_types.py
@@ -0,0 +1,303 @@
+from __future__ import annotations
+from av.container import InputContainer
+from av.subtitles.stream import SubtitleStream
+from fractions import Fraction
+from typing import Optional
+from comfy_api.input import AudioInput
+import av
+import io
+import json
+import numpy as np
+import torch
+from comfy_api.input import VideoInput
+from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
+
+
+def container_to_output_format(container_format: str | None) -> str | None:
+ """
+ A container's `format` may be a comma-separated list of formats.
+ E.g., iso container's `format` may be `mov,mp4,m4a,3gp,3g2,mj2`.
+ However, writing to a file/stream with `av.open` requires a single format,
+ or `None` to auto-detect.
+ """
+ if not container_format:
+ return None # Auto-detect
+
+ if "," not in container_format:
+ return container_format
+
+ formats = container_format.split(",")
+ return formats[0]
+
+
+def get_open_write_kwargs(
+ dest: str | io.BytesIO, container_format: str, to_format: str | None
+) -> dict:
+ """Get kwargs for writing a `VideoFromFile` to a file/stream with `av.open`"""
+ open_kwargs = {
+ "mode": "w",
+ # If isobmff, preserve custom metadata tags (workflow, prompt, extra_pnginfo)
+ "options": {"movflags": "use_metadata_tags"},
+ }
+
+ is_write_to_buffer = isinstance(dest, io.BytesIO)
+ if is_write_to_buffer:
+ # Set output format explicitly, since it cannot be inferred from file extension
+ if to_format == VideoContainer.AUTO:
+ to_format = container_format.lower()
+ elif isinstance(to_format, str):
+ to_format = to_format.lower()
+ open_kwargs["format"] = container_to_output_format(to_format)
+
+ return open_kwargs
+
+
+class VideoFromFile(VideoInput):
+ """
+ Class representing video input from a file.
+ """
+
+ def __init__(self, file: str | io.BytesIO):
+ """
+ Initialize the VideoFromFile object based off of either a path on disk or a BytesIO object
+ containing the file contents.
+ """
+ self.__file = file
+
+ def get_dimensions(self) -> tuple[int, int]:
+ """
+ Returns the dimensions of the video input.
+
+ Returns:
+ Tuple of (width, height)
+ """
+ if isinstance(self.__file, io.BytesIO):
+ self.__file.seek(0) # Reset the BytesIO object to the beginning
+ with av.open(self.__file, mode='r') as container:
+ for stream in container.streams:
+ if stream.type == 'video':
+ assert isinstance(stream, av.VideoStream)
+ return stream.width, stream.height
+ raise ValueError(f"No video stream found in file '{self.__file}'")
+
+ def get_duration(self) -> float:
+ """
+ Returns the duration of the video in seconds.
+
+ Returns:
+ Duration in seconds
+ """
+ if isinstance(self.__file, io.BytesIO):
+ self.__file.seek(0)
+ with av.open(self.__file, mode="r") as container:
+ if container.duration is not None:
+ return float(container.duration / av.time_base)
+
+ # Fallback: calculate from frame count and frame rate
+ video_stream = next(
+ (s for s in container.streams if s.type == "video"), None
+ )
+ if video_stream and video_stream.frames and video_stream.average_rate:
+ return float(video_stream.frames / video_stream.average_rate)
+
+ # Last resort: decode frames to count them
+ if video_stream and video_stream.average_rate:
+ frame_count = 0
+ container.seek(0)
+ for packet in container.demux(video_stream):
+ for _ in packet.decode():
+ frame_count += 1
+ if frame_count > 0:
+ return float(frame_count / video_stream.average_rate)
+
+ raise ValueError(f"Could not determine duration for file '{self.__file}'")
+
+ def get_components_internal(self, container: InputContainer) -> VideoComponents:
+ # Get video frames
+ frames = []
+ for frame in container.decode(video=0):
+ img = frame.to_ndarray(format='rgb24') # shape: (H, W, 3)
+ img = torch.from_numpy(img) / 255.0 # shape: (H, W, 3)
+ frames.append(img)
+
+ images = torch.stack(frames) if len(frames) > 0 else torch.zeros(0, 3, 0, 0)
+
+ # Get frame rate
+ video_stream = next(s for s in container.streams if s.type == 'video')
+ frame_rate = Fraction(video_stream.average_rate) if video_stream and video_stream.average_rate else Fraction(1)
+
+ # Get audio if available
+ audio = None
+ try:
+ container.seek(0) # Reset the container to the beginning
+ for stream in container.streams:
+ if stream.type != 'audio':
+ continue
+ assert isinstance(stream, av.AudioStream)
+ audio_frames = []
+ for packet in container.demux(stream):
+ for frame in packet.decode():
+ assert isinstance(frame, av.AudioFrame)
+ audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
+ if len(audio_frames) > 0:
+ audio_data = np.concatenate(audio_frames, axis=1) # shape: (channels, total_samples)
+ audio_tensor = torch.from_numpy(audio_data).unsqueeze(0) # shape: (1, channels, total_samples)
+ audio = AudioInput({
+ "waveform": audio_tensor,
+ "sample_rate": int(stream.sample_rate) if stream.sample_rate else 1,
+ })
+ except StopIteration:
+ pass # No audio stream
+
+ metadata = container.metadata
+ return VideoComponents(images=images, audio=audio, frame_rate=frame_rate, metadata=metadata)
+
+ def get_components(self) -> VideoComponents:
+ if isinstance(self.__file, io.BytesIO):
+ self.__file.seek(0) # Reset the BytesIO object to the beginning
+ with av.open(self.__file, mode='r') as container:
+ return self.get_components_internal(container)
+ raise ValueError(f"No video stream found in file '{self.__file}'")
+
+ def save_to(
+ self,
+ path: str | io.BytesIO,
+ format: VideoContainer = VideoContainer.AUTO,
+ codec: VideoCodec = VideoCodec.AUTO,
+ metadata: Optional[dict] = None
+ ):
+ if isinstance(self.__file, io.BytesIO):
+ self.__file.seek(0) # Reset the BytesIO object to the beginning
+ with av.open(self.__file, mode='r') as container:
+ container_format = container.format.name
+ video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
+ reuse_streams = True
+ if format != VideoContainer.AUTO and format not in container_format.split(","):
+ reuse_streams = False
+ if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
+ reuse_streams = False
+
+ if not reuse_streams:
+ components = self.get_components_internal(container)
+ video = VideoFromComponents(components)
+ return video.save_to(
+ path,
+ format=format,
+ codec=codec,
+ metadata=metadata
+ )
+
+ streams = container.streams
+
+ open_kwargs = get_open_write_kwargs(path, container_format, format)
+ with av.open(path, **open_kwargs) as output_container:
+ # Copy over the original metadata
+ for key, value in container.metadata.items():
+ if metadata is None or key not in metadata:
+ output_container.metadata[key] = value
+
+ # Add our new metadata
+ if metadata is not None:
+ for key, value in metadata.items():
+ if isinstance(value, str):
+ output_container.metadata[key] = value
+ else:
+ output_container.metadata[key] = json.dumps(value)
+
+ # Add streams to the new container
+ stream_map = {}
+ for stream in streams:
+ if isinstance(stream, (av.VideoStream, av.AudioStream, SubtitleStream)):
+ out_stream = output_container.add_stream_from_template(template=stream, opaque=True)
+ stream_map[stream] = out_stream
+
+ # Write packets to the new container
+ for packet in container.demux():
+ if packet.stream in stream_map and packet.dts is not None:
+ packet.stream = stream_map[packet.stream]
+ output_container.mux(packet)
+
+class VideoFromComponents(VideoInput):
+ """
+ Class representing video input from tensors.
+ """
+
+ def __init__(self, components: VideoComponents):
+ self.__components = components
+
+ def get_components(self) -> VideoComponents:
+ return VideoComponents(
+ images=self.__components.images,
+ audio=self.__components.audio,
+ frame_rate=self.__components.frame_rate
+ )
+
+ def save_to(
+ self,
+ path: str,
+ format: VideoContainer = VideoContainer.AUTO,
+ codec: VideoCodec = VideoCodec.AUTO,
+ metadata: Optional[dict] = None
+ ):
+ if format != VideoContainer.AUTO and format != VideoContainer.MP4:
+ raise ValueError("Only MP4 format is supported for now")
+ if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
+ raise ValueError("Only H264 codec is supported for now")
+ with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output:
+ # Add metadata before writing any streams
+ if metadata is not None:
+ for key, value in metadata.items():
+ output.metadata[key] = json.dumps(value)
+
+ frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
+ # Create a video stream
+ video_stream = output.add_stream('h264', rate=frame_rate)
+ video_stream.width = self.__components.images.shape[2]
+ video_stream.height = self.__components.images.shape[1]
+ video_stream.pix_fmt = 'yuv420p'
+
+ # Create an audio stream
+ audio_sample_rate = 1
+ audio_stream: Optional[av.AudioStream] = None
+ if self.__components.audio:
+ audio_sample_rate = int(self.__components.audio['sample_rate'])
+ audio_stream = output.add_stream('aac', rate=audio_sample_rate)
+ audio_stream.sample_rate = audio_sample_rate
+ audio_stream.format = 'fltp'
+
+ # Encode video
+ for i, frame in enumerate(self.__components.images):
+ img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
+ frame = av.VideoFrame.from_ndarray(img, format='rgb24')
+ frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
+ packet = video_stream.encode(frame)
+ output.mux(packet)
+
+ # Flush video
+ packet = video_stream.encode(None)
+ output.mux(packet)
+
+ if audio_stream and self.__components.audio:
+ # Encode audio
+ samples_per_frame = int(audio_sample_rate / frame_rate)
+ num_frames = self.__components.audio['waveform'].shape[2] // samples_per_frame
+ for i in range(num_frames):
+ start = i * samples_per_frame
+ end = start + samples_per_frame
+ # TODO(Feature) - Add support for stereo audio
+ chunk = (
+ self.__components.audio["waveform"][0, 0, start:end]
+ .unsqueeze(0)
+ .contiguous()
+ .numpy()
+ )
+ audio_frame = av.AudioFrame.from_ndarray(chunk, format='fltp', layout='mono')
+ audio_frame.sample_rate = audio_sample_rate
+ audio_frame.pts = i * samples_per_frame
+ for packet in audio_stream.encode(audio_frame):
+ output.mux(packet)
+
+ # Flush audio
+ for packet in audio_stream.encode(None):
+ output.mux(packet)
+
diff --git a/comfy_api/torch_helpers/__init__.py b/comfy_api/torch_helpers/__init__.py
new file mode 100644
index 000000000..be7ae7a61
--- /dev/null
+++ b/comfy_api/torch_helpers/__init__.py
@@ -0,0 +1,5 @@
+from .torch_compile import set_torch_compile_wrapper
+
+__all__ = [
+ "set_torch_compile_wrapper",
+]
diff --git a/comfy_api/torch_helpers/torch_compile.py b/comfy_api/torch_helpers/torch_compile.py
new file mode 100644
index 000000000..9223f58db
--- /dev/null
+++ b/comfy_api/torch_helpers/torch_compile.py
@@ -0,0 +1,69 @@
+from __future__ import annotations
+import torch
+
+import comfy.utils
+from comfy.patcher_extension import WrappersMP
+from typing import TYPE_CHECKING, Callable, Optional
+if TYPE_CHECKING:
+ from comfy.model_patcher import ModelPatcher
+ from comfy.patcher_extension import WrapperExecutor
+
+
+COMPILE_KEY = "torch.compile"
+TORCH_COMPILE_KWARGS = "torch_compile_kwargs"
+
+
+def apply_torch_compile_factory(compiled_module_dict: dict[str, Callable]) -> Callable:
+ '''
+ Create a wrapper that will refer to the compiled_diffusion_model.
+ '''
+ def apply_torch_compile_wrapper(executor: WrapperExecutor, *args, **kwargs):
+ try:
+ orig_modules = {}
+ for key, value in compiled_module_dict.items():
+ orig_modules[key] = comfy.utils.get_attr(executor.class_obj, key)
+ comfy.utils.set_attr(executor.class_obj, key, value)
+ return executor(*args, **kwargs)
+ finally:
+ for key, value in orig_modules.items():
+ comfy.utils.set_attr(executor.class_obj, key, value)
+ return apply_torch_compile_wrapper
+
+
+def set_torch_compile_wrapper(model: ModelPatcher, backend: str, options: Optional[dict[str,str]]=None,
+ mode: Optional[str]=None, fullgraph=False, dynamic: Optional[bool]=None,
+ keys: list[str]=["diffusion_model"], *args, **kwargs):
+ '''
+ Perform torch.compile that will be applied at sample time for either the whole model or specific params of the BaseModel instance.
+
+ When keys is None, it will default to using ["diffusion_model"], compiling the whole diffusion_model.
+ When a list of keys is provided, it will perform torch.compile on only the selected modules.
+ '''
+ # clear out any other torch.compile wrappers
+ model.remove_wrappers_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY)
+ # if no keys, default to 'diffusion_model'
+ if not keys:
+ keys = ["diffusion_model"]
+ # create kwargs dict that can be referenced later
+ compile_kwargs = {
+ "backend": backend,
+ "options": options,
+ "mode": mode,
+ "fullgraph": fullgraph,
+ "dynamic": dynamic,
+ }
+ # get a dict of compiled keys
+ compiled_modules = {}
+ for key in keys:
+ compiled_modules[key] = torch.compile(
+ model=model.get_model_object(key),
+ **compile_kwargs,
+ )
+ # add torch.compile wrapper
+ wrapper_func = apply_torch_compile_factory(
+ compiled_module_dict=compiled_modules,
+ )
+ # store wrapper to run on BaseModel's apply_model function
+ model.add_wrapper_with_key(WrappersMP.APPLY_MODEL, COMPILE_KEY, wrapper_func)
+ # keep compile kwargs for reference
+ model.model_options[TORCH_COMPILE_KWARGS] = compile_kwargs
diff --git a/comfy_api/util/__init__.py b/comfy_api/util/__init__.py
new file mode 100644
index 000000000..9019c46db
--- /dev/null
+++ b/comfy_api/util/__init__.py
@@ -0,0 +1,8 @@
+from .video_types import VideoContainer, VideoCodec, VideoComponents
+
+__all__ = [
+ # Utility Types
+ "VideoContainer",
+ "VideoCodec",
+ "VideoComponents",
+]
diff --git a/comfy_api/util/video_types.py b/comfy_api/util/video_types.py
new file mode 100644
index 000000000..d09663db9
--- /dev/null
+++ b/comfy_api/util/video_types.py
@@ -0,0 +1,51 @@
+from __future__ import annotations
+from dataclasses import dataclass
+from enum import Enum
+from fractions import Fraction
+from typing import Optional
+from comfy_api.input import ImageInput, AudioInput
+
+class VideoCodec(str, Enum):
+ AUTO = "auto"
+ H264 = "h264"
+
+ @classmethod
+ def as_input(cls) -> list[str]:
+ """
+ Returns a list of codec names that can be used as node input.
+ """
+ return [member.value for member in cls]
+
+class VideoContainer(str, Enum):
+ AUTO = "auto"
+ MP4 = "mp4"
+
+ @classmethod
+ def as_input(cls) -> list[str]:
+ """
+ Returns a list of container names that can be used as node input.
+ """
+ return [member.value for member in cls]
+
+ @classmethod
+ def get_extension(cls, value) -> str:
+ """
+ Returns the file extension for the container.
+ """
+ if isinstance(value, str):
+ value = cls(value)
+ if value == VideoContainer.MP4 or value == VideoContainer.AUTO:
+ return "mp4"
+ return ""
+
+@dataclass
+class VideoComponents:
+ """
+ Dataclass representing the components of a video.
+ """
+
+ images: ImageInput
+ frame_rate: Fraction
+ audio: Optional[AudioInput] = None
+ metadata: Optional[dict] = None
+
diff --git a/comfy_api_nodes/README.md b/comfy_api_nodes/README.md
new file mode 100644
index 000000000..64a389cc1
--- /dev/null
+++ b/comfy_api_nodes/README.md
@@ -0,0 +1,65 @@
+# ComfyUI API Nodes
+
+## Introduction
+
+Below are a collection of nodes that work by calling external APIs. More information available in our [docs](https://docs.comfy.org/tutorials/api-nodes/overview#api-nodes).
+
+## Development
+
+While developing, you should be testing against the Staging environment. To test against staging:
+
+**Install ComfyUI_frontend**
+
+Follow the instructions [here](https://github.com/Comfy-Org/ComfyUI_frontend) to start the frontend server. By default, it will connect to Staging authentication.
+
+> **Hint:** If you use --front-end-version argument for ComfyUI, it will use production authentication.
+
+```bash
+python run main.py --comfy-api-base https://stagingapi.comfy.org
+```
+
+To authenticate to staging, please login and then ask one of Comfy Org team to whitelist you for access to staging.
+
+API stubs are generated through automatic codegen tools from OpenAPI definitions. Since the Comfy Org OpenAPI definition contains many things from the Comfy Registry as well, we use redocly/cli to filter out only the paths relevant for API nodes.
+
+### Redocly Instructions
+
+**Tip**
+When developing locally, use the `redocly-dev.yaml` file to generate pydantic models. This lets you use stubs for APIs that are not marked `Released` yet.
+
+Before your API node PR merges, make sure to add the `Released` tag to the `openapi.yaml` file and test in staging.
+
+```bash
+# Download the OpenAPI file from staging server.
+curl -o openapi.yaml https://stagingapi.comfy.org/openapi
+
+# Filter out unneeded API definitions.
+npm install -g @redocly/cli
+redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly-dev.yaml --remove-unused-components
+
+# Generate the pydantic datamodels for validation.
+datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
+
+```
+
+
+# Merging to Master
+
+Before merging to comfyanonymous/ComfyUI master, follow these steps:
+
+1. Add the "Released" tag to the ComfyUI OpenAPI yaml file for each endpoint you are using in the nodes.
+1. Make sure the ComfyUI API is deployed to prod with your changes.
+1. Run the code generation again with `redocly.yaml` and the production OpenAPI yaml file.
+
+```bash
+# Download the OpenAPI file from prod server.
+curl -o openapi.yaml https://api.comfy.org/openapi
+
+# Filter out unneeded API definitions.
+npm install -g @redocly/cli
+redocly bundle openapi.yaml --output filtered-openapi.yaml --config comfy_api_nodes/redocly.yaml --remove-unused-components
+
+# Generate the pydantic datamodels for validation.
+datamodel-codegen --use-subclass-enum --field-constraints --strict-types bytes --input filtered-openapi.yaml --output comfy_api_nodes/apis/__init__.py --output-model-type pydantic_v2.BaseModel
+
+```
diff --git a/comfy_api_nodes/__init__.py b/comfy_api_nodes/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py
new file mode 100644
index 000000000..788e2803f
--- /dev/null
+++ b/comfy_api_nodes/apinode_utils.py
@@ -0,0 +1,678 @@
+from __future__ import annotations
+import io
+import logging
+import mimetypes
+from typing import Optional, Union
+from comfy.utils import common_upscale
+from comfy_api.input_impl import VideoFromFile
+from comfy_api.util import VideoContainer, VideoCodec
+from comfy_api.input.video_types import VideoInput
+from comfy_api.input.basic_types import AudioInput
+from comfy_api_nodes.apis.client import (
+ ApiClient,
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+ UploadRequest,
+ UploadResponse,
+)
+from server import PromptServer
+
+
+import numpy as np
+from PIL import Image
+import requests
+import torch
+import math
+import base64
+import uuid
+from io import BytesIO
+import av
+
+
+def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
+ """Downloads a video from a URL and returns a `VIDEO` output.
+
+ Args:
+ video_url: The URL of the video to download.
+
+ Returns:
+ A Comfy node `VIDEO` output.
+ """
+ video_io = download_url_to_bytesio(video_url, timeout)
+ if video_io is None:
+ error_msg = f"Failed to download video from {video_url}"
+ logging.error(error_msg)
+ raise ValueError(error_msg)
+ return VideoFromFile(video_io)
+
+
+def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
+ """Downscale input image tensor to roughly the specified total pixels."""
+ samples = image.movedim(-1, 1)
+ total = int(total_pixels)
+ scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
+ if scale_by >= 1:
+ return image
+ width = round(samples.shape[3] * scale_by)
+ height = round(samples.shape[2] * scale_by)
+
+ s = common_upscale(samples, width, height, "lanczos", "disabled")
+ s = s.movedim(1, -1)
+ return s
+
+
+def validate_and_cast_response(
+ response, timeout: int = None, node_id: Union[str, None] = None
+) -> torch.Tensor:
+ """Validates and casts a response to a torch.Tensor.
+
+ Args:
+ response: The response to validate and cast.
+ timeout: Request timeout in seconds. Defaults to None (no timeout).
+
+ Returns:
+ A torch.Tensor representing the image (1, H, W, C).
+
+ Raises:
+ ValueError: If the response is not valid.
+ """
+ # validate raw JSON response
+ data = response.data
+ if not data or len(data) == 0:
+ raise ValueError("No images returned from API endpoint")
+
+ # Initialize list to store image tensors
+ image_tensors: list[torch.Tensor] = []
+
+ # Process each image in the data array
+ for image_data in data:
+ image_url = image_data.url
+ b64_data = image_data.b64_json
+
+ if not image_url and not b64_data:
+ raise ValueError("No image was generated in the response")
+
+ if b64_data:
+ img_data = base64.b64decode(b64_data)
+ img = Image.open(io.BytesIO(img_data))
+
+ elif image_url:
+ if node_id:
+ PromptServer.instance.send_progress_text(
+ f"Result URL: {image_url}", node_id
+ )
+ img_response = requests.get(image_url, timeout=timeout)
+ if img_response.status_code != 200:
+ raise ValueError("Failed to download the image")
+ img = Image.open(io.BytesIO(img_response.content))
+
+ img = img.convert("RGBA")
+
+ # Convert to numpy array, normalize to float32 between 0 and 1
+ img_array = np.array(img).astype(np.float32) / 255.0
+ img_tensor = torch.from_numpy(img_array)
+
+ # Add to list of tensors
+ image_tensors.append(img_tensor)
+
+ return torch.stack(image_tensors, dim=0)
+
+
+def validate_aspect_ratio(
+ aspect_ratio: str,
+ minimum_ratio: float,
+ maximum_ratio: float,
+ minimum_ratio_str: str,
+ maximum_ratio_str: str,
+) -> float:
+ """Validates and casts an aspect ratio string to a float.
+
+ Args:
+ aspect_ratio: The aspect ratio string to validate.
+ minimum_ratio: The minimum aspect ratio.
+ maximum_ratio: The maximum aspect ratio.
+ minimum_ratio_str: The minimum aspect ratio string.
+ maximum_ratio_str: The maximum aspect ratio string.
+
+ Returns:
+ The validated and cast aspect ratio.
+
+ Raises:
+ Exception: If the aspect ratio is not valid.
+ """
+ # get ratio values
+ numbers = aspect_ratio.split(":")
+ if len(numbers) != 2:
+ raise TypeError(
+ f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}."
+ )
+ try:
+ numerator = int(numbers[0])
+ denominator = int(numbers[1])
+ except ValueError as exc:
+ raise TypeError(
+ f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}."
+ ) from exc
+ calculated_ratio = numerator / denominator
+ # if not close to minimum and maximum, check bounds
+ if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose(
+ calculated_ratio, maximum_ratio
+ ):
+ if calculated_ratio < minimum_ratio:
+ raise TypeError(
+ f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
+ )
+ elif calculated_ratio > maximum_ratio:
+ raise TypeError(
+ f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})."
+ )
+ return aspect_ratio
+
+
+def mimetype_to_extension(mime_type: str) -> str:
+ """Converts a MIME type to a file extension."""
+ return mime_type.split("/")[-1].lower()
+
+
+def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
+ """Downloads content from a URL using requests and returns it as BytesIO.
+
+ Args:
+ url: The URL to download.
+ timeout: Request timeout in seconds. Defaults to None (no timeout).
+
+ Returns:
+ BytesIO object containing the downloaded content.
+ """
+ response = requests.get(url, stream=True, timeout=timeout)
+ response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
+ return BytesIO(response.content)
+
+
+def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
+ """Converts image data from BytesIO to a torch.Tensor.
+
+ Args:
+ image_bytesio: BytesIO object containing the image data.
+ mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA").
+
+ Returns:
+ A torch.Tensor representing the image (1, H, W, C).
+
+ Raises:
+ PIL.UnidentifiedImageError: If the image data cannot be identified.
+ ValueError: If the specified mode is invalid.
+ """
+ image = Image.open(image_bytesio)
+ image = image.convert(mode)
+ image_array = np.array(image).astype(np.float32) / 255.0
+ return torch.from_numpy(image_array).unsqueeze(0)
+
+
+def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
+ """Downloads an image from a URL and returns a [B, H, W, C] tensor."""
+ image_bytesio = download_url_to_bytesio(url, timeout)
+ return bytesio_to_image_tensor(image_bytesio)
+
+
+def process_image_response(response: requests.Response) -> torch.Tensor:
+ """Uses content from a Response object and converts it to a torch.Tensor"""
+ return bytesio_to_image_tensor(BytesIO(response.content))
+
+
+def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
+ """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling."""
+ if len(image.shape) > 3:
+ image = image[0]
+ # TODO: remove alpha if not allowed and present
+ input_tensor = image.cpu()
+ input_tensor = downscale_image_tensor(
+ input_tensor.unsqueeze(0), total_pixels=total_pixels
+ ).squeeze()
+ image_np = (input_tensor.numpy() * 255).astype(np.uint8)
+ img = Image.fromarray(image_np)
+ return img
+
+
+def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO:
+ """Converts a PIL Image to a BytesIO object."""
+ if not mime_type:
+ mime_type = "image/png"
+
+ img_byte_arr = io.BytesIO()
+ # Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG')
+ pil_format = mime_type.split("/")[-1].upper()
+ if pil_format == "JPG":
+ pil_format = "JPEG"
+ img.save(img_byte_arr, format=pil_format)
+ img_byte_arr.seek(0)
+ return img_byte_arr
+
+
+def tensor_to_bytesio(
+ image: torch.Tensor,
+ name: Optional[str] = None,
+ total_pixels: int = 2048 * 2048,
+ mime_type: str = "image/png",
+) -> BytesIO:
+ """Converts a torch.Tensor image to a named BytesIO object.
+
+ Args:
+ image: Input torch.Tensor image.
+ name: Optional filename for the BytesIO object.
+ total_pixels: Maximum total pixels for potential downscaling.
+ mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
+
+ Returns:
+ Named BytesIO object containing the image data.
+ """
+ if not mime_type:
+ mime_type = "image/png"
+
+ pil_image = _tensor_to_pil(image, total_pixels=total_pixels)
+ img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type)
+ img_binary.name = (
+ f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}"
+ )
+ return img_binary
+
+
+def tensor_to_base64_string(
+ image_tensor: torch.Tensor,
+ total_pixels: int = 2048 * 2048,
+ mime_type: str = "image/png",
+) -> str:
+ """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string.
+
+ Args:
+ image_tensor: Input torch.Tensor image.
+ total_pixels: Maximum total pixels for potential downscaling.
+ mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4').
+
+ Returns:
+ Base64 encoded string of the image.
+ """
+ pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels)
+ img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type)
+ img_bytes = img_byte_arr.getvalue()
+ # Encode bytes to base64 string
+ base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8")
+ return base64_encoded_string
+
+
+def tensor_to_data_uri(
+ image_tensor: torch.Tensor,
+ total_pixels: int = 2048 * 2048,
+ mime_type: str = "image/png",
+) -> str:
+ """Converts a tensor image to a Data URI string.
+
+ Args:
+ image_tensor: Input torch.Tensor image.
+ total_pixels: Maximum total pixels for potential downscaling.
+ mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp').
+
+ Returns:
+ Data URI string (e.g., 'data:image/png;base64,...').
+ """
+ base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type)
+ return f"data:{mime_type};base64,{base64_string}"
+
+
+def text_filepath_to_base64_string(filepath: str) -> str:
+ """Converts a text file to a base64 string."""
+ with open(filepath, "rb") as f:
+ file_content = f.read()
+ return base64.b64encode(file_content).decode("utf-8")
+
+
+def text_filepath_to_data_uri(filepath: str) -> str:
+ """Converts a text file to a data URI."""
+ base64_string = text_filepath_to_base64_string(filepath)
+ mime_type, _ = mimetypes.guess_type(filepath)
+ if mime_type is None:
+ mime_type = "application/octet-stream"
+ return f"data:{mime_type};base64,{base64_string}"
+
+
+def upload_file_to_comfyapi(
+ file_bytes_io: BytesIO,
+ filename: str,
+ upload_mime_type: str,
+ auth_kwargs: Optional[dict[str, str]] = None,
+) -> str:
+ """
+ Uploads a single file to ComfyUI API and returns its download URL.
+
+ Args:
+ file_bytes_io: BytesIO object containing the file data.
+ filename: The filename of the file.
+ upload_mime_type: MIME type of the file.
+ auth_kwargs: Optional authentication token(s).
+
+ Returns:
+ The download URL for the uploaded file.
+ """
+ request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/customers/storage",
+ method=HttpMethod.POST,
+ request_model=UploadRequest,
+ response_model=UploadResponse,
+ ),
+ request=request_object,
+ auth_kwargs=auth_kwargs,
+ )
+
+ response: UploadResponse = operation.execute()
+ upload_response = ApiClient.upload_file(
+ response.upload_url, file_bytes_io, content_type=upload_mime_type
+ )
+ upload_response.raise_for_status()
+
+ return response.download_url
+
+
+def video_to_base64_string(
+ video: VideoInput,
+ container_format: VideoContainer = None,
+ codec: VideoCodec = None
+) -> str:
+ """
+ Converts a video input to a base64 string.
+
+ Args:
+ video: The video input to convert
+ container_format: Optional container format to use (defaults to video.container if available)
+ codec: Optional codec to use (defaults to video.codec if available)
+ """
+ video_bytes_io = io.BytesIO()
+
+ # Use provided format/codec if specified, otherwise use video's own if available
+ format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4)
+ codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264)
+
+ video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use)
+ video_bytes_io.seek(0)
+ return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")
+
+
+def upload_video_to_comfyapi(
+ video: VideoInput,
+ auth_kwargs: Optional[dict[str, str]] = None,
+ container: VideoContainer = VideoContainer.MP4,
+ codec: VideoCodec = VideoCodec.H264,
+ max_duration: Optional[int] = None,
+) -> str:
+ """
+ Uploads a single video to ComfyUI API and returns its download URL.
+ Uses the specified container and codec for saving the video before upload.
+
+ Args:
+ video: VideoInput object (Comfy VIDEO type).
+ auth_kwargs: Optional authentication token(s).
+ container: The video container format to use (default: MP4).
+ codec: The video codec to use (default: H264).
+ max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised.
+
+ Returns:
+ The download URL for the uploaded video file.
+ """
+ if max_duration is not None:
+ try:
+ actual_duration = video.duration_seconds
+ if actual_duration is not None and actual_duration > max_duration:
+ raise ValueError(
+ f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)."
+ )
+ except Exception as e:
+ logging.error(f"Error getting video duration: {e}")
+ raise ValueError(f"Could not verify video duration from source: {e}") from e
+
+ upload_mime_type = f"video/{container.value.lower()}"
+ filename = f"uploaded_video.{container.value.lower()}"
+
+ # Convert VideoInput to BytesIO using specified container/codec
+ video_bytes_io = io.BytesIO()
+ video.save_to(video_bytes_io, format=container, codec=codec)
+ video_bytes_io.seek(0)
+
+ return upload_file_to_comfyapi(
+ video_bytes_io, filename, upload_mime_type, auth_kwargs
+ )
+
+
+def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
+ """
+ Prepares audio waveform for av library by converting to a contiguous numpy array.
+
+ Args:
+ waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type.
+
+ Returns:
+ Contiguous numpy array of the audio waveform. If the audio was batched,
+ the first item is taken.
+ """
+ if waveform.ndim != 3 or waveform.shape[0] != 1:
+ raise ValueError("Expected waveform tensor shape (1, channels, samples)")
+
+ # If batch is > 1, take first item
+ if waveform.shape[0] > 1:
+ waveform = waveform[0]
+
+ # Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array
+ audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy()
+ if audio_data_np.dtype != np.float32:
+ audio_data_np = audio_data_np.astype(np.float32)
+
+ return audio_data_np
+
+
+def audio_ndarray_to_bytesio(
+ audio_data_np: np.ndarray,
+ sample_rate: int,
+ container_format: str = "mp4",
+ codec_name: str = "aac",
+) -> BytesIO:
+ """
+ Encodes a numpy array of audio data into a BytesIO object.
+ """
+ audio_bytes_io = io.BytesIO()
+ with av.open(audio_bytes_io, mode="w", format=container_format) as output_container:
+ audio_stream = output_container.add_stream(codec_name, rate=sample_rate)
+ frame = av.AudioFrame.from_ndarray(
+ audio_data_np,
+ format="fltp",
+ layout="stereo" if audio_data_np.shape[0] > 1 else "mono",
+ )
+ frame.sample_rate = sample_rate
+ frame.pts = 0
+
+ for packet in audio_stream.encode(frame):
+ output_container.mux(packet)
+
+ # Flush stream
+ for packet in audio_stream.encode(None):
+ output_container.mux(packet)
+
+ audio_bytes_io.seek(0)
+ return audio_bytes_io
+
+
+def upload_audio_to_comfyapi(
+ audio: AudioInput,
+ auth_kwargs: Optional[dict[str, str]] = None,
+ container_format: str = "mp4",
+ codec_name: str = "aac",
+ mime_type: str = "audio/mp4",
+ filename: str = "uploaded_audio.mp4",
+) -> str:
+ """
+ Uploads a single audio input to ComfyUI API and returns its download URL.
+ Encodes the raw waveform into the specified format before uploading.
+
+ Args:
+ audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate)
+ auth_kwargs: Optional authentication token(s).
+
+ Returns:
+ The download URL for the uploaded audio file.
+ """
+ sample_rate: int = audio["sample_rate"]
+ waveform: torch.Tensor = audio["waveform"]
+ audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
+ audio_bytes_io = audio_ndarray_to_bytesio(
+ audio_data_np, sample_rate, container_format, codec_name
+ )
+
+ return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
+
+
+def audio_to_base64_string(
+ audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac"
+) -> str:
+ """Converts an audio input to a base64 string."""
+ sample_rate: int = audio["sample_rate"]
+ waveform: torch.Tensor = audio["waveform"]
+ audio_data_np = audio_tensor_to_contiguous_ndarray(waveform)
+ audio_bytes_io = audio_ndarray_to_bytesio(
+ audio_data_np, sample_rate, container_format, codec_name
+ )
+ audio_bytes = audio_bytes_io.getvalue()
+ return base64.b64encode(audio_bytes).decode("utf-8")
+
+
+def upload_images_to_comfyapi(
+ image: torch.Tensor,
+ max_images=8,
+ auth_kwargs: Optional[dict[str, str]] = None,
+ mime_type: Optional[str] = None,
+) -> list[str]:
+ """
+ Uploads images to ComfyUI API and returns download URLs.
+ To upload multiple images, stack them in the batch dimension first.
+
+ Args:
+ image: Input torch.Tensor image.
+ max_images: Maximum number of images to upload.
+ auth_kwargs: Optional authentication token(s).
+ mime_type: Optional MIME type for the image.
+ """
+ # if batch, try to upload each file if max_images is greater than 0
+ idx_image = 0
+ download_urls: list[str] = []
+ is_batch = len(image.shape) > 3
+ batch_length = 1
+ if is_batch:
+ batch_length = image.shape[0]
+ while True:
+ curr_image = image
+ if len(image.shape) > 3:
+ curr_image = image[idx_image]
+ # get BytesIO version of image
+ img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type)
+ # first, request upload/download urls from comfy API
+ if not mime_type:
+ request_object = UploadRequest(file_name=img_binary.name)
+ else:
+ request_object = UploadRequest(
+ file_name=img_binary.name, content_type=mime_type
+ )
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/customers/storage",
+ method=HttpMethod.POST,
+ request_model=UploadRequest,
+ response_model=UploadResponse,
+ ),
+ request=request_object,
+ auth_kwargs=auth_kwargs,
+ )
+ response = operation.execute()
+
+ upload_response = ApiClient.upload_file(
+ response.upload_url, img_binary, content_type=mime_type
+ )
+ # verify success
+ try:
+ upload_response.raise_for_status()
+ except requests.exceptions.HTTPError as e:
+ raise ValueError(f"Could not upload one or more images: {e}") from e
+ # add download_url to list
+ download_urls.append(response.download_url)
+
+ idx_image += 1
+ # stop uploading additional files if done
+ if is_batch and max_images > 0:
+ if idx_image >= max_images:
+ break
+ if idx_image >= batch_length:
+ break
+ return download_urls
+
+
+def resize_mask_to_image(
+ mask: torch.Tensor,
+ image: torch.Tensor,
+ upscale_method="nearest-exact",
+ crop="disabled",
+ allow_gradient=True,
+ add_channel_dim=False,
+):
+ """
+ Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.
+ """
+ _, H, W, _ = image.shape
+ mask = mask.unsqueeze(-1)
+ mask = mask.movedim(-1, 1)
+ mask = common_upscale(
+ mask, width=W, height=H, upscale_method=upscale_method, crop=crop
+ )
+ mask = mask.movedim(1, -1)
+ if not add_channel_dim:
+ mask = mask.squeeze(-1)
+ if not allow_gradient:
+ mask = (mask > 0.5).float()
+ return mask
+
+
+def validate_string(
+ string: str,
+ strip_whitespace=True,
+ field_name="prompt",
+ min_length=None,
+ max_length=None,
+):
+ if string is None:
+ raise Exception(f"Field '{field_name}' cannot be empty.")
+ if strip_whitespace:
+ string = string.strip()
+ if min_length and len(string) < min_length:
+ raise Exception(
+ f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long."
+ )
+ if max_length and len(string) > max_length:
+ raise Exception(
+ f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long."
+ )
+
+
+def image_tensor_pair_to_batch(
+ image1: torch.Tensor, image2: torch.Tensor
+) -> torch.Tensor:
+ """
+ Converts a pair of image tensors to a batch tensor.
+ If the images are not the same size, the smaller image is resized to
+ match the larger image.
+ """
+ if image1.shape[1:] != image2.shape[1:]:
+ image2 = common_upscale(
+ image2.movedim(-1, 1),
+ image1.shape[2],
+ image1.shape[1],
+ "bilinear",
+ "center",
+ ).movedim(1, -1)
+ return torch.cat((image1, image2), dim=0)
diff --git a/comfy_api_nodes/apis/PixverseController.py b/comfy_api_nodes/apis/PixverseController.py
new file mode 100644
index 000000000..310c0f546
--- /dev/null
+++ b/comfy_api_nodes/apis/PixverseController.py
@@ -0,0 +1,17 @@
+# generated by datamodel-codegen:
+# filename: filtered-openapi.yaml
+# timestamp: 2025-04-29T23:44:54+00:00
+
+from __future__ import annotations
+
+from typing import Optional
+
+from pydantic import BaseModel
+
+from . import PixverseDto
+
+
+class ResponseData(BaseModel):
+ ErrCode: Optional[int] = None
+ ErrMsg: Optional[str] = None
+ Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None
diff --git a/comfy_api_nodes/apis/PixverseDto.py b/comfy_api_nodes/apis/PixverseDto.py
new file mode 100644
index 000000000..323c38e96
--- /dev/null
+++ b/comfy_api_nodes/apis/PixverseDto.py
@@ -0,0 +1,57 @@
+# generated by datamodel-codegen:
+# filename: filtered-openapi.yaml
+# timestamp: 2025-04-29T23:44:54+00:00
+
+from __future__ import annotations
+
+from typing import Optional
+
+from pydantic import BaseModel, Field
+
+
+class V2OpenAPII2VResp(BaseModel):
+ video_id: Optional[int] = Field(None, description='Video_id')
+
+
+class V2OpenAPIT2VReq(BaseModel):
+ aspect_ratio: str = Field(
+ ..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9']
+ )
+ duration: int = Field(
+ ...,
+ description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)',
+ examples=[5],
+ )
+ model: str = Field(
+ ..., description='Model version (only supports v3.5)', examples=['v3.5']
+ )
+ motion_mode: Optional[str] = Field(
+ 'normal',
+ description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)',
+ examples=['normal'],
+ )
+ negative_prompt: Optional[str] = Field(
+ None, description='Negative prompt\n', max_length=2048
+ )
+ prompt: str = Field(..., description='Prompt', max_length=2048)
+ quality: str = Field(
+ ...,
+ description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")',
+ examples=['540p'],
+ )
+ seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647')
+ style: Optional[str] = Field(
+ None,
+ description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed',
+ examples=['anime'],
+ )
+ template_id: Optional[int] = Field(
+ None,
+ description='Template ID (template_id must be activated before use)',
+ examples=[302325299692608],
+ )
+ water_mark: Optional[bool] = Field(
+ False,
+ description='Watermark (true: add watermark, false: no watermark)',
+ examples=[False],
+ )
diff --git a/comfy_api_nodes/apis/__init__.py b/comfy_api_nodes/apis/__init__.py
new file mode 100644
index 000000000..e38d38cc9
--- /dev/null
+++ b/comfy_api_nodes/apis/__init__.py
@@ -0,0 +1,3453 @@
+# generated by datamodel-codegen:
+# filename: filtered-openapi.yaml
+# timestamp: 2025-05-19T21:38:55+00:00
+
+from __future__ import annotations
+
+from datetime import date, datetime
+from enum import Enum
+from typing import Any, Dict, List, Literal, Optional, Union
+from uuid import UUID
+
+from pydantic import AnyUrl, BaseModel, ConfigDict, Field, RootModel, StrictBytes
+
+
+class APIKey(BaseModel):
+ created_at: Optional[datetime] = None
+ description: Optional[str] = None
+ id: Optional[str] = None
+ key_prefix: Optional[str] = None
+ name: Optional[str] = None
+
+
+class APIKeyWithPlaintext(APIKey):
+ plaintext_key: Optional[str] = Field(
+ None, description='The full API key (only returned at creation)'
+ )
+
+
+class AuditLog(BaseModel):
+ createdAt: Optional[datetime] = Field(
+ None, description='The date and time the event was created'
+ )
+ event_id: Optional[str] = Field(None, description='the id of the event')
+ event_type: Optional[str] = Field(None, description='the type of the event')
+ params: Optional[Dict[str, Any]] = Field(
+ None, description='data related to the event'
+ )
+
+
+class OutputFormat(str, Enum):
+ jpeg = 'jpeg'
+ png = 'png'
+
+
+class BFLFluxPro11GenerateRequest(BaseModel):
+ height: int = Field(..., description='Height of the generated image')
+ image_prompt: Optional[str] = Field(None, description='Optional image prompt')
+ output_format: Optional[OutputFormat] = Field(
+ None, description='Output image format'
+ )
+ prompt: str = Field(..., description='The main text prompt for image generation')
+ prompt_upsampling: Optional[bool] = Field(
+ None, description='Whether to use prompt upsampling'
+ )
+ safety_tolerance: Optional[int] = Field(None, description='Safety tolerance level')
+ seed: Optional[int] = Field(None, description='Random seed for reproducibility')
+ webhook_secret: Optional[str] = Field(
+ None, description='Optional webhook secret for async processing'
+ )
+ webhook_url: Optional[str] = Field(
+ None, description='Optional webhook URL for async processing'
+ )
+ width: int = Field(..., description='Width of the generated image')
+
+
+class BFLFluxPro11GenerateResponse(BaseModel):
+ id: str = Field(..., description='Job ID for tracking')
+ polling_url: str = Field(..., description='URL to poll for results')
+
+
+class BFLFluxProGenerateRequest(BaseModel):
+ guidance_scale: Optional[float] = Field(
+ None, description='The guidance scale for generation.', ge=1.0, le=20.0
+ )
+ height: int = Field(
+ ..., description='The height of the image to generate.', ge=64, le=2048
+ )
+ negative_prompt: Optional[str] = Field(
+ None, description='The negative prompt for image generation.'
+ )
+ num_images: Optional[int] = Field(
+ None, description='The number of images to generate.', ge=1, le=4
+ )
+ num_inference_steps: Optional[int] = Field(
+ None, description='The number of inference steps.', ge=1, le=100
+ )
+ prompt: str = Field(..., description='The text prompt for image generation.')
+ seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
+ width: int = Field(
+ ..., description='The width of the image to generate.', ge=64, le=2048
+ )
+
+
+class BFLFluxProGenerateResponse(BaseModel):
+ id: str = Field(..., description='The unique identifier for the generation task.')
+ polling_url: str = Field(..., description='URL to poll for the generation result.')
+
+
+class Status(str, Enum):
+ in_progress = 'in_progress'
+ completed = 'completed'
+ incomplete = 'incomplete'
+
+
+class Type(str, Enum):
+ computer_call = 'computer_call'
+
+
+class ComputerToolCall(BaseModel):
+ action: Dict[str, Any]
+ call_id: str = Field(
+ ...,
+ description='An identifier used when responding to the tool call with output.\n',
+ )
+ id: str = Field(..., description='The unique ID of the computer call.')
+ status: Status = Field(
+ ...,
+ description='The status of the item. One of `in_progress`, `completed`, or\n`incomplete`. Populated when items are returned via API.\n',
+ )
+ type: Type = Field(
+ ..., description='The type of the computer call. Always `computer_call`.'
+ )
+
+
+class Environment(str, Enum):
+ windows = 'windows'
+ mac = 'mac'
+ linux = 'linux'
+ ubuntu = 'ubuntu'
+ browser = 'browser'
+
+
+class Type1(str, Enum):
+ computer_use_preview = 'computer_use_preview'
+
+
+class ComputerUsePreviewTool(BaseModel):
+ display_height: int = Field(..., description='The height of the computer display.')
+ display_width: int = Field(..., description='The width of the computer display.')
+ environment: Environment = Field(
+ ..., description='The type of computer environment to control.'
+ )
+ type: Literal['ComputerUsePreviewTool'] = Field(
+ ...,
+ description='The type of the computer use tool. Always `computer_use_preview`.',
+ )
+
+
+class CreateAPIKeyRequest(BaseModel):
+ description: Optional[str] = None
+ name: str
+
+
+class Customer(BaseModel):
+ createdAt: Optional[datetime] = Field(
+ None, description='The date and time the user was created'
+ )
+ email: Optional[str] = Field(None, description='The email address for this user')
+ id: str = Field(..., description='The firebase UID of the user')
+ is_admin: Optional[bool] = Field(None, description='Whether the user is an admin')
+ metronome_id: Optional[str] = Field(None, description='The Metronome customer ID')
+ name: Optional[str] = Field(None, description='The name for this user')
+ stripe_id: Optional[str] = Field(None, description='The Stripe customer ID')
+ updatedAt: Optional[datetime] = Field(
+ None, description='The date and time the user was last updated'
+ )
+
+
+class CustomerStorageResourceResponse(BaseModel):
+ download_url: Optional[str] = Field(
+ None,
+ description='The signed URL to use for downloading the file from the specified path',
+ )
+ existing_file: Optional[bool] = Field(
+ None, description='Whether an existing file with the same hash was found'
+ )
+ expires_at: Optional[datetime] = Field(
+ None, description='When the signed URL will expire'
+ )
+ upload_url: Optional[str] = Field(
+ None,
+ description='The signed URL to use for uploading the file to the specified path',
+ )
+
+
+class Role(str, Enum):
+ user = 'user'
+ assistant = 'assistant'
+ system = 'system'
+ developer = 'developer'
+
+
+class Type2(str, Enum):
+ message = 'message'
+
+
+class ErrorResponse(BaseModel):
+ error: str
+ message: str
+
+
+class Type3(str, Enum):
+ file_search = 'file_search'
+
+
+class FileSearchTool(BaseModel):
+ type: Literal['FileSearchTool'] = Field(..., description='The type of tool')
+ vector_store_ids: List[str] = Field(
+ ..., description='IDs of vector stores to search in'
+ )
+
+
+class Result(BaseModel):
+ file_id: Optional[str] = Field(None, description='The unique ID of the file.\n')
+ filename: Optional[str] = Field(None, description='The name of the file.\n')
+ score: Optional[float] = Field(
+ None, description='The relevance score of the file - a value between 0 and 1.\n'
+ )
+ text: Optional[str] = Field(
+ None, description='The text that was retrieved from the file.\n'
+ )
+
+
+class Status1(str, Enum):
+ in_progress = 'in_progress'
+ searching = 'searching'
+ completed = 'completed'
+ incomplete = 'incomplete'
+ failed = 'failed'
+
+
+class Type4(str, Enum):
+ file_search_call = 'file_search_call'
+
+
+class FileSearchToolCall(BaseModel):
+ id: str = Field(..., description='The unique ID of the file search tool call.\n')
+ queries: List[str] = Field(
+ ..., description='The queries used to search for files.\n'
+ )
+ results: Optional[List[Result]] = Field(
+ None, description='The results of the file search tool call.\n'
+ )
+ status: Status1 = Field(
+ ...,
+ description='The status of the file search tool call. One of `in_progress`, \n`searching`, `incomplete` or `failed`,\n',
+ )
+ type: Type4 = Field(
+ ...,
+ description='The type of the file search tool call. Always `file_search_call`.\n',
+ )
+
+
+class Type5(str, Enum):
+ function = 'function'
+
+
+class FunctionTool(BaseModel):
+ description: Optional[str] = Field(
+ None, description='Description of what the function does'
+ )
+ name: str = Field(..., description='Name of the function')
+ parameters: Dict[str, Any] = Field(
+ ..., description='JSON Schema object describing the function parameters'
+ )
+ type: Literal['FunctionTool'] = Field(..., description='The type of tool')
+
+
+class Status2(str, Enum):
+ in_progress = 'in_progress'
+ completed = 'completed'
+ incomplete = 'incomplete'
+
+
+class Type6(str, Enum):
+ function_call = 'function_call'
+
+
+class FunctionToolCall(BaseModel):
+ arguments: str = Field(
+ ..., description='A JSON string of the arguments to pass to the function.\n'
+ )
+ call_id: str = Field(
+ ...,
+ description='The unique ID of the function tool call generated by the model.\n',
+ )
+ id: Optional[str] = Field(
+ None, description='The unique ID of the function tool call.\n'
+ )
+ name: str = Field(..., description='The name of the function to run.\n')
+ status: Optional[Status2] = Field(
+ None,
+ description='The status of the item. One of `in_progress`, `completed`, or\n`incomplete`. Populated when items are returned via API.\n',
+ )
+ type: Type6 = Field(
+ ..., description='The type of the function tool call. Always `function_call`.\n'
+ )
+
+
+class GeminiCitation(BaseModel):
+ authors: Optional[List[str]] = None
+ endIndex: Optional[int] = None
+ license: Optional[str] = None
+ publicationDate: Optional[date] = None
+ startIndex: Optional[int] = None
+ title: Optional[str] = None
+ uri: Optional[str] = None
+
+
+class GeminiCitationMetadata(BaseModel):
+ citations: Optional[List[GeminiCitation]] = None
+
+
+class Role1(str, Enum):
+ user = 'user'
+ model = 'model'
+
+
+class GeminiFunctionDeclaration(BaseModel):
+ description: Optional[str] = None
+ name: str
+ parameters: Dict[str, Any] = Field(
+ ..., description='JSON schema for the function parameters'
+ )
+
+
+class GeminiGenerationConfig(BaseModel):
+ maxOutputTokens: Optional[int] = Field(
+ None,
+ description='Maximum number of tokens that can be generated in the response. A token is approximately 4 characters. 100 tokens correspond to roughly 60-80 words.\n',
+ examples=[2048],
+ ge=16,
+ le=8192,
+ )
+ seed: Optional[int] = Field(
+ None,
+ description="When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used. Available for the following models:, gemini-2.5-flash-preview-04-1, gemini-2.5-pro-preview-05-0, gemini-2.0-flash-lite-00, gemini-2.0-flash-001\n",
+ examples=[343940597],
+ )
+ stopSequences: Optional[List[str]] = None
+ temperature: Optional[float] = Field(
+ 1,
+ description="The temperature is used for sampling during response generation, which occurs when topP and topK are applied. Temperature controls the degree of randomness in token selection. Lower temperatures are good for prompts that require a less open-ended or creative response, while higher temperatures can lead to more diverse or creative results. A temperature of 0 means that the highest probability tokens are always selected. In this case, responses for a given prompt are mostly deterministic, but a small amount of variation is still possible. If the model returns a response that's too generic, too short, or the model gives a fallback response, try increasing the temperature\n",
+ ge=0.0,
+ le=2.0,
+ )
+ topK: Optional[int] = Field(
+ 40,
+ description="Top-K changes how the model selects tokens for output. A top-K of 1 means the next selected token is the most probable among all tokens in the model's vocabulary. A top-K of 3 means that the next token is selected from among the 3 most probable tokens by using temperature.\n",
+ examples=[40],
+ ge=1,
+ )
+ topP: Optional[float] = Field(
+ 0.95,
+ description='If specified, nucleus sampling is used.\nTop-P changes how the model selects tokens for output. Tokens are selected from the most (see top-K) to least probable until the sum of their probabilities equals the top-P value. For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-P value is 0.5, then the model will select either A or B as the next token by using temperature and excludes C as a candidate.\nSpecify a lower value for less random responses and a higher value for more random responses.\n',
+ ge=0.0,
+ le=1.0,
+ )
+
+
+class GeminiMimeType(str, Enum):
+ application_pdf = 'application/pdf'
+ audio_mpeg = 'audio/mpeg'
+ audio_mp3 = 'audio/mp3'
+ audio_wav = 'audio/wav'
+ image_png = 'image/png'
+ image_jpeg = 'image/jpeg'
+ image_webp = 'image/webp'
+ text_plain = 'text/plain'
+ video_mov = 'video/mov'
+ video_mpeg = 'video/mpeg'
+ video_mp4 = 'video/mp4'
+ video_mpg = 'video/mpg'
+ video_avi = 'video/avi'
+ video_wmv = 'video/wmv'
+ video_mpegps = 'video/mpegps'
+ video_flv = 'video/flv'
+
+
+class GeminiOffset(BaseModel):
+ nanos: Optional[int] = Field(
+ None,
+ description='Signed fractions of a second at nanosecond resolution. Negative second values with fractions must still have non-negative nanos values.\n',
+ examples=[0],
+ ge=0,
+ le=999999999,
+ )
+ seconds: Optional[int] = Field(
+ None,
+ description='Signed seconds of the span of time. Must be from -315,576,000,000 to +315,576,000,000 inclusive.\n',
+ examples=[60],
+ ge=-315576000000,
+ le=315576000000,
+ )
+
+
+class GeminiSafetyCategory(str, Enum):
+ HARM_CATEGORY_SEXUALLY_EXPLICIT = 'HARM_CATEGORY_SEXUALLY_EXPLICIT'
+ HARM_CATEGORY_HATE_SPEECH = 'HARM_CATEGORY_HATE_SPEECH'
+ HARM_CATEGORY_HARASSMENT = 'HARM_CATEGORY_HARASSMENT'
+ HARM_CATEGORY_DANGEROUS_CONTENT = 'HARM_CATEGORY_DANGEROUS_CONTENT'
+
+
+class Probability(str, Enum):
+ NEGLIGIBLE = 'NEGLIGIBLE'
+ LOW = 'LOW'
+ MEDIUM = 'MEDIUM'
+ HIGH = 'HIGH'
+ UNKNOWN = 'UNKNOWN'
+
+
+class GeminiSafetyRating(BaseModel):
+ category: Optional[GeminiSafetyCategory] = None
+ probability: Optional[Probability] = Field(
+ None,
+ description='The probability that the content violates the specified safety category',
+ )
+
+
+class GeminiSafetyThreshold(str, Enum):
+ OFF = 'OFF'
+ BLOCK_NONE = 'BLOCK_NONE'
+ BLOCK_LOW_AND_ABOVE = 'BLOCK_LOW_AND_ABOVE'
+ BLOCK_MEDIUM_AND_ABOVE = 'BLOCK_MEDIUM_AND_ABOVE'
+ BLOCK_ONLY_HIGH = 'BLOCK_ONLY_HIGH'
+
+
+class GeminiTextPart(BaseModel):
+ text: Optional[str] = Field(
+ None,
+ description='A text prompt or code snippet.',
+ examples=['Answer as concisely as possible'],
+ )
+
+
+class GeminiTool(BaseModel):
+ functionDeclarations: Optional[List[GeminiFunctionDeclaration]] = None
+
+
+class GeminiVideoMetadata(BaseModel):
+ endOffset: Optional[GeminiOffset] = None
+ startOffset: Optional[GeminiOffset] = None
+
+
+class IdeogramColorPalette1(BaseModel):
+ name: str = Field(..., description='Name of the preset color palette')
+
+
+class Member(BaseModel):
+ color: Optional[str] = Field(
+ None, description='Hexadecimal color code', pattern='^#[0-9A-Fa-f]{6}$'
+ )
+ weight: Optional[float] = Field(
+ None, description='Optional weight for the color (0-1)', ge=0.0, le=1.0
+ )
+
+
+class IdeogramColorPalette2(BaseModel):
+ members: List[Member] = Field(
+ ..., description='Array of color definitions with optional weights'
+ )
+
+
+class IdeogramColorPalette(
+ RootModel[Union[IdeogramColorPalette1, IdeogramColorPalette2]]
+):
+ root: Union[IdeogramColorPalette1, IdeogramColorPalette2] = Field(
+ ...,
+ description='A color palette specification that can either use a preset name or explicit color definitions with weights',
+ )
+
+
+class ImageRequest(BaseModel):
+ aspect_ratio: Optional[str] = Field(
+ None,
+ description="Optional. The aspect ratio (e.g., 'ASPECT_16_9', 'ASPECT_1_1'). Cannot be used with resolution. Defaults to 'ASPECT_1_1' if unspecified.",
+ )
+ color_palette: Optional[Dict[str, Any]] = Field(
+ None, description='Optional. Color palette object. Only for V_2, V_2_TURBO.'
+ )
+ magic_prompt_option: Optional[str] = Field(
+ None, description="Optional. MagicPrompt usage ('AUTO', 'ON', 'OFF')."
+ )
+ model: str = Field(..., description="The model used (e.g., 'V_2', 'V_2A_TURBO')")
+ negative_prompt: Optional[str] = Field(
+ None,
+ description='Optional. Description of what to exclude. Only for V_1, V_1_TURBO, V_2, V_2_TURBO.',
+ )
+ num_images: Optional[int] = Field(
+ 1,
+ description='Optional. Number of images to generate (1-8). Defaults to 1.',
+ ge=1,
+ le=8,
+ )
+ prompt: str = Field(
+ ..., description='Required. The prompt to use to generate the image.'
+ )
+ resolution: Optional[str] = Field(
+ None,
+ description="Optional. Resolution (e.g., 'RESOLUTION_1024_1024'). Only for model V_2. Cannot be used with aspect_ratio.",
+ )
+ seed: Optional[int] = Field(
+ None,
+ description='Optional. A number between 0 and 2147483647.',
+ ge=0,
+ le=2147483647,
+ )
+ style_type: Optional[str] = Field(
+ None,
+ description="Optional. Style type ('AUTO', 'GENERAL', 'REALISTIC', 'DESIGN', 'RENDER_3D', 'ANIME'). Only for models V_2 and above.",
+ )
+
+
+class IdeogramGenerateRequest(BaseModel):
+ image_request: ImageRequest = Field(
+ ..., description='The image generation request parameters.'
+ )
+
+
+class Datum(BaseModel):
+ is_image_safe: Optional[bool] = Field(
+ None, description='Indicates whether the image is considered safe.'
+ )
+ prompt: Optional[str] = Field(
+ None, description='The prompt used to generate this image.'
+ )
+ resolution: Optional[str] = Field(
+ None, description="The resolution of the generated image (e.g., '1024x1024')."
+ )
+ seed: Optional[int] = Field(
+ None, description='The seed value used for this generation.'
+ )
+ style_type: Optional[str] = Field(
+ None,
+ description="The style type used for generation (e.g., 'REALISTIC', 'ANIME').",
+ )
+ url: Optional[str] = Field(None, description='URL to the generated image.')
+
+
+class IdeogramGenerateResponse(BaseModel):
+ created: Optional[datetime] = Field(
+ None, description='Timestamp when the generation was created.'
+ )
+ data: Optional[List[Datum]] = Field(
+ None, description='Array of generated image information.'
+ )
+
+
+class StyleCode(RootModel[str]):
+ root: str = Field(..., pattern='^[0-9A-Fa-f]{8}$')
+
+
+class Datum1(BaseModel):
+ is_image_safe: Optional[bool] = None
+ prompt: Optional[str] = None
+ resolution: Optional[str] = None
+ seed: Optional[int] = None
+ style_type: Optional[str] = None
+ url: Optional[str] = None
+
+
+class IdeogramV3IdeogramResponse(BaseModel):
+ created: Optional[datetime] = None
+ data: Optional[List[Datum1]] = None
+
+
+class RenderingSpeed1(str, Enum):
+ TURBO = 'TURBO'
+ DEFAULT = 'DEFAULT'
+ QUALITY = 'QUALITY'
+
+
+class IdeogramV3ReframeRequest(BaseModel):
+ color_palette: Optional[Dict[str, Any]] = None
+ image: Optional[StrictBytes] = None
+ num_images: Optional[int] = Field(None, ge=1, le=8)
+ rendering_speed: Optional[RenderingSpeed1] = None
+ resolution: str
+ seed: Optional[int] = Field(None, ge=0, le=2147483647)
+ style_codes: Optional[List[str]] = None
+ style_reference_images: Optional[List[StrictBytes]] = None
+
+
+class MagicPrompt(str, Enum):
+ AUTO = 'AUTO'
+ ON = 'ON'
+ OFF = 'OFF'
+
+
+class StyleType(str, Enum):
+ AUTO = 'AUTO'
+ GENERAL = 'GENERAL'
+ REALISTIC = 'REALISTIC'
+ DESIGN = 'DESIGN'
+
+
+class IdeogramV3RemixRequest(BaseModel):
+ aspect_ratio: Optional[str] = None
+ color_palette: Optional[Dict[str, Any]] = None
+ image: Optional[StrictBytes] = None
+ image_weight: Optional[int] = Field(50, ge=1, le=100)
+ magic_prompt: Optional[MagicPrompt] = None
+ negative_prompt: Optional[str] = None
+ num_images: Optional[int] = Field(None, ge=1, le=8)
+ prompt: str
+ rendering_speed: Optional[RenderingSpeed1] = None
+ resolution: Optional[str] = None
+ seed: Optional[int] = Field(None, ge=0, le=2147483647)
+ style_codes: Optional[List[str]] = None
+ style_reference_images: Optional[List[StrictBytes]] = None
+ style_type: Optional[StyleType] = None
+
+
+class IdeogramV3ReplaceBackgroundRequest(BaseModel):
+ color_palette: Optional[Dict[str, Any]] = None
+ image: Optional[StrictBytes] = None
+ magic_prompt: Optional[MagicPrompt] = None
+ num_images: Optional[int] = Field(None, ge=1, le=8)
+ prompt: str
+ rendering_speed: Optional[RenderingSpeed1] = None
+ seed: Optional[int] = Field(None, ge=0, le=2147483647)
+ style_codes: Optional[List[str]] = None
+ style_reference_images: Optional[List[StrictBytes]] = None
+
+
+class ColorPalette(BaseModel):
+ name: str = Field(..., description='Name of the color palette', examples=['PASTEL'])
+
+
+class MagicPrompt2(str, Enum):
+ ON = 'ON'
+ OFF = 'OFF'
+
+
+class StyleType1(str, Enum):
+ GENERAL = 'GENERAL'
+
+
+class ImagenImageGenerationInstance(BaseModel):
+ prompt: str = Field(..., description='Text prompt for image generation')
+
+
+class AspectRatio(str, Enum):
+ field_1_1 = '1:1'
+ field_9_16 = '9:16'
+ field_16_9 = '16:9'
+ field_3_4 = '3:4'
+ field_4_3 = '4:3'
+
+
+class PersonGeneration(str, Enum):
+ dont_allow = 'dont_allow'
+ allow_adult = 'allow_adult'
+ allow_all = 'allow_all'
+
+
+class SafetySetting(str, Enum):
+ block_most = 'block_most'
+ block_some = 'block_some'
+ block_few = 'block_few'
+ block_fewest = 'block_fewest'
+
+
+class ImagenImagePrediction(BaseModel):
+ bytesBase64Encoded: Optional[str] = Field(
+ None, description='Base64-encoded image content'
+ )
+ mimeType: Optional[str] = Field(
+ None, description='MIME type of the generated image'
+ )
+ prompt: Optional[str] = Field(
+ None, description='Enhanced or rewritten prompt used to generate this image'
+ )
+
+
+class MimeType(str, Enum):
+ image_png = 'image/png'
+ image_jpeg = 'image/jpeg'
+
+
+class ImagenOutputOptions(BaseModel):
+ compressionQuality: Optional[int] = Field(None, ge=0, le=100)
+ mimeType: Optional[MimeType] = None
+
+
+class Includable(str, Enum):
+ file_search_call_results = 'file_search_call.results'
+ message_input_image_image_url = 'message.input_image.image_url'
+ computer_call_output_output_image_url = 'computer_call_output.output.image_url'
+
+
+class Type7(str, Enum):
+ input_file = 'input_file'
+
+
+class InputFileContent(BaseModel):
+ file_data: Optional[str] = Field(
+ None, description='The content of the file to be sent to the model.\n'
+ )
+ file_id: Optional[str] = Field(
+ None, description='The ID of the file to be sent to the model.'
+ )
+ filename: Optional[str] = Field(
+ None, description='The name of the file to be sent to the model.'
+ )
+ type: Type7 = Field(
+ ..., description='The type of the input item. Always `input_file`.'
+ )
+
+
+class Detail(str, Enum):
+ low = 'low'
+ high = 'high'
+ auto = 'auto'
+
+
+class Type8(str, Enum):
+ input_image = 'input_image'
+
+
+class InputImageContent(BaseModel):
+ detail: Detail = Field(
+ ...,
+ description='The detail level of the image to be sent to the model. One of `high`, `low`, or `auto`. Defaults to `auto`.',
+ )
+ file_id: Optional[str] = Field(
+ None, description='The ID of the file to be sent to the model.'
+ )
+ image_url: Optional[str] = Field(
+ None,
+ description='The URL of the image to be sent to the model. A fully qualified URL or base64 encoded image in a data URL.',
+ )
+ type: Type8 = Field(
+ ..., description='The type of the input item. Always `input_image`.'
+ )
+
+
+class Role3(str, Enum):
+ user = 'user'
+ system = 'system'
+ developer = 'developer'
+
+
+class Type9(str, Enum):
+ message = 'message'
+
+
+class Type10(str, Enum):
+ input_text = 'input_text'
+
+
+class InputTextContent(BaseModel):
+ text: str = Field(..., description='The text input to the model.')
+ type: Type10 = Field(
+ ..., description='The type of the input item. Always `input_text`.'
+ )
+
+
+class KlingAudioUploadType(str, Enum):
+ file = 'file'
+ url = 'url'
+
+
+class KlingCameraConfig(BaseModel):
+ horizontal: Optional[float] = Field(
+ None,
+ description="Controls camera's movement along horizontal axis (x-axis). Negative indicates left, positive indicates right.",
+ ge=-10.0,
+ le=10.0,
+ )
+ pan: Optional[float] = Field(
+ None,
+ description="Controls camera's rotation in vertical plane (x-axis). Negative indicates downward rotation, positive indicates upward rotation.",
+ ge=-10.0,
+ le=10.0,
+ )
+ roll: Optional[float] = Field(
+ None,
+ description="Controls camera's rolling amount (z-axis). Negative indicates counterclockwise, positive indicates clockwise.",
+ ge=-10.0,
+ le=10.0,
+ )
+ tilt: Optional[float] = Field(
+ None,
+ description="Controls camera's rotation in horizontal plane (y-axis). Negative indicates left rotation, positive indicates right rotation.",
+ ge=-10.0,
+ le=10.0,
+ )
+ vertical: Optional[float] = Field(
+ None,
+ description="Controls camera's movement along vertical axis (y-axis). Negative indicates downward, positive indicates upward.",
+ ge=-10.0,
+ le=10.0,
+ )
+ zoom: Optional[float] = Field(
+ None,
+ description="Controls change in camera's focal length. Negative indicates narrower field of view, positive indicates wider field of view.",
+ ge=-10.0,
+ le=10.0,
+ )
+
+
+class KlingCameraControlType(str, Enum):
+ simple = 'simple'
+ down_back = 'down_back'
+ forward_up = 'forward_up'
+ right_turn_forward = 'right_turn_forward'
+ left_turn_forward = 'left_turn_forward'
+
+
+class KlingCharacterEffectModelName(str, Enum):
+ kling_v1 = 'kling-v1'
+ kling_v1_5 = 'kling-v1-5'
+ kling_v1_6 = 'kling-v1-6'
+
+
+class KlingDualCharacterEffectsScene(str, Enum):
+ hug = 'hug'
+ kiss = 'kiss'
+ heart_gesture = 'heart_gesture'
+
+
+class KlingDualCharacterImages(RootModel[List[str]]):
+ root: List[str] = Field(..., max_length=2, min_length=2)
+
+
+class KlingErrorResponse(BaseModel):
+ code: int = Field(
+ ...,
+ description='- 1000: Authentication failed\n- 1001: Authorization is empty\n- 1002: Authorization is invalid\n- 1003: Authorization is not yet valid\n- 1004: Authorization has expired\n- 1100: Account exception\n- 1101: Account in arrears (postpaid scenario)\n- 1102: Resource pack depleted or expired (prepaid scenario)\n- 1103: Unauthorized access to requested resource\n- 1200: Invalid request parameters\n- 1201: Invalid parameters\n- 1202: Invalid request method\n- 1203: Requested resource does not exist\n- 1300: Trigger platform strategy\n- 1301: Trigger content security policy\n- 1302: API request too frequent\n- 1303: Concurrency/QPS exceeds limit\n- 1304: Trigger IP whitelist policy\n- 5000: Internal server error\n- 5001: Service temporarily unavailable\n- 5002: Server internal timeout\n',
+ )
+ message: str = Field(..., description='Human-readable error message')
+ request_id: str = Field(
+ ..., description='Request ID for tracking and troubleshooting'
+ )
+
+
+class Trajectory(BaseModel):
+ x: Optional[int] = Field(
+ None,
+ description='The horizontal coordinate of trajectory point. Based on bottom-left corner of image as origin (0,0).',
+ )
+ y: Optional[int] = Field(
+ None,
+ description='The vertical coordinate of trajectory point. Based on bottom-left corner of image as origin (0,0).',
+ )
+
+
+class DynamicMask(BaseModel):
+ mask: Optional[AnyUrl] = Field(
+ None,
+ description='Dynamic Brush Application Area (Mask image created by users using the motion brush). The aspect ratio must match the input image.',
+ )
+ trajectories: Optional[List[Trajectory]] = None
+
+
+class TaskInfo(BaseModel):
+ external_task_id: Optional[str] = None
+
+
+class KlingImageGenAspectRatio(str, Enum):
+ field_16_9 = '16:9'
+ field_9_16 = '9:16'
+ field_1_1 = '1:1'
+ field_4_3 = '4:3'
+ field_3_4 = '3:4'
+ field_3_2 = '3:2'
+ field_2_3 = '2:3'
+ field_21_9 = '21:9'
+
+
+class KlingImageGenImageReferenceType(str, Enum):
+ subject = 'subject'
+ face = 'face'
+
+
+class KlingImageGenModelName(str, Enum):
+ kling_v1 = 'kling-v1'
+ kling_v1_5 = 'kling-v1-5'
+ kling_v2 = 'kling-v2'
+
+
+class KlingImageGenerationsRequest(BaseModel):
+ aspect_ratio: Optional[KlingImageGenAspectRatio] = '16:9'
+ callback_url: Optional[AnyUrl] = Field(
+ None, description='The callback notification address'
+ )
+ human_fidelity: Optional[float] = Field(
+ 0.45, description='Subject reference similarity', ge=0.0, le=1.0
+ )
+ image: Optional[str] = Field(
+ None, description='Reference Image - Base64 encoded string or image URL'
+ )
+ image_fidelity: Optional[float] = Field(
+ 0.5, description='Reference intensity for user-uploaded images', ge=0.0, le=1.0
+ )
+ image_reference: Optional[KlingImageGenImageReferenceType] = None
+ model_name: Optional[KlingImageGenModelName] = 'kling-v1'
+ n: Optional[int] = Field(1, description='Number of generated images', ge=1, le=9)
+ negative_prompt: Optional[str] = Field(
+ None, description='Negative text prompt', max_length=200
+ )
+ prompt: str = Field(..., description='Positive text prompt', max_length=500)
+
+
+class KlingImageResult(BaseModel):
+ index: Optional[int] = Field(None, description='Image Number (0-9)')
+ url: Optional[AnyUrl] = Field(None, description='URL for generated image')
+
+
+class KlingLipSyncMode(str, Enum):
+ text2video = 'text2video'
+ audio2video = 'audio2video'
+
+
+class KlingLipSyncVoiceLanguage(str, Enum):
+ zh = 'zh'
+ en = 'en'
+
+
+class ResourcePackType(str, Enum):
+ decreasing_total = 'decreasing_total'
+ constant_period = 'constant_period'
+
+
+class Status4(str, Enum):
+ toBeOnline = 'toBeOnline'
+ online = 'online'
+ expired = 'expired'
+ runOut = 'runOut'
+
+
+class ResourcePackSubscribeInfo(BaseModel):
+ effective_time: Optional[int] = Field(
+ None, description='Effective time, Unix timestamp in ms'
+ )
+ invalid_time: Optional[int] = Field(
+ None, description='Expiration time, Unix timestamp in ms'
+ )
+ purchase_time: Optional[int] = Field(
+ None, description='Purchase time, Unix timestamp in ms'
+ )
+ remaining_quantity: Optional[float] = Field(
+ None, description='Remaining quantity (updated with a 12-hour delay)'
+ )
+ resource_pack_id: Optional[str] = Field(None, description='Resource package ID')
+ resource_pack_name: Optional[str] = Field(None, description='Resource package name')
+ resource_pack_type: Optional[ResourcePackType] = Field(
+ None,
+ description='Resource package type (decreasing_total=decreasing total, constant_period=constant periodicity)',
+ )
+ status: Optional[Status4] = Field(None, description='Resource Package Status')
+ total_quantity: Optional[float] = Field(None, description='Total quantity')
+
+
+class Data3(BaseModel):
+ code: Optional[int] = Field(None, description='Error code; 0 indicates success')
+ msg: Optional[str] = Field(None, description='Error information')
+ resource_pack_subscribe_infos: Optional[List[ResourcePackSubscribeInfo]] = Field(
+ None, description='Resource package list'
+ )
+
+
+class KlingResourcePackageResponse(BaseModel):
+ code: Optional[int] = Field(None, description='Error code; 0 indicates success')
+ data: Optional[Data3] = None
+ message: Optional[str] = Field(None, description='Error information')
+ request_id: Optional[str] = Field(
+ None,
+ description='Request ID, generated by the system, used to track requests and troubleshoot problems',
+ )
+
+
+class KlingSingleImageEffectDuration(str, Enum):
+ field_5 = '5'
+
+
+class KlingSingleImageEffectModelName(str, Enum):
+ kling_v1_6 = 'kling-v1-6'
+
+
+class KlingSingleImageEffectsScene(str, Enum):
+ bloombloom = 'bloombloom'
+ dizzydizzy = 'dizzydizzy'
+ fuzzyfuzzy = 'fuzzyfuzzy'
+ squish = 'squish'
+ expansion = 'expansion'
+
+
+class KlingTaskStatus(str, Enum):
+ submitted = 'submitted'
+ processing = 'processing'
+ succeed = 'succeed'
+ failed = 'failed'
+
+
+class KlingTextToVideoModelName(str, Enum):
+ kling_v1 = 'kling-v1'
+ kling_v1_6 = 'kling-v1-6'
+
+
+class KlingVideoGenAspectRatio(str, Enum):
+ field_16_9 = '16:9'
+ field_9_16 = '9:16'
+ field_1_1 = '1:1'
+
+
+class KlingVideoGenCfgScale(RootModel[float]):
+ root: float = Field(
+ ...,
+ description="Flexibility in video generation. The higher the value, the lower the model's degree of flexibility, and the stronger the relevance to the user's prompt.",
+ ge=0.0,
+ le=1.0,
+ )
+
+
+class KlingVideoGenDuration(str, Enum):
+ field_5 = '5'
+ field_10 = '10'
+
+
+class KlingVideoGenMode(str, Enum):
+ std = 'std'
+ pro = 'pro'
+
+
+class KlingVideoGenModelName(str, Enum):
+ kling_v1 = 'kling-v1'
+ kling_v1_5 = 'kling-v1-5'
+ kling_v1_6 = 'kling-v1-6'
+ kling_v2_master = 'kling-v2-master'
+
+
+class KlingVideoResult(BaseModel):
+ duration: Optional[str] = Field(None, description='Total video duration')
+ id: Optional[str] = Field(None, description='Generated video ID')
+ url: Optional[AnyUrl] = Field(None, description='URL for generated video')
+
+
+class KlingVirtualTryOnModelName(str, Enum):
+ kolors_virtual_try_on_v1 = 'kolors-virtual-try-on-v1'
+ kolors_virtual_try_on_v1_5 = 'kolors-virtual-try-on-v1-5'
+
+
+class KlingVirtualTryOnRequest(BaseModel):
+ callback_url: Optional[AnyUrl] = Field(
+ None, description='The callback notification address'
+ )
+ cloth_image: Optional[str] = Field(
+ None,
+ description='Reference clothing image - Base64 encoded string or image URL',
+ )
+ human_image: str = Field(
+ ..., description='Reference human image - Base64 encoded string or image URL'
+ )
+ model_name: Optional[KlingVirtualTryOnModelName] = 'kolors-virtual-try-on-v1'
+
+
+class TaskResult6(BaseModel):
+ images: Optional[List[KlingImageResult]] = None
+
+
+class Data7(BaseModel):
+ created_at: Optional[int] = Field(None, description='Task creation time')
+ task_id: Optional[str] = Field(None, description='Task ID')
+ task_result: Optional[TaskResult6] = None
+ task_status: Optional[KlingTaskStatus] = None
+ task_status_msg: Optional[str] = Field(None, description='Task status information')
+ updated_at: Optional[int] = Field(None, description='Task update time')
+
+
+class KlingVirtualTryOnResponse(BaseModel):
+ code: Optional[int] = Field(None, description='Error code')
+ data: Optional[Data7] = None
+ message: Optional[str] = Field(None, description='Error message')
+ request_id: Optional[str] = Field(None, description='Request ID')
+
+
+class LumaAspectRatio(str, Enum):
+ field_1_1 = '1:1'
+ field_16_9 = '16:9'
+ field_9_16 = '9:16'
+ field_4_3 = '4:3'
+ field_3_4 = '3:4'
+ field_21_9 = '21:9'
+ field_9_21 = '9:21'
+
+
+class LumaAssets(BaseModel):
+ image: Optional[AnyUrl] = Field(None, description='The URL of the image')
+ progress_video: Optional[AnyUrl] = Field(
+ None, description='The URL of the progress video'
+ )
+ video: Optional[AnyUrl] = Field(None, description='The URL of the video')
+
+
+class GenerationType(str, Enum):
+ add_audio = 'add_audio'
+
+
+class LumaAudioGenerationRequest(BaseModel):
+ callback_url: Optional[AnyUrl] = Field(
+ None, description='The callback URL for the audio'
+ )
+ generation_type: Optional[GenerationType] = 'add_audio'
+ negative_prompt: Optional[str] = Field(
+ None, description='The negative prompt of the audio'
+ )
+ prompt: Optional[str] = Field(None, description='The prompt of the audio')
+
+
+class LumaError(BaseModel):
+ detail: Optional[str] = Field(None, description='The error message')
+
+
+class Type11(str, Enum):
+ generation = 'generation'
+
+
+class LumaGenerationReference(BaseModel):
+ id: UUID = Field(..., description='The ID of the generation')
+ type: Literal['generation']
+
+
+class GenerationType1(str, Enum):
+ video = 'video'
+
+
+class LumaGenerationType(str, Enum):
+ video = 'video'
+ image = 'image'
+
+
+class GenerationType2(str, Enum):
+ image = 'image'
+
+
+class LumaImageIdentity(BaseModel):
+ images: Optional[List[AnyUrl]] = Field(
+ None, description='The URLs of the image identity'
+ )
+
+
+class LumaImageModel(str, Enum):
+ photon_1 = 'photon-1'
+ photon_flash_1 = 'photon-flash-1'
+
+
+class LumaImageRef(BaseModel):
+ url: Optional[AnyUrl] = Field(None, description='The URL of the image reference')
+ weight: Optional[float] = Field(
+ None, description='The weight of the image reference'
+ )
+
+
+class Type12(str, Enum):
+ image = 'image'
+
+
+class LumaImageReference(BaseModel):
+ type: Literal['image']
+ url: AnyUrl = Field(..., description='The URL of the image')
+
+
+class LumaKeyframe(RootModel[Union[LumaGenerationReference, LumaImageReference]]):
+ root: Union[LumaGenerationReference, LumaImageReference] = Field(
+ ...,
+ description='A keyframe can be either a Generation reference, an Image, or a Video',
+ discriminator='type',
+ )
+
+
+class LumaKeyframes(BaseModel):
+ frame0: Optional[LumaKeyframe] = None
+ frame1: Optional[LumaKeyframe] = None
+
+
+class LumaModifyImageRef(BaseModel):
+ url: Optional[AnyUrl] = Field(None, description='The URL of the image reference')
+ weight: Optional[float] = Field(
+ None, description='The weight of the modify image reference'
+ )
+
+
+class LumaState(str, Enum):
+ queued = 'queued'
+ dreaming = 'dreaming'
+ completed = 'completed'
+ failed = 'failed'
+
+
+class GenerationType3(str, Enum):
+ upscale_video = 'upscale_video'
+
+
+class LumaVideoModel(str, Enum):
+ ray_2 = 'ray-2'
+ ray_flash_2 = 'ray-flash-2'
+ ray_1_6 = 'ray-1-6'
+
+
+class LumaVideoModelOutputDuration1(str, Enum):
+ field_5s = '5s'
+ field_9s = '9s'
+
+
+class LumaVideoModelOutputDuration(
+ RootModel[Union[LumaVideoModelOutputDuration1, str]]
+):
+ root: Union[LumaVideoModelOutputDuration1, str]
+
+
+class LumaVideoModelOutputResolution1(str, Enum):
+ field_540p = '540p'
+ field_720p = '720p'
+ field_1080p = '1080p'
+ field_4k = '4k'
+
+
+class LumaVideoModelOutputResolution(
+ RootModel[Union[LumaVideoModelOutputResolution1, str]]
+):
+ root: Union[LumaVideoModelOutputResolution1, str]
+
+
+class MinimaxBaseResponse(BaseModel):
+ status_code: int = Field(
+ ...,
+ description='Status code. 0 indicates success, other values indicate errors.',
+ )
+ status_msg: str = Field(
+ ..., description='Specific error details or success message.'
+ )
+
+
+class File(BaseModel):
+ bytes: Optional[int] = Field(None, description='File size in bytes')
+ created_at: Optional[int] = Field(
+ None, description='Unix timestamp when the file was created, in seconds'
+ )
+ download_url: Optional[str] = Field(
+ None, description='The URL to download the video'
+ )
+ file_id: Optional[int] = Field(None, description='Unique identifier for the file')
+ filename: Optional[str] = Field(None, description='The name of the file')
+ purpose: Optional[str] = Field(None, description='The purpose of using the file')
+
+
+class MinimaxFileRetrieveResponse(BaseModel):
+ base_resp: MinimaxBaseResponse
+ file: File
+
+
+class Status5(str, Enum):
+ Queueing = 'Queueing'
+ Preparing = 'Preparing'
+ Processing = 'Processing'
+ Success = 'Success'
+ Fail = 'Fail'
+
+
+class MinimaxTaskResultResponse(BaseModel):
+ base_resp: MinimaxBaseResponse
+ file_id: Optional[str] = Field(
+ None,
+ description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.',
+ )
+ status: Status5 = Field(
+ ...,
+ description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).",
+ )
+ task_id: str = Field(..., description='The task ID being queried.')
+
+
+class Model(str, Enum):
+ T2V_01_Director = 'T2V-01-Director'
+ I2V_01_Director = 'I2V-01-Director'
+ S2V_01 = 'S2V-01'
+ I2V_01 = 'I2V-01'
+ I2V_01_live = 'I2V-01-live'
+ T2V_01 = 'T2V-01'
+
+
+class SubjectReferenceItem(BaseModel):
+ image: Optional[str] = Field(
+ None, description='URL or base64 encoding of the subject reference image.'
+ )
+ mask: Optional[str] = Field(
+ None,
+ description='URL or base64 encoding of the mask for the subject reference image.',
+ )
+
+
+class MinimaxVideoGenerationRequest(BaseModel):
+ callback_url: Optional[str] = Field(
+ None,
+ description='Optional. URL to receive real-time status updates about the video generation task.',
+ )
+ first_frame_image: Optional[str] = Field(
+ None,
+ description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.',
+ )
+ model: Model = Field(
+ ...,
+ description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01',
+ )
+ prompt: Optional[str] = Field(
+ None,
+ description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].',
+ max_length=2000,
+ )
+ prompt_optimizer: Optional[bool] = Field(
+ True,
+ description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.',
+ )
+ subject_reference: Optional[List[SubjectReferenceItem]] = Field(
+ None,
+ description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.',
+ )
+
+
+class MinimaxVideoGenerationResponse(BaseModel):
+ base_resp: MinimaxBaseResponse
+ task_id: str = Field(
+ ..., description='The task ID for the asynchronous video generation task.'
+ )
+
+
+class Truncation(str, Enum):
+ disabled = 'disabled'
+ auto = 'auto'
+
+
+class ModelResponseProperties(BaseModel):
+ instructions: Optional[str] = Field(
+ None, description='Instructions for the model on how to generate the response'
+ )
+ max_output_tokens: Optional[int] = Field(
+ None, description='Maximum number of tokens to generate'
+ )
+ model: Optional[str] = Field(
+ None, description='The model used to generate the response'
+ )
+ temperature: Optional[float] = Field(
+ 1, description='Controls randomness in the response', ge=0.0, le=2.0
+ )
+ top_p: Optional[float] = Field(
+ 1,
+ description='Controls diversity of the response via nucleus sampling',
+ ge=0.0,
+ le=1.0,
+ )
+ truncation: Optional[Truncation] = Field(
+ 'disabled', description='How to handle truncation of the response'
+ )
+
+
+class Moderation(str, Enum):
+ low = 'low'
+ auto = 'auto'
+
+
+class OutputFormat1(str, Enum):
+ png = 'png'
+ webp = 'webp'
+ jpeg = 'jpeg'
+
+
+class OpenAIImageEditRequest(BaseModel):
+ background: Optional[str] = Field(
+ None, description='Background transparency', examples=['opaque']
+ )
+ model: str = Field(
+ ..., description='The model to use for image editing', examples=['gpt-image-1']
+ )
+ moderation: Optional[Moderation] = Field(
+ None, description='Content moderation setting', examples=['auto']
+ )
+ n: Optional[int] = Field(
+ None, description='The number of images to generate', examples=[1]
+ )
+ output_compression: Optional[int] = Field(
+ None, description='Compression level for JPEG or WebP (0-100)', examples=[100]
+ )
+ output_format: Optional[OutputFormat1] = Field(
+ None, description='Format of the output image', examples=['png']
+ )
+ prompt: str = Field(
+ ...,
+ description='A text description of the desired edit',
+ examples=['Give the rocketship rainbow coloring'],
+ )
+ quality: Optional[str] = Field(
+ None, description='The quality of the edited image', examples=['low']
+ )
+ size: Optional[str] = Field(
+ None, description='Size of the output image', examples=['1024x1024']
+ )
+ user: Optional[str] = Field(
+ None,
+ description='A unique identifier for end-user monitoring',
+ examples=['user-1234'],
+ )
+
+
+class Background(str, Enum):
+ transparent = 'transparent'
+ opaque = 'opaque'
+
+
+class Quality(str, Enum):
+ low = 'low'
+ medium = 'medium'
+ high = 'high'
+ standard = 'standard'
+ hd = 'hd'
+
+
+class ResponseFormat(str, Enum):
+ url = 'url'
+ b64_json = 'b64_json'
+
+
+class Style(str, Enum):
+ vivid = 'vivid'
+ natural = 'natural'
+
+
+class OpenAIImageGenerationRequest(BaseModel):
+ background: Optional[Background] = Field(
+ None, description='Background transparency', examples=['opaque']
+ )
+ model: Optional[str] = Field(
+ None, description='The model to use for image generation', examples=['dall-e-3']
+ )
+ moderation: Optional[Moderation] = Field(
+ None, description='Content moderation setting', examples=['auto']
+ )
+ n: Optional[int] = Field(
+ None,
+ description='The number of images to generate (1-10). Only 1 supported for dall-e-3.',
+ examples=[1],
+ )
+ output_compression: Optional[int] = Field(
+ None, description='Compression level for JPEG or WebP (0-100)', examples=[100]
+ )
+ output_format: Optional[OutputFormat1] = Field(
+ None, description='Format of the output image', examples=['png']
+ )
+ prompt: str = Field(
+ ...,
+ description='A text description of the desired image',
+ examples=['Draw a rocket in front of a blackhole in deep space'],
+ )
+ quality: Optional[Quality] = Field(
+ None, description='The quality of the generated image', examples=['high']
+ )
+ response_format: Optional[ResponseFormat] = Field(
+ None, description='Response format of image data', examples=['b64_json']
+ )
+ size: Optional[str] = Field(
+ None,
+ description='Size of the image (e.g., 1024x1024, 1536x1024, auto)',
+ examples=['1024x1536'],
+ )
+ style: Optional[Style] = Field(
+ None, description='Style of the image (only for dall-e-3)', examples=['vivid']
+ )
+ user: Optional[str] = Field(
+ None,
+ description='A unique identifier for end-user monitoring',
+ examples=['user-1234'],
+ )
+
+
+class Datum2(BaseModel):
+ b64_json: Optional[str] = Field(None, description='Base64 encoded image data')
+ revised_prompt: Optional[str] = Field(None, description='Revised prompt')
+ url: Optional[str] = Field(None, description='URL of the image')
+
+
+class InputTokensDetails(BaseModel):
+ image_tokens: Optional[int] = None
+ text_tokens: Optional[int] = None
+
+
+class Usage(BaseModel):
+ input_tokens: Optional[int] = None
+ input_tokens_details: Optional[InputTokensDetails] = None
+ output_tokens: Optional[int] = None
+ total_tokens: Optional[int] = None
+
+
+class OpenAIImageGenerationResponse(BaseModel):
+ data: Optional[List[Datum2]] = None
+ usage: Optional[Usage] = None
+
+
+class OpenAIModels(str, Enum):
+ gpt_4 = 'gpt-4'
+ gpt_4_0314 = 'gpt-4-0314'
+ gpt_4_0613 = 'gpt-4-0613'
+ gpt_4_32k = 'gpt-4-32k'
+ gpt_4_32k_0314 = 'gpt-4-32k-0314'
+ gpt_4_32k_0613 = 'gpt-4-32k-0613'
+ gpt_4_0125_preview = 'gpt-4-0125-preview'
+ gpt_4_turbo = 'gpt-4-turbo'
+ gpt_4_turbo_2024_04_09 = 'gpt-4-turbo-2024-04-09'
+ gpt_4_turbo_preview = 'gpt-4-turbo-preview'
+ gpt_4_1106_preview = 'gpt-4-1106-preview'
+ gpt_4_vision_preview = 'gpt-4-vision-preview'
+ gpt_3_5_turbo = 'gpt-3.5-turbo'
+ gpt_3_5_turbo_16k = 'gpt-3.5-turbo-16k'
+ gpt_3_5_turbo_0301 = 'gpt-3.5-turbo-0301'
+ gpt_3_5_turbo_0613 = 'gpt-3.5-turbo-0613'
+ gpt_3_5_turbo_1106 = 'gpt-3.5-turbo-1106'
+ gpt_3_5_turbo_0125 = 'gpt-3.5-turbo-0125'
+ gpt_3_5_turbo_16k_0613 = 'gpt-3.5-turbo-16k-0613'
+ gpt_4_1 = 'gpt-4.1'
+ gpt_4_1_mini = 'gpt-4.1-mini'
+ gpt_4_1_nano = 'gpt-4.1-nano'
+ gpt_4_1_2025_04_14 = 'gpt-4.1-2025-04-14'
+ gpt_4_1_mini_2025_04_14 = 'gpt-4.1-mini-2025-04-14'
+ gpt_4_1_nano_2025_04_14 = 'gpt-4.1-nano-2025-04-14'
+ o1 = 'o1'
+ o1_mini = 'o1-mini'
+ o1_preview = 'o1-preview'
+ o1_pro = 'o1-pro'
+ o1_2024_12_17 = 'o1-2024-12-17'
+ o1_preview_2024_09_12 = 'o1-preview-2024-09-12'
+ o1_mini_2024_09_12 = 'o1-mini-2024-09-12'
+ o1_pro_2025_03_19 = 'o1-pro-2025-03-19'
+ o3 = 'o3'
+ o3_mini = 'o3-mini'
+ o3_2025_04_16 = 'o3-2025-04-16'
+ o3_mini_2025_01_31 = 'o3-mini-2025-01-31'
+ o4_mini = 'o4-mini'
+ o4_mini_2025_04_16 = 'o4-mini-2025-04-16'
+ gpt_4o = 'gpt-4o'
+ gpt_4o_mini = 'gpt-4o-mini'
+ gpt_4o_2024_11_20 = 'gpt-4o-2024-11-20'
+ gpt_4o_2024_08_06 = 'gpt-4o-2024-08-06'
+ gpt_4o_2024_05_13 = 'gpt-4o-2024-05-13'
+ gpt_4o_mini_2024_07_18 = 'gpt-4o-mini-2024-07-18'
+ gpt_4o_audio_preview = 'gpt-4o-audio-preview'
+ gpt_4o_audio_preview_2024_10_01 = 'gpt-4o-audio-preview-2024-10-01'
+ gpt_4o_audio_preview_2024_12_17 = 'gpt-4o-audio-preview-2024-12-17'
+ gpt_4o_mini_audio_preview = 'gpt-4o-mini-audio-preview'
+ gpt_4o_mini_audio_preview_2024_12_17 = 'gpt-4o-mini-audio-preview-2024-12-17'
+ gpt_4o_search_preview = 'gpt-4o-search-preview'
+ gpt_4o_mini_search_preview = 'gpt-4o-mini-search-preview'
+ gpt_4o_search_preview_2025_03_11 = 'gpt-4o-search-preview-2025-03-11'
+ gpt_4o_mini_search_preview_2025_03_11 = 'gpt-4o-mini-search-preview-2025-03-11'
+ computer_use_preview = 'computer-use-preview'
+ computer_use_preview_2025_03_11 = 'computer-use-preview-2025-03-11'
+ chatgpt_4o_latest = 'chatgpt-4o-latest'
+
+
+class Reason(str, Enum):
+ max_output_tokens = 'max_output_tokens'
+ content_filter = 'content_filter'
+
+
+class IncompleteDetails(BaseModel):
+ reason: Optional[Reason] = Field(
+ None, description='The reason why the response is incomplete.'
+ )
+
+
+class Object(str, Enum):
+ response = 'response'
+
+
+class Status6(str, Enum):
+ completed = 'completed'
+ failed = 'failed'
+ in_progress = 'in_progress'
+ incomplete = 'incomplete'
+
+
+class Type13(str, Enum):
+ output_audio = 'output_audio'
+
+
+class OutputAudioContent(BaseModel):
+ data: str = Field(..., description='Base64-encoded audio data')
+ transcript: str = Field(..., description='Transcript of the audio')
+ type: Type13 = Field(..., description='The type of output content')
+
+
+class Role4(str, Enum):
+ assistant = 'assistant'
+
+
+class Type14(str, Enum):
+ message = 'message'
+
+
+class Type15(str, Enum):
+ output_text = 'output_text'
+
+
+class OutputTextContent(BaseModel):
+ text: str = Field(..., description='The text content')
+ type: Type15 = Field(..., description='The type of output content')
+
+
+class AspectRatio1(RootModel[float]):
+ root: float = Field(
+ ...,
+ description='Aspect ratio (width / height)',
+ ge=0.4,
+ le=2.5,
+ title='Aspectratio',
+ )
+
+
+class IngredientsMode(str, Enum):
+ creative = 'creative'
+ precise = 'precise'
+
+
+class PikaBodyGenerate22C2vGenerate22PikascenesPost(BaseModel):
+ aspectRatio: Optional[AspectRatio1] = Field(
+ None, description='Aspect ratio (width / height)', title='Aspectratio'
+ )
+ duration: Optional[int] = Field(5, title='Duration')
+ images: Optional[List[StrictBytes]] = Field(None, title='Images')
+ ingredientsMode: IngredientsMode = Field(..., title='Ingredientsmode')
+ negativePrompt: Optional[str] = Field(None, title='Negativeprompt')
+ promptText: Optional[str] = Field(None, title='Prompttext')
+ resolution: Optional[str] = Field('1080p', title='Resolution')
+ seed: Optional[int] = Field(None, title='Seed')
+
+
+class PikaBodyGeneratePikadditionsGeneratePikadditionsPost(BaseModel):
+ image: Optional[StrictBytes] = Field(None, title='Image')
+ negativePrompt: Optional[str] = Field(None, title='Negativeprompt')
+ promptText: Optional[str] = Field(None, title='Prompttext')
+ seed: Optional[int] = Field(None, title='Seed')
+ video: Optional[StrictBytes] = Field(None, title='Video')
+
+
+class PikaBodyGeneratePikaswapsGeneratePikaswapsPost(BaseModel):
+ image: Optional[StrictBytes] = Field(None, title='Image')
+ modifyRegionMask: Optional[StrictBytes] = Field(
+ None,
+ description='A mask image that specifies the region to modify, where the mask is white and the background is black',
+ title='Modifyregionmask',
+ )
+ modifyRegionRoi: Optional[str] = Field(
+ None,
+ description='Plaintext description of the object / region to modify',
+ title='Modifyregionroi',
+ )
+ negativePrompt: Optional[str] = Field(None, title='Negativeprompt')
+ promptText: Optional[str] = Field(None, title='Prompttext')
+ seed: Optional[int] = Field(None, title='Seed')
+ video: Optional[StrictBytes] = Field(None, title='Video')
+
+
+class PikaDurationEnum(int, Enum):
+ integer_5 = 5
+ integer_10 = 10
+
+
+class PikaGenerateResponse(BaseModel):
+ video_id: str = Field(..., title='Video Id')
+
+
+class PikaResolutionEnum(str, Enum):
+ field_1080p = '1080p'
+ field_720p = '720p'
+
+
+class PikaStatusEnum(str, Enum):
+ queued = 'queued'
+ started = 'started'
+ finished = 'finished'
+
+
+class PikaValidationError(BaseModel):
+ loc: List[Union[str, int]] = Field(..., title='Location')
+ msg: str = Field(..., title='Message')
+ type: str = Field(..., title='Error Type')
+
+
+class PikaVideoResponse(BaseModel):
+ id: str = Field(..., title='Id')
+ progress: Optional[int] = Field(None, title='Progress')
+ status: PikaStatusEnum
+ url: Optional[str] = Field(None, title='Url')
+
+
+class Pikaffect(str, Enum):
+ Cake_ify = 'Cake-ify'
+ Crumble = 'Crumble'
+ Crush = 'Crush'
+ Decapitate = 'Decapitate'
+ Deflate = 'Deflate'
+ Dissolve = 'Dissolve'
+ Explode = 'Explode'
+ Eye_pop = 'Eye-pop'
+ Inflate = 'Inflate'
+ Levitate = 'Levitate'
+ Melt = 'Melt'
+ Peel = 'Peel'
+ Poke = 'Poke'
+ Squish = 'Squish'
+ Ta_da = 'Ta-da'
+ Tear = 'Tear'
+
+
+class Resp(BaseModel):
+ img_id: Optional[int] = None
+
+
+class PixverseImageUploadResponse(BaseModel):
+ ErrCode: Optional[int] = None
+ ErrMsg: Optional[str] = None
+ Resp_1: Optional[Resp] = Field(None, alias='Resp')
+
+
+class Duration(int, Enum):
+ integer_5 = 5
+ integer_8 = 8
+
+
+class Model1(str, Enum):
+ v3_5 = 'v3.5'
+
+
+class MotionMode(str, Enum):
+ normal = 'normal'
+ fast = 'fast'
+
+
+class Quality1(str, Enum):
+ field_360p = '360p'
+ field_540p = '540p'
+ field_720p = '720p'
+ field_1080p = '1080p'
+
+
+class Style1(str, Enum):
+ anime = 'anime'
+ field_3d_animation = '3d_animation'
+ clay = 'clay'
+ comic = 'comic'
+ cyberpunk = 'cyberpunk'
+
+
+class PixverseImageVideoRequest(BaseModel):
+ duration: Duration
+ img_id: int
+ model: Model1
+ motion_mode: Optional[MotionMode] = None
+ prompt: str
+ quality: Quality1
+ seed: Optional[int] = None
+ style: Optional[Style1] = None
+ template_id: Optional[int] = None
+ water_mark: Optional[bool] = None
+
+
+class AspectRatio2(str, Enum):
+ field_16_9 = '16:9'
+ field_4_3 = '4:3'
+ field_1_1 = '1:1'
+ field_3_4 = '3:4'
+ field_9_16 = '9:16'
+
+
+class PixverseTextVideoRequest(BaseModel):
+ aspect_ratio: AspectRatio2
+ duration: Duration
+ model: Model1
+ motion_mode: Optional[MotionMode] = None
+ negative_prompt: Optional[str] = None
+ prompt: str
+ quality: Quality1
+ seed: Optional[int] = None
+ style: Optional[Style1] = None
+ template_id: Optional[int] = None
+ water_mark: Optional[bool] = None
+
+
+class PixverseTransitionVideoRequest(BaseModel):
+ duration: Duration
+ first_frame_img: int
+ last_frame_img: int
+ model: Model1
+ motion_mode: MotionMode
+ prompt: str
+ quality: Quality1
+ seed: int
+ style: Optional[Style1] = None
+ template_id: Optional[int] = None
+ water_mark: Optional[bool] = None
+
+
+class Resp1(BaseModel):
+ video_id: Optional[int] = None
+
+
+class PixverseVideoResponse(BaseModel):
+ ErrCode: Optional[int] = None
+ ErrMsg: Optional[str] = None
+ Resp: Optional[Resp1] = None
+
+
+class Status7(int, Enum):
+ integer_1 = 1
+ integer_5 = 5
+ integer_6 = 6
+ integer_7 = 7
+ integer_8 = 8
+
+
+class Resp2(BaseModel):
+ create_time: Optional[str] = None
+ id: Optional[int] = None
+ modify_time: Optional[str] = None
+ negative_prompt: Optional[str] = None
+ outputHeight: Optional[int] = None
+ outputWidth: Optional[int] = None
+ prompt: Optional[str] = None
+ resolution_ratio: Optional[int] = None
+ seed: Optional[int] = None
+ size: Optional[int] = None
+ status: Optional[Status7] = Field(
+ None,
+ description='Video generation status codes:\n* 1 - Generation successful\n* 5 - Generating\n* 6 - Deleted\n* 7 - Contents moderation failed\n* 8 - Generation failed\n',
+ )
+ style: Optional[str] = None
+ url: Optional[str] = None
+
+
+class PixverseVideoResultResponse(BaseModel):
+ ErrCode: Optional[int] = None
+ ErrMsg: Optional[str] = None
+ Resp: Optional[Resp2] = None
+
+
+class RgbItem(RootModel[int]):
+ root: int = Field(..., ge=0, le=255)
+
+
+class RGBColor(BaseModel):
+ rgb: List[RgbItem] = Field(..., max_length=3, min_length=3)
+
+
+class GenerateSummary(str, Enum):
+ auto = 'auto'
+ concise = 'concise'
+ detailed = 'detailed'
+
+
+class Summary(str, Enum):
+ auto = 'auto'
+ concise = 'concise'
+ detailed = 'detailed'
+
+
+class ReasoningEffort(str, Enum):
+ low = 'low'
+ medium = 'medium'
+ high = 'high'
+
+
+class Status8(str, Enum):
+ in_progress = 'in_progress'
+ completed = 'completed'
+ incomplete = 'incomplete'
+
+
+class Type16(str, Enum):
+ summary_text = 'summary_text'
+
+
+class SummaryItem(BaseModel):
+ text: str = Field(
+ ...,
+ description='A short summary of the reasoning used by the model when generating\nthe response.\n',
+ )
+ type: Type16 = Field(
+ ..., description='The type of the object. Always `summary_text`.\n'
+ )
+
+
+class Type17(str, Enum):
+ reasoning = 'reasoning'
+
+
+class ReasoningItem(BaseModel):
+ id: str = Field(
+ ..., description='The unique identifier of the reasoning content.\n'
+ )
+ status: Optional[Status8] = Field(
+ None,
+ description='The status of the item. One of `in_progress`, `completed`, or\n`incomplete`. Populated when items are returned via API.\n',
+ )
+ summary: List[SummaryItem] = Field(..., description='Reasoning text contents.\n')
+ type: Type17 = Field(
+ ..., description='The type of the object. Always `reasoning`.\n'
+ )
+
+
+class Controls(BaseModel):
+ artistic_level: Optional[int] = Field(
+ None,
+ description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity.',
+ ge=0,
+ le=5,
+ )
+ background_color: Optional[RGBColor] = None
+ colors: Optional[List[RGBColor]] = Field(
+ None, description='An array of preferable colors'
+ )
+ no_text: Optional[bool] = Field(None, description='Do not embed text layouts')
+
+
+class RecraftImageGenerationRequest(BaseModel):
+ controls: Optional[Controls] = Field(
+ None, description='The controls for the generated image'
+ )
+ model: str = Field(
+ ..., description='The model to use for generation (e.g., "recraftv3")'
+ )
+ n: int = Field(..., description='The number of images to generate', ge=1, le=4)
+ prompt: str = Field(
+ ..., description='The text prompt describing the image to generate'
+ )
+ size: str = Field(
+ ..., description='The size of the generated image (e.g., "1024x1024")'
+ )
+ style: Optional[str] = Field(
+ None,
+ description='The style to apply to the generated image (e.g., "digital_illustration")',
+ )
+ style_id: Optional[str] = Field(
+ None,
+ description='The style ID to apply to the generated image (e.g., "123e4567-e89b-12d3-a456-426614174000"). If style_id is provided, style should not be provided.',
+ )
+
+
+class Datum3(BaseModel):
+ image_id: Optional[str] = Field(
+ None, description='Unique identifier for the generated image'
+ )
+ url: Optional[str] = Field(None, description='URL to access the generated image')
+
+
+class RecraftImageGenerationResponse(BaseModel):
+ created: int = Field(
+ ..., description='Unix timestamp when the generation was created'
+ )
+ credits: int = Field(..., description='Number of credits used for the generation')
+ data: List[Datum3] = Field(..., description='Array of generated image information')
+
+
+class RenderingSpeed(str, Enum):
+ BALANCED = 'BALANCED'
+ TURBO = 'TURBO'
+ QUALITY = 'QUALITY'
+
+
+class ResponseErrorCode(str, Enum):
+ server_error = 'server_error'
+ rate_limit_exceeded = 'rate_limit_exceeded'
+ invalid_prompt = 'invalid_prompt'
+ vector_store_timeout = 'vector_store_timeout'
+ invalid_image = 'invalid_image'
+ invalid_image_format = 'invalid_image_format'
+ invalid_base64_image = 'invalid_base64_image'
+ invalid_image_url = 'invalid_image_url'
+ image_too_large = 'image_too_large'
+ image_too_small = 'image_too_small'
+ image_parse_error = 'image_parse_error'
+ image_content_policy_violation = 'image_content_policy_violation'
+ invalid_image_mode = 'invalid_image_mode'
+ image_file_too_large = 'image_file_too_large'
+ unsupported_image_media_type = 'unsupported_image_media_type'
+ empty_image_file = 'empty_image_file'
+ failed_to_download_image = 'failed_to_download_image'
+ image_file_not_found = 'image_file_not_found'
+
+
+class Type18(str, Enum):
+ json_object = 'json_object'
+
+
+class ResponseFormatJsonObject(BaseModel):
+ type: Type18 = Field(
+ ...,
+ description='The type of response format being defined. Always `json_object`.',
+ )
+
+
+class ResponseFormatJsonSchemaSchema(BaseModel):
+ pass
+ model_config = ConfigDict(
+ extra='allow',
+ )
+
+
+class Type19(str, Enum):
+ text = 'text'
+
+
+class ResponseFormatText(BaseModel):
+ type: Type19 = Field(
+ ..., description='The type of response format being defined. Always `text`.'
+ )
+
+
+class Truncation1(str, Enum):
+ auto = 'auto'
+ disabled = 'disabled'
+
+
+class InputTokensDetails1(BaseModel):
+ cached_tokens: int = Field(
+ ...,
+ description='The number of tokens that were retrieved from the cache. \n[More on prompt caching](/docs/guides/prompt-caching).\n',
+ )
+
+
+class OutputTokensDetails(BaseModel):
+ reasoning_tokens: int = Field(..., description='The number of reasoning tokens.')
+
+
+class ResponseUsage(BaseModel):
+ input_tokens: int = Field(..., description='The number of input tokens.')
+ input_tokens_details: InputTokensDetails1 = Field(
+ ..., description='A detailed breakdown of the input tokens.'
+ )
+ output_tokens: int = Field(..., description='The number of output tokens.')
+ output_tokens_details: OutputTokensDetails = Field(
+ ..., description='A detailed breakdown of the output tokens.'
+ )
+ total_tokens: int = Field(..., description='The total number of tokens used.')
+
+
+class Rodin3DCheckStatusRequest(BaseModel):
+ subscription_key: str = Field(
+ ..., description='subscription from generate endpoint'
+ )
+
+
+class Rodin3DCheckStatusResponse(BaseModel):
+ pass
+
+
+class Rodin3DDownloadRequest(BaseModel):
+ task_uuid: str = Field(..., description='Task UUID')
+
+
+class RodinGenerateJobsData(BaseModel):
+ subscription_key: Optional[str] = Field(None, description='Subscription Key.')
+ uuids: Optional[List[str]] = Field(None, description='subjobs uuid.')
+
+
+class RodinMaterialType(str, Enum):
+ PBR = 'PBR'
+ Shaded = 'Shaded'
+
+
+class RodinMeshModeType(str, Enum):
+ Quad = 'Quad'
+ Raw = 'Raw'
+
+
+class RodinQualityType(str, Enum):
+ extra_low = 'extra-low'
+ low = 'low'
+ medium = 'medium'
+ high = 'high'
+
+
+class RodinResourceItem(BaseModel):
+ name: Optional[str] = Field(None, description='File name')
+ url: Optional[str] = Field(None, description='Download url')
+
+
+class RodinTierType(str, Enum):
+ Regular = 'Regular'
+ Sketch = 'Sketch'
+ Detail = 'Detail'
+ Smooth = 'Smooth'
+
+
+class RunwayAspectRatioEnum(str, Enum):
+ field_1280_720 = '1280:720'
+ field_720_1280 = '720:1280'
+ field_1104_832 = '1104:832'
+ field_832_1104 = '832:1104'
+ field_960_960 = '960:960'
+ field_1584_672 = '1584:672'
+ field_1280_768 = '1280:768'
+ field_768_1280 = '768:1280'
+
+
+class RunwayDurationEnum(int, Enum):
+ integer_5 = 5
+ integer_10 = 10
+
+
+class RunwayImageToVideoResponse(BaseModel):
+ id: Optional[str] = Field(None, description='Task ID')
+
+
+class RunwayModelEnum(str, Enum):
+ gen4_turbo = 'gen4_turbo'
+ gen3a_turbo = 'gen3a_turbo'
+
+
+class Position(str, Enum):
+ first = 'first'
+ last = 'last'
+
+
+class RunwayPromptImageDetailedObject(BaseModel):
+ position: Position = Field(
+ ...,
+ description="The position of the image in the output video. 'last' is currently supported for gen3a_turbo only.",
+ )
+ uri: str = Field(
+ ..., description='A HTTPS URL or data URI containing an encoded image.'
+ )
+
+
+class RunwayPromptImageObject(
+ RootModel[Union[str, List[RunwayPromptImageDetailedObject]]]
+):
+ root: Union[str, List[RunwayPromptImageDetailedObject]] = Field(
+ ...,
+ description='Image(s) to use for the video generation. Can be a single URI or an array of image objects with positions.',
+ )
+
+
+class RunwayTaskStatusEnum(str, Enum):
+ SUCCEEDED = 'SUCCEEDED'
+ RUNNING = 'RUNNING'
+ FAILED = 'FAILED'
+ PENDING = 'PENDING'
+ CANCELLED = 'CANCELLED'
+ THROTTLED = 'THROTTLED'
+
+
+class RunwayTaskStatusResponse(BaseModel):
+ createdAt: datetime = Field(..., description='Task creation timestamp')
+ id: str = Field(..., description='Task ID')
+ output: Optional[List[str]] = Field(None, description='Array of output video URLs')
+ progress: Optional[float] = Field(
+ None,
+ description='Float value between 0 and 1 representing the progress of the task. Only available if status is RUNNING.',
+ ge=0.0,
+ le=1.0,
+ )
+ status: RunwayTaskStatusEnum
+
+
+class RunwayTextToImageAspectRatioEnum(str, Enum):
+ field_1920_1080 = '1920:1080'
+ field_1080_1920 = '1080:1920'
+ field_1024_1024 = '1024:1024'
+ field_1360_768 = '1360:768'
+ field_1080_1080 = '1080:1080'
+ field_1168_880 = '1168:880'
+ field_1440_1080 = '1440:1080'
+ field_1080_1440 = '1080:1440'
+ field_1808_768 = '1808:768'
+ field_2112_912 = '2112:912'
+
+class Model4(str, Enum):
+ gen4_image = 'gen4_image'
+
+
+class ReferenceImage(BaseModel):
+ uri: Optional[str] = Field(
+ None, description='A HTTPS URL or data URI containing an encoded image'
+ )
+
+
+class RunwayTextToImageRequest(BaseModel):
+ model: Model4 = Field(..., description='Model to use for generation')
+ promptText: str = Field(
+ ..., description='Text prompt for the image generation', max_length=1000
+ )
+ ratio: RunwayTextToImageAspectRatioEnum
+ referenceImages: Optional[List[ReferenceImage]] = Field(
+ None, description='Array of reference images to guide the generation'
+ )
+
+
+class RunwayTextToImageResponse(BaseModel):
+ id: Optional[str] = Field(None, description='Task ID')
+
+
+class StabilityError(BaseModel):
+ errors: List[str] = Field(
+ ...,
+ description='One or more error messages indicating what went wrong.',
+ examples=[[{'some-field': 'is required'}]],
+ min_length=1,
+ )
+ id: str = Field(
+ ...,
+ description='A unique identifier associated with this error. Please include this in any [support tickets](https://kb.stability.ai/knowledge-base/kb-tickets/new) you file, as it will greatly assist us in diagnosing the root cause of the problem.\n',
+ examples=['a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4'],
+ min_length=1,
+ )
+ name: str = Field(
+ ...,
+ description='Short-hand name for an error, useful for discriminating between errors with the same status code.',
+ examples=['bad_request'],
+ min_length=1,
+ )
+
+
+class Status9(str, Enum):
+ in_progress = 'in-progress'
+
+
+class StabilityGetResultResponse202(BaseModel):
+ id: Optional[str] = Field(
+ None, description='The ID of the generation result.', examples=[1234567890]
+ )
+ status: Optional[Status9] = None
+
+
+class Type20(str, Enum):
+ json_schema = 'json_schema'
+
+
+class TextResponseFormatJsonSchema(BaseModel):
+ description: Optional[str] = Field(
+ None,
+ description='A description of what the response format is for, used by the model to\ndetermine how to respond in the format.\n',
+ )
+ name: str = Field(
+ ...,
+ description='The name of the response format. Must be a-z, A-Z, 0-9, or contain\nunderscores and dashes, with a maximum length of 64.\n',
+ )
+ schema_: ResponseFormatJsonSchemaSchema = Field(..., alias='schema')
+ strict: Optional[bool] = Field(
+ False,
+ description='Whether to enable strict schema adherence when generating the output.\nIf set to true, the model will always follow the exact schema defined\nin the `schema` field. Only a subset of JSON Schema is supported when\n`strict` is `true`. To learn more, read the [Structured Outputs\nguide](/docs/guides/structured-outputs).\n',
+ )
+ type: Type20 = Field(
+ ...,
+ description='The type of response format being defined. Always `json_schema`.',
+ )
+
+
+class Type21(str, Enum):
+ function = 'function'
+
+
+class ToolChoiceFunction(BaseModel):
+ name: str = Field(..., description='The name of the function to call.')
+ type: Type21 = Field(
+ ..., description='For function calling, the type is always `function`.'
+ )
+
+
+class ToolChoiceOptions(str, Enum):
+ none = 'none'
+ auto = 'auto'
+ required = 'required'
+
+
+class Type22(str, Enum):
+ file_search = 'file_search'
+ web_search_preview = 'web_search_preview'
+ computer_use_preview = 'computer_use_preview'
+ web_search_preview_2025_03_11 = 'web_search_preview_2025_03_11'
+
+
+class ToolChoiceTypes(BaseModel):
+ type: Type22 = Field(
+ ...,
+ description='The type of hosted tool the model should to use. Learn more about\n[built-in tools](/docs/guides/tools).\n\nAllowed values are:\n- `file_search`\n- `web_search_preview`\n- `computer_use_preview`\n',
+ )
+
+
+class TripoAnimation(str, Enum):
+ preset_idle = 'preset:idle'
+ preset_walk = 'preset:walk'
+ preset_climb = 'preset:climb'
+ preset_jump = 'preset:jump'
+ preset_run = 'preset:run'
+ preset_slash = 'preset:slash'
+ preset_shoot = 'preset:shoot'
+ preset_hurt = 'preset:hurt'
+ preset_fall = 'preset:fall'
+ preset_turn = 'preset:turn'
+
+
+class TripoBalance(BaseModel):
+ balance: float
+ frozen: float
+
+
+class TripoConvertFormat(str, Enum):
+ GLTF = 'GLTF'
+ USDZ = 'USDZ'
+ FBX = 'FBX'
+ OBJ = 'OBJ'
+ STL = 'STL'
+ field_3MF = '3MF'
+
+
+class Code(int, Enum):
+ integer_1001 = 1001
+ integer_2000 = 2000
+ integer_2001 = 2001
+ integer_2002 = 2002
+ integer_2003 = 2003
+ integer_2004 = 2004
+ integer_2006 = 2006
+ integer_2007 = 2007
+ integer_2008 = 2008
+ integer_2010 = 2010
+
+
+class TripoErrorResponse(BaseModel):
+ code: Code
+ message: str
+ suggestion: str
+
+
+class TripoImageToModel(str, Enum):
+ image_to_model = 'image_to_model'
+
+
+class TripoModelStyle(str, Enum):
+ person_person2cartoon = 'person:person2cartoon'
+ animal_venom = 'animal:venom'
+ object_clay = 'object:clay'
+ object_steampunk = 'object:steampunk'
+ object_christmas = 'object:christmas'
+ object_barbie = 'object:barbie'
+ gold = 'gold'
+ ancient_bronze = 'ancient_bronze'
+
+
+class TripoModelVersion(str, Enum):
+ V2_5 = 'v2.5-20250123'
+ V2_0 = 'v2.0-20240919'
+ V1_4 = 'v1.4-20240625'
+
+
+class TripoMultiviewMode(str, Enum):
+ LEFT = 'LEFT'
+ RIGHT = 'RIGHT'
+
+
+class TripoMultiviewToModel(str, Enum):
+ multiview_to_model = 'multiview_to_model'
+
+
+class TripoOrientation(str, Enum):
+ align_image = 'align_image'
+ default = 'default'
+
+
+class TripoResponseSuccessCode(RootModel[int]):
+ root: int = Field(
+ ...,
+ description='Standard success code for Tripo API responses. Typically 0 for success.',
+ examples=[0],
+ )
+
+
+class TripoSpec(str, Enum):
+ mixamo = 'mixamo'
+ tripo = 'tripo'
+
+
+class TripoStandardFormat(str, Enum):
+ glb = 'glb'
+ fbx = 'fbx'
+
+
+class TripoStylizeOptions(str, Enum):
+ lego = 'lego'
+ voxel = 'voxel'
+ voronoi = 'voronoi'
+ minecraft = 'minecraft'
+
+
+class Code1(int, Enum):
+ integer_0 = 0
+
+
+class Data8(BaseModel):
+ task_id: str = Field(..., description='used for getTask')
+
+
+class TripoSuccessTask(BaseModel):
+ code: Code1
+ data: Data8
+
+
+class Topology(str, Enum):
+ bip = 'bip'
+ quad = 'quad'
+
+
+class Output(BaseModel):
+ base_model: Optional[str] = None
+ model: Optional[str] = None
+ pbr_model: Optional[str] = None
+ rendered_image: Optional[str] = None
+ riggable: Optional[bool] = None
+ topology: Optional[Topology] = None
+
+
+class Status10(str, Enum):
+ queued = 'queued'
+ running = 'running'
+ success = 'success'
+ failed = 'failed'
+ cancelled = 'cancelled'
+ unknown = 'unknown'
+ banned = 'banned'
+ expired = 'expired'
+
+
+class TripoTask(BaseModel):
+ create_time: int
+ input: Dict[str, Any]
+ output: Output
+ progress: int = Field(..., ge=0, le=100)
+ status: Status10
+ task_id: str
+ type: str
+
+
+class TripoTextToModel(str, Enum):
+ text_to_model = 'text_to_model'
+
+
+class TripoTextureAlignment(str, Enum):
+ original_image = 'original_image'
+ geometry = 'geometry'
+
+
+class TripoTextureFormat(str, Enum):
+ BMP = 'BMP'
+ DPX = 'DPX'
+ HDR = 'HDR'
+ JPEG = 'JPEG'
+ OPEN_EXR = 'OPEN_EXR'
+ PNG = 'PNG'
+ TARGA = 'TARGA'
+ TIFF = 'TIFF'
+ WEBP = 'WEBP'
+
+
+class TripoTextureQuality(str, Enum):
+ standard = 'standard'
+ detailed = 'detailed'
+
+
+class TripoTopology(str, Enum):
+ bip = 'bip'
+ quad = 'quad'
+
+
+class TripoTypeAnimatePrerigcheck(str, Enum):
+ animate_prerigcheck = 'animate_prerigcheck'
+
+
+class TripoTypeAnimateRetarget(str, Enum):
+ animate_retarget = 'animate_retarget'
+
+
+class TripoTypeAnimateRig(str, Enum):
+ animate_rig = 'animate_rig'
+
+
+class TripoTypeConvertModel(str, Enum):
+ convert_model = 'convert_model'
+
+
+class TripoTypeRefineModel(str, Enum):
+ refine_model = 'refine_model'
+
+
+class TripoTypeStylizeModel(str, Enum):
+ stylize_model = 'stylize_model'
+
+
+class TripoTypeTextureModel(str, Enum):
+ texture_model = 'texture_model'
+
+
+class Veo2GenVidPollRequest(BaseModel):
+ operationName: str = Field(
+ ...,
+ description='Full operation name (from predict response)',
+ examples=[
+ 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID'
+ ],
+ )
+
+
+class Error(BaseModel):
+ code: Optional[int] = Field(None, description='Error code')
+ message: Optional[str] = Field(None, description='Error message')
+
+
+class Video(BaseModel):
+ bytesBase64Encoded: Optional[str] = Field(
+ None, description='Base64-encoded video content'
+ )
+ gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video')
+ mimeType: Optional[str] = Field(None, description='Video MIME type')
+
+
+class Response(BaseModel):
+ field_type: Optional[str] = Field(
+ None,
+ alias='@type',
+ examples=[
+ 'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse'
+ ],
+ )
+ raiMediaFilteredCount: Optional[int] = Field(
+ None, description='Count of media filtered by responsible AI policies'
+ )
+ raiMediaFilteredReasons: Optional[List[str]] = Field(
+ None, description='Reasons why media was filtered by responsible AI policies'
+ )
+ videos: Optional[List[Video]] = None
+
+
+class Veo2GenVidPollResponse(BaseModel):
+ done: Optional[bool] = None
+ error: Optional[Error] = Field(
+ None, description='Error details if operation failed'
+ )
+ name: Optional[str] = None
+ response: Optional[Response] = Field(
+ None, description='The actual prediction response if done is true'
+ )
+
+
+class Image(BaseModel):
+ bytesBase64Encoded: str
+ gcsUri: Optional[str] = None
+ mimeType: Optional[str] = None
+
+
+class Image1(BaseModel):
+ bytesBase64Encoded: Optional[str] = None
+ gcsUri: str
+ mimeType: Optional[str] = None
+
+
+class Instance(BaseModel):
+ image: Optional[Union[Image, Image1]] = Field(
+ None, description='Optional image to guide video generation'
+ )
+ prompt: str = Field(..., description='Text description of the video')
+
+
+class PersonGeneration1(str, Enum):
+ ALLOW = 'ALLOW'
+ BLOCK = 'BLOCK'
+
+
+class Parameters(BaseModel):
+ aspectRatio: Optional[str] = Field(None, examples=['16:9'])
+ durationSeconds: Optional[int] = None
+ enhancePrompt: Optional[bool] = None
+ negativePrompt: Optional[str] = None
+ personGeneration: Optional[PersonGeneration1] = None
+ sampleCount: Optional[int] = None
+ seed: Optional[int] = None
+ storageUri: Optional[str] = Field(
+ None, description='Optional Cloud Storage URI to upload the video'
+ )
+
+
+class Veo2GenVidRequest(BaseModel):
+ instances: Optional[List[Instance]] = None
+ parameters: Optional[Parameters] = None
+
+
+class Veo2GenVidResponse(BaseModel):
+ name: str = Field(
+ ...,
+ description='Operation resource name',
+ examples=[
+ 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8'
+ ],
+ )
+
+
+class SearchContextSize(str, Enum):
+ low = 'low'
+ medium = 'medium'
+ high = 'high'
+
+
+class Type23(str, Enum):
+ web_search_preview = 'web_search_preview'
+ web_search_preview_2025_03_11 = 'web_search_preview_2025_03_11'
+
+
+class WebSearchPreviewTool(BaseModel):
+ search_context_size: Optional[SearchContextSize] = Field(
+ None,
+ description='High level guidance for the amount of context window space to use for the search. One of `low`, `medium`, or `high`. `medium` is the default.',
+ )
+ type: Literal['WebSearchPreviewTool'] = Field(
+ ...,
+ description='The type of the web search tool. One of `web_search_preview` or `web_search_preview_2025_03_11`.',
+ )
+
+
+class Status11(str, Enum):
+ in_progress = 'in_progress'
+ searching = 'searching'
+ completed = 'completed'
+ failed = 'failed'
+
+
+class Type24(str, Enum):
+ web_search_call = 'web_search_call'
+
+
+class WebSearchToolCall(BaseModel):
+ id: str = Field(..., description='The unique ID of the web search tool call.\n')
+ status: Status11 = Field(
+ ..., description='The status of the web search tool call.\n'
+ )
+ type: Type24 = Field(
+ ...,
+ description='The type of the web search tool call. Always `web_search_call`.\n',
+ )
+
+
+class CreateModelResponseProperties(ModelResponseProperties):
+ pass
+
+
+class GeminiInlineData(BaseModel):
+ data: Optional[str] = Field(
+ None,
+ description='The base64 encoding of the image, PDF, or video to include inline in the prompt. When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB\n',
+ )
+ mimeType: Optional[GeminiMimeType] = None
+
+
+class GeminiPart(BaseModel):
+ inlineData: Optional[GeminiInlineData] = None
+ text: Optional[str] = Field(
+ None,
+ description='A text prompt or code snippet.',
+ examples=['Write a story about a robot learning to paint'],
+ )
+
+
+class GeminiPromptFeedback(BaseModel):
+ blockReason: Optional[str] = None
+ blockReasonMessage: Optional[str] = None
+ safetyRatings: Optional[List[GeminiSafetyRating]] = None
+
+
+class GeminiSafetySetting(BaseModel):
+ category: GeminiSafetyCategory
+ threshold: GeminiSafetyThreshold
+
+
+class GeminiSystemInstructionContent(BaseModel):
+ parts: List[GeminiTextPart] = Field(
+ ...,
+ description='A list of ordered parts that make up a single message. Different parts may have different IANA MIME types. For limits on the inputs, such as the maximum number of tokens or the number of images, see the model specifications on the Google models page.\n',
+ )
+ role: Role1 = Field(
+ ...,
+ description='The identity of the entity that creates the message. The following values are supported: user: This indicates that the message is sent by a real person, typically a user-generated message. model: This indicates that the message is generated by the model. The model value is used to insert messages from the model into the conversation during multi-turn conversations. For non-multi-turn conversations, this field can be left blank or unset.\n',
+ examples=['user'],
+ )
+
+
+class IdeogramV3EditRequest(BaseModel):
+ color_palette: Optional[IdeogramColorPalette] = None
+ image: Optional[StrictBytes] = Field(
+ None,
+ description='The image being edited (max size 10MB); only JPEG, WebP and PNG formats are supported at this time.',
+ )
+ magic_prompt: Optional[str] = Field(
+ None,
+ description='Determine if MagicPrompt should be used in generating the request or not.',
+ )
+ mask: Optional[StrictBytes] = Field(
+ None,
+ description='A black and white image of the same size as the image being edited (max size 10MB). Black regions in the mask should match up with the regions of the image that you would like to edit; only JPEG, WebP and PNG formats are supported at this time.',
+ )
+ num_images: Optional[int] = Field(
+ None, description='The number of images to generate.'
+ )
+ prompt: str = Field(
+ ..., description='The prompt used to describe the edited result.'
+ )
+ rendering_speed: RenderingSpeed
+ seed: Optional[int] = Field(
+ None, description='Random seed. Set for reproducible generation.'
+ )
+ style_codes: Optional[List[StyleCode]] = Field(
+ None,
+ description='A list of 8 character hexadecimal codes representing the style of the image. Cannot be used in conjunction with style_reference_images or style_type.',
+ )
+ style_reference_images: Optional[List[StrictBytes]] = Field(
+ None,
+ description='A set of images to use as style references (maximum total size 10MB across all style references). The images should be in JPEG, PNG or WebP format.',
+ )
+
+
+class IdeogramV3Request(BaseModel):
+ aspect_ratio: Optional[str] = Field(
+ None, description='Aspect ratio in format WxH', examples=['1x3']
+ )
+ color_palette: Optional[ColorPalette] = None
+ magic_prompt: Optional[MagicPrompt2] = Field(
+ None, description='Whether to enable magic prompt enhancement'
+ )
+ negative_prompt: Optional[str] = Field(
+ None, description='Text prompt specifying what to avoid in the generation'
+ )
+ num_images: Optional[int] = Field(
+ None, description='Number of images to generate', ge=1
+ )
+ prompt: str = Field(..., description='The text prompt for image generation')
+ rendering_speed: RenderingSpeed
+ resolution: Optional[str] = Field(
+ None, description='Image resolution in format WxH', examples=['1280x800']
+ )
+ seed: Optional[int] = Field(
+ None, description='Seed value for reproducible generation'
+ )
+ style_codes: Optional[List[StyleCode]] = Field(
+ None, description='Array of style codes in hexadecimal format'
+ )
+ style_reference_images: Optional[List[str]] = Field(
+ None, description='Array of reference image URLs or identifiers'
+ )
+ style_type: Optional[StyleType1] = Field(
+ None, description='The type of style to apply'
+ )
+
+
+class ImagenGenerateImageResponse(BaseModel):
+ predictions: Optional[List[ImagenImagePrediction]] = None
+
+
+class ImagenImageGenerationParameters(BaseModel):
+ addWatermark: Optional[bool] = None
+ aspectRatio: Optional[AspectRatio] = None
+ enhancePrompt: Optional[bool] = None
+ includeRaiReason: Optional[bool] = None
+ includeSafetyAttributes: Optional[bool] = None
+ outputOptions: Optional[ImagenOutputOptions] = None
+ personGeneration: Optional[PersonGeneration] = None
+ safetySetting: Optional[SafetySetting] = None
+ sampleCount: Optional[int] = Field(None, ge=1, le=4)
+ seed: Optional[int] = None
+ storageUri: Optional[AnyUrl] = None
+
+
+class InputContent(
+ RootModel[Union[InputTextContent, InputImageContent, InputFileContent]]
+):
+ root: Union[InputTextContent, InputImageContent, InputFileContent]
+
+
+class InputMessageContentList(RootModel[List[InputContent]]):
+ root: List[InputContent] = Field(
+ ...,
+ description='A list of one or many input items to the model, containing different content \ntypes.\n',
+ title='Input item content list',
+ )
+
+
+class KlingCameraControl(BaseModel):
+ config: Optional[KlingCameraConfig] = None
+ type: Optional[KlingCameraControlType] = None
+
+
+class KlingDualCharacterEffectInput(BaseModel):
+ duration: KlingVideoGenDuration
+ images: KlingDualCharacterImages
+ mode: Optional[KlingVideoGenMode] = 'std'
+ model_name: Optional[KlingCharacterEffectModelName] = 'kling-v1'
+
+
+class KlingImage2VideoRequest(BaseModel):
+ aspect_ratio: Optional[KlingVideoGenAspectRatio] = '16:9'
+ callback_url: Optional[AnyUrl] = Field(
+ None,
+ description='The callback notification address. Server will notify when the task status changes.',
+ )
+ camera_control: Optional[KlingCameraControl] = None
+ cfg_scale: Optional[KlingVideoGenCfgScale] = Field(
+ default_factory=lambda: KlingVideoGenCfgScale.model_validate(0.5)
+ )
+ duration: Optional[KlingVideoGenDuration] = '5'
+ dynamic_masks: Optional[List[DynamicMask]] = Field(
+ None,
+ description='Dynamic Brush Configuration List (up to 6 groups). For 5-second videos, trajectory length must not exceed 77 coordinates.',
+ )
+ external_task_id: Optional[str] = Field(
+ None,
+ description='Customized Task ID. Must be unique within a single user account.',
+ )
+ image: Optional[str] = Field(
+ None,
+ description='Reference Image - URL or Base64 encoded string, cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1. Base64 should not include data:image prefix.',
+ )
+ image_tail: Optional[str] = Field(
+ None,
+ description='Reference Image - End frame control. URL or Base64 encoded string, cannot exceed 10MB, resolution not less than 300*300px. Base64 should not include data:image prefix.',
+ )
+ mode: Optional[KlingVideoGenMode] = 'std'
+ model_name: Optional[KlingVideoGenModelName] = 'kling-v2-master'
+ negative_prompt: Optional[str] = Field(
+ None, description='Negative text prompt', max_length=2500
+ )
+ prompt: Optional[str] = Field(
+ None, description='Positive text prompt', max_length=2500
+ )
+ static_mask: Optional[str] = Field(
+ None,
+ description='Static Brush Application Area (Mask image created by users using the motion brush). The aspect ratio must match the input image.',
+ )
+
+
+class TaskResult(BaseModel):
+ videos: Optional[List[KlingVideoResult]] = None
+
+
+class Data(BaseModel):
+ created_at: Optional[int] = Field(None, description='Task creation time')
+ task_id: Optional[str] = Field(None, description='Task ID')
+ task_info: Optional[TaskInfo] = None
+ task_result: Optional[TaskResult] = None
+ task_status: Optional[KlingTaskStatus] = None
+ updated_at: Optional[int] = Field(None, description='Task update time')
+
+
+class KlingImage2VideoResponse(BaseModel):
+ code: Optional[int] = Field(None, description='Error code')
+ data: Optional[Data] = None
+ message: Optional[str] = Field(None, description='Error message')
+ request_id: Optional[str] = Field(None, description='Request ID')
+
+
+class TaskResult1(BaseModel):
+ images: Optional[List[KlingImageResult]] = None
+
+
+class Data1(BaseModel):
+ created_at: Optional[int] = Field(None, description='Task creation time')
+ task_id: Optional[str] = Field(None, description='Task ID')
+ task_result: Optional[TaskResult1] = None
+ task_status: Optional[KlingTaskStatus] = None
+ task_status_msg: Optional[str] = Field(None, description='Task status information')
+ updated_at: Optional[int] = Field(None, description='Task update time')
+
+
+class KlingImageGenerationsResponse(BaseModel):
+ code: Optional[int] = Field(None, description='Error code')
+ data: Optional[Data1] = None
+ message: Optional[str] = Field(None, description='Error message')
+ request_id: Optional[str] = Field(None, description='Request ID')
+
+
+class KlingLipSyncInputObject(BaseModel):
+ audio_file: Optional[str] = Field(
+ None,
+ description='Local Path of Audio File. Supported formats: .mp3/.wav/.m4a/.aac, maximum file size of 5MB. Base64 code.',
+ )
+ audio_type: Optional[KlingAudioUploadType] = None
+ audio_url: Optional[str] = Field(
+ None,
+ description='Audio File Download URL. Supported formats: .mp3/.wav/.m4a/.aac, maximum file size of 5MB.',
+ )
+ mode: KlingLipSyncMode
+ text: Optional[str] = Field(
+ None,
+ description='Text Content for Lip-Sync Video Generation. Required when mode is text2video. Maximum length is 120 characters.',
+ )
+ video_id: Optional[str] = Field(
+ None,
+ description='The ID of the video generated by Kling AI. Only supports 5-second and 10-second videos generated within the last 30 days.',
+ )
+ video_url: Optional[str] = Field(
+ None,
+ description='Get link for uploaded video. Video files support .mp4/.mov, file size does not exceed 100MB, video length between 2-10s.',
+ )
+ voice_id: Optional[str] = Field(
+ None,
+ description='Voice ID. Required when mode is text2video. The system offers a variety of voice options to choose from.',
+ )
+ voice_language: Optional[KlingLipSyncVoiceLanguage] = 'en'
+ voice_speed: Optional[float] = Field(
+ 1,
+ description='Speech Rate. Valid range: 0.8~2.0, accurate to one decimal place.',
+ ge=0.8,
+ le=2.0,
+ )
+
+
+class KlingLipSyncRequest(BaseModel):
+ callback_url: Optional[AnyUrl] = Field(
+ None,
+ description='The callback notification address. Server will notify when the task status changes.',
+ )
+ input: KlingLipSyncInputObject
+
+
+class TaskResult2(BaseModel):
+ videos: Optional[List[KlingVideoResult]] = None
+
+
+class Data2(BaseModel):
+ created_at: Optional[int] = Field(None, description='Task creation time')
+ task_id: Optional[str] = Field(None, description='Task ID')
+ task_info: Optional[TaskInfo] = None
+ task_result: Optional[TaskResult2] = None
+ task_status: Optional[KlingTaskStatus] = None
+ updated_at: Optional[int] = Field(None, description='Task update time')
+
+
+class KlingLipSyncResponse(BaseModel):
+ code: Optional[int] = Field(None, description='Error code')
+ data: Optional[Data2] = None
+ message: Optional[str] = Field(None, description='Error message')
+ request_id: Optional[str] = Field(None, description='Request ID')
+
+
+class KlingSingleImageEffectInput(BaseModel):
+ duration: KlingSingleImageEffectDuration
+ image: str = Field(
+ ...,
+ description='Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1.',
+ )
+ model_name: KlingSingleImageEffectModelName
+
+
+class KlingText2VideoRequest(BaseModel):
+ aspect_ratio: Optional[KlingVideoGenAspectRatio] = '16:9'
+ callback_url: Optional[AnyUrl] = Field(
+ None, description='The callback notification address'
+ )
+ camera_control: Optional[KlingCameraControl] = None
+ cfg_scale: Optional[KlingVideoGenCfgScale] = Field(
+ default_factory=lambda: KlingVideoGenCfgScale.model_validate(0.5)
+ )
+ duration: Optional[KlingVideoGenDuration] = '5'
+ external_task_id: Optional[str] = Field(None, description='Customized Task ID')
+ mode: Optional[KlingVideoGenMode] = 'std'
+ model_name: Optional[KlingTextToVideoModelName] = 'kling-v1'
+ negative_prompt: Optional[str] = Field(
+ None, description='Negative text prompt', max_length=2500
+ )
+ prompt: Optional[str] = Field(
+ None, description='Positive text prompt', max_length=2500
+ )
+
+
+class Data4(BaseModel):
+ created_at: Optional[int] = Field(None, description='Task creation time')
+ task_id: Optional[str] = Field(None, description='Task ID')
+ task_info: Optional[TaskInfo] = None
+ task_result: Optional[TaskResult2] = None
+ task_status: Optional[KlingTaskStatus] = None
+ updated_at: Optional[int] = Field(None, description='Task update time')
+
+
+class KlingText2VideoResponse(BaseModel):
+ code: Optional[int] = Field(None, description='Error code')
+ data: Optional[Data4] = None
+ message: Optional[str] = Field(None, description='Error message')
+ request_id: Optional[str] = Field(None, description='Request ID')
+
+
+class KlingVideoEffectsInput(
+ RootModel[Union[KlingSingleImageEffectInput, KlingDualCharacterEffectInput]]
+):
+ root: Union[KlingSingleImageEffectInput, KlingDualCharacterEffectInput]
+
+
+class KlingVideoEffectsRequest(BaseModel):
+ callback_url: Optional[AnyUrl] = Field(
+ None,
+ description='The callback notification address for the result of this task.',
+ )
+ effect_scene: Union[KlingDualCharacterEffectsScene, KlingSingleImageEffectsScene]
+ external_task_id: Optional[str] = Field(
+ None,
+ description='Customized Task ID. Must be unique within a single user account.',
+ )
+ input: KlingVideoEffectsInput
+
+
+class Data5(BaseModel):
+ created_at: Optional[int] = Field(None, description='Task creation time')
+ task_id: Optional[str] = Field(None, description='Task ID')
+ task_info: Optional[TaskInfo] = None
+ task_result: Optional[TaskResult2] = None
+ task_status: Optional[KlingTaskStatus] = None
+ updated_at: Optional[int] = Field(None, description='Task update time')
+
+
+class KlingVideoEffectsResponse(BaseModel):
+ code: Optional[int] = Field(None, description='Error code')
+ data: Optional[Data5] = None
+ message: Optional[str] = Field(None, description='Error message')
+ request_id: Optional[str] = Field(None, description='Request ID')
+
+
+class KlingVideoExtendRequest(BaseModel):
+ callback_url: Optional[AnyUrl] = Field(
+ None,
+ description='The callback notification address. Server will notify when the task status changes.',
+ )
+ cfg_scale: Optional[KlingVideoGenCfgScale] = Field(
+ default_factory=lambda: KlingVideoGenCfgScale.model_validate(0.5)
+ )
+ negative_prompt: Optional[str] = Field(
+ None,
+ description='Negative text prompt for elements to avoid in the extended video',
+ max_length=2500,
+ )
+ prompt: Optional[str] = Field(
+ None,
+ description='Positive text prompt for guiding the video extension',
+ max_length=2500,
+ )
+ video_id: Optional[str] = Field(
+ None,
+ description='The ID of the video to be extended. Supports videos generated by text-to-video, image-to-video, and previous video extension operations. Cannot exceed 3 minutes total duration after extension.',
+ )
+
+
+class Data6(BaseModel):
+ created_at: Optional[int] = Field(None, description='Task creation time')
+ task_id: Optional[str] = Field(None, description='Task ID')
+ task_info: Optional[TaskInfo] = None
+ task_result: Optional[TaskResult2] = None
+ task_status: Optional[KlingTaskStatus] = None
+ updated_at: Optional[int] = Field(None, description='Task update time')
+
+
+class KlingVideoExtendResponse(BaseModel):
+ code: Optional[int] = Field(None, description='Error code')
+ data: Optional[Data6] = None
+ message: Optional[str] = Field(None, description='Error message')
+ request_id: Optional[str] = Field(None, description='Request ID')
+
+
+class LumaGenerationRequest(BaseModel):
+ aspect_ratio: LumaAspectRatio
+ callback_url: Optional[AnyUrl] = Field(
+ None,
+ description='The callback URL of the generation, a POST request with Generation object will be sent to the callback URL when the generation is dreaming, completed, or failed',
+ )
+ duration: LumaVideoModelOutputDuration
+ generation_type: Optional[GenerationType1] = 'video'
+ keyframes: Optional[LumaKeyframes] = None
+ loop: Optional[bool] = Field(None, description='Whether to loop the video')
+ model: LumaVideoModel
+ prompt: str = Field(..., description='The prompt of the generation')
+ resolution: LumaVideoModelOutputResolution
+
+
+class CharacterRef(BaseModel):
+ identity0: Optional[LumaImageIdentity] = None
+
+
+class LumaImageGenerationRequest(BaseModel):
+ aspect_ratio: Optional[LumaAspectRatio] = '16:9'
+ callback_url: Optional[AnyUrl] = Field(
+ None, description='The callback URL for the generation'
+ )
+ character_ref: Optional[CharacterRef] = None
+ generation_type: Optional[GenerationType2] = 'image'
+ image_ref: Optional[List[LumaImageRef]] = None
+ model: Optional[LumaImageModel] = 'photon-1'
+ modify_image_ref: Optional[LumaModifyImageRef] = None
+ prompt: Optional[str] = Field(None, description='The prompt of the generation')
+ style_ref: Optional[List[LumaImageRef]] = None
+
+
+class LumaUpscaleVideoGenerationRequest(BaseModel):
+ callback_url: Optional[AnyUrl] = Field(
+ None, description='The callback URL for the upscale'
+ )
+ generation_type: Optional[GenerationType3] = 'upscale_video'
+ resolution: Optional[LumaVideoModelOutputResolution] = None
+
+
+class OutputContent(RootModel[Union[OutputTextContent, OutputAudioContent]]):
+ root: Union[OutputTextContent, OutputAudioContent]
+
+
+class OutputMessage(BaseModel):
+ content: List[OutputContent] = Field(..., description='The content of the message')
+ role: Role4 = Field(..., description='The role of the message')
+ type: Type14 = Field(..., description='The type of output item')
+
+
+class PikaBodyGenerate22I2vGenerate22I2vPost(BaseModel):
+ duration: Optional[PikaDurationEnum] = 5
+ image: Optional[StrictBytes] = Field(None, title='Image')
+ negativePrompt: Optional[str] = Field(None, title='Negativeprompt')
+ promptText: Optional[str] = Field(None, title='Prompttext')
+ resolution: Optional[PikaResolutionEnum] = '1080p'
+ seed: Optional[int] = Field(None, title='Seed')
+
+
+class PikaBodyGenerate22KeyframeGenerate22PikaframesPost(BaseModel):
+ duration: Optional[int] = Field(None, ge=5, le=10, title='Duration')
+ keyFrames: Optional[List[StrictBytes]] = Field(
+ None, description='Array of keyframe images', title='Keyframes'
+ )
+ negativePrompt: Optional[str] = Field(None, title='Negativeprompt')
+ promptText: str = Field(..., title='Prompttext')
+ resolution: Optional[PikaResolutionEnum] = '1080p'
+ seed: Optional[int] = Field(None, title='Seed')
+
+
+class PikaBodyGenerate22T2vGenerate22T2vPost(BaseModel):
+ aspectRatio: Optional[float] = Field(
+ 1.7777777777777777,
+ description='Aspect ratio (width / height)',
+ ge=0.4,
+ le=2.5,
+ title='Aspectratio',
+ )
+ duration: Optional[PikaDurationEnum] = 5
+ negativePrompt: Optional[str] = Field(None, title='Negativeprompt')
+ promptText: str = Field(..., title='Prompttext')
+ resolution: Optional[PikaResolutionEnum] = '1080p'
+ seed: Optional[int] = Field(None, title='Seed')
+
+
+class PikaBodyGeneratePikaffectsGeneratePikaffectsPost(BaseModel):
+ image: Optional[StrictBytes] = Field(None, title='Image')
+ negativePrompt: Optional[str] = Field(None, title='Negativeprompt')
+ pikaffect: Optional[Pikaffect] = None
+ promptText: Optional[str] = Field(None, title='Prompttext')
+ seed: Optional[int] = Field(None, title='Seed')
+
+
+class PikaHTTPValidationError(BaseModel):
+ detail: Optional[List[PikaValidationError]] = Field(None, title='Detail')
+
+
+class Reasoning(BaseModel):
+ effort: Optional[ReasoningEffort] = 'medium'
+ generate_summary: Optional[GenerateSummary] = Field(
+ None,
+ description="**Deprecated:** use `summary` instead.\n\nA summary of the reasoning performed by the model. This can be\nuseful for debugging and understanding the model's reasoning process.\nOne of `auto`, `concise`, or `detailed`.\n",
+ )
+ summary: Optional[Summary] = Field(
+ None,
+ description="A summary of the reasoning performed by the model. This can be\nuseful for debugging and understanding the model's reasoning process.\nOne of `auto`, `concise`, or `detailed`.\n",
+ )
+
+
+class ResponseError(BaseModel):
+ code: ResponseErrorCode
+ message: str = Field(..., description='A human-readable description of the error.')
+
+
+class Rodin3DDownloadResponse(BaseModel):
+ list: Optional[RodinResourceItem] = None
+
+
+class Rodin3DGenerateRequest(BaseModel):
+ images: str = Field(..., description='The reference images to generate 3D Assets.')
+ material: Optional[RodinMaterialType] = None
+ mesh_mode: Optional[RodinMeshModeType] = None
+ quality: Optional[RodinQualityType] = None
+ seed: Optional[int] = Field(None, description='Seed.')
+ tier: Optional[RodinTierType] = None
+
+
+class Rodin3DGenerateResponse(BaseModel):
+ jobs: Optional[RodinGenerateJobsData] = None
+ message: Optional[str] = Field(None, description='message')
+ prompt: Optional[str] = Field(None, description='prompt')
+ submit_time: Optional[str] = Field(None, description='Time')
+ uuid: Optional[str] = Field(None, description='Task UUID')
+
+
+class RunwayImageToVideoRequest(BaseModel):
+ duration: RunwayDurationEnum
+ model: RunwayModelEnum
+ promptImage: RunwayPromptImageObject
+ promptText: Optional[str] = Field(
+ None, description='Text prompt for the generation', max_length=1000
+ )
+ ratio: RunwayAspectRatioEnum
+ seed: int = Field(
+ ..., description='Random seed for generation', ge=0, le=4294967295
+ )
+
+
+class TextResponseFormatConfiguration(
+ RootModel[
+ Union[
+ ResponseFormatText, TextResponseFormatJsonSchema, ResponseFormatJsonObject
+ ]
+ ]
+):
+ root: Union[
+ ResponseFormatText, TextResponseFormatJsonSchema, ResponseFormatJsonObject
+ ] = Field(
+ ...,
+ description='An object specifying the format that the model must output.\n\nConfiguring `{ "type": "json_schema" }` enables Structured Outputs, \nwhich ensures the model will match your supplied JSON schema. Learn more in the \n[Structured Outputs guide](/docs/guides/structured-outputs).\n\nThe default format is `{ "type": "text" }` with no additional options.\n\n**Not recommended for gpt-4o and newer models:**\n\nSetting to `{ "type": "json_object" }` enables the older JSON mode, which\nensures the message the model generates is valid JSON. Using `json_schema`\nis preferred for models that support it.\n',
+ )
+
+
+class Tool(
+ RootModel[
+ Union[
+ FileSearchTool, FunctionTool, WebSearchPreviewTool, ComputerUsePreviewTool
+ ]
+ ]
+):
+ root: Union[
+ FileSearchTool, FunctionTool, WebSearchPreviewTool, ComputerUsePreviewTool
+ ] = Field(..., discriminator='type')
+
+
+class EasyInputMessage(BaseModel):
+ content: Union[str, InputMessageContentList] = Field(
+ ...,
+ description='Text, image, or audio input to the model, used to generate a response.\nCan also contain previous assistant responses.\n',
+ )
+ role: Role = Field(
+ ...,
+ description='The role of the message input. One of `user`, `assistant`, `system`, or\n`developer`.\n',
+ )
+ type: Optional[Type2] = Field(
+ None, description='The type of the message input. Always `message`.\n'
+ )
+
+
+class GeminiContent(BaseModel):
+ parts: List[GeminiPart]
+ role: Role1 = Field(..., examples=['user'])
+
+
+class GeminiGenerateContentRequest(BaseModel):
+ contents: List[GeminiContent]
+ generationConfig: Optional[GeminiGenerationConfig] = None
+ safetySettings: Optional[List[GeminiSafetySetting]] = None
+ systemInstruction: Optional[GeminiSystemInstructionContent] = None
+ tools: Optional[List[GeminiTool]] = None
+ videoMetadata: Optional[GeminiVideoMetadata] = None
+
+
+class ImagenGenerateImageRequest(BaseModel):
+ instances: List[ImagenImageGenerationInstance]
+ parameters: ImagenImageGenerationParameters
+
+
+class InputMessage(BaseModel):
+ content: Optional[InputMessageContentList] = None
+ role: Optional[Role3] = None
+ status: Optional[Status2] = None
+ type: Optional[Type9] = None
+
+
+class Item(
+ RootModel[
+ Union[
+ InputMessage,
+ OutputMessage,
+ FileSearchToolCall,
+ ComputerToolCall,
+ WebSearchToolCall,
+ FunctionToolCall,
+ ReasoningItem,
+ ]
+ ]
+):
+ root: Union[
+ InputMessage,
+ OutputMessage,
+ FileSearchToolCall,
+ ComputerToolCall,
+ WebSearchToolCall,
+ FunctionToolCall,
+ ReasoningItem,
+ ] = Field(..., description='Content item used to generate a response.\n')
+
+
+class LumaGeneration(BaseModel):
+ assets: Optional[LumaAssets] = None
+ created_at: Optional[datetime] = Field(
+ None, description='The date and time when the generation was created'
+ )
+ failure_reason: Optional[str] = Field(
+ None, description='The reason for the state of the generation'
+ )
+ generation_type: Optional[LumaGenerationType] = None
+ id: Optional[UUID] = Field(None, description='The ID of the generation')
+ model: Optional[str] = Field(None, description='The model used for the generation')
+ request: Optional[
+ Union[
+ LumaGenerationRequest,
+ LumaImageGenerationRequest,
+ LumaUpscaleVideoGenerationRequest,
+ LumaAudioGenerationRequest,
+ ]
+ ] = Field(None, description='The request of the generation')
+ state: Optional[LumaState] = None
+
+
+class OutputItem(
+ RootModel[
+ Union[
+ OutputMessage,
+ FileSearchToolCall,
+ FunctionToolCall,
+ WebSearchToolCall,
+ ComputerToolCall,
+ ReasoningItem,
+ ]
+ ]
+):
+ root: Union[
+ OutputMessage,
+ FileSearchToolCall,
+ FunctionToolCall,
+ WebSearchToolCall,
+ ComputerToolCall,
+ ReasoningItem,
+ ]
+
+
+class Text(BaseModel):
+ format: Optional[TextResponseFormatConfiguration] = None
+
+
+class ResponseProperties(BaseModel):
+ instructions: Optional[str] = Field(
+ None,
+ description="Inserts a system (or developer) message as the first item in the model's context.\n\nWhen using along with `previous_response_id`, the instructions from a previous\nresponse will not be carried over to the next response. This makes it simple\nto swap out system (or developer) messages in new responses.\n",
+ )
+ max_output_tokens: Optional[int] = Field(
+ None,
+ description='An upper bound for the number of tokens that can be generated for a response, including visible output tokens and [reasoning tokens](/docs/guides/reasoning).\n',
+ )
+ model: Optional[OpenAIModels] = None
+ previous_response_id: Optional[str] = Field(
+ None,
+ description='The unique ID of the previous response to the model. Use this to\ncreate multi-turn conversations. Learn more about \n[conversation state](/docs/guides/conversation-state).\n',
+ )
+ reasoning: Optional[Reasoning] = None
+ text: Optional[Text] = None
+ tool_choice: Optional[
+ Union[ToolChoiceOptions, ToolChoiceTypes, ToolChoiceFunction]
+ ] = Field(
+ None,
+ description='How the model should select which tool (or tools) to use when generating\na response. See the `tools` parameter to see how to specify which tools\nthe model can call.\n',
+ )
+ tools: Optional[List[Tool]] = None
+ truncation: Optional[Truncation1] = Field(
+ 'disabled',
+ description="The truncation strategy to use for the model response.\n- `auto`: If the context of this response and previous ones exceeds\n the model's context window size, the model will truncate the \n response to fit the context window by dropping input items in the\n middle of the conversation. \n- `disabled` (default): If a model response will exceed the context window \n size for a model, the request will fail with a 400 error.\n",
+ )
+
+
+class GeminiCandidate(BaseModel):
+ citationMetadata: Optional[GeminiCitationMetadata] = None
+ content: Optional[GeminiContent] = None
+ finishReason: Optional[str] = None
+ safetyRatings: Optional[List[GeminiSafetyRating]] = None
+
+
+class GeminiGenerateContentResponse(BaseModel):
+ candidates: Optional[List[GeminiCandidate]] = None
+ promptFeedback: Optional[GeminiPromptFeedback] = None
+
+
+class InputItem(RootModel[Union[EasyInputMessage, Item]]):
+ root: Union[EasyInputMessage, Item]
+
+
+class OpenAICreateResponse(CreateModelResponseProperties, ResponseProperties):
+ include: Optional[List[Includable]] = Field(
+ None,
+ description='Specify additional output data to include in the model response. Currently\nsupported values are:\n- `file_search_call.results`: Include the search results of\n the file search tool call.\n- `message.input_image.image_url`: Include image urls from the input message.\n- `computer_call_output.output.image_url`: Include image urls from the computer call output.\n',
+ )
+ input: Union[str, List[InputItem]] = Field(
+ ...,
+ description='Text, image, or file inputs to the model, used to generate a response.\n\nLearn more:\n- [Text inputs and outputs](/docs/guides/text)\n- [Image inputs](/docs/guides/images)\n- [File inputs](/docs/guides/pdf-files)\n- [Conversation state](/docs/guides/conversation-state)\n- [Function calling](/docs/guides/function-calling)\n',
+ )
+ parallel_tool_calls: Optional[bool] = Field(
+ True, description='Whether to allow the model to run tool calls in parallel.\n'
+ )
+ store: Optional[bool] = Field(
+ True,
+ description='Whether to store the generated model response for later retrieval via\nAPI.\n',
+ )
+ stream: Optional[bool] = Field(
+ False,
+ description='If set to true, the model response data will be streamed to the client\nas it is generated using [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format).\nSee the [Streaming section below](/docs/api-reference/responses-streaming)\nfor more information.\n',
+ )
+ usage: Optional[ResponseUsage] = None
+
+
+class OpenAIResponse(ModelResponseProperties, ResponseProperties):
+ created_at: Optional[float] = Field(
+ None,
+ description='Unix timestamp (in seconds) of when this Response was created.',
+ )
+ error: Optional[ResponseError] = None
+ id: Optional[str] = Field(None, description='Unique identifier for this Response.')
+ incomplete_details: Optional[IncompleteDetails] = Field(
+ None, description='Details about why the response is incomplete.\n'
+ )
+ object: Optional[Object] = Field(
+ None, description='The object type of this resource - always set to `response`.'
+ )
+ output: Optional[List[OutputItem]] = Field(
+ None,
+ description="An array of content items generated by the model.\n\n- The length and order of items in the `output` array is dependent\n on the model's response.\n- Rather than accessing the first item in the `output` array and \n assuming it's an `assistant` message with the content generated by\n the model, you might consider using the `output_text` property where\n supported in SDKs.\n",
+ )
+ output_text: Optional[str] = Field(
+ None,
+ description='SDK-only convenience property that contains the aggregated text output \nfrom all `output_text` items in the `output` array, if any are present. \nSupported in the Python and JavaScript SDKs.\n',
+ )
+ parallel_tool_calls: Optional[bool] = Field(
+ True, description='Whether to allow the model to run tool calls in parallel.\n'
+ )
+ status: Optional[Status6] = Field(
+ None,
+ description='The status of the response generation. One of `completed`, `failed`, `in_progress`, or `incomplete`.',
+ )
+ usage: Optional[ResponseUsage] = None
diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl_api.py
new file mode 100644
index 000000000..c189038fb
--- /dev/null
+++ b/comfy_api_nodes/apis/bfl_api.py
@@ -0,0 +1,156 @@
+from __future__ import annotations
+
+from enum import Enum
+from typing import Any, Dict, Optional
+
+from pydantic import BaseModel, Field, confloat, conint
+
+
+class BFLOutputFormat(str, Enum):
+ png = 'png'
+ jpeg = 'jpeg'
+
+
+class BFLFluxExpandImageRequest(BaseModel):
+ prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
+ prompt_upsampling: Optional[bool] = Field(
+ None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
+ )
+ seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
+ top: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the top of the image')
+ bottom: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the bottom of the image')
+ left: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the left side of the image')
+ right: conint(ge=0, le=2048) = Field(..., description='Number of pixels to expand at the right side of the image')
+ steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
+ guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
+ safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
+ 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
+ )
+ output_format: Optional[BFLOutputFormat] = Field(
+ BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
+ )
+ image: str = Field(None, description='A Base64-encoded string representing the image you wish to expand')
+
+
+class BFLFluxFillImageRequest(BaseModel):
+ prompt: str = Field(..., description='The description of the changes you want to make. This text guides the expansion process, allowing you to specify features, styles, or modifications for the expanded areas.')
+ prompt_upsampling: Optional[bool] = Field(
+ None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
+ )
+ seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
+ steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
+ guidance: confloat(ge=1.5, le=100) = Field(..., description='Guidance strength for the image generation process')
+ safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
+ 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
+ )
+ output_format: Optional[BFLOutputFormat] = Field(
+ BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
+ )
+ image: str = Field(None, description='A Base64-encoded string representing the image you wish to modify. Can contain alpha mask if desired.')
+ mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.')
+
+
+class BFLFluxCannyImageRequest(BaseModel):
+ prompt: str = Field(..., description='Text prompt for image generation')
+ prompt_upsampling: Optional[bool] = Field(
+ None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
+ )
+ canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection')
+ canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection')
+ seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
+ steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
+ guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
+ safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
+ 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
+ )
+ output_format: Optional[BFLOutputFormat] = Field(
+ BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
+ )
+ control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
+ preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
+
+
+class BFLFluxDepthImageRequest(BaseModel):
+ prompt: str = Field(..., description='Text prompt for image generation')
+ prompt_upsampling: Optional[bool] = Field(
+ None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
+ )
+ seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
+ steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process')
+ guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process')
+ safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
+ 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
+ )
+ output_format: Optional[BFLOutputFormat] = Field(
+ BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
+ )
+ control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided')
+ preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step')
+
+
+class BFLFluxProGenerateRequest(BaseModel):
+ prompt: str = Field(..., description='The text prompt for image generation.')
+ prompt_upsampling: Optional[bool] = Field(
+ None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
+ )
+ seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
+ width: conint(ge=256, le=1440) = Field(1024, description='Width of the generated image in pixels. Must be a multiple of 32.')
+ height: conint(ge=256, le=1440) = Field(768, description='Height of the generated image in pixels. Must be a multiple of 32.')
+ safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
+ 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
+ )
+ output_format: Optional[BFLOutputFormat] = Field(
+ BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
+ )
+ image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
+ # image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
+ # None, description='Blend between the prompt and the image prompt.'
+ # )
+
+
+class BFLFluxProUltraGenerateRequest(BaseModel):
+ prompt: str = Field(..., description='The text prompt for image generation.')
+ prompt_upsampling: Optional[bool] = Field(
+ None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.'
+ )
+ seed: Optional[int] = Field(None, description='The seed value for reproducibility.')
+ aspect_ratio: Optional[str] = Field(None, description='Aspect ratio of the image between 21:9 and 9:21.')
+ safety_tolerance: Optional[conint(ge=0, le=6)] = Field(
+ 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.'
+ )
+ output_format: Optional[BFLOutputFormat] = Field(
+ BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png']
+ )
+ raw: Optional[bool] = Field(None, description='Generate less processed, more natural-looking images.')
+ image_prompt: Optional[str] = Field(None, description='Optional image to remix in base64 format')
+ image_prompt_strength: Optional[confloat(ge=0.0, le=1.0)] = Field(
+ None, description='Blend between the prompt and the image prompt.'
+ )
+
+
+class BFLFluxProGenerateResponse(BaseModel):
+ id: str = Field(..., description='The unique identifier for the generation task.')
+ polling_url: str = Field(..., description='URL to poll for the generation result.')
+
+
+class BFLStatus(str, Enum):
+ task_not_found = "Task not found"
+ pending = "Pending"
+ request_moderated = "Request Moderated"
+ content_moderated = "Content Moderated"
+ ready = "Ready"
+ error = "Error"
+
+
+class BFLFluxProStatusResponse(BaseModel):
+ id: str = Field(..., description="The unique identifier for the generation task.")
+ status: BFLStatus = Field(..., description="The status of the task.")
+ result: Optional[Dict[str, Any]] = Field(
+ None, description="The result of the task (null if not completed)."
+ )
+ progress: confloat(ge=0.0, le=1.0) = Field(
+ ..., description="The progress of the task (0.0 to 1.0)."
+ )
+ details: Optional[Dict[str, Any]] = Field(
+ None, description="Additional details about the task (null if not available)."
+ )
diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py
new file mode 100644
index 000000000..0897d5d78
--- /dev/null
+++ b/comfy_api_nodes/apis/client.py
@@ -0,0 +1,1124 @@
+"""
+API Client Framework for api.comfy.org.
+
+This module provides a flexible framework for making API requests from ComfyUI nodes.
+It supports both synchronous and asynchronous API operations with proper type validation.
+
+Key Components:
+--------------
+1. ApiClient - Handles HTTP requests with authentication and error handling
+2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models
+3. ApiOperation - Executes a single synchronous API operation
+
+Usage Examples:
+--------------
+
+# Example 1: Synchronous API Operation
+# ------------------------------------
+# For a simple API call that returns the result immediately:
+
+# 1. Create the API client
+api_client = ApiClient(
+ base_url="https://api.example.com",
+ auth_token="your_auth_token_here",
+ comfy_api_key="your_comfy_api_key_here",
+ timeout=30.0,
+ verify_ssl=True
+)
+
+# 2. Define the endpoint
+user_info_endpoint = ApiEndpoint(
+ path="/v1/users/me",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest, # No request body needed
+ response_model=UserProfile, # Pydantic model for the response
+ query_params=None
+)
+
+# 3. Create the request object
+request = EmptyRequest()
+
+# 4. Create and execute the operation
+operation = ApiOperation(
+ endpoint=user_info_endpoint,
+ request=request
+)
+user_profile = operation.execute(client=api_client) # Returns immediately with the result
+
+
+# Example 2: Asynchronous API Operation with Polling
+# -------------------------------------------------
+# For an API that starts a task and requires polling for completion:
+
+# 1. Define the endpoints (initial request and polling)
+generate_image_endpoint = ApiEndpoint(
+ path="/v1/images/generate",
+ method=HttpMethod.POST,
+ request_model=ImageGenerationRequest,
+ response_model=TaskCreatedResponse,
+ query_params=None
+)
+
+check_task_endpoint = ApiEndpoint(
+ path="/v1/tasks/{task_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=ImageGenerationResult,
+ query_params=None
+)
+
+# 2. Create the request object
+request = ImageGenerationRequest(
+ prompt="a beautiful sunset over mountains",
+ width=1024,
+ height=1024,
+ num_images=1
+)
+
+# 3. Create and execute the polling operation
+operation = PollingOperation(
+ initial_endpoint=generate_image_endpoint,
+ initial_request=request,
+ poll_endpoint=check_task_endpoint,
+ task_id_field="task_id",
+ status_field="status",
+ completed_statuses=["completed"],
+ failed_statuses=["failed", "error"]
+)
+
+# This will make the initial request and then poll until completion
+result = operation.execute(client=api_client) # Returns the final ImageGenerationResult when done
+"""
+
+from __future__ import annotations
+import logging
+import time
+import io
+import socket
+from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple
+from enum import Enum
+import json
+import requests
+from urllib.parse import urljoin, urlparse
+from pydantic import BaseModel, Field
+import uuid # For generating unique operation IDs
+
+from server import PromptServer
+from comfy.cli_args import args
+from comfy import utils
+from . import request_logger
+
+T = TypeVar("T", bound=BaseModel)
+R = TypeVar("R", bound=BaseModel)
+P = TypeVar("P", bound=BaseModel) # For poll response
+
+PROGRESS_BAR_MAX = 100
+
+
+class NetworkError(Exception):
+ """Base exception for network-related errors with diagnostic information."""
+ pass
+
+
+class LocalNetworkError(NetworkError):
+ """Exception raised when local network connectivity issues are detected."""
+ pass
+
+
+class ApiServerError(NetworkError):
+ """Exception raised when the API server is unreachable but internet is working."""
+ pass
+
+
+class EmptyRequest(BaseModel):
+ """Base class for empty request bodies.
+ For GET requests, fields will be sent as query parameters."""
+
+ pass
+
+
+class UploadRequest(BaseModel):
+ file_name: str = Field(..., description="Filename to upload")
+ content_type: Optional[str] = Field(
+ None,
+ description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.",
+ )
+
+
+class UploadResponse(BaseModel):
+ download_url: str = Field(..., description="URL to GET uploaded file")
+ upload_url: str = Field(..., description="URL to PUT file to upload")
+
+
+class HttpMethod(str, Enum):
+ GET = "GET"
+ POST = "POST"
+ PUT = "PUT"
+ DELETE = "DELETE"
+ PATCH = "PATCH"
+
+
+class ApiClient:
+ """
+ Client for making HTTP requests to an API with authentication, error handling, and retry logic.
+ """
+
+ def __init__(
+ self,
+ base_url: str,
+ auth_token: Optional[str] = None,
+ comfy_api_key: Optional[str] = None,
+ timeout: float = 3600.0,
+ verify_ssl: bool = True,
+ max_retries: int = 3,
+ retry_delay: float = 1.0,
+ retry_backoff_factor: float = 2.0,
+ retry_status_codes: Optional[Tuple[int, ...]] = None,
+ ):
+ self.base_url = base_url
+ self.auth_token = auth_token
+ self.comfy_api_key = comfy_api_key
+ self.timeout = timeout
+ self.verify_ssl = verify_ssl
+ self.max_retries = max_retries
+ self.retry_delay = retry_delay
+ self.retry_backoff_factor = retry_backoff_factor
+ # Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests),
+ # 500, 502, 503, 504 (Server Errors)
+ self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504)
+
+ def _generate_operation_id(self, path: str) -> str:
+ """Generates a unique operation ID for logging."""
+ return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}"
+
+ def _create_json_payload_args(
+ self,
+ data: Optional[Dict[str, Any]] = None,
+ headers: Optional[Dict[str, str]] = None,
+ ) -> Dict[str, Any]:
+ return {
+ "json": data,
+ "headers": headers,
+ }
+
+ def _create_form_data_args(
+ self,
+ data: Dict[str, Any],
+ files: Dict[str, Any],
+ headers: Optional[Dict[str, str]] = None,
+ multipart_parser = None,
+ ) -> Dict[str, Any]:
+ if headers and "Content-Type" in headers:
+ del headers["Content-Type"]
+
+ if multipart_parser:
+ data = multipart_parser(data)
+
+ return {
+ "data": data,
+ "files": files,
+ "headers": headers,
+ }
+
+ def _create_urlencoded_form_data_args(
+ self,
+ data: Dict[str, Any],
+ headers: Optional[Dict[str, str]] = None,
+ ) -> Dict[str, Any]:
+ headers = headers or {}
+ headers["Content-Type"] = "application/x-www-form-urlencoded"
+
+ return {
+ "data": data,
+ "headers": headers,
+ }
+
+ def get_headers(self) -> Dict[str, str]:
+ """Get headers for API requests, including authentication if available"""
+ headers = {"Content-Type": "application/json", "Accept": "application/json"}
+
+ if self.auth_token:
+ headers["Authorization"] = f"Bearer {self.auth_token}"
+ elif self.comfy_api_key:
+ headers["X-API-KEY"] = self.comfy_api_key
+
+ return headers
+
+ def _check_connectivity(self, target_url: str) -> Dict[str, bool]:
+ """
+ Check connectivity to determine if network issues are local or server-related.
+
+ Args:
+ target_url: URL to check connectivity to
+
+ Returns:
+ Dictionary with connectivity status details
+ """
+ results = {
+ "internet_accessible": False,
+ "api_accessible": False,
+ "is_local_issue": False,
+ "is_api_issue": False
+ }
+
+ # First check basic internet connectivity using a reliable external site
+ try:
+ # Use a reliable external domain for checking basic connectivity
+ check_response = requests.get("https://www.google.com",
+ timeout=5.0,
+ verify=self.verify_ssl)
+ if check_response.status_code < 500:
+ results["internet_accessible"] = True
+ except (requests.RequestException, socket.error):
+ results["internet_accessible"] = False
+ results["is_local_issue"] = True
+ return results
+
+ # Now check API server connectivity
+ try:
+ # Extract domain from the target URL to do a simpler health check
+ parsed_url = urlparse(target_url)
+ api_base = f"{parsed_url.scheme}://{parsed_url.netloc}"
+
+ # Try to reach the API domain
+ api_response = requests.get(f"{api_base}/health", timeout=5.0, verify=self.verify_ssl)
+ if api_response.status_code < 500:
+ results["api_accessible"] = True
+ else:
+ results["api_accessible"] = False
+ results["is_api_issue"] = True
+ except requests.RequestException:
+ results["api_accessible"] = False
+ # If we can reach the internet but not the API, it's an API issue
+ results["is_api_issue"] = True
+
+ return results
+
+ def request(
+ self,
+ method: str,
+ path: str,
+ params: Optional[Dict[str, Any]] = None,
+ data: Optional[Dict[str, Any]] = None,
+ files: Optional[Dict[str, Any]] = None,
+ headers: Optional[Dict[str, str]] = None,
+ content_type: str = "application/json",
+ multipart_parser: Callable = None,
+ retry_count: int = 0, # Used internally for tracking retries
+ ) -> Dict[str, Any]:
+ """
+ Make an HTTP request to the API with automatic retries for transient errors.
+
+ Args:
+ method: HTTP method (GET, POST, etc.)
+ path: API endpoint path (will be joined with base_url)
+ params: Query parameters
+ data: body data
+ files: Files to upload
+ headers: Additional headers
+ content_type: Content type of the request. Defaults to application/json.
+ retry_count: Internal parameter for tracking retries, do not set manually
+
+ Returns:
+ Parsed JSON response
+
+ Raises:
+ LocalNetworkError: If local network connectivity issues are detected
+ ApiServerError: If the API server is unreachable but internet is working
+ Exception: For other request failures
+ """
+ url = urljoin(self.base_url, path)
+ self.check_auth(self.auth_token, self.comfy_api_key)
+ # Combine default headers with any provided headers
+ request_headers = self.get_headers()
+ if headers:
+ request_headers.update(headers)
+
+ # Let requests handle the content type when files are present.
+ if files:
+ del request_headers["Content-Type"]
+
+ logging.debug(f"[DEBUG] Request Headers: {request_headers}")
+ logging.debug(f"[DEBUG] Files: {files}")
+ logging.debug(f"[DEBUG] Params: {params}")
+ logging.debug(f"[DEBUG] Data: {data}")
+
+ if content_type == "application/x-www-form-urlencoded":
+ payload_args = self._create_urlencoded_form_data_args(data, request_headers)
+ elif content_type == "multipart/form-data":
+ payload_args = self._create_form_data_args(
+ data, files, request_headers, multipart_parser
+ )
+ else:
+ payload_args = self._create_json_payload_args(data, request_headers)
+
+ operation_id = self._generate_operation_id(path)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ request_headers=request_headers,
+ request_params=params,
+ request_data=data if content_type == "application/json" else "[form-data or other]"
+ )
+
+ try:
+ response = requests.request(
+ method=method,
+ url=url,
+ params=params,
+ timeout=self.timeout,
+ verify=self.verify_ssl,
+ **payload_args,
+ )
+
+ # Check if we should retry based on status code
+ if (response.status_code in self.retry_status_codes and
+ retry_count < self.max_retries):
+
+ # Calculate delay with exponential backoff
+ delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
+
+ logging.warning(
+ f"Request failed with status {response.status_code}. "
+ f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
+ )
+
+ time.sleep(delay)
+ return self.request(
+ method=method,
+ path=path,
+ params=params,
+ data=data,
+ files=files,
+ headers=headers,
+ content_type=content_type,
+ multipart_parser=multipart_parser,
+ retry_count=retry_count + 1,
+ )
+
+ # Raise exception for error status codes
+ response.raise_for_status()
+
+ # Log successful response
+ response_content_to_log = response.content
+ try:
+ # Attempt to parse JSON for prettier logging, fallback to raw content
+ response_content_to_log = response.json()
+ except json.JSONDecodeError:
+ pass # Keep as bytes/str if not JSON
+
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method, # Pass request details again for context in log
+ request_url=url,
+ response_status_code=response.status_code,
+ response_headers=dict(response.headers),
+ response_content=response_content_to_log
+ )
+
+ except requests.ConnectionError as e:
+ error_message = f"ConnectionError: {str(e)}"
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method,
+ request_url=url,
+ error_message=error_message
+ )
+ # Only perform connectivity check if we've exhausted all retries
+ if retry_count >= self.max_retries:
+ # Check connectivity to determine if it's a local or API issue
+ connectivity = self._check_connectivity(self.base_url)
+
+ if connectivity["is_local_issue"]:
+ raise LocalNetworkError(
+ "Unable to connect to the API server due to local network issues. "
+ "Please check your internet connection and try again."
+ ) from e
+ elif connectivity["is_api_issue"]:
+ raise ApiServerError(
+ f"The API server at {self.base_url} is currently unreachable. "
+ f"The service may be experiencing issues. Please try again later."
+ ) from e
+
+ # If we haven't exhausted retries yet, retry the request
+ if retry_count < self.max_retries:
+ delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
+ logging.warning(
+ f"Connection error: {str(e)}. "
+ f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
+ )
+ time.sleep(delay)
+ return self.request(
+ method=method,
+ path=path,
+ params=params,
+ data=data,
+ files=files,
+ headers=headers,
+ content_type=content_type,
+ multipart_parser=multipart_parser,
+ retry_count=retry_count + 1,
+ )
+
+ # If we've exhausted retries and didn't identify the specific issue,
+ # raise a generic exception
+ final_error_message = (
+ f"Unable to connect to the API server after {self.max_retries} attempts. "
+ f"Please check your internet connection or try again later."
+ )
+ request_logger.log_request_response( # Log final failure
+ operation_id=operation_id,
+ request_method=method, request_url=url,
+ error_message=final_error_message
+ )
+ raise Exception(final_error_message) from e
+
+ except requests.Timeout as e:
+ error_message = f"Timeout: {str(e)}"
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method, request_url=url,
+ error_message=error_message
+ )
+ # Retry timeouts if we haven't exhausted retries
+ if retry_count < self.max_retries:
+ delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
+ logging.warning(
+ f"Request timed out. "
+ f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
+ )
+ time.sleep(delay)
+ return self.request(
+ method=method,
+ path=path,
+ params=params,
+ data=data,
+ files=files,
+ headers=headers,
+ content_type=content_type,
+ multipart_parser=multipart_parser,
+ retry_count=retry_count + 1,
+ )
+ final_error_message = (
+ f"Request timed out after {self.timeout} seconds and {self.max_retries} retry attempts. "
+ f"The server might be experiencing high load or the operation is taking longer than expected."
+ )
+ request_logger.log_request_response( # Log final failure
+ operation_id=operation_id,
+ request_method=method, request_url=url,
+ error_message=final_error_message
+ )
+ raise Exception(final_error_message) from e
+
+ except requests.HTTPError as e:
+ status_code = e.response.status_code if hasattr(e, "response") else None
+ original_error_message = f"HTTP Error: {str(e)}"
+ error_content_for_log = None
+ if hasattr(e, "response") and e.response is not None:
+ error_content_for_log = e.response.content
+ try:
+ error_content_for_log = e.response.json()
+ except json.JSONDecodeError:
+ pass
+
+
+ # Try to extract detailed error message from JSON response for user display
+ # but log the full error content.
+ user_display_error_message = original_error_message
+
+ try:
+ if hasattr(e, "response") and e.response is not None and e.response.content:
+ error_json = e.response.json()
+ if "error" in error_json and "message" in error_json["error"]:
+ user_display_error_message = f"API Error: {error_json['error']['message']}"
+ if "type" in error_json["error"]:
+ user_display_error_message += f" (Type: {error_json['error']['type']})"
+ elif isinstance(error_json, dict): # Handle cases where error is just a JSON dict
+ user_display_error_message = f"API Error: {json.dumps(error_json)}"
+ else: # Non-dict JSON error
+ user_display_error_message = f"API Error: {str(error_json)}"
+ except json.JSONDecodeError:
+ # If not JSON, use the raw content if it's not too long, or a summary
+ if hasattr(e, "response") and e.response is not None and e.response.content:
+ raw_content = e.response.content.decode(errors='ignore')
+ if len(raw_content) < 200: # Arbitrary limit for display
+ user_display_error_message = f"API Error (raw): {raw_content}"
+ else:
+ user_display_error_message = f"API Error (raw, status {status_code})"
+
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method=method, request_url=url,
+ response_status_code=status_code,
+ response_headers=dict(e.response.headers) if hasattr(e, "response") and e.response is not None else None,
+ response_content=error_content_for_log,
+ error_message=original_error_message # Log the original exception string as error
+ )
+
+ logging.debug(f"[DEBUG] API Error: {user_display_error_message} (Status: {status_code})")
+ if hasattr(e, "response") and e.response is not None and e.response.content:
+ logging.debug(f"[DEBUG] Response content: {e.response.content}")
+
+ # Retry if the status code is in our retry list and we haven't exhausted retries
+ if (status_code in self.retry_status_codes and
+ retry_count < self.max_retries):
+
+ delay = self.retry_delay * (self.retry_backoff_factor ** retry_count)
+ logging.warning(
+ f"HTTP error {status_code}. "
+ f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})"
+ )
+ time.sleep(delay)
+ return self.request(
+ method=method,
+ path=path,
+ params=params,
+ data=data,
+ files=files,
+ headers=headers,
+ content_type=content_type,
+ multipart_parser=multipart_parser,
+ retry_count=retry_count + 1,
+ )
+
+ # Specific error messages for common status codes for user display
+ if status_code == 401:
+ user_display_error_message = "Unauthorized: Please login first to use this node."
+ elif status_code == 402:
+ user_display_error_message = "Payment Required: Please add credits to your account to use this node."
+ elif status_code == 409:
+ user_display_error_message = "There is a problem with your account. Please contact support@comfy.org."
+ elif status_code == 429:
+ user_display_error_message = "Rate Limit Exceeded: Please try again later."
+ # else, user_display_error_message remains as parsed from response or original HTTPError string
+
+ raise Exception(user_display_error_message) # Raise with the user-friendly message
+
+ # Parse and return JSON response
+ if response.content:
+ return response.json()
+ return {}
+
+ def check_auth(self, auth_token, comfy_api_key):
+ """Verify that an auth token is present or comfy_api_key is present"""
+ if auth_token is None and comfy_api_key is None:
+ raise Exception("Unauthorized: Please login first to use this node.")
+ return auth_token or comfy_api_key
+
+ @staticmethod
+ def upload_file(
+ upload_url: str,
+ file: io.BytesIO | str,
+ content_type: str | None = None,
+ max_retries: int = 3,
+ retry_delay: float = 1.0,
+ retry_backoff_factor: float = 2.0,
+ ):
+ """Upload a file to the API with retry logic.
+
+ Args:
+ upload_url: The URL to upload to
+ file: Either a file path string, BytesIO object, or tuple of (file_path, filename)
+ content_type: Optional mime type to set for the upload
+ max_retries: Maximum number of retry attempts
+ retry_delay: Initial delay between retries in seconds
+ retry_backoff_factor: Multiplier for the delay after each retry
+ """
+ headers = {}
+ if content_type:
+ headers["Content-Type"] = content_type
+
+ # Prepare the file data
+ if isinstance(file, io.BytesIO):
+ file.seek(0) # Ensure we're at the start of the file
+ data = file.read()
+ elif isinstance(file, str):
+ with open(file, "rb") as f:
+ data = f.read()
+ else:
+ raise ValueError("File must be either a BytesIO object or a file path string")
+
+ # Try the upload with retries
+ last_exception = None
+ operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" # Simplified ID for uploads
+
+ # Log initial attempt (without full file data for brevity)
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method="PUT",
+ request_url=upload_url,
+ request_headers=headers,
+ request_data=f"[File data of type {content_type or 'unknown'}, size {len(data)} bytes]"
+ )
+
+ for retry_attempt in range(max_retries + 1):
+ try:
+ response = requests.put(upload_url, data=data, headers=headers)
+ response.raise_for_status()
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method="PUT", request_url=upload_url, # For context
+ response_status_code=response.status_code,
+ response_headers=dict(response.headers),
+ response_content="File uploaded successfully." # Or response.text if available
+ )
+ return response
+
+ except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e:
+ last_exception = e
+ error_message_for_log = f"{type(e).__name__}: {str(e)}"
+ response_content_for_log = None
+ status_code_for_log = None
+ headers_for_log = None
+
+ if hasattr(e, 'response') and e.response is not None:
+ status_code_for_log = e.response.status_code
+ headers_for_log = dict(e.response.headers)
+ try:
+ response_content_for_log = e.response.json()
+ except json.JSONDecodeError:
+ response_content_for_log = e.response.content
+
+
+ request_logger.log_request_response(
+ operation_id=operation_id,
+ request_method="PUT", request_url=upload_url,
+ response_status_code=status_code_for_log,
+ response_headers=headers_for_log,
+ response_content=response_content_for_log,
+ error_message=error_message_for_log
+ )
+
+ if retry_attempt < max_retries:
+ delay = retry_delay * (retry_backoff_factor ** retry_attempt)
+ logging.warning(
+ f"File upload failed: {str(e)}. "
+ f"Retrying in {delay:.2f}s ({retry_attempt + 1}/{max_retries})"
+ )
+ time.sleep(delay)
+ else:
+ break # Max retries reached
+
+ # If we've exhausted all retries, determine the final error type and raise
+ final_error_message = f"Failed to upload file after {max_retries + 1} attempts. Error: {str(last_exception)}"
+ try:
+ # Check basic internet connectivity
+ check_response = requests.get("https://www.google.com", timeout=5.0, verify=True) # Assuming verify=True is desired
+ if check_response.status_code >= 500: # Google itself has an issue (rare)
+ final_error_message = (f"Failed to upload file. Internet connectivity check to Google failed "
+ f"(status {check_response.status_code}). Original error: {str(last_exception)}")
+ # Not raising LocalNetworkError here as Google itself might be down.
+ # If Google is reachable, the issue is likely with the upload server or a more specific local problem
+ # not caught by a simple Google ping (e.g., DNS for the specific upload URL, firewall).
+ # The original last_exception is probably most relevant.
+
+ except (requests.RequestException, socket.error) as conn_check_exc:
+ # Could not reach Google, likely a local network issue
+ final_error_message = (f"Failed to upload file due to network connectivity issues "
+ f"(cannot reach Google: {str(conn_check_exc)}). "
+ f"Original upload error: {str(last_exception)}")
+ request_logger.log_request_response( # Log final failure reason
+ operation_id=operation_id,
+ request_method="PUT", request_url=upload_url,
+ error_message=final_error_message
+ )
+ raise LocalNetworkError(final_error_message) from last_exception
+
+ request_logger.log_request_response( # Log final failure reason if not LocalNetworkError
+ operation_id=operation_id,
+ request_method="PUT", request_url=upload_url,
+ error_message=final_error_message
+ )
+ raise Exception(final_error_message) from last_exception
+
+
+class ApiEndpoint(Generic[T, R]):
+ """Defines an API endpoint with its request and response types"""
+
+ def __init__(
+ self,
+ path: str,
+ method: HttpMethod,
+ request_model: Type[T],
+ response_model: Type[R],
+ query_params: Optional[Dict[str, Any]] = None,
+ ):
+ """Initialize an API endpoint definition.
+
+ Args:
+ path: The URL path for this endpoint, can include placeholders like {id}
+ method: The HTTP method to use (GET, POST, etc.)
+ request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint
+ response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint
+ query_params: Optional dictionary of query parameters to include in the request
+ """
+ self.path = path
+ self.method = method
+ self.request_model = request_model
+ self.response_model = response_model
+ self.query_params = query_params or {}
+
+
+class SynchronousOperation(Generic[T, R]):
+ """
+ Represents a single synchronous API operation.
+ """
+
+ def __init__(
+ self,
+ endpoint: ApiEndpoint[T, R],
+ request: T,
+ files: Optional[Dict[str, Any]] = None,
+ api_base: str | None = None,
+ auth_token: Optional[str] = None,
+ comfy_api_key: Optional[str] = None,
+ auth_kwargs: Optional[Dict[str,str]] = None,
+ timeout: float = 604800.0,
+ verify_ssl: bool = True,
+ content_type: str = "application/json",
+ multipart_parser: Callable = None,
+ max_retries: int = 3,
+ retry_delay: float = 1.0,
+ retry_backoff_factor: float = 2.0,
+ ):
+ self.endpoint = endpoint
+ self.request = request
+ self.response = None
+ self.error = None
+ self.api_base: str = api_base or args.comfy_api_base
+ self.auth_token = auth_token
+ self.comfy_api_key = comfy_api_key
+ if auth_kwargs is not None:
+ self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
+ self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
+ self.timeout = timeout
+ self.verify_ssl = verify_ssl
+ self.files = files
+ self.content_type = content_type
+ self.multipart_parser = multipart_parser
+ self.max_retries = max_retries
+ self.retry_delay = retry_delay
+ self.retry_backoff_factor = retry_backoff_factor
+
+ def execute(self, client: Optional[ApiClient] = None) -> R:
+ """Execute the API operation using the provided client or create one with retry support"""
+ try:
+ # Create client if not provided
+ if client is None:
+ client = ApiClient(
+ base_url=self.api_base,
+ auth_token=self.auth_token,
+ comfy_api_key=self.comfy_api_key,
+ timeout=self.timeout,
+ verify_ssl=self.verify_ssl,
+ max_retries=self.max_retries,
+ retry_delay=self.retry_delay,
+ retry_backoff_factor=self.retry_backoff_factor,
+ )
+
+ # Convert request model to dict, but use None for EmptyRequest
+ request_dict = (
+ None
+ if isinstance(self.request, EmptyRequest)
+ else self.request.model_dump(exclude_none=True)
+ )
+ if request_dict:
+ for key, value in request_dict.items():
+ if isinstance(value, Enum):
+ request_dict[key] = value.value
+
+ # Debug log for request
+ logging.debug(
+ f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}"
+ )
+ logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}")
+ logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}")
+
+ # Make the request with built-in retry
+ resp = client.request(
+ method=self.endpoint.method.value,
+ path=self.endpoint.path,
+ data=request_dict,
+ params=self.endpoint.query_params,
+ files=self.files,
+ content_type=self.content_type,
+ multipart_parser=self.multipart_parser
+ )
+
+ # Debug log for response
+ logging.debug("=" * 50)
+ logging.debug("[DEBUG] RESPONSE DETAILS:")
+ logging.debug("[DEBUG] Status Code: 200 (Success)")
+ logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}")
+ logging.debug("=" * 50)
+
+ # Parse and return the response
+ return self._parse_response(resp)
+
+ except LocalNetworkError as e:
+ # Propagate specific network error types
+ logging.error(f"[ERROR] Local network error: {str(e)}")
+ raise
+
+ except ApiServerError as e:
+ # Propagate API server errors
+ logging.error(f"[ERROR] API server error: {str(e)}")
+ raise
+
+ except Exception as e:
+ logging.error(f"[ERROR] API Exception: {str(e)}")
+ raise Exception(str(e))
+
+ def _parse_response(self, resp):
+ """Parse response data - can be overridden by subclasses"""
+ # The response is already the complete object, don't extract just the "data" field
+ # as that would lose the outer structure (created timestamp, etc.)
+
+ # Parse response using the provided model
+ self.response = self.endpoint.response_model.model_validate(resp)
+ logging.debug(f"[DEBUG] Parsed Response: {self.response}")
+ return self.response
+
+
+class TaskStatus(str, Enum):
+ """Enum for task status values"""
+
+ COMPLETED = "completed"
+ FAILED = "failed"
+ PENDING = "pending"
+
+
+class PollingOperation(Generic[T, R]):
+ """
+ Represents an asynchronous API operation that requires polling for completion.
+ """
+
+ def __init__(
+ self,
+ poll_endpoint: ApiEndpoint[EmptyRequest, R],
+ completed_statuses: list,
+ failed_statuses: list,
+ status_extractor: Callable[[R], str],
+ progress_extractor: Callable[[R], float] = None,
+ result_url_extractor: Callable[[R], str] = None,
+ request: Optional[T] = None,
+ api_base: str | None = None,
+ auth_token: Optional[str] = None,
+ comfy_api_key: Optional[str] = None,
+ auth_kwargs: Optional[Dict[str,str]] = None,
+ poll_interval: float = 5.0,
+ max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval)
+ max_retries: int = 3, # Max retries per individual API call
+ retry_delay: float = 1.0,
+ retry_backoff_factor: float = 2.0,
+ estimated_duration: Optional[float] = None,
+ node_id: Optional[str] = None,
+ ):
+ self.poll_endpoint = poll_endpoint
+ self.request = request
+ self.api_base: str = api_base or args.comfy_api_base
+ self.auth_token = auth_token
+ self.comfy_api_key = comfy_api_key
+ if auth_kwargs is not None:
+ self.auth_token = auth_kwargs.get("auth_token", self.auth_token)
+ self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key)
+ self.poll_interval = poll_interval
+ self.max_poll_attempts = max_poll_attempts
+ self.max_retries = max_retries
+ self.retry_delay = retry_delay
+ self.retry_backoff_factor = retry_backoff_factor
+ self.estimated_duration = estimated_duration
+
+ # Polling configuration
+ self.status_extractor = status_extractor or (
+ lambda x: getattr(x, "status", None)
+ )
+ self.progress_extractor = progress_extractor
+ self.result_url_extractor = result_url_extractor
+ self.node_id = node_id
+ self.completed_statuses = completed_statuses
+ self.failed_statuses = failed_statuses
+
+ # For storing response data
+ self.final_response = None
+ self.error = None
+
+ def execute(self, client: Optional[ApiClient] = None) -> R:
+ """Execute the polling operation using the provided client. If failed, raise an exception."""
+ try:
+ if client is None:
+ client = ApiClient(
+ base_url=self.api_base,
+ auth_token=self.auth_token,
+ comfy_api_key=self.comfy_api_key,
+ max_retries=self.max_retries,
+ retry_delay=self.retry_delay,
+ retry_backoff_factor=self.retry_backoff_factor,
+ )
+ return self._poll_until_complete(client)
+ except LocalNetworkError as e:
+ # Provide clear message for local network issues
+ raise Exception(
+ f"Polling failed due to local network issues. Please check your internet connection. "
+ f"Details: {str(e)}"
+ ) from e
+ except ApiServerError as e:
+ # Provide clear message for API server issues
+ raise Exception(
+ f"Polling failed due to API server issues. The service may be experiencing problems. "
+ f"Please try again later. Details: {str(e)}"
+ ) from e
+ except Exception as e:
+ raise Exception(f"Error during polling: {str(e)}")
+
+ def _display_text_on_node(self, text: str):
+ """Sends text to the client which will be displayed on the node in the UI"""
+ if not self.node_id:
+ return
+
+ PromptServer.instance.send_progress_text(text, self.node_id)
+
+ def _display_time_progress_on_node(self, time_completed: int):
+ if not self.node_id:
+ return
+
+ if self.estimated_duration is not None:
+ estimated_time_remaining = max(
+ 0, int(self.estimated_duration) - int(time_completed)
+ )
+ message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)"
+ else:
+ message = f"Task in progress: {time_completed:.0f}s"
+ self._display_text_on_node(message)
+
+ def _check_task_status(self, response: R) -> TaskStatus:
+ """Check task status using the status extractor function"""
+ try:
+ status = self.status_extractor(response)
+ if status in self.completed_statuses:
+ return TaskStatus.COMPLETED
+ elif status in self.failed_statuses:
+ return TaskStatus.FAILED
+ return TaskStatus.PENDING
+ except Exception as e:
+ logging.error(f"Error extracting status: {e}")
+ return TaskStatus.PENDING
+
+ def _poll_until_complete(self, client: ApiClient) -> R:
+ """Poll until the task is complete"""
+ poll_count = 0
+ consecutive_errors = 0
+ max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors
+
+ if self.progress_extractor:
+ progress = utils.ProgressBar(PROGRESS_BAR_MAX)
+
+ while poll_count < self.max_poll_attempts:
+ try:
+ poll_count += 1
+ logging.debug(f"[DEBUG] Polling attempt #{poll_count}")
+
+ request_dict = (
+ self.request.model_dump(exclude_none=True)
+ if self.request is not None
+ else None
+ )
+
+ if poll_count == 1:
+ logging.debug(
+ f"[DEBUG] Poll Request: {self.poll_endpoint.method.value} {self.poll_endpoint.path}"
+ )
+ logging.debug(
+ f"[DEBUG] Poll Request Data: {json.dumps(request_dict, indent=2) if request_dict else 'None'}"
+ )
+
+ # Query task status
+ resp = client.request(
+ method=self.poll_endpoint.method.value,
+ path=self.poll_endpoint.path,
+ params=self.poll_endpoint.query_params,
+ data=request_dict,
+ )
+
+ # Successfully got a response, reset consecutive error count
+ consecutive_errors = 0
+
+ # Parse response
+ response_obj = self.poll_endpoint.response_model.model_validate(resp)
+
+ # Check if task is complete
+ status = self._check_task_status(response_obj)
+ logging.debug(f"[DEBUG] Task Status: {status}")
+
+ # If progress extractor is provided, extract progress
+ if self.progress_extractor:
+ new_progress = self.progress_extractor(response_obj)
+ if new_progress is not None:
+ progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)
+
+ if status == TaskStatus.COMPLETED:
+ message = "Task completed successfully"
+ if self.result_url_extractor:
+ result_url = self.result_url_extractor(response_obj)
+ if result_url:
+ message = f"Result URL: {result_url}"
+ else:
+ message = "Task completed successfully!"
+ logging.debug(f"[DEBUG] {message}")
+ self._display_text_on_node(message)
+ self.final_response = response_obj
+ if self.progress_extractor:
+ progress.update(100)
+ return self.final_response
+ elif status == TaskStatus.FAILED:
+ message = f"Task failed: {json.dumps(resp)}"
+ logging.error(f"[DEBUG] {message}")
+ raise Exception(message)
+ else:
+ logging.debug("[DEBUG] Task still pending, continuing to poll...")
+
+ # Wait before polling again
+ logging.debug(
+ f"[DEBUG] Waiting {self.poll_interval} seconds before next poll"
+ )
+ for i in range(int(self.poll_interval)):
+ time_completed = (poll_count * self.poll_interval) + i
+ self._display_time_progress_on_node(time_completed)
+ time.sleep(1)
+
+ except (LocalNetworkError, ApiServerError) as e:
+ # For network-related errors, increment error count and potentially abort
+ consecutive_errors += 1
+ if consecutive_errors >= max_consecutive_errors:
+ raise Exception(
+ f"Polling aborted after {consecutive_errors} consecutive network errors: {str(e)}"
+ ) from e
+
+ # Log the error but continue polling
+ logging.warning(
+ f"Network error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
+ f"Will retry in {self.poll_interval} seconds."
+ )
+ time.sleep(self.poll_interval)
+
+ except Exception as e:
+ # For other errors, increment count and potentially abort
+ consecutive_errors += 1
+ if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
+ raise Exception(
+ f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
+ ) from e
+
+ logging.error(f"[DEBUG] Polling error: {str(e)}")
+ logging.warning(
+ f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. "
+ f"Will retry in {self.poll_interval} seconds."
+ )
+ time.sleep(self.poll_interval)
+
+ # If we've exhausted all polling attempts
+ raise Exception(
+ f"Polling timed out after {poll_count} attempts ({poll_count * self.poll_interval} seconds). "
+ f"The operation may still be running on the server but is taking longer than expected."
+ )
diff --git a/comfy_api_nodes/apis/luma_api.py b/comfy_api_nodes/apis/luma_api.py
new file mode 100644
index 000000000..632c4ab96
--- /dev/null
+++ b/comfy_api_nodes/apis/luma_api.py
@@ -0,0 +1,253 @@
+from __future__ import annotations
+
+
+import torch
+
+from enum import Enum
+from typing import Optional, Union
+
+from pydantic import BaseModel, Field, confloat
+
+
+
+class LumaIO:
+ LUMA_REF = "LUMA_REF"
+ LUMA_CONCEPTS = "LUMA_CONCEPTS"
+
+
+class LumaReference:
+ def __init__(self, image: torch.Tensor, weight: float):
+ self.image = image
+ self.weight = weight
+
+ def create_api_model(self, download_url: str):
+ return LumaImageRef(url=download_url, weight=self.weight)
+
+class LumaReferenceChain:
+ def __init__(self, first_ref: LumaReference=None):
+ self.refs: list[LumaReference] = []
+ if first_ref:
+ self.refs.append(first_ref)
+
+ def add(self, luma_ref: LumaReference=None):
+ self.refs.append(luma_ref)
+
+ def create_api_model(self, download_urls: list[str], max_refs=4):
+ if len(self.refs) == 0:
+ return None
+ api_refs: list[LumaImageRef] = []
+ for ref, url in zip(self.refs, download_urls):
+ api_ref = LumaImageRef(url=url, weight=ref.weight)
+ api_refs.append(api_ref)
+ return api_refs
+
+ def clone(self):
+ c = LumaReferenceChain()
+ for ref in self.refs:
+ c.add(ref)
+ return c
+
+
+class LumaConcept:
+ def __init__(self, key: str):
+ self.key = key
+
+
+class LumaConceptChain:
+ def __init__(self, str_list: list[str] = None):
+ self.concepts: list[LumaConcept] = []
+ if str_list is not None:
+ for c in str_list:
+ if c != "None":
+ self.add(LumaConcept(key=c))
+
+ def add(self, concept: LumaConcept):
+ self.concepts.append(concept)
+
+ def create_api_model(self):
+ if len(self.concepts) == 0:
+ return None
+ api_concepts: list[LumaConceptObject] = []
+ for concept in self.concepts:
+ if concept.key == "None":
+ continue
+ api_concepts.append(LumaConceptObject(key=concept.key))
+ if len(api_concepts) == 0:
+ return None
+ return api_concepts
+
+ def clone(self):
+ c = LumaConceptChain()
+ for concept in self.concepts:
+ c.add(concept)
+ return c
+
+ def clone_and_merge(self, other: LumaConceptChain):
+ c = self.clone()
+ for concept in other.concepts:
+ c.add(concept)
+ return c
+
+
+def get_luma_concepts(include_none=False):
+ concepts = []
+ if include_none:
+ concepts.append("None")
+ return concepts + [
+ "truck_left",
+ "pan_right",
+ "pedestal_down",
+ "low_angle",
+ "pedestal_up",
+ "selfie",
+ "pan_left",
+ "roll_right",
+ "zoom_in",
+ "over_the_shoulder",
+ "orbit_right",
+ "orbit_left",
+ "static",
+ "tiny_planet",
+ "high_angle",
+ "bolt_cam",
+ "dolly_zoom",
+ "overhead",
+ "zoom_out",
+ "handheld",
+ "roll_left",
+ "pov",
+ "aerial_drone",
+ "push_in",
+ "crane_down",
+ "truck_right",
+ "tilt_down",
+ "elevator_doors",
+ "tilt_up",
+ "ground_level",
+ "pull_out",
+ "aerial",
+ "crane_up",
+ "eye_level"
+ ]
+
+
+class LumaImageModel(str, Enum):
+ photon_1 = "photon-1"
+ photon_flash_1 = "photon-flash-1"
+
+
+class LumaVideoModel(str, Enum):
+ ray_2 = "ray-2"
+ ray_flash_2 = "ray-flash-2"
+ ray_1_6 = "ray-1-6"
+
+
+class LumaAspectRatio(str, Enum):
+ ratio_1_1 = "1:1"
+ ratio_16_9 = "16:9"
+ ratio_9_16 = "9:16"
+ ratio_4_3 = "4:3"
+ ratio_3_4 = "3:4"
+ ratio_21_9 = "21:9"
+ ratio_9_21 = "9:21"
+
+
+class LumaVideoOutputResolution(str, Enum):
+ res_540p = "540p"
+ res_720p = "720p"
+ res_1080p = "1080p"
+ res_4k = "4k"
+
+
+class LumaVideoModelOutputDuration(str, Enum):
+ dur_5s = "5s"
+ dur_9s = "9s"
+
+
+class LumaGenerationType(str, Enum):
+ video = 'video'
+ image = 'image'
+
+
+class LumaState(str, Enum):
+ queued = "queued"
+ dreaming = "dreaming"
+ completed = "completed"
+ failed = "failed"
+
+
+class LumaAssets(BaseModel):
+ video: Optional[str] = Field(None, description='The URL of the video')
+ image: Optional[str] = Field(None, description='The URL of the image')
+ progress_video: Optional[str] = Field(None, description='The URL of the progress video')
+
+
+class LumaImageRef(BaseModel):
+ '''Used for image gen'''
+ url: str = Field(..., description='The URL of the image reference')
+ weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
+
+
+class LumaImageReference(BaseModel):
+ '''Used for video gen'''
+ type: Optional[str] = Field('image', description='Input type, defaults to image')
+ url: str = Field(..., description='The URL of the image')
+
+
+class LumaModifyImageRef(BaseModel):
+ url: str = Field(..., description='The URL of the image reference')
+ weight: confloat(ge=0.0, le=1.0) = Field(..., description='The weight of the image reference')
+
+
+class LumaCharacterRef(BaseModel):
+ identity0: LumaImageIdentity = Field(..., description='The image identity object')
+
+
+class LumaImageIdentity(BaseModel):
+ images: list[str] = Field(..., description='The URLs of the image identity')
+
+
+class LumaGenerationReference(BaseModel):
+ type: str = Field('generation', description='Input type, defaults to generation')
+ id: str = Field(..., description='The ID of the generation')
+
+
+class LumaKeyframes(BaseModel):
+ frame0: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
+ frame1: Optional[Union[LumaImageReference, LumaGenerationReference]] = Field(None, description='')
+
+
+class LumaConceptObject(BaseModel):
+ key: str = Field(..., description='Camera Concept name')
+
+
+class LumaImageGenerationRequest(BaseModel):
+ prompt: str = Field(..., description='The prompt of the generation')
+ model: LumaImageModel = Field(LumaImageModel.photon_1, description='The image model used for the generation')
+ aspect_ratio: Optional[LumaAspectRatio] = Field(LumaAspectRatio.ratio_16_9, description='The aspect ratio of the generation')
+ image_ref: Optional[list[LumaImageRef]] = Field(None, description='List of image reference objects')
+ style_ref: Optional[list[LumaImageRef]] = Field(None, description='List of style reference objects')
+ character_ref: Optional[LumaCharacterRef] = Field(None, description='The image identity object')
+ modify_image_ref: Optional[LumaModifyImageRef] = Field(None, description='The modify image reference object')
+
+
+class LumaGenerationRequest(BaseModel):
+ prompt: str = Field(..., description='The prompt of the generation')
+ model: LumaVideoModel = Field(LumaVideoModel.ray_2, description='The video model used for the generation')
+ duration: Optional[LumaVideoModelOutputDuration] = Field(None, description='The duration of the generation')
+ aspect_ratio: Optional[LumaAspectRatio] = Field(None, description='The aspect ratio of the generation')
+ resolution: Optional[LumaVideoOutputResolution] = Field(None, description='The resolution of the generation')
+ loop: Optional[bool] = Field(None, description='Whether to loop the video')
+ keyframes: Optional[LumaKeyframes] = Field(None, description='The keyframes of the generation')
+ concepts: Optional[list[LumaConceptObject]] = Field(None, description='Camera Concepts to apply to generation')
+
+
+class LumaGeneration(BaseModel):
+ id: str = Field(..., description='The ID of the generation')
+ generation_type: LumaGenerationType = Field(..., description='Generation type, image or video')
+ state: LumaState = Field(..., description='The state of the generation')
+ failure_reason: Optional[str] = Field(None, description='The reason for the state of the generation')
+ created_at: str = Field(..., description='The date and time when the generation was created')
+ assets: Optional[LumaAssets] = Field(None, description='The assets of the generation')
+ model: str = Field(..., description='The model used for the generation')
+ request: Union[LumaGenerationRequest, LumaImageGenerationRequest] = Field(..., description="The request used for the generation")
diff --git a/comfy_api_nodes/apis/pixverse_api.py b/comfy_api_nodes/apis/pixverse_api.py
new file mode 100644
index 000000000..9bb29c383
--- /dev/null
+++ b/comfy_api_nodes/apis/pixverse_api.py
@@ -0,0 +1,146 @@
+from __future__ import annotations
+
+from enum import Enum
+from typing import Optional
+
+from pydantic import BaseModel, Field
+
+
+pixverse_templates = {
+ "Microwave": 324641385496960,
+ "Suit Swagger": 328545151283968,
+ "Anything, Robot": 313358700761536,
+ "Subject 3 Fever": 327828816843648,
+ "kiss kiss": 315446315336768,
+}
+
+
+class PixverseIO:
+ TEMPLATE = "PIXVERSE_TEMPLATE"
+
+
+class PixverseStatus(int, Enum):
+ successful = 1
+ generating = 5
+ deleted = 6
+ contents_moderation = 7
+ failed = 8
+
+
+class PixverseAspectRatio(str, Enum):
+ ratio_16_9 = "16:9"
+ ratio_4_3 = "4:3"
+ ratio_1_1 = "1:1"
+ ratio_3_4 = "3:4"
+ ratio_9_16 = "9:16"
+
+
+class PixverseQuality(str, Enum):
+ res_360p = "360p"
+ res_540p = "540p"
+ res_720p = "720p"
+ res_1080p = "1080p"
+
+
+class PixverseDuration(int, Enum):
+ dur_5 = 5
+ dur_8 = 8
+
+
+class PixverseMotionMode(str, Enum):
+ normal = "normal"
+ fast = "fast"
+
+
+class PixverseStyle(str, Enum):
+ anime = "anime"
+ animation_3d = "3d_animation"
+ clay = "clay"
+ comic = "comic"
+ cyberpunk = "cyberpunk"
+
+
+# NOTE: forgoing descriptions for now in return for dev speed
+class PixverseTextVideoRequest(BaseModel):
+ aspect_ratio: PixverseAspectRatio = Field(...)
+ quality: PixverseQuality = Field(...)
+ duration: PixverseDuration = Field(...)
+ model: Optional[str] = Field("v3.5")
+ motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
+ prompt: str = Field(...)
+ negative_prompt: Optional[str] = Field(None)
+ seed: Optional[int] = Field(None)
+ style: Optional[str] = Field(None)
+ template_id: Optional[int] = Field(None)
+ water_mark: Optional[bool] = Field(None)
+
+
+class PixverseImageVideoRequest(BaseModel):
+ quality: PixverseQuality = Field(...)
+ duration: PixverseDuration = Field(...)
+ img_id: int = Field(...)
+ model: Optional[str] = Field("v3.5")
+ motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
+ prompt: str = Field(...)
+ negative_prompt: Optional[str] = Field(None)
+ seed: Optional[int] = Field(None)
+ style: Optional[str] = Field(None)
+ template_id: Optional[int] = Field(None)
+ water_mark: Optional[bool] = Field(None)
+
+
+class PixverseTransitionVideoRequest(BaseModel):
+ quality: PixverseQuality = Field(...)
+ duration: PixverseDuration = Field(...)
+ first_frame_img: int = Field(...)
+ last_frame_img: int = Field(...)
+ model: Optional[str] = Field("v3.5")
+ motion_mode: Optional[PixverseMotionMode] = Field(PixverseMotionMode.normal)
+ prompt: str = Field(...)
+ # negative_prompt: Optional[str] = Field(None)
+ seed: Optional[int] = Field(None)
+ # style: Optional[str] = Field(None)
+ # template_id: Optional[int] = Field(None)
+ # water_mark: Optional[bool] = Field(None)
+
+
+class PixverseImageUploadResponse(BaseModel):
+ ErrCode: Optional[int] = None
+ ErrMsg: Optional[str] = None
+ Resp: Optional[PixverseImgIdResponseObject] = Field(None, alias='Resp')
+
+
+class PixverseImgIdResponseObject(BaseModel):
+ img_id: Optional[int] = None
+
+
+class PixverseVideoResponse(BaseModel):
+ ErrCode: Optional[int] = Field(None)
+ ErrMsg: Optional[str] = Field(None)
+ Resp: Optional[PixverseVideoIdResponseObject] = Field(None)
+
+
+class PixverseVideoIdResponseObject(BaseModel):
+ video_id: int = Field(..., description='Video_id')
+
+
+class PixverseGenerationStatusResponse(BaseModel):
+ ErrCode: Optional[int] = Field(None)
+ ErrMsg: Optional[str] = Field(None)
+ Resp: Optional[PixverseGenerationStatusResponseObject] = Field(None)
+
+
+class PixverseGenerationStatusResponseObject(BaseModel):
+ create_time: Optional[str] = Field(None)
+ id: Optional[int] = Field(None)
+ modify_time: Optional[str] = Field(None)
+ negative_prompt: Optional[str] = Field(None)
+ outputHeight: Optional[int] = Field(None)
+ outputWidth: Optional[int] = Field(None)
+ prompt: Optional[str] = Field(None)
+ resolution_ratio: Optional[int] = Field(None)
+ seed: Optional[int] = Field(None)
+ size: Optional[int] = Field(None)
+ status: Optional[int] = Field(None)
+ style: Optional[str] = Field(None)
+ url: Optional[str] = Field(None)
diff --git a/comfy_api_nodes/apis/recraft_api.py b/comfy_api_nodes/apis/recraft_api.py
new file mode 100644
index 000000000..c36d95f24
--- /dev/null
+++ b/comfy_api_nodes/apis/recraft_api.py
@@ -0,0 +1,262 @@
+from __future__ import annotations
+
+
+
+from enum import Enum
+from typing import Optional
+
+from pydantic import BaseModel, Field, conint, confloat
+
+
+class RecraftColor:
+ def __init__(self, r: int, g: int, b: int):
+ self.color = [r, g, b]
+
+ def create_api_model(self):
+ return RecraftColorObject(rgb=self.color)
+
+
+class RecraftColorChain:
+ def __init__(self):
+ self.colors: list[RecraftColor] = []
+
+ def get_first(self):
+ if len(self.colors) > 0:
+ return self.colors[0]
+ return None
+
+ def add(self, color: RecraftColor):
+ self.colors.append(color)
+
+ def create_api_model(self):
+ if not self.colors:
+ return None
+ colors_api = [x.create_api_model() for x in self.colors]
+ return colors_api
+
+ def clone(self):
+ c = RecraftColorChain()
+ for color in self.colors:
+ c.add(color)
+ return c
+
+ def clone_and_merge(self, other: RecraftColorChain):
+ c = self.clone()
+ for color in other.colors:
+ c.add(color)
+ return c
+
+
+class RecraftControls:
+ def __init__(self, colors: RecraftColorChain=None, background_color: RecraftColorChain=None,
+ artistic_level: int=None, no_text: bool=None):
+ self.colors = colors
+ self.background_color = background_color
+ self.artistic_level = artistic_level
+ self.no_text = no_text
+
+ def create_api_model(self):
+ if self.colors is None and self.background_color is None and self.artistic_level is None and self.no_text is None:
+ return None
+ colors_api = None
+ background_color_api = None
+ if self.colors:
+ colors_api = self.colors.create_api_model()
+ if self.background_color:
+ first_background = self.background_color.get_first()
+ background_color_api = first_background.create_api_model() if first_background else None
+
+ return RecraftControlsObject(colors=colors_api, background_color=background_color_api,
+ artistic_level=self.artistic_level, no_text=self.no_text)
+
+
+class RecraftStyle:
+ def __init__(self, style: str=None, substyle: str=None, style_id: str=None):
+ self.style = style
+ if substyle == "None":
+ substyle = None
+ self.substyle = substyle
+ self.style_id = style_id
+
+
+class RecraftIO:
+ STYLEV3 = "RECRAFT_V3_STYLE"
+ COLOR = "RECRAFT_COLOR"
+ CONTROLS = "RECRAFT_CONTROLS"
+
+
+class RecraftStyleV3(str, Enum):
+ #any = 'any' NOTE: this does not work for some reason... why?
+ realistic_image = 'realistic_image'
+ digital_illustration = 'digital_illustration'
+ vector_illustration = 'vector_illustration'
+ logo_raster = 'logo_raster'
+
+
+def get_v3_substyles(style_v3: str, include_none=True) -> list[str]:
+ substyles: list[str] = []
+ if include_none:
+ substyles.append("None")
+ return substyles + dict_recraft_substyles_v3.get(style_v3, [])
+
+
+dict_recraft_substyles_v3 = {
+ RecraftStyleV3.realistic_image: [
+ "b_and_w",
+ "enterprise",
+ "evening_light",
+ "faded_nostalgia",
+ "forest_life",
+ "hard_flash",
+ "hdr",
+ "motion_blur",
+ "mystic_naturalism",
+ "natural_light",
+ "natural_tones",
+ "organic_calm",
+ "real_life_glow",
+ "retro_realism",
+ "retro_snapshot",
+ "studio_portrait",
+ "urban_drama",
+ "village_realism",
+ "warm_folk"
+ ],
+ RecraftStyleV3.digital_illustration: [
+ "2d_art_poster",
+ "2d_art_poster_2",
+ "antiquarian",
+ "bold_fantasy",
+ "child_book",
+ "child_books",
+ "cover",
+ "crosshatch",
+ "digital_engraving",
+ "engraving_color",
+ "expressionism",
+ "freehand_details",
+ "grain",
+ "grain_20",
+ "graphic_intensity",
+ "hand_drawn",
+ "hand_drawn_outline",
+ "handmade_3d",
+ "hard_comics",
+ "infantile_sketch",
+ "long_shadow",
+ "modern_folk",
+ "multicolor",
+ "neon_calm",
+ "noir",
+ "nostalgic_pastel",
+ "outline_details",
+ "pastel_gradient",
+ "pastel_sketch",
+ "pixel_art",
+ "plastic",
+ "pop_art",
+ "pop_renaissance",
+ "seamless",
+ "street_art",
+ "tablet_sketch",
+ "urban_glow",
+ "urban_sketching",
+ "vanilla_dreams",
+ "young_adult_book",
+ "young_adult_book_2"
+ ],
+ RecraftStyleV3.vector_illustration: [
+ "bold_stroke",
+ "chemistry",
+ "colored_stencil",
+ "contour_pop_art",
+ "cosmics",
+ "cutout",
+ "depressive",
+ "editorial",
+ "emotional_flat",
+ "engraving",
+ "infographical",
+ "line_art",
+ "line_circuit",
+ "linocut",
+ "marker_outline",
+ "mosaic",
+ "naivector",
+ "roundish_flat",
+ "seamless",
+ "segmented_colors",
+ "sharp_contrast",
+ "thin",
+ "vector_photo",
+ "vivid_shapes"
+ ],
+ RecraftStyleV3.logo_raster: [
+ "emblem_graffiti",
+ "emblem_pop_art",
+ "emblem_punk",
+ "emblem_stamp",
+ "emblem_vintage"
+ ],
+}
+
+
+class RecraftModel(str, Enum):
+ recraftv3 = 'recraftv3'
+ recraftv2 = 'recraftv2'
+
+
+class RecraftImageSize(str, Enum):
+ res_1024x1024 = '1024x1024'
+ res_1365x1024 = '1365x1024'
+ res_1024x1365 = '1024x1365'
+ res_1536x1024 = '1536x1024'
+ res_1024x1536 = '1024x1536'
+ res_1820x1024 = '1820x1024'
+ res_1024x1820 = '1024x1820'
+ res_1024x2048 = '1024x2048'
+ res_2048x1024 = '2048x1024'
+ res_1434x1024 = '1434x1024'
+ res_1024x1434 = '1024x1434'
+ res_1024x1280 = '1024x1280'
+ res_1280x1024 = '1280x1024'
+ res_1024x1707 = '1024x1707'
+ res_1707x1024 = '1707x1024'
+
+
+class RecraftColorObject(BaseModel):
+ rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
+
+
+class RecraftControlsObject(BaseModel):
+ colors: Optional[list[RecraftColorObject]] = Field(None, description='An array of preferable colors')
+ background_color: Optional[RecraftColorObject] = Field(None, description='Use given color as a desired background color')
+ no_text: Optional[bool] = Field(None, description='Do not embed text layouts')
+ artistic_level: Optional[conint(ge=0, le=5)] = Field(None, description='Defines artistic tone of your image. At a simple level, the person looks straight at the camera in a static and clean style. Dynamic and eccentric levels introduce movement and creativity. The value should be in range [0..5].')
+
+
+class RecraftImageGenerationRequest(BaseModel):
+ prompt: str = Field(..., description='The text prompt describing the image to generate')
+ size: Optional[RecraftImageSize] = Field(None, description='The size of the generated image (e.g., "1024x1024")')
+ n: conint(ge=1, le=6) = Field(..., description='The number of images to generate')
+ negative_prompt: Optional[str] = Field(None, description='A text description of undesired elements on an image')
+ model: Optional[RecraftModel] = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
+ style: Optional[str] = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
+ substyle: Optional[str] = Field(None, description='The substyle to apply to the generated image, depending on the style input')
+ controls: Optional[RecraftControlsObject] = Field(None, description='A set of custom parameters to tweak generation process')
+ style_id: Optional[str] = Field(None, description='Use a previously uploaded style as a reference; UUID')
+ strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
+ random_seed: Optional[int] = Field(None, description="Seed for video generation")
+ # text_layout
+
+
+class RecraftReturnedObject(BaseModel):
+ image_id: str = Field(..., description='Unique identifier for the generated image')
+ url: str = Field(..., description='URL to access the generated image')
+
+
+class RecraftImageGenerationResponse(BaseModel):
+ created: int = Field(..., description='Unix timestamp when the generation was created')
+ credits: int = Field(..., description='Number of credits used for the generation')
+ data: Optional[list[RecraftReturnedObject]] = Field(None, description='Array of generated image information')
+ image: Optional[RecraftReturnedObject] = Field(None, description='Single generated image')
diff --git a/comfy_api_nodes/apis/request_logger.py b/comfy_api_nodes/apis/request_logger.py
new file mode 100644
index 000000000..93517ede9
--- /dev/null
+++ b/comfy_api_nodes/apis/request_logger.py
@@ -0,0 +1,125 @@
+import os
+import datetime
+import json
+import logging
+import folder_paths
+
+# Get the logger instance
+logger = logging.getLogger(__name__)
+
+def get_log_directory():
+ """
+ Ensures the API log directory exists within ComfyUI's temp directory
+ and returns its path.
+ """
+ base_temp_dir = folder_paths.get_temp_directory()
+ log_dir = os.path.join(base_temp_dir, "api_logs")
+ try:
+ os.makedirs(log_dir, exist_ok=True)
+ except Exception as e:
+ logger.error(f"Error creating API log directory {log_dir}: {e}")
+ # Fallback to base temp directory if sub-directory creation fails
+ return base_temp_dir
+ return log_dir
+
+def _format_data_for_logging(data):
+ """Helper to format data (dict, str, bytes) for logging."""
+ if isinstance(data, bytes):
+ try:
+ return data.decode('utf-8') # Try to decode as text
+ except UnicodeDecodeError:
+ return f"[Binary data of length {len(data)} bytes]"
+ elif isinstance(data, (dict, list)):
+ try:
+ return json.dumps(data, indent=2, ensure_ascii=False)
+ except TypeError:
+ return str(data) # Fallback for non-serializable objects
+ return str(data)
+
+def log_request_response(
+ operation_id: str,
+ request_method: str,
+ request_url: str,
+ request_headers: dict | None = None,
+ request_params: dict | None = None,
+ request_data: any = None,
+ response_status_code: int | None = None,
+ response_headers: dict | None = None,
+ response_content: any = None,
+ error_message: str | None = None
+):
+ """
+ Logs API request and response details to a file in the temp/api_logs directory.
+ """
+ log_dir = get_log_directory()
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
+ filename = f"{timestamp}_{operation_id.replace('/', '_').replace(':', '_')}.log"
+ filepath = os.path.join(log_dir, filename)
+
+ log_content = []
+
+ log_content.append(f"Timestamp: {datetime.datetime.now().isoformat()}")
+ log_content.append(f"Operation ID: {operation_id}")
+ log_content.append("-" * 30 + " REQUEST " + "-" * 30)
+ log_content.append(f"Method: {request_method}")
+ log_content.append(f"URL: {request_url}")
+ if request_headers:
+ log_content.append(f"Headers:\n{_format_data_for_logging(request_headers)}")
+ if request_params:
+ log_content.append(f"Params:\n{_format_data_for_logging(request_params)}")
+ if request_data:
+ log_content.append(f"Data/Body:\n{_format_data_for_logging(request_data)}")
+
+ log_content.append("\n" + "-" * 30 + " RESPONSE " + "-" * 30)
+ if response_status_code is not None:
+ log_content.append(f"Status Code: {response_status_code}")
+ if response_headers:
+ log_content.append(f"Headers:\n{_format_data_for_logging(response_headers)}")
+ if response_content:
+ log_content.append(f"Content:\n{_format_data_for_logging(response_content)}")
+ if error_message:
+ log_content.append(f"Error:\n{error_message}")
+
+ try:
+ with open(filepath, "w", encoding="utf-8") as f:
+ f.write("\n".join(log_content))
+ logger.debug(f"API log saved to: {filepath}")
+ except Exception as e:
+ logger.error(f"Error writing API log to {filepath}: {e}")
+
+if __name__ == '__main__':
+ # Example usage (for testing the logger directly)
+ logger.setLevel(logging.DEBUG)
+ # Mock folder_paths for direct execution if not running within ComfyUI full context
+ if not hasattr(folder_paths, 'get_temp_directory'):
+ class MockFolderPaths:
+ def get_temp_directory(self):
+ # Create a local temp dir for testing if needed
+ p = os.path.join(os.path.dirname(__file__), 'temp_test_logs')
+ os.makedirs(p, exist_ok=True)
+ return p
+ folder_paths = MockFolderPaths()
+
+ log_request_response(
+ operation_id="test_operation_get",
+ request_method="GET",
+ request_url="https://api.example.com/test",
+ request_headers={"Authorization": "Bearer testtoken"},
+ request_params={"param1": "value1"},
+ response_status_code=200,
+ response_content={"message": "Success!"}
+ )
+ log_request_response(
+ operation_id="test_operation_post_error",
+ request_method="POST",
+ request_url="https://api.example.com/submit",
+ request_data={"key": "value", "nested": {"num": 123}},
+ error_message="Connection timed out"
+ )
+ log_request_response(
+ operation_id="test_binary_response",
+ request_method="GET",
+ request_url="https://api.example.com/image.png",
+ response_status_code=200,
+ response_content=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR...' # Sample binary data
+ )
diff --git a/comfy_api_nodes/apis/rodin_api.py b/comfy_api_nodes/apis/rodin_api.py
new file mode 100644
index 000000000..b0cf171fa
--- /dev/null
+++ b/comfy_api_nodes/apis/rodin_api.py
@@ -0,0 +1,57 @@
+from __future__ import annotations
+
+from enum import Enum
+from typing import Optional, List
+from pydantic import BaseModel, Field
+
+
+class Rodin3DGenerateRequest(BaseModel):
+ seed: int = Field(..., description="seed_")
+ tier: str = Field(..., description="Tier of generation.")
+ material: str = Field(..., description="The material type.")
+ quality: str = Field(..., description="The generation quality of the mesh.")
+ mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
+
+class GenerateJobsData(BaseModel):
+ uuids: List[str] = Field(..., description="str LIST")
+ subscription_key: str = Field(..., description="subscription key")
+
+class Rodin3DGenerateResponse(BaseModel):
+ message: Optional[str] = Field(None, description="Return message.")
+ prompt: Optional[str] = Field(None, description="Generated Prompt from image.")
+ submit_time: Optional[str] = Field(None, description="Submit Time")
+ uuid: Optional[str] = Field(None, description="Task str")
+ jobs: Optional[GenerateJobsData] = Field(None, description="Details of jobs")
+
+class JobStatus(str, Enum):
+ """
+ Status for jobs
+ """
+ Done = "Done"
+ Failed = "Failed"
+ Generating = "Generating"
+ Waiting = "Waiting"
+
+class Rodin3DCheckStatusRequest(BaseModel):
+ subscription_key: str = Field(..., description="subscription from generate endpoint")
+
+class JobItem(BaseModel):
+ uuid: str = Field(..., description="uuid")
+ status: JobStatus = Field(...,description="Status Currently")
+
+class Rodin3DCheckStatusResponse(BaseModel):
+ jobs: List[JobItem] = Field(..., description="Job status List")
+
+class Rodin3DDownloadRequest(BaseModel):
+ task_uuid: str = Field(..., description="Task str")
+
+class RodinResourceItem(BaseModel):
+ url: str = Field(..., description="Download Url")
+ name: str = Field(..., description="File name with ext")
+
+class Rodin3DDownloadResponse(BaseModel):
+ list: List[RodinResourceItem] = Field(..., description="Source List")
+
+
+
+
diff --git a/comfy_api_nodes/apis/stability_api.py b/comfy_api_nodes/apis/stability_api.py
new file mode 100644
index 000000000..47c87daec
--- /dev/null
+++ b/comfy_api_nodes/apis/stability_api.py
@@ -0,0 +1,127 @@
+from __future__ import annotations
+
+from enum import Enum
+from typing import Optional
+
+from pydantic import BaseModel, Field, confloat
+
+
+class StabilityFormat(str, Enum):
+ png = 'png'
+ jpeg = 'jpeg'
+ webp = 'webp'
+
+
+class StabilityAspectRatio(str, Enum):
+ ratio_1_1 = "1:1"
+ ratio_16_9 = "16:9"
+ ratio_9_16 = "9:16"
+ ratio_3_2 = "3:2"
+ ratio_2_3 = "2:3"
+ ratio_5_4 = "5:4"
+ ratio_4_5 = "4:5"
+ ratio_21_9 = "21:9"
+ ratio_9_21 = "9:21"
+
+
+def get_stability_style_presets(include_none=True):
+ presets = []
+ if include_none:
+ presets.append("None")
+ return presets + [x.value for x in StabilityStylePreset]
+
+
+class StabilityStylePreset(str, Enum):
+ _3d_model = "3d-model"
+ analog_film = "analog-film"
+ anime = "anime"
+ cinematic = "cinematic"
+ comic_book = "comic-book"
+ digital_art = "digital-art"
+ enhance = "enhance"
+ fantasy_art = "fantasy-art"
+ isometric = "isometric"
+ line_art = "line-art"
+ low_poly = "low-poly"
+ modeling_compound = "modeling-compound"
+ neon_punk = "neon-punk"
+ origami = "origami"
+ photographic = "photographic"
+ pixel_art = "pixel-art"
+ tile_texture = "tile-texture"
+
+
+class Stability_SD3_5_Model(str, Enum):
+ sd3_5_large = "sd3.5-large"
+ # sd3_5_large_turbo = "sd3.5-large-turbo"
+ sd3_5_medium = "sd3.5-medium"
+
+
+class Stability_SD3_5_GenerationMode(str, Enum):
+ text_to_image = "text-to-image"
+ image_to_image = "image-to-image"
+
+
+class StabilityStable3_5Request(BaseModel):
+ model: str = Field(...)
+ mode: str = Field(...)
+ prompt: str = Field(...)
+ negative_prompt: Optional[str] = Field(None)
+ aspect_ratio: Optional[str] = Field(None)
+ seed: Optional[int] = Field(None)
+ output_format: Optional[str] = Field(StabilityFormat.png.value)
+ image: Optional[str] = Field(None)
+ style_preset: Optional[str] = Field(None)
+ cfg_scale: float = Field(...)
+ strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
+
+
+class StabilityUpscaleConservativeRequest(BaseModel):
+ prompt: str = Field(...)
+ negative_prompt: Optional[str] = Field(None)
+ seed: Optional[int] = Field(None)
+ output_format: Optional[str] = Field(StabilityFormat.png.value)
+ image: Optional[str] = Field(None)
+ creativity: Optional[confloat(ge=0.2, le=0.5)] = Field(None)
+
+
+class StabilityUpscaleCreativeRequest(BaseModel):
+ prompt: str = Field(...)
+ negative_prompt: Optional[str] = Field(None)
+ seed: Optional[int] = Field(None)
+ output_format: Optional[str] = Field(StabilityFormat.png.value)
+ image: Optional[str] = Field(None)
+ creativity: Optional[confloat(ge=0.1, le=0.5)] = Field(None)
+ style_preset: Optional[str] = Field(None)
+
+
+class StabilityStableUltraRequest(BaseModel):
+ prompt: str = Field(...)
+ negative_prompt: Optional[str] = Field(None)
+ aspect_ratio: Optional[str] = Field(None)
+ seed: Optional[int] = Field(None)
+ output_format: Optional[str] = Field(StabilityFormat.png.value)
+ image: Optional[str] = Field(None)
+ style_preset: Optional[str] = Field(None)
+ strength: Optional[confloat(ge=0.0, le=1.0)] = Field(None)
+
+
+class StabilityStableUltraResponse(BaseModel):
+ image: Optional[str] = Field(None)
+ finish_reason: Optional[str] = Field(None)
+ seed: Optional[int] = Field(None)
+
+
+class StabilityResultsGetResponse(BaseModel):
+ image: Optional[str] = Field(None)
+ finish_reason: Optional[str] = Field(None)
+ seed: Optional[int] = Field(None)
+ id: Optional[str] = Field(None)
+ name: Optional[str] = Field(None)
+ errors: Optional[list[str]] = Field(None)
+ status: Optional[str] = Field(None)
+ result: Optional[str] = Field(None)
+
+
+class StabilityAsyncResponse(BaseModel):
+ id: Optional[str] = Field(None)
diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo_api.py
new file mode 100644
index 000000000..626e8d277
--- /dev/null
+++ b/comfy_api_nodes/apis/tripo_api.py
@@ -0,0 +1,275 @@
+from __future__ import annotations
+from comfy_api_nodes.apis import (
+ TripoModelVersion,
+ TripoTextureQuality,
+)
+from enum import Enum
+from typing import Optional, List, Dict, Any, Union
+
+from pydantic import BaseModel, Field, RootModel
+
+class TripoStyle(str, Enum):
+ PERSON_TO_CARTOON = "person:person2cartoon"
+ ANIMAL_VENOM = "animal:venom"
+ OBJECT_CLAY = "object:clay"
+ OBJECT_STEAMPUNK = "object:steampunk"
+ OBJECT_CHRISTMAS = "object:christmas"
+ OBJECT_BARBIE = "object:barbie"
+ GOLD = "gold"
+ ANCIENT_BRONZE = "ancient_bronze"
+ NONE = "None"
+
+class TripoTaskType(str, Enum):
+ TEXT_TO_MODEL = "text_to_model"
+ IMAGE_TO_MODEL = "image_to_model"
+ MULTIVIEW_TO_MODEL = "multiview_to_model"
+ TEXTURE_MODEL = "texture_model"
+ REFINE_MODEL = "refine_model"
+ ANIMATE_PRERIGCHECK = "animate_prerigcheck"
+ ANIMATE_RIG = "animate_rig"
+ ANIMATE_RETARGET = "animate_retarget"
+ STYLIZE_MODEL = "stylize_model"
+ CONVERT_MODEL = "convert_model"
+
+class TripoTextureAlignment(str, Enum):
+ ORIGINAL_IMAGE = "original_image"
+ GEOMETRY = "geometry"
+
+class TripoOrientation(str, Enum):
+ ALIGN_IMAGE = "align_image"
+ DEFAULT = "default"
+
+class TripoOutFormat(str, Enum):
+ GLB = "glb"
+ FBX = "fbx"
+
+class TripoTopology(str, Enum):
+ BIP = "bip"
+ QUAD = "quad"
+
+class TripoSpec(str, Enum):
+ MIXAMO = "mixamo"
+ TRIPO = "tripo"
+
+class TripoAnimation(str, Enum):
+ IDLE = "preset:idle"
+ WALK = "preset:walk"
+ CLIMB = "preset:climb"
+ JUMP = "preset:jump"
+ RUN = "preset:run"
+ SLASH = "preset:slash"
+ SHOOT = "preset:shoot"
+ HURT = "preset:hurt"
+ FALL = "preset:fall"
+ TURN = "preset:turn"
+
+class TripoStylizeStyle(str, Enum):
+ LEGO = "lego"
+ VOXEL = "voxel"
+ VORONOI = "voronoi"
+ MINECRAFT = "minecraft"
+
+class TripoConvertFormat(str, Enum):
+ GLTF = "GLTF"
+ USDZ = "USDZ"
+ FBX = "FBX"
+ OBJ = "OBJ"
+ STL = "STL"
+ _3MF = "3MF"
+
+class TripoTextureFormat(str, Enum):
+ BMP = "BMP"
+ DPX = "DPX"
+ HDR = "HDR"
+ JPEG = "JPEG"
+ OPEN_EXR = "OPEN_EXR"
+ PNG = "PNG"
+ TARGA = "TARGA"
+ TIFF = "TIFF"
+ WEBP = "WEBP"
+
+class TripoTaskStatus(str, Enum):
+ QUEUED = "queued"
+ RUNNING = "running"
+ SUCCESS = "success"
+ FAILED = "failed"
+ CANCELLED = "cancelled"
+ UNKNOWN = "unknown"
+ BANNED = "banned"
+ EXPIRED = "expired"
+
+class TripoFileTokenReference(BaseModel):
+ type: Optional[str] = Field(None, description='The type of the reference')
+ file_token: str
+
+class TripoUrlReference(BaseModel):
+ type: Optional[str] = Field(None, description='The type of the reference')
+ url: str
+
+class TripoObjectStorage(BaseModel):
+ bucket: str
+ key: str
+
+class TripoObjectReference(BaseModel):
+ type: str
+ object: TripoObjectStorage
+
+class TripoFileEmptyReference(BaseModel):
+ pass
+
+class TripoFileReference(RootModel):
+ root: Union[TripoFileTokenReference, TripoUrlReference, TripoObjectReference, TripoFileEmptyReference]
+
+class TripoGetStsTokenRequest(BaseModel):
+ format: str = Field(..., description='The format of the image')
+
+class TripoTextToModelRequest(BaseModel):
+ type: TripoTaskType = Field(TripoTaskType.TEXT_TO_MODEL, description='Type of task')
+ prompt: str = Field(..., description='The text prompt describing the model to generate', max_length=1024)
+ negative_prompt: Optional[str] = Field(None, description='The negative text prompt', max_length=1024)
+ model_version: Optional[TripoModelVersion] = TripoModelVersion.V2_5
+ face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
+ texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
+ pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
+ image_seed: Optional[int] = Field(None, description='The seed for the text')
+ model_seed: Optional[int] = Field(None, description='The seed for the model')
+ texture_seed: Optional[int] = Field(None, description='The seed for the texture')
+ texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
+ style: Optional[TripoStyle] = None
+ auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
+ quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
+
+class TripoImageToModelRequest(BaseModel):
+ type: TripoTaskType = Field(TripoTaskType.IMAGE_TO_MODEL, description='Type of task')
+ file: TripoFileReference = Field(..., description='The file reference to convert to a model')
+ model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
+ face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
+ texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
+ pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
+ model_seed: Optional[int] = Field(None, description='The seed for the model')
+ texture_seed: Optional[int] = Field(None, description='The seed for the texture')
+ texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
+ texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
+ style: Optional[TripoStyle] = Field(None, description='The style to apply to the generated model')
+ auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
+ orientation: Optional[TripoOrientation] = TripoOrientation.DEFAULT
+ quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
+
+class TripoMultiviewToModelRequest(BaseModel):
+ type: TripoTaskType = TripoTaskType.MULTIVIEW_TO_MODEL
+ files: List[TripoFileReference] = Field(..., description='The file references to convert to a model')
+ model_version: Optional[TripoModelVersion] = Field(None, description='The model version to use for generation')
+ orthographic_projection: Optional[bool] = Field(False, description='Whether to use orthographic projection')
+ face_limit: Optional[int] = Field(None, description='The number of faces to limit the generation to')
+ texture: Optional[bool] = Field(True, description='Whether to apply texture to the generated model')
+ pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the generated model')
+ model_seed: Optional[int] = Field(None, description='The seed for the model')
+ texture_seed: Optional[int] = Field(None, description='The seed for the texture')
+ texture_quality: Optional[TripoTextureQuality] = TripoTextureQuality.standard
+ texture_alignment: Optional[TripoTextureAlignment] = TripoTextureAlignment.ORIGINAL_IMAGE
+ auto_size: Optional[bool] = Field(False, description='Whether to auto-size the model')
+ orientation: Optional[TripoOrientation] = Field(TripoOrientation.DEFAULT, description='The orientation for the model')
+ quad: Optional[bool] = Field(False, description='Whether to apply quad to the generated model')
+
+class TripoTextureModelRequest(BaseModel):
+ type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description='Type of task')
+ original_model_task_id: str = Field(..., description='The task ID of the original model')
+ texture: Optional[bool] = Field(True, description='Whether to apply texture to the model')
+ pbr: Optional[bool] = Field(True, description='Whether to apply PBR to the model')
+ model_seed: Optional[int] = Field(None, description='The seed for the model')
+ texture_seed: Optional[int] = Field(None, description='The seed for the texture')
+ texture_quality: Optional[TripoTextureQuality] = Field(None, description='The quality of the texture')
+ texture_alignment: Optional[TripoTextureAlignment] = Field(TripoTextureAlignment.ORIGINAL_IMAGE, description='The texture alignment method')
+
+class TripoRefineModelRequest(BaseModel):
+ type: TripoTaskType = Field(TripoTaskType.REFINE_MODEL, description='Type of task')
+ draft_model_task_id: str = Field(..., description='The task ID of the draft model')
+
+class TripoAnimatePrerigcheckRequest(BaseModel):
+ type: TripoTaskType = Field(TripoTaskType.ANIMATE_PRERIGCHECK, description='Type of task')
+ original_model_task_id: str = Field(..., description='The task ID of the original model')
+
+class TripoAnimateRigRequest(BaseModel):
+ type: TripoTaskType = Field(TripoTaskType.ANIMATE_RIG, description='Type of task')
+ original_model_task_id: str = Field(..., description='The task ID of the original model')
+ out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
+ spec: Optional[TripoSpec] = Field(TripoSpec.TRIPO, description='The specification for rigging')
+
+class TripoAnimateRetargetRequest(BaseModel):
+ type: TripoTaskType = Field(TripoTaskType.ANIMATE_RETARGET, description='Type of task')
+ original_model_task_id: str = Field(..., description='The task ID of the original model')
+ animation: TripoAnimation = Field(..., description='The animation to apply')
+ out_format: Optional[TripoOutFormat] = Field(TripoOutFormat.GLB, description='The output format')
+ bake_animation: Optional[bool] = Field(True, description='Whether to bake the animation')
+
+class TripoStylizeModelRequest(BaseModel):
+ type: TripoTaskType = Field(TripoTaskType.STYLIZE_MODEL, description='Type of task')
+ style: TripoStylizeStyle = Field(..., description='The style to apply to the model')
+ original_model_task_id: str = Field(..., description='The task ID of the original model')
+ block_size: Optional[int] = Field(80, description='The block size for stylization')
+
+class TripoConvertModelRequest(BaseModel):
+ type: TripoTaskType = Field(TripoTaskType.CONVERT_MODEL, description='Type of task')
+ format: TripoConvertFormat = Field(..., description='The format to convert to')
+ original_model_task_id: str = Field(..., description='The task ID of the original model')
+ quad: Optional[bool] = Field(False, description='Whether to apply quad to the model')
+ force_symmetry: Optional[bool] = Field(False, description='Whether to force symmetry')
+ face_limit: Optional[int] = Field(10000, description='The number of faces to limit the conversion to')
+ flatten_bottom: Optional[bool] = Field(False, description='Whether to flatten the bottom of the model')
+ flatten_bottom_threshold: Optional[float] = Field(0.01, description='The threshold for flattening the bottom')
+ texture_size: Optional[int] = Field(4096, description='The size of the texture')
+ texture_format: Optional[TripoTextureFormat] = Field(TripoTextureFormat.JPEG, description='The format of the texture')
+ pivot_to_center_bottom: Optional[bool] = Field(False, description='Whether to pivot to the center bottom')
+
+class TripoTaskRequest(RootModel):
+ root: Union[
+ TripoTextToModelRequest,
+ TripoImageToModelRequest,
+ TripoMultiviewToModelRequest,
+ TripoTextureModelRequest,
+ TripoRefineModelRequest,
+ TripoAnimatePrerigcheckRequest,
+ TripoAnimateRigRequest,
+ TripoAnimateRetargetRequest,
+ TripoStylizeModelRequest,
+ TripoConvertModelRequest
+ ]
+
+class TripoTaskOutput(BaseModel):
+ model: Optional[str] = Field(None, description='URL to the model')
+ base_model: Optional[str] = Field(None, description='URL to the base model')
+ pbr_model: Optional[str] = Field(None, description='URL to the PBR model')
+ rendered_image: Optional[str] = Field(None, description='URL to the rendered image')
+ riggable: Optional[bool] = Field(None, description='Whether the model is riggable')
+
+class TripoTask(BaseModel):
+ task_id: str = Field(..., description='The task ID')
+ type: Optional[str] = Field(None, description='The type of task')
+ status: Optional[TripoTaskStatus] = Field(None, description='The status of the task')
+ input: Optional[Dict[str, Any]] = Field(None, description='The input parameters for the task')
+ output: Optional[TripoTaskOutput] = Field(None, description='The output of the task')
+ progress: Optional[int] = Field(None, description='The progress of the task', ge=0, le=100)
+ create_time: Optional[int] = Field(None, description='The creation time of the task')
+ running_left_time: Optional[int] = Field(None, description='The estimated time left for the task')
+ queue_position: Optional[int] = Field(None, description='The position in the queue')
+
+class TripoTaskResponse(BaseModel):
+ code: int = Field(0, description='The response code')
+ data: TripoTask = Field(..., description='The task data')
+
+class TripoGeneralResponse(BaseModel):
+ code: int = Field(0, description='The response code')
+ data: Dict[str, str] = Field(..., description='The task ID data')
+
+class TripoBalanceData(BaseModel):
+ balance: float = Field(..., description='The account balance')
+ frozen: float = Field(..., description='The frozen balance')
+
+class TripoBalanceResponse(BaseModel):
+ code: int = Field(0, description='The response code')
+ data: TripoBalanceData = Field(..., description='The balance data')
+
+class TripoErrorResponse(BaseModel):
+ code: int = Field(..., description='The error code')
+ message: str = Field(..., description='The error message')
+ suggestion: str = Field(..., description='The suggestion for fixing the error')
diff --git a/comfy_api_nodes/canary.py b/comfy_api_nodes/canary.py
new file mode 100644
index 000000000..4df7590b6
--- /dev/null
+++ b/comfy_api_nodes/canary.py
@@ -0,0 +1,10 @@
+import av
+
+ver = av.__version__.split(".")
+if int(ver[0]) < 14:
+ raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
+
+if int(ver[0]) == 14 and int(ver[1]) < 2:
+ raise Exception("INSTALL NEW VERSION OF PYAV TO USE API NODES.")
+
+NODE_CLASS_MAPPINGS = {}
diff --git a/comfy_api_nodes/mapper_utils.py b/comfy_api_nodes/mapper_utils.py
new file mode 100644
index 000000000..6fab8f4bb
--- /dev/null
+++ b/comfy_api_nodes/mapper_utils.py
@@ -0,0 +1,116 @@
+from enum import Enum
+
+from pydantic.fields import FieldInfo
+from pydantic import BaseModel
+from pydantic_core import PydanticUndefined
+
+from comfy.comfy_types.node_typing import IO, InputTypeOptions
+
+NodeInput = tuple[IO, InputTypeOptions]
+
+
+def _create_base_config(field_info: FieldInfo) -> InputTypeOptions:
+ config = {}
+ if hasattr(field_info, "default") and field_info.default is not PydanticUndefined:
+ config["default"] = field_info.default
+ if hasattr(field_info, "description") and field_info.description is not None:
+ config["tooltip"] = field_info.description
+ return config
+
+
+def _get_number_constraints_config(field_info: FieldInfo) -> dict:
+ config = {}
+ if hasattr(field_info, "metadata"):
+ metadata = field_info.metadata
+ for constraint in metadata:
+ if hasattr(constraint, "ge"):
+ config["min"] = constraint.ge
+ if hasattr(constraint, "le"):
+ config["max"] = constraint.le
+ if hasattr(constraint, "multiple_of"):
+ config["step"] = constraint.multiple_of
+ return config
+
+
+def _model_field_to_image_input(field_info: FieldInfo, **kwargs) -> NodeInput:
+ return IO.IMAGE, {
+ **_create_base_config(field_info),
+ **kwargs,
+ }
+
+
+def _model_field_to_string_input(field_info: FieldInfo, **kwargs) -> NodeInput:
+ return IO.STRING, {
+ **_create_base_config(field_info),
+ **kwargs,
+ }
+
+
+def _model_field_to_float_input(field_info: FieldInfo, **kwargs) -> NodeInput:
+ return IO.FLOAT, {
+ **_create_base_config(field_info),
+ **_get_number_constraints_config(field_info),
+ **kwargs,
+ }
+
+
+def _model_field_to_int_input(field_info: FieldInfo, **kwargs) -> NodeInput:
+ return IO.INT, {
+ **_create_base_config(field_info),
+ **_get_number_constraints_config(field_info),
+ **kwargs,
+ }
+
+
+def _model_field_to_combo_input(
+ field_info: FieldInfo, enum_type: type[Enum] = None, **kwargs
+) -> NodeInput:
+ combo_config = {}
+ if enum_type is not None:
+ combo_config["options"] = [option.value for option in enum_type]
+ combo_config = {
+ **combo_config,
+ **_create_base_config(field_info),
+ **kwargs,
+ }
+ return IO.COMBO, combo_config
+
+
+def model_field_to_node_input(
+ input_type: IO, base_model: type[BaseModel], field_name: str, **kwargs
+) -> NodeInput:
+ """
+ Maps a field from a Pydantic model to a Comfy node input.
+
+ Args:
+ input_type: The type of the input.
+ base_model: The Pydantic model to map the field from.
+ field_name: The name of the field to map.
+ **kwargs: Additional key/values to include in the input options.
+
+ Note:
+ For combo inputs, pass an `Enum` to the `enum_type` keyword argument to populate the options automatically.
+
+ Example:
+ >>> model_field_to_node_input(IO.STRING, MyModel, "my_field", multiline=True)
+ >>> model_field_to_node_input(IO.COMBO, MyModel, "my_field", enum_type=MyEnum)
+ >>> model_field_to_node_input(IO.FLOAT, MyModel, "my_field", slider=True)
+ """
+ field_info: FieldInfo = base_model.model_fields[field_name]
+ result: NodeInput
+
+ if input_type == IO.IMAGE:
+ result = _model_field_to_image_input(field_info, **kwargs)
+ elif input_type == IO.STRING:
+ result = _model_field_to_string_input(field_info, **kwargs)
+ elif input_type == IO.FLOAT:
+ result = _model_field_to_float_input(field_info, **kwargs)
+ elif input_type == IO.INT:
+ result = _model_field_to_int_input(field_info, **kwargs)
+ elif input_type == IO.COMBO:
+ result = _model_field_to_combo_input(field_info, **kwargs)
+ else:
+ message = f"Invalid input type: {input_type}"
+ raise ValueError(message)
+
+ return result
diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py
new file mode 100644
index 000000000..509170b34
--- /dev/null
+++ b/comfy_api_nodes/nodes_bfl.py
@@ -0,0 +1,931 @@
+import io
+from inspect import cleandoc
+from typing import Union
+from comfy.comfy_types.node_typing import IO, ComfyNodeABC
+from comfy_api_nodes.apis.bfl_api import (
+ BFLStatus,
+ BFLFluxExpandImageRequest,
+ BFLFluxFillImageRequest,
+ BFLFluxCannyImageRequest,
+ BFLFluxDepthImageRequest,
+ BFLFluxProGenerateRequest,
+ BFLFluxProUltraGenerateRequest,
+ BFLFluxProGenerateResponse,
+)
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+)
+from comfy_api_nodes.apinode_utils import (
+ downscale_image_tensor,
+ validate_aspect_ratio,
+ process_image_response,
+ resize_mask_to_image,
+ validate_string,
+)
+
+import numpy as np
+from PIL import Image
+import requests
+import torch
+import base64
+import time
+from server import PromptServer
+
+
+def convert_mask_to_image(mask: torch.Tensor):
+ """
+ Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image.
+ """
+ mask = mask.unsqueeze(-1)
+ mask = torch.cat([mask]*3, dim=-1)
+ return mask
+
+
+def handle_bfl_synchronous_operation(
+ operation: SynchronousOperation,
+ timeout_bfl_calls=360,
+ node_id: Union[str, None] = None,
+):
+ response_api: BFLFluxProGenerateResponse = operation.execute()
+ return _poll_until_generated(
+ response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
+ )
+
+
+def _poll_until_generated(
+ polling_url: str, timeout=360, node_id: Union[str, None] = None
+):
+ # used bfl-comfy-nodes to verify code implementation:
+ # https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
+ start_time = time.time()
+ retries_404 = 0
+ max_retries_404 = 5
+ retry_404_seconds = 2
+ retry_202_seconds = 2
+ retry_pending_seconds = 1
+ request = requests.Request(method=HttpMethod.GET, url=polling_url)
+ # NOTE: should True loop be replaced with checking if workflow has been interrupted?
+ while True:
+ if node_id:
+ time_elapsed = time.time() - start_time
+ PromptServer.instance.send_progress_text(
+ f"Generating ({time_elapsed:.0f}s)", node_id
+ )
+
+ response = requests.Session().send(request.prepare())
+ if response.status_code == 200:
+ result = response.json()
+ if result["status"] == BFLStatus.ready:
+ img_url = result["result"]["sample"]
+ if node_id:
+ PromptServer.instance.send_progress_text(
+ f"Result URL: {img_url}", node_id
+ )
+ img_response = requests.get(img_url)
+ return process_image_response(img_response)
+ elif result["status"] in [
+ BFLStatus.request_moderated,
+ BFLStatus.content_moderated,
+ ]:
+ status = result["status"]
+ raise Exception(
+ f"BFL API did not return an image due to: {status}."
+ )
+ elif result["status"] == BFLStatus.error:
+ raise Exception(f"BFL API encountered an error: {result}.")
+ elif result["status"] == BFLStatus.pending:
+ time.sleep(retry_pending_seconds)
+ continue
+ elif response.status_code == 404:
+ if retries_404 < max_retries_404:
+ retries_404 += 1
+ time.sleep(retry_404_seconds)
+ continue
+ raise Exception(
+ f"BFL API could not find task after {max_retries_404} tries."
+ )
+ elif response.status_code == 202:
+ time.sleep(retry_202_seconds)
+ elif time.time() - start_time > timeout:
+ raise Exception(
+ f"BFL API experienced a timeout; could not return request under {timeout} seconds."
+ )
+ else:
+ raise Exception(f"BFL API encountered an error: {response.json()}")
+
+def convert_image_to_base64(image: torch.Tensor):
+ scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048)
+ # remove batch dimension if present
+ if len(scaled_image.shape) > 3:
+ scaled_image = scaled_image[0]
+ image_np = (scaled_image.numpy() * 255).astype(np.uint8)
+ img = Image.fromarray(image_np)
+ img_byte_arr = io.BytesIO()
+ img.save(img_byte_arr, format="PNG")
+ return base64.b64encode(img_byte_arr.getvalue()).decode()
+
+
+class FluxProUltraImageNode(ComfyNodeABC):
+ """
+ Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution.
+ """
+
+ MINIMUM_RATIO = 1 / 4
+ MAXIMUM_RATIO = 4 / 1
+ MINIMUM_RATIO_STR = "1:4"
+ MAXIMUM_RATIO_STR = "4:1"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation",
+ },
+ ),
+ "prompt_upsampling": (
+ IO.BOOLEAN,
+ {
+ "default": False,
+ "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "The random seed used for creating the noise.",
+ },
+ ),
+ "aspect_ratio": (
+ IO.STRING,
+ {
+ "default": "16:9",
+ "tooltip": "Aspect ratio of image; must be between 1:4 and 4:1.",
+ },
+ ),
+ "raw": (
+ IO.BOOLEAN,
+ {
+ "default": False,
+ "tooltip": "When True, generate less processed, more natural-looking images.",
+ },
+ ),
+ },
+ "optional": {
+ "image_prompt": (IO.IMAGE,),
+ "image_prompt_strength": (
+ IO.FLOAT,
+ {
+ "default": 0.1,
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ "tooltip": "Blend between the prompt and the image prompt.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ @classmethod
+ def VALIDATE_INPUTS(cls, aspect_ratio: str):
+ try:
+ validate_aspect_ratio(
+ aspect_ratio,
+ minimum_ratio=cls.MINIMUM_RATIO,
+ maximum_ratio=cls.MAXIMUM_RATIO,
+ minimum_ratio_str=cls.MINIMUM_RATIO_STR,
+ maximum_ratio_str=cls.MAXIMUM_RATIO_STR,
+ )
+ except Exception as e:
+ return str(e)
+ return True
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/BFL"
+
+ def api_call(
+ self,
+ prompt: str,
+ aspect_ratio: str,
+ prompt_upsampling=False,
+ raw=False,
+ seed=0,
+ image_prompt=None,
+ image_prompt_strength=0.1,
+ unique_id: Union[str, None] = None,
+ **kwargs,
+ ):
+ if image_prompt is None:
+ validate_string(prompt, strip_whitespace=False)
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/bfl/flux-pro-1.1-ultra/generate",
+ method=HttpMethod.POST,
+ request_model=BFLFluxProUltraGenerateRequest,
+ response_model=BFLFluxProGenerateResponse,
+ ),
+ request=BFLFluxProUltraGenerateRequest(
+ prompt=prompt,
+ prompt_upsampling=prompt_upsampling,
+ seed=seed,
+ aspect_ratio=validate_aspect_ratio(
+ aspect_ratio,
+ minimum_ratio=self.MINIMUM_RATIO,
+ maximum_ratio=self.MAXIMUM_RATIO,
+ minimum_ratio_str=self.MINIMUM_RATIO_STR,
+ maximum_ratio_str=self.MAXIMUM_RATIO_STR,
+ ),
+ raw=raw,
+ image_prompt=(
+ image_prompt
+ if image_prompt is None
+ else convert_image_to_base64(image_prompt)
+ ),
+ image_prompt_strength=(
+ None if image_prompt is None else round(image_prompt_strength, 2)
+ ),
+ ),
+ auth_kwargs=kwargs,
+ )
+ output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
+ return (output_image,)
+
+
+
+class FluxProImageNode(ComfyNodeABC):
+ """
+ Generates images synchronously based on prompt and resolution.
+ """
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation",
+ },
+ ),
+ "prompt_upsampling": (
+ IO.BOOLEAN,
+ {
+ "default": False,
+ "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
+ },
+ ),
+ "width": (
+ IO.INT,
+ {
+ "default": 1024,
+ "min": 256,
+ "max": 1440,
+ "step": 32,
+ },
+ ),
+ "height": (
+ IO.INT,
+ {
+ "default": 768,
+ "min": 256,
+ "max": 1440,
+ "step": 32,
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "The random seed used for creating the noise.",
+ },
+ ),
+ },
+ "optional": {
+ "image_prompt": (IO.IMAGE,),
+ # "image_prompt_strength": (
+ # IO.FLOAT,
+ # {
+ # "default": 0.1,
+ # "min": 0.0,
+ # "max": 1.0,
+ # "step": 0.01,
+ # "tooltip": "Blend between the prompt and the image prompt.",
+ # },
+ # ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/BFL"
+
+ def api_call(
+ self,
+ prompt: str,
+ prompt_upsampling,
+ width: int,
+ height: int,
+ seed=0,
+ image_prompt=None,
+ # image_prompt_strength=0.1,
+ unique_id: Union[str, None] = None,
+ **kwargs,
+ ):
+ image_prompt = (
+ image_prompt
+ if image_prompt is None
+ else convert_image_to_base64(image_prompt)
+ )
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/bfl/flux-pro-1.1/generate",
+ method=HttpMethod.POST,
+ request_model=BFLFluxProGenerateRequest,
+ response_model=BFLFluxProGenerateResponse,
+ ),
+ request=BFLFluxProGenerateRequest(
+ prompt=prompt,
+ prompt_upsampling=prompt_upsampling,
+ width=width,
+ height=height,
+ seed=seed,
+ image_prompt=image_prompt,
+ ),
+ auth_kwargs=kwargs,
+ )
+ output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
+ return (output_image,)
+
+
+class FluxProExpandNode(ComfyNodeABC):
+ """
+ Outpaints image based on prompt.
+ """
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (IO.IMAGE,),
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation",
+ },
+ ),
+ "prompt_upsampling": (
+ IO.BOOLEAN,
+ {
+ "default": False,
+ "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
+ },
+ ),
+ "top": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 2048,
+ "tooltip": "Number of pixels to expand at the top of the image"
+ },
+ ),
+ "bottom": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 2048,
+ "tooltip": "Number of pixels to expand at the bottom of the image"
+ },
+ ),
+ "left": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 2048,
+ "tooltip": "Number of pixels to expand at the left side of the image"
+ },
+ ),
+ "right": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 2048,
+ "tooltip": "Number of pixels to expand at the right side of the image"
+ },
+ ),
+ "guidance": (
+ IO.FLOAT,
+ {
+ "default": 60,
+ "min": 1.5,
+ "max": 100,
+ "tooltip": "Guidance strength for the image generation process"
+ },
+ ),
+ "steps": (
+ IO.INT,
+ {
+ "default": 50,
+ "min": 15,
+ "max": 50,
+ "tooltip": "Number of steps for the image generation process"
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "The random seed used for creating the noise.",
+ },
+ ),
+ },
+ "optional": {},
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/BFL"
+
+ def api_call(
+ self,
+ image: torch.Tensor,
+ prompt: str,
+ prompt_upsampling: bool,
+ top: int,
+ bottom: int,
+ left: int,
+ right: int,
+ steps: int,
+ guidance: float,
+ seed=0,
+ unique_id: Union[str, None] = None,
+ **kwargs,
+ ):
+ image = convert_image_to_base64(image)
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/bfl/flux-pro-1.0-expand/generate",
+ method=HttpMethod.POST,
+ request_model=BFLFluxExpandImageRequest,
+ response_model=BFLFluxProGenerateResponse,
+ ),
+ request=BFLFluxExpandImageRequest(
+ prompt=prompt,
+ prompt_upsampling=prompt_upsampling,
+ top=top,
+ bottom=bottom,
+ left=left,
+ right=right,
+ steps=steps,
+ guidance=guidance,
+ seed=seed,
+ image=image,
+ ),
+ auth_kwargs=kwargs,
+ )
+ output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
+ return (output_image,)
+
+
+
+class FluxProFillNode(ComfyNodeABC):
+ """
+ Inpaints image based on mask and prompt.
+ """
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (IO.IMAGE,),
+ "mask": (IO.MASK,),
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation",
+ },
+ ),
+ "prompt_upsampling": (
+ IO.BOOLEAN,
+ {
+ "default": False,
+ "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
+ },
+ ),
+ "guidance": (
+ IO.FLOAT,
+ {
+ "default": 60,
+ "min": 1.5,
+ "max": 100,
+ "tooltip": "Guidance strength for the image generation process"
+ },
+ ),
+ "steps": (
+ IO.INT,
+ {
+ "default": 50,
+ "min": 15,
+ "max": 50,
+ "tooltip": "Number of steps for the image generation process"
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "The random seed used for creating the noise.",
+ },
+ ),
+ },
+ "optional": {},
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/BFL"
+
+ def api_call(
+ self,
+ image: torch.Tensor,
+ mask: torch.Tensor,
+ prompt: str,
+ prompt_upsampling: bool,
+ steps: int,
+ guidance: float,
+ seed=0,
+ unique_id: Union[str, None] = None,
+ **kwargs,
+ ):
+ # prepare mask
+ mask = resize_mask_to_image(mask, image)
+ mask = convert_image_to_base64(convert_mask_to_image(mask))
+ # make sure image will have alpha channel removed
+ image = convert_image_to_base64(image[:, :, :, :3])
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/bfl/flux-pro-1.0-fill/generate",
+ method=HttpMethod.POST,
+ request_model=BFLFluxFillImageRequest,
+ response_model=BFLFluxProGenerateResponse,
+ ),
+ request=BFLFluxFillImageRequest(
+ prompt=prompt,
+ prompt_upsampling=prompt_upsampling,
+ steps=steps,
+ guidance=guidance,
+ seed=seed,
+ image=image,
+ mask=mask,
+ ),
+ auth_kwargs=kwargs,
+ )
+ output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
+ return (output_image,)
+
+
+class FluxProCannyNode(ComfyNodeABC):
+ """
+ Generate image using a control image (canny).
+ """
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "control_image": (IO.IMAGE,),
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation",
+ },
+ ),
+ "prompt_upsampling": (
+ IO.BOOLEAN,
+ {
+ "default": False,
+ "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
+ },
+ ),
+ "canny_low_threshold": (
+ IO.FLOAT,
+ {
+ "default": 0.1,
+ "min": 0.01,
+ "max": 0.99,
+ "step": 0.01,
+ "tooltip": "Low threshold for Canny edge detection; ignored if skip_processing is True"
+ },
+ ),
+ "canny_high_threshold": (
+ IO.FLOAT,
+ {
+ "default": 0.4,
+ "min": 0.01,
+ "max": 0.99,
+ "step": 0.01,
+ "tooltip": "High threshold for Canny edge detection; ignored if skip_processing is True"
+ },
+ ),
+ "skip_preprocessing": (
+ IO.BOOLEAN,
+ {
+ "default": False,
+ "tooltip": "Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.",
+ },
+ ),
+ "guidance": (
+ IO.FLOAT,
+ {
+ "default": 30,
+ "min": 1,
+ "max": 100,
+ "tooltip": "Guidance strength for the image generation process"
+ },
+ ),
+ "steps": (
+ IO.INT,
+ {
+ "default": 50,
+ "min": 15,
+ "max": 50,
+ "tooltip": "Number of steps for the image generation process"
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "The random seed used for creating the noise.",
+ },
+ ),
+ },
+ "optional": {},
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/BFL"
+
+ def api_call(
+ self,
+ control_image: torch.Tensor,
+ prompt: str,
+ prompt_upsampling: bool,
+ canny_low_threshold: float,
+ canny_high_threshold: float,
+ skip_preprocessing: bool,
+ steps: int,
+ guidance: float,
+ seed=0,
+ unique_id: Union[str, None] = None,
+ **kwargs,
+ ):
+ control_image = convert_image_to_base64(control_image[:, :, :, :3])
+ preprocessed_image = None
+
+ # scale canny threshold between 0-500, to match BFL's API
+ def scale_value(value: float, min_val=0, max_val=500):
+ return min_val + value * (max_val - min_val)
+ canny_low_threshold = int(round(scale_value(canny_low_threshold)))
+ canny_high_threshold = int(round(scale_value(canny_high_threshold)))
+
+
+ if skip_preprocessing:
+ preprocessed_image = control_image
+ control_image = None
+ canny_low_threshold = None
+ canny_high_threshold = None
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/bfl/flux-pro-1.0-canny/generate",
+ method=HttpMethod.POST,
+ request_model=BFLFluxCannyImageRequest,
+ response_model=BFLFluxProGenerateResponse,
+ ),
+ request=BFLFluxCannyImageRequest(
+ prompt=prompt,
+ prompt_upsampling=prompt_upsampling,
+ steps=steps,
+ guidance=guidance,
+ seed=seed,
+ control_image=control_image,
+ canny_low_threshold=canny_low_threshold,
+ canny_high_threshold=canny_high_threshold,
+ preprocessed_image=preprocessed_image,
+ ),
+ auth_kwargs=kwargs,
+ )
+ output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
+ return (output_image,)
+
+
+class FluxProDepthNode(ComfyNodeABC):
+ """
+ Generate image using a control image (depth).
+ """
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "control_image": (IO.IMAGE,),
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation",
+ },
+ ),
+ "prompt_upsampling": (
+ IO.BOOLEAN,
+ {
+ "default": False,
+ "tooltip": "Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).",
+ },
+ ),
+ "skip_preprocessing": (
+ IO.BOOLEAN,
+ {
+ "default": False,
+ "tooltip": "Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.",
+ },
+ ),
+ "guidance": (
+ IO.FLOAT,
+ {
+ "default": 15,
+ "min": 1,
+ "max": 100,
+ "tooltip": "Guidance strength for the image generation process"
+ },
+ ),
+ "steps": (
+ IO.INT,
+ {
+ "default": 50,
+ "min": 15,
+ "max": 50,
+ "tooltip": "Number of steps for the image generation process"
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "The random seed used for creating the noise.",
+ },
+ ),
+ },
+ "optional": {},
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/BFL"
+
+ def api_call(
+ self,
+ control_image: torch.Tensor,
+ prompt: str,
+ prompt_upsampling: bool,
+ skip_preprocessing: bool,
+ steps: int,
+ guidance: float,
+ seed=0,
+ unique_id: Union[str, None] = None,
+ **kwargs,
+ ):
+ control_image = convert_image_to_base64(control_image[:,:,:,:3])
+ preprocessed_image = None
+
+ if skip_preprocessing:
+ preprocessed_image = control_image
+ control_image = None
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/bfl/flux-pro-1.0-depth/generate",
+ method=HttpMethod.POST,
+ request_model=BFLFluxDepthImageRequest,
+ response_model=BFLFluxProGenerateResponse,
+ ),
+ request=BFLFluxDepthImageRequest(
+ prompt=prompt,
+ prompt_upsampling=prompt_upsampling,
+ steps=steps,
+ guidance=guidance,
+ seed=seed,
+ control_image=control_image,
+ preprocessed_image=preprocessed_image,
+ ),
+ auth_kwargs=kwargs,
+ )
+ output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
+ return (output_image,)
+
+
+# A dictionary that contains all nodes you want to export with their names
+# NOTE: names should be globally unique
+NODE_CLASS_MAPPINGS = {
+ "FluxProUltraImageNode": FluxProUltraImageNode,
+ # "FluxProImageNode": FluxProImageNode,
+ "FluxProExpandNode": FluxProExpandNode,
+ "FluxProFillNode": FluxProFillNode,
+ "FluxProCannyNode": FluxProCannyNode,
+ "FluxProDepthNode": FluxProDepthNode,
+}
+
+# A dictionary that contains the friendly/humanly readable titles for the nodes
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "FluxProUltraImageNode": "Flux 1.1 [pro] Ultra Image",
+ # "FluxProImageNode": "Flux 1.1 [pro] Image",
+ "FluxProExpandNode": "Flux.1 Expand Image",
+ "FluxProFillNode": "Flux.1 Fill Image",
+ "FluxProCannyNode": "Flux.1 Canny Control Image",
+ "FluxProDepthNode": "Flux.1 Depth Control Image",
+}
diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py
new file mode 100644
index 000000000..ae7b04846
--- /dev/null
+++ b/comfy_api_nodes/nodes_gemini.py
@@ -0,0 +1,446 @@
+"""
+API Nodes for Gemini Multimodal LLM Usage via Remote API
+See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
+"""
+
+import os
+from enum import Enum
+from typing import Optional, Literal
+
+import torch
+
+import folder_paths
+from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
+from server import PromptServer
+from comfy_api_nodes.apis import (
+ GeminiContent,
+ GeminiGenerateContentRequest,
+ GeminiGenerateContentResponse,
+ GeminiInlineData,
+ GeminiPart,
+ GeminiMimeType,
+)
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+)
+from comfy_api_nodes.apinode_utils import (
+ validate_string,
+ audio_to_base64_string,
+ video_to_base64_string,
+ tensor_to_base64_string,
+)
+
+
+GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
+GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
+
+
+class GeminiModel(str, Enum):
+ """
+ Gemini Model Names allowed by comfy-api
+ """
+
+ gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
+ gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
+
+
+def get_gemini_endpoint(
+ model: GeminiModel,
+) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]:
+ """
+ Get the API endpoint for a given Gemini model.
+
+ Args:
+ model: The Gemini model to use, either as enum or string value.
+
+ Returns:
+ ApiEndpoint configured for the specific Gemini model.
+ """
+ if isinstance(model, str):
+ model = GeminiModel(model)
+ return ApiEndpoint(
+ path=f"{GEMINI_BASE_ENDPOINT}/{model.value}",
+ method=HttpMethod.POST,
+ request_model=GeminiGenerateContentRequest,
+ response_model=GeminiGenerateContentResponse,
+ )
+
+
+class GeminiNode(ComfyNodeABC):
+ """
+ Node to generate text responses from a Gemini model.
+
+ This node allows users to interact with Google's Gemini AI models, providing
+ multimodal inputs (text, images, audio, video, files) to generate coherent
+ text responses. The node works with the latest Gemini models, handling the
+ API communication and response parsing.
+ """
+
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypeDict:
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.",
+ },
+ ),
+ "model": (
+ IO.COMBO,
+ {
+ "tooltip": "The Gemini model to use for generating responses.",
+ "options": [model.value for model in GeminiModel],
+ "default": GeminiModel.gemini_2_5_pro_preview_05_06.value,
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 42,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.",
+ },
+ ),
+ },
+ "optional": {
+ "images": (
+ IO.IMAGE,
+ {
+ "default": None,
+ "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
+ },
+ ),
+ "audio": (
+ IO.AUDIO,
+ {
+ "tooltip": "Optional audio to use as context for the model.",
+ "default": None,
+ },
+ ),
+ "video": (
+ IO.VIDEO,
+ {
+ "tooltip": "Optional video to use as context for the model.",
+ "default": None,
+ },
+ ),
+ "files": (
+ "GEMINI_INPUT_FILES",
+ {
+ "default": None,
+ "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses."
+ RETURN_TYPES = ("STRING",)
+ FUNCTION = "api_call"
+ CATEGORY = "api node/text/Gemini"
+ API_NODE = True
+
+ def get_parts_from_response(
+ self, response: GeminiGenerateContentResponse
+ ) -> list[GeminiPart]:
+ """
+ Extract all parts from the Gemini API response.
+
+ Args:
+ response: The API response from Gemini.
+
+ Returns:
+ List of response parts from the first candidate.
+ """
+ return response.candidates[0].content.parts
+
+ def get_parts_by_type(
+ self, response: GeminiGenerateContentResponse, part_type: Literal["text"] | str
+ ) -> list[GeminiPart]:
+ """
+ Filter response parts by their type.
+
+ Args:
+ response: The API response from Gemini.
+ part_type: Type of parts to extract ("text" or a MIME type).
+
+ Returns:
+ List of response parts matching the requested type.
+ """
+ parts = []
+ for part in self.get_parts_from_response(response):
+ if part_type == "text" and hasattr(part, "text") and part.text:
+ parts.append(part)
+ elif (
+ hasattr(part, "inlineData")
+ and part.inlineData
+ and part.inlineData.mimeType == part_type
+ ):
+ parts.append(part)
+ # Skip parts that don't match the requested type
+ return parts
+
+ def get_text_from_response(self, response: GeminiGenerateContentResponse) -> str:
+ """
+ Extract and concatenate all text parts from the response.
+
+ Args:
+ response: The API response from Gemini.
+
+ Returns:
+ Combined text from all text parts in the response.
+ """
+ parts = self.get_parts_by_type(response, "text")
+ return "\n".join([part.text for part in parts])
+
+ def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]:
+ """
+ Convert video input to Gemini API compatible parts.
+
+ Args:
+ video_input: Video tensor from ComfyUI.
+ **kwargs: Additional arguments to pass to the conversion function.
+
+ Returns:
+ List of GeminiPart objects containing the encoded video.
+ """
+ from comfy_api.util import VideoContainer, VideoCodec
+ base_64_string = video_to_base64_string(
+ video_input,
+ container_format=VideoContainer.MP4,
+ codec=VideoCodec.H264
+ )
+ return [
+ GeminiPart(
+ inlineData=GeminiInlineData(
+ mimeType=GeminiMimeType.video_mp4,
+ data=base_64_string,
+ )
+ )
+ ]
+
+ def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]:
+ """
+ Convert audio input to Gemini API compatible parts.
+
+ Args:
+ audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
+
+ Returns:
+ List of GeminiPart objects containing the encoded audio.
+ """
+ audio_parts: list[GeminiPart] = []
+ for batch_index in range(audio_input["waveform"].shape[0]):
+ # Recreate an IO.AUDIO object for the given batch dimension index
+ audio_at_index = {
+ "waveform": audio_input["waveform"][batch_index].unsqueeze(0),
+ "sample_rate": audio_input["sample_rate"],
+ }
+ # Convert to MP3 format for compatibility with Gemini API
+ audio_bytes = audio_to_base64_string(
+ audio_at_index,
+ container_format="mp3",
+ codec_name="libmp3lame",
+ )
+ audio_parts.append(
+ GeminiPart(
+ inlineData=GeminiInlineData(
+ mimeType=GeminiMimeType.audio_mp3,
+ data=audio_bytes,
+ )
+ )
+ )
+ return audio_parts
+
+ def create_image_parts(self, image_input: torch.Tensor) -> list[GeminiPart]:
+ """
+ Convert image tensor input to Gemini API compatible parts.
+
+ Args:
+ image_input: Batch of image tensors from ComfyUI.
+
+ Returns:
+ List of GeminiPart objects containing the encoded images.
+ """
+ image_parts: list[GeminiPart] = []
+ for image_index in range(image_input.shape[0]):
+ image_as_b64 = tensor_to_base64_string(
+ image_input[image_index].unsqueeze(0)
+ )
+ image_parts.append(
+ GeminiPart(
+ inlineData=GeminiInlineData(
+ mimeType=GeminiMimeType.image_png,
+ data=image_as_b64,
+ )
+ )
+ )
+ return image_parts
+
+ def create_text_part(self, text: str) -> GeminiPart:
+ """
+ Create a text part for the Gemini API request.
+
+ Args:
+ text: The text content to include in the request.
+
+ Returns:
+ A GeminiPart object with the text content.
+ """
+ return GeminiPart(text=text)
+
+ def api_call(
+ self,
+ prompt: str,
+ model: GeminiModel,
+ images: Optional[IO.IMAGE] = None,
+ audio: Optional[IO.AUDIO] = None,
+ video: Optional[IO.VIDEO] = None,
+ files: Optional[list[GeminiPart]] = None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ) -> tuple[str]:
+ # Validate inputs
+ validate_string(prompt, strip_whitespace=False)
+
+ # Create parts list with text prompt as the first part
+ parts: list[GeminiPart] = [self.create_text_part(prompt)]
+
+ # Add other modal parts
+ if images is not None:
+ image_parts = self.create_image_parts(images)
+ parts.extend(image_parts)
+ if audio is not None:
+ parts.extend(self.create_audio_parts(audio))
+ if video is not None:
+ parts.extend(self.create_video_parts(video))
+ if files is not None:
+ parts.extend(files)
+
+ # Create response
+ response = SynchronousOperation(
+ endpoint=get_gemini_endpoint(model),
+ request=GeminiGenerateContentRequest(
+ contents=[
+ GeminiContent(
+ role="user",
+ parts=parts,
+ )
+ ]
+ ),
+ auth_kwargs=kwargs,
+ ).execute()
+
+ # Get result output
+ output_text = self.get_text_from_response(response)
+ if unique_id and output_text:
+ PromptServer.instance.send_progress_text(output_text, node_id=unique_id)
+
+ return (output_text or "Empty response from Gemini model...",)
+
+
+class GeminiInputFiles(ComfyNodeABC):
+ """
+ Loads and formats input files for use with the Gemini API.
+
+ This node allows users to include text (.txt) and PDF (.pdf) files as input
+ context for the Gemini model. Files are converted to the appropriate format
+ required by the API and can be chained together to include multiple files
+ in a single request.
+ """
+
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypeDict:
+ """
+ For details about the supported file input types, see:
+ https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
+ """
+ input_dir = folder_paths.get_input_directory()
+ input_files = [
+ f
+ for f in os.scandir(input_dir)
+ if f.is_file()
+ and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
+ and f.stat().st_size < GEMINI_MAX_INPUT_FILE_SIZE
+ ]
+ input_files = sorted(input_files, key=lambda x: x.name)
+ input_files = [f.name for f in input_files]
+ return {
+ "required": {
+ "file": (
+ IO.COMBO,
+ {
+ "tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
+ "options": input_files,
+ "default": input_files[0] if input_files else None,
+ },
+ ),
+ },
+ "optional": {
+ "GEMINI_INPUT_FILES": (
+ "GEMINI_INPUT_FILES",
+ {
+ "tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
+ "default": None,
+ },
+ ),
+ },
+ }
+
+ DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes."
+ RETURN_TYPES = ("GEMINI_INPUT_FILES",)
+ FUNCTION = "prepare_files"
+ CATEGORY = "api node/text/Gemini"
+
+ def create_file_part(self, file_path: str) -> GeminiPart:
+ mime_type = (
+ GeminiMimeType.pdf
+ if file_path.endswith(".pdf")
+ else GeminiMimeType.text_plain
+ )
+ # Use base64 string directly, not the data URI
+ with open(file_path, "rb") as f:
+ file_content = f.read()
+ import base64
+ base64_str = base64.b64encode(file_content).decode("utf-8")
+
+ return GeminiPart(
+ inlineData=GeminiInlineData(
+ mimeType=mime_type,
+ data=base64_str,
+ )
+ )
+
+ def prepare_files(
+ self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = []
+ ) -> tuple[list[GeminiPart]]:
+ """
+ Loads and formats input files for Gemini API.
+ """
+ file_path = folder_paths.get_annotated_filepath(file)
+ input_file_content = self.create_file_part(file_path)
+ files = [input_file_content] + GEMINI_INPUT_FILES
+ return (files,)
+
+
+NODE_CLASS_MAPPINGS = {
+ "GeminiNode": GeminiNode,
+ "GeminiInputFiles": GeminiInputFiles,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "GeminiNode": "Google Gemini",
+ "GeminiInputFiles": "Gemini Input Files",
+}
diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py
new file mode 100644
index 000000000..b1cbf511d
--- /dev/null
+++ b/comfy_api_nodes/nodes_ideogram.py
@@ -0,0 +1,801 @@
+from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
+from inspect import cleandoc
+from PIL import Image
+import numpy as np
+import io
+import torch
+from comfy_api_nodes.apis import (
+ IdeogramGenerateRequest,
+ IdeogramGenerateResponse,
+ ImageRequest,
+ IdeogramV3Request,
+ IdeogramV3EditRequest,
+)
+
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+)
+
+from comfy_api_nodes.apinode_utils import (
+ download_url_to_bytesio,
+ bytesio_to_image_tensor,
+ resize_mask_to_image,
+)
+from server import PromptServer
+
+V1_V1_RES_MAP = {
+ "Auto":"AUTO",
+ "512 x 1536":"RESOLUTION_512_1536",
+ "576 x 1408":"RESOLUTION_576_1408",
+ "576 x 1472":"RESOLUTION_576_1472",
+ "576 x 1536":"RESOLUTION_576_1536",
+ "640 x 1024":"RESOLUTION_640_1024",
+ "640 x 1344":"RESOLUTION_640_1344",
+ "640 x 1408":"RESOLUTION_640_1408",
+ "640 x 1472":"RESOLUTION_640_1472",
+ "640 x 1536":"RESOLUTION_640_1536",
+ "704 x 1152":"RESOLUTION_704_1152",
+ "704 x 1216":"RESOLUTION_704_1216",
+ "704 x 1280":"RESOLUTION_704_1280",
+ "704 x 1344":"RESOLUTION_704_1344",
+ "704 x 1408":"RESOLUTION_704_1408",
+ "704 x 1472":"RESOLUTION_704_1472",
+ "720 x 1280":"RESOLUTION_720_1280",
+ "736 x 1312":"RESOLUTION_736_1312",
+ "768 x 1024":"RESOLUTION_768_1024",
+ "768 x 1088":"RESOLUTION_768_1088",
+ "768 x 1152":"RESOLUTION_768_1152",
+ "768 x 1216":"RESOLUTION_768_1216",
+ "768 x 1232":"RESOLUTION_768_1232",
+ "768 x 1280":"RESOLUTION_768_1280",
+ "768 x 1344":"RESOLUTION_768_1344",
+ "832 x 960":"RESOLUTION_832_960",
+ "832 x 1024":"RESOLUTION_832_1024",
+ "832 x 1088":"RESOLUTION_832_1088",
+ "832 x 1152":"RESOLUTION_832_1152",
+ "832 x 1216":"RESOLUTION_832_1216",
+ "832 x 1248":"RESOLUTION_832_1248",
+ "864 x 1152":"RESOLUTION_864_1152",
+ "896 x 960":"RESOLUTION_896_960",
+ "896 x 1024":"RESOLUTION_896_1024",
+ "896 x 1088":"RESOLUTION_896_1088",
+ "896 x 1120":"RESOLUTION_896_1120",
+ "896 x 1152":"RESOLUTION_896_1152",
+ "960 x 832":"RESOLUTION_960_832",
+ "960 x 896":"RESOLUTION_960_896",
+ "960 x 1024":"RESOLUTION_960_1024",
+ "960 x 1088":"RESOLUTION_960_1088",
+ "1024 x 640":"RESOLUTION_1024_640",
+ "1024 x 768":"RESOLUTION_1024_768",
+ "1024 x 832":"RESOLUTION_1024_832",
+ "1024 x 896":"RESOLUTION_1024_896",
+ "1024 x 960":"RESOLUTION_1024_960",
+ "1024 x 1024":"RESOLUTION_1024_1024",
+ "1088 x 768":"RESOLUTION_1088_768",
+ "1088 x 832":"RESOLUTION_1088_832",
+ "1088 x 896":"RESOLUTION_1088_896",
+ "1088 x 960":"RESOLUTION_1088_960",
+ "1120 x 896":"RESOLUTION_1120_896",
+ "1152 x 704":"RESOLUTION_1152_704",
+ "1152 x 768":"RESOLUTION_1152_768",
+ "1152 x 832":"RESOLUTION_1152_832",
+ "1152 x 864":"RESOLUTION_1152_864",
+ "1152 x 896":"RESOLUTION_1152_896",
+ "1216 x 704":"RESOLUTION_1216_704",
+ "1216 x 768":"RESOLUTION_1216_768",
+ "1216 x 832":"RESOLUTION_1216_832",
+ "1232 x 768":"RESOLUTION_1232_768",
+ "1248 x 832":"RESOLUTION_1248_832",
+ "1280 x 704":"RESOLUTION_1280_704",
+ "1280 x 720":"RESOLUTION_1280_720",
+ "1280 x 768":"RESOLUTION_1280_768",
+ "1280 x 800":"RESOLUTION_1280_800",
+ "1312 x 736":"RESOLUTION_1312_736",
+ "1344 x 640":"RESOLUTION_1344_640",
+ "1344 x 704":"RESOLUTION_1344_704",
+ "1344 x 768":"RESOLUTION_1344_768",
+ "1408 x 576":"RESOLUTION_1408_576",
+ "1408 x 640":"RESOLUTION_1408_640",
+ "1408 x 704":"RESOLUTION_1408_704",
+ "1472 x 576":"RESOLUTION_1472_576",
+ "1472 x 640":"RESOLUTION_1472_640",
+ "1472 x 704":"RESOLUTION_1472_704",
+ "1536 x 512":"RESOLUTION_1536_512",
+ "1536 x 576":"RESOLUTION_1536_576",
+ "1536 x 640":"RESOLUTION_1536_640",
+}
+
+V1_V2_RATIO_MAP = {
+ "1:1":"ASPECT_1_1",
+ "4:3":"ASPECT_4_3",
+ "3:4":"ASPECT_3_4",
+ "16:9":"ASPECT_16_9",
+ "9:16":"ASPECT_9_16",
+ "2:1":"ASPECT_2_1",
+ "1:2":"ASPECT_1_2",
+ "3:2":"ASPECT_3_2",
+ "2:3":"ASPECT_2_3",
+ "4:5":"ASPECT_4_5",
+ "5:4":"ASPECT_5_4",
+}
+
+V3_RATIO_MAP = {
+ "1:3":"1x3",
+ "3:1":"3x1",
+ "1:2":"1x2",
+ "2:1":"2x1",
+ "9:16":"9x16",
+ "16:9":"16x9",
+ "10:16":"10x16",
+ "16:10":"16x10",
+ "2:3":"2x3",
+ "3:2":"3x2",
+ "3:4":"3x4",
+ "4:3":"4x3",
+ "4:5":"4x5",
+ "5:4":"5x4",
+ "1:1":"1x1",
+}
+
+V3_RESOLUTIONS= [
+ "Auto",
+ "512x1536",
+ "576x1408",
+ "576x1472",
+ "576x1536",
+ "640x1344",
+ "640x1408",
+ "640x1472",
+ "640x1536",
+ "704x1152",
+ "704x1216",
+ "704x1280",
+ "704x1344",
+ "704x1408",
+ "704x1472",
+ "736x1312",
+ "768x1088",
+ "768x1216",
+ "768x1280",
+ "768x1344",
+ "800x1280",
+ "832x960",
+ "832x1024",
+ "832x1088",
+ "832x1152",
+ "832x1216",
+ "832x1248",
+ "864x1152",
+ "896x960",
+ "896x1024",
+ "896x1088",
+ "896x1120",
+ "896x1152",
+ "960x832",
+ "960x896",
+ "960x1024",
+ "960x1088",
+ "1024x832",
+ "1024x896",
+ "1024x960",
+ "1024x1024",
+ "1088x768",
+ "1088x832",
+ "1088x896",
+ "1088x960",
+ "1120x896",
+ "1152x704",
+ "1152x832",
+ "1152x864",
+ "1152x896",
+ "1216x704",
+ "1216x768",
+ "1216x832",
+ "1248x832",
+ "1280x704",
+ "1280x768",
+ "1280x800",
+ "1312x736",
+ "1344x640",
+ "1344x704",
+ "1344x768",
+ "1408x576",
+ "1408x640",
+ "1408x704",
+ "1472x576",
+ "1472x640",
+ "1472x704",
+ "1536x512",
+ "1536x576",
+ "1536x640"
+]
+
+def download_and_process_images(image_urls):
+ """Helper function to download and process multiple images from URLs"""
+
+ # Initialize list to store image tensors
+ image_tensors = []
+
+ for image_url in image_urls:
+ # Using functions from apinode_utils.py to handle downloading and processing
+ image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO
+ img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode
+ image_tensors.append(img_tensor)
+
+ # Stack tensors to match (N, width, height, channels)
+ if image_tensors:
+ stacked_tensors = torch.cat(image_tensors, dim=0)
+ else:
+ raise Exception("No valid images were processed")
+
+ return stacked_tensors
+
+
+def display_image_urls_on_node(image_urls, node_id):
+ if node_id and image_urls:
+ if len(image_urls) == 1:
+ PromptServer.instance.send_progress_text(
+ f"Generated Image URL:\n{image_urls[0]}", node_id
+ )
+ else:
+ urls_text = "Generated Image URLs:\n" + "\n".join(
+ f"{i+1}. {url}" for i, url in enumerate(image_urls)
+ )
+ PromptServer.instance.send_progress_text(urls_text, node_id)
+
+
+class IdeogramV1(ComfyNodeABC):
+ """
+ Generates images using the Ideogram V1 model.
+ """
+
+ def __init__(self):
+ pass
+
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypeDict:
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation",
+ },
+ ),
+ "turbo": (
+ IO.BOOLEAN,
+ {
+ "default": False,
+ "tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
+ }
+ ),
+ },
+ "optional": {
+ "aspect_ratio": (
+ IO.COMBO,
+ {
+ "options": list(V1_V2_RATIO_MAP.keys()),
+ "default": "1:1",
+ "tooltip": "The aspect ratio for image generation.",
+ },
+ ),
+ "magic_prompt_option": (
+ IO.COMBO,
+ {
+ "options": ["AUTO", "ON", "OFF"],
+ "default": "AUTO",
+ "tooltip": "Determine if MagicPrompt should be used in generation",
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 2147483647,
+ "step": 1,
+ "control_after_generate": True,
+ "display": "number",
+ },
+ ),
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Description of what to exclude from the image",
+ },
+ ),
+ "num_images": (
+ IO.INT,
+ {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = (IO.IMAGE,)
+ FUNCTION = "api_call"
+ CATEGORY = "api node/image/Ideogram/v1"
+ DESCRIPTION = cleandoc(__doc__ or "")
+ API_NODE = True
+
+ def api_call(
+ self,
+ prompt,
+ turbo=False,
+ aspect_ratio="1:1",
+ magic_prompt_option="AUTO",
+ seed=0,
+ negative_prompt="",
+ num_images=1,
+ unique_id=None,
+ **kwargs,
+ ):
+ # Determine the model based on turbo setting
+ aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
+ model = "V_1_TURBO" if turbo else "V_1"
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/ideogram/generate",
+ method=HttpMethod.POST,
+ request_model=IdeogramGenerateRequest,
+ response_model=IdeogramGenerateResponse,
+ ),
+ request=IdeogramGenerateRequest(
+ image_request=ImageRequest(
+ prompt=prompt,
+ model=model,
+ num_images=num_images,
+ seed=seed,
+ aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None,
+ magic_prompt_option=(
+ magic_prompt_option if magic_prompt_option != "AUTO" else None
+ ),
+ negative_prompt=negative_prompt if negative_prompt else None,
+ )
+ ),
+ auth_kwargs=kwargs,
+ )
+
+ response = operation.execute()
+
+ if not response.data or len(response.data) == 0:
+ raise Exception("No images were generated in the response")
+
+ image_urls = [image_data.url for image_data in response.data if image_data.url]
+
+ if not image_urls:
+ raise Exception("No image URLs were generated in the response")
+
+ display_image_urls_on_node(image_urls, unique_id)
+ return (download_and_process_images(image_urls),)
+
+
+class IdeogramV2(ComfyNodeABC):
+ """
+ Generates images using the Ideogram V2 model.
+ """
+
+ def __init__(self):
+ pass
+
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypeDict:
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation",
+ },
+ ),
+ "turbo": (
+ IO.BOOLEAN,
+ {
+ "default": False,
+ "tooltip": "Whether to use turbo mode (faster generation, potentially lower quality)",
+ }
+ ),
+ },
+ "optional": {
+ "aspect_ratio": (
+ IO.COMBO,
+ {
+ "options": list(V1_V2_RATIO_MAP.keys()),
+ "default": "1:1",
+ "tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to AUTO.",
+ },
+ ),
+ "resolution": (
+ IO.COMBO,
+ {
+ "options": list(V1_V1_RES_MAP.keys()),
+ "default": "Auto",
+ "tooltip": "The resolution for image generation. If not set to AUTO, this overrides the aspect_ratio setting.",
+ },
+ ),
+ "magic_prompt_option": (
+ IO.COMBO,
+ {
+ "options": ["AUTO", "ON", "OFF"],
+ "default": "AUTO",
+ "tooltip": "Determine if MagicPrompt should be used in generation",
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 2147483647,
+ "step": 1,
+ "control_after_generate": True,
+ "display": "number",
+ },
+ ),
+ "style_type": (
+ IO.COMBO,
+ {
+ "options": ["AUTO", "GENERAL", "REALISTIC", "DESIGN", "RENDER_3D", "ANIME"],
+ "default": "NONE",
+ "tooltip": "Style type for generation (V2 only)",
+ },
+ ),
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Description of what to exclude from the image",
+ },
+ ),
+ "num_images": (
+ IO.INT,
+ {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
+ ),
+ #"color_palette": (
+ # IO.STRING,
+ # {
+ # "multiline": False,
+ # "default": "",
+ # "tooltip": "Color palette preset name or hex colors with weights",
+ # },
+ #),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = (IO.IMAGE,)
+ FUNCTION = "api_call"
+ CATEGORY = "api node/image/Ideogram/v2"
+ DESCRIPTION = cleandoc(__doc__ or "")
+ API_NODE = True
+
+ def api_call(
+ self,
+ prompt,
+ turbo=False,
+ aspect_ratio="1:1",
+ resolution="Auto",
+ magic_prompt_option="AUTO",
+ seed=0,
+ style_type="NONE",
+ negative_prompt="",
+ num_images=1,
+ color_palette="",
+ unique_id=None,
+ **kwargs,
+ ):
+ aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None)
+ resolution = V1_V1_RES_MAP.get(resolution, None)
+ # Determine the model based on turbo setting
+ model = "V_2_TURBO" if turbo else "V_2"
+
+ # Handle resolution vs aspect_ratio logic
+ # If resolution is not AUTO, it overrides aspect_ratio
+ final_resolution = None
+ final_aspect_ratio = None
+
+ if resolution != "AUTO":
+ final_resolution = resolution
+ else:
+ final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/ideogram/generate",
+ method=HttpMethod.POST,
+ request_model=IdeogramGenerateRequest,
+ response_model=IdeogramGenerateResponse,
+ ),
+ request=IdeogramGenerateRequest(
+ image_request=ImageRequest(
+ prompt=prompt,
+ model=model,
+ num_images=num_images,
+ seed=seed,
+ aspect_ratio=final_aspect_ratio,
+ resolution=final_resolution,
+ magic_prompt_option=(
+ magic_prompt_option if magic_prompt_option != "AUTO" else None
+ ),
+ style_type=style_type if style_type != "NONE" else None,
+ negative_prompt=negative_prompt if negative_prompt else None,
+ color_palette=color_palette if color_palette else None,
+ )
+ ),
+ auth_kwargs=kwargs,
+ )
+
+ response = operation.execute()
+
+ if not response.data or len(response.data) == 0:
+ raise Exception("No images were generated in the response")
+
+ image_urls = [image_data.url for image_data in response.data if image_data.url]
+
+ if not image_urls:
+ raise Exception("No image URLs were generated in the response")
+
+ display_image_urls_on_node(image_urls, unique_id)
+ return (download_and_process_images(image_urls),)
+
+class IdeogramV3(ComfyNodeABC):
+ """
+ Generates images using the Ideogram V3 model. Supports both regular image generation from text prompts and image editing with mask.
+ """
+
+ def __init__(self):
+ pass
+
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypeDict:
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation or editing",
+ },
+ ),
+ },
+ "optional": {
+ "image": (
+ IO.IMAGE,
+ {
+ "default": None,
+ "tooltip": "Optional reference image for image editing.",
+ },
+ ),
+ "mask": (
+ IO.MASK,
+ {
+ "default": None,
+ "tooltip": "Optional mask for inpainting (white areas will be replaced)",
+ },
+ ),
+ "aspect_ratio": (
+ IO.COMBO,
+ {
+ "options": list(V3_RATIO_MAP.keys()),
+ "default": "1:1",
+ "tooltip": "The aspect ratio for image generation. Ignored if resolution is not set to Auto.",
+ },
+ ),
+ "resolution": (
+ IO.COMBO,
+ {
+ "options": V3_RESOLUTIONS,
+ "default": "Auto",
+ "tooltip": "The resolution for image generation. If not set to Auto, this overrides the aspect_ratio setting.",
+ },
+ ),
+ "magic_prompt_option": (
+ IO.COMBO,
+ {
+ "options": ["AUTO", "ON", "OFF"],
+ "default": "AUTO",
+ "tooltip": "Determine if MagicPrompt should be used in generation",
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 2147483647,
+ "step": 1,
+ "control_after_generate": True,
+ "display": "number",
+ },
+ ),
+ "num_images": (
+ IO.INT,
+ {"default": 1, "min": 1, "max": 8, "step": 1, "display": "number"},
+ ),
+ "rendering_speed": (
+ IO.COMBO,
+ {
+ "options": ["BALANCED", "TURBO", "QUALITY"],
+ "default": "BALANCED",
+ "tooltip": "Controls the trade-off between generation speed and quality",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = (IO.IMAGE,)
+ FUNCTION = "api_call"
+ CATEGORY = "api node/image/Ideogram/v3"
+ DESCRIPTION = cleandoc(__doc__ or "")
+ API_NODE = True
+
+ def api_call(
+ self,
+ prompt,
+ image=None,
+ mask=None,
+ resolution="Auto",
+ aspect_ratio="1:1",
+ magic_prompt_option="AUTO",
+ seed=0,
+ num_images=1,
+ rendering_speed="BALANCED",
+ unique_id=None,
+ **kwargs,
+ ):
+ # Check if both image and mask are provided for editing mode
+ if image is not None and mask is not None:
+ # Edit mode
+ path = "/proxy/ideogram/ideogram-v3/edit"
+
+ # Process image and mask
+ input_tensor = image.squeeze().cpu()
+ # Resize mask to match image dimension
+ mask = resize_mask_to_image(mask, image, allow_gradient=False)
+ # Invert mask, as Ideogram API will edit black areas instead of white areas (opposite of convention).
+ mask = 1.0 - mask
+
+ # Validate mask dimensions match image
+ if mask.shape[1:] != image.shape[1:-1]:
+ raise Exception("Mask and Image must be the same size")
+
+ # Process image
+ img_np = (input_tensor.numpy() * 255).astype(np.uint8)
+ img = Image.fromarray(img_np)
+ img_byte_arr = io.BytesIO()
+ img.save(img_byte_arr, format="PNG")
+ img_byte_arr.seek(0)
+ img_binary = img_byte_arr
+ img_binary.name = "image.png"
+
+ # Process mask - white areas will be replaced
+ mask_np = (mask.squeeze().cpu().numpy() * 255).astype(np.uint8)
+ mask_img = Image.fromarray(mask_np)
+ mask_byte_arr = io.BytesIO()
+ mask_img.save(mask_byte_arr, format="PNG")
+ mask_byte_arr.seek(0)
+ mask_binary = mask_byte_arr
+ mask_binary.name = "mask.png"
+
+ # Create edit request
+ edit_request = IdeogramV3EditRequest(
+ prompt=prompt,
+ rendering_speed=rendering_speed,
+ )
+
+ # Add optional parameters
+ if magic_prompt_option != "AUTO":
+ edit_request.magic_prompt = magic_prompt_option
+ if seed != 0:
+ edit_request.seed = seed
+ if num_images > 1:
+ edit_request.num_images = num_images
+
+ # Execute the operation for edit mode
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=path,
+ method=HttpMethod.POST,
+ request_model=IdeogramV3EditRequest,
+ response_model=IdeogramGenerateResponse,
+ ),
+ request=edit_request,
+ files={
+ "image": img_binary,
+ "mask": mask_binary,
+ },
+ content_type="multipart/form-data",
+ auth_kwargs=kwargs,
+ )
+
+ elif image is not None or mask is not None:
+ # If only one of image or mask is provided, raise an error
+ raise Exception("Ideogram V3 image editing requires both an image AND a mask")
+ else:
+ # Generation mode
+ path = "/proxy/ideogram/ideogram-v3/generate"
+
+ # Create generation request
+ gen_request = IdeogramV3Request(
+ prompt=prompt,
+ rendering_speed=rendering_speed,
+ )
+
+ # Handle resolution vs aspect ratio
+ if resolution != "Auto":
+ gen_request.resolution = resolution
+ elif aspect_ratio != "1:1":
+ v3_aspect = V3_RATIO_MAP.get(aspect_ratio)
+ if v3_aspect:
+ gen_request.aspect_ratio = v3_aspect
+
+ # Add optional parameters
+ if magic_prompt_option != "AUTO":
+ gen_request.magic_prompt = magic_prompt_option
+ if seed != 0:
+ gen_request.seed = seed
+ if num_images > 1:
+ gen_request.num_images = num_images
+
+ # Execute the operation for generation mode
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=path,
+ method=HttpMethod.POST,
+ request_model=IdeogramV3Request,
+ response_model=IdeogramGenerateResponse,
+ ),
+ request=gen_request,
+ auth_kwargs=kwargs,
+ )
+
+ # Execute the operation and process response
+ response = operation.execute()
+
+ if not response.data or len(response.data) == 0:
+ raise Exception("No images were generated in the response")
+
+ image_urls = [image_data.url for image_data in response.data if image_data.url]
+
+ if not image_urls:
+ raise Exception("No image URLs were generated in the response")
+
+ display_image_urls_on_node(image_urls, unique_id)
+ return (download_and_process_images(image_urls),)
+
+
+NODE_CLASS_MAPPINGS = {
+ "IdeogramV1": IdeogramV1,
+ "IdeogramV2": IdeogramV2,
+ "IdeogramV3": IdeogramV3,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "IdeogramV1": "Ideogram V1",
+ "IdeogramV2": "Ideogram V2",
+ "IdeogramV3": "Ideogram V3",
+}
diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py
new file mode 100644
index 000000000..641cd6353
--- /dev/null
+++ b/comfy_api_nodes/nodes_kling.py
@@ -0,0 +1,1758 @@
+"""Kling API Nodes
+
+For source of truth on the allowed permutations of request fields, please reference:
+- [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
+"""
+
+from __future__ import annotations
+from typing import Optional, TypeVar, Any
+from collections.abc import Callable
+import math
+import logging
+
+import torch
+
+from comfy_api_nodes.apis import (
+ KlingTaskStatus,
+ KlingCameraControl,
+ KlingCameraConfig,
+ KlingCameraControlType,
+ KlingVideoGenDuration,
+ KlingVideoGenMode,
+ KlingVideoGenAspectRatio,
+ KlingVideoGenModelName,
+ KlingText2VideoRequest,
+ KlingText2VideoResponse,
+ KlingImage2VideoRequest,
+ KlingImage2VideoResponse,
+ KlingVideoExtendRequest,
+ KlingVideoExtendResponse,
+ KlingLipSyncVoiceLanguage,
+ KlingLipSyncInputObject,
+ KlingLipSyncRequest,
+ KlingLipSyncResponse,
+ KlingVirtualTryOnModelName,
+ KlingVirtualTryOnRequest,
+ KlingVirtualTryOnResponse,
+ KlingVideoResult,
+ KlingImageResult,
+ KlingImageGenerationsRequest,
+ KlingImageGenerationsResponse,
+ KlingImageGenImageReferenceType,
+ KlingImageGenModelName,
+ KlingImageGenAspectRatio,
+ KlingVideoEffectsRequest,
+ KlingVideoEffectsResponse,
+ KlingDualCharacterEffectsScene,
+ KlingSingleImageEffectsScene,
+ KlingDualCharacterEffectInput,
+ KlingSingleImageEffectInput,
+ KlingCharacterEffectModelName,
+ KlingSingleImageEffectModelName,
+)
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+ PollingOperation,
+ EmptyRequest,
+)
+from comfy_api_nodes.apinode_utils import (
+ tensor_to_base64_string,
+ download_url_to_video_output,
+ upload_video_to_comfyapi,
+ upload_audio_to_comfyapi,
+ download_url_to_image_tensor,
+)
+from comfy_api_nodes.mapper_utils import model_field_to_node_input
+from comfy_api_nodes.util.validation_utils import (
+ validate_image_dimensions,
+ validate_image_aspect_ratio,
+ validate_video_dimensions,
+ validate_video_duration,
+)
+from comfy_api.input.basic_types import AudioInput
+from comfy_api.input.video_types import VideoInput
+from comfy_api.input_impl import VideoFromFile
+from comfy.comfy_types.node_typing import IO, InputTypeOptions, ComfyNodeABC
+
+KLING_API_VERSION = "v1"
+PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video"
+PATH_IMAGE_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/image2video"
+PATH_VIDEO_EXTEND = f"/proxy/kling/{KLING_API_VERSION}/videos/video-extend"
+PATH_LIP_SYNC = f"/proxy/kling/{KLING_API_VERSION}/videos/lip-sync"
+PATH_VIDEO_EFFECTS = f"/proxy/kling/{KLING_API_VERSION}/videos/effects"
+PATH_CHARACTER_IMAGE = f"/proxy/kling/{KLING_API_VERSION}/images/generations"
+PATH_VIRTUAL_TRY_ON = f"/proxy/kling/{KLING_API_VERSION}/images/kolors-virtual-try-on"
+PATH_IMAGE_GENERATIONS = f"/proxy/kling/{KLING_API_VERSION}/images/generations"
+
+MAX_PROMPT_LENGTH_T2V = 2500
+MAX_PROMPT_LENGTH_I2V = 500
+MAX_PROMPT_LENGTH_IMAGE_GEN = 500
+MAX_NEGATIVE_PROMPT_LENGTH_IMAGE_GEN = 200
+MAX_PROMPT_LENGTH_LIP_SYNC = 120
+
+AVERAGE_DURATION_T2V = 319
+AVERAGE_DURATION_I2V = 164
+AVERAGE_DURATION_LIP_SYNC = 455
+AVERAGE_DURATION_VIRTUAL_TRY_ON = 19
+AVERAGE_DURATION_IMAGE_GEN = 32
+AVERAGE_DURATION_VIDEO_EFFECTS = 320
+AVERAGE_DURATION_VIDEO_EXTEND = 320
+
+R = TypeVar("R")
+
+
+class KlingApiError(Exception):
+ """Base exception for Kling API errors."""
+
+ pass
+
+
+def poll_until_finished(
+ auth_kwargs: dict[str, str],
+ api_endpoint: ApiEndpoint[Any, R],
+ result_url_extractor: Optional[Callable[[R], str]] = None,
+ estimated_duration: Optional[int] = None,
+ node_id: Optional[str] = None,
+) -> R:
+ """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response."""
+ return PollingOperation(
+ poll_endpoint=api_endpoint,
+ completed_statuses=[
+ KlingTaskStatus.succeed.value,
+ ],
+ failed_statuses=[KlingTaskStatus.failed.value],
+ status_extractor=lambda response: (
+ response.data.task_status.value
+ if response.data and response.data.task_status
+ else None
+ ),
+ auth_kwargs=auth_kwargs,
+ result_url_extractor=result_url_extractor,
+ estimated_duration=estimated_duration,
+ node_id=node_id,
+ ).execute()
+
+
+def is_valid_camera_control_configs(configs: list[float]) -> bool:
+ """Verifies that at least one camera control configuration is non-zero."""
+ return any(not math.isclose(value, 0.0) for value in configs)
+
+
+def is_valid_prompt(prompt: str) -> bool:
+ """Verifies that the prompt is not empty."""
+ return bool(prompt)
+
+
+def is_valid_task_creation_response(response: KlingText2VideoResponse) -> bool:
+ """Verifies that the initial response contains a task ID."""
+ return bool(response.data.task_id)
+
+
+def is_valid_video_response(response: KlingText2VideoResponse) -> bool:
+ """Verifies that the response contains a task result with at least one video."""
+ return (
+ response.data is not None
+ and response.data.task_result is not None
+ and response.data.task_result.videos is not None
+ and len(response.data.task_result.videos) > 0
+ )
+
+
+def is_valid_image_response(response: KlingVirtualTryOnResponse) -> bool:
+ """Verifies that the response contains a task result with at least one image."""
+ return (
+ response.data is not None
+ and response.data.task_result is not None
+ and response.data.task_result.images is not None
+ and len(response.data.task_result.images) > 0
+ )
+
+
+def validate_prompts(prompt: str, negative_prompt: str, max_length: int) -> bool:
+ """Verifies that the positive prompt is not empty and that neither promt is too long."""
+ if not prompt:
+ raise ValueError("Positive prompt is empty")
+ if len(prompt) > max_length:
+ raise ValueError(f"Positive prompt is too long: {len(prompt)} characters")
+ if negative_prompt and len(negative_prompt) > max_length:
+ raise ValueError(
+ f"Negative prompt is too long: {len(negative_prompt)} characters"
+ )
+ return True
+
+
+def validate_task_creation_response(response) -> None:
+ """Validates that the Kling task creation request was successful."""
+ if not is_valid_task_creation_response(response):
+ error_msg = f"Kling initial request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}"
+ logging.error(error_msg)
+ raise KlingApiError(error_msg)
+
+
+def validate_video_result_response(response) -> None:
+ """Validates that the Kling task result contains a video."""
+ if not is_valid_video_response(response):
+ error_msg = f"Kling task {response.data.task_id} succeeded but no video data found in response."
+ logging.error(f"Error: {error_msg}.\nResponse: {response}")
+ raise KlingApiError(error_msg)
+
+
+def validate_image_result_response(response) -> None:
+ """Validates that the Kling task result contains an image."""
+ if not is_valid_image_response(response):
+ error_msg = f"Kling task {response.data.task_id} succeeded but no image data found in response."
+ logging.error(f"Error: {error_msg}.\nResponse: {response}")
+ raise KlingApiError(error_msg)
+
+
+def validate_input_image(image: torch.Tensor) -> None:
+ """
+ Validates the input image adheres to the expectations of the Kling API:
+ - The image resolution should not be less than 300*300px
+ - The aspect ratio of the image should be between 1:2.5 ~ 2.5:1
+
+ See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo
+ """
+ validate_image_dimensions(image, min_width=300, min_height=300)
+ validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5)
+
+
+def get_camera_control_input_config(
+ tooltip: str, default: float = 0.0
+) -> tuple[IO, InputTypeOptions]:
+ """Returns common InputTypeOptions for Kling camera control configurations."""
+ input_config = {
+ "default": default,
+ "min": -10.0,
+ "max": 10.0,
+ "step": 0.25,
+ "display": "slider",
+ "tooltip": tooltip,
+ }
+ return IO.FLOAT, input_config
+
+
+def get_video_from_response(response) -> KlingVideoResult:
+ """Returns the first video object from the Kling video generation task result.
+ Will raise an error if the response is not valid.
+ """
+ video = response.data.task_result.videos[0]
+ logging.info(
+ "Kling task %s succeeded. Video URL: %s", response.data.task_id, video.url
+ )
+ return video
+
+
+def get_video_url_from_response(response) -> Optional[str]:
+ """Returns the first video url from the Kling video generation task result.
+ Will not raise an error if the response is not valid.
+ """
+ if response and is_valid_video_response(response):
+ return str(get_video_from_response(response).url)
+ else:
+ return None
+
+
+def get_images_from_response(response) -> list[KlingImageResult]:
+ """Returns the list of image objects from the Kling image generation task result.
+ Will raise an error if the response is not valid.
+ """
+ images = response.data.task_result.images
+ logging.info("Kling task %s succeeded. Images: %s", response.data.task_id, images)
+ return images
+
+
+def get_images_urls_from_response(response) -> Optional[str]:
+ """Returns the list of image urls from the Kling image generation task result.
+ Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls.
+ """
+ if response and is_valid_image_response(response):
+ images = get_images_from_response(response)
+ image_urls = [str(image.url) for image in images]
+ return "\n".join(image_urls)
+ else:
+ return None
+
+
+def video_result_to_node_output(
+ video: KlingVideoResult,
+) -> tuple[VideoFromFile, str, str]:
+ """Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output."""
+ return (
+ download_url_to_video_output(video.url),
+ str(video.id),
+ str(video.duration),
+ )
+
+
+def image_result_to_node_output(
+ images: list[KlingImageResult],
+) -> torch.Tensor:
+ """
+ Converts a KlingImageResult to a tuple containing a [B, H, W, C] tensor.
+ If multiple images are returned, they will be stacked along the batch dimension.
+ """
+ if len(images) == 1:
+ return download_url_to_image_tensor(images[0].url)
+ else:
+ return torch.cat([download_url_to_image_tensor(image.url) for image in images])
+
+
+class KlingNodeBase(ComfyNodeABC):
+ """Base class for Kling nodes."""
+
+ FUNCTION = "api_call"
+ CATEGORY = "api node/video/Kling"
+ API_NODE = True
+
+
+class KlingCameraControls(KlingNodeBase):
+ """Kling Camera Controls Node"""
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "camera_control_type": model_field_to_node_input(
+ IO.COMBO,
+ KlingCameraControl,
+ "type",
+ enum_type=KlingCameraControlType,
+ ),
+ "horizontal_movement": get_camera_control_input_config(
+ "Controls camera's movement along horizontal axis (x-axis). Negative indicates left, positive indicates right"
+ ),
+ "vertical_movement": get_camera_control_input_config(
+ "Controls camera's movement along vertical axis (y-axis). Negative indicates downward, positive indicates upward."
+ ),
+ "pan": get_camera_control_input_config(
+ "Controls camera's rotation in vertical plane (x-axis). Negative indicates downward rotation, positive indicates upward rotation.",
+ default=0.5,
+ ),
+ "tilt": get_camera_control_input_config(
+ "Controls camera's rotation in horizontal plane (y-axis). Negative indicates left rotation, positive indicates right rotation.",
+ ),
+ "roll": get_camera_control_input_config(
+ "Controls camera's rolling amount (z-axis). Negative indicates counterclockwise, positive indicates clockwise.",
+ ),
+ "zoom": get_camera_control_input_config(
+ "Controls change in camera's focal length. Negative indicates narrower field of view, positive indicates wider field of view.",
+ ),
+ }
+ }
+
+ DESCRIPTION = "Allows specifying configuration options for Kling Camera Controls and motion control effects."
+ RETURN_TYPES = ("CAMERA_CONTROL",)
+ RETURN_NAMES = ("camera_control",)
+ FUNCTION = "main"
+ API_NODE = False # This is just a helper node, it doesn't make an API call
+
+ @classmethod
+ def VALIDATE_INPUTS(
+ cls,
+ horizontal_movement: float,
+ vertical_movement: float,
+ pan: float,
+ tilt: float,
+ roll: float,
+ zoom: float,
+ ) -> bool | str:
+ if not is_valid_camera_control_configs(
+ [
+ horizontal_movement,
+ vertical_movement,
+ pan,
+ tilt,
+ roll,
+ zoom,
+ ]
+ ):
+ return "Invalid camera control configs: at least one of the values must be non-zero"
+ return True
+
+ def main(
+ self,
+ camera_control_type: str,
+ horizontal_movement: float,
+ vertical_movement: float,
+ pan: float,
+ tilt: float,
+ roll: float,
+ zoom: float,
+ ) -> tuple[KlingCameraControl]:
+ return (
+ KlingCameraControl(
+ type=KlingCameraControlType(camera_control_type),
+ config=KlingCameraConfig(
+ horizontal=horizontal_movement,
+ vertical=vertical_movement,
+ pan=pan,
+ roll=roll,
+ tilt=tilt,
+ zoom=zoom,
+ ),
+ ),
+ )
+
+
+class KlingTextToVideoNode(KlingNodeBase):
+ """Kling Text to Video Node"""
+
+ @staticmethod
+ def get_mode_string_mapping() -> dict[str, tuple[str, str, str]]:
+ """
+ Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples.
+ Only includes config combos that support the `image_tail` request field.
+
+ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
+ """
+ return {
+ "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
+ "standard mode / 10s duration / kling-v1": ("std", "10", "kling-v1"),
+ "pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
+ "pro mode / 10s duration / kling-v1": ("pro", "10", "kling-v1"),
+ "standard mode / 5s duration / kling-v1-6": ("std", "5", "kling-v1-6"),
+ "standard mode / 10s duration / kling-v1-6": ("std", "10", "kling-v1-6"),
+ "pro mode / 5s duration / kling-v2-master": ("pro", "5", "kling-v2-master"),
+ "pro mode / 10s duration / kling-v2-master": ("pro", "10", "kling-v2-master"),
+ "standard mode / 5s duration / kling-v2-master": ("std", "5", "kling-v2-master"),
+ "standard mode / 10s duration / kling-v2-master": ("std", "10", "kling-v2-master"),
+ }
+
+ @classmethod
+ def INPUT_TYPES(s):
+ modes = list(KlingTextToVideoNode.get_mode_string_mapping().keys())
+ return {
+ "required": {
+ "prompt": model_field_to_node_input(
+ IO.STRING, KlingText2VideoRequest, "prompt", multiline=True
+ ),
+ "negative_prompt": model_field_to_node_input(
+ IO.STRING, KlingText2VideoRequest, "negative_prompt", multiline=True
+ ),
+ "cfg_scale": model_field_to_node_input(
+ IO.FLOAT,
+ KlingText2VideoRequest,
+ "cfg_scale",
+ default=1.0,
+ min=0.0,
+ max=1.0,
+ ),
+ "aspect_ratio": model_field_to_node_input(
+ IO.COMBO,
+ KlingText2VideoRequest,
+ "aspect_ratio",
+ enum_type=KlingVideoGenAspectRatio,
+ ),
+ "mode": (
+ modes,
+ {
+ "default": modes[4],
+ "tooltip": "The configuration to use for the video generation following the format: mode / duration / model_name.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = ("VIDEO", "STRING", "STRING")
+ RETURN_NAMES = ("VIDEO", "video_id", "duration")
+ DESCRIPTION = "Kling Text to Video Node"
+
+ def get_response(
+ self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
+ ) -> KlingText2VideoResponse:
+ return poll_until_finished(
+ auth_kwargs,
+ ApiEndpoint(
+ path=f"{PATH_TEXT_TO_VIDEO}/{task_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=KlingText2VideoResponse,
+ ),
+ result_url_extractor=get_video_url_from_response,
+ estimated_duration=AVERAGE_DURATION_T2V,
+ node_id=node_id,
+ )
+
+ def api_call(
+ self,
+ prompt: str,
+ negative_prompt: str,
+ cfg_scale: float,
+ mode: str,
+ aspect_ratio: str,
+ camera_control: Optional[KlingCameraControl] = None,
+ model_name: Optional[str] = None,
+ duration: Optional[str] = None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ) -> tuple[VideoFromFile, str, str]:
+ validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
+ if model_name is None:
+ mode, duration, model_name = self.get_mode_string_mapping()[mode]
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_TEXT_TO_VIDEO,
+ method=HttpMethod.POST,
+ request_model=KlingText2VideoRequest,
+ response_model=KlingText2VideoResponse,
+ ),
+ request=KlingText2VideoRequest(
+ prompt=prompt if prompt else None,
+ negative_prompt=negative_prompt if negative_prompt else None,
+ duration=KlingVideoGenDuration(duration),
+ mode=KlingVideoGenMode(mode),
+ model_name=KlingVideoGenModelName(model_name),
+ cfg_scale=cfg_scale,
+ aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
+ camera_control=camera_control,
+ ),
+ auth_kwargs=kwargs,
+ )
+
+ task_creation_response = initial_operation.execute()
+ validate_task_creation_response(task_creation_response)
+
+ task_id = task_creation_response.data.task_id
+ final_response = self.get_response(
+ task_id, auth_kwargs=kwargs, node_id=unique_id
+ )
+ validate_video_result_response(final_response)
+
+ video = get_video_from_response(final_response)
+ return video_result_to_node_output(video)
+
+
+class KlingCameraControlT2VNode(KlingTextToVideoNode):
+ """
+ Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera.
+ Duration, mode, and model_name request fields are hard-coded because camera control is only supported in pro mode with the kling-v1-5 model at 5s duration as of 2025-05-02.
+ """
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": model_field_to_node_input(
+ IO.STRING, KlingText2VideoRequest, "prompt", multiline=True
+ ),
+ "negative_prompt": model_field_to_node_input(
+ IO.STRING,
+ KlingText2VideoRequest,
+ "negative_prompt",
+ multiline=True,
+ ),
+ "cfg_scale": model_field_to_node_input(
+ IO.FLOAT,
+ KlingText2VideoRequest,
+ "cfg_scale",
+ default=0.75,
+ min=0.0,
+ max=1.0,
+ ),
+ "aspect_ratio": model_field_to_node_input(
+ IO.COMBO,
+ KlingText2VideoRequest,
+ "aspect_ratio",
+ enum_type=KlingVideoGenAspectRatio,
+ ),
+ "camera_control": (
+ "CAMERA_CONTROL",
+ {
+ "tooltip": "Can be created using the Kling Camera Controls node. Controls the camera movement and motion during the video generation.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text."
+
+ def api_call(
+ self,
+ prompt: str,
+ negative_prompt: str,
+ cfg_scale: float,
+ aspect_ratio: str,
+ camera_control: Optional[KlingCameraControl] = None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ return super().api_call(
+ model_name=KlingVideoGenModelName.kling_v1,
+ cfg_scale=cfg_scale,
+ mode=KlingVideoGenMode.std,
+ aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
+ duration=KlingVideoGenDuration.field_5,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ camera_control=camera_control,
+ **kwargs,
+ )
+
+
+class KlingImage2VideoNode(KlingNodeBase):
+ """Kling Image to Video Node"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "start_frame": model_field_to_node_input(
+ IO.IMAGE,
+ KlingImage2VideoRequest,
+ "image",
+ tooltip="The reference image used to generate the video.",
+ ),
+ "prompt": model_field_to_node_input(
+ IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True
+ ),
+ "negative_prompt": model_field_to_node_input(
+ IO.STRING,
+ KlingImage2VideoRequest,
+ "negative_prompt",
+ multiline=True,
+ ),
+ "model_name": model_field_to_node_input(
+ IO.COMBO,
+ KlingImage2VideoRequest,
+ "model_name",
+ enum_type=KlingVideoGenModelName,
+ ),
+ "cfg_scale": model_field_to_node_input(
+ IO.FLOAT,
+ KlingImage2VideoRequest,
+ "cfg_scale",
+ default=0.8,
+ min=0.0,
+ max=1.0,
+ ),
+ "mode": model_field_to_node_input(
+ IO.COMBO,
+ KlingImage2VideoRequest,
+ "mode",
+ enum_type=KlingVideoGenMode,
+ ),
+ "aspect_ratio": model_field_to_node_input(
+ IO.COMBO,
+ KlingImage2VideoRequest,
+ "aspect_ratio",
+ enum_type=KlingVideoGenAspectRatio,
+ ),
+ "duration": model_field_to_node_input(
+ IO.COMBO,
+ KlingImage2VideoRequest,
+ "duration",
+ enum_type=KlingVideoGenDuration,
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = ("VIDEO", "STRING", "STRING")
+ RETURN_NAMES = ("VIDEO", "video_id", "duration")
+ DESCRIPTION = "Kling Image to Video Node"
+
+ def get_response(
+ self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
+ ) -> KlingImage2VideoResponse:
+ return poll_until_finished(
+ auth_kwargs,
+ ApiEndpoint(
+ path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}",
+ method=HttpMethod.GET,
+ request_model=KlingImage2VideoRequest,
+ response_model=KlingImage2VideoResponse,
+ ),
+ result_url_extractor=get_video_url_from_response,
+ estimated_duration=AVERAGE_DURATION_I2V,
+ node_id=node_id,
+ )
+
+ def api_call(
+ self,
+ start_frame: torch.Tensor,
+ prompt: str,
+ negative_prompt: str,
+ model_name: str,
+ cfg_scale: float,
+ mode: str,
+ aspect_ratio: str,
+ duration: str,
+ camera_control: Optional[KlingCameraControl] = None,
+ end_frame: Optional[torch.Tensor] = None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ) -> tuple[VideoFromFile]:
+ validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V)
+ validate_input_image(start_frame)
+
+ if camera_control is not None:
+ # Camera control type for image 2 video is always `simple`
+ camera_control.type = KlingCameraControlType.simple
+
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_IMAGE_TO_VIDEO,
+ method=HttpMethod.POST,
+ request_model=KlingImage2VideoRequest,
+ response_model=KlingImage2VideoResponse,
+ ),
+ request=KlingImage2VideoRequest(
+ model_name=KlingVideoGenModelName(model_name),
+ image=tensor_to_base64_string(start_frame),
+ image_tail=(
+ tensor_to_base64_string(end_frame)
+ if end_frame is not None
+ else None
+ ),
+ prompt=prompt,
+ negative_prompt=negative_prompt if negative_prompt else None,
+ cfg_scale=cfg_scale,
+ mode=KlingVideoGenMode(mode),
+ duration=KlingVideoGenDuration(duration),
+ camera_control=camera_control,
+ ),
+ auth_kwargs=kwargs,
+ )
+
+ task_creation_response = initial_operation.execute()
+ validate_task_creation_response(task_creation_response)
+ task_id = task_creation_response.data.task_id
+
+ final_response = self.get_response(
+ task_id, auth_kwargs=kwargs, node_id=unique_id
+ )
+ validate_video_result_response(final_response)
+
+ video = get_video_from_response(final_response)
+ return video_result_to_node_output(video)
+
+
+class KlingCameraControlI2VNode(KlingImage2VideoNode):
+ """
+ Kling Image to Video Camera Control Node. This node is a image to video node, but it supports controlling the camera.
+ Duration, mode, and model_name request fields are hard-coded because camera control is only supported in pro mode with the kling-v1-5 model at 5s duration as of 2025-05-02.
+ """
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "start_frame": model_field_to_node_input(
+ IO.IMAGE, KlingImage2VideoRequest, "image"
+ ),
+ "prompt": model_field_to_node_input(
+ IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True
+ ),
+ "negative_prompt": model_field_to_node_input(
+ IO.STRING,
+ KlingImage2VideoRequest,
+ "negative_prompt",
+ multiline=True,
+ ),
+ "cfg_scale": model_field_to_node_input(
+ IO.FLOAT,
+ KlingImage2VideoRequest,
+ "cfg_scale",
+ default=0.75,
+ min=0.0,
+ max=1.0,
+ ),
+ "aspect_ratio": model_field_to_node_input(
+ IO.COMBO,
+ KlingImage2VideoRequest,
+ "aspect_ratio",
+ enum_type=KlingVideoGenAspectRatio,
+ ),
+ "camera_control": (
+ "CAMERA_CONTROL",
+ {
+ "tooltip": "Can be created using the Kling Camera Controls node. Controls the camera movement and motion during the video generation.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image."
+
+ def api_call(
+ self,
+ start_frame: torch.Tensor,
+ prompt: str,
+ negative_prompt: str,
+ cfg_scale: float,
+ aspect_ratio: str,
+ camera_control: KlingCameraControl,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ return super().api_call(
+ model_name=KlingVideoGenModelName.kling_v1_5,
+ start_frame=start_frame,
+ cfg_scale=cfg_scale,
+ mode=KlingVideoGenMode.pro,
+ aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
+ duration=KlingVideoGenDuration.field_5,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ camera_control=camera_control,
+ unique_id=unique_id,
+ **kwargs,
+ )
+
+
+class KlingStartEndFrameNode(KlingImage2VideoNode):
+ """
+ Kling First Last Frame Node. This node allows creation of a video from a first and last frame. It calls the normal image to video endpoint, but only allows the subset of input options that support the `image_tail` request field.
+ """
+
+ @staticmethod
+ def get_mode_string_mapping() -> dict[str, tuple[str, str, str]]:
+ """
+ Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples.
+ Only includes config combos that support the `image_tail` request field.
+
+ See: [Kling API Docs Capability Map](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap)
+ """
+ return {
+ "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"),
+ "pro mode / 5s duration / kling-v1": ("pro", "5", "kling-v1"),
+ "pro mode / 5s duration / kling-v1-5": ("pro", "5", "kling-v1-5"),
+ "pro mode / 10s duration / kling-v1-5": ("pro", "10", "kling-v1-5"),
+ "pro mode / 5s duration / kling-v1-6": ("pro", "5", "kling-v1-6"),
+ "pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"),
+ }
+
+ @classmethod
+ def INPUT_TYPES(s):
+ modes = list(KlingStartEndFrameNode.get_mode_string_mapping().keys())
+ return {
+ "required": {
+ "start_frame": model_field_to_node_input(
+ IO.IMAGE, KlingImage2VideoRequest, "image"
+ ),
+ "end_frame": model_field_to_node_input(
+ IO.IMAGE, KlingImage2VideoRequest, "image_tail"
+ ),
+ "prompt": model_field_to_node_input(
+ IO.STRING, KlingImage2VideoRequest, "prompt", multiline=True
+ ),
+ "negative_prompt": model_field_to_node_input(
+ IO.STRING,
+ KlingImage2VideoRequest,
+ "negative_prompt",
+ multiline=True,
+ ),
+ "cfg_scale": model_field_to_node_input(
+ IO.FLOAT,
+ KlingImage2VideoRequest,
+ "cfg_scale",
+ default=0.5,
+ min=0.0,
+ max=1.0,
+ ),
+ "aspect_ratio": model_field_to_node_input(
+ IO.COMBO,
+ KlingImage2VideoRequest,
+ "aspect_ratio",
+ enum_type=KlingVideoGenAspectRatio,
+ ),
+ "mode": (
+ modes,
+ {
+ "default": modes[2],
+ "tooltip": "The configuration to use for the video generation following the format: mode / duration / model_name.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last."
+
+ def api_call(
+ self,
+ start_frame: torch.Tensor,
+ end_frame: torch.Tensor,
+ prompt: str,
+ negative_prompt: str,
+ cfg_scale: float,
+ aspect_ratio: str,
+ mode: str,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[
+ mode
+ ]
+ return super().api_call(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ model_name=model_name,
+ start_frame=start_frame,
+ cfg_scale=cfg_scale,
+ mode=mode,
+ aspect_ratio=aspect_ratio,
+ duration=duration,
+ end_frame=end_frame,
+ unique_id=unique_id,
+ **kwargs,
+ )
+
+
+class KlingVideoExtendNode(KlingNodeBase):
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": model_field_to_node_input(
+ IO.STRING, KlingVideoExtendRequest, "prompt", multiline=True
+ ),
+ "negative_prompt": model_field_to_node_input(
+ IO.STRING,
+ KlingVideoExtendRequest,
+ "negative_prompt",
+ multiline=True,
+ ),
+ "cfg_scale": model_field_to_node_input(
+ IO.FLOAT,
+ KlingVideoExtendRequest,
+ "cfg_scale",
+ default=0.5,
+ min=0.0,
+ max=1.0,
+ ),
+ "video_id": model_field_to_node_input(
+ IO.STRING, KlingVideoExtendRequest, "video_id", forceInput=True
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = ("VIDEO", "STRING", "STRING")
+ RETURN_NAMES = ("VIDEO", "video_id", "duration")
+ DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes."
+
+ def get_response(
+ self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
+ ) -> KlingVideoExtendResponse:
+ return poll_until_finished(
+ auth_kwargs,
+ ApiEndpoint(
+ path=f"{PATH_VIDEO_EXTEND}/{task_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=KlingVideoExtendResponse,
+ ),
+ result_url_extractor=get_video_url_from_response,
+ estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND,
+ node_id=node_id,
+ )
+
+ def api_call(
+ self,
+ prompt: str,
+ negative_prompt: str,
+ cfg_scale: float,
+ video_id: str,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ) -> tuple[VideoFromFile, str, str]:
+ validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V)
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_VIDEO_EXTEND,
+ method=HttpMethod.POST,
+ request_model=KlingVideoExtendRequest,
+ response_model=KlingVideoExtendResponse,
+ ),
+ request=KlingVideoExtendRequest(
+ prompt=prompt if prompt else None,
+ negative_prompt=negative_prompt if negative_prompt else None,
+ cfg_scale=cfg_scale,
+ video_id=video_id,
+ ),
+ auth_kwargs=kwargs,
+ )
+
+ task_creation_response = initial_operation.execute()
+ validate_task_creation_response(task_creation_response)
+ task_id = task_creation_response.data.task_id
+
+ final_response = self.get_response(
+ task_id, auth_kwargs=kwargs, node_id=unique_id
+ )
+ validate_video_result_response(final_response)
+
+ video = get_video_from_response(final_response)
+ return video_result_to_node_output(video)
+
+
+class KlingVideoEffectsBase(KlingNodeBase):
+ """Kling Video Effects Base"""
+
+ RETURN_TYPES = ("VIDEO", "STRING", "STRING")
+ RETURN_NAMES = ("VIDEO", "video_id", "duration")
+
+ def get_response(
+ self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
+ ) -> KlingVideoEffectsResponse:
+ return poll_until_finished(
+ auth_kwargs,
+ ApiEndpoint(
+ path=f"{PATH_VIDEO_EFFECTS}/{task_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=KlingVideoEffectsResponse,
+ ),
+ result_url_extractor=get_video_url_from_response,
+ estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS,
+ node_id=node_id,
+ )
+
+ def api_call(
+ self,
+ dual_character: bool,
+ effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene,
+ model_name: str,
+ duration: KlingVideoGenDuration,
+ image_1: torch.Tensor,
+ image_2: Optional[torch.Tensor] = None,
+ mode: Optional[KlingVideoGenMode] = None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ if dual_character:
+ request_input_field = KlingDualCharacterEffectInput(
+ model_name=model_name,
+ mode=mode,
+ images=[
+ tensor_to_base64_string(image_1),
+ tensor_to_base64_string(image_2),
+ ],
+ duration=duration,
+ )
+ else:
+ request_input_field = KlingSingleImageEffectInput(
+ model_name=model_name,
+ image=tensor_to_base64_string(image_1),
+ duration=duration,
+ )
+
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_VIDEO_EFFECTS,
+ method=HttpMethod.POST,
+ request_model=KlingVideoEffectsRequest,
+ response_model=KlingVideoEffectsResponse,
+ ),
+ request=KlingVideoEffectsRequest(
+ effect_scene=effect_scene,
+ input=request_input_field,
+ ),
+ auth_kwargs=kwargs,
+ )
+
+ task_creation_response = initial_operation.execute()
+ validate_task_creation_response(task_creation_response)
+ task_id = task_creation_response.data.task_id
+
+ final_response = self.get_response(
+ task_id, auth_kwargs=kwargs, node_id=unique_id
+ )
+ validate_video_result_response(final_response)
+
+ video = get_video_from_response(final_response)
+ return video_result_to_node_output(video)
+
+
+class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase):
+ """Kling Dual Character Video Effect Node"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image_left": (IO.IMAGE, {"tooltip": "Left side image"}),
+ "image_right": (IO.IMAGE, {"tooltip": "Right side image"}),
+ "effect_scene": model_field_to_node_input(
+ IO.COMBO,
+ KlingVideoEffectsRequest,
+ "effect_scene",
+ enum_type=KlingDualCharacterEffectsScene,
+ ),
+ "model_name": model_field_to_node_input(
+ IO.COMBO,
+ KlingDualCharacterEffectInput,
+ "model_name",
+ enum_type=KlingCharacterEffectModelName,
+ ),
+ "mode": model_field_to_node_input(
+ IO.COMBO,
+ KlingDualCharacterEffectInput,
+ "mode",
+ enum_type=KlingVideoGenMode,
+ ),
+ "duration": model_field_to_node_input(
+ IO.COMBO,
+ KlingDualCharacterEffectInput,
+ "duration",
+ enum_type=KlingVideoGenDuration,
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene. First image will be positioned on left side, second on right side of the composite."
+ RETURN_TYPES = ("VIDEO", "STRING")
+ RETURN_NAMES = ("VIDEO", "duration")
+
+ def api_call(
+ self,
+ image_left: torch.Tensor,
+ image_right: torch.Tensor,
+ effect_scene: KlingDualCharacterEffectsScene,
+ model_name: KlingCharacterEffectModelName,
+ mode: KlingVideoGenMode,
+ duration: KlingVideoGenDuration,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ video, _, duration = super().api_call(
+ dual_character=True,
+ effect_scene=effect_scene,
+ model_name=model_name,
+ mode=mode,
+ duration=duration,
+ image_1=image_left,
+ image_2=image_right,
+ unique_id=unique_id,
+ **kwargs,
+ )
+ return video, duration
+
+
+class KlingSingleImageVideoEffectNode(KlingVideoEffectsBase):
+ """Kling Single Image Video Effect Node"""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (
+ IO.IMAGE,
+ {
+ "tooltip": " Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"
+ },
+ ),
+ "effect_scene": model_field_to_node_input(
+ IO.COMBO,
+ KlingVideoEffectsRequest,
+ "effect_scene",
+ enum_type=KlingSingleImageEffectsScene,
+ ),
+ "model_name": model_field_to_node_input(
+ IO.COMBO,
+ KlingSingleImageEffectInput,
+ "model_name",
+ enum_type=KlingSingleImageEffectModelName,
+ ),
+ "duration": model_field_to_node_input(
+ IO.COMBO,
+ KlingSingleImageEffectInput,
+ "duration",
+ enum_type=KlingVideoGenDuration,
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene."
+
+ def api_call(
+ self,
+ image: torch.Tensor,
+ effect_scene: KlingSingleImageEffectsScene,
+ model_name: KlingSingleImageEffectModelName,
+ duration: KlingVideoGenDuration,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ return super().api_call(
+ dual_character=False,
+ effect_scene=effect_scene,
+ model_name=model_name,
+ duration=duration,
+ image_1=image,
+ unique_id=unique_id,
+ **kwargs,
+ )
+
+
+class KlingLipSyncBase(KlingNodeBase):
+ """Kling Lip Sync Base"""
+
+ RETURN_TYPES = ("VIDEO", "STRING", "STRING")
+ RETURN_NAMES = ("VIDEO", "video_id", "duration")
+
+ def validate_lip_sync_video(self, video: VideoInput):
+ """
+ Validates the input video adheres to the expectations of the Kling Lip Sync API:
+ - Video length does not exceed 10s and is not shorter than 2s
+ - Length and width dimensions should both be between 720px and 1920px
+
+ See: https://app.klingai.com/global/dev/document-api/apiReference/model/videoTolip
+ """
+ validate_video_dimensions(video, 720, 1920)
+ validate_video_duration(video, 2, 10)
+
+ def validate_text(self, text: str):
+ if not text:
+ raise ValueError("Text is required")
+ if len(text) > MAX_PROMPT_LENGTH_LIP_SYNC:
+ raise ValueError(
+ f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters."
+ )
+
+ def get_response(
+ self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
+ ) -> KlingLipSyncResponse:
+ """Polls the Kling API endpoint until the task reaches a terminal state."""
+ return poll_until_finished(
+ auth_kwargs,
+ ApiEndpoint(
+ path=f"{PATH_LIP_SYNC}/{task_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=KlingLipSyncResponse,
+ ),
+ result_url_extractor=get_video_url_from_response,
+ estimated_duration=AVERAGE_DURATION_LIP_SYNC,
+ node_id=node_id,
+ )
+
+ def api_call(
+ self,
+ video: VideoInput,
+ audio: Optional[AudioInput] = None,
+ voice_language: Optional[str] = None,
+ mode: Optional[str] = None,
+ text: Optional[str] = None,
+ voice_speed: Optional[float] = None,
+ voice_id: Optional[str] = None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ) -> tuple[VideoFromFile, str, str]:
+ if text:
+ self.validate_text(text)
+ self.validate_lip_sync_video(video)
+
+ # Upload video to Comfy API and get download URL
+ video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs)
+ logging.info("Uploaded video to Comfy API. URL: %s", video_url)
+
+ # Upload the audio file to Comfy API and get download URL
+ if audio:
+ audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs)
+ logging.info("Uploaded audio to Comfy API. URL: %s", audio_url)
+ else:
+ audio_url = None
+
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_LIP_SYNC,
+ method=HttpMethod.POST,
+ request_model=KlingLipSyncRequest,
+ response_model=KlingLipSyncResponse,
+ ),
+ request=KlingLipSyncRequest(
+ input=KlingLipSyncInputObject(
+ video_url=video_url,
+ mode=mode,
+ text=text,
+ voice_language=voice_language,
+ voice_speed=voice_speed,
+ audio_type="url",
+ audio_url=audio_url,
+ voice_id=voice_id,
+ ),
+ ),
+ auth_kwargs=kwargs,
+ )
+
+ task_creation_response = initial_operation.execute()
+ validate_task_creation_response(task_creation_response)
+ task_id = task_creation_response.data.task_id
+
+ final_response = self.get_response(
+ task_id, auth_kwargs=kwargs, node_id=unique_id
+ )
+ validate_video_result_response(final_response)
+
+ video = get_video_from_response(final_response)
+ return video_result_to_node_output(video)
+
+
+class KlingLipSyncAudioToVideoNode(KlingLipSyncBase):
+ """Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file."""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "video": (IO.VIDEO, {}),
+ "audio": (IO.AUDIO, {}),
+ "voice_language": model_field_to_node_input(
+ IO.COMBO,
+ KlingLipSyncInputObject,
+ "voice_language",
+ enum_type=KlingLipSyncVoiceLanguage,
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
+
+ def api_call(
+ self,
+ video: VideoInput,
+ audio: AudioInput,
+ voice_language: str,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ return super().api_call(
+ video=video,
+ audio=audio,
+ voice_language=voice_language,
+ mode="audio2video",
+ unique_id=unique_id,
+ **kwargs,
+ )
+
+
+class KlingLipSyncTextToVideoNode(KlingLipSyncBase):
+ """Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt."""
+
+ @staticmethod
+ def get_voice_config() -> dict[str, tuple[str, str]]:
+ return {
+ # English voices
+ "Melody": ("girlfriend_4_speech02", "en"),
+ "Sunny": ("genshin_vindi2", "en"),
+ "Sage": ("zhinen_xuesheng", "en"),
+ "Ace": ("AOT", "en"),
+ "Blossom": ("ai_shatang", "en"),
+ "Peppy": ("genshin_klee2", "en"),
+ "Dove": ("genshin_kirara", "en"),
+ "Shine": ("ai_kaiya", "en"),
+ "Anchor": ("oversea_male1", "en"),
+ "Lyric": ("ai_chenjiahao_712", "en"),
+ "Tender": ("chat1_female_new-3", "en"),
+ "Siren": ("chat_0407_5-1", "en"),
+ "Zippy": ("cartoon-boy-07", "en"),
+ "Bud": ("uk_boy1", "en"),
+ "Sprite": ("cartoon-girl-01", "en"),
+ "Candy": ("PeppaPig_platform", "en"),
+ "Beacon": ("ai_huangzhong_712", "en"),
+ "Rock": ("ai_huangyaoshi_712", "en"),
+ "Titan": ("ai_laoguowang_712", "en"),
+ "Grace": ("chengshu_jiejie", "en"),
+ "Helen": ("you_pingjing", "en"),
+ "Lore": ("calm_story1", "en"),
+ "Crag": ("uk_man2", "en"),
+ "Prattle": ("laopopo_speech02", "en"),
+ "Hearth": ("heainainai_speech02", "en"),
+ "The Reader": ("reader_en_m-v1", "en"),
+ "Commercial Lady": ("commercial_lady_en_f-v1", "en"),
+ # Chinese voices
+ "阳光少年": ("genshin_vindi2", "zh"),
+ "懂事小弟": ("zhinen_xuesheng", "zh"),
+ "运动少年": ("tiyuxi_xuedi", "zh"),
+ "青春少女": ("ai_shatang", "zh"),
+ "温柔小妹": ("genshin_klee2", "zh"),
+ "元气少女": ("genshin_kirara", "zh"),
+ "阳光男生": ("ai_kaiya", "zh"),
+ "幽默小哥": ("tiexin_nanyou", "zh"),
+ "文艺小哥": ("ai_chenjiahao_712", "zh"),
+ "甜美邻家": ("girlfriend_1_speech02", "zh"),
+ "温柔姐姐": ("chat1_female_new-3", "zh"),
+ "职场女青": ("girlfriend_2_speech02", "zh"),
+ "活泼男童": ("cartoon-boy-07", "zh"),
+ "俏皮女童": ("cartoon-girl-01", "zh"),
+ "稳重老爸": ("ai_huangyaoshi_712", "zh"),
+ "温柔妈妈": ("you_pingjing", "zh"),
+ "严肃上司": ("ai_laoguowang_712", "zh"),
+ "优雅贵妇": ("chengshu_jiejie", "zh"),
+ "慈祥爷爷": ("zhuxi_speech02", "zh"),
+ "唠叨爷爷": ("uk_oldman3", "zh"),
+ "唠叨奶奶": ("laopopo_speech02", "zh"),
+ "和蔼奶奶": ("heainainai_speech02", "zh"),
+ "东北老铁": ("dongbeilaotie_speech02", "zh"),
+ "重庆小伙": ("chongqingxiaohuo_speech02", "zh"),
+ "四川妹子": ("chuanmeizi_speech02", "zh"),
+ "潮汕大叔": ("chaoshandashu_speech02", "zh"),
+ "台湾男生": ("ai_taiwan_man2_speech02", "zh"),
+ "西安掌柜": ("xianzhanggui_speech02", "zh"),
+ "天津姐姐": ("tianjinjiejie_speech02", "zh"),
+ "新闻播报男": ("diyinnansang_DB_CN_M_04-v2", "zh"),
+ "译制片男": ("yizhipiannan-v1", "zh"),
+ "撒娇女友": ("tianmeixuemei-v1", "zh"),
+ "刀片烟嗓": ("daopianyansang-v1", "zh"),
+ "乖巧正太": ("mengwa-v1", "zh"),
+ }
+
+ @classmethod
+ def INPUT_TYPES(s):
+ voice_options = list(s.get_voice_config().keys())
+ return {
+ "required": {
+ "video": (IO.VIDEO, {}),
+ "text": model_field_to_node_input(
+ IO.STRING, KlingLipSyncInputObject, "text", multiline=True
+ ),
+ "voice": (voice_options, {"default": voice_options[0]}),
+ "voice_speed": model_field_to_node_input(
+ IO.FLOAT, KlingLipSyncInputObject, "voice_speed", slider=True
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length."
+
+ def api_call(
+ self,
+ video: VideoInput,
+ text: str,
+ voice: str,
+ voice_speed: float,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice]
+ return super().api_call(
+ video=video,
+ text=text,
+ voice_language=voice_language,
+ voice_id=voice_id,
+ voice_speed=voice_speed,
+ mode="text2video",
+ unique_id=unique_id,
+ **kwargs,
+ )
+
+
+class KlingImageGenerationBase(KlingNodeBase):
+ """Kling Image Generation Base Node."""
+
+ RETURN_TYPES = ("IMAGE",)
+ CATEGORY = "api node/image/Kling"
+
+ def validate_prompt(self, prompt: str, negative_prompt: Optional[str] = None):
+ if not prompt or len(prompt) > MAX_PROMPT_LENGTH_IMAGE_GEN:
+ raise ValueError(
+ f"Prompt must be less than {MAX_PROMPT_LENGTH_IMAGE_GEN} characters"
+ )
+ if negative_prompt and len(negative_prompt) > MAX_PROMPT_LENGTH_IMAGE_GEN:
+ raise ValueError(
+ f"Negative prompt must be less than {MAX_PROMPT_LENGTH_IMAGE_GEN} characters"
+ )
+
+
+class KlingVirtualTryOnNode(KlingImageGenerationBase):
+ """Kling Virtual Try On Node."""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "human_image": (IO.IMAGE, {}),
+ "cloth_image": (IO.IMAGE, {}),
+ "model_name": model_field_to_node_input(
+ IO.COMBO,
+ KlingVirtualTryOnRequest,
+ "model_name",
+ enum_type=KlingVirtualTryOnModelName,
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background."
+
+ def get_response(
+ self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
+ ) -> KlingVirtualTryOnResponse:
+ return poll_until_finished(
+ auth_kwargs,
+ ApiEndpoint(
+ path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=KlingVirtualTryOnResponse,
+ ),
+ result_url_extractor=get_images_urls_from_response,
+ estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON,
+ node_id=node_id,
+ )
+
+ def api_call(
+ self,
+ human_image: torch.Tensor,
+ cloth_image: torch.Tensor,
+ model_name: KlingVirtualTryOnModelName,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_VIRTUAL_TRY_ON,
+ method=HttpMethod.POST,
+ request_model=KlingVirtualTryOnRequest,
+ response_model=KlingVirtualTryOnResponse,
+ ),
+ request=KlingVirtualTryOnRequest(
+ human_image=tensor_to_base64_string(human_image),
+ cloth_image=tensor_to_base64_string(cloth_image),
+ model_name=model_name,
+ ),
+ auth_kwargs=kwargs,
+ )
+
+ task_creation_response = initial_operation.execute()
+ validate_task_creation_response(task_creation_response)
+ task_id = task_creation_response.data.task_id
+
+ final_response = self.get_response(
+ task_id, auth_kwargs=kwargs, node_id=unique_id
+ )
+ validate_image_result_response(final_response)
+
+ images = get_images_from_response(final_response)
+ return (image_result_to_node_output(images),)
+
+
+class KlingImageGenerationNode(KlingImageGenerationBase):
+ """Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": model_field_to_node_input(
+ IO.STRING,
+ KlingImageGenerationsRequest,
+ "prompt",
+ multiline=True,
+ max_length=MAX_PROMPT_LENGTH_IMAGE_GEN,
+ ),
+ "negative_prompt": model_field_to_node_input(
+ IO.STRING,
+ KlingImageGenerationsRequest,
+ "negative_prompt",
+ multiline=True,
+ ),
+ "image_type": model_field_to_node_input(
+ IO.COMBO,
+ KlingImageGenerationsRequest,
+ "image_reference",
+ enum_type=KlingImageGenImageReferenceType,
+ ),
+ "image_fidelity": model_field_to_node_input(
+ IO.FLOAT,
+ KlingImageGenerationsRequest,
+ "image_fidelity",
+ slider=True,
+ step=0.01,
+ ),
+ "human_fidelity": model_field_to_node_input(
+ IO.FLOAT,
+ KlingImageGenerationsRequest,
+ "human_fidelity",
+ slider=True,
+ step=0.01,
+ ),
+ "model_name": model_field_to_node_input(
+ IO.COMBO,
+ KlingImageGenerationsRequest,
+ "model_name",
+ enum_type=KlingImageGenModelName,
+ ),
+ "aspect_ratio": model_field_to_node_input(
+ IO.COMBO,
+ KlingImageGenerationsRequest,
+ "aspect_ratio",
+ enum_type=KlingImageGenAspectRatio,
+ ),
+ "n": model_field_to_node_input(
+ IO.INT,
+ KlingImageGenerationsRequest,
+ "n",
+ ),
+ },
+ "optional": {
+ "image": (IO.IMAGE, {}),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image."
+
+ def get_response(
+ self,
+ task_id: str,
+ auth_kwargs: Optional[dict[str, str]],
+ node_id: Optional[str] = None,
+ ) -> KlingImageGenerationsResponse:
+ return poll_until_finished(
+ auth_kwargs,
+ ApiEndpoint(
+ path=f"{PATH_IMAGE_GENERATIONS}/{task_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=KlingImageGenerationsResponse,
+ ),
+ result_url_extractor=get_images_urls_from_response,
+ estimated_duration=AVERAGE_DURATION_IMAGE_GEN,
+ node_id=node_id,
+ )
+
+ def api_call(
+ self,
+ model_name: KlingImageGenModelName,
+ prompt: str,
+ negative_prompt: str,
+ image_type: KlingImageGenImageReferenceType,
+ image_fidelity: float,
+ human_fidelity: float,
+ n: int,
+ aspect_ratio: KlingImageGenAspectRatio,
+ image: Optional[torch.Tensor] = None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ self.validate_prompt(prompt, negative_prompt)
+
+ if image is not None:
+ image = tensor_to_base64_string(image)
+
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_IMAGE_GENERATIONS,
+ method=HttpMethod.POST,
+ request_model=KlingImageGenerationsRequest,
+ response_model=KlingImageGenerationsResponse,
+ ),
+ request=KlingImageGenerationsRequest(
+ model_name=model_name,
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ image=image,
+ image_reference=image_type,
+ image_fidelity=image_fidelity,
+ human_fidelity=human_fidelity,
+ n=n,
+ aspect_ratio=aspect_ratio,
+ ),
+ auth_kwargs=kwargs,
+ )
+
+ task_creation_response = initial_operation.execute()
+ validate_task_creation_response(task_creation_response)
+ task_id = task_creation_response.data.task_id
+
+ final_response = self.get_response(
+ task_id, auth_kwargs=kwargs, node_id=unique_id
+ )
+ validate_image_result_response(final_response)
+
+ images = get_images_from_response(final_response)
+ return (image_result_to_node_output(images),)
+
+
+NODE_CLASS_MAPPINGS = {
+ "KlingCameraControls": KlingCameraControls,
+ "KlingTextToVideoNode": KlingTextToVideoNode,
+ "KlingImage2VideoNode": KlingImage2VideoNode,
+ "KlingCameraControlI2VNode": KlingCameraControlI2VNode,
+ "KlingCameraControlT2VNode": KlingCameraControlT2VNode,
+ "KlingStartEndFrameNode": KlingStartEndFrameNode,
+ "KlingVideoExtendNode": KlingVideoExtendNode,
+ "KlingLipSyncAudioToVideoNode": KlingLipSyncAudioToVideoNode,
+ "KlingLipSyncTextToVideoNode": KlingLipSyncTextToVideoNode,
+ "KlingVirtualTryOnNode": KlingVirtualTryOnNode,
+ "KlingImageGenerationNode": KlingImageGenerationNode,
+ "KlingSingleImageVideoEffectNode": KlingSingleImageVideoEffectNode,
+ "KlingDualCharacterVideoEffectNode": KlingDualCharacterVideoEffectNode,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "KlingCameraControls": "Kling Camera Controls",
+ "KlingTextToVideoNode": "Kling Text to Video",
+ "KlingImage2VideoNode": "Kling Image to Video",
+ "KlingCameraControlI2VNode": "Kling Image to Video (Camera Control)",
+ "KlingCameraControlT2VNode": "Kling Text to Video (Camera Control)",
+ "KlingStartEndFrameNode": "Kling Start-End Frame to Video",
+ "KlingVideoExtendNode": "Kling Video Extend",
+ "KlingLipSyncAudioToVideoNode": "Kling Lip Sync Video with Audio",
+ "KlingLipSyncTextToVideoNode": "Kling Lip Sync Video with Text",
+ "KlingVirtualTryOnNode": "Kling Virtual Try On",
+ "KlingImageGenerationNode": "Kling Image Generation",
+ "KlingSingleImageVideoEffectNode": "Kling Video Effects",
+ "KlingDualCharacterVideoEffectNode": "Kling Dual Character Video Effects",
+}
diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py
new file mode 100644
index 000000000..525dc38e6
--- /dev/null
+++ b/comfy_api_nodes/nodes_luma.py
@@ -0,0 +1,737 @@
+from __future__ import annotations
+from inspect import cleandoc
+from typing import Optional
+from comfy.comfy_types.node_typing import IO, ComfyNodeABC
+from comfy_api.input_impl.video_types import VideoFromFile
+from comfy_api_nodes.apis.luma_api import (
+ LumaImageModel,
+ LumaVideoModel,
+ LumaVideoOutputResolution,
+ LumaVideoModelOutputDuration,
+ LumaAspectRatio,
+ LumaState,
+ LumaImageGenerationRequest,
+ LumaGenerationRequest,
+ LumaGeneration,
+ LumaCharacterRef,
+ LumaModifyImageRef,
+ LumaImageIdentity,
+ LumaReference,
+ LumaReferenceChain,
+ LumaImageReference,
+ LumaKeyframes,
+ LumaConceptChain,
+ LumaIO,
+ get_luma_concepts,
+)
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+ PollingOperation,
+ EmptyRequest,
+)
+from comfy_api_nodes.apinode_utils import (
+ upload_images_to_comfyapi,
+ process_image_response,
+ validate_string,
+)
+from server import PromptServer
+
+import requests
+import torch
+from io import BytesIO
+
+LUMA_T2V_AVERAGE_DURATION = 105
+LUMA_I2V_AVERAGE_DURATION = 100
+
+def image_result_url_extractor(response: LumaGeneration):
+ return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None
+
+def video_result_url_extractor(response: LumaGeneration):
+ return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None
+
+class LumaReferenceNode(ComfyNodeABC):
+ """
+ Holds an image and weight for use with Luma Generate Image node.
+ """
+
+ RETURN_TYPES = (LumaIO.LUMA_REF,)
+ RETURN_NAMES = ("luma_ref",)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "create_luma_reference"
+ CATEGORY = "api node/image/Luma"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (
+ IO.IMAGE,
+ {
+ "tooltip": "Image to use as reference.",
+ },
+ ),
+ "weight": (
+ IO.FLOAT,
+ {
+ "default": 1.0,
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ "tooltip": "Weight of image reference.",
+ },
+ ),
+ },
+ "optional": {"luma_ref": (LumaIO.LUMA_REF,)},
+ }
+
+ def create_luma_reference(
+ self, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None
+ ):
+ if luma_ref is not None:
+ luma_ref = luma_ref.clone()
+ else:
+ luma_ref = LumaReferenceChain()
+ luma_ref.add(LumaReference(image=image, weight=round(weight, 2)))
+ return (luma_ref,)
+
+
+class LumaConceptsNode(ComfyNodeABC):
+ """
+ Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.
+ """
+
+ RETURN_TYPES = (LumaIO.LUMA_CONCEPTS,)
+ RETURN_NAMES = ("luma_concepts",)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "create_concepts"
+ CATEGORY = "api node/video/Luma"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "concept1": (get_luma_concepts(include_none=True),),
+ "concept2": (get_luma_concepts(include_none=True),),
+ "concept3": (get_luma_concepts(include_none=True),),
+ "concept4": (get_luma_concepts(include_none=True),),
+ },
+ "optional": {
+ "luma_concepts": (
+ LumaIO.LUMA_CONCEPTS,
+ {
+ "tooltip": "Optional Camera Concepts to add to the ones chosen here."
+ },
+ ),
+ },
+ }
+
+ def create_concepts(
+ self,
+ concept1: str,
+ concept2: str,
+ concept3: str,
+ concept4: str,
+ luma_concepts: LumaConceptChain = None,
+ ):
+ chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4])
+ if luma_concepts is not None:
+ chain = luma_concepts.clone_and_merge(chain)
+ return (chain,)
+
+
+class LumaImageGenerationNode(ComfyNodeABC):
+ """
+ Generates images synchronously based on prompt and aspect ratio.
+ """
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Luma"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation",
+ },
+ ),
+ "model": ([model.value for model in LumaImageModel],),
+ "aspect_ratio": (
+ [ratio.value for ratio in LumaAspectRatio],
+ {
+ "default": LumaAspectRatio.ratio_16_9,
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
+ },
+ ),
+ "style_image_weight": (
+ IO.FLOAT,
+ {
+ "default": 1.0,
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ "tooltip": "Weight of style image. Ignored if no style_image provided.",
+ },
+ ),
+ },
+ "optional": {
+ "image_luma_ref": (
+ LumaIO.LUMA_REF,
+ {
+ "tooltip": "Luma Reference node connection to influence generation with input images; up to 4 images can be considered."
+ },
+ ),
+ "style_image": (
+ IO.IMAGE,
+ {"tooltip": "Style reference image; only 1 image will be used."},
+ ),
+ "character_image": (
+ IO.IMAGE,
+ {
+ "tooltip": "Character reference images; can be a batch of multiple, up to 4 images can be considered."
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ def api_call(
+ self,
+ prompt: str,
+ model: str,
+ aspect_ratio: str,
+ seed,
+ style_image_weight: float,
+ image_luma_ref: LumaReferenceChain = None,
+ style_image: torch.Tensor = None,
+ character_image: torch.Tensor = None,
+ unique_id: str = None,
+ **kwargs,
+ ):
+ validate_string(prompt, strip_whitespace=True, min_length=3)
+ # handle image_luma_ref
+ api_image_ref = None
+ if image_luma_ref is not None:
+ api_image_ref = self._convert_luma_refs(
+ image_luma_ref, max_refs=4, auth_kwargs=kwargs,
+ )
+ # handle style_luma_ref
+ api_style_ref = None
+ if style_image is not None:
+ api_style_ref = self._convert_style_image(
+ style_image, weight=style_image_weight, auth_kwargs=kwargs,
+ )
+ # handle character_ref images
+ character_ref = None
+ if character_image is not None:
+ download_urls = upload_images_to_comfyapi(
+ character_image, max_images=4, auth_kwargs=kwargs,
+ )
+ character_ref = LumaCharacterRef(
+ identity0=LumaImageIdentity(images=download_urls)
+ )
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/luma/generations/image",
+ method=HttpMethod.POST,
+ request_model=LumaImageGenerationRequest,
+ response_model=LumaGeneration,
+ ),
+ request=LumaImageGenerationRequest(
+ prompt=prompt,
+ model=model,
+ aspect_ratio=aspect_ratio,
+ image_ref=api_image_ref,
+ style_ref=api_style_ref,
+ character_ref=character_ref,
+ ),
+ auth_kwargs=kwargs,
+ )
+ response_api: LumaGeneration = operation.execute()
+
+ operation = PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path=f"/proxy/luma/generations/{response_api.id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=LumaGeneration,
+ ),
+ completed_statuses=[LumaState.completed],
+ failed_statuses=[LumaState.failed],
+ status_extractor=lambda x: x.state,
+ result_url_extractor=image_result_url_extractor,
+ node_id=unique_id,
+ auth_kwargs=kwargs,
+ )
+ response_poll = operation.execute()
+
+ img_response = requests.get(response_poll.assets.image)
+ img = process_image_response(img_response)
+ return (img,)
+
+ def _convert_luma_refs(
+ self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None
+ ):
+ luma_urls = []
+ ref_count = 0
+ for ref in luma_ref.refs:
+ download_urls = upload_images_to_comfyapi(
+ ref.image, max_images=1, auth_kwargs=auth_kwargs
+ )
+ luma_urls.append(download_urls[0])
+ ref_count += 1
+ if ref_count >= max_refs:
+ break
+ return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs)
+
+ def _convert_style_image(
+ self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None
+ ):
+ chain = LumaReferenceChain(
+ first_ref=LumaReference(image=style_image, weight=weight)
+ )
+ return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs)
+
+
+class LumaImageModifyNode(ComfyNodeABC):
+ """
+ Modifies images synchronously based on prompt and aspect ratio.
+ """
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Luma"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (IO.IMAGE,),
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation",
+ },
+ ),
+ "image_weight": (
+ IO.FLOAT,
+ {
+ "default": 0.1,
+ "min": 0.0,
+ "max": 0.98,
+ "step": 0.01,
+ "tooltip": "Weight of the image; the closer to 1.0, the less the image will be modified.",
+ },
+ ),
+ "model": ([model.value for model in LumaImageModel],),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
+ },
+ ),
+ },
+ "optional": {},
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ def api_call(
+ self,
+ prompt: str,
+ model: str,
+ image: torch.Tensor,
+ image_weight: float,
+ seed,
+ unique_id: str = None,
+ **kwargs,
+ ):
+ # first, upload image
+ download_urls = upload_images_to_comfyapi(
+ image, max_images=1, auth_kwargs=kwargs,
+ )
+ image_url = download_urls[0]
+ # next, make Luma call with download url provided
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/luma/generations/image",
+ method=HttpMethod.POST,
+ request_model=LumaImageGenerationRequest,
+ response_model=LumaGeneration,
+ ),
+ request=LumaImageGenerationRequest(
+ prompt=prompt,
+ model=model,
+ modify_image_ref=LumaModifyImageRef(
+ url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2)
+ ),
+ ),
+ auth_kwargs=kwargs,
+ )
+ response_api: LumaGeneration = operation.execute()
+
+ operation = PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path=f"/proxy/luma/generations/{response_api.id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=LumaGeneration,
+ ),
+ completed_statuses=[LumaState.completed],
+ failed_statuses=[LumaState.failed],
+ status_extractor=lambda x: x.state,
+ result_url_extractor=image_result_url_extractor,
+ node_id=unique_id,
+ auth_kwargs=kwargs,
+ )
+ response_poll = operation.execute()
+
+ img_response = requests.get(response_poll.assets.image)
+ img = process_image_response(img_response)
+ return (img,)
+
+
+class LumaTextToVideoGenerationNode(ComfyNodeABC):
+ """
+ Generates videos synchronously based on prompt and output_size.
+ """
+
+ RETURN_TYPES = (IO.VIDEO,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/video/Luma"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the video generation",
+ },
+ ),
+ "model": ([model.value for model in LumaVideoModel],),
+ "aspect_ratio": (
+ [ratio.value for ratio in LumaAspectRatio],
+ {
+ "default": LumaAspectRatio.ratio_16_9,
+ },
+ ),
+ "resolution": (
+ [resolution.value for resolution in LumaVideoOutputResolution],
+ {
+ "default": LumaVideoOutputResolution.res_540p,
+ },
+ ),
+ "duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
+ "loop": (
+ IO.BOOLEAN,
+ {
+ "default": False,
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
+ },
+ ),
+ },
+ "optional": {
+ "luma_concepts": (
+ LumaIO.LUMA_CONCEPTS,
+ {
+ "tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ def api_call(
+ self,
+ prompt: str,
+ model: str,
+ aspect_ratio: str,
+ resolution: str,
+ duration: str,
+ loop: bool,
+ seed,
+ luma_concepts: LumaConceptChain = None,
+ unique_id: str = None,
+ **kwargs,
+ ):
+ validate_string(prompt, strip_whitespace=False, min_length=3)
+ duration = duration if model != LumaVideoModel.ray_1_6 else None
+ resolution = resolution if model != LumaVideoModel.ray_1_6 else None
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/luma/generations",
+ method=HttpMethod.POST,
+ request_model=LumaGenerationRequest,
+ response_model=LumaGeneration,
+ ),
+ request=LumaGenerationRequest(
+ prompt=prompt,
+ model=model,
+ resolution=resolution,
+ aspect_ratio=aspect_ratio,
+ duration=duration,
+ loop=loop,
+ concepts=luma_concepts.create_api_model() if luma_concepts else None,
+ ),
+ auth_kwargs=kwargs,
+ )
+ response_api: LumaGeneration = operation.execute()
+
+ if unique_id:
+ PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
+
+ operation = PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path=f"/proxy/luma/generations/{response_api.id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=LumaGeneration,
+ ),
+ completed_statuses=[LumaState.completed],
+ failed_statuses=[LumaState.failed],
+ status_extractor=lambda x: x.state,
+ result_url_extractor=video_result_url_extractor,
+ node_id=unique_id,
+ estimated_duration=LUMA_T2V_AVERAGE_DURATION,
+ auth_kwargs=kwargs,
+ )
+ response_poll = operation.execute()
+
+ vid_response = requests.get(response_poll.assets.video)
+ return (VideoFromFile(BytesIO(vid_response.content)),)
+
+
+class LumaImageToVideoGenerationNode(ComfyNodeABC):
+ """
+ Generates videos synchronously based on prompt, input images, and output_size.
+ """
+
+ RETURN_TYPES = (IO.VIDEO,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/video/Luma"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the video generation",
+ },
+ ),
+ "model": ([model.value for model in LumaVideoModel],),
+ # "aspect_ratio": ([ratio.value for ratio in LumaAspectRatio], {
+ # "default": LumaAspectRatio.ratio_16_9,
+ # }),
+ "resolution": (
+ [resolution.value for resolution in LumaVideoOutputResolution],
+ {
+ "default": LumaVideoOutputResolution.res_540p,
+ },
+ ),
+ "duration": ([dur.value for dur in LumaVideoModelOutputDuration],),
+ "loop": (
+ IO.BOOLEAN,
+ {
+ "default": False,
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
+ },
+ ),
+ },
+ "optional": {
+ "first_image": (
+ IO.IMAGE,
+ {"tooltip": "First frame of generated video."},
+ ),
+ "last_image": (IO.IMAGE, {"tooltip": "Last frame of generated video."}),
+ "luma_concepts": (
+ LumaIO.LUMA_CONCEPTS,
+ {
+ "tooltip": "Optional Camera Concepts to dictate camera motion via the Luma Concepts node."
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ def api_call(
+ self,
+ prompt: str,
+ model: str,
+ resolution: str,
+ duration: str,
+ loop: bool,
+ seed,
+ first_image: torch.Tensor = None,
+ last_image: torch.Tensor = None,
+ luma_concepts: LumaConceptChain = None,
+ unique_id: str = None,
+ **kwargs,
+ ):
+ if first_image is None and last_image is None:
+ raise Exception(
+ "At least one of first_image and last_image requires an input."
+ )
+ keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs)
+ duration = duration if model != LumaVideoModel.ray_1_6 else None
+ resolution = resolution if model != LumaVideoModel.ray_1_6 else None
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/luma/generations",
+ method=HttpMethod.POST,
+ request_model=LumaGenerationRequest,
+ response_model=LumaGeneration,
+ ),
+ request=LumaGenerationRequest(
+ prompt=prompt,
+ model=model,
+ aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason
+ resolution=resolution,
+ duration=duration,
+ loop=loop,
+ keyframes=keyframes,
+ concepts=luma_concepts.create_api_model() if luma_concepts else None,
+ ),
+ auth_kwargs=kwargs,
+ )
+ response_api: LumaGeneration = operation.execute()
+
+ if unique_id:
+ PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id)
+
+ operation = PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path=f"/proxy/luma/generations/{response_api.id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=LumaGeneration,
+ ),
+ completed_statuses=[LumaState.completed],
+ failed_statuses=[LumaState.failed],
+ status_extractor=lambda x: x.state,
+ result_url_extractor=video_result_url_extractor,
+ node_id=unique_id,
+ estimated_duration=LUMA_I2V_AVERAGE_DURATION,
+ auth_kwargs=kwargs,
+ )
+ response_poll = operation.execute()
+
+ vid_response = requests.get(response_poll.assets.video)
+ return (VideoFromFile(BytesIO(vid_response.content)),)
+
+ def _convert_to_keyframes(
+ self,
+ first_image: torch.Tensor = None,
+ last_image: torch.Tensor = None,
+ auth_kwargs: Optional[dict[str,str]] = None,
+ ):
+ if first_image is None and last_image is None:
+ return None
+ frame0 = None
+ frame1 = None
+ if first_image is not None:
+ download_urls = upload_images_to_comfyapi(
+ first_image, max_images=1, auth_kwargs=auth_kwargs,
+ )
+ frame0 = LumaImageReference(type="image", url=download_urls[0])
+ if last_image is not None:
+ download_urls = upload_images_to_comfyapi(
+ last_image, max_images=1, auth_kwargs=auth_kwargs,
+ )
+ frame1 = LumaImageReference(type="image", url=download_urls[0])
+ return LumaKeyframes(frame0=frame0, frame1=frame1)
+
+
+# A dictionary that contains all nodes you want to export with their names
+# NOTE: names should be globally unique
+NODE_CLASS_MAPPINGS = {
+ "LumaImageNode": LumaImageGenerationNode,
+ "LumaImageModifyNode": LumaImageModifyNode,
+ "LumaVideoNode": LumaTextToVideoGenerationNode,
+ "LumaImageToVideoNode": LumaImageToVideoGenerationNode,
+ "LumaReferenceNode": LumaReferenceNode,
+ "LumaConceptsNode": LumaConceptsNode,
+}
+
+# A dictionary that contains the friendly/humanly readable titles for the nodes
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "LumaImageNode": "Luma Text to Image",
+ "LumaImageModifyNode": "Luma Image to Image",
+ "LumaVideoNode": "Luma Text to Video",
+ "LumaImageToVideoNode": "Luma Image to Video",
+ "LumaReferenceNode": "Luma Reference",
+ "LumaConceptsNode": "Luma Concepts",
+}
diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py
new file mode 100644
index 000000000..9b46636db
--- /dev/null
+++ b/comfy_api_nodes/nodes_minimax.py
@@ -0,0 +1,332 @@
+from typing import Union
+import logging
+import torch
+
+from comfy.comfy_types.node_typing import IO
+from comfy_api.input_impl.video_types import VideoFromFile
+from comfy_api_nodes.apis import (
+ MinimaxVideoGenerationRequest,
+ MinimaxVideoGenerationResponse,
+ MinimaxFileRetrieveResponse,
+ MinimaxTaskResultResponse,
+ SubjectReferenceItem,
+ Model
+)
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+ PollingOperation,
+ EmptyRequest,
+)
+from comfy_api_nodes.apinode_utils import (
+ download_url_to_bytesio,
+ upload_images_to_comfyapi,
+ validate_string,
+)
+from server import PromptServer
+
+
+I2V_AVERAGE_DURATION = 114
+T2V_AVERAGE_DURATION = 234
+
+class MinimaxTextToVideoNode:
+ """
+ Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API.
+ """
+
+ AVERAGE_DURATION = T2V_AVERAGE_DURATION
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt_text": (
+ "STRING",
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Text prompt to guide the video generation",
+ },
+ ),
+ "model": (
+ [
+ "T2V-01",
+ "T2V-01-Director",
+ ],
+ {
+ "default": "T2V-01",
+ "tooltip": "Model to use for video generation",
+ },
+ ),
+ },
+ "optional": {
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "The random seed used for creating the noise.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = ("VIDEO",)
+ DESCRIPTION = "Generates videos from prompts using MiniMax's API"
+ FUNCTION = "generate_video"
+ CATEGORY = "api node/video/MiniMax"
+ API_NODE = True
+ OUTPUT_NODE = True
+
+ def generate_video(
+ self,
+ prompt_text,
+ seed=0,
+ model="T2V-01",
+ image: torch.Tensor=None, # used for ImageToVideo
+ subject: torch.Tensor=None, # used for SubjectToVideo
+ unique_id: Union[str, None]=None,
+ **kwargs,
+ ):
+ '''
+ Function used between MiniMax nodes - supports T2V, I2V, and S2V, based on provided arguments.
+ '''
+ if image is None:
+ validate_string(prompt_text, field_name="prompt_text")
+ # upload image, if passed in
+ image_url = None
+ if image is not None:
+ image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0]
+
+ # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model
+ subject_reference = None
+ if subject is not None:
+ subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0]
+ subject_reference = [SubjectReferenceItem(image=subject_url)]
+
+
+ video_generate_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/minimax/video_generation",
+ method=HttpMethod.POST,
+ request_model=MinimaxVideoGenerationRequest,
+ response_model=MinimaxVideoGenerationResponse,
+ ),
+ request=MinimaxVideoGenerationRequest(
+ model=Model(model),
+ prompt=prompt_text,
+ callback_url=None,
+ first_frame_image=image_url,
+ subject_reference=subject_reference,
+ prompt_optimizer=None,
+ ),
+ auth_kwargs=kwargs,
+ )
+ response = video_generate_operation.execute()
+
+ task_id = response.task_id
+ if not task_id:
+ raise Exception(f"MiniMax generation failed: {response.base_resp}")
+
+ video_generate_operation = PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path="/proxy/minimax/query/video_generation",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=MinimaxTaskResultResponse,
+ query_params={"task_id": task_id},
+ ),
+ completed_statuses=["Success"],
+ failed_statuses=["Fail"],
+ status_extractor=lambda x: x.status.value,
+ estimated_duration=self.AVERAGE_DURATION,
+ node_id=unique_id,
+ auth_kwargs=kwargs,
+ )
+ task_result = video_generate_operation.execute()
+
+ file_id = task_result.file_id
+ if file_id is None:
+ raise Exception("Request was not successful. Missing file ID.")
+ file_retrieve_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/minimax/files/retrieve",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=MinimaxFileRetrieveResponse,
+ query_params={"file_id": int(file_id)},
+ ),
+ request=EmptyRequest(),
+ auth_kwargs=kwargs,
+ )
+ file_result = file_retrieve_operation.execute()
+
+ file_url = file_result.file.download_url
+ if file_url is None:
+ raise Exception(
+ f"No video was found in the response. Full response: {file_result.model_dump()}"
+ )
+ logging.info(f"Generated video URL: {file_url}")
+ if unique_id:
+ if hasattr(file_result.file, "backup_download_url"):
+ message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}"
+ else:
+ message = f"Result URL: {file_url}"
+ PromptServer.instance.send_progress_text(message, unique_id)
+
+ video_io = download_url_to_bytesio(file_url)
+ if video_io is None:
+ error_msg = f"Failed to download video from {file_url}"
+ logging.error(error_msg)
+ raise Exception(error_msg)
+ return (VideoFromFile(video_io),)
+
+
+class MinimaxImageToVideoNode(MinimaxTextToVideoNode):
+ """
+ Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
+ """
+
+ AVERAGE_DURATION = I2V_AVERAGE_DURATION
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (
+ IO.IMAGE,
+ {
+ "tooltip": "Image to use as first frame of video generation"
+ },
+ ),
+ "prompt_text": (
+ "STRING",
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Text prompt to guide the video generation",
+ },
+ ),
+ "model": (
+ [
+ "I2V-01-Director",
+ "I2V-01",
+ "I2V-01-live",
+ ],
+ {
+ "default": "I2V-01",
+ "tooltip": "Model to use for video generation",
+ },
+ ),
+ },
+ "optional": {
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "The random seed used for creating the noise.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = ("VIDEO",)
+ DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
+ FUNCTION = "generate_video"
+ CATEGORY = "api node/video/MiniMax"
+ API_NODE = True
+ OUTPUT_NODE = True
+
+
+class MinimaxSubjectToVideoNode(MinimaxTextToVideoNode):
+ """
+ Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API.
+ """
+
+ AVERAGE_DURATION = T2V_AVERAGE_DURATION
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "subject": (
+ IO.IMAGE,
+ {
+ "tooltip": "Image of subject to reference video generation"
+ },
+ ),
+ "prompt_text": (
+ "STRING",
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Text prompt to guide the video generation",
+ },
+ ),
+ "model": (
+ [
+ "S2V-01",
+ ],
+ {
+ "default": "S2V-01",
+ "tooltip": "Model to use for video generation",
+ },
+ ),
+ },
+ "optional": {
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "The random seed used for creating the noise.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = ("VIDEO",)
+ DESCRIPTION = "Generates videos from an image and prompts using MiniMax's API"
+ FUNCTION = "generate_video"
+ CATEGORY = "api node/video/MiniMax"
+ API_NODE = True
+ OUTPUT_NODE = True
+
+
+# A dictionary that contains all nodes you want to export with their names
+# NOTE: names should be globally unique
+NODE_CLASS_MAPPINGS = {
+ "MinimaxTextToVideoNode": MinimaxTextToVideoNode,
+ "MinimaxImageToVideoNode": MinimaxImageToVideoNode,
+ # "MinimaxSubjectToVideoNode": MinimaxSubjectToVideoNode,
+}
+
+# A dictionary that contains the friendly/humanly readable titles for the nodes
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "MinimaxTextToVideoNode": "MiniMax Text to Video",
+ "MinimaxImageToVideoNode": "MiniMax Image to Video",
+ "MinimaxSubjectToVideoNode": "MiniMax Subject to Video",
+}
diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py
new file mode 100644
index 000000000..be1d2de4a
--- /dev/null
+++ b/comfy_api_nodes/nodes_openai.py
@@ -0,0 +1,1008 @@
+import io
+from typing import TypedDict, Optional
+import json
+import os
+import time
+import re
+import uuid
+from enum import Enum
+from inspect import cleandoc
+import numpy as np
+import torch
+from PIL import Image
+from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict
+from server import PromptServer
+import folder_paths
+
+
+from comfy_api_nodes.apis import (
+ OpenAIImageGenerationRequest,
+ OpenAIImageEditRequest,
+ OpenAIImageGenerationResponse,
+ OpenAICreateResponse,
+ OpenAIResponse,
+ CreateModelResponseProperties,
+ Item,
+ Includable,
+ OutputContent,
+ InputImageContent,
+ Detail,
+ InputTextContent,
+ InputMessage,
+ InputMessageContentList,
+ InputContent,
+ InputFileContent,
+)
+
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+ PollingOperation,
+ EmptyRequest,
+)
+
+from comfy_api_nodes.apinode_utils import (
+ downscale_image_tensor,
+ validate_and_cast_response,
+ validate_string,
+ tensor_to_base64_string,
+ text_filepath_to_data_uri,
+)
+from comfy_api_nodes.mapper_utils import model_field_to_node_input
+
+
+RESPONSES_ENDPOINT = "/proxy/openai/v1/responses"
+STARTING_POINT_ID_PATTERN = r""
+
+
+class HistoryEntry(TypedDict):
+ """Type definition for a single history entry in the chat."""
+
+ prompt: str
+ response: str
+ response_id: str
+ timestamp: float
+
+
+class ChatHistory(TypedDict):
+ """Type definition for the chat history dictionary."""
+
+ __annotations__: dict[str, list[HistoryEntry]]
+
+
+class SupportedOpenAIModel(str, Enum):
+ o4_mini = "o4-mini"
+ o1 = "o1"
+ o3 = "o3"
+ o1_pro = "o1-pro"
+ gpt_4o = "gpt-4o"
+ gpt_4_1 = "gpt-4.1"
+ gpt_4_1_mini = "gpt-4.1-mini"
+ gpt_4_1_nano = "gpt-4.1-nano"
+
+
+class OpenAIDalle2(ComfyNodeABC):
+ """
+ Generates images synchronously via OpenAI's DALL·E 2 endpoint.
+ """
+
+ def __init__(self):
+ pass
+
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypeDict:
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Text prompt for DALL·E",
+ },
+ ),
+ },
+ "optional": {
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 2**31 - 1,
+ "step": 1,
+ "display": "number",
+ "control_after_generate": True,
+ "tooltip": "not implemented yet in backend",
+ },
+ ),
+ "size": (
+ IO.COMBO,
+ {
+ "options": ["256x256", "512x512", "1024x1024"],
+ "default": "1024x1024",
+ "tooltip": "Image size",
+ },
+ ),
+ "n": (
+ IO.INT,
+ {
+ "default": 1,
+ "min": 1,
+ "max": 8,
+ "step": 1,
+ "display": "number",
+ "tooltip": "How many images to generate",
+ },
+ ),
+ "image": (
+ IO.IMAGE,
+ {
+ "default": None,
+ "tooltip": "Optional reference image for image editing.",
+ },
+ ),
+ "mask": (
+ IO.MASK,
+ {
+ "default": None,
+ "tooltip": "Optional mask for inpainting (white areas will be replaced)",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = (IO.IMAGE,)
+ FUNCTION = "api_call"
+ CATEGORY = "api node/image/OpenAI"
+ DESCRIPTION = cleandoc(__doc__ or "")
+ API_NODE = True
+
+ def api_call(
+ self,
+ prompt,
+ seed=0,
+ image=None,
+ mask=None,
+ n=1,
+ size="1024x1024",
+ unique_id=None,
+ **kwargs,
+ ):
+ validate_string(prompt, strip_whitespace=False)
+ model = "dall-e-2"
+ path = "/proxy/openai/images/generations"
+ content_type = "application/json"
+ request_class = OpenAIImageGenerationRequest
+ img_binary = None
+
+ if image is not None and mask is not None:
+ path = "/proxy/openai/images/edits"
+ content_type = "multipart/form-data"
+ request_class = OpenAIImageEditRequest
+
+ input_tensor = image.squeeze().cpu()
+ height, width, channels = input_tensor.shape
+ rgba_tensor = torch.ones(height, width, 4, device="cpu")
+ rgba_tensor[:, :, :channels] = input_tensor
+
+ if mask.shape[1:] != image.shape[1:-1]:
+ raise Exception("Mask and Image must be the same size")
+ rgba_tensor[:, :, 3] = 1 - mask.squeeze().cpu()
+
+ rgba_tensor = downscale_image_tensor(rgba_tensor.unsqueeze(0)).squeeze()
+
+ image_np = (rgba_tensor.numpy() * 255).astype(np.uint8)
+ img = Image.fromarray(image_np)
+ img_byte_arr = io.BytesIO()
+ img.save(img_byte_arr, format="PNG")
+ img_byte_arr.seek(0)
+ img_binary = img_byte_arr # .getvalue()
+ img_binary.name = "image.png"
+ elif image is not None or mask is not None:
+ raise Exception("Dall-E 2 image editing requires an image AND a mask")
+
+ # Build the operation
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=path,
+ method=HttpMethod.POST,
+ request_model=request_class,
+ response_model=OpenAIImageGenerationResponse,
+ ),
+ request=request_class(
+ model=model,
+ prompt=prompt,
+ n=n,
+ size=size,
+ seed=seed,
+ ),
+ files=(
+ {
+ "image": img_binary,
+ }
+ if img_binary
+ else None
+ ),
+ content_type=content_type,
+ auth_kwargs=kwargs,
+ )
+
+ response = operation.execute()
+
+ img_tensor = validate_and_cast_response(response, node_id=unique_id)
+ return (img_tensor,)
+
+
+class OpenAIDalle3(ComfyNodeABC):
+ """
+ Generates images synchronously via OpenAI's DALL·E 3 endpoint.
+ """
+
+ def __init__(self):
+ pass
+
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypeDict:
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Text prompt for DALL·E",
+ },
+ ),
+ },
+ "optional": {
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 2**31 - 1,
+ "step": 1,
+ "display": "number",
+ "control_after_generate": True,
+ "tooltip": "not implemented yet in backend",
+ },
+ ),
+ "quality": (
+ IO.COMBO,
+ {
+ "options": ["standard", "hd"],
+ "default": "standard",
+ "tooltip": "Image quality",
+ },
+ ),
+ "style": (
+ IO.COMBO,
+ {
+ "options": ["natural", "vivid"],
+ "default": "natural",
+ "tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.",
+ },
+ ),
+ "size": (
+ IO.COMBO,
+ {
+ "options": ["1024x1024", "1024x1792", "1792x1024"],
+ "default": "1024x1024",
+ "tooltip": "Image size",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = (IO.IMAGE,)
+ FUNCTION = "api_call"
+ CATEGORY = "api node/image/OpenAI"
+ DESCRIPTION = cleandoc(__doc__ or "")
+ API_NODE = True
+
+ def api_call(
+ self,
+ prompt,
+ seed=0,
+ style="natural",
+ quality="standard",
+ size="1024x1024",
+ unique_id=None,
+ **kwargs,
+ ):
+ validate_string(prompt, strip_whitespace=False)
+ model = "dall-e-3"
+
+ # build the operation
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/openai/images/generations",
+ method=HttpMethod.POST,
+ request_model=OpenAIImageGenerationRequest,
+ response_model=OpenAIImageGenerationResponse,
+ ),
+ request=OpenAIImageGenerationRequest(
+ model=model,
+ prompt=prompt,
+ quality=quality,
+ size=size,
+ style=style,
+ seed=seed,
+ ),
+ auth_kwargs=kwargs,
+ )
+
+ response = operation.execute()
+
+ img_tensor = validate_and_cast_response(response, node_id=unique_id)
+ return (img_tensor,)
+
+
+class OpenAIGPTImage1(ComfyNodeABC):
+ """
+ Generates images synchronously via OpenAI's GPT Image 1 endpoint.
+ """
+
+ def __init__(self):
+ pass
+
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypeDict:
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Text prompt for GPT Image 1",
+ },
+ ),
+ },
+ "optional": {
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 2**31 - 1,
+ "step": 1,
+ "display": "number",
+ "control_after_generate": True,
+ "tooltip": "not implemented yet in backend",
+ },
+ ),
+ "quality": (
+ IO.COMBO,
+ {
+ "options": ["low", "medium", "high"],
+ "default": "low",
+ "tooltip": "Image quality, affects cost and generation time.",
+ },
+ ),
+ "background": (
+ IO.COMBO,
+ {
+ "options": ["opaque", "transparent"],
+ "default": "opaque",
+ "tooltip": "Return image with or without background",
+ },
+ ),
+ "size": (
+ IO.COMBO,
+ {
+ "options": ["auto", "1024x1024", "1024x1536", "1536x1024"],
+ "default": "auto",
+ "tooltip": "Image size",
+ },
+ ),
+ "n": (
+ IO.INT,
+ {
+ "default": 1,
+ "min": 1,
+ "max": 8,
+ "step": 1,
+ "display": "number",
+ "tooltip": "How many images to generate",
+ },
+ ),
+ "image": (
+ IO.IMAGE,
+ {
+ "default": None,
+ "tooltip": "Optional reference image for image editing.",
+ },
+ ),
+ "mask": (
+ IO.MASK,
+ {
+ "default": None,
+ "tooltip": "Optional mask for inpainting (white areas will be replaced)",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = (IO.IMAGE,)
+ FUNCTION = "api_call"
+ CATEGORY = "api node/image/OpenAI"
+ DESCRIPTION = cleandoc(__doc__ or "")
+ API_NODE = True
+
+ def api_call(
+ self,
+ prompt,
+ seed=0,
+ quality="low",
+ background="opaque",
+ image=None,
+ mask=None,
+ n=1,
+ size="1024x1024",
+ unique_id=None,
+ **kwargs,
+ ):
+ validate_string(prompt, strip_whitespace=False)
+ model = "gpt-image-1"
+ path = "/proxy/openai/images/generations"
+ content_type = "application/json"
+ request_class = OpenAIImageGenerationRequest
+ img_binaries = []
+ mask_binary = None
+ files = []
+
+ if image is not None:
+ path = "/proxy/openai/images/edits"
+ request_class = OpenAIImageEditRequest
+ content_type = "multipart/form-data"
+
+ batch_size = image.shape[0]
+
+ for i in range(batch_size):
+ single_image = image[i : i + 1]
+ scaled_image = downscale_image_tensor(single_image).squeeze()
+
+ image_np = (scaled_image.numpy() * 255).astype(np.uint8)
+ img = Image.fromarray(image_np)
+ img_byte_arr = io.BytesIO()
+ img.save(img_byte_arr, format="PNG")
+ img_byte_arr.seek(0)
+ img_binary = img_byte_arr
+ img_binary.name = f"image_{i}.png"
+
+ img_binaries.append(img_binary)
+ if batch_size == 1:
+ files.append(("image", img_binary))
+ else:
+ files.append(("image[]", img_binary))
+
+ if mask is not None:
+ if image is None:
+ raise Exception("Cannot use a mask without an input image")
+ if image.shape[0] != 1:
+ raise Exception("Cannot use a mask with multiple image")
+ if mask.shape[1:] != image.shape[1:-1]:
+ raise Exception("Mask and Image must be the same size")
+ batch, height, width = mask.shape
+ rgba_mask = torch.zeros(height, width, 4, device="cpu")
+ rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
+
+ scaled_mask = downscale_image_tensor(rgba_mask.unsqueeze(0)).squeeze()
+
+ mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
+ mask_img = Image.fromarray(mask_np)
+ mask_img_byte_arr = io.BytesIO()
+ mask_img.save(mask_img_byte_arr, format="PNG")
+ mask_img_byte_arr.seek(0)
+ mask_binary = mask_img_byte_arr
+ mask_binary.name = "mask.png"
+ files.append(("mask", mask_binary))
+
+ # Build the operation
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=path,
+ method=HttpMethod.POST,
+ request_model=request_class,
+ response_model=OpenAIImageGenerationResponse,
+ ),
+ request=request_class(
+ model=model,
+ prompt=prompt,
+ quality=quality,
+ background=background,
+ n=n,
+ seed=seed,
+ size=size,
+ ),
+ files=files if files else None,
+ content_type=content_type,
+ auth_kwargs=kwargs,
+ )
+
+ response = operation.execute()
+
+ img_tensor = validate_and_cast_response(response, node_id=unique_id)
+ return (img_tensor,)
+
+
+class OpenAITextNode(ComfyNodeABC):
+ """
+ Base class for OpenAI text generation nodes.
+ """
+
+ RETURN_TYPES = (IO.STRING,)
+ FUNCTION = "api_call"
+ CATEGORY = "api node/text/OpenAI"
+ API_NODE = True
+
+
+class OpenAIChatNode(OpenAITextNode):
+ """
+ Node to generate text responses from an OpenAI model.
+ """
+
+ def __init__(self) -> None:
+ """Initialize the chat node with a new session ID and empty history."""
+ self.current_session_id: str = str(uuid.uuid4())
+ self.history: dict[str, list[HistoryEntry]] = {}
+ self.previous_response_id: Optional[str] = None
+
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypeDict:
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Text inputs to the model, used to generate a response.",
+ },
+ ),
+ "persist_context": (
+ IO.BOOLEAN,
+ {
+ "default": True,
+ "tooltip": "Persist chat context between calls (multi-turn conversation)",
+ },
+ ),
+ "model": model_field_to_node_input(
+ IO.COMBO,
+ OpenAICreateResponse,
+ "model",
+ enum_type=SupportedOpenAIModel,
+ ),
+ },
+ "optional": {
+ "images": (
+ IO.IMAGE,
+ {
+ "default": None,
+ "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.",
+ },
+ ),
+ "files": (
+ "OPENAI_INPUT_FILES",
+ {
+ "default": None,
+ "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.",
+ },
+ ),
+ "advanced_options": (
+ "OPENAI_CHAT_CONFIG",
+ {
+ "default": None,
+ "tooltip": "Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Generate text responses from an OpenAI model."
+
+ def get_result_response(
+ self,
+ response_id: str,
+ include: Optional[list[Includable]] = None,
+ auth_kwargs: Optional[dict[str, str]] = None,
+ ) -> OpenAIResponse:
+ """
+ Retrieve a model response with the given ID from the OpenAI API.
+
+ Args:
+ response_id (str): The ID of the response to retrieve.
+ include (Optional[List[Includable]]): Additional fields to include
+ in the response. See the `include` parameter for Response
+ creation above for more information.
+
+ """
+ return PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path=f"{RESPONSES_ENDPOINT}/{response_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=OpenAIResponse,
+ query_params={"include": include},
+ ),
+ completed_statuses=["completed"],
+ failed_statuses=["failed"],
+ status_extractor=lambda response: response.status,
+ auth_kwargs=auth_kwargs,
+ ).execute()
+
+ def get_message_content_from_response(
+ self, response: OpenAIResponse
+ ) -> list[OutputContent]:
+ """Extract message content from the API response."""
+ for output in response.output:
+ if output.root.type == "message":
+ return output.root.content
+ raise TypeError("No output message found in response")
+
+ def get_text_from_message_content(
+ self, message_content: list[OutputContent]
+ ) -> str:
+ """Extract text content from message content."""
+ for content_item in message_content:
+ if content_item.root.type == "output_text":
+ return str(content_item.root.text)
+ return "No text output found in response"
+
+ def get_history_text(self, session_id: str) -> str:
+ """Convert the entire history for a given session to JSON string."""
+ return json.dumps(self.history[session_id])
+
+ def display_history_on_node(self, session_id: str, node_id: str) -> None:
+ """Display formatted chat history on the node UI."""
+ render_spec = {
+ "node_id": node_id,
+ "component": "ChatHistoryWidget",
+ "props": {
+ "history": self.get_history_text(session_id),
+ },
+ }
+ PromptServer.instance.send_sync(
+ "display_component",
+ render_spec,
+ )
+
+ def add_to_history(
+ self, session_id: str, prompt: str, output_text: str, response_id: str
+ ) -> None:
+ """Add a new entry to the chat history."""
+ if session_id not in self.history:
+ self.history[session_id] = []
+ self.history[session_id].append(
+ {
+ "prompt": prompt,
+ "response": output_text,
+ "response_id": response_id,
+ "timestamp": time.time(),
+ }
+ )
+
+ def parse_output_text_from_response(self, response: OpenAIResponse) -> str:
+ """Extract text output from the API response."""
+ message_contents = self.get_message_content_from_response(response)
+ return self.get_text_from_message_content(message_contents)
+
+ def generate_new_session_id(self) -> str:
+ """Generate a new unique session ID."""
+ return str(uuid.uuid4())
+
+ def get_session_id(self, persist_context: bool) -> str:
+ """Get the current or generate a new session ID based on context persistence."""
+ return (
+ self.current_session_id
+ if persist_context
+ else self.generate_new_session_id()
+ )
+
+ def tensor_to_input_image_content(
+ self, image: torch.Tensor, detail_level: Detail = "auto"
+ ) -> InputImageContent:
+ """Convert a tensor to an input image content object."""
+ return InputImageContent(
+ detail=detail_level,
+ image_url=f"data:image/png;base64,{tensor_to_base64_string(image)}",
+ type="input_image",
+ )
+
+ def create_input_message_contents(
+ self,
+ prompt: str,
+ image: Optional[torch.Tensor] = None,
+ files: Optional[list[InputFileContent]] = None,
+ ) -> InputMessageContentList:
+ """Create a list of input message contents from prompt and optional image."""
+ content_list: list[InputContent] = [
+ InputTextContent(text=prompt, type="input_text"),
+ ]
+ if image is not None:
+ for i in range(image.shape[0]):
+ content_list.append(
+ self.tensor_to_input_image_content(image[i].unsqueeze(0))
+ )
+ if files is not None:
+ content_list.extend(files)
+
+ return InputMessageContentList(
+ root=content_list,
+ )
+
+ def parse_response_id_from_prompt(self, prompt: str) -> Optional[str]:
+ """Extract response ID from prompt if it exists."""
+ parsed_id = re.search(STARTING_POINT_ID_PATTERN, prompt)
+ return parsed_id.group(1) if parsed_id else None
+
+ def strip_response_tag_from_prompt(self, prompt: str) -> str:
+ """Remove the response ID tag from the prompt."""
+ return re.sub(STARTING_POINT_ID_PATTERN, "", prompt.strip())
+
+ def delete_history_after_response_id(
+ self, new_start_id: str, session_id: str
+ ) -> None:
+ """Delete history entries after a specific response ID."""
+ if session_id not in self.history:
+ return
+
+ new_history = []
+ i = 0
+ while (
+ i < len(self.history[session_id])
+ and self.history[session_id][i]["response_id"] != new_start_id
+ ):
+ new_history.append(self.history[session_id][i])
+ i += 1
+
+ # Since it's the new starting point (not the response being edited), we include it as well
+ if i < len(self.history[session_id]):
+ new_history.append(self.history[session_id][i])
+
+ self.history[session_id] = new_history
+
+ def api_call(
+ self,
+ prompt: str,
+ persist_context: bool,
+ model: SupportedOpenAIModel,
+ unique_id: Optional[str] = None,
+ images: Optional[torch.Tensor] = None,
+ files: Optional[list[InputFileContent]] = None,
+ advanced_options: Optional[CreateModelResponseProperties] = None,
+ **kwargs,
+ ) -> tuple[str]:
+ # Validate inputs
+ validate_string(prompt, strip_whitespace=False)
+
+ session_id = self.get_session_id(persist_context)
+ response_id_override = self.parse_response_id_from_prompt(prompt)
+ if response_id_override:
+ is_starting_from_beginning = response_id_override == "start"
+ if is_starting_from_beginning:
+ self.history[session_id] = []
+ previous_response_id = None
+ else:
+ previous_response_id = response_id_override
+ self.delete_history_after_response_id(response_id_override, session_id)
+ prompt = self.strip_response_tag_from_prompt(prompt)
+ elif persist_context:
+ previous_response_id = self.previous_response_id
+ else:
+ previous_response_id = None
+
+ # Create response
+ create_response = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=RESPONSES_ENDPOINT,
+ method=HttpMethod.POST,
+ request_model=OpenAICreateResponse,
+ response_model=OpenAIResponse,
+ ),
+ request=OpenAICreateResponse(
+ input=[
+ Item(
+ root=InputMessage(
+ content=self.create_input_message_contents(
+ prompt, images, files
+ ),
+ role="user",
+ )
+ ),
+ ],
+ store=True,
+ stream=False,
+ model=model,
+ previous_response_id=previous_response_id,
+ **(
+ advanced_options.model_dump(exclude_none=True)
+ if advanced_options
+ else {}
+ ),
+ ),
+ auth_kwargs=kwargs,
+ ).execute()
+ response_id = create_response.id
+
+ # Get result output
+ result_response = self.get_result_response(response_id, auth_kwargs=kwargs)
+ output_text = self.parse_output_text_from_response(result_response)
+
+ # Update history
+ self.add_to_history(session_id, prompt, output_text, response_id)
+ self.display_history_on_node(session_id, unique_id)
+ self.previous_response_id = response_id
+
+ return (output_text,)
+
+
+class OpenAIInputFiles(ComfyNodeABC):
+ """
+ Loads and formats input files for OpenAI API.
+ """
+
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypeDict:
+ """
+ For details about the supported file input types, see:
+ https://platform.openai.com/docs/guides/pdf-files?api-mode=responses
+ """
+ input_dir = folder_paths.get_input_directory()
+ input_files = [
+ f
+ for f in os.scandir(input_dir)
+ if f.is_file()
+ and (f.name.endswith(".txt") or f.name.endswith(".pdf"))
+ and f.stat().st_size < 32 * 1024 * 1024
+ ]
+ input_files = sorted(input_files, key=lambda x: x.name)
+ input_files = [f.name for f in input_files]
+ return {
+ "required": {
+ "file": (
+ IO.COMBO,
+ {
+ "tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.",
+ "options": input_files,
+ "default": input_files[0] if input_files else None,
+ },
+ ),
+ },
+ "optional": {
+ "OPENAI_INPUT_FILES": (
+ "OPENAI_INPUT_FILES",
+ {
+ "tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.",
+ "default": None,
+ },
+ ),
+ },
+ }
+
+ DESCRIPTION = "Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes."
+ RETURN_TYPES = ("OPENAI_INPUT_FILES",)
+ FUNCTION = "prepare_files"
+ CATEGORY = "api node/text/OpenAI"
+
+ def create_input_file_content(self, file_path: str) -> InputFileContent:
+ return InputFileContent(
+ file_data=text_filepath_to_data_uri(file_path),
+ filename=os.path.basename(file_path),
+ type="input_file",
+ )
+
+ def prepare_files(
+ self, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = []
+ ) -> tuple[list[InputFileContent]]:
+ """
+ Loads and formats input files for OpenAI API.
+ """
+ file_path = folder_paths.get_annotated_filepath(file)
+ input_file_content = self.create_input_file_content(file_path)
+ files = [input_file_content] + OPENAI_INPUT_FILES
+ return (files,)
+
+
+class OpenAIChatConfig(ComfyNodeABC):
+ """Allows setting additional configuration for the OpenAI Chat Node."""
+
+ RETURN_TYPES = ("OPENAI_CHAT_CONFIG",)
+ FUNCTION = "configure"
+ DESCRIPTION = (
+ "Allows specifying advanced configuration options for the OpenAI Chat Nodes."
+ )
+ CATEGORY = "api node/text/OpenAI"
+
+ @classmethod
+ def INPUT_TYPES(cls) -> InputTypeDict:
+ return {
+ "required": {
+ "truncation": (
+ IO.COMBO,
+ {
+ "options": ["auto", "disabled"],
+ "default": "auto",
+ "tooltip": "The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error",
+ },
+ ),
+ },
+ "optional": {
+ "max_output_tokens": model_field_to_node_input(
+ IO.INT,
+ OpenAICreateResponse,
+ "max_output_tokens",
+ min=16,
+ default=4096,
+ max=16384,
+ tooltip="An upper bound for the number of tokens that can be generated for a response, including visible output tokens",
+ ),
+ "instructions": model_field_to_node_input(
+ IO.STRING, OpenAICreateResponse, "instructions", multiline=True
+ ),
+ },
+ }
+
+ def configure(
+ self,
+ truncation: bool,
+ instructions: Optional[str] = None,
+ max_output_tokens: Optional[int] = None,
+ ) -> tuple[CreateModelResponseProperties]:
+ """
+ Configure advanced options for the OpenAI Chat Node.
+
+ Note:
+ While `top_p` and `temperature` are listed as properties in the
+ spec, they are not supported for all models (e.g., o4-mini).
+ They are not exposed as inputs at all to avoid having to manually
+ remove depending on model choice.
+ """
+ return (
+ CreateModelResponseProperties(
+ instructions=instructions,
+ truncation=truncation,
+ max_output_tokens=max_output_tokens,
+ ),
+ )
+
+
+NODE_CLASS_MAPPINGS = {
+ "OpenAIDalle2": OpenAIDalle2,
+ "OpenAIDalle3": OpenAIDalle3,
+ "OpenAIGPTImage1": OpenAIGPTImage1,
+ "OpenAIChatNode": OpenAIChatNode,
+ "OpenAIInputFiles": OpenAIInputFiles,
+ "OpenAIChatConfig": OpenAIChatConfig,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "OpenAIDalle2": "OpenAI DALL·E 2",
+ "OpenAIDalle3": "OpenAI DALL·E 3",
+ "OpenAIGPTImage1": "OpenAI GPT Image 1",
+ "OpenAIChatNode": "OpenAI Chat",
+ "OpenAIInputFiles": "OpenAI Chat Input Files",
+ "OpenAIChatConfig": "OpenAI Chat Advanced Options",
+}
diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py
new file mode 100644
index 000000000..30562790a
--- /dev/null
+++ b/comfy_api_nodes/nodes_pika.py
@@ -0,0 +1,779 @@
+"""
+Pika x ComfyUI API Nodes
+
+Pika API docs: https://pika-827374fb.mintlify.app/api-reference
+"""
+from __future__ import annotations
+
+import io
+from typing import Optional, TypeVar
+import logging
+import torch
+import numpy as np
+from comfy_api_nodes.apis import (
+ PikaBodyGenerate22T2vGenerate22T2vPost,
+ PikaGenerateResponse,
+ PikaBodyGenerate22I2vGenerate22I2vPost,
+ PikaVideoResponse,
+ PikaBodyGenerate22C2vGenerate22PikascenesPost,
+ IngredientsMode,
+ PikaDurationEnum,
+ PikaResolutionEnum,
+ PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
+ PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
+ PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
+ PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
+ Pikaffect,
+)
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+ PollingOperation,
+ EmptyRequest,
+)
+from comfy_api_nodes.apinode_utils import (
+ tensor_to_bytesio,
+ download_url_to_video_output,
+)
+from comfy_api_nodes.mapper_utils import model_field_to_node_input
+from comfy_api.input_impl.video_types import VideoInput, VideoContainer, VideoCodec
+from comfy_api.input_impl import VideoFromFile
+from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeOptions
+
+R = TypeVar("R")
+
+PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions"
+PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps"
+PATH_PIKAFFECTS = "/proxy/pika/generate/pikaffects"
+
+PIKA_API_VERSION = "2.2"
+PATH_TEXT_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/t2v"
+PATH_IMAGE_TO_VIDEO = f"/proxy/pika/generate/{PIKA_API_VERSION}/i2v"
+PATH_PIKAFRAMES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikaframes"
+PATH_PIKASCENES = f"/proxy/pika/generate/{PIKA_API_VERSION}/pikascenes"
+
+PATH_VIDEO_GET = "/proxy/pika/videos"
+
+
+class PikaApiError(Exception):
+ """Exception for Pika API errors."""
+
+ pass
+
+
+def is_valid_video_response(response: PikaVideoResponse) -> bool:
+ """Check if the video response is valid."""
+ return hasattr(response, "url") and response.url is not None
+
+
+def is_valid_initial_response(response: PikaGenerateResponse) -> bool:
+ """Check if the initial response is valid."""
+ return hasattr(response, "video_id") and response.video_id is not None
+
+
+class PikaNodeBase(ComfyNodeABC):
+ """Base class for Pika nodes."""
+
+ @classmethod
+ def get_base_inputs_types(
+ cls, request_model
+ ) -> dict[str, tuple[IO, InputTypeOptions]]:
+ """Get the base required inputs types common to all Pika nodes."""
+ return {
+ "prompt_text": model_field_to_node_input(
+ IO.STRING,
+ request_model,
+ "promptText",
+ multiline=True,
+ ),
+ "negative_prompt": model_field_to_node_input(
+ IO.STRING,
+ request_model,
+ "negativePrompt",
+ multiline=True,
+ ),
+ "seed": model_field_to_node_input(
+ IO.INT,
+ request_model,
+ "seed",
+ min=0,
+ max=0xFFFFFFFF,
+ control_after_generate=True,
+ ),
+ "resolution": model_field_to_node_input(
+ IO.COMBO,
+ request_model,
+ "resolution",
+ enum_type=PikaResolutionEnum,
+ ),
+ "duration": model_field_to_node_input(
+ IO.COMBO,
+ request_model,
+ "duration",
+ enum_type=PikaDurationEnum,
+ ),
+ }
+
+ CATEGORY = "api node/video/Pika"
+ API_NODE = True
+ FUNCTION = "api_call"
+ RETURN_TYPES = ("VIDEO",)
+
+ def poll_for_task_status(
+ self,
+ task_id: str,
+ auth_kwargs: Optional[dict[str, str]] = None,
+ node_id: Optional[str] = None,
+ ) -> PikaGenerateResponse:
+ polling_operation = PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path=f"{PATH_VIDEO_GET}/{task_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=PikaVideoResponse,
+ ),
+ completed_statuses=[
+ "finished",
+ ],
+ failed_statuses=["failed", "cancelled"],
+ status_extractor=lambda response: (
+ response.status.value if response.status else None
+ ),
+ progress_extractor=lambda response: (
+ response.progress if hasattr(response, "progress") else None
+ ),
+ auth_kwargs=auth_kwargs,
+ result_url_extractor=lambda response: (
+ response.url if hasattr(response, "url") else None
+ ),
+ node_id=node_id,
+ estimated_duration=60
+ )
+ return polling_operation.execute()
+
+ def execute_task(
+ self,
+ initial_operation: SynchronousOperation[R, PikaGenerateResponse],
+ auth_kwargs: Optional[dict[str, str]] = None,
+ node_id: Optional[str] = None,
+ ) -> tuple[VideoFromFile]:
+ """Executes the initial operation then polls for the task status until it is completed.
+
+ Args:
+ initial_operation: The initial operation to execute.
+ auth_kwargs: The authentication token(s) to use for the API call.
+
+ Returns:
+ A tuple containing the video file as a VIDEO output.
+ """
+ initial_response = initial_operation.execute()
+ if not is_valid_initial_response(initial_response):
+ error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}"
+ logging.error(error_msg)
+ raise PikaApiError(error_msg)
+
+ task_id = initial_response.video_id
+ final_response = self.poll_for_task_status(task_id, auth_kwargs)
+ if not is_valid_video_response(final_response):
+ error_msg = (
+ f"Pika task {task_id} succeeded but no video data found in response."
+ )
+ logging.error(error_msg)
+ raise PikaApiError(error_msg)
+
+ video_url = str(final_response.url)
+ logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url)
+
+ return (download_url_to_video_output(video_url),)
+
+
+class PikaImageToVideoV2_2(PikaNodeBase):
+ """Pika 2.2 Image to Video Node."""
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "image": (
+ IO.IMAGE,
+ {"tooltip": "The image to convert to video"},
+ ),
+ **cls.get_base_inputs_types(PikaBodyGenerate22I2vGenerate22I2vPost),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video."
+
+ def api_call(
+ self,
+ image: torch.Tensor,
+ prompt_text: str,
+ negative_prompt: str,
+ seed: int,
+ resolution: str,
+ duration: int,
+ unique_id: str,
+ **kwargs,
+ ) -> tuple[VideoFromFile]:
+ # Convert image to BytesIO
+ image_bytes_io = tensor_to_bytesio(image)
+ image_bytes_io.seek(0)
+
+ pika_files = {"image": ("image.png", image_bytes_io, "image/png")}
+
+ # Prepare non-file data
+ pika_request_data = PikaBodyGenerate22I2vGenerate22I2vPost(
+ promptText=prompt_text,
+ negativePrompt=negative_prompt,
+ seed=seed,
+ resolution=resolution,
+ duration=duration,
+ )
+
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_IMAGE_TO_VIDEO,
+ method=HttpMethod.POST,
+ request_model=PikaBodyGenerate22I2vGenerate22I2vPost,
+ response_model=PikaGenerateResponse,
+ ),
+ request=pika_request_data,
+ files=pika_files,
+ content_type="multipart/form-data",
+ auth_kwargs=kwargs,
+ )
+
+ return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
+
+
+class PikaTextToVideoNodeV2_2(PikaNodeBase):
+ """Pika Text2Video v2.2 Node."""
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ **cls.get_base_inputs_types(PikaBodyGenerate22T2vGenerate22T2vPost),
+ "aspect_ratio": model_field_to_node_input(
+ IO.FLOAT,
+ PikaBodyGenerate22T2vGenerate22T2vPost,
+ "aspectRatio",
+ step=0.001,
+ min=0.4,
+ max=2.5,
+ default=1.7777777777777777,
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video."
+
+ def api_call(
+ self,
+ prompt_text: str,
+ negative_prompt: str,
+ seed: int,
+ resolution: str,
+ duration: int,
+ aspect_ratio: float,
+ unique_id: str,
+ **kwargs,
+ ) -> tuple[VideoFromFile]:
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_TEXT_TO_VIDEO,
+ method=HttpMethod.POST,
+ request_model=PikaBodyGenerate22T2vGenerate22T2vPost,
+ response_model=PikaGenerateResponse,
+ ),
+ request=PikaBodyGenerate22T2vGenerate22T2vPost(
+ promptText=prompt_text,
+ negativePrompt=negative_prompt,
+ seed=seed,
+ resolution=resolution,
+ duration=duration,
+ aspectRatio=aspect_ratio,
+ ),
+ auth_kwargs=kwargs,
+ content_type="application/x-www-form-urlencoded",
+ )
+
+ return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
+
+
+class PikaScenesV2_2(PikaNodeBase):
+ """PikaScenes v2.2 Node."""
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ image_ingredient_input = (
+ IO.IMAGE,
+ {"tooltip": "Image that will be used as ingredient to create a video."},
+ )
+ return {
+ "required": {
+ **cls.get_base_inputs_types(
+ PikaBodyGenerate22C2vGenerate22PikascenesPost,
+ ),
+ "ingredients_mode": model_field_to_node_input(
+ IO.COMBO,
+ PikaBodyGenerate22C2vGenerate22PikascenesPost,
+ "ingredientsMode",
+ enum_type=IngredientsMode,
+ default="creative",
+ ),
+ "aspect_ratio": model_field_to_node_input(
+ IO.FLOAT,
+ PikaBodyGenerate22C2vGenerate22PikascenesPost,
+ "aspectRatio",
+ step=0.001,
+ min=0.4,
+ max=2.5,
+ default=1.7777777777777777,
+ ),
+ },
+ "optional": {
+ "image_ingredient_1": image_ingredient_input,
+ "image_ingredient_2": image_ingredient_input,
+ "image_ingredient_3": image_ingredient_input,
+ "image_ingredient_4": image_ingredient_input,
+ "image_ingredient_5": image_ingredient_input,
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them."
+
+ def api_call(
+ self,
+ prompt_text: str,
+ negative_prompt: str,
+ seed: int,
+ resolution: str,
+ duration: int,
+ ingredients_mode: str,
+ aspect_ratio: float,
+ unique_id: str,
+ image_ingredient_1: Optional[torch.Tensor] = None,
+ image_ingredient_2: Optional[torch.Tensor] = None,
+ image_ingredient_3: Optional[torch.Tensor] = None,
+ image_ingredient_4: Optional[torch.Tensor] = None,
+ image_ingredient_5: Optional[torch.Tensor] = None,
+ **kwargs,
+ ) -> tuple[VideoFromFile]:
+ # Convert all passed images to BytesIO
+ all_image_bytes_io = []
+ for image in [
+ image_ingredient_1,
+ image_ingredient_2,
+ image_ingredient_3,
+ image_ingredient_4,
+ image_ingredient_5,
+ ]:
+ if image is not None:
+ image_bytes_io = tensor_to_bytesio(image)
+ image_bytes_io.seek(0)
+ all_image_bytes_io.append(image_bytes_io)
+
+ pika_files = [
+ ("images", (f"image_{i}.png", image_bytes_io, "image/png"))
+ for i, image_bytes_io in enumerate(all_image_bytes_io)
+ ]
+
+ pika_request_data = PikaBodyGenerate22C2vGenerate22PikascenesPost(
+ ingredientsMode=ingredients_mode,
+ promptText=prompt_text,
+ negativePrompt=negative_prompt,
+ seed=seed,
+ resolution=resolution,
+ duration=duration,
+ aspectRatio=aspect_ratio,
+ )
+
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_PIKASCENES,
+ method=HttpMethod.POST,
+ request_model=PikaBodyGenerate22C2vGenerate22PikascenesPost,
+ response_model=PikaGenerateResponse,
+ ),
+ request=pika_request_data,
+ files=pika_files,
+ content_type="multipart/form-data",
+ auth_kwargs=kwargs,
+ )
+
+ return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
+
+
+class PikAdditionsNode(PikaNodeBase):
+ """Pika Pikadditions Node. Add an image into a video."""
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "video": (IO.VIDEO, {"tooltip": "The video to add an image to."}),
+ "image": (IO.IMAGE, {"tooltip": "The image to add to the video."}),
+ "prompt_text": model_field_to_node_input(
+ IO.STRING,
+ PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
+ "promptText",
+ multiline=True,
+ ),
+ "negative_prompt": model_field_to_node_input(
+ IO.STRING,
+ PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
+ "negativePrompt",
+ multiline=True,
+ ),
+ "seed": model_field_to_node_input(
+ IO.INT,
+ PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
+ "seed",
+ min=0,
+ max=0xFFFFFFFF,
+ control_after_generate=True,
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you’d like to add to create a seamlessly integrated result."
+
+ def api_call(
+ self,
+ video: VideoInput,
+ image: torch.Tensor,
+ prompt_text: str,
+ negative_prompt: str,
+ seed: int,
+ unique_id: str,
+ **kwargs,
+ ) -> tuple[VideoFromFile]:
+ # Convert video to BytesIO
+ video_bytes_io = io.BytesIO()
+ video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
+ video_bytes_io.seek(0)
+
+ # Convert image to BytesIO
+ image_bytes_io = tensor_to_bytesio(image)
+ image_bytes_io.seek(0)
+
+ pika_files = [
+ ("video", ("video.mp4", video_bytes_io, "video/mp4")),
+ ("image", ("image.png", image_bytes_io, "image/png")),
+ ]
+
+ # Prepare non-file data
+ pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost(
+ promptText=prompt_text,
+ negativePrompt=negative_prompt,
+ seed=seed,
+ )
+
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_PIKADDITIONS,
+ method=HttpMethod.POST,
+ request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
+ response_model=PikaGenerateResponse,
+ ),
+ request=pika_request_data,
+ files=pika_files,
+ content_type="multipart/form-data",
+ auth_kwargs=kwargs,
+ )
+
+ return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
+
+
+class PikaSwapsNode(PikaNodeBase):
+ """Pika Pikaswaps Node."""
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "video": (IO.VIDEO, {"tooltip": "The video to swap an object in."}),
+ "image": (
+ IO.IMAGE,
+ {
+ "tooltip": "The image used to replace the masked object in the video."
+ },
+ ),
+ "mask": (
+ IO.MASK,
+ {"tooltip": "Use the mask to define areas in the video to replace"},
+ ),
+ "prompt_text": model_field_to_node_input(
+ IO.STRING,
+ PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
+ "promptText",
+ multiline=True,
+ ),
+ "negative_prompt": model_field_to_node_input(
+ IO.STRING,
+ PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
+ "negativePrompt",
+ multiline=True,
+ ),
+ "seed": model_field_to_node_input(
+ IO.INT,
+ PikaBodyGeneratePikaswapsGeneratePikaswapsPost,
+ "seed",
+ min=0,
+ max=0xFFFFFFFF,
+ control_after_generate=True,
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates."
+ RETURN_TYPES = ("VIDEO",)
+
+ def api_call(
+ self,
+ video: VideoInput,
+ image: torch.Tensor,
+ mask: torch.Tensor,
+ prompt_text: str,
+ negative_prompt: str,
+ seed: int,
+ unique_id: str,
+ **kwargs,
+ ) -> tuple[VideoFromFile]:
+ # Convert video to BytesIO
+ video_bytes_io = io.BytesIO()
+ video.save_to(video_bytes_io, format=VideoContainer.MP4, codec=VideoCodec.H264)
+ video_bytes_io.seek(0)
+
+ # Convert mask to binary mask with three channels
+ mask = torch.round(mask)
+ mask = mask.repeat(1, 3, 1, 1)
+
+ # Convert 3-channel binary mask to BytesIO
+ mask_bytes_io = io.BytesIO()
+ mask_bytes_io.write(mask.numpy().astype(np.uint8))
+ mask_bytes_io.seek(0)
+
+ # Convert image to BytesIO
+ image_bytes_io = tensor_to_bytesio(image)
+ image_bytes_io.seek(0)
+
+ pika_files = [
+ ("video", ("video.mp4", video_bytes_io, "video/mp4")),
+ ("image", ("image.png", image_bytes_io, "image/png")),
+ ("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")),
+ ]
+
+ # Prepare non-file data
+ pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost(
+ promptText=prompt_text,
+ negativePrompt=negative_prompt,
+ seed=seed,
+ )
+
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_PIKADDITIONS,
+ method=HttpMethod.POST,
+ request_model=PikaBodyGeneratePikadditionsGeneratePikadditionsPost,
+ response_model=PikaGenerateResponse,
+ ),
+ request=pika_request_data,
+ files=pika_files,
+ content_type="multipart/form-data",
+ auth_kwargs=kwargs,
+ )
+
+ return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
+
+
+class PikaffectsNode(PikaNodeBase):
+ """Pika Pikaffects Node."""
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "image": (
+ IO.IMAGE,
+ {"tooltip": "The reference image to apply the Pikaffect to."},
+ ),
+ "pikaffect": model_field_to_node_input(
+ IO.COMBO,
+ PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
+ "pikaffect",
+ enum_type=Pikaffect,
+ default="Cake-ify",
+ ),
+ "prompt_text": model_field_to_node_input(
+ IO.STRING,
+ PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
+ "promptText",
+ multiline=True,
+ ),
+ "negative_prompt": model_field_to_node_input(
+ IO.STRING,
+ PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
+ "negativePrompt",
+ multiline=True,
+ ),
+ "seed": model_field_to_node_input(
+ IO.INT,
+ PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
+ "seed",
+ min=0,
+ max=0xFFFFFFFF,
+ control_after_generate=True,
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear"
+
+ def api_call(
+ self,
+ image: torch.Tensor,
+ pikaffect: str,
+ prompt_text: str,
+ negative_prompt: str,
+ seed: int,
+ unique_id: str,
+ **kwargs,
+ ) -> tuple[VideoFromFile]:
+
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_PIKAFFECTS,
+ method=HttpMethod.POST,
+ request_model=PikaBodyGeneratePikaffectsGeneratePikaffectsPost,
+ response_model=PikaGenerateResponse,
+ ),
+ request=PikaBodyGeneratePikaffectsGeneratePikaffectsPost(
+ pikaffect=pikaffect,
+ promptText=prompt_text,
+ negativePrompt=negative_prompt,
+ seed=seed,
+ ),
+ files={"image": ("image.png", tensor_to_bytesio(image), "image/png")},
+ content_type="multipart/form-data",
+ auth_kwargs=kwargs,
+ )
+
+ return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
+
+
+class PikaStartEndFrameNode2_2(PikaNodeBase):
+ """PikaFrames v2.2 Node."""
+
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "image_start": (IO.IMAGE, {"tooltip": "The first image to combine."}),
+ "image_end": (IO.IMAGE, {"tooltip": "The last image to combine."}),
+ **cls.get_base_inputs_types(
+ PikaBodyGenerate22KeyframeGenerate22PikaframesPost
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them."
+
+ def api_call(
+ self,
+ image_start: torch.Tensor,
+ image_end: torch.Tensor,
+ prompt_text: str,
+ negative_prompt: str,
+ seed: int,
+ resolution: str,
+ duration: int,
+ unique_id: str,
+ **kwargs,
+ ) -> tuple[VideoFromFile]:
+
+ pika_files = [
+ (
+ "keyFrames",
+ ("image_start.png", tensor_to_bytesio(image_start), "image/png"),
+ ),
+ ("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")),
+ ]
+
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_PIKAFRAMES,
+ method=HttpMethod.POST,
+ request_model=PikaBodyGenerate22KeyframeGenerate22PikaframesPost,
+ response_model=PikaGenerateResponse,
+ ),
+ request=PikaBodyGenerate22KeyframeGenerate22PikaframesPost(
+ promptText=prompt_text,
+ negativePrompt=negative_prompt,
+ seed=seed,
+ resolution=resolution,
+ duration=duration,
+ ),
+ files=pika_files,
+ content_type="multipart/form-data",
+ auth_kwargs=kwargs,
+ )
+
+ return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id)
+
+
+NODE_CLASS_MAPPINGS = {
+ "PikaImageToVideoNode2_2": PikaImageToVideoV2_2,
+ "PikaTextToVideoNode2_2": PikaTextToVideoNodeV2_2,
+ "PikaScenesV2_2": PikaScenesV2_2,
+ "Pikadditions": PikAdditionsNode,
+ "Pikaswaps": PikaSwapsNode,
+ "Pikaffects": PikaffectsNode,
+ "PikaStartEndFrameNode2_2": PikaStartEndFrameNode2_2,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "PikaImageToVideoNode2_2": "Pika Image to Video",
+ "PikaTextToVideoNode2_2": "Pika Text to Video",
+ "PikaScenesV2_2": "Pika Scenes (Video Image Composition)",
+ "Pikadditions": "Pikadditions (Video Object Insertion)",
+ "Pikaswaps": "Pika Swaps (Video Object Replacement)",
+ "Pikaffects": "Pikaffects (Video Effects)",
+ "PikaStartEndFrameNode2_2": "Pika Start and End Frame to Video",
+}
diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py
new file mode 100644
index 000000000..ef4a9a802
--- /dev/null
+++ b/comfy_api_nodes/nodes_pixverse.py
@@ -0,0 +1,525 @@
+from inspect import cleandoc
+from typing import Optional
+from comfy_api_nodes.apis.pixverse_api import (
+ PixverseTextVideoRequest,
+ PixverseImageVideoRequest,
+ PixverseTransitionVideoRequest,
+ PixverseImageUploadResponse,
+ PixverseVideoResponse,
+ PixverseGenerationStatusResponse,
+ PixverseAspectRatio,
+ PixverseQuality,
+ PixverseDuration,
+ PixverseMotionMode,
+ PixverseStatus,
+ PixverseIO,
+ pixverse_templates,
+)
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+ PollingOperation,
+ EmptyRequest,
+)
+from comfy_api_nodes.apinode_utils import (
+ tensor_to_bytesio,
+ validate_string,
+)
+from comfy.comfy_types.node_typing import IO, ComfyNodeABC
+from comfy_api.input_impl import VideoFromFile
+
+import torch
+import requests
+from io import BytesIO
+
+
+AVERAGE_DURATION_T2V = 32
+AVERAGE_DURATION_I2V = 30
+AVERAGE_DURATION_T2T = 52
+
+
+def get_video_url_from_response(
+ response: PixverseGenerationStatusResponse,
+) -> Optional[str]:
+ if response.Resp is None or response.Resp.url is None:
+ return None
+ return str(response.Resp.url)
+
+
+def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None):
+ # first, upload image to Pixverse and get image id to use in actual generation call
+ files = {"image": tensor_to_bytesio(image)}
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/pixverse/image/upload",
+ method=HttpMethod.POST,
+ request_model=EmptyRequest,
+ response_model=PixverseImageUploadResponse,
+ ),
+ request=EmptyRequest(),
+ files=files,
+ content_type="multipart/form-data",
+ auth_kwargs=auth_kwargs,
+ )
+ response_upload: PixverseImageUploadResponse = operation.execute()
+
+ if response_upload.Resp is None:
+ raise Exception(
+ f"PixVerse image upload request failed: '{response_upload.ErrMsg}'"
+ )
+
+ return response_upload.Resp.img_id
+
+
+class PixverseTemplateNode:
+ """
+ Select template for PixVerse Video generation.
+ """
+
+ RETURN_TYPES = (PixverseIO.TEMPLATE,)
+ RETURN_NAMES = ("pixverse_template",)
+ FUNCTION = "create_template"
+ CATEGORY = "api node/video/PixVerse"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "template": (list(pixverse_templates.keys()),),
+ }
+ }
+
+ def create_template(self, template: str):
+ template_id = pixverse_templates.get(template, None)
+ if template_id is None:
+ raise Exception(f"Template '{template}' is not recognized.")
+ # just return the integer
+ return (template_id,)
+
+
+class PixverseTextToVideoNode(ComfyNodeABC):
+ """
+ Generates videos based on prompt and output_size.
+ """
+
+ RETURN_TYPES = (IO.VIDEO,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/video/PixVerse"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the video generation",
+ },
+ ),
+ "aspect_ratio": ([ratio.value for ratio in PixverseAspectRatio],),
+ "quality": (
+ [resolution.value for resolution in PixverseQuality],
+ {
+ "default": PixverseQuality.res_540p,
+ },
+ ),
+ "duration_seconds": ([dur.value for dur in PixverseDuration],),
+ "motion_mode": ([mode.value for mode in PixverseMotionMode],),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 2147483647,
+ "control_after_generate": True,
+ "tooltip": "Seed for video generation.",
+ },
+ ),
+ },
+ "optional": {
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "default": "",
+ "forceInput": True,
+ "tooltip": "An optional text description of undesired elements on an image.",
+ },
+ ),
+ "pixverse_template": (
+ PixverseIO.TEMPLATE,
+ {
+ "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ def api_call(
+ self,
+ prompt: str,
+ aspect_ratio: str,
+ quality: str,
+ duration_seconds: int,
+ motion_mode: str,
+ seed,
+ negative_prompt: str = None,
+ pixverse_template: int = None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ validate_string(prompt, strip_whitespace=False)
+ # 1080p is limited to 5 seconds duration
+ # only normal motion_mode supported for 1080p or for non-5 second duration
+ if quality == PixverseQuality.res_1080p:
+ motion_mode = PixverseMotionMode.normal
+ duration_seconds = PixverseDuration.dur_5
+ elif duration_seconds != PixverseDuration.dur_5:
+ motion_mode = PixverseMotionMode.normal
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/pixverse/video/text/generate",
+ method=HttpMethod.POST,
+ request_model=PixverseTextVideoRequest,
+ response_model=PixverseVideoResponse,
+ ),
+ request=PixverseTextVideoRequest(
+ prompt=prompt,
+ aspect_ratio=aspect_ratio,
+ quality=quality,
+ duration=duration_seconds,
+ motion_mode=motion_mode,
+ negative_prompt=negative_prompt if negative_prompt else None,
+ template_id=pixverse_template,
+ seed=seed,
+ ),
+ auth_kwargs=kwargs,
+ )
+ response_api = operation.execute()
+
+ if response_api.Resp is None:
+ raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
+
+ operation = PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=PixverseGenerationStatusResponse,
+ ),
+ completed_statuses=[PixverseStatus.successful],
+ failed_statuses=[
+ PixverseStatus.contents_moderation,
+ PixverseStatus.failed,
+ PixverseStatus.deleted,
+ ],
+ status_extractor=lambda x: x.Resp.status,
+ auth_kwargs=kwargs,
+ node_id=unique_id,
+ result_url_extractor=get_video_url_from_response,
+ estimated_duration=AVERAGE_DURATION_T2V,
+ )
+ response_poll = operation.execute()
+
+ vid_response = requests.get(response_poll.Resp.url)
+
+ return (VideoFromFile(BytesIO(vid_response.content)),)
+
+
+class PixverseImageToVideoNode(ComfyNodeABC):
+ """
+ Generates videos based on prompt and output_size.
+ """
+
+ RETURN_TYPES = (IO.VIDEO,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/video/PixVerse"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (IO.IMAGE,),
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the video generation",
+ },
+ ),
+ "quality": (
+ [resolution.value for resolution in PixverseQuality],
+ {
+ "default": PixverseQuality.res_540p,
+ },
+ ),
+ "duration_seconds": ([dur.value for dur in PixverseDuration],),
+ "motion_mode": ([mode.value for mode in PixverseMotionMode],),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 2147483647,
+ "control_after_generate": True,
+ "tooltip": "Seed for video generation.",
+ },
+ ),
+ },
+ "optional": {
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "default": "",
+ "forceInput": True,
+ "tooltip": "An optional text description of undesired elements on an image.",
+ },
+ ),
+ "pixverse_template": (
+ PixverseIO.TEMPLATE,
+ {
+ "tooltip": "An optional template to influence style of generation, created by the PixVerse Template node."
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ def api_call(
+ self,
+ image: torch.Tensor,
+ prompt: str,
+ quality: str,
+ duration_seconds: int,
+ motion_mode: str,
+ seed,
+ negative_prompt: str = None,
+ pixverse_template: int = None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ validate_string(prompt, strip_whitespace=False)
+ img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs)
+
+ # 1080p is limited to 5 seconds duration
+ # only normal motion_mode supported for 1080p or for non-5 second duration
+ if quality == PixverseQuality.res_1080p:
+ motion_mode = PixverseMotionMode.normal
+ duration_seconds = PixverseDuration.dur_5
+ elif duration_seconds != PixverseDuration.dur_5:
+ motion_mode = PixverseMotionMode.normal
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/pixverse/video/img/generate",
+ method=HttpMethod.POST,
+ request_model=PixverseImageVideoRequest,
+ response_model=PixverseVideoResponse,
+ ),
+ request=PixverseImageVideoRequest(
+ img_id=img_id,
+ prompt=prompt,
+ quality=quality,
+ duration=duration_seconds,
+ motion_mode=motion_mode,
+ negative_prompt=negative_prompt if negative_prompt else None,
+ template_id=pixverse_template,
+ seed=seed,
+ ),
+ auth_kwargs=kwargs,
+ )
+ response_api = operation.execute()
+
+ if response_api.Resp is None:
+ raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
+
+ operation = PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=PixverseGenerationStatusResponse,
+ ),
+ completed_statuses=[PixverseStatus.successful],
+ failed_statuses=[
+ PixverseStatus.contents_moderation,
+ PixverseStatus.failed,
+ PixverseStatus.deleted,
+ ],
+ status_extractor=lambda x: x.Resp.status,
+ auth_kwargs=kwargs,
+ node_id=unique_id,
+ result_url_extractor=get_video_url_from_response,
+ estimated_duration=AVERAGE_DURATION_I2V,
+ )
+ response_poll = operation.execute()
+
+ vid_response = requests.get(response_poll.Resp.url)
+ return (VideoFromFile(BytesIO(vid_response.content)),)
+
+
+class PixverseTransitionVideoNode(ComfyNodeABC):
+ """
+ Generates videos based on prompt and output_size.
+ """
+
+ RETURN_TYPES = (IO.VIDEO,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/video/PixVerse"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "first_frame": (IO.IMAGE,),
+ "last_frame": (IO.IMAGE,),
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the video generation",
+ },
+ ),
+ "quality": (
+ [resolution.value for resolution in PixverseQuality],
+ {
+ "default": PixverseQuality.res_540p,
+ },
+ ),
+ "duration_seconds": ([dur.value for dur in PixverseDuration],),
+ "motion_mode": ([mode.value for mode in PixverseMotionMode],),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 2147483647,
+ "control_after_generate": True,
+ "tooltip": "Seed for video generation.",
+ },
+ ),
+ },
+ "optional": {
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "default": "",
+ "forceInput": True,
+ "tooltip": "An optional text description of undesired elements on an image.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ def api_call(
+ self,
+ first_frame: torch.Tensor,
+ last_frame: torch.Tensor,
+ prompt: str,
+ quality: str,
+ duration_seconds: int,
+ motion_mode: str,
+ seed,
+ negative_prompt: str = None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ validate_string(prompt, strip_whitespace=False)
+ first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs)
+ last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs)
+
+ # 1080p is limited to 5 seconds duration
+ # only normal motion_mode supported for 1080p or for non-5 second duration
+ if quality == PixverseQuality.res_1080p:
+ motion_mode = PixverseMotionMode.normal
+ duration_seconds = PixverseDuration.dur_5
+ elif duration_seconds != PixverseDuration.dur_5:
+ motion_mode = PixverseMotionMode.normal
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/pixverse/video/transition/generate",
+ method=HttpMethod.POST,
+ request_model=PixverseTransitionVideoRequest,
+ response_model=PixverseVideoResponse,
+ ),
+ request=PixverseTransitionVideoRequest(
+ first_frame_img=first_frame_id,
+ last_frame_img=last_frame_id,
+ prompt=prompt,
+ quality=quality,
+ duration=duration_seconds,
+ motion_mode=motion_mode,
+ negative_prompt=negative_prompt if negative_prompt else None,
+ seed=seed,
+ ),
+ auth_kwargs=kwargs,
+ )
+ response_api = operation.execute()
+
+ if response_api.Resp is None:
+ raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'")
+
+ operation = PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=PixverseGenerationStatusResponse,
+ ),
+ completed_statuses=[PixverseStatus.successful],
+ failed_statuses=[
+ PixverseStatus.contents_moderation,
+ PixverseStatus.failed,
+ PixverseStatus.deleted,
+ ],
+ status_extractor=lambda x: x.Resp.status,
+ auth_kwargs=kwargs,
+ node_id=unique_id,
+ result_url_extractor=get_video_url_from_response,
+ estimated_duration=AVERAGE_DURATION_T2V,
+ )
+ response_poll = operation.execute()
+
+ vid_response = requests.get(response_poll.Resp.url)
+ return (VideoFromFile(BytesIO(vid_response.content)),)
+
+
+NODE_CLASS_MAPPINGS = {
+ "PixverseTextToVideoNode": PixverseTextToVideoNode,
+ "PixverseImageToVideoNode": PixverseImageToVideoNode,
+ "PixverseTransitionVideoNode": PixverseTransitionVideoNode,
+ "PixverseTemplateNode": PixverseTemplateNode,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "PixverseTextToVideoNode": "PixVerse Text to Video",
+ "PixverseImageToVideoNode": "PixVerse Image to Video",
+ "PixverseTransitionVideoNode": "PixVerse Transition Video",
+ "PixverseTemplateNode": "PixVerse Template",
+}
diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py
new file mode 100644
index 000000000..e369c4b7e
--- /dev/null
+++ b/comfy_api_nodes/nodes_recraft.py
@@ -0,0 +1,1138 @@
+from __future__ import annotations
+from inspect import cleandoc
+from typing import Optional
+from comfy.utils import ProgressBar
+from comfy_extras.nodes_images import SVG # Added
+from comfy.comfy_types.node_typing import IO
+from comfy_api_nodes.apis.recraft_api import (
+ RecraftImageGenerationRequest,
+ RecraftImageGenerationResponse,
+ RecraftImageSize,
+ RecraftModel,
+ RecraftStyle,
+ RecraftStyleV3,
+ RecraftColor,
+ RecraftColorChain,
+ RecraftControls,
+ RecraftIO,
+ get_v3_substyles,
+)
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+ EmptyRequest,
+)
+from comfy_api_nodes.apinode_utils import (
+ bytesio_to_image_tensor,
+ download_url_to_bytesio,
+ tensor_to_bytesio,
+ resize_mask_to_image,
+ validate_string,
+)
+from server import PromptServer
+
+import torch
+from io import BytesIO
+from PIL import UnidentifiedImageError
+
+
+def handle_recraft_file_request(
+ image: torch.Tensor,
+ path: str,
+ mask: torch.Tensor=None,
+ total_pixels=4096*4096,
+ timeout=1024,
+ request=None,
+ auth_kwargs: dict[str,str] = None,
+ ) -> list[BytesIO]:
+ """
+ Handle sending common Recraft file-only request to get back file bytes.
+ """
+ if request is None:
+ request = EmptyRequest()
+
+ files = {
+ 'image': tensor_to_bytesio(image, total_pixels=total_pixels).read()
+ }
+ if mask is not None:
+ files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read()
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=path,
+ method=HttpMethod.POST,
+ request_model=type(request),
+ response_model=RecraftImageGenerationResponse,
+ ),
+ request=request,
+ files=files,
+ content_type="multipart/form-data",
+ auth_kwargs=auth_kwargs,
+ multipart_parser=recraft_multipart_parser,
+ )
+ response: RecraftImageGenerationResponse = operation.execute()
+ all_bytesio = []
+ if response.image is not None:
+ all_bytesio.append(download_url_to_bytesio(response.image.url, timeout=timeout))
+ else:
+ for data in response.data:
+ all_bytesio.append(download_url_to_bytesio(data.url, timeout=timeout))
+
+ return all_bytesio
+
+
+def recraft_multipart_parser(data, parent_key=None, formatter: callable=None, converted_to_check: list[list]=None, is_list=False) -> dict:
+ """
+ Formats data such that multipart/form-data will work with requests library
+ when both files and data are present.
+
+ The OpenAI client that Recraft uses has a bizarre way of serializing lists:
+
+ It does NOT keep track of indeces of each list, so for background_color, that must be serialized as:
+ 'background_color[rgb][]' = [0, 0, 255]
+ where the array is assigned to a key that has '[]' at the end, to signal it's an array.
+
+ This has the consequence of nested lists having the exact same key, forcing arrays to merge; all colors inputs fall under the same key:
+ if 1 color -> 'controls[colors][][rgb][]' = [0, 0, 255]
+ if 2 colors -> 'controls[colors][][rgb][]' = [0, 0, 255, 255, 0, 0]
+ if 3 colors -> 'controls[colors][][rgb][]' = [0, 0, 255, 255, 0, 0, 0, 255, 0]
+ etc.
+ Whoever made this serialization up at OpenAI added the constraint that lists must be of uniform length on objects of same 'type'.
+ """
+ # Modification of a function that handled a different type of multipart parsing, big ups:
+ # https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b
+
+ def handle_converted_lists(data, parent_key, lists_to_check=tuple[list]):
+ # if list already exists exists, just extend list with data
+ for check_list in lists_to_check:
+ for conv_tuple in check_list:
+ if conv_tuple[0] == parent_key and type(conv_tuple[1]) is list:
+ conv_tuple[1].append(formatter(data))
+ return True
+ return False
+
+ if converted_to_check is None:
+ converted_to_check = []
+
+
+ if formatter is None:
+ formatter = lambda v: v # Multipart representation of value
+
+ if type(data) is not dict:
+ # if list already exists exists, just extend list with data
+ added = handle_converted_lists(data, parent_key, converted_to_check)
+ if added:
+ return {}
+ # otherwise if is_list, create new list with data
+ if is_list:
+ return {parent_key: [formatter(data)]}
+ # return new key with data
+ return {parent_key: formatter(data)}
+
+ converted = []
+ next_check = [converted]
+ next_check.extend(converted_to_check)
+
+ for key, value in data.items():
+ current_key = key if parent_key is None else f"{parent_key}[{key}]"
+ if type(value) is dict:
+ converted.extend(recraft_multipart_parser(value, current_key, formatter, next_check).items())
+ elif type(value) is list:
+ for ind, list_value in enumerate(value):
+ iter_key = f"{current_key}[]"
+ converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items())
+ else:
+ converted.append((current_key, formatter(value)))
+
+ return dict(converted)
+
+
+class handle_recraft_image_output:
+ """
+ Catch an exception related to receiving SVG data instead of image, when Infinite Style Library style_id is in use.
+ """
+ def __init__(self):
+ pass
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if exc_type is not None and exc_type is UnidentifiedImageError:
+ raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.")
+
+
+class RecraftColorRGBNode:
+ """
+ Create Recraft Color by choosing specific RGB values.
+ """
+
+ RETURN_TYPES = (RecraftIO.COLOR,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ RETURN_NAMES = ("recraft_color",)
+ FUNCTION = "create_color"
+ CATEGORY = "api node/image/Recraft"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "r": (IO.INT, {
+ "default": 0,
+ "min": 0,
+ "max": 255,
+ "tooltip": "Red value of color."
+ }),
+ "g": (IO.INT, {
+ "default": 0,
+ "min": 0,
+ "max": 255,
+ "tooltip": "Green value of color."
+ }),
+ "b": (IO.INT, {
+ "default": 0,
+ "min": 0,
+ "max": 255,
+ "tooltip": "Blue value of color."
+ }),
+ },
+ "optional": {
+ "recraft_color": (RecraftIO.COLOR,),
+ }
+ }
+
+ def create_color(self, r: int, g: int, b: int, recraft_color: RecraftColorChain=None):
+ recraft_color = recraft_color.clone() if recraft_color else RecraftColorChain()
+ recraft_color.add(RecraftColor(r, g, b))
+ return (recraft_color, )
+
+
+class RecraftControlsNode:
+ """
+ Create Recraft Controls for customizing Recraft generation.
+ """
+
+ RETURN_TYPES = (RecraftIO.CONTROLS,)
+ RETURN_NAMES = ("recraft_controls",)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "create_controls"
+ CATEGORY = "api node/image/Recraft"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ },
+ "optional": {
+ "colors": (RecraftIO.COLOR,),
+ "background_color": (RecraftIO.COLOR,),
+ }
+ }
+
+ def create_controls(self, colors: RecraftColorChain=None, background_color: RecraftColorChain=None):
+ return (RecraftControls(colors=colors, background_color=background_color), )
+
+
+class RecraftStyleV3RealisticImageNode:
+ """
+ Select realistic_image style and optional substyle.
+ """
+
+ RETURN_TYPES = (RecraftIO.STYLEV3,)
+ RETURN_NAMES = ("recraft_style",)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "create_style"
+ CATEGORY = "api node/image/Recraft"
+
+ RECRAFT_STYLE = RecraftStyleV3.realistic_image
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "substyle": (get_v3_substyles(s.RECRAFT_STYLE),),
+ }
+ }
+
+ def create_style(self, substyle: str):
+ if substyle == "None":
+ substyle = None
+ return (RecraftStyle(self.RECRAFT_STYLE, substyle),)
+
+
+class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode):
+ """
+ Select digital_illustration style and optional substyle.
+ """
+
+ RECRAFT_STYLE = RecraftStyleV3.digital_illustration
+
+
+class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode):
+ """
+ Select vector_illustration style and optional substyle.
+ """
+
+ RECRAFT_STYLE = RecraftStyleV3.vector_illustration
+
+
+class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode):
+ """
+ Select vector_illustration style and optional substyle.
+ """
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "substyle": (get_v3_substyles(s.RECRAFT_STYLE, include_none=False),),
+ }
+ }
+
+ RECRAFT_STYLE = RecraftStyleV3.logo_raster
+
+
+class RecraftStyleInfiniteStyleLibrary:
+ """
+ Select style based on preexisting UUID from Recraft's Infinite Style Library.
+ """
+
+ RETURN_TYPES = (RecraftIO.STYLEV3,)
+ RETURN_NAMES = ("recraft_style",)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "create_style"
+ CATEGORY = "api node/image/Recraft"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "style_id": (IO.STRING, {
+ "default": "",
+ "tooltip": "UUID of style from Infinite Style Library.",
+ })
+ }
+ }
+
+ def create_style(self, style_id: str):
+ if not style_id:
+ raise Exception("The style_id input cannot be empty.")
+ return (RecraftStyle(style_id=style_id),)
+
+
+class RecraftTextToImageNode:
+ """
+ Generates images synchronously based on prompt and resolution.
+ """
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Recraft"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation.",
+ },
+ ),
+ "size": (
+ [res.value for res in RecraftImageSize],
+ {
+ "default": RecraftImageSize.res_1024x1024,
+ "tooltip": "The size of the generated image.",
+ },
+ ),
+ "n": (
+ IO.INT,
+ {
+ "default": 1,
+ "min": 1,
+ "max": 6,
+ "tooltip": "The number of images to generate.",
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
+ },
+ ),
+ },
+ "optional": {
+ "recraft_style": (RecraftIO.STYLEV3,),
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "default": "",
+ "forceInput": True,
+ "tooltip": "An optional text description of undesired elements on an image.",
+ },
+ ),
+ "recraft_controls": (
+ RecraftIO.CONTROLS,
+ {
+ "tooltip": "Optional additional controls over the generation via the Recraft Controls node."
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ def api_call(
+ self,
+ prompt: str,
+ size: str,
+ n: int,
+ seed,
+ recraft_style: RecraftStyle = None,
+ negative_prompt: str = None,
+ recraft_controls: RecraftControls = None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ validate_string(prompt, strip_whitespace=False, max_length=1000)
+ default_style = RecraftStyle(RecraftStyleV3.realistic_image)
+ if recraft_style is None:
+ recraft_style = default_style
+
+ controls_api = None
+ if recraft_controls:
+ controls_api = recraft_controls.create_api_model()
+
+ if not negative_prompt:
+ negative_prompt = None
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/recraft/image_generation",
+ method=HttpMethod.POST,
+ request_model=RecraftImageGenerationRequest,
+ response_model=RecraftImageGenerationResponse,
+ ),
+ request=RecraftImageGenerationRequest(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ model=RecraftModel.recraftv3,
+ size=size,
+ n=n,
+ style=recraft_style.style,
+ substyle=recraft_style.substyle,
+ style_id=recraft_style.style_id,
+ controls=controls_api,
+ ),
+ auth_kwargs=kwargs,
+ )
+ response: RecraftImageGenerationResponse = operation.execute()
+ images = []
+ urls = []
+ for data in response.data:
+ with handle_recraft_image_output():
+ if unique_id and data.url:
+ urls.append(data.url)
+ urls_string = '\n'.join(urls)
+ PromptServer.instance.send_progress_text(
+ f"Result URL: {urls_string}", unique_id
+ )
+ image = bytesio_to_image_tensor(
+ download_url_to_bytesio(data.url, timeout=1024)
+ )
+ if len(image.shape) < 4:
+ image = image.unsqueeze(0)
+ images.append(image)
+ output_image = torch.cat(images, dim=0)
+
+ return (output_image,)
+
+
+class RecraftImageToImageNode:
+ """
+ Modify image based on prompt and strength.
+ """
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Recraft"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (IO.IMAGE, ),
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation.",
+ },
+ ),
+ "n": (
+ IO.INT,
+ {
+ "default": 1,
+ "min": 1,
+ "max": 6,
+ "tooltip": "The number of images to generate.",
+ },
+ ),
+ "strength": (
+ IO.FLOAT,
+ {
+ "default": 0.5,
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ "tooltip": "Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity."
+ }
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
+ },
+ ),
+ },
+ "optional": {
+ "recraft_style": (RecraftIO.STYLEV3,),
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "default": "",
+ "forceInput": True,
+ "tooltip": "An optional text description of undesired elements on an image.",
+ },
+ ),
+ "recraft_controls": (
+ RecraftIO.CONTROLS,
+ {
+ "tooltip": "Optional additional controls over the generation via the Recraft Controls node."
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(
+ self,
+ image: torch.Tensor,
+ prompt: str,
+ n: int,
+ strength: float,
+ seed,
+ recraft_style: RecraftStyle = None,
+ negative_prompt: str = None,
+ recraft_controls: RecraftControls = None,
+ **kwargs,
+ ):
+ validate_string(prompt, strip_whitespace=False, max_length=1000)
+ default_style = RecraftStyle(RecraftStyleV3.realistic_image)
+ if recraft_style is None:
+ recraft_style = default_style
+
+ controls_api = None
+ if recraft_controls:
+ controls_api = recraft_controls.create_api_model()
+
+ if not negative_prompt:
+ negative_prompt = None
+
+ request = RecraftImageGenerationRequest(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ model=RecraftModel.recraftv3,
+ n=n,
+ strength=round(strength, 2),
+ style=recraft_style.style,
+ substyle=recraft_style.substyle,
+ style_id=recraft_style.style_id,
+ controls=controls_api,
+ )
+
+ images = []
+ total = image.shape[0]
+ pbar = ProgressBar(total)
+ for i in range(total):
+ sub_bytes = handle_recraft_file_request(
+ image=image[i],
+ path="/proxy/recraft/images/imageToImage",
+ request=request,
+ auth_kwargs=kwargs,
+ )
+ with handle_recraft_image_output():
+ images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
+ pbar.update(1)
+
+ images_tensor = torch.cat(images, dim=0)
+ return (images_tensor, )
+
+
+class RecraftImageInpaintingNode:
+ """
+ Modify image based on prompt and mask.
+ """
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Recraft"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (IO.IMAGE, ),
+ "mask": (IO.MASK, ),
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation.",
+ },
+ ),
+ "n": (
+ IO.INT,
+ {
+ "default": 1,
+ "min": 1,
+ "max": 6,
+ "tooltip": "The number of images to generate.",
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
+ },
+ ),
+ },
+ "optional": {
+ "recraft_style": (RecraftIO.STYLEV3,),
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "default": "",
+ "forceInput": True,
+ "tooltip": "An optional text description of undesired elements on an image.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(
+ self,
+ image: torch.Tensor,
+ mask: torch.Tensor,
+ prompt: str,
+ n: int,
+ seed,
+ recraft_style: RecraftStyle = None,
+ negative_prompt: str = None,
+ **kwargs,
+ ):
+ validate_string(prompt, strip_whitespace=False, max_length=1000)
+ default_style = RecraftStyle(RecraftStyleV3.realistic_image)
+ if recraft_style is None:
+ recraft_style = default_style
+
+ if not negative_prompt:
+ negative_prompt = None
+
+ request = RecraftImageGenerationRequest(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ model=RecraftModel.recraftv3,
+ n=n,
+ style=recraft_style.style,
+ substyle=recraft_style.substyle,
+ style_id=recraft_style.style_id,
+ )
+
+ # prepare mask tensor
+ mask = resize_mask_to_image(mask, image, allow_gradient=False, add_channel_dim=True)
+
+ images = []
+ total = image.shape[0]
+ pbar = ProgressBar(total)
+ for i in range(total):
+ sub_bytes = handle_recraft_file_request(
+ image=image[i],
+ mask=mask[i:i+1],
+ path="/proxy/recraft/images/inpaint",
+ request=request,
+ auth_kwargs=kwargs,
+ )
+ with handle_recraft_image_output():
+ images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
+ pbar.update(1)
+
+ images_tensor = torch.cat(images, dim=0)
+ return (images_tensor, )
+
+
+class RecraftTextToVectorNode:
+ """
+ Generates SVG synchronously based on prompt and resolution.
+ """
+
+ RETURN_TYPES = ("SVG",) # Changed
+ DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Recraft"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation.",
+ },
+ ),
+ "substyle": (get_v3_substyles(RecraftStyleV3.vector_illustration),),
+ "size": (
+ [res.value for res in RecraftImageSize],
+ {
+ "default": RecraftImageSize.res_1024x1024,
+ "tooltip": "The size of the generated image.",
+ },
+ ),
+ "n": (
+ IO.INT,
+ {
+ "default": 1,
+ "min": 1,
+ "max": 6,
+ "tooltip": "The number of images to generate.",
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
+ },
+ ),
+ },
+ "optional": {
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "default": "",
+ "forceInput": True,
+ "tooltip": "An optional text description of undesired elements on an image.",
+ },
+ ),
+ "recraft_controls": (
+ RecraftIO.CONTROLS,
+ {
+ "tooltip": "Optional additional controls over the generation via the Recraft Controls node."
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ def api_call(
+ self,
+ prompt: str,
+ substyle: str,
+ size: str,
+ n: int,
+ seed,
+ negative_prompt: str = None,
+ recraft_controls: RecraftControls = None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ validate_string(prompt, strip_whitespace=False, max_length=1000)
+ # create RecraftStyle so strings will be formatted properly (i.e. "None" will become None)
+ recraft_style = RecraftStyle(RecraftStyleV3.vector_illustration, substyle=substyle)
+
+ controls_api = None
+ if recraft_controls:
+ controls_api = recraft_controls.create_api_model()
+
+ if not negative_prompt:
+ negative_prompt = None
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/recraft/image_generation",
+ method=HttpMethod.POST,
+ request_model=RecraftImageGenerationRequest,
+ response_model=RecraftImageGenerationResponse,
+ ),
+ request=RecraftImageGenerationRequest(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ model=RecraftModel.recraftv3,
+ size=size,
+ n=n,
+ style=recraft_style.style,
+ substyle=recraft_style.substyle,
+ controls=controls_api,
+ ),
+ auth_kwargs=kwargs,
+ )
+ response: RecraftImageGenerationResponse = operation.execute()
+ svg_data = []
+ urls = []
+ for data in response.data:
+ if unique_id and data.url:
+ urls.append(data.url)
+ # Print result on each iteration in case of error
+ PromptServer.instance.send_progress_text(
+ f"Result URL: {' '.join(urls)}", unique_id
+ )
+ svg_data.append(download_url_to_bytesio(data.url, timeout=1024))
+
+ return (SVG(svg_data),)
+
+
+class RecraftVectorizeImageNode:
+ """
+ Generates SVG synchronously from an input image.
+ """
+
+ RETURN_TYPES = ("SVG",) # Changed
+ DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Recraft"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (IO.IMAGE, ),
+ },
+ "optional": {
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(
+ self,
+ image: torch.Tensor,
+ **kwargs,
+ ):
+ svgs = []
+ total = image.shape[0]
+ pbar = ProgressBar(total)
+ for i in range(total):
+ sub_bytes = handle_recraft_file_request(
+ image=image[i],
+ path="/proxy/recraft/images/vectorize",
+ auth_kwargs=kwargs,
+ )
+ svgs.append(SVG(sub_bytes))
+ pbar.update(1)
+
+ return (SVG.combine_all(svgs), )
+
+
+class RecraftReplaceBackgroundNode:
+ """
+ Replace background on image, based on provided prompt.
+ """
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Recraft"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (IO.IMAGE, ),
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Prompt for the image generation.",
+ },
+ ),
+ "n": (
+ IO.INT,
+ {
+ "default": 1,
+ "min": 1,
+ "max": 6,
+ "tooltip": "The number of images to generate.",
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFFFFFFFFFF,
+ "control_after_generate": True,
+ "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.",
+ },
+ ),
+ },
+ "optional": {
+ "recraft_style": (RecraftIO.STYLEV3,),
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "default": "",
+ "forceInput": True,
+ "tooltip": "An optional text description of undesired elements on an image.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(
+ self,
+ image: torch.Tensor,
+ prompt: str,
+ n: int,
+ seed,
+ recraft_style: RecraftStyle = None,
+ negative_prompt: str = None,
+ **kwargs,
+ ):
+ default_style = RecraftStyle(RecraftStyleV3.realistic_image)
+ if recraft_style is None:
+ recraft_style = default_style
+
+ if not negative_prompt:
+ negative_prompt = None
+
+ request = RecraftImageGenerationRequest(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ model=RecraftModel.recraftv3,
+ n=n,
+ style=recraft_style.style,
+ substyle=recraft_style.substyle,
+ style_id=recraft_style.style_id,
+ )
+
+ images = []
+ total = image.shape[0]
+ pbar = ProgressBar(total)
+ for i in range(total):
+ sub_bytes = handle_recraft_file_request(
+ image=image[i],
+ path="/proxy/recraft/images/replaceBackground",
+ request=request,
+ auth_kwargs=kwargs,
+ )
+ images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
+ pbar.update(1)
+
+ images_tensor = torch.cat(images, dim=0)
+ return (images_tensor, )
+
+
+class RecraftRemoveBackgroundNode:
+ """
+ Remove background from image, and return processed image and mask.
+ """
+
+ RETURN_TYPES = (IO.IMAGE, IO.MASK)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Recraft"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (IO.IMAGE, ),
+ },
+ "optional": {
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(
+ self,
+ image: torch.Tensor,
+ **kwargs,
+ ):
+ images = []
+ total = image.shape[0]
+ pbar = ProgressBar(total)
+ for i in range(total):
+ sub_bytes = handle_recraft_file_request(
+ image=image[i],
+ path="/proxy/recraft/images/removeBackground",
+ auth_kwargs=kwargs,
+ )
+ images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
+ pbar.update(1)
+
+ images_tensor = torch.cat(images, dim=0)
+ # use alpha channel as masks, in B,H,W format
+ masks_tensor = images_tensor[:,:,:,-1:].squeeze(-1)
+ return (images_tensor, masks_tensor)
+
+
+class RecraftCrispUpscaleNode:
+ """
+ Upscale image synchronously.
+ Enhances a given raster image using ‘crisp upscale’ tool, increasing image resolution, making the image sharper and cleaner.
+ """
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Recraft"
+
+ RECRAFT_PATH = "/proxy/recraft/images/crispUpscale"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (IO.IMAGE, ),
+ },
+ "optional": {
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(
+ self,
+ image: torch.Tensor,
+ **kwargs,
+ ):
+ images = []
+ total = image.shape[0]
+ pbar = ProgressBar(total)
+ for i in range(total):
+ sub_bytes = handle_recraft_file_request(
+ image=image[i],
+ path=self.RECRAFT_PATH,
+ auth_kwargs=kwargs,
+ )
+ images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0))
+ pbar.update(1)
+
+ images_tensor = torch.cat(images, dim=0)
+ return (images_tensor,)
+
+
+class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
+ """
+ Upscale image synchronously.
+ Enhances a given raster image using ‘creative upscale’ tool, boosting resolution with a focus on refining small details and faces.
+ """
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Recraft"
+
+ RECRAFT_PATH = "/proxy/recraft/images/creativeUpscale"
+
+
+# A dictionary that contains all nodes you want to export with their names
+# NOTE: names should be globally unique
+NODE_CLASS_MAPPINGS = {
+ "RecraftTextToImageNode": RecraftTextToImageNode,
+ "RecraftImageToImageNode": RecraftImageToImageNode,
+ "RecraftImageInpaintingNode": RecraftImageInpaintingNode,
+ "RecraftTextToVectorNode": RecraftTextToVectorNode,
+ "RecraftVectorizeImageNode": RecraftVectorizeImageNode,
+ "RecraftRemoveBackgroundNode": RecraftRemoveBackgroundNode,
+ "RecraftReplaceBackgroundNode": RecraftReplaceBackgroundNode,
+ "RecraftCrispUpscaleNode": RecraftCrispUpscaleNode,
+ "RecraftCreativeUpscaleNode": RecraftCreativeUpscaleNode,
+ "RecraftStyleV3RealisticImage": RecraftStyleV3RealisticImageNode,
+ "RecraftStyleV3DigitalIllustration": RecraftStyleV3DigitalIllustrationNode,
+ "RecraftStyleV3LogoRaster": RecraftStyleV3LogoRasterNode,
+ "RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary,
+ "RecraftColorRGB": RecraftColorRGBNode,
+ "RecraftControls": RecraftControlsNode,
+}
+
+# A dictionary that contains the friendly/humanly readable titles for the nodes
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "RecraftTextToImageNode": "Recraft Text to Image",
+ "RecraftImageToImageNode": "Recraft Image to Image",
+ "RecraftImageInpaintingNode": "Recraft Image Inpainting",
+ "RecraftTextToVectorNode": "Recraft Text to Vector",
+ "RecraftVectorizeImageNode": "Recraft Vectorize Image",
+ "RecraftRemoveBackgroundNode": "Recraft Remove Background",
+ "RecraftReplaceBackgroundNode": "Recraft Replace Background",
+ "RecraftCrispUpscaleNode": "Recraft Crisp Upscale Image",
+ "RecraftCreativeUpscaleNode": "Recraft Creative Upscale Image",
+ "RecraftStyleV3RealisticImage": "Recraft Style - Realistic Image",
+ "RecraftStyleV3DigitalIllustration": "Recraft Style - Digital Illustration",
+ "RecraftStyleV3LogoRaster": "Recraft Style - Logo Raster",
+ "RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library",
+ "RecraftColorRGB": "Recraft Color RGB",
+ "RecraftControls": "Recraft Controls",
+}
diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py
new file mode 100644
index 000000000..67f90478c
--- /dev/null
+++ b/comfy_api_nodes/nodes_rodin.py
@@ -0,0 +1,462 @@
+"""
+ComfyUI X Rodin3D(Deemos) API Nodes
+
+Rodin API docs: https://developer.hyper3d.ai/
+
+"""
+
+from __future__ import annotations
+from inspect import cleandoc
+from comfy.comfy_types.node_typing import IO
+import folder_paths as comfy_paths
+import requests
+import os
+import datetime
+import shutil
+import time
+import io
+import logging
+import math
+from PIL import Image
+from comfy_api_nodes.apis.rodin_api import (
+ Rodin3DGenerateRequest,
+ Rodin3DGenerateResponse,
+ Rodin3DCheckStatusRequest,
+ Rodin3DCheckStatusResponse,
+ Rodin3DDownloadRequest,
+ Rodin3DDownloadResponse,
+ JobStatus,
+)
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+ PollingOperation,
+)
+
+
+COMMON_PARAMETERS = {
+ "Seed": (
+ IO.INT,
+ {
+ "default":0,
+ "min":0,
+ "max":65535,
+ "display":"number"
+ }
+ ),
+ "Material_Type": (
+ IO.COMBO,
+ {
+ "options": ["PBR", "Shaded"],
+ "default": "PBR"
+ }
+ ),
+ "Polygon_count": (
+ IO.COMBO,
+ {
+ "options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "200K-Triangle"],
+ "default": "18K-Quad"
+ }
+ )
+}
+
+def create_task_error(response: Rodin3DGenerateResponse):
+ """Check if the response has error"""
+ return hasattr(response, "error")
+
+
+
+class Rodin3DAPI:
+ """
+ Generate 3D Assets using Rodin API
+ """
+ RETURN_TYPES = (IO.STRING,)
+ RETURN_NAMES = ("3D Model Path",)
+ CATEGORY = "api node/3d/Rodin"
+ DESCRIPTION = cleandoc(__doc__ or "")
+ FUNCTION = "api_call"
+ API_NODE = True
+
+ def tensor_to_filelike(self, tensor, max_pixels: int = 2048*2048):
+ """
+ Converts a PyTorch tensor to a file-like object.
+
+ Args:
+ - tensor (torch.Tensor): A tensor representing an image of shape (H, W, C)
+ where C is the number of channels (3 for RGB), H is height, and W is width.
+
+ Returns:
+ - io.BytesIO: A file-like object containing the image data.
+ """
+ array = tensor.cpu().numpy()
+ array = (array * 255).astype('uint8')
+ image = Image.fromarray(array, 'RGB')
+
+ original_width, original_height = image.size
+ original_pixels = original_width * original_height
+ if original_pixels > max_pixels:
+ scale = math.sqrt(max_pixels / original_pixels)
+ new_width = int(original_width * scale)
+ new_height = int(original_height * scale)
+ else:
+ new_width, new_height = original_width, original_height
+
+ if new_width != original_width or new_height != original_height:
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
+
+ img_byte_arr = io.BytesIO()
+ image.save(img_byte_arr, format='PNG') # PNG is used for lossless compression
+ img_byte_arr.seek(0)
+ return img_byte_arr
+
+ def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str:
+ has_failed = any(job.status == JobStatus.Failed for job in response.jobs)
+ all_done = all(job.status == JobStatus.Done for job in response.jobs)
+ status_list = [str(job.status) for job in response.jobs]
+ logging.info(f"[ Rodin3D API - CheckStatus ] Generate Status: {status_list}")
+ if has_failed:
+ logging.error(f"[ Rodin3D API - CheckStatus ] Generate Failed: {status_list}, Please try again.")
+ raise Exception("[ Rodin3D API ] Generate Failed, Please Try again.")
+ elif all_done:
+ return "DONE"
+ else:
+ return "Generating"
+
+ def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs):
+ if images == None:
+ raise Exception("Rodin 3D generate requires at least 1 image.")
+ if len(images) >= 5:
+ raise Exception("Rodin 3D generate requires up to 5 image.")
+
+ path = "/proxy/rodin/api/v2/rodin"
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=path,
+ method=HttpMethod.POST,
+ request_model=Rodin3DGenerateRequest,
+ response_model=Rodin3DGenerateResponse,
+ ),
+ request=Rodin3DGenerateRequest(
+ seed=seed,
+ tier=tier,
+ material=material,
+ quality=quality,
+ mesh_mode=mesh_mode
+ ),
+ files=[
+ (
+ "images",
+ open(image, "rb") if isinstance(image, str) else self.tensor_to_filelike(image)
+ )
+ for image in images if image is not None
+ ],
+ content_type = "multipart/form-data",
+ auth_kwargs=kwargs,
+ )
+
+ response = operation.execute()
+
+ if create_task_error(response):
+ error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}"
+ logging.error(error_message)
+ raise Exception(error_message)
+
+ logging.info("[ Rodin3D API - Submit Jobs ] Submit Generate Task Success!")
+ subscription_key = response.jobs.subscription_key
+ task_uuid = response.uuid
+ logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}")
+ return task_uuid, subscription_key
+
+ def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse:
+
+ path = "/proxy/rodin/api/v2/status"
+
+ poll_operation = PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path = path,
+ method=HttpMethod.POST,
+ request_model=Rodin3DCheckStatusRequest,
+ response_model=Rodin3DCheckStatusResponse,
+ ),
+ request=Rodin3DCheckStatusRequest(
+ subscription_key = subscription_key
+ ),
+ completed_statuses=["DONE"],
+ failed_statuses=["FAILED"],
+ status_extractor=self.check_rodin_status,
+ poll_interval=3.0,
+ auth_kwargs=kwargs,
+ )
+
+ logging.info("[ Rodin3D API - CheckStatus ] Generate Start!")
+
+ return poll_operation.execute()
+
+
+
+ def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse:
+ logging.info("[ Rodin3D API - Downloading ] Generate Successfully!")
+
+ path = "/proxy/rodin/api/v2/download"
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=path,
+ method=HttpMethod.POST,
+ request_model=Rodin3DDownloadRequest,
+ response_model=Rodin3DDownloadResponse,
+ ),
+ request=Rodin3DDownloadRequest(
+ task_uuid=uuid
+ ),
+ auth_kwargs=kwargs
+ )
+
+ return operation.execute()
+
+ def GetQualityAndMode(self, PolyCount):
+ if PolyCount == "200K-Triangle":
+ mesh_mode = "Raw"
+ quality = "medium"
+ else:
+ mesh_mode = "Quad"
+ if PolyCount == "4K-Quad":
+ quality = "extra-low"
+ elif PolyCount == "8K-Quad":
+ quality = "low"
+ elif PolyCount == "18K-Quad":
+ quality = "medium"
+ elif PolyCount == "50K-Quad":
+ quality = "high"
+ else:
+ quality = "medium"
+
+ return mesh_mode, quality
+
+ def DownLoadFiles(self, Url_List):
+ Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
+ os.makedirs(Save_path, exist_ok=True)
+ model_file_path = None
+ for Item in Url_List.list:
+ url = Item.url
+ file_name = Item.name
+ file_path = os.path.join(Save_path, file_name)
+ if file_path.endswith(".glb"):
+ model_file_path = file_path
+ logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}")
+ max_retries = 5
+ for attempt in range(max_retries):
+ try:
+ with requests.get(url, stream=True) as r:
+ r.raise_for_status()
+ with open(file_path, "wb") as f:
+ shutil.copyfileobj(r.raw, f)
+ break
+ except Exception as e:
+ logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}")
+ if attempt < max_retries - 1:
+ logging.info("Retrying...")
+ time.sleep(2)
+ else:
+ logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.")
+
+ return model_file_path
+
+
+class Rodin3D_Regular(Rodin3DAPI):
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "Images":
+ (
+ IO.IMAGE,
+ {
+ "forceInput":True,
+ }
+ )
+ },
+ "optional": {
+ **COMMON_PARAMETERS
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(
+ self,
+ Images,
+ Seed,
+ Material_Type,
+ Polygon_count,
+ **kwargs
+ ):
+ tier = "Regular"
+ num_images = Images.shape[0]
+ m_images = []
+ for i in range(num_images):
+ m_images.append(Images[i])
+ mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
+ task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
+ self.poll_for_task_status(subscription_key, **kwargs)
+ Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
+ model = self.DownLoadFiles(Download_List)
+
+ return (model,)
+
+class Rodin3D_Detail(Rodin3DAPI):
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "Images":
+ (
+ IO.IMAGE,
+ {
+ "forceInput":True,
+ }
+ )
+ },
+ "optional": {
+ **COMMON_PARAMETERS
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(
+ self,
+ Images,
+ Seed,
+ Material_Type,
+ Polygon_count,
+ **kwargs
+ ):
+ tier = "Detail"
+ num_images = Images.shape[0]
+ m_images = []
+ for i in range(num_images):
+ m_images.append(Images[i])
+ mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
+ task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
+ self.poll_for_task_status(subscription_key, **kwargs)
+ Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
+ model = self.DownLoadFiles(Download_List)
+
+ return (model,)
+
+class Rodin3D_Smooth(Rodin3DAPI):
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "Images":
+ (
+ IO.IMAGE,
+ {
+ "forceInput":True,
+ }
+ )
+ },
+ "optional": {
+ **COMMON_PARAMETERS
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(
+ self,
+ Images,
+ Seed,
+ Material_Type,
+ Polygon_count,
+ **kwargs
+ ):
+ tier = "Smooth"
+ num_images = Images.shape[0]
+ m_images = []
+ for i in range(num_images):
+ m_images.append(Images[i])
+ mesh_mode, quality = self.GetQualityAndMode(Polygon_count)
+ task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
+ self.poll_for_task_status(subscription_key, **kwargs)
+ Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
+ model = self.DownLoadFiles(Download_List)
+
+ return (model,)
+
+class Rodin3D_Sketch(Rodin3DAPI):
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "Images":
+ (
+ IO.IMAGE,
+ {
+ "forceInput":True,
+ }
+ )
+ },
+ "optional": {
+ "Seed":
+ (
+ IO.INT,
+ {
+ "default":0,
+ "min":0,
+ "max":65535,
+ "display":"number"
+ }
+ )
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(
+ self,
+ Images,
+ Seed,
+ **kwargs
+ ):
+ tier = "Sketch"
+ num_images = Images.shape[0]
+ m_images = []
+ for i in range(num_images):
+ m_images.append(Images[i])
+ material_type = "PBR"
+ quality = "medium"
+ mesh_mode = "Quad"
+ task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs)
+ self.poll_for_task_status(subscription_key, **kwargs)
+ Download_List = self.GetRodinDownloadList(task_uuid, **kwargs)
+ model = self.DownLoadFiles(Download_List)
+
+ return (model,)
+
+# A dictionary that contains all nodes you want to export with their names
+# NOTE: names should be globally unique
+NODE_CLASS_MAPPINGS = {
+ "Rodin3D_Regular": Rodin3D_Regular,
+ "Rodin3D_Detail": Rodin3D_Detail,
+ "Rodin3D_Smooth": Rodin3D_Smooth,
+ "Rodin3D_Sketch": Rodin3D_Sketch,
+}
+
+# A dictionary that contains the friendly/humanly readable titles for the nodes
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "Rodin3D_Regular": "Rodin 3D Generate - Regular Generate",
+ "Rodin3D_Detail": "Rodin 3D Generate - Detail Generate",
+ "Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate",
+ "Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate",
+}
diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py
new file mode 100644
index 000000000..af4b321f9
--- /dev/null
+++ b/comfy_api_nodes/nodes_runway.py
@@ -0,0 +1,635 @@
+"""Runway API Nodes
+
+API Docs:
+ - https://docs.dev.runwayml.com/api/#tag/Task-management/paths/~1v1~1tasks~1%7Bid%7D/delete
+
+User Guides:
+ - https://help.runwayml.com/hc/en-us/sections/30265301423635-Gen-3-Alpha
+ - https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video
+ - https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo
+ - https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3
+
+"""
+
+from typing import Union, Optional, Any
+from enum import Enum
+
+import torch
+
+from comfy_api_nodes.apis import (
+ RunwayImageToVideoRequest,
+ RunwayImageToVideoResponse,
+ RunwayTaskStatusResponse as TaskStatusResponse,
+ RunwayTaskStatusEnum as TaskStatus,
+ RunwayModelEnum as Model,
+ RunwayDurationEnum as Duration,
+ RunwayAspectRatioEnum as AspectRatio,
+ RunwayPromptImageObject,
+ RunwayPromptImageDetailedObject,
+ RunwayTextToImageRequest,
+ RunwayTextToImageResponse,
+ Model4,
+ ReferenceImage,
+ RunwayTextToImageAspectRatioEnum,
+)
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+ PollingOperation,
+ EmptyRequest,
+)
+from comfy_api_nodes.apinode_utils import (
+ upload_images_to_comfyapi,
+ download_url_to_video_output,
+ image_tensor_pair_to_batch,
+ validate_string,
+ download_url_to_image_tensor,
+)
+from comfy_api_nodes.mapper_utils import model_field_to_node_input
+from comfy_api.input_impl import VideoFromFile
+from comfy.comfy_types.node_typing import IO, ComfyNodeABC
+
+PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
+PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
+PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
+
+AVERAGE_DURATION_I2V_SECONDS = 64
+AVERAGE_DURATION_FLF_SECONDS = 256
+AVERAGE_DURATION_T2I_SECONDS = 41
+
+
+class RunwayApiError(Exception):
+ """Base exception for Runway API errors."""
+
+ pass
+
+
+class RunwayGen4TurboAspectRatio(str, Enum):
+ """Aspect ratios supported for Image to Video API when using gen4_turbo model."""
+
+ field_1280_720 = "1280:720"
+ field_720_1280 = "720:1280"
+ field_1104_832 = "1104:832"
+ field_832_1104 = "832:1104"
+ field_960_960 = "960:960"
+ field_1584_672 = "1584:672"
+
+
+class RunwayGen3aAspectRatio(str, Enum):
+ """Aspect ratios supported for Image to Video API when using gen3a_turbo model."""
+
+ field_768_1280 = "768:1280"
+ field_1280_768 = "1280:768"
+
+
+def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
+ """Returns the video URL from the task status response if it exists."""
+ if response.output and len(response.output) > 0:
+ return response.output[0]
+ return None
+
+
+# TODO: replace with updated image validation utils (upstream)
+def validate_input_image(image: torch.Tensor) -> bool:
+ """
+ Validate the input image is within the size limits for the Runway API.
+ See: https://docs.dev.runwayml.com/assets/inputs/#common-error-reasons
+ """
+ return image.shape[2] < 8000 and image.shape[1] < 8000
+
+
+def poll_until_finished(
+ auth_kwargs: dict[str, str],
+ api_endpoint: ApiEndpoint[Any, TaskStatusResponse],
+ estimated_duration: Optional[int] = None,
+ node_id: Optional[str] = None,
+) -> TaskStatusResponse:
+ """Polls the Runway API endpoint until the task reaches a terminal state, then returns the response."""
+ return PollingOperation(
+ poll_endpoint=api_endpoint,
+ completed_statuses=[
+ TaskStatus.SUCCEEDED.value,
+ ],
+ failed_statuses=[
+ TaskStatus.FAILED.value,
+ TaskStatus.CANCELLED.value,
+ ],
+ status_extractor=lambda response: (response.status.value),
+ auth_kwargs=auth_kwargs,
+ result_url_extractor=get_video_url_from_task_status,
+ estimated_duration=estimated_duration,
+ node_id=node_id,
+ progress_extractor=extract_progress_from_task_status,
+ ).execute()
+
+
+def extract_progress_from_task_status(
+ response: TaskStatusResponse,
+) -> Union[float, None]:
+ if hasattr(response, "progress") and response.progress is not None:
+ return response.progress * 100
+ return None
+
+
+def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
+ """Returns the image URL from the task status response if it exists."""
+ if response.output and len(response.output) > 0:
+ return response.output[0]
+ return None
+
+
+class RunwayVideoGenNode(ComfyNodeABC):
+ """Runway Video Node Base."""
+
+ RETURN_TYPES = ("VIDEO",)
+ FUNCTION = "api_call"
+ CATEGORY = "api node/video/Runway"
+ API_NODE = True
+
+ def validate_task_created(self, response: RunwayImageToVideoResponse) -> bool:
+ """
+ Validate the task creation response from the Runway API matches
+ expected format.
+ """
+ if not bool(response.id):
+ raise RunwayApiError("Invalid initial response from Runway API.")
+ return True
+
+ def validate_response(self, response: RunwayImageToVideoResponse) -> bool:
+ """
+ Validate the successful task status response from the Runway API
+ matches expected format.
+ """
+ if not response.output or len(response.output) == 0:
+ raise RunwayApiError(
+ "Runway task succeeded but no video data found in response."
+ )
+ return True
+
+ def get_response(
+ self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
+ ) -> RunwayImageToVideoResponse:
+ """Poll the task status until it is finished then get the response."""
+ return poll_until_finished(
+ auth_kwargs,
+ ApiEndpoint(
+ path=f"{PATH_GET_TASK_STATUS}/{task_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=TaskStatusResponse,
+ ),
+ estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
+ node_id=node_id,
+ )
+
+ def generate_video(
+ self,
+ request: RunwayImageToVideoRequest,
+ auth_kwargs: dict[str, str],
+ node_id: Optional[str] = None,
+ ) -> tuple[VideoFromFile]:
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_IMAGE_TO_VIDEO,
+ method=HttpMethod.POST,
+ request_model=RunwayImageToVideoRequest,
+ response_model=RunwayImageToVideoResponse,
+ ),
+ request=request,
+ auth_kwargs=auth_kwargs,
+ )
+
+ initial_response = initial_operation.execute()
+ self.validate_task_created(initial_response)
+ task_id = initial_response.id
+
+ final_response = self.get_response(task_id, auth_kwargs, node_id)
+ self.validate_response(final_response)
+
+ video_url = get_video_url_from_task_status(final_response)
+ return (download_url_to_video_output(video_url),)
+
+
+class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode):
+ """Runway Image to Video Node using Gen3a Turbo model."""
+
+ DESCRIPTION = "Generate a video from a single starting frame using Gen3a Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo."
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": model_field_to_node_input(
+ IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
+ ),
+ "start_frame": (
+ IO.IMAGE,
+ {"tooltip": "Start frame to be used for the video"},
+ ),
+ "duration": model_field_to_node_input(
+ IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
+ ),
+ "ratio": model_field_to_node_input(
+ IO.COMBO,
+ RunwayImageToVideoRequest,
+ "ratio",
+ enum_type=RunwayGen3aAspectRatio,
+ ),
+ "seed": model_field_to_node_input(
+ IO.INT,
+ RunwayImageToVideoRequest,
+ "seed",
+ control_after_generate=True,
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ def api_call(
+ self,
+ prompt: str,
+ start_frame: torch.Tensor,
+ duration: str,
+ ratio: str,
+ seed: int,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ) -> tuple[VideoFromFile]:
+ # Validate inputs
+ validate_string(prompt, min_length=1)
+ validate_input_image(start_frame)
+
+ # Upload image
+ download_urls = upload_images_to_comfyapi(
+ start_frame,
+ max_images=1,
+ mime_type="image/png",
+ auth_kwargs=kwargs,
+ )
+ if len(download_urls) != 1:
+ raise RunwayApiError("Failed to upload one or more images to comfy api.")
+
+ return self.generate_video(
+ RunwayImageToVideoRequest(
+ promptText=prompt,
+ seed=seed,
+ model=Model("gen3a_turbo"),
+ duration=Duration(duration),
+ ratio=AspectRatio(ratio),
+ promptImage=RunwayPromptImageObject(
+ root=[
+ RunwayPromptImageDetailedObject(
+ uri=str(download_urls[0]), position="first"
+ )
+ ]
+ ),
+ ),
+ auth_kwargs=kwargs,
+ node_id=unique_id,
+ )
+
+
+class RunwayImageToVideoNodeGen4(RunwayVideoGenNode):
+ """Runway Image to Video Node using Gen4 Turbo model."""
+
+ DESCRIPTION = "Generate a video from a single starting frame using Gen4 Turbo model. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video."
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": model_field_to_node_input(
+ IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
+ ),
+ "start_frame": (
+ IO.IMAGE,
+ {"tooltip": "Start frame to be used for the video"},
+ ),
+ "duration": model_field_to_node_input(
+ IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
+ ),
+ "ratio": model_field_to_node_input(
+ IO.COMBO,
+ RunwayImageToVideoRequest,
+ "ratio",
+ enum_type=RunwayGen4TurboAspectRatio,
+ ),
+ "seed": model_field_to_node_input(
+ IO.INT,
+ RunwayImageToVideoRequest,
+ "seed",
+ control_after_generate=True,
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ def api_call(
+ self,
+ prompt: str,
+ start_frame: torch.Tensor,
+ duration: str,
+ ratio: str,
+ seed: int,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ) -> tuple[VideoFromFile]:
+ # Validate inputs
+ validate_string(prompt, min_length=1)
+ validate_input_image(start_frame)
+
+ # Upload image
+ download_urls = upload_images_to_comfyapi(
+ start_frame,
+ max_images=1,
+ mime_type="image/png",
+ auth_kwargs=kwargs,
+ )
+ if len(download_urls) != 1:
+ raise RunwayApiError("Failed to upload one or more images to comfy api.")
+
+ return self.generate_video(
+ RunwayImageToVideoRequest(
+ promptText=prompt,
+ seed=seed,
+ model=Model("gen4_turbo"),
+ duration=Duration(duration),
+ ratio=AspectRatio(ratio),
+ promptImage=RunwayPromptImageObject(
+ root=[
+ RunwayPromptImageDetailedObject(
+ uri=str(download_urls[0]), position="first"
+ )
+ ]
+ ),
+ ),
+ auth_kwargs=kwargs,
+ node_id=unique_id,
+ )
+
+
+class RunwayFirstLastFrameNode(RunwayVideoGenNode):
+ """Runway First-Last Frame Node."""
+
+ DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3."
+
+ def get_response(
+ self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
+ ) -> RunwayImageToVideoResponse:
+ return poll_until_finished(
+ auth_kwargs,
+ ApiEndpoint(
+ path=f"{PATH_GET_TASK_STATUS}/{task_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=TaskStatusResponse,
+ ),
+ estimated_duration=AVERAGE_DURATION_FLF_SECONDS,
+ node_id=node_id,
+ )
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": model_field_to_node_input(
+ IO.STRING, RunwayImageToVideoRequest, "promptText", multiline=True
+ ),
+ "start_frame": (
+ IO.IMAGE,
+ {"tooltip": "Start frame to be used for the video"},
+ ),
+ "end_frame": (
+ IO.IMAGE,
+ {
+ "tooltip": "End frame to be used for the video. Supported for gen3a_turbo only."
+ },
+ ),
+ "duration": model_field_to_node_input(
+ IO.COMBO, RunwayImageToVideoRequest, "duration", enum_type=Duration
+ ),
+ "ratio": model_field_to_node_input(
+ IO.COMBO,
+ RunwayImageToVideoRequest,
+ "ratio",
+ enum_type=RunwayGen3aAspectRatio,
+ ),
+ "seed": model_field_to_node_input(
+ IO.INT,
+ RunwayImageToVideoRequest,
+ "seed",
+ control_after_generate=True,
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(
+ self,
+ prompt: str,
+ start_frame: torch.Tensor,
+ end_frame: torch.Tensor,
+ duration: str,
+ ratio: str,
+ seed: int,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ) -> tuple[VideoFromFile]:
+ # Validate inputs
+ validate_string(prompt, min_length=1)
+ validate_input_image(start_frame)
+ validate_input_image(end_frame)
+
+ # Upload images
+ stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame)
+ download_urls = upload_images_to_comfyapi(
+ stacked_input_images,
+ max_images=2,
+ mime_type="image/png",
+ auth_kwargs=kwargs,
+ )
+ if len(download_urls) != 2:
+ raise RunwayApiError("Failed to upload one or more images to comfy api.")
+
+ return self.generate_video(
+ RunwayImageToVideoRequest(
+ promptText=prompt,
+ seed=seed,
+ model=Model("gen3a_turbo"),
+ duration=Duration(duration),
+ ratio=AspectRatio(ratio),
+ promptImage=RunwayPromptImageObject(
+ root=[
+ RunwayPromptImageDetailedObject(
+ uri=str(download_urls[0]), position="first"
+ ),
+ RunwayPromptImageDetailedObject(
+ uri=str(download_urls[1]), position="last"
+ ),
+ ]
+ ),
+ ),
+ auth_kwargs=kwargs,
+ node_id=unique_id,
+ )
+
+
+class RunwayTextToImageNode(ComfyNodeABC):
+ """Runway Text to Image Node."""
+
+ RETURN_TYPES = ("IMAGE",)
+ FUNCTION = "api_call"
+ CATEGORY = "api node/image/Runway"
+ API_NODE = True
+ DESCRIPTION = "Generate an image from a text prompt using Runway's Gen 4 model. You can also include reference images to guide the generation."
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": model_field_to_node_input(
+ IO.STRING, RunwayTextToImageRequest, "promptText", multiline=True
+ ),
+ "ratio": model_field_to_node_input(
+ IO.COMBO,
+ RunwayTextToImageRequest,
+ "ratio",
+ enum_type=RunwayTextToImageAspectRatioEnum,
+ ),
+ },
+ "optional": {
+ "reference_image": (
+ IO.IMAGE,
+ {"tooltip": "Optional reference image to guide the generation"},
+ )
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ def validate_task_created(self, response: RunwayTextToImageResponse) -> bool:
+ """
+ Validate the task creation response from the Runway API matches
+ expected format.
+ """
+ if not bool(response.id):
+ raise RunwayApiError("Invalid initial response from Runway API.")
+ return True
+
+ def validate_response(self, response: TaskStatusResponse) -> bool:
+ """
+ Validate the successful task status response from the Runway API
+ matches expected format.
+ """
+ if not response.output or len(response.output) == 0:
+ raise RunwayApiError(
+ "Runway task succeeded but no image data found in response."
+ )
+ return True
+
+ def get_response(
+ self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None
+ ) -> TaskStatusResponse:
+ """Poll the task status until it is finished then get the response."""
+ return poll_until_finished(
+ auth_kwargs,
+ ApiEndpoint(
+ path=f"{PATH_GET_TASK_STATUS}/{task_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=TaskStatusResponse,
+ ),
+ estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
+ node_id=node_id,
+ )
+
+ def api_call(
+ self,
+ prompt: str,
+ ratio: str,
+ reference_image: Optional[torch.Tensor] = None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ) -> tuple[torch.Tensor]:
+ # Validate inputs
+ validate_string(prompt, min_length=1)
+
+ # Prepare reference images if provided
+ reference_images = None
+ if reference_image is not None:
+ validate_input_image(reference_image)
+ download_urls = upload_images_to_comfyapi(
+ reference_image,
+ max_images=1,
+ mime_type="image/png",
+ auth_kwargs=kwargs,
+ )
+ if len(download_urls) != 1:
+ raise RunwayApiError("Failed to upload reference image to comfy api.")
+
+ reference_images = [ReferenceImage(uri=str(download_urls[0]))]
+
+ # Create request
+ request = RunwayTextToImageRequest(
+ promptText=prompt,
+ model=Model4.gen4_image,
+ ratio=ratio,
+ referenceImages=reference_images,
+ )
+
+ # Execute initial request
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path=PATH_TEXT_TO_IMAGE,
+ method=HttpMethod.POST,
+ request_model=RunwayTextToImageRequest,
+ response_model=RunwayTextToImageResponse,
+ ),
+ request=request,
+ auth_kwargs=kwargs,
+ )
+
+ initial_response = initial_operation.execute()
+ self.validate_task_created(initial_response)
+ task_id = initial_response.id
+
+ # Poll for completion
+ final_response = self.get_response(
+ task_id, auth_kwargs=kwargs, node_id=unique_id
+ )
+ self.validate_response(final_response)
+
+ # Download and return image
+ image_url = get_image_url_from_task_status(final_response)
+ return (download_url_to_image_tensor(image_url),)
+
+
+NODE_CLASS_MAPPINGS = {
+ "RunwayFirstLastFrameNode": RunwayFirstLastFrameNode,
+ "RunwayImageToVideoNodeGen3a": RunwayImageToVideoNodeGen3a,
+ "RunwayImageToVideoNodeGen4": RunwayImageToVideoNodeGen4,
+ "RunwayTextToImageNode": RunwayTextToImageNode,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "RunwayFirstLastFrameNode": "Runway First-Last-Frame to Video",
+ "RunwayImageToVideoNodeGen3a": "Runway Image to Video (Gen3a Turbo)",
+ "RunwayImageToVideoNodeGen4": "Runway Image to Video (Gen4 Turbo)",
+ "RunwayTextToImageNode": "Runway Text to Image",
+}
diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py
new file mode 100644
index 000000000..02e421678
--- /dev/null
+++ b/comfy_api_nodes/nodes_stability.py
@@ -0,0 +1,614 @@
+from inspect import cleandoc
+from comfy.comfy_types.node_typing import IO
+from comfy_api_nodes.apis.stability_api import (
+ StabilityUpscaleConservativeRequest,
+ StabilityUpscaleCreativeRequest,
+ StabilityAsyncResponse,
+ StabilityResultsGetResponse,
+ StabilityStable3_5Request,
+ StabilityStableUltraRequest,
+ StabilityStableUltraResponse,
+ StabilityAspectRatio,
+ Stability_SD3_5_Model,
+ Stability_SD3_5_GenerationMode,
+ get_stability_style_presets,
+)
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+ PollingOperation,
+ EmptyRequest,
+)
+from comfy_api_nodes.apinode_utils import (
+ bytesio_to_image_tensor,
+ tensor_to_bytesio,
+ validate_string,
+)
+
+import torch
+import base64
+from io import BytesIO
+from enum import Enum
+
+
+class StabilityPollStatus(str, Enum):
+ finished = "finished"
+ in_progress = "in_progress"
+ failed = "failed"
+
+
+def get_async_dummy_status(x: StabilityResultsGetResponse):
+ if x.name is not None or x.errors is not None:
+ return StabilityPollStatus.failed
+ elif x.finish_reason is not None:
+ return StabilityPollStatus.finished
+ return StabilityPollStatus.in_progress
+
+
+class StabilityStableImageUltraNode:
+ """
+ Generates images synchronously based on prompt and resolution.
+ """
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Stability AI"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
+ "What you wish to see in the output image. A strong, descriptive prompt that clearly defines" +
+ "elements, colors, and subjects will lead to better results. " +
+ "To control the weight of a given word use the format `(word:weight)`," +
+ "where `word` is the word you'd like to control the weight of and `weight`" +
+ "is a value between 0 and 1. For example: `The sky was a crisp (blue:0.3) and (green:0.8)`" +
+ "would convey a sky that was blue and green, but more green than blue."
+ },
+ ),
+ "aspect_ratio": ([x.value for x in StabilityAspectRatio],
+ {
+ "default": StabilityAspectRatio.ratio_1_1,
+ "tooltip": "Aspect ratio of generated image.",
+ },
+ ),
+ "style_preset": (get_stability_style_presets(),
+ {
+ "tooltip": "Optional desired style of generated image.",
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 4294967294,
+ "control_after_generate": True,
+ "tooltip": "The random seed used for creating the noise.",
+ },
+ ),
+ },
+ "optional": {
+ "image": (IO.IMAGE,),
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "default": "",
+ "forceInput": True,
+ "tooltip": "A blurb of text describing what you do not wish to see in the output image. This is an advanced feature."
+ },
+ ),
+ "image_denoise": (
+ IO.FLOAT,
+ {
+ "default": 0.5,
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ "tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int,
+ negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
+ **kwargs):
+ validate_string(prompt, strip_whitespace=False)
+ # prepare image binary if image present
+ image_binary = None
+ if image is not None:
+ image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
+ else:
+ image_denoise = None
+
+ if not negative_prompt:
+ negative_prompt = None
+ if style_preset == "None":
+ style_preset = None
+
+ files = {
+ "image": image_binary
+ }
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/stability/v2beta/stable-image/generate/ultra",
+ method=HttpMethod.POST,
+ request_model=StabilityStableUltraRequest,
+ response_model=StabilityStableUltraResponse,
+ ),
+ request=StabilityStableUltraRequest(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ aspect_ratio=aspect_ratio,
+ seed=seed,
+ strength=image_denoise,
+ style_preset=style_preset,
+ ),
+ files=files,
+ content_type="multipart/form-data",
+ auth_kwargs=kwargs,
+ )
+ response_api = operation.execute()
+
+ if response_api.finish_reason != "SUCCESS":
+ raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
+
+ image_data = base64.b64decode(response_api.image)
+ returned_image = bytesio_to_image_tensor(BytesIO(image_data))
+
+ return (returned_image,)
+
+
+class StabilityStableImageSD_3_5Node:
+ """
+ Generates images synchronously based on prompt and resolution.
+ """
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Stability AI"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
+ },
+ ),
+ "model": ([x.value for x in Stability_SD3_5_Model],),
+ "aspect_ratio": ([x.value for x in StabilityAspectRatio],
+ {
+ "default": StabilityAspectRatio.ratio_1_1,
+ "tooltip": "Aspect ratio of generated image.",
+ },
+ ),
+ "style_preset": (get_stability_style_presets(),
+ {
+ "tooltip": "Optional desired style of generated image.",
+ },
+ ),
+ "cfg_scale": (
+ IO.FLOAT,
+ {
+ "default": 4.0,
+ "min": 1.0,
+ "max": 10.0,
+ "step": 0.1,
+ "tooltip": "How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)",
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 4294967294,
+ "control_after_generate": True,
+ "tooltip": "The random seed used for creating the noise.",
+ },
+ ),
+ },
+ "optional": {
+ "image": (IO.IMAGE,),
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "default": "",
+ "forceInput": True,
+ "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
+ },
+ ),
+ "image_denoise": (
+ IO.FLOAT,
+ {
+ "default": 0.5,
+ "min": 0.0,
+ "max": 1.0,
+ "step": 0.01,
+ "tooltip": "Denoise of input image; 0.0 yields image identical to input, 1.0 is as if no image was provided at all.",
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float,
+ negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None,
+ **kwargs):
+ validate_string(prompt, strip_whitespace=False)
+ # prepare image binary if image present
+ image_binary = None
+ mode = Stability_SD3_5_GenerationMode.text_to_image
+ if image is not None:
+ image_binary = tensor_to_bytesio(image, total_pixels=1504*1504).read()
+ mode = Stability_SD3_5_GenerationMode.image_to_image
+ aspect_ratio = None
+ else:
+ image_denoise = None
+
+ if not negative_prompt:
+ negative_prompt = None
+ if style_preset == "None":
+ style_preset = None
+
+ files = {
+ "image": image_binary
+ }
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/stability/v2beta/stable-image/generate/sd3",
+ method=HttpMethod.POST,
+ request_model=StabilityStable3_5Request,
+ response_model=StabilityStableUltraResponse,
+ ),
+ request=StabilityStable3_5Request(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ aspect_ratio=aspect_ratio,
+ seed=seed,
+ strength=image_denoise,
+ style_preset=style_preset,
+ cfg_scale=cfg_scale,
+ model=model,
+ mode=mode,
+ ),
+ files=files,
+ content_type="multipart/form-data",
+ auth_kwargs=kwargs,
+ )
+ response_api = operation.execute()
+
+ if response_api.finish_reason != "SUCCESS":
+ raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
+
+ image_data = base64.b64decode(response_api.image)
+ returned_image = bytesio_to_image_tensor(BytesIO(image_data))
+
+ return (returned_image,)
+
+
+class StabilityUpscaleConservativeNode:
+ """
+ Upscale image with minimal alterations to 4K resolution.
+ """
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Stability AI"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (IO.IMAGE,),
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
+ },
+ ),
+ "creativity": (
+ IO.FLOAT,
+ {
+ "default": 0.35,
+ "min": 0.2,
+ "max": 0.5,
+ "step": 0.01,
+ "tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 4294967294,
+ "control_after_generate": True,
+ "tooltip": "The random seed used for creating the noise.",
+ },
+ ),
+ },
+ "optional": {
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "default": "",
+ "forceInput": True,
+ "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None,
+ **kwargs):
+ validate_string(prompt, strip_whitespace=False)
+ image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
+
+ if not negative_prompt:
+ negative_prompt = None
+
+ files = {
+ "image": image_binary
+ }
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/stability/v2beta/stable-image/upscale/conservative",
+ method=HttpMethod.POST,
+ request_model=StabilityUpscaleConservativeRequest,
+ response_model=StabilityStableUltraResponse,
+ ),
+ request=StabilityUpscaleConservativeRequest(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ creativity=round(creativity,2),
+ seed=seed,
+ ),
+ files=files,
+ content_type="multipart/form-data",
+ auth_kwargs=kwargs,
+ )
+ response_api = operation.execute()
+
+ if response_api.finish_reason != "SUCCESS":
+ raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
+
+ image_data = base64.b64decode(response_api.image)
+ returned_image = bytesio_to_image_tensor(BytesIO(image_data))
+
+ return (returned_image,)
+
+
+class StabilityUpscaleCreativeNode:
+ """
+ Upscale image with minimal alterations to 4K resolution.
+ """
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Stability AI"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (IO.IMAGE,),
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "What you wish to see in the output image. A strong, descriptive prompt that clearly defines elements, colors, and subjects will lead to better results."
+ },
+ ),
+ "creativity": (
+ IO.FLOAT,
+ {
+ "default": 0.3,
+ "min": 0.1,
+ "max": 0.5,
+ "step": 0.01,
+ "tooltip": "Controls the likelihood of creating additional details not heavily conditioned by the init image.",
+ },
+ ),
+ "style_preset": (get_stability_style_presets(),
+ {
+ "tooltip": "Optional desired style of generated image.",
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 4294967294,
+ "control_after_generate": True,
+ "tooltip": "The random seed used for creating the noise.",
+ },
+ ),
+ },
+ "optional": {
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "default": "",
+ "forceInput": True,
+ "tooltip": "Keywords of what you do not wish to see in the output image. This is an advanced feature."
+ },
+ ),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None,
+ **kwargs):
+ validate_string(prompt, strip_whitespace=False)
+ image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read()
+
+ if not negative_prompt:
+ negative_prompt = None
+ if style_preset == "None":
+ style_preset = None
+
+ files = {
+ "image": image_binary
+ }
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/stability/v2beta/stable-image/upscale/creative",
+ method=HttpMethod.POST,
+ request_model=StabilityUpscaleCreativeRequest,
+ response_model=StabilityAsyncResponse,
+ ),
+ request=StabilityUpscaleCreativeRequest(
+ prompt=prompt,
+ negative_prompt=negative_prompt,
+ creativity=round(creativity,2),
+ style_preset=style_preset,
+ seed=seed,
+ ),
+ files=files,
+ content_type="multipart/form-data",
+ auth_kwargs=kwargs,
+ )
+ response_api = operation.execute()
+
+ operation = PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path=f"/proxy/stability/v2beta/results/{response_api.id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=StabilityResultsGetResponse,
+ ),
+ poll_interval=3,
+ completed_statuses=[StabilityPollStatus.finished],
+ failed_statuses=[StabilityPollStatus.failed],
+ status_extractor=lambda x: get_async_dummy_status(x),
+ auth_kwargs=kwargs,
+ )
+ response_poll: StabilityResultsGetResponse = operation.execute()
+
+ if response_poll.finish_reason != "SUCCESS":
+ raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
+
+ image_data = base64.b64decode(response_poll.result)
+ returned_image = bytesio_to_image_tensor(BytesIO(image_data))
+
+ return (returned_image,)
+
+
+class StabilityUpscaleFastNode:
+ """
+ Quickly upscales an image via Stability API call to 4x its original size; intended for upscaling low-quality/compressed images.
+ """
+
+ RETURN_TYPES = (IO.IMAGE,)
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "api_call"
+ API_NODE = True
+ CATEGORY = "api node/image/Stability AI"
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": (IO.IMAGE,),
+ },
+ "optional": {
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ },
+ }
+
+ def api_call(self, image: torch.Tensor,
+ **kwargs):
+ image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read()
+
+ files = {
+ "image": image_binary
+ }
+
+ operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/stability/v2beta/stable-image/upscale/fast",
+ method=HttpMethod.POST,
+ request_model=EmptyRequest,
+ response_model=StabilityStableUltraResponse,
+ ),
+ request=EmptyRequest(),
+ files=files,
+ content_type="multipart/form-data",
+ auth_kwargs=kwargs,
+ )
+ response_api = operation.execute()
+
+ if response_api.finish_reason != "SUCCESS":
+ raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
+
+ image_data = base64.b64decode(response_api.image)
+ returned_image = bytesio_to_image_tensor(BytesIO(image_data))
+
+ return (returned_image,)
+
+
+# A dictionary that contains all nodes you want to export with their names
+# NOTE: names should be globally unique
+NODE_CLASS_MAPPINGS = {
+ "StabilityStableImageUltraNode": StabilityStableImageUltraNode,
+ "StabilityStableImageSD_3_5Node": StabilityStableImageSD_3_5Node,
+ "StabilityUpscaleConservativeNode": StabilityUpscaleConservativeNode,
+ "StabilityUpscaleCreativeNode": StabilityUpscaleCreativeNode,
+ "StabilityUpscaleFastNode": StabilityUpscaleFastNode,
+}
+
+# A dictionary that contains the friendly/humanly readable titles for the nodes
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "StabilityStableImageUltraNode": "Stability AI Stable Image Ultra",
+ "StabilityStableImageSD_3_5Node": "Stability AI Stable Diffusion 3.5 Image",
+ "StabilityUpscaleConservativeNode": "Stability AI Upscale Conservative",
+ "StabilityUpscaleCreativeNode": "Stability AI Upscale Creative",
+ "StabilityUpscaleFastNode": "Stability AI Upscale Fast",
+}
diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py
new file mode 100644
index 000000000..65f3b21f5
--- /dev/null
+++ b/comfy_api_nodes/nodes_tripo.py
@@ -0,0 +1,574 @@
+import os
+from folder_paths import get_output_directory
+from comfy_api_nodes.mapper_utils import model_field_to_node_input
+from comfy.comfy_types.node_typing import IO
+from comfy_api_nodes.apis import (
+ TripoOrientation,
+ TripoModelVersion,
+)
+from comfy_api_nodes.apis.tripo_api import (
+ TripoTaskType,
+ TripoStyle,
+ TripoFileReference,
+ TripoFileEmptyReference,
+ TripoUrlReference,
+ TripoTaskResponse,
+ TripoTaskStatus,
+ TripoTextToModelRequest,
+ TripoImageToModelRequest,
+ TripoMultiviewToModelRequest,
+ TripoTextureModelRequest,
+ TripoRefineModelRequest,
+ TripoAnimateRigRequest,
+ TripoAnimateRetargetRequest,
+ TripoConvertModelRequest,
+)
+
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+ PollingOperation,
+ EmptyRequest,
+)
+from comfy_api_nodes.apinode_utils import (
+ upload_images_to_comfyapi,
+ download_url_to_bytesio,
+)
+
+
+def upload_image_to_tripo(image, **kwargs):
+ urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)
+ return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg"))
+
+def get_model_url_from_response(response: TripoTaskResponse) -> str:
+ if response.data is not None:
+ for key in ["pbr_model", "model", "base_model"]:
+ if getattr(response.data.output, key, None) is not None:
+ return getattr(response.data.output, key)
+ raise RuntimeError(f"Failed to get model url from response: {response}")
+
+
+def poll_until_finished(
+ kwargs: dict[str, str],
+ response: TripoTaskResponse,
+) -> tuple[str, str]:
+ """Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response."""
+ if response.code != 0:
+ raise RuntimeError(f"Failed to generate mesh: {response.error}")
+ task_id = response.data.task_id
+ response_poll = PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path=f"/proxy/tripo/v2/openapi/task/{task_id}",
+ method=HttpMethod.GET,
+ request_model=EmptyRequest,
+ response_model=TripoTaskResponse,
+ ),
+ completed_statuses=[TripoTaskStatus.SUCCESS],
+ failed_statuses=[
+ TripoTaskStatus.FAILED,
+ TripoTaskStatus.CANCELLED,
+ TripoTaskStatus.UNKNOWN,
+ TripoTaskStatus.BANNED,
+ TripoTaskStatus.EXPIRED,
+ ],
+ status_extractor=lambda x: x.data.status,
+ auth_kwargs=kwargs,
+ node_id=kwargs["unique_id"],
+ result_url_extractor=get_model_url_from_response,
+ progress_extractor=lambda x: x.data.progress,
+ ).execute()
+ if response_poll.data.status == TripoTaskStatus.SUCCESS:
+ url = get_model_url_from_response(response_poll)
+ bytesio = download_url_to_bytesio(url)
+ # Save the downloaded model file
+ model_file = f"tripo_model_{task_id}.glb"
+ with open(os.path.join(get_output_directory(), model_file), "wb") as f:
+ f.write(bytesio.getvalue())
+ return model_file, task_id
+ raise RuntimeError(f"Failed to generate mesh: {response_poll}")
+
+class TripoTextToModelNode:
+ """
+ Generates 3D models synchronously based on a text prompt using Tripo's API.
+ """
+ AVERAGE_DURATION = 80
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": ("STRING", {"multiline": True}),
+ },
+ "optional": {
+ "negative_prompt": ("STRING", {"multiline": True}),
+ "model_version": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "model_version", enum_type=TripoModelVersion),
+ "style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
+ "texture": ("BOOLEAN", {"default": True}),
+ "pbr": ("BOOLEAN", {"default": True}),
+ "image_seed": ("INT", {"default": 42}),
+ "model_seed": ("INT", {"default": 42}),
+ "texture_seed": ("INT", {"default": 42}),
+ "texture_quality": (["standard", "detailed"], {"default": "standard"}),
+ "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
+ "quad": ("BOOLEAN", {"default": False})
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
+ RETURN_NAMES = ("model_file", "model task_id")
+ FUNCTION = "generate_mesh"
+ CATEGORY = "api node/3d/Tripo"
+ API_NODE = True
+ OUTPUT_NODE = True
+
+ def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
+ style_enum = None if style == "None" else style
+ if not prompt:
+ raise RuntimeError("Prompt is required")
+ response = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/tripo/v2/openapi/task",
+ method=HttpMethod.POST,
+ request_model=TripoTextToModelRequest,
+ response_model=TripoTaskResponse,
+ ),
+ request=TripoTextToModelRequest(
+ type=TripoTaskType.TEXT_TO_MODEL,
+ prompt=prompt,
+ negative_prompt=negative_prompt if negative_prompt else None,
+ model_version=model_version,
+ style=style_enum,
+ texture=texture,
+ pbr=pbr,
+ image_seed=image_seed,
+ model_seed=model_seed,
+ texture_seed=texture_seed,
+ texture_quality=texture_quality,
+ face_limit=face_limit,
+ auto_size=True,
+ quad=quad
+ ),
+ auth_kwargs=kwargs,
+ ).execute()
+ return poll_until_finished(kwargs, response)
+
+class TripoImageToModelNode:
+ """
+ Generates 3D models synchronously based on a single image using Tripo's API.
+ """
+ AVERAGE_DURATION = 80
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": ("IMAGE",),
+ },
+ "optional": {
+ "model_version": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "model_version", enum_type=TripoModelVersion),
+ "style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"),
+ "texture": ("BOOLEAN", {"default": True}),
+ "pbr": ("BOOLEAN", {"default": True}),
+ "model_seed": ("INT", {"default": 42}),
+ "orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
+ "texture_seed": ("INT", {"default": 42}),
+ "texture_quality": (["standard", "detailed"], {"default": "standard"}),
+ "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
+ "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
+ "quad": ("BOOLEAN", {"default": False})
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
+ RETURN_NAMES = ("model_file", "model task_id")
+ FUNCTION = "generate_mesh"
+ CATEGORY = "api node/3d/Tripo"
+ API_NODE = True
+ OUTPUT_NODE = True
+
+ def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs):
+ style_enum = None if style == "None" else style
+ if image is None:
+ raise RuntimeError("Image is required")
+ tripo_file = upload_image_to_tripo(image, **kwargs)
+ response = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/tripo/v2/openapi/task",
+ method=HttpMethod.POST,
+ request_model=TripoImageToModelRequest,
+ response_model=TripoTaskResponse,
+ ),
+ request=TripoImageToModelRequest(
+ type=TripoTaskType.IMAGE_TO_MODEL,
+ file=tripo_file,
+ model_version=model_version,
+ style=style_enum,
+ texture=texture,
+ pbr=pbr,
+ model_seed=model_seed,
+ orientation=orientation,
+ texture_alignment=texture_alignment,
+ texture_seed=texture_seed,
+ texture_quality=texture_quality,
+ face_limit=face_limit,
+ auto_size=True,
+ quad=quad
+ ),
+ auth_kwargs=kwargs,
+ ).execute()
+ return poll_until_finished(kwargs, response)
+
+class TripoMultiviewToModelNode:
+ """
+ Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API.
+ """
+ AVERAGE_DURATION = 80
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "image": ("IMAGE",),
+ },
+ "optional": {
+ "image_left": ("IMAGE",),
+ "image_back": ("IMAGE",),
+ "image_right": ("IMAGE",),
+ "model_version": model_field_to_node_input(IO.COMBO, TripoMultiviewToModelRequest, "model_version", enum_type=TripoModelVersion),
+ "orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation),
+ "texture": ("BOOLEAN", {"default": True}),
+ "pbr": ("BOOLEAN", {"default": True}),
+ "model_seed": ("INT", {"default": 42}),
+ "texture_seed": ("INT", {"default": 42}),
+ "texture_quality": (["standard", "detailed"], {"default": "standard"}),
+ "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
+ "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
+ "quad": ("BOOLEAN", {"default": False})
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
+ RETURN_NAMES = ("model_file", "model task_id")
+ FUNCTION = "generate_mesh"
+ CATEGORY = "api node/3d/Tripo"
+ API_NODE = True
+ OUTPUT_NODE = True
+
+ def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs):
+ if image is None:
+ raise RuntimeError("front image for multiview is required")
+ images = []
+ image_dict = {
+ "image": image,
+ "image_left": image_left,
+ "image_back": image_back,
+ "image_right": image_right
+ }
+ if image_left is None and image_back is None and image_right is None:
+ raise RuntimeError("At least one of left, back, or right image must be provided for multiview")
+ for image_name in ["image", "image_left", "image_back", "image_right"]:
+ image_ = image_dict[image_name]
+ if image_ is not None:
+ tripo_file = upload_image_to_tripo(image_, **kwargs)
+ images.append(tripo_file)
+ else:
+ images.append(TripoFileEmptyReference())
+ response = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/tripo/v2/openapi/task",
+ method=HttpMethod.POST,
+ request_model=TripoMultiviewToModelRequest,
+ response_model=TripoTaskResponse,
+ ),
+ request=TripoMultiviewToModelRequest(
+ type=TripoTaskType.MULTIVIEW_TO_MODEL,
+ files=images,
+ model_version=model_version,
+ orientation=orientation,
+ texture=texture,
+ pbr=pbr,
+ model_seed=model_seed,
+ texture_seed=texture_seed,
+ texture_quality=texture_quality,
+ texture_alignment=texture_alignment,
+ face_limit=face_limit,
+ quad=quad,
+ ),
+ auth_kwargs=kwargs,
+ ).execute()
+ return poll_until_finished(kwargs, response)
+
+class TripoTextureNode:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "model_task_id": ("MODEL_TASK_ID",),
+ },
+ "optional": {
+ "texture": ("BOOLEAN", {"default": True}),
+ "pbr": ("BOOLEAN", {"default": True}),
+ "texture_seed": ("INT", {"default": 42}),
+ "texture_quality": (["standard", "detailed"], {"default": "standard"}),
+ "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
+ RETURN_NAMES = ("model_file", "model task_id")
+ FUNCTION = "generate_mesh"
+ CATEGORY = "api node/3d/Tripo"
+ API_NODE = True
+ OUTPUT_NODE = True
+ AVERAGE_DURATION = 80
+
+ def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs):
+ response = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/tripo/v2/openapi/task",
+ method=HttpMethod.POST,
+ request_model=TripoTextureModelRequest,
+ response_model=TripoTaskResponse,
+ ),
+ request=TripoTextureModelRequest(
+ original_model_task_id=model_task_id,
+ texture=texture,
+ pbr=pbr,
+ texture_seed=texture_seed,
+ texture_quality=texture_quality,
+ texture_alignment=texture_alignment
+ ),
+ auth_kwargs=kwargs,
+ ).execute()
+ return poll_until_finished(kwargs, response)
+
+
+class TripoRefineNode:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "model_task_id": ("MODEL_TASK_ID", {
+ "tooltip": "Must be a v1.4 Tripo model"
+ }),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ DESCRIPTION = "Refine a draft model created by v1.4 Tripo models only."
+
+ RETURN_TYPES = ("STRING", "MODEL_TASK_ID",)
+ RETURN_NAMES = ("model_file", "model task_id")
+ FUNCTION = "generate_mesh"
+ CATEGORY = "api node/3d/Tripo"
+ API_NODE = True
+ OUTPUT_NODE = True
+ AVERAGE_DURATION = 240
+
+ def generate_mesh(self, model_task_id, **kwargs):
+ response = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/tripo/v2/openapi/task",
+ method=HttpMethod.POST,
+ request_model=TripoRefineModelRequest,
+ response_model=TripoTaskResponse,
+ ),
+ request=TripoRefineModelRequest(
+ draft_model_task_id=model_task_id
+ ),
+ auth_kwargs=kwargs,
+ ).execute()
+ return poll_until_finished(kwargs, response)
+
+
+class TripoRigNode:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "original_model_task_id": ("MODEL_TASK_ID",),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = ("STRING", "RIG_TASK_ID")
+ RETURN_NAMES = ("model_file", "rig task_id")
+ FUNCTION = "generate_mesh"
+ CATEGORY = "api node/3d/Tripo"
+ API_NODE = True
+ OUTPUT_NODE = True
+ AVERAGE_DURATION = 180
+
+ def generate_mesh(self, original_model_task_id, **kwargs):
+ response = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/tripo/v2/openapi/task",
+ method=HttpMethod.POST,
+ request_model=TripoAnimateRigRequest,
+ response_model=TripoTaskResponse,
+ ),
+ request=TripoAnimateRigRequest(
+ original_model_task_id=original_model_task_id,
+ out_format="glb",
+ spec="tripo"
+ ),
+ auth_kwargs=kwargs,
+ ).execute()
+ return poll_until_finished(kwargs, response)
+
+class TripoRetargetNode:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "original_model_task_id": ("RIG_TASK_ID",),
+ "animation": ([
+ "preset:idle",
+ "preset:walk",
+ "preset:climb",
+ "preset:jump",
+ "preset:slash",
+ "preset:shoot",
+ "preset:hurt",
+ "preset:fall",
+ "preset:turn",
+ ],),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = ("STRING", "RETARGET_TASK_ID")
+ RETURN_NAMES = ("model_file", "retarget task_id")
+ FUNCTION = "generate_mesh"
+ CATEGORY = "api node/3d/Tripo"
+ API_NODE = True
+ OUTPUT_NODE = True
+ AVERAGE_DURATION = 30
+
+ def generate_mesh(self, animation, original_model_task_id, **kwargs):
+ response = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/tripo/v2/openapi/task",
+ method=HttpMethod.POST,
+ request_model=TripoAnimateRetargetRequest,
+ response_model=TripoTaskResponse,
+ ),
+ request=TripoAnimateRetargetRequest(
+ original_model_task_id=original_model_task_id,
+ animation=animation,
+ out_format="glb",
+ bake_animation=True
+ ),
+ auth_kwargs=kwargs,
+ ).execute()
+ return poll_until_finished(kwargs, response)
+
+class TripoConversionNode:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "original_model_task_id": ("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID",),
+ "format": (["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"],),
+ },
+ "optional": {
+ "quad": ("BOOLEAN", {"default": False}),
+ "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}),
+ "texture_size": ("INT", {"min": 128, "max": 4096, "default": 4096}),
+ "texture_format": (["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], {"default": "JPEG"})
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ @classmethod
+ def VALIDATE_INPUTS(cls, input_types):
+ # The min and max of input1 and input2 are still validated because
+ # we didn't take `input1` or `input2` as arguments
+ if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"):
+ return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type"
+ return True
+
+ RETURN_TYPES = ()
+ FUNCTION = "generate_mesh"
+ CATEGORY = "api node/3d/Tripo"
+ API_NODE = True
+ OUTPUT_NODE = True
+ AVERAGE_DURATION = 30
+
+ def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs):
+ if not original_model_task_id:
+ raise RuntimeError("original_model_task_id is required")
+ response = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/tripo/v2/openapi/task",
+ method=HttpMethod.POST,
+ request_model=TripoConvertModelRequest,
+ response_model=TripoTaskResponse,
+ ),
+ request=TripoConvertModelRequest(
+ original_model_task_id=original_model_task_id,
+ format=format,
+ quad=quad if quad else None,
+ face_limit=face_limit if face_limit != -1 else None,
+ texture_size=texture_size if texture_size != 4096 else None,
+ texture_format=texture_format if texture_format != "JPEG" else None
+ ),
+ auth_kwargs=kwargs,
+ ).execute()
+ return poll_until_finished(kwargs, response)
+
+NODE_CLASS_MAPPINGS = {
+ "TripoTextToModelNode": TripoTextToModelNode,
+ "TripoImageToModelNode": TripoImageToModelNode,
+ "TripoMultiviewToModelNode": TripoMultiviewToModelNode,
+ "TripoTextureNode": TripoTextureNode,
+ "TripoRefineNode": TripoRefineNode,
+ "TripoRigNode": TripoRigNode,
+ "TripoRetargetNode": TripoRetargetNode,
+ "TripoConversionNode": TripoConversionNode,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "TripoTextToModelNode": "Tripo: Text to Model",
+ "TripoImageToModelNode": "Tripo: Image to Model",
+ "TripoMultiviewToModelNode": "Tripo: Multiview to Model",
+ "TripoTextureNode": "Tripo: Texture model",
+ "TripoRefineNode": "Tripo: Refine Draft model",
+ "TripoRigNode": "Tripo: Rig model",
+ "TripoRetargetNode": "Tripo: Retarget rigged model",
+ "TripoConversionNode": "Tripo: Convert model",
+}
diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py
new file mode 100644
index 000000000..df846d5dd
--- /dev/null
+++ b/comfy_api_nodes/nodes_veo2.py
@@ -0,0 +1,308 @@
+import io
+import logging
+import base64
+import requests
+import torch
+from typing import Optional
+
+from comfy.comfy_types.node_typing import IO, ComfyNodeABC
+from comfy_api.input_impl.video_types import VideoFromFile
+from comfy_api_nodes.apis import (
+ Veo2GenVidRequest,
+ Veo2GenVidResponse,
+ Veo2GenVidPollRequest,
+ Veo2GenVidPollResponse
+)
+from comfy_api_nodes.apis.client import (
+ ApiEndpoint,
+ HttpMethod,
+ SynchronousOperation,
+ PollingOperation,
+)
+
+from comfy_api_nodes.apinode_utils import (
+ downscale_image_tensor,
+ tensor_to_base64_string
+)
+
+AVERAGE_DURATION_VIDEO_GEN = 32
+
+def convert_image_to_base64(image: torch.Tensor):
+ if image is None:
+ return None
+
+ scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
+ return tensor_to_base64_string(scaled_image)
+
+
+def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]:
+ if (
+ poll_response.response
+ and hasattr(poll_response.response, "videos")
+ and poll_response.response.videos
+ and len(poll_response.response.videos) > 0
+ ):
+ video = poll_response.response.videos[0]
+ else:
+ return None
+ if hasattr(video, "gcsUri") and video.gcsUri:
+ return str(video.gcsUri)
+ return None
+
+
+class VeoVideoGenerationNode(ComfyNodeABC):
+ """
+ Generates videos from text prompts using Google's Veo API.
+
+ This node can create videos from text descriptions and optional image inputs,
+ with control over parameters like aspect ratio, duration, and more.
+ """
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Text description of the video",
+ },
+ ),
+ "aspect_ratio": (
+ IO.COMBO,
+ {
+ "options": ["16:9", "9:16"],
+ "default": "16:9",
+ "tooltip": "Aspect ratio of the output video",
+ },
+ ),
+ },
+ "optional": {
+ "negative_prompt": (
+ IO.STRING,
+ {
+ "multiline": True,
+ "default": "",
+ "tooltip": "Negative text prompt to guide what to avoid in the video",
+ },
+ ),
+ "duration_seconds": (
+ IO.INT,
+ {
+ "default": 5,
+ "min": 5,
+ "max": 8,
+ "step": 1,
+ "display": "number",
+ "tooltip": "Duration of the output video in seconds",
+ },
+ ),
+ "enhance_prompt": (
+ IO.BOOLEAN,
+ {
+ "default": True,
+ "tooltip": "Whether to enhance the prompt with AI assistance",
+ }
+ ),
+ "person_generation": (
+ IO.COMBO,
+ {
+ "options": ["ALLOW", "BLOCK"],
+ "default": "ALLOW",
+ "tooltip": "Whether to allow generating people in the video",
+ },
+ ),
+ "seed": (
+ IO.INT,
+ {
+ "default": 0,
+ "min": 0,
+ "max": 0xFFFFFFFF,
+ "step": 1,
+ "display": "number",
+ "control_after_generate": True,
+ "tooltip": "Seed for video generation (0 for random)",
+ },
+ ),
+ "image": (IO.IMAGE, {
+ "default": None,
+ "tooltip": "Optional reference image to guide video generation",
+ }),
+ },
+ "hidden": {
+ "auth_token": "AUTH_TOKEN_COMFY_ORG",
+ "comfy_api_key": "API_KEY_COMFY_ORG",
+ "unique_id": "UNIQUE_ID",
+ },
+ }
+
+ RETURN_TYPES = (IO.VIDEO,)
+ FUNCTION = "generate_video"
+ CATEGORY = "api node/video/Veo"
+ DESCRIPTION = "Generates videos from text prompts using Google's Veo API"
+ API_NODE = True
+
+ def generate_video(
+ self,
+ prompt,
+ aspect_ratio="16:9",
+ negative_prompt="",
+ duration_seconds=5,
+ enhance_prompt=True,
+ person_generation="ALLOW",
+ seed=0,
+ image=None,
+ unique_id: Optional[str] = None,
+ **kwargs,
+ ):
+ # Prepare the instances for the request
+ instances = []
+
+ instance = {
+ "prompt": prompt
+ }
+
+ # Add image if provided
+ if image is not None:
+ image_base64 = convert_image_to_base64(image)
+ if image_base64:
+ instance["image"] = {
+ "bytesBase64Encoded": image_base64,
+ "mimeType": "image/png"
+ }
+
+ instances.append(instance)
+
+ # Create parameters dictionary
+ parameters = {
+ "aspectRatio": aspect_ratio,
+ "personGeneration": person_generation,
+ "durationSeconds": duration_seconds,
+ "enhancePrompt": enhance_prompt,
+ }
+
+ # Add optional parameters if provided
+ if negative_prompt:
+ parameters["negativePrompt"] = negative_prompt
+ if seed > 0:
+ parameters["seed"] = seed
+
+ # Initial request to start video generation
+ initial_operation = SynchronousOperation(
+ endpoint=ApiEndpoint(
+ path="/proxy/veo/generate",
+ method=HttpMethod.POST,
+ request_model=Veo2GenVidRequest,
+ response_model=Veo2GenVidResponse
+ ),
+ request=Veo2GenVidRequest(
+ instances=instances,
+ parameters=parameters
+ ),
+ auth_kwargs=kwargs,
+ )
+
+ initial_response = initial_operation.execute()
+ operation_name = initial_response.name
+
+ logging.info(f"Veo generation started with operation name: {operation_name}")
+
+ # Define status extractor function
+ def status_extractor(response):
+ # Only return "completed" if the operation is done, regardless of success or failure
+ # We'll check for errors after polling completes
+ return "completed" if response.done else "pending"
+
+ # Define progress extractor function
+ def progress_extractor(response):
+ # Could be enhanced if the API provides progress information
+ return None
+
+ # Define the polling operation
+ poll_operation = PollingOperation(
+ poll_endpoint=ApiEndpoint(
+ path="/proxy/veo/poll",
+ method=HttpMethod.POST,
+ request_model=Veo2GenVidPollRequest,
+ response_model=Veo2GenVidPollResponse
+ ),
+ completed_statuses=["completed"],
+ failed_statuses=[], # No failed statuses, we'll handle errors after polling
+ status_extractor=status_extractor,
+ progress_extractor=progress_extractor,
+ request=Veo2GenVidPollRequest(
+ operationName=operation_name
+ ),
+ auth_kwargs=kwargs,
+ poll_interval=5.0,
+ result_url_extractor=get_video_url_from_response,
+ node_id=unique_id,
+ estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
+ )
+
+ # Execute the polling operation
+ poll_response = poll_operation.execute()
+
+ # Now check for errors in the final response
+ # Check for error in poll response
+ if hasattr(poll_response, 'error') and poll_response.error:
+ error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})"
+ logging.error(error_message)
+ raise Exception(error_message)
+
+ # Check for RAI filtered content
+ if (hasattr(poll_response.response, 'raiMediaFilteredCount') and
+ poll_response.response.raiMediaFilteredCount > 0):
+
+ # Extract reason message if available
+ if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and
+ poll_response.response.raiMediaFilteredReasons):
+ reason = poll_response.response.raiMediaFilteredReasons[0]
+ error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
+ else:
+ error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
+
+ logging.error(error_message)
+ raise Exception(error_message)
+
+ # Extract video data
+ video_data = None
+ if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
+ video = poll_response.response.videos[0]
+
+ # Check if video is provided as base64 or URL
+ if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded:
+ # Decode base64 string to bytes
+ video_data = base64.b64decode(video.bytesBase64Encoded)
+ elif hasattr(video, 'gcsUri') and video.gcsUri:
+ # Download from URL
+ video_url = video.gcsUri
+ video_response = requests.get(video_url)
+ video_data = video_response.content
+ else:
+ raise Exception("Video returned but no data or URL was provided")
+ else:
+ raise Exception("Video generation completed but no video was returned")
+
+ if not video_data:
+ raise Exception("No video data was returned")
+
+ logging.info("Video generation completed successfully")
+
+ # Convert video data to BytesIO object
+ video_io = io.BytesIO(video_data)
+
+ # Return VideoFromFile object
+ return (VideoFromFile(video_io),)
+
+
+# Register the node
+NODE_CLASS_MAPPINGS = {
+ "VeoVideoGenerationNode": VeoVideoGenerationNode,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "VeoVideoGenerationNode": "Google Veo2 Video Generation",
+}
diff --git a/comfy_api_nodes/redocly-dev.yaml b/comfy_api_nodes/redocly-dev.yaml
new file mode 100644
index 000000000..d9e3cab70
--- /dev/null
+++ b/comfy_api_nodes/redocly-dev.yaml
@@ -0,0 +1,10 @@
+# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
+# This is used for development purposes to generate stubs for unreleased API endpoints.
+apis:
+ filter:
+ root: openapi.yaml
+ decorators:
+ filter-in:
+ property: tags
+ value: ['API Nodes']
+ matchStrategy: all
diff --git a/comfy_api_nodes/redocly.yaml b/comfy_api_nodes/redocly.yaml
new file mode 100644
index 000000000..d102345b1
--- /dev/null
+++ b/comfy_api_nodes/redocly.yaml
@@ -0,0 +1,10 @@
+# This file is used to filter the Comfy Org OpenAPI spec for schemas related to API Nodes.
+
+apis:
+ filter:
+ root: openapi.yaml
+ decorators:
+ filter-in:
+ property: tags
+ value: ['API Nodes', 'Released']
+ matchStrategy: all
diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py
new file mode 100644
index 000000000..031b9fbd3
--- /dev/null
+++ b/comfy_api_nodes/util/validation_utils.py
@@ -0,0 +1,100 @@
+import logging
+from typing import Optional
+
+import torch
+from comfy_api.input.video_types import VideoInput
+
+
+def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]:
+ if len(image.shape) == 4:
+ return image.shape[1], image.shape[2]
+ elif len(image.shape) == 3:
+ return image.shape[0], image.shape[1]
+ else:
+ raise ValueError("Invalid image tensor shape.")
+
+
+def validate_image_dimensions(
+ image: torch.Tensor,
+ min_width: Optional[int] = None,
+ max_width: Optional[int] = None,
+ min_height: Optional[int] = None,
+ max_height: Optional[int] = None,
+):
+ height, width = get_image_dimensions(image)
+
+ if min_width is not None and width < min_width:
+ raise ValueError(f"Image width must be at least {min_width}px, got {width}px")
+ if max_width is not None and width > max_width:
+ raise ValueError(f"Image width must be at most {max_width}px, got {width}px")
+ if min_height is not None and height < min_height:
+ raise ValueError(
+ f"Image height must be at least {min_height}px, got {height}px"
+ )
+ if max_height is not None and height > max_height:
+ raise ValueError(f"Image height must be at most {max_height}px, got {height}px")
+
+
+def validate_image_aspect_ratio(
+ image: torch.Tensor,
+ min_aspect_ratio: Optional[float] = None,
+ max_aspect_ratio: Optional[float] = None,
+):
+ width, height = get_image_dimensions(image)
+ aspect_ratio = width / height
+
+ if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio:
+ raise ValueError(
+ f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}"
+ )
+ if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio:
+ raise ValueError(
+ f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}"
+ )
+
+
+def validate_video_dimensions(
+ video: VideoInput,
+ min_width: Optional[int] = None,
+ max_width: Optional[int] = None,
+ min_height: Optional[int] = None,
+ max_height: Optional[int] = None,
+):
+ try:
+ width, height = video.get_dimensions()
+ except Exception as e:
+ logging.error("Error getting dimensions of video: %s", e)
+ return
+
+ if min_width is not None and width < min_width:
+ raise ValueError(f"Video width must be at least {min_width}px, got {width}px")
+ if max_width is not None and width > max_width:
+ raise ValueError(f"Video width must be at most {max_width}px, got {width}px")
+ if min_height is not None and height < min_height:
+ raise ValueError(
+ f"Video height must be at least {min_height}px, got {height}px"
+ )
+ if max_height is not None and height > max_height:
+ raise ValueError(f"Video height must be at most {max_height}px, got {height}px")
+
+
+def validate_video_duration(
+ video: VideoInput,
+ min_duration: Optional[float] = None,
+ max_duration: Optional[float] = None,
+):
+ try:
+ duration = video.get_duration()
+ except Exception as e:
+ logging.error("Error getting duration of video: %s", e)
+ return
+
+ epsilon = 0.0001
+ if min_duration is not None and min_duration - epsilon > duration:
+ raise ValueError(
+ f"Video duration must be at least {min_duration}s, got {duration}s"
+ )
+ if max_duration is not None and duration > max_duration + epsilon:
+ raise ValueError(
+ f"Video duration must be at most {max_duration}s, got {duration}s"
+ )
diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py
index 630f280fc..dbb37b89f 100644
--- a/comfy_execution/caching.py
+++ b/comfy_execution/caching.py
@@ -316,3 +316,156 @@ class LRUCache(BasicCache):
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
return self
+
+class DependencyAwareCache(BasicCache):
+ """
+ A cache implementation that tracks dependencies between nodes and manages
+ their execution and caching accordingly. It extends the BasicCache class.
+ Nodes are removed from this cache once all of their descendants have been
+ executed.
+ """
+
+ def __init__(self, key_class):
+ """
+ Initialize the DependencyAwareCache.
+
+ Args:
+ key_class: The class used for generating cache keys.
+ """
+ super().__init__(key_class)
+ self.descendants = {} # Maps node_id -> set of descendant node_ids
+ self.ancestors = {} # Maps node_id -> set of ancestor node_ids
+ self.executed_nodes = set() # Tracks nodes that have been executed
+
+ def set_prompt(self, dynprompt, node_ids, is_changed_cache):
+ """
+ Clear the entire cache and rebuild the dependency graph.
+
+ Args:
+ dynprompt: The dynamic prompt object containing node information.
+ node_ids: List of node IDs to initialize the cache for.
+ is_changed_cache: Flag indicating if the cache has changed.
+ """
+ # Clear all existing cache data
+ self.cache.clear()
+ self.subcaches.clear()
+ self.descendants.clear()
+ self.ancestors.clear()
+ self.executed_nodes.clear()
+
+ # Call the parent method to initialize the cache with the new prompt
+ super().set_prompt(dynprompt, node_ids, is_changed_cache)
+
+ # Rebuild the dependency graph
+ self._build_dependency_graph(dynprompt, node_ids)
+
+ def _build_dependency_graph(self, dynprompt, node_ids):
+ """
+ Build the dependency graph for all nodes.
+
+ Args:
+ dynprompt: The dynamic prompt object containing node information.
+ node_ids: List of node IDs to build the graph for.
+ """
+ self.descendants.clear()
+ self.ancestors.clear()
+ for node_id in node_ids:
+ self.descendants[node_id] = set()
+ self.ancestors[node_id] = set()
+
+ for node_id in node_ids:
+ inputs = dynprompt.get_node(node_id)["inputs"]
+ for input_data in inputs.values():
+ if is_link(input_data): # Check if the input is a link to another node
+ ancestor_id = input_data[0]
+ self.descendants[ancestor_id].add(node_id)
+ self.ancestors[node_id].add(ancestor_id)
+
+ def set(self, node_id, value):
+ """
+ Mark a node as executed and store its value in the cache.
+
+ Args:
+ node_id: The ID of the node to store.
+ value: The value to store for the node.
+ """
+ self._set_immediate(node_id, value)
+ self.executed_nodes.add(node_id)
+ self._cleanup_ancestors(node_id)
+
+ def get(self, node_id):
+ """
+ Retrieve the cached value for a node.
+
+ Args:
+ node_id: The ID of the node to retrieve.
+
+ Returns:
+ The cached value for the node.
+ """
+ return self._get_immediate(node_id)
+
+ def ensure_subcache_for(self, node_id, children_ids):
+ """
+ Ensure a subcache exists for a node and update dependencies.
+
+ Args:
+ node_id: The ID of the parent node.
+ children_ids: List of child node IDs to associate with the parent node.
+
+ Returns:
+ The subcache object for the node.
+ """
+ subcache = super()._ensure_subcache(node_id, children_ids)
+ for child_id in children_ids:
+ self.descendants[node_id].add(child_id)
+ self.ancestors[child_id].add(node_id)
+ return subcache
+
+ def _cleanup_ancestors(self, node_id):
+ """
+ Check if ancestors of a node can be removed from the cache.
+
+ Args:
+ node_id: The ID of the node whose ancestors are to be checked.
+ """
+ for ancestor_id in self.ancestors.get(node_id, []):
+ if ancestor_id in self.executed_nodes:
+ # Remove ancestor if all its descendants have been executed
+ if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]):
+ self._remove_node(ancestor_id)
+
+ def _remove_node(self, node_id):
+ """
+ Remove a node from the cache.
+
+ Args:
+ node_id: The ID of the node to remove.
+ """
+ cache_key = self.cache_key_set.get_data_key(node_id)
+ if cache_key in self.cache:
+ del self.cache[cache_key]
+ subcache_key = self.cache_key_set.get_subcache_key(node_id)
+ if subcache_key in self.subcaches:
+ del self.subcaches[subcache_key]
+
+ def clean_unused(self):
+ """
+ Clean up unused nodes. This is a no-op for this cache implementation.
+ """
+ pass
+
+ def recursive_debug_dump(self):
+ """
+ Dump the cache and dependency graph for debugging.
+
+ Returns:
+ A list containing the cache state and dependency graph.
+ """
+ result = super().recursive_debug_dump()
+ result.append({
+ "descendants": self.descendants,
+ "ancestors": self.ancestors,
+ "executed_nodes": list(self.executed_nodes),
+ })
+ return result
diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py
index 59b42b746..a2799b52e 100644
--- a/comfy_execution/graph.py
+++ b/comfy_execution/graph.py
@@ -1,6 +1,9 @@
-import nodes
+from __future__ import annotations
+from typing import Type, Literal
+import nodes
from comfy_execution.graph_utils import is_link
+from comfy.comfy_types.node_typing import ComfyNodeABC, InputTypeDict, InputTypeOptions
class DependencyCycleError(Exception):
pass
@@ -54,7 +57,22 @@ class DynamicPrompt:
def get_original_prompt(self):
return self.original_prompt
-def get_input_info(class_def, input_name, valid_inputs=None):
+def get_input_info(
+ class_def: Type[ComfyNodeABC],
+ input_name: str,
+ valid_inputs: InputTypeDict | None = None
+) -> tuple[str, Literal["required", "optional", "hidden"], InputTypeOptions] | tuple[None, None, None]:
+ """Get the input type, category, and extra info for a given input name.
+
+ Arguments:
+ class_def: The class definition of the node.
+ input_name: The name of the input to get info for.
+ valid_inputs: The valid inputs for the node, or None to use the class_def.INPUT_TYPES().
+
+ Returns:
+ tuple[str, str, dict] | tuple[None, None, None]: The input type, category, and extra info for the input name.
+ """
+
valid_inputs = valid_inputs or class_def.INPUT_TYPES()
input_info = None
input_category = None
@@ -126,7 +144,7 @@ class TopologicalSort:
from_node_id, from_socket = value
if subgraph_nodes is not None and from_node_id not in subgraph_nodes:
continue
- input_type, input_category, input_info = self.get_input_info(unique_id, input_name)
+ _, _, input_info = self.get_input_info(unique_id, input_name)
is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"]
if (include_lazy or not is_lazy) and not self.is_cached(from_node_id):
node_ids.append(from_node_id)
diff --git a/comfy_extras/nodes_ace.py b/comfy_extras/nodes_ace.py
new file mode 100644
index 000000000..cbfec15a2
--- /dev/null
+++ b/comfy_extras/nodes_ace.py
@@ -0,0 +1,49 @@
+import torch
+import comfy.model_management
+import node_helpers
+
+class TextEncodeAceStepAudio:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "clip": ("CLIP", ),
+ "tags": ("STRING", {"multiline": True, "dynamicPrompts": True}),
+ "lyrics": ("STRING", {"multiline": True, "dynamicPrompts": True}),
+ "lyrics_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
+ }}
+ RETURN_TYPES = ("CONDITIONING",)
+ FUNCTION = "encode"
+
+ CATEGORY = "conditioning"
+
+ def encode(self, clip, tags, lyrics, lyrics_strength):
+ tokens = clip.tokenize(tags, lyrics=lyrics)
+ conditioning = clip.encode_from_tokens_scheduled(tokens)
+ conditioning = node_helpers.conditioning_set_values(conditioning, {"lyrics_strength": lyrics_strength})
+ return (conditioning, )
+
+
+class EmptyAceStepLatentAudio:
+ def __init__(self):
+ self.device = comfy.model_management.intermediate_device()
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"seconds": ("FLOAT", {"default": 120.0, "min": 1.0, "max": 1000.0, "step": 0.1}),
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
+ }}
+ RETURN_TYPES = ("LATENT",)
+ FUNCTION = "generate"
+
+ CATEGORY = "latent/audio"
+
+ def generate(self, seconds, batch_size):
+ length = int(seconds * 44100 / 512 / 8)
+ latent = torch.zeros([batch_size, 8, 16, length], device=self.device)
+ return ({"samples": latent, "type": "audio"}, )
+
+
+NODE_CLASS_MAPPINGS = {
+ "TextEncodeAceStepAudio": TextEncodeAceStepAudio,
+ "EmptyAceStepLatentAudio": EmptyAceStepLatentAudio,
+}
diff --git a/comfy_extras/nodes_apg.py b/comfy_extras/nodes_apg.py
new file mode 100644
index 000000000..25b21b1b8
--- /dev/null
+++ b/comfy_extras/nodes_apg.py
@@ -0,0 +1,76 @@
+import torch
+
+def project(v0, v1):
+ v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
+ v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
+ v0_orthogonal = v0 - v0_parallel
+ return v0_parallel, v0_orthogonal
+
+class APG:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "model": ("MODEL",),
+ "eta": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "tooltip": "Controls the scale of the parallel guidance vector. Default CFG behavior at a setting of 1."}),
+ "norm_threshold": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 50.0, "step": 0.1, "tooltip": "Normalize guidance vector to this value, normalization disable at a setting of 0."}),
+ "momentum": ("FLOAT", {"default": 0.0, "min": -5.0, "max": 1.0, "step": 0.01, "tooltip":"Controls a running average of guidance during diffusion, disabled at a setting of 0."}),
+ }
+ }
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+ CATEGORY = "sampling/custom_sampling"
+
+ def patch(self, model, eta, norm_threshold, momentum):
+ running_avg = 0
+ prev_sigma = None
+
+ def pre_cfg_function(args):
+ nonlocal running_avg, prev_sigma
+
+ if len(args["conds_out"]) == 1: return args["conds_out"]
+
+ cond = args["conds_out"][0]
+ uncond = args["conds_out"][1]
+ sigma = args["sigma"][0]
+ cond_scale = args["cond_scale"]
+
+ if prev_sigma is not None and sigma > prev_sigma:
+ running_avg = 0
+ prev_sigma = sigma
+
+ guidance = cond - uncond
+
+ if momentum != 0:
+ if not torch.is_tensor(running_avg):
+ running_avg = guidance
+ else:
+ running_avg = momentum * running_avg + guidance
+ guidance = running_avg
+
+ if norm_threshold > 0:
+ guidance_norm = guidance.norm(p=2, dim=[-1, -2, -3], keepdim=True)
+ scale = torch.minimum(
+ torch.ones_like(guidance_norm),
+ norm_threshold / guidance_norm
+ )
+ guidance = guidance * scale
+
+ guidance_parallel, guidance_orthogonal = project(guidance, cond)
+ modified_guidance = guidance_orthogonal + eta * guidance_parallel
+
+ modified_cond = (uncond + modified_guidance) + (cond - uncond) / cond_scale
+
+ return [modified_cond, uncond] + args["conds_out"][2:]
+
+ m = model.clone()
+ m.set_model_sampler_pre_cfg_function(pre_cfg_function)
+ return (m,)
+
+NODE_CLASS_MAPPINGS = {
+ "APG": APG,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "APG": "Adaptive Projected Guidance",
+}
diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py
index 3cb918e09..49af1eae4 100644
--- a/comfy_extras/nodes_audio.py
+++ b/comfy_extras/nodes_audio.py
@@ -1,3 +1,6 @@
+from __future__ import annotations
+
+import av
import torchaudio
import torch
import comfy.model_management
@@ -5,11 +8,11 @@ import folder_paths
import os
import io
import json
-import struct
import random
import hashlib
import node_helpers
from comfy.cli_args import args
+from comfy.comfy_types import FileLocator
class EmptyLatentAudio:
def __init__(self):
@@ -87,60 +90,118 @@ class VAEDecodeAudio:
return ({"waveform": audio, "sample_rate": 44100}, )
-def create_vorbis_comment_block(comment_dict, last_block):
- vendor_string = b'ComfyUI'
- vendor_length = len(vendor_string)
+def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"):
- comments = []
- for key, value in comment_dict.items():
- comment = f"{key}={value}".encode('utf-8')
- comments.append(struct.pack('I', len(comment_data))[1:] + comment_data
+ # Opus supported sample rates
+ OPUS_RATES = [8000, 12000, 16000, 24000, 48000]
- return comment_block
+ for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
+ filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
+ file = f"{filename_with_batch_num}_{counter:05}_.{format}"
+ output_path = os.path.join(full_output_folder, file)
-def insert_or_replace_vorbis_comment(flac_io, comment_dict):
- if len(comment_dict) == 0:
- return flac_io
+ # Use original sample rate initially
+ sample_rate = audio["sample_rate"]
- flac_io.seek(4)
+ # Handle Opus sample rate requirements
+ if format == "opus":
+ if sample_rate > 48000:
+ sample_rate = 48000
+ elif sample_rate not in OPUS_RATES:
+ # Find the next highest supported rate
+ for rate in sorted(OPUS_RATES):
+ if rate > sample_rate:
+ sample_rate = rate
+ break
+ if sample_rate not in OPUS_RATES: # Fallback if still not supported
+ sample_rate = 48000
- blocks = []
- last_block = False
+ # Resample if necessary
+ if sample_rate != audio["sample_rate"]:
+ waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate)
- while not last_block:
- header = flac_io.read(4)
- last_block = (header[0] & 0x80) != 0
- block_type = header[0] & 0x7F
- block_length = struct.unpack('>I', b'\x00' + header[1:])[0]
- block_data = flac_io.read(block_length)
+ # Create in-memory WAV buffer
+ wav_buffer = io.BytesIO()
+ torchaudio.save(wav_buffer, waveform, sample_rate, format="WAV")
+ wav_buffer.seek(0) # Rewind for reading
- if block_type == 4 or block_type == 1:
- pass
- else:
- header = bytes([(header[0] & (~0x80))]) + header[1:]
- blocks.append(header + block_data)
+ # Use PyAV to convert and add metadata
+ input_container = av.open(wav_buffer)
- blocks.append(create_vorbis_comment_block(comment_dict, last_block=True))
+ # Create output with specified format
+ output_buffer = io.BytesIO()
+ output_container = av.open(output_buffer, mode='w', format=format)
- new_flac_io = io.BytesIO()
- new_flac_io.write(b'fLaC')
- for block in blocks:
- new_flac_io.write(block)
+ # Set metadata on the container
+ for key, value in metadata.items():
+ output_container.metadata[key] = value
- new_flac_io.write(flac_io.read())
- return new_flac_io
+ # Set up the output stream with appropriate properties
+ input_container.streams.audio[0]
+ if format == "opus":
+ out_stream = output_container.add_stream("libopus", rate=sample_rate)
+ if quality == "64k":
+ out_stream.bit_rate = 64000
+ elif quality == "96k":
+ out_stream.bit_rate = 96000
+ elif quality == "128k":
+ out_stream.bit_rate = 128000
+ elif quality == "192k":
+ out_stream.bit_rate = 192000
+ elif quality == "320k":
+ out_stream.bit_rate = 320000
+ elif format == "mp3":
+ out_stream = output_container.add_stream("libmp3lame", rate=sample_rate)
+ if quality == "V0":
+ #TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool
+ out_stream.codec_context.qscale = 1
+ elif quality == "128k":
+ out_stream.bit_rate = 128000
+ elif quality == "320k":
+ out_stream.bit_rate = 320000
+ else: #format == "flac":
+ out_stream = output_container.add_stream("flac", rate=sample_rate)
+ # Copy frames from input to output
+ for frame in input_container.decode(audio=0):
+ frame.pts = None # Let PyAV handle timestamps
+ output_container.mux(out_stream.encode(frame))
+
+ # Flush encoder
+ output_container.mux(out_stream.encode(None))
+
+ # Close containers
+ output_container.close()
+ input_container.close()
+
+ # Write the output to file
+ output_buffer.seek(0)
+ with open(output_path, 'wb') as f:
+ f.write(output_buffer.getbuffer())
+
+ results.append({
+ "filename": file,
+ "subfolder": subfolder,
+ "type": self.type
+ })
+ counter += 1
+
+ return { "ui": { "audio": results } }
+
class SaveAudio:
def __init__(self):
self.output_dir = folder_paths.get_output_directory()
@@ -150,50 +211,70 @@ class SaveAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": { "audio": ("AUDIO", ),
- "filename_prefix": ("STRING", {"default": "audio/ComfyUI"})},
+ "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
+ },
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}
RETURN_TYPES = ()
- FUNCTION = "save_audio"
+ FUNCTION = "save_flac"
OUTPUT_NODE = True
CATEGORY = "audio"
- def save_audio(self, audio, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None):
- filename_prefix += self.prefix_append
- full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
- results = list()
+ def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None):
+ return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo)
- metadata = {}
- if not args.disable_metadata:
- if prompt is not None:
- metadata["prompt"] = json.dumps(prompt)
- if extra_pnginfo is not None:
- for x in extra_pnginfo:
- metadata[x] = json.dumps(extra_pnginfo[x])
+class SaveAudioMP3:
+ def __init__(self):
+ self.output_dir = folder_paths.get_output_directory()
+ self.type = "output"
+ self.prefix_append = ""
- for (batch_number, waveform) in enumerate(audio["waveform"].cpu()):
- filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
- file = f"{filename_with_batch_num}_{counter:05}_.flac"
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "audio": ("AUDIO", ),
+ "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
+ "quality": (["V0", "128k", "320k"], {"default": "V0"}),
+ },
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
+ }
- buff = io.BytesIO()
- torchaudio.save(buff, waveform, audio["sample_rate"], format="FLAC")
+ RETURN_TYPES = ()
+ FUNCTION = "save_mp3"
- buff = insert_or_replace_vorbis_comment(buff, metadata)
+ OUTPUT_NODE = True
- with open(os.path.join(full_output_folder, file), 'wb') as f:
- f.write(buff.getbuffer())
+ CATEGORY = "audio"
- results.append({
- "filename": file,
- "subfolder": subfolder,
- "type": self.type
- })
- counter += 1
+ def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"):
+ return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
- return { "ui": { "audio": results } }
+class SaveAudioOpus:
+ def __init__(self):
+ self.output_dir = folder_paths.get_output_directory()
+ self.type = "output"
+ self.prefix_append = ""
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "audio": ("AUDIO", ),
+ "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}),
+ "quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}),
+ },
+ "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
+ }
+
+ RETURN_TYPES = ()
+ FUNCTION = "save_opus"
+
+ OUTPUT_NODE = True
+
+ CATEGORY = "audio"
+
+ def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"):
+ return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality)
class PreviewAudio(SaveAudio):
def __init__(self):
@@ -245,7 +326,20 @@ NODE_CLASS_MAPPINGS = {
"VAEEncodeAudio": VAEEncodeAudio,
"VAEDecodeAudio": VAEDecodeAudio,
"SaveAudio": SaveAudio,
+ "SaveAudioMP3": SaveAudioMP3,
+ "SaveAudioOpus": SaveAudioOpus,
"LoadAudio": LoadAudio,
"PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio,
}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "EmptyLatentAudio": "Empty Latent Audio",
+ "VAEEncodeAudio": "VAE Encode Audio",
+ "VAEDecodeAudio": "VAE Decode Audio",
+ "PreviewAudio": "Preview Audio",
+ "LoadAudio": "Load Audio",
+ "SaveAudio": "Save Audio (FLAC)",
+ "SaveAudioMP3": "Save Audio (MP3)",
+ "SaveAudioOpus": "Save Audio (Opus)",
+}
diff --git a/comfy_extras/nodes_camera_trajectory.py b/comfy_extras/nodes_camera_trajectory.py
new file mode 100644
index 000000000..5e0e39f91
--- /dev/null
+++ b/comfy_extras/nodes_camera_trajectory.py
@@ -0,0 +1,218 @@
+import nodes
+import torch
+import numpy as np
+from einops import rearrange
+import comfy.model_management
+
+
+
+MAX_RESOLUTION = nodes.MAX_RESOLUTION
+
+CAMERA_DICT = {
+ "base_T_norm": 1.5,
+ "base_angle": np.pi/3,
+ "Static": { "angle":[0., 0., 0.], "T":[0., 0., 0.]},
+ "Pan Up": { "angle":[0., 0., 0.], "T":[0., -1., 0.]},
+ "Pan Down": { "angle":[0., 0., 0.], "T":[0.,1.,0.]},
+ "Pan Left": { "angle":[0., 0., 0.], "T":[-1.,0.,0.]},
+ "Pan Right": { "angle":[0., 0., 0.], "T": [1.,0.,0.]},
+ "Zoom In": { "angle":[0., 0., 0.], "T": [0.,0.,2.]},
+ "Zoom Out": { "angle":[0., 0., 0.], "T": [0.,0.,-2.]},
+ "Anti Clockwise (ACW)": { "angle": [0., 0., -1.], "T":[0., 0., 0.]},
+ "ClockWise (CW)": { "angle": [0., 0., 1.], "T":[0., 0., 0.]},
+}
+
+
+def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
+
+ def get_relative_pose(cam_params):
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
+ """
+ abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
+ abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
+ cam_to_origin = 0
+ target_cam_c2w = np.array([
+ [1, 0, 0, 0],
+ [0, 1, 0, -cam_to_origin],
+ [0, 0, 1, 0],
+ [0, 0, 0, 1]
+ ])
+ abs2rel = target_cam_c2w @ abs_w2cs[0]
+ ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
+ ret_poses = np.array(ret_poses, dtype=np.float32)
+ return ret_poses
+
+ """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
+ """
+ cam_params = [Camera(cam_param) for cam_param in cam_params]
+
+ sample_wh_ratio = width / height
+ pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
+
+ if pose_wh_ratio > sample_wh_ratio:
+ resized_ori_w = height * pose_wh_ratio
+ for cam_param in cam_params:
+ cam_param.fx = resized_ori_w * cam_param.fx / width
+ else:
+ resized_ori_h = width / pose_wh_ratio
+ for cam_param in cam_params:
+ cam_param.fy = resized_ori_h * cam_param.fy / height
+
+ intrinsic = np.asarray([[cam_param.fx * width,
+ cam_param.fy * height,
+ cam_param.cx * width,
+ cam_param.cy * height]
+ for cam_param in cam_params], dtype=np.float32)
+
+ K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
+ c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
+ c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
+ plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
+ plucker_embedding = plucker_embedding[None]
+ plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
+ return plucker_embedding
+
+class Camera(object):
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
+ """
+ def __init__(self, entry):
+ fx, fy, cx, cy = entry[1:5]
+ self.fx = fx
+ self.fy = fy
+ self.cx = cx
+ self.cy = cy
+ c2w_mat = np.array(entry[7:]).reshape(4, 4)
+ self.c2w_mat = c2w_mat
+ self.w2c_mat = np.linalg.inv(c2w_mat)
+
+def ray_condition(K, c2w, H, W, device):
+ """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
+ """
+ # c2w: B, V, 4, 4
+ # K: B, V, 4
+
+ B = K.shape[0]
+
+ j, i = torch.meshgrid(
+ torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
+ torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
+ indexing='ij'
+ )
+ i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
+ j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
+
+ fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
+
+ zs = torch.ones_like(i) # [B, HxW]
+ xs = (i - cx) / fx * zs
+ ys = (j - cy) / fy * zs
+ zs = zs.expand_as(ys)
+
+ directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
+ directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
+
+ rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
+ rays_o = c2w[..., :3, 3] # B, V, 3
+ rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
+ # c2w @ dirctions
+ rays_dxo = torch.cross(rays_o, rays_d)
+ plucker = torch.cat([rays_dxo, rays_d], dim=-1)
+ plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
+ # plucker = plucker.permute(0, 1, 4, 2, 3)
+ return plucker
+
+def get_camera_motion(angle, T, speed, n=81):
+ def compute_R_form_rad_angle(angles):
+ theta_x, theta_y, theta_z = angles
+ Rx = np.array([[1, 0, 0],
+ [0, np.cos(theta_x), -np.sin(theta_x)],
+ [0, np.sin(theta_x), np.cos(theta_x)]])
+
+ Ry = np.array([[np.cos(theta_y), 0, np.sin(theta_y)],
+ [0, 1, 0],
+ [-np.sin(theta_y), 0, np.cos(theta_y)]])
+
+ Rz = np.array([[np.cos(theta_z), -np.sin(theta_z), 0],
+ [np.sin(theta_z), np.cos(theta_z), 0],
+ [0, 0, 1]])
+
+ R = np.dot(Rz, np.dot(Ry, Rx))
+ return R
+ RT = []
+ for i in range(n):
+ _angle = (i/n)*speed*(CAMERA_DICT["base_angle"])*angle
+ R = compute_R_form_rad_angle(_angle)
+ _T=(i/n)*speed*(CAMERA_DICT["base_T_norm"])*(T.reshape(3,1))
+ _RT = np.concatenate([R,_T], axis=1)
+ RT.append(_RT)
+ RT = np.stack(RT)
+ return RT
+
+class WanCameraEmbedding:
+ @classmethod
+ def INPUT_TYPES(cls):
+ return {
+ "required": {
+ "camera_pose":(["Static","Pan Up","Pan Down","Pan Left","Pan Right","Zoom In","Zoom Out","Anti Clockwise (ACW)", "ClockWise (CW)"],{"default":"Static"}),
+ "width": ("INT", {"default": 832, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
+ "height": ("INT", {"default": 480, "min": 16, "max": MAX_RESOLUTION, "step": 16}),
+ "length": ("INT", {"default": 81, "min": 1, "max": MAX_RESOLUTION, "step": 4}),
+ },
+ "optional":{
+ "speed":("FLOAT",{"default":1.0, "min": 0, "max": 10.0, "step": 0.1}),
+ "fx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
+ "fy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.000000001}),
+ "cx":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
+ "cy":("FLOAT",{"default":0.5, "min": 0, "max": 1, "step": 0.01}),
+ }
+
+ }
+
+ RETURN_TYPES = ("WAN_CAMERA_EMBEDDING","INT","INT","INT")
+ RETURN_NAMES = ("camera_embedding","width","height","length")
+ FUNCTION = "run"
+ CATEGORY = "camera"
+
+ def run(self, camera_pose, width, height, length, speed=1.0, fx=0.5, fy=0.5, cx=0.5, cy=0.5):
+ """
+ Use Camera trajectory as extrinsic parameters to calculate Plücker embeddings (Sitzmannet al., 2021)
+ Adapted from https://github.com/aigc-apps/VideoX-Fun/blob/main/comfyui/comfyui_nodes.py
+ """
+ motion_list = [camera_pose]
+ speed = speed
+ angle = np.array(CAMERA_DICT[motion_list[0]]["angle"])
+ T = np.array(CAMERA_DICT[motion_list[0]]["T"])
+ RT = get_camera_motion(angle, T, speed, length)
+
+ trajs=[]
+ for cp in RT.tolist():
+ traj=[fx,fy,cx,cy,0,0]
+ traj.extend(cp[0])
+ traj.extend(cp[1])
+ traj.extend(cp[2])
+ traj.extend([0,0,0,1])
+ trajs.append(traj)
+
+ cam_params = np.array([[float(x) for x in pose] for pose in trajs])
+ cam_params = np.concatenate([np.zeros_like(cam_params[:, :1]), cam_params], 1)
+ control_camera_video = process_pose_params(cam_params, width=width, height=height)
+ control_camera_video = control_camera_video.permute([3, 0, 1, 2]).unsqueeze(0).to(device=comfy.model_management.intermediate_device())
+
+ control_camera_video = torch.concat(
+ [
+ torch.repeat_interleave(control_camera_video[:, :, 0:1], repeats=4, dim=2),
+ control_camera_video[:, :, 1:]
+ ], dim=2
+ ).transpose(1, 2)
+
+ # Reshape, transpose, and view into desired shape
+ b, f, c, h, w = control_camera_video.shape
+ control_camera_video = control_camera_video.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
+ control_camera_video = control_camera_video.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
+
+ return (control_camera_video, width, height, length)
+
+
+NODE_CLASS_MAPPINGS = {
+ "WanCameraEmbedding": WanCameraEmbedding,
+}
diff --git a/comfy_extras/nodes_cfg.py b/comfy_extras/nodes_cfg.py
new file mode 100644
index 000000000..1fb686644
--- /dev/null
+++ b/comfy_extras/nodes_cfg.py
@@ -0,0 +1,45 @@
+import torch
+
+# https://github.com/WeichenFan/CFG-Zero-star
+def optimized_scale(positive, negative):
+ positive_flat = positive.reshape(positive.shape[0], -1)
+ negative_flat = negative.reshape(negative.shape[0], -1)
+
+ # Calculate dot production
+ dot_product = torch.sum(positive_flat * negative_flat, dim=1, keepdim=True)
+
+ # Squared norm of uncondition
+ squared_norm = torch.sum(negative_flat ** 2, dim=1, keepdim=True) + 1e-8
+
+ # st_star = v_cond^T * v_uncond / ||v_uncond||^2
+ st_star = dot_product / squared_norm
+
+ return st_star.reshape([positive.shape[0]] + [1] * (positive.ndim - 1))
+
+class CFGZeroStar:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"model": ("MODEL",),
+ }}
+ RETURN_TYPES = ("MODEL",)
+ RETURN_NAMES = ("patched_model",)
+ FUNCTION = "patch"
+ CATEGORY = "advanced/guidance"
+
+ def patch(self, model):
+ m = model.clone()
+ def cfg_zero_star(args):
+ guidance_scale = args['cond_scale']
+ x = args['input']
+ cond_p = args['cond_denoised']
+ uncond_p = args['uncond_denoised']
+ out = args["denoised"]
+ alpha = optimized_scale(x - cond_p, x - uncond_p)
+
+ return out + uncond_p * (alpha - 1.0) + guidance_scale * uncond_p * (1.0 - alpha)
+ m.set_model_sampler_post_cfg_function(cfg_zero_star)
+ return (m, )
+
+NODE_CLASS_MAPPINGS = {
+ "CFGZeroStar": CFGZeroStar
+}
diff --git a/comfy_extras/nodes_cond.py b/comfy_extras/nodes_cond.py
index 4c3a1d5bf..58c16f621 100644
--- a/comfy_extras/nodes_cond.py
+++ b/comfy_extras/nodes_cond.py
@@ -20,6 +20,30 @@ class CLIPTextEncodeControlnet:
c.append(n)
return (c, )
+class T5TokenizerOptions:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "clip": ("CLIP", ),
+ "min_padding": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
+ "min_length": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1}),
+ }
+ }
+
+ CATEGORY = "_for_testing/conditioning"
+ RETURN_TYPES = ("CLIP",)
+ FUNCTION = "set_options"
+
+ def set_options(self, clip, min_padding, min_length):
+ clip = clip.clone()
+ for t5_type in ["t5xxl", "pile_t5xl", "t5base", "mt5xl", "umt5xxl"]:
+ clip.set_tokenizer_option("{}_min_padding".format(t5_type), min_padding)
+ clip.set_tokenizer_option("{}_min_length".format(t5_type), min_length)
+
+ return (clip, )
+
NODE_CLASS_MAPPINGS = {
- "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet
+ "CLIPTextEncodeControlnet": CLIPTextEncodeControlnet,
+ "T5TokenizerOptions": T5TokenizerOptions,
}
diff --git a/comfy_extras/nodes_cosmos.py b/comfy_extras/nodes_cosmos.py
index d88773e25..bd35ddb06 100644
--- a/comfy_extras/nodes_cosmos.py
+++ b/comfy_extras/nodes_cosmos.py
@@ -1,6 +1,8 @@
import nodes
import torch
import comfy.model_management
+import comfy.utils
+
class EmptyCosmosLatentVideo:
@classmethod
@@ -16,8 +18,65 @@ class EmptyCosmosLatentVideo:
def generate(self, width, height, length, batch_size=1):
latent = torch.zeros([batch_size, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
- return ({"samples":latent}, )
+ return ({"samples": latent}, )
+
+
+def vae_encode_with_padding(vae, image, width, height, length, padding=0):
+ pixels = comfy.utils.common_upscale(image[..., :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
+ pixel_len = min(pixels.shape[0], length)
+ padded_length = min(length, (((pixel_len - 1) // 8) + 1 + padding) * 8 - 7)
+ padded_pixels = torch.ones((padded_length, height, width, 3)) * 0.5
+ padded_pixels[:pixel_len] = pixels[:pixel_len]
+ latent_len = ((pixel_len - 1) // 8) + 1
+ latent_temp = vae.encode(padded_pixels)
+ return latent_temp[:, :, :latent_len]
+
+
+class CosmosImageToVideoLatent:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"vae": ("VAE", ),
+ "width": ("INT", {"default": 1280, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
+ "height": ("INT", {"default": 704, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
+ "length": ("INT", {"default": 121, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 8}),
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
+ },
+ "optional": {"start_image": ("IMAGE", ),
+ "end_image": ("IMAGE", ),
+ }}
+
+
+ RETURN_TYPES = ("LATENT",)
+ FUNCTION = "encode"
+
+ CATEGORY = "conditioning/inpaint"
+
+ def encode(self, vae, width, height, length, batch_size, start_image=None, end_image=None):
+ latent = torch.zeros([1, 16, ((length - 1) // 8) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
+ if start_image is None and end_image is None:
+ out_latent = {}
+ out_latent["samples"] = latent
+ return (out_latent,)
+
+ mask = torch.ones([latent.shape[0], 1, ((length - 1) // 8) + 1, latent.shape[-2], latent.shape[-1]], device=comfy.model_management.intermediate_device())
+
+ if start_image is not None:
+ latent_temp = vae_encode_with_padding(vae, start_image, width, height, length, padding=1)
+ latent[:, :, :latent_temp.shape[-3]] = latent_temp
+ mask[:, :, :latent_temp.shape[-3]] *= 0.0
+
+ if end_image is not None:
+ latent_temp = vae_encode_with_padding(vae, end_image, width, height, length, padding=0)
+ latent[:, :, -latent_temp.shape[-3]:] = latent_temp
+ mask[:, :, -latent_temp.shape[-3]:] *= 0.0
+
+ out_latent = {}
+ out_latent["samples"] = latent.repeat((batch_size, ) + (1,) * (latent.ndim - 1))
+ out_latent["noise_mask"] = mask.repeat((batch_size, ) + (1,) * (mask.ndim - 1))
+ return (out_latent,)
+
NODE_CLASS_MAPPINGS = {
"EmptyCosmosLatentVideo": EmptyCosmosLatentVideo,
+ "CosmosImageToVideoLatent": CosmosImageToVideoLatent,
}
diff --git a/comfy_extras/nodes_custom_sampler.py b/comfy_extras/nodes_custom_sampler.py
index c7ff9a4d8..3e5be3d3c 100644
--- a/comfy_extras/nodes_custom_sampler.py
+++ b/comfy_extras/nodes_custom_sampler.py
@@ -1,3 +1,4 @@
+import math
import comfy.samplers
import comfy.sample
from comfy.k_diffusion import sampling as k_diffusion_sampling
@@ -231,6 +232,73 @@ class FlipSigmas:
sigmas[0] = 0.0001
return (sigmas,)
+class SetFirstSigma:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {"sigmas": ("SIGMAS", ),
+ "sigma": ("FLOAT", {"default": 136.0, "min": 0.0, "max": 20000.0, "step": 0.001, "round": False}),
+ }
+ }
+ RETURN_TYPES = ("SIGMAS",)
+ CATEGORY = "sampling/custom_sampling/sigmas"
+
+ FUNCTION = "set_first_sigma"
+
+ def set_first_sigma(self, sigmas, sigma):
+ sigmas = sigmas.clone()
+ sigmas[0] = sigma
+ return (sigmas, )
+
+class ExtendIntermediateSigmas:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required":
+ {"sigmas": ("SIGMAS", ),
+ "steps": ("INT", {"default": 2, "min": 1, "max": 100}),
+ "start_at_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20000.0, "step": 0.01, "round": False}),
+ "end_at_sigma": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 20000.0, "step": 0.01, "round": False}),
+ "spacing": (['linear', 'cosine', 'sine'],),
+ }
+ }
+ RETURN_TYPES = ("SIGMAS",)
+ CATEGORY = "sampling/custom_sampling/sigmas"
+
+ FUNCTION = "extend"
+
+ def extend(self, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str):
+ if start_at_sigma < 0:
+ start_at_sigma = float("inf")
+
+ interpolator = {
+ 'linear': lambda x: x,
+ 'cosine': lambda x: torch.sin(x*math.pi/2),
+ 'sine': lambda x: 1 - torch.cos(x*math.pi/2)
+ }[spacing]
+
+ # linear space for our interpolation function
+ x = torch.linspace(0, 1, steps + 1, device=sigmas.device)[1:-1]
+ computed_spacing = interpolator(x)
+
+ extended_sigmas = []
+ for i in range(len(sigmas) - 1):
+ sigma_current = sigmas[i]
+ sigma_next = sigmas[i+1]
+
+ extended_sigmas.append(sigma_current)
+
+ if end_at_sigma <= sigma_current <= start_at_sigma:
+ interpolated_steps = computed_spacing * (sigma_next - sigma_current) + sigma_current
+ extended_sigmas.extend(interpolated_steps.tolist())
+
+ # Add the last sigma value
+ if len(sigmas) > 0:
+ extended_sigmas.append(sigmas[-1])
+
+ extended_sigmas = torch.FloatTensor(extended_sigmas)
+
+ return (extended_sigmas,)
+
class KSamplerSelect:
@classmethod
def INPUT_TYPES(s):
@@ -436,7 +504,7 @@ class SamplerCustom:
return {"required":
{"model": ("MODEL",),
"add_noise": ("BOOLEAN", {"default": True}),
- "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
+ "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True}),
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}),
"positive": ("CONDITIONING", ),
"negative": ("CONDITIONING", ),
@@ -587,10 +655,16 @@ class DisableNoise:
class RandomNoise(DisableNoise):
@classmethod
def INPUT_TYPES(s):
- return {"required":{
- "noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
- }
- }
+ return {
+ "required": {
+ "noise_seed": ("INT", {
+ "default": 0,
+ "min": 0,
+ "max": 0xffffffffffffffff,
+ "control_after_generate": True,
+ }),
+ }
+ }
def get_noise(self, noise_seed):
return (Noise_RandomNoise(noise_seed),)
@@ -710,6 +784,8 @@ NODE_CLASS_MAPPINGS = {
"SplitSigmas": SplitSigmas,
"SplitSigmasDenoise": SplitSigmasDenoise,
"FlipSigmas": FlipSigmas,
+ "SetFirstSigma": SetFirstSigma,
+ "ExtendIntermediateSigmas": ExtendIntermediateSigmas,
"CFGGuider": CFGGuider,
"DualCFGGuider": DualCFGGuider,
diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py
index 2ae23f735..ad6c15f37 100644
--- a/comfy_extras/nodes_flux.py
+++ b/comfy_extras/nodes_flux.py
@@ -38,7 +38,26 @@ class FluxGuidance:
return (c, )
+class FluxDisableGuidance:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "conditioning": ("CONDITIONING", ),
+ }}
+
+ RETURN_TYPES = ("CONDITIONING",)
+ FUNCTION = "append"
+
+ CATEGORY = "advanced/conditioning/flux"
+ DESCRIPTION = "This node completely disables the guidance embed on Flux and Flux like models"
+
+ def append(self, conditioning):
+ c = node_helpers.conditioning_set_values(conditioning, {"guidance": None})
+ return (c, )
+
+
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeFlux": CLIPTextEncodeFlux,
"FluxGuidance": FluxGuidance,
+ "FluxDisableGuidance": FluxDisableGuidance,
}
diff --git a/comfy_extras/nodes_fresca.py b/comfy_extras/nodes_fresca.py
new file mode 100644
index 000000000..ee310c874
--- /dev/null
+++ b/comfy_extras/nodes_fresca.py
@@ -0,0 +1,100 @@
+# Code based on https://github.com/WikiChao/FreSca (MIT License)
+import torch
+import torch.fft as fft
+
+
+def Fourier_filter(x, scale_low=1.0, scale_high=1.5, freq_cutoff=20):
+ """
+ Apply frequency-dependent scaling to an image tensor using Fourier transforms.
+
+ Parameters:
+ x: Input tensor of shape (B, C, H, W)
+ scale_low: Scaling factor for low-frequency components (default: 1.0)
+ scale_high: Scaling factor for high-frequency components (default: 1.5)
+ freq_cutoff: Number of frequency indices around center to consider as low-frequency (default: 20)
+
+ Returns:
+ x_filtered: Filtered version of x in spatial domain with frequency-specific scaling applied.
+ """
+ # Preserve input dtype and device
+ dtype, device = x.dtype, x.device
+
+ # Convert to float32 for FFT computations
+ x = x.to(torch.float32)
+
+ # 1) Apply FFT and shift low frequencies to center
+ x_freq = fft.fftn(x, dim=(-2, -1))
+ x_freq = fft.fftshift(x_freq, dim=(-2, -1))
+
+ # Initialize mask with high-frequency scaling factor
+ mask = torch.ones(x_freq.shape, device=device) * scale_high
+ m = mask
+ for d in range(len(x_freq.shape) - 2):
+ dim = d + 2
+ cc = x_freq.shape[dim] // 2
+ f_c = min(freq_cutoff, cc)
+ m = m.narrow(dim, cc - f_c, f_c * 2)
+
+ # Apply low-frequency scaling factor to center region
+ m[:] = scale_low
+
+ # 3) Apply frequency-specific scaling
+ x_freq = x_freq * mask
+
+ # 4) Convert back to spatial domain
+ x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
+ x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
+
+ # 5) Restore original dtype
+ x_filtered = x_filtered.to(dtype)
+
+ return x_filtered
+
+
+class FreSca:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "model": ("MODEL",),
+ "scale_low": ("FLOAT", {"default": 1.0, "min": 0, "max": 10, "step": 0.01,
+ "tooltip": "Scaling factor for low-frequency components"}),
+ "scale_high": ("FLOAT", {"default": 1.25, "min": 0, "max": 10, "step": 0.01,
+ "tooltip": "Scaling factor for high-frequency components"}),
+ "freq_cutoff": ("INT", {"default": 20, "min": 1, "max": 10000, "step": 1,
+ "tooltip": "Number of frequency indices around center to consider as low-frequency"}),
+ }
+ }
+ RETURN_TYPES = ("MODEL",)
+ FUNCTION = "patch"
+ CATEGORY = "_for_testing"
+ DESCRIPTION = "Applies frequency-dependent scaling to the guidance"
+ def patch(self, model, scale_low, scale_high, freq_cutoff):
+ def custom_cfg_function(args):
+ cond = args["conds_out"][0]
+ uncond = args["conds_out"][1]
+
+ guidance = cond - uncond
+ filtered_guidance = Fourier_filter(
+ guidance,
+ scale_low=scale_low,
+ scale_high=scale_high,
+ freq_cutoff=freq_cutoff,
+ )
+ filtered_cond = filtered_guidance + uncond
+
+ return [filtered_cond, uncond]
+
+ m = model.clone()
+ m.set_model_sampler_pre_cfg_function(custom_cfg_function)
+
+ return (m,)
+
+
+NODE_CLASS_MAPPINGS = {
+ "FreSca": FreSca,
+}
+
+NODE_DISPLAY_NAME_MAPPINGS = {
+ "FreSca": "FreSca",
+}
diff --git a/comfy_extras/nodes_hidream.py b/comfy_extras/nodes_hidream.py
new file mode 100644
index 000000000..dfb98597b
--- /dev/null
+++ b/comfy_extras/nodes_hidream.py
@@ -0,0 +1,55 @@
+import folder_paths
+import comfy.sd
+import comfy.model_management
+
+
+class QuadrupleCLIPLoader:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": { "clip_name1": (folder_paths.get_filename_list("text_encoders"), ),
+ "clip_name2": (folder_paths.get_filename_list("text_encoders"), ),
+ "clip_name3": (folder_paths.get_filename_list("text_encoders"), ),
+ "clip_name4": (folder_paths.get_filename_list("text_encoders"), )
+ }}
+ RETURN_TYPES = ("CLIP",)
+ FUNCTION = "load_clip"
+
+ CATEGORY = "advanced/loaders"
+
+ DESCRIPTION = "[Recipes]\n\nhidream: long clip-l, long clip-g, t5xxl, llama_8b_3.1_instruct"
+
+ def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4):
+ clip_path1 = folder_paths.get_full_path_or_raise("text_encoders", clip_name1)
+ clip_path2 = folder_paths.get_full_path_or_raise("text_encoders", clip_name2)
+ clip_path3 = folder_paths.get_full_path_or_raise("text_encoders", clip_name3)
+ clip_path4 = folder_paths.get_full_path_or_raise("text_encoders", clip_name4)
+ clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3, clip_path4], embedding_directory=folder_paths.get_folder_paths("embeddings"))
+ return (clip,)
+
+class CLIPTextEncodeHiDream:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "clip": ("CLIP", ),
+ "clip_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
+ "clip_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
+ "t5xxl": ("STRING", {"multiline": True, "dynamicPrompts": True}),
+ "llama": ("STRING", {"multiline": True, "dynamicPrompts": True})
+ }}
+ RETURN_TYPES = ("CONDITIONING",)
+ FUNCTION = "encode"
+
+ CATEGORY = "advanced/conditioning"
+
+ def encode(self, clip, clip_l, clip_g, t5xxl, llama):
+
+ tokens = clip.tokenize(clip_g)
+ tokens["l"] = clip.tokenize(clip_l)["l"]
+ tokens["t5xxl"] = clip.tokenize(t5xxl)["t5xxl"]
+ tokens["llama"] = clip.tokenize(llama)["llama"]
+ return (clip.encode_from_tokens_scheduled(tokens), )
+
+NODE_CLASS_MAPPINGS = {
+ "QuadrupleCLIPLoader": QuadrupleCLIPLoader,
+ "CLIPTextEncodeHiDream": CLIPTextEncodeHiDream,
+}
diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py
index d6408269f..d7278e7a7 100644
--- a/comfy_extras/nodes_hunyuan.py
+++ b/comfy_extras/nodes_hunyuan.py
@@ -1,4 +1,5 @@
import nodes
+import node_helpers
import torch
import comfy.model_management
@@ -38,7 +39,85 @@ class EmptyHunyuanLatentVideo:
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
return ({"samples":latent}, )
+PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = (
+ "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: "
+ "1. The main content and theme of the video."
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
+ "4. background environment, light, style and atmosphere."
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>\n\n"
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
+)
+
+class TextEncodeHunyuanVideo_ImageToVideo:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {
+ "clip": ("CLIP", ),
+ "clip_vision_output": ("CLIP_VISION_OUTPUT", ),
+ "prompt": ("STRING", {"multiline": True, "dynamicPrompts": True}),
+ "image_interleave": ("INT", {"default": 2, "min": 1, "max": 512, "tooltip": "How much the image influences things vs the text prompt. Higher number means more influence from the text prompt."}),
+ }}
+ RETURN_TYPES = ("CONDITIONING",)
+ FUNCTION = "encode"
+
+ CATEGORY = "advanced/conditioning"
+
+ def encode(self, clip, clip_vision_output, prompt, image_interleave):
+ tokens = clip.tokenize(prompt, llama_template=PROMPT_TEMPLATE_ENCODE_VIDEO_I2V, image_embeds=clip_vision_output.mm_projected, image_interleave=image_interleave)
+ return (clip.encode_from_tokens_scheduled(tokens), )
+
+class HunyuanImageToVideo:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"positive": ("CONDITIONING", ),
+ "vae": ("VAE", ),
+ "width": ("INT", {"default": 848, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
+ "height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
+ "length": ("INT", {"default": 53, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
+ "guidance_type": (["v1 (concat)", "v2 (replace)", "custom"], )
+ },
+ "optional": {"start_image": ("IMAGE", ),
+ }}
+
+ RETURN_TYPES = ("CONDITIONING", "LATENT")
+ RETURN_NAMES = ("positive", "latent")
+ FUNCTION = "encode"
+
+ CATEGORY = "conditioning/video_models"
+
+ def encode(self, positive, vae, width, height, length, batch_size, guidance_type, start_image=None):
+ latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
+ out_latent = {}
+
+ if start_image is not None:
+ start_image = comfy.utils.common_upscale(start_image[:length, :, :, :3].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
+
+ concat_latent_image = vae.encode(start_image)
+ mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype)
+ mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0
+
+ if guidance_type == "v1 (concat)":
+ cond = {"concat_latent_image": concat_latent_image, "concat_mask": mask}
+ elif guidance_type == "v2 (replace)":
+ cond = {'guiding_frame_index': 0}
+ latent[:, :, :concat_latent_image.shape[2]] = concat_latent_image
+ out_latent["noise_mask"] = mask
+ elif guidance_type == "custom":
+ cond = {"ref_latent": concat_latent_image}
+
+ positive = node_helpers.conditioning_set_values(positive, cond)
+
+ out_latent["samples"] = latent
+ return (positive, out_latent)
+
+
+
NODE_CLASS_MAPPINGS = {
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
+ "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
+ "HunyuanImageToVideo": HunyuanImageToVideo,
}
diff --git a/comfy_extras/nodes_hunyuan3d.py b/comfy_extras/nodes_hunyuan3d.py
new file mode 100644
index 000000000..51e45336a
--- /dev/null
+++ b/comfy_extras/nodes_hunyuan3d.py
@@ -0,0 +1,634 @@
+import torch
+import os
+import json
+import struct
+import numpy as np
+from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_from_grid_torch
+import folder_paths
+import comfy.model_management
+from comfy.cli_args import args
+
+
+class EmptyLatentHunyuan3Dv2:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}),
+ "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}),
+ }}
+ RETURN_TYPES = ("LATENT",)
+ FUNCTION = "generate"
+
+ CATEGORY = "latent/3d"
+
+ def generate(self, resolution, batch_size):
+ latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device())
+ return ({"samples": latent, "type": "hunyuan3dv2"}, )
+
+
+class Hunyuan3Dv2Conditioning:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",),
+ }}
+
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
+ RETURN_NAMES = ("positive", "negative")
+
+ FUNCTION = "encode"
+
+ CATEGORY = "conditioning/video_models"
+
+ def encode(self, clip_vision_output):
+ embeds = clip_vision_output.last_hidden_state
+ positive = [[embeds, {}]]
+ negative = [[torch.zeros_like(embeds), {}]]
+ return (positive, negative)
+
+
+class Hunyuan3Dv2ConditioningMultiView:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {},
+ "optional": {"front": ("CLIP_VISION_OUTPUT",),
+ "left": ("CLIP_VISION_OUTPUT",),
+ "back": ("CLIP_VISION_OUTPUT",),
+ "right": ("CLIP_VISION_OUTPUT",), }}
+
+ RETURN_TYPES = ("CONDITIONING", "CONDITIONING")
+ RETURN_NAMES = ("positive", "negative")
+
+ FUNCTION = "encode"
+
+ CATEGORY = "conditioning/video_models"
+
+ def encode(self, front=None, left=None, back=None, right=None):
+ all_embeds = [front, left, back, right]
+ out = []
+ pos_embeds = None
+ for i, e in enumerate(all_embeds):
+ if e is not None:
+ if pos_embeds is None:
+ pos_embeds = get_1d_sincos_pos_embed_from_grid_torch(e.last_hidden_state.shape[-1], torch.arange(4))
+ out.append(e.last_hidden_state + pos_embeds[i].reshape(1, 1, -1))
+
+ embeds = torch.cat(out, dim=1)
+ positive = [[embeds, {}]]
+ negative = [[torch.zeros_like(embeds), {}]]
+ return (positive, negative)
+
+
+class VOXEL:
+ def __init__(self, data):
+ self.data = data
+
+
+class VAEDecodeHunyuan3D:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"samples": ("LATENT", ),
+ "vae": ("VAE", ),
+ "num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}),
+ "octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}),
+ }}
+ RETURN_TYPES = ("VOXEL",)
+ FUNCTION = "decode"
+
+ CATEGORY = "latent/3d"
+
+ def decode(self, vae, samples, num_chunks, octree_resolution):
+ voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution}))
+ return (voxels, )
+
+
+def voxel_to_mesh(voxels, threshold=0.5, device=None):
+ if device is None:
+ device = torch.device("cpu")
+ voxels = voxels.to(device)
+
+ binary = (voxels > threshold).float()
+ padded = torch.nn.functional.pad(binary, (1, 1, 1, 1, 1, 1), 'constant', 0)
+
+ D, H, W = binary.shape
+
+ neighbors = torch.tensor([
+ [0, 0, 1],
+ [0, 0, -1],
+ [0, 1, 0],
+ [0, -1, 0],
+ [1, 0, 0],
+ [-1, 0, 0]
+ ], device=device)
+
+ z, y, x = torch.meshgrid(
+ torch.arange(D, device=device),
+ torch.arange(H, device=device),
+ torch.arange(W, device=device),
+ indexing='ij'
+ )
+ voxel_indices = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
+
+ solid_mask = binary.flatten() > 0
+ solid_indices = voxel_indices[solid_mask]
+
+ corner_offsets = [
+ torch.tensor([
+ [0, 0, 1], [0, 1, 1], [1, 1, 1], [1, 0, 1]
+ ], device=device),
+ torch.tensor([
+ [0, 0, 0], [1, 0, 0], [1, 1, 0], [0, 1, 0]
+ ], device=device),
+ torch.tensor([
+ [0, 1, 0], [1, 1, 0], [1, 1, 1], [0, 1, 1]
+ ], device=device),
+ torch.tensor([
+ [0, 0, 0], [0, 0, 1], [1, 0, 1], [1, 0, 0]
+ ], device=device),
+ torch.tensor([
+ [1, 0, 1], [1, 1, 1], [1, 1, 0], [1, 0, 0]
+ ], device=device),
+ torch.tensor([
+ [0, 1, 0], [0, 1, 1], [0, 0, 1], [0, 0, 0]
+ ], device=device)
+ ]
+
+ all_vertices = []
+ all_indices = []
+
+ vertex_count = 0
+
+ for face_idx, offset in enumerate(neighbors):
+ neighbor_indices = solid_indices + offset
+
+ padded_indices = neighbor_indices + 1
+
+ is_exposed = padded[
+ padded_indices[:, 0],
+ padded_indices[:, 1],
+ padded_indices[:, 2]
+ ] == 0
+
+ if not is_exposed.any():
+ continue
+
+ exposed_indices = solid_indices[is_exposed]
+
+ corners = corner_offsets[face_idx].unsqueeze(0)
+
+ face_vertices = exposed_indices.unsqueeze(1) + corners
+
+ all_vertices.append(face_vertices.reshape(-1, 3))
+
+ num_faces = exposed_indices.shape[0]
+ face_indices = torch.arange(
+ vertex_count,
+ vertex_count + 4 * num_faces,
+ device=device
+ ).reshape(-1, 4)
+
+ all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 1], face_indices[:, 2]], dim=1))
+ all_indices.append(torch.stack([face_indices[:, 0], face_indices[:, 2], face_indices[:, 3]], dim=1))
+
+ vertex_count += 4 * num_faces
+
+ if len(all_vertices) > 0:
+ vertices = torch.cat(all_vertices, dim=0)
+ faces = torch.cat(all_indices, dim=0)
+ else:
+ vertices = torch.zeros((1, 3))
+ faces = torch.zeros((1, 3))
+
+ v_min = 0
+ v_max = max(voxels.shape)
+
+ vertices = vertices - (v_min + v_max) / 2
+
+ scale = (v_max - v_min) / 2
+ if scale > 0:
+ vertices = vertices / scale
+
+ vertices = torch.fliplr(vertices)
+ return vertices, faces
+
+def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None):
+ if device is None:
+ device = torch.device("cpu")
+ voxels = voxels.to(device)
+
+ D, H, W = voxels.shape
+
+ padded = torch.nn.functional.pad(voxels, (1, 1, 1, 1, 1, 1), 'constant', 0)
+ z, y, x = torch.meshgrid(
+ torch.arange(D, device=device),
+ torch.arange(H, device=device),
+ torch.arange(W, device=device),
+ indexing='ij'
+ )
+ cell_positions = torch.stack([z.flatten(), y.flatten(), x.flatten()], dim=1)
+
+ corner_offsets = torch.tensor([
+ [0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0],
+ [0, 0, 1], [1, 0, 1], [0, 1, 1], [1, 1, 1]
+ ], device=device)
+
+ corner_values = torch.zeros((cell_positions.shape[0], 8), device=device)
+ for c, (dz, dy, dx) in enumerate(corner_offsets):
+ corner_values[:, c] = padded[
+ cell_positions[:, 0] + dz,
+ cell_positions[:, 1] + dy,
+ cell_positions[:, 2] + dx
+ ]
+
+ corner_signs = corner_values > threshold
+ has_inside = torch.any(corner_signs, dim=1)
+ has_outside = torch.any(~corner_signs, dim=1)
+ contains_surface = has_inside & has_outside
+
+ active_cells = cell_positions[contains_surface]
+ active_signs = corner_signs[contains_surface]
+ active_values = corner_values[contains_surface]
+
+ if active_cells.shape[0] == 0:
+ return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
+
+ edges = torch.tensor([
+ [0, 1], [0, 2], [0, 4], [1, 3],
+ [1, 5], [2, 3], [2, 6], [3, 7],
+ [4, 5], [4, 6], [5, 7], [6, 7]
+ ], device=device)
+
+ cell_vertices = {}
+ progress = comfy.utils.ProgressBar(100)
+
+ for edge_idx, (e1, e2) in enumerate(edges):
+ progress.update(1)
+ crossing = active_signs[:, e1] != active_signs[:, e2]
+ if not crossing.any():
+ continue
+
+ cell_indices = torch.nonzero(crossing, as_tuple=True)[0]
+
+ v1 = active_values[cell_indices, e1]
+ v2 = active_values[cell_indices, e2]
+
+ t = torch.zeros_like(v1, device=device)
+ denom = v2 - v1
+ valid = denom != 0
+ t[valid] = (threshold - v1[valid]) / denom[valid]
+ t[~valid] = 0.5
+
+ p1 = corner_offsets[e1].float()
+ p2 = corner_offsets[e2].float()
+
+ intersection = p1.unsqueeze(0) + t.unsqueeze(1) * (p2.unsqueeze(0) - p1.unsqueeze(0))
+
+ for i, point in zip(cell_indices.tolist(), intersection):
+ if i not in cell_vertices:
+ cell_vertices[i] = []
+ cell_vertices[i].append(point)
+
+ # Calculate the final vertices as the average of intersection points for each cell
+ vertices = []
+ vertex_lookup = {}
+
+ vert_progress_mod = round(len(cell_vertices)/50)
+
+ for i, points in cell_vertices.items():
+ if not i % vert_progress_mod:
+ progress.update(1)
+
+ if points:
+ vertex = torch.stack(points).mean(dim=0)
+ vertex = vertex + active_cells[i].float()
+ vertex_lookup[tuple(active_cells[i].tolist())] = len(vertices)
+ vertices.append(vertex)
+
+ if not vertices:
+ return torch.zeros((0, 3), device=device), torch.zeros((0, 3), dtype=torch.long, device=device)
+
+ final_vertices = torch.stack(vertices)
+
+ inside_corners_mask = active_signs
+ outside_corners_mask = ~active_signs
+
+ inside_counts = inside_corners_mask.sum(dim=1, keepdim=True).float()
+ outside_counts = outside_corners_mask.sum(dim=1, keepdim=True).float()
+
+ inside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
+ outside_pos = torch.zeros((active_cells.shape[0], 3), device=device)
+
+ for i in range(8):
+ mask_inside = inside_corners_mask[:, i].unsqueeze(1)
+ mask_outside = outside_corners_mask[:, i].unsqueeze(1)
+ inside_pos += corner_offsets[i].float().unsqueeze(0) * mask_inside
+ outside_pos += corner_offsets[i].float().unsqueeze(0) * mask_outside
+
+ inside_pos /= inside_counts
+ outside_pos /= outside_counts
+ gradients = inside_pos - outside_pos
+
+ pos_dirs = torch.tensor([
+ [1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1]
+ ], device=device)
+
+ cross_products = [
+ torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float())
+ for i in range(3) for j in range(i+1, 3)
+ ]
+
+ faces = []
+ all_keys = set(vertex_lookup.keys())
+
+ face_progress_mod = round(len(active_cells)/38*3)
+
+ for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]):
+ dir_i = pos_dirs[i]
+ dir_j = pos_dirs[j]
+ cross_product = cross_products[pair_idx]
+
+ ni_positions = active_cells + dir_i
+ nj_positions = active_cells + dir_j
+ diag_positions = active_cells + dir_i + dir_j
+
+ alignments = torch.matmul(gradients, cross_product)
+
+ valid_quads = []
+ quad_indices = []
+
+ for idx, active_cell in enumerate(active_cells):
+ if not idx % face_progress_mod:
+ progress.update(1)
+ cell_key = tuple(active_cell.tolist())
+ ni_key = tuple(ni_positions[idx].tolist())
+ nj_key = tuple(nj_positions[idx].tolist())
+ diag_key = tuple(diag_positions[idx].tolist())
+
+ if cell_key in all_keys and ni_key in all_keys and nj_key in all_keys and diag_key in all_keys:
+ v0 = vertex_lookup[cell_key]
+ v1 = vertex_lookup[ni_key]
+ v2 = vertex_lookup[nj_key]
+ v3 = vertex_lookup[diag_key]
+
+ valid_quads.append((v0, v1, v2, v3))
+ quad_indices.append(idx)
+
+ for q_idx, (v0, v1, v2, v3) in enumerate(valid_quads):
+ cell_idx = quad_indices[q_idx]
+ if alignments[cell_idx] > 0:
+ faces.append(torch.tensor([v0, v1, v3], device=device, dtype=torch.long))
+ faces.append(torch.tensor([v0, v3, v2], device=device, dtype=torch.long))
+ else:
+ faces.append(torch.tensor([v0, v3, v1], device=device, dtype=torch.long))
+ faces.append(torch.tensor([v0, v2, v3], device=device, dtype=torch.long))
+
+ if faces:
+ faces = torch.stack(faces)
+ else:
+ faces = torch.zeros((0, 3), dtype=torch.long, device=device)
+
+ v_min = 0
+ v_max = max(D, H, W)
+
+ final_vertices = final_vertices - (v_min + v_max) / 2
+
+ scale = (v_max - v_min) / 2
+ if scale > 0:
+ final_vertices = final_vertices / scale
+
+ final_vertices = torch.fliplr(final_vertices)
+
+ return final_vertices, faces
+
+class MESH:
+ def __init__(self, vertices, faces):
+ self.vertices = vertices
+ self.faces = faces
+
+
+class VoxelToMeshBasic:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"voxel": ("VOXEL", ),
+ "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
+ }}
+ RETURN_TYPES = ("MESH",)
+ FUNCTION = "decode"
+
+ CATEGORY = "3d"
+
+ def decode(self, voxel, threshold):
+ vertices = []
+ faces = []
+ for x in voxel.data:
+ v, f = voxel_to_mesh(x, threshold=threshold, device=None)
+ vertices.append(v)
+ faces.append(f)
+
+ return (MESH(torch.stack(vertices), torch.stack(faces)), )
+
+class VoxelToMesh:
+ @classmethod
+ def INPUT_TYPES(s):
+ return {"required": {"voxel": ("VOXEL", ),
+ "algorithm": (["surface net", "basic"], ),
+ "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}),
+ }}
+ RETURN_TYPES = ("MESH",)
+ FUNCTION = "decode"
+
+ CATEGORY = "3d"
+
+ def decode(self, voxel, algorithm, threshold):
+ vertices = []
+ faces = []
+
+ if algorithm == "basic":
+ mesh_function = voxel_to_mesh
+ elif algorithm == "surface net":
+ mesh_function = voxel_to_mesh_surfnet
+
+ for x in voxel.data:
+ v, f = mesh_function(x, threshold=threshold, device=None)
+ vertices.append(v)
+ faces.append(f)
+
+ return (MESH(torch.stack(vertices), torch.stack(faces)), )
+
+
+def save_glb(vertices, faces, filepath, metadata=None):
+ """
+ Save PyTorch tensor vertices and faces as a GLB file without external dependencies.
+
+ Parameters:
+ vertices: torch.Tensor of shape (N, 3) - The vertex coordinates
+ faces: torch.Tensor of shape (M, 3) - The face indices (triangle faces)
+ filepath: str - Output filepath (should end with .glb)
+ """
+
+ # Convert tensors to numpy arrays
+ vertices_np = vertices.cpu().numpy().astype(np.float32)
+ faces_np = faces.cpu().numpy().astype(np.uint32)
+
+ vertices_buffer = vertices_np.tobytes()
+ indices_buffer = faces_np.tobytes()
+
+ def pad_to_4_bytes(buffer):
+ padding_length = (4 - (len(buffer) % 4)) % 4
+ return buffer + b'\x00' * padding_length
+
+ vertices_buffer_padded = pad_to_4_bytes(vertices_buffer)
+ indices_buffer_padded = pad_to_4_bytes(indices_buffer)
+
+ buffer_data = vertices_buffer_padded + indices_buffer_padded
+
+ vertices_byte_length = len(vertices_buffer)
+ vertices_byte_offset = 0
+ indices_byte_length = len(indices_buffer)
+ indices_byte_offset = len(vertices_buffer_padded)
+
+ gltf = {
+ "asset": {"version": "2.0", "generator": "ComfyUI"},
+ "buffers": [
+ {
+ "byteLength": len(buffer_data)
+ }
+ ],
+ "bufferViews": [
+ {
+ "buffer": 0,
+ "byteOffset": vertices_byte_offset,
+ "byteLength": vertices_byte_length,
+ "target": 34962 # ARRAY_BUFFER
+ },
+ {
+ "buffer": 0,
+ "byteOffset": indices_byte_offset,
+ "byteLength": indices_byte_length,
+ "target": 34963 # ELEMENT_ARRAY_BUFFER
+ }
+ ],
+ "accessors": [
+ {
+ "bufferView": 0,
+ "byteOffset": 0,
+ "componentType": 5126, # FLOAT
+ "count": len(vertices_np),
+ "type": "VEC3",
+ "max": vertices_np.max(axis=0).tolist(),
+ "min": vertices_np.min(axis=0).tolist()
+ },
+ {
+ "bufferView": 1,
+ "byteOffset": 0,
+ "componentType": 5125, # UNSIGNED_INT
+ "count": faces_np.size,
+ "type": "SCALAR"
+ }
+ ],
+ "meshes": [
+ {
+ "primitives": [
+ {
+ "attributes": {
+ "POSITION": 0
+ },
+ "indices": 1,
+ "mode": 4 # TRIANGLES
+ }
+ ]
+ }
+ ],
+ "nodes": [
+ {
+ "mesh": 0
+ }
+ ],
+ "scenes": [
+ {
+ "nodes": [0]
+ }
+ ],
+ "scene": 0
+ }
+
+ if metadata is not None:
+ gltf["asset"]["extras"] = metadata
+
+ # Convert the JSON to bytes
+ gltf_json = json.dumps(gltf).encode('utf8')
+
+ def pad_json_to_4_bytes(buffer):
+ padding_length = (4 - (len(buffer) % 4)) % 4
+ return buffer + b' ' * padding_length
+
+ gltf_json_padded = pad_json_to_4_bytes(gltf_json)
+
+ # Create the GLB header
+ # Magic glTF
+ glb_header = struct.pack('<4sII', b'glTF', 2, 12 + 8 + len(gltf_json_padded) + 8 + len(buffer_data))
+
+ # Create JSON chunk header (chunk type 0)
+ json_chunk_header = struct.pack(' 'SVG':
+ return SVG(self.data + other.data)
+
+ @staticmethod
+ def combine_all(svgs: list['SVG']) -> 'SVG':
+ all_svgs_list: list[BytesIO] = []
+ for svg_item in svgs:
+ all_svgs_list.extend(svg_item.data)
+ return SVG(all_svgs_list)
+
+class SaveSVGNode:
+ """
+ Save SVG files on disk.
+ """
+
+ def __init__(self):
+ self.output_dir = folder_paths.get_output_directory()
+ self.type = "output"
+ self.prefix_append = ""
+
+ RETURN_TYPES = ()
+ DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value
+ FUNCTION = "save_svg"
+ CATEGORY = "image/save" # Changed
+ OUTPUT_NODE = True
+
+ @classmethod
+ def INPUT_TYPES(s):
+ return {
+ "required": {
+ "svg": ("SVG",), # Changed
+ "filename_prefix": ("STRING", {"default": "svg/ComfyUI", "tooltip": "The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."})
+ },
+ "hidden": {
+ "prompt": "PROMPT",
+ "extra_pnginfo": "EXTRA_PNGINFO"
+ }
+ }
+
+ def save_svg(self, svg: SVG, filename_prefix="svg/ComfyUI", prompt=None, extra_pnginfo=None):
+ filename_prefix += self.prefix_append
+ full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir)
+ results = list()
+
+ # Prepare metadata JSON
+ metadata_dict = {}
+ if prompt is not None:
+ metadata_dict["prompt"] = prompt
+ if extra_pnginfo is not None:
+ metadata_dict.update(extra_pnginfo)
+
+ # Convert metadata to JSON string
+ metadata_json = json.dumps(metadata_dict, indent=2) if metadata_dict else None
+
+ for batch_number, svg_bytes in enumerate(svg.data):
+ filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
+ file = f"{filename_with_batch_num}_{counter:05}_.svg"
+
+ # Read SVG content
+ svg_bytes.seek(0)
+ svg_content = svg_bytes.read().decode('utf-8')
+
+ # Inject metadata if available
+ if metadata_json:
+ # Create metadata element with CDATA section
+ metadata_element = f"""
+
+
+ """
+ # Insert metadata after opening svg tag using regex with a replacement function
+ def replacement(match):
+ # match.group(1) contains the captured