diff --git a/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat new file mode 100644 index 000000000..ed00583b6 --- /dev/null +++ b/.ci/windows_nvidia_base_files/advanced/run_nvidia_gpu_disable_api_nodes.bat @@ -0,0 +1,3 @@ +..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes +echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest. +pause diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index 3cf2717b7..6556677e0 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -8,13 +8,15 @@ body: Before submitting a **Bug Report**, please ensure the following: - **1:** You are running the latest version of ComfyUI. - - **2:** You have looked at the existing bug reports and made sure this isn't already reported. + - **2:** You have your ComfyUI logs and relevant workflow on hand and will post them in this bug report. - **3:** You confirmed that the bug is not caused by a custom node. You can disable all custom nodes by passing - `--disable-all-custom-nodes` command line argument. + `--disable-all-custom-nodes` command line argument. If you have custom node try updating them to the latest version. - **4:** This is an actual bug in ComfyUI, not just a support question. A bug is when you can specify exact steps to replicate what went wrong and others will be able to repeat your steps and see the same issue happen. - If unsure, ask on the [ComfyUI Matrix Space](https://app.element.io/#/room/%23comfyui_space%3Amatrix.org) or the [Comfy Org Discord](https://discord.gg/comfyorg) first. + ## Very Important + + Please make sure that you post ALL your ComfyUI logs in the bug report. A bug report without logs will likely be ignored. - type: checkboxes id: custom-nodes-test attributes: diff --git a/.github/PULL_REQUEST_TEMPLATE/api-node.md b/.github/PULL_REQUEST_TEMPLATE/api-node.md new file mode 100644 index 000000000..c1f1bafb1 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE/api-node.md @@ -0,0 +1,21 @@ + + +## API Node PR Checklist + +### Scope +- [ ] **Is API Node Change** + +### Pricing & Billing +- [ ] **Need pricing update** +- [ ] **No pricing update** + +If **Need pricing update**: +- [ ] Metronome rate cards updated +- [ ] Auto‑billing tests updated and passing + +### QA +- [ ] **QA done** +- [ ] **QA not required** + +### Comms +- [ ] Informed **Kosinkadink** diff --git a/.github/workflows/api-node-template.yml b/.github/workflows/api-node-template.yml new file mode 100644 index 000000000..fdb81c0c5 --- /dev/null +++ b/.github/workflows/api-node-template.yml @@ -0,0 +1,58 @@ +name: Append API Node PR template + +on: + pull_request_target: + types: [opened, reopened, synchronize, ready_for_review] + paths: + - 'comfy_api_nodes/**' # only run if these files changed + +permissions: + contents: read + pull-requests: write + +jobs: + inject: + runs-on: ubuntu-latest + steps: + - name: Ensure template exists and append to PR body + uses: actions/github-script@v7 + with: + script: | + const { owner, repo } = context.repo; + const number = context.payload.pull_request.number; + const templatePath = '.github/PULL_REQUEST_TEMPLATE/api-node.md'; + const marker = ''; + + const { data: pr } = await github.rest.pulls.get({ owner, repo, pull_number: number }); + + let templateText; + try { + const res = await github.rest.repos.getContent({ + owner, + repo, + path: templatePath, + ref: pr.base.ref + }); + const buf = Buffer.from(res.data.content, res.data.encoding || 'base64'); + templateText = buf.toString('utf8'); + } catch (e) { + core.setFailed(`Required PR template not found at "${templatePath}" on ${pr.base.ref}. Please add it to the repo.`); + return; + } + + // Enforce the presence of the marker inside the template (for idempotence) + if (!templateText.includes(marker)) { + core.setFailed(`Template at "${templatePath}" does not contain the required marker:\n${marker}\nAdd it so we can detect duplicates safely.`); + return; + } + + // If the PR already contains the marker, do not append again. + const body = pr.body || ''; + if (body.includes(marker)) { + core.info('Template already present in PR body; nothing to inject.'); + return; + } + + const newBody = (body ? body + '\n\n' : '') + templateText + '\n'; + await github.rest.pulls.update({ owner, repo, pull_number: number, body: newBody }); + core.notice('API Node template appended to PR description.'); diff --git a/CODEOWNERS b/CODEOWNERS index f4f456f32..237b176ce 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,6 +1,2 @@ # Admins -# From upstream -* @comfyanonymous -* @kosinkadink -# For the fork -* @doctorpangloss +* @comfyanonymous @kosinkadink @guill @doctorpangloss diff --git a/QUANTIZATION.md b/QUANTIZATION.md new file mode 100644 index 000000000..1693e13f3 --- /dev/null +++ b/QUANTIZATION.md @@ -0,0 +1,168 @@ +# The Comfy guide to Quantization + + +## How does quantization work? + +Quantization aims to map a high-precision value x_f to a lower precision format with minimal loss in accuracy. These smaller formats then serve to reduce the models memory footprint and increase throughput by using specialized hardware. + +When simply converting a value from FP16 to FP8 using the round-nearest method we might hit two issues: +- The dynamic range of FP16 (-65,504, 65,504) far exceeds FP8 formats like E4M3 (-448, 448) or E5M2 (-57,344, 57,344), potentially resulting in clipped values +- The original values are concentrated in a small range (e.g. -1,1) leaving many FP8-bits "unused" + +By using a scaling factor, we aim to map these values into the quantized-dtype range, making use of the full spectrum. One of the easiest approaches, and common, is using per-tensor absolute-maximum scaling. + +``` +absmax = max(abs(tensor)) +scale = amax / max_dynamic_range_low_precision + +# Quantization +tensor_q = (tensor / scale).to(low_precision_dtype) + +# De-Quantization +tensor_dq = tensor_q.to(fp16) * scale + +tensor_dq ~ tensor +``` + +Given that additional information (scaling factor) is needed to "interpret" the quantized values, we describe those as derived datatypes. + + +## Quantization in Comfy + +``` +QuantizedTensor (torch.Tensor subclass) + ↓ __torch_dispatch__ +Two-Level Registry (generic + layout handlers) + ↓ +MixedPrecisionOps + Metadata Detection +``` + +### Representation + +To represent these derived datatypes, ComfyUI uses a subclass of torch.Tensor to implements these using the `QuantizedTensor` class found in `comfy/quant_ops.py` + +A `Layout` class defines how a specific quantization format behaves: +- Required parameters +- Quantize method +- De-Quantize method + +```python +from comfy.quant_ops import QuantizedLayout + +class MyLayout(QuantizedLayout): + @classmethod + def quantize(cls, tensor, **kwargs): + # Convert to quantized format + qdata = ... + params = {'scale': ..., 'orig_dtype': tensor.dtype} + return qdata, params + + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + return qdata.to(orig_dtype) * scale +``` + +To then run operations using these QuantizedTensors we use two registry systems to define supported operations. +The first is a **generic registry** that handles operations common to all quantized formats (e.g., `.to()`, `.clone()`, `.reshape()`). + +The second registry is layout-specific and allows to implement fast-paths like nn.Linear. +```python +from comfy.quant_ops import register_layout_op + +@register_layout_op(torch.ops.aten.linear.default, MyLayout) +def my_linear(func, args, kwargs): + # Extract tensors, call optimized kernel + ... +``` +When `torch.nn.functional.linear()` is called with QuantizedTensor arguments, `__torch_dispatch__` automatically routes to the registered implementation. +For any unsupported operation, QuantizedTensor will fallback to call `dequantize` and dispatch using the high-precision implementation. + + +### Mixed Precision + +The `MixedPrecisionOps` class (lines 542-648 in `comfy/ops.py`) enables per-layer quantization decisions, allowing different layers in a model to use different precisions. This is activated when a model config contains a `layer_quant_config` dictionary that specifies which layers should be quantized and how. + +**Architecture:** + +```python +class MixedPrecisionOps(disable_weight_init): + _layer_quant_config = {} # Maps layer names to quantization configs + _compute_dtype = torch.bfloat16 # Default compute / dequantize precision +``` + +**Key mechanism:** + +The custom `Linear._load_from_state_dict()` method inspects each layer during model loading: +- If the layer name is **not** in `_layer_quant_config`: load weight as regular tensor in `_compute_dtype` +- If the layer name **is** in `_layer_quant_config`: + - Load weight as `QuantizedTensor` with the specified layout (e.g., `TensorCoreFP8Layout`) + - Load associated quantization parameters (scales, block_size, etc.) + +**Why it's needed:** + +Not all layers tolerate quantization equally. Sensitive operations like final projections can be kept in higher precision, while compute-heavy matmuls are quantized. This provides most of the performance benefits while maintaining quality. + +The system is selected in `pick_operations()` when `model_config.layer_quant_config` is present, making it the highest-priority operation mode. + + +## Checkpoint Format + +Quantized checkpoints are stored as standard safetensors files with quantized weight tensors and associated scaling parameters, plus a `_quantization_metadata` JSON entry describing the quantization scheme. + +The quantized checkpoint will contain the same layers as the original checkpoint but: +- The weights are stored as quantized values, sometimes using a different storage datatype. E.g. uint8 container for fp8. +- For each quantized weight a number of additional scaling parameters are stored alongside depending on the recipe. +- We store a metadata.json in the metadata of the final safetensor containing the `_quantization_metadata` describing which layers are quantized and what layout has been used. + +### Scaling Parameters details +We define 4 possible scaling parameters that should cover most recipes in the near-future: +- **weight_scale**: quantization scalers for the weights +- **weight_scale_2**: global scalers in the context of double scaling +- **pre_quant_scale**: scalers used for smoothing salient weights +- **input_scale**: quantization scalers for the activations + +| Format | Storage dtype | weight_scale | weight_scale_2 | pre_quant_scale | input_scale | +|--------|---------------|--------------|----------------|-----------------|-------------| +| float8_e4m3fn | float32 | float32 (scalar) | - | - | float32 (scalar) | + +You can find the defined formats in `comfy/quant_ops.py` (QUANT_ALGOS). + +### Quantization Metadata + +The metadata stored alongside the checkpoint contains: +- **format_version**: String to define a version of the standard +- **layers**: A dictionary mapping layer names to their quantization format. The format string maps to the definitions found in `QUANT_ALGOS`. + +Example: +```json +{ + "_quantization_metadata": { + "format_version": "1.0", + "layers": { + "model.layers.0.mlp.up_proj": "float8_e4m3fn", + "model.layers.0.mlp.down_proj": "float8_e4m3fn", + "model.layers.1.mlp.up_proj": "float8_e4m3fn" + } + } +} +``` + + +## Creating Quantized Checkpoints + +To create compatible checkpoints, use any quantization tool provided the output follows the checkpoint format described above and uses a layout defined in `QUANT_ALGOS`. + +### Weight Quantization + +Weight quantization is straightforward - compute the scaling factor directly from the weight tensor using the absolute maximum method described earlier. Each layer's weights are quantized independently and stored with their corresponding `weight_scale` parameter. + +### Calibration (for Activation Quantization) + +Activation quantization (e.g., for FP8 Tensor Core operations) requires `input_scale` parameters that cannot be determined from static weights alone. Since activation values depend on actual inputs, we use **post-training calibration (PTQ)**: + +1. **Collect statistics**: Run inference on N representative samples +2. **Track activations**: Record the absolute maximum (`amax`) of inputs to each quantized layer +3. **Compute scales**: Derive `input_scale` from collected statistics +4. **Store in checkpoint**: Save `input_scale` parameters alongside weights + +The calibration dataset should be representative of your target use case. For diffusion models, this typically means a diverse set of prompts and generation parameters. \ No newline at end of file diff --git a/README.md b/README.md index 98490ab07..eb10c5c5a 100644 --- a/README.md +++ b/README.md @@ -359,6 +359,8 @@ There are two kinds of custom nodes: vanilla custom nodes, which generally expec ComfyUI-Manager is a popular extension to help you install and manage other custom nodes. To install it, you will need `git` on your system. +#### Manual Install + The installation process for ComfyUI-Manager requires two steps: installing its Python dependencies, and then cloning its code into the `custom_nodes` directory. 1. **Install dependencies.** @@ -381,6 +383,34 @@ The installation process for ComfyUI-Manager requires two steps: installing its 3. **Restart ComfyUI.** After the cloning is complete, restart ComfyUI. You should now see a "Manager" button in the menu. +### PyPi Install + +[ComfyUI-Manager](https://github.com/Comfy-Org/ComfyUI-Manager/tree/manager-v4) + +**ComfyUI-Manager** is an extension that allows you to easily install, update, and manage custom nodes for ComfyUI. + +### Setup + +1. Install the manager dependencies: + ```bash + pip install -r manager_requirements.txt + ``` + +2. Enable the manager with the `--enable-manager` flag when running ComfyUI: + ```bash + python main.py --enable-manager + ``` + +### Command Line Options + +| Flag | Description | +|------|-------------| +| `--enable-manager` | Enable ComfyUI-Manager | +| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) | +| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) | + + + ### Vanilla Custom Nodes Clone the repository containing the custom nodes into `custom_nodes/` in your working directory and install its requirements, or use the manager. diff --git a/comfy/__init__.py b/comfy/__init__.py index ba0c5a01a..848463ce4 100644 --- a/comfy/__init__.py +++ b/comfy/__init__.py @@ -1,6 +1,6 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.3.66" +__version__ = "0.3.76" # This deals with workspace issues from comfy_compatibility.workspace import auto_patch_workspace_and_restart diff --git a/comfy/app/frontend_management.py b/comfy/app/frontend_management.py index a5b937163..0163c65e7 100644 --- a/comfy/app/frontend_management.py +++ b/comfy/app/frontend_management.py @@ -11,13 +11,16 @@ import importlib.metadata from dataclasses import dataclass from functools import cached_property from pathlib import Path -from typing import TypedDict, Optional +from typing import Dict, TypedDict, Optional +from aiohttp import web +from importlib.metadata import version import requests from typing_extensions import NotRequired from ..cli_args import DEFAULT_VERSION_STRING from ..cmd.folder_paths import add_model_folder_path # pylint: disable=import-error + logger = logging.getLogger(__name__) REQUEST_TIMEOUT = 10 # seconds @@ -172,7 +175,53 @@ class FrontendManager: return "" @classmethod - def templates_path(cls) -> str: + def template_asset_map(cls) -> Optional[Dict[str, str]]: + """Return a mapping of template asset names to their absolute paths.""" + try: + from comfyui_workflow_templates import ( + get_asset_path, + iter_templates, + ) + except ImportError: + logging.error( + f""" +********** ERROR *********** + +comfyui-workflow-templates is not installed. + +{frontend_install_warning_message()} + +********** ERROR *********** +""".strip() + ) + return None + + try: + template_entries = list(iter_templates()) + except Exception as exc: + logging.error(f"Failed to enumerate workflow templates: {exc}") + return None + + asset_map: Dict[str, str] = {} + try: + for entry in template_entries: + for asset in entry.assets: + asset_map[asset.filename] = get_asset_path( + entry.template_id, asset.filename + ) + except Exception as exc: + logging.error(f"Failed to resolve template asset paths: {exc}") + return None + + if not asset_map: + logging.error("No workflow template assets found. Did the packages install correctly?") + return None + + return asset_map + + @classmethod + def legacy_templates_path(cls) -> Optional[str]: + """Return the legacy templates directory shipped inside the meta package.""" try: import comfyui_workflow_templates @@ -299,3 +348,18 @@ class FrontendManager: logger.info("Falling back to the default frontend.") check_frontend_version() return cls.default_frontend_path() + + @classmethod + def template_asset_handler(cls): + assets = cls.template_asset_map() + if not assets: + return None + + async def serve_template(request: web.Request) -> web.StreamResponse: + rel_path = request.match_info.get("path", "") + target = assets.get(rel_path) + if target is None: + raise web.HTTPNotFound() + return web.FileResponse(target) + + return serve_template diff --git a/comfy/app/subgraph_manager.py b/comfy/app/subgraph_manager.py new file mode 100644 index 000000000..dbe404541 --- /dev/null +++ b/comfy/app/subgraph_manager.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from typing import TypedDict +import os +import folder_paths +import glob +from aiohttp import web +import hashlib + + +class Source: + custom_node = "custom_node" + +class SubgraphEntry(TypedDict): + source: str + """ + Source of subgraph - custom_nodes vs templates. + """ + path: str + """ + Relative path of the subgraph file. + For custom nodes, will be the relative directory like /subgraphs/.json + """ + name: str + """ + Name of subgraph file. + """ + info: CustomNodeSubgraphEntryInfo + """ + Additional info about subgraph; in the case of custom_nodes, will contain nodepack name + """ + data: str + +class CustomNodeSubgraphEntryInfo(TypedDict): + node_pack: str + """Node pack name.""" + +class SubgraphManager: + def __init__(self): + self.cached_custom_node_subgraphs: dict[SubgraphEntry] | None = None + + async def load_entry_data(self, entry: SubgraphEntry): + with open(entry['path'], 'r') as f: + entry['data'] = f.read() + return entry + + async def sanitize_entry(self, entry: SubgraphEntry | None, remove_data=False) -> SubgraphEntry | None: + if entry is None: + return None + entry = entry.copy() + entry.pop('path', None) + if remove_data: + entry.pop('data', None) + return entry + + async def sanitize_entries(self, entries: dict[str, SubgraphEntry], remove_data=False) -> dict[str, SubgraphEntry]: + entries = entries.copy() + for key in list(entries.keys()): + entries[key] = await self.sanitize_entry(entries[key], remove_data) + return entries + + async def get_custom_node_subgraphs(self, loadedModules, force_reload=False): + # if not forced to reload and cached, return cache + if not force_reload and self.cached_custom_node_subgraphs is not None: + return self.cached_custom_node_subgraphs + # Load subgraphs from custom nodes + subfolder = "subgraphs" + subgraphs_dict: dict[SubgraphEntry] = {} + + for folder in folder_paths.get_folder_paths("custom_nodes"): + pattern = os.path.join(folder, f"*/{subfolder}/*.json") + matched_files = glob.glob(pattern) + for file in matched_files: + # replace backslashes with forward slashes + file = file.replace('\\', '/') + info: CustomNodeSubgraphEntryInfo = { + "node_pack": "custom_nodes." + file.split('/')[-3] + } + source = Source.custom_node + # hash source + path to make sure id will be as unique as possible, but + # reproducible across backend reloads + id = hashlib.sha256(f"{source}{file}".encode()).hexdigest() + entry: SubgraphEntry = { + "source": Source.custom_node, + "name": os.path.splitext(os.path.basename(file))[0], + "path": file, + "info": info, + } + subgraphs_dict[id] = entry + self.cached_custom_node_subgraphs = subgraphs_dict + return subgraphs_dict + + async def get_custom_node_subgraph(self, id: str, loadedModules): + subgraphs = await self.get_custom_node_subgraphs(loadedModules) + entry: SubgraphEntry = subgraphs.get(id, None) + if entry is not None and entry.get('data', None) is None: + await self.load_entry_data(entry) + return entry + + def add_routes(self, routes, loadedModules): + @routes.get("/global_subgraphs") + async def get_global_subgraphs(request): + subgraphs_dict = await self.get_custom_node_subgraphs(loadedModules) + # NOTE: we may want to include other sources of global subgraphs such as templates in the future; + # that's the reasoning for the current implementation + return web.json_response(await self.sanitize_entries(subgraphs_dict, remove_data=True)) + + @routes.get("/global_subgraphs/{id}") + async def get_global_subgraph(request): + id = request.match_info.get("id", None) + subgraph = await self.get_custom_node_subgraph(id, loadedModules) + return web.json_response(await self.sanitize_entry(subgraph)) diff --git a/comfy/app/user_manager.py b/comfy/app/user_manager.py index 35a130267..eab310575 100644 --- a/comfy/app/user_manager.py +++ b/comfy/app/user_manager.py @@ -60,6 +60,9 @@ class UserManager(): user = "default" if args.multi_user and "comfy-user" in request.headers: user = request.headers["comfy-user"] + # Block System Users (use same error message to prevent probing) + if user.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise KeyError("Unknown user: " + user) if user not in self.users: raise KeyError("Unknown user: " + user) @@ -67,15 +70,16 @@ class UserManager(): return user def get_request_user_filepath(self, request, file, type="userdata", create_dir=True): - user_directory = folder_paths.get_user_directory() - if type == "userdata": - root_dir = user_directory + root_dir = folder_paths.get_user_directory() else: raise KeyError("Unknown filepath type:" + type) user = self.get_request_user_id(request) - path = user_root = os.path.abspath(os.path.join(root_dir, user)) + user_root = folder_paths.get_public_user_directory(user) + if user_root is None: + return None + path = user_root # prevent leaving /{type} if os.path.commonpath((root_dir, user_root)) != root_dir: @@ -102,7 +106,11 @@ class UserManager(): name = name.strip() if not name: raise ValueError("username not provided") + if name.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise ValueError("System User prefix not allowed") user_id = re.sub("[^a-zA-Z0-9-_]+", '-', name) + if user_id.startswith(folder_paths.SYSTEM_USER_PREFIX): + raise ValueError("System User prefix not allowed") user_id = user_id + "_" + str(uuid.uuid4()) self.users[user_id] = name @@ -133,7 +141,10 @@ class UserManager(): if username in self.users.values(): return web.json_response({"error": "Duplicate username."}, status=400) - user_id = self.add_user(username) + try: + user_id = self.add_user(username) + except ValueError as e: + return web.json_response({"error": str(e)}, status=400) return web.json_response(user_id) @routes.get("/userdata") @@ -425,7 +436,7 @@ class UserManager(): return source dest = get_user_data_path(request, check_exists=False, param="dest") - if not isinstance(source, str): + if not isinstance(dest, str): return dest overwrite = request.query.get("overwrite", 'true') != "false" diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index 8967632b6..6ada7e1b9 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -415,7 +415,8 @@ class ControlNet(nn.Module): out_middle = [] if self.num_classes is not None: - assert y.shape[0] == x.shape[0], "There may be a mismatch between the ControlNet and Diffusion models being used" + if y is None: + raise ValueError("y is None, did you try using a controlnet for SDXL on SD1?") emb = emb + self.label_emb(y) h = x diff --git a/comfy/cli_args.py b/comfy/cli_args.py index d8857597b..39dece44a 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -105,6 +105,7 @@ def _create_parser() -> EnhancedConfigArgParser: cache_group.add_argument("--cache-classic", action="store_true", help="WARNING: Unused. Use the old style (aggressive) caching.") cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.") cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.") + cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB") attn_group = parser.add_mutually_exclusive_group() attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.") @@ -120,6 +121,10 @@ def _create_parser() -> EnhancedConfigArgParser: upcast = parser.add_mutually_exclusive_group() upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.") upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.") + parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.") + manager_group = parser.add_mutually_exclusive_group() + manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.") + manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager") vram_group = parser.add_mutually_exclusive_group() vram_group.add_argument("--gpu-only", action="store_true", help="Store and run everything (text encoders/CLIP models, etc... on the GPU).") @@ -132,7 +137,8 @@ def _create_parser() -> EnhancedConfigArgParser: vram_group.add_argument("--cpu", action="store_true", help="To use the CPU for everything (slow).") parser.add_argument("--reserve-vram", type=float, default=None, help="Set the amount of vram in GB you want to reserve for use by your OS/other software. By default some amount is reserved depending on your OS.") - parser.add_argument("--async-offload", action="store_true", help="Use async weight offloading.") + parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") + parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") parser.add_argument("--default-hashing-function", type=str, choices=['md5', 'sha1', 'sha256', 'sha512'], default='sha256', help="Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256.") parser.add_argument("--disable-smart-memory", action="store_true", @@ -141,6 +147,7 @@ def _create_parser() -> EnhancedConfigArgParser: help="Make pytorch use slower deterministic algorithms when it can. Note that this might not make images deterministic in all cases.") parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help=f"Enable some untested and potentially quality deteriorating optimizations. Pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {' '.join([f.value for f in PerformanceFeature])}", default=set()) + parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.") parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.") parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.") @@ -155,7 +162,7 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument("--disable-all-custom-nodes", action="store_true", help="Disable loading all custom nodes.") parser.add_argument("--whitelist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to load even when --disable-all-custom-nodes is enabled.") parser.add_argument("--blacklist-custom-nodes", type=str, action=FlattenAndAppendAction, nargs='+', default=[], help="Specify custom node folders to never load. Accepts shell-style globs.") - parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes.") + parser.add_argument("--disable-api-nodes", action="store_true", help="Disable loading all api nodes. Also prevents the frontend from communicating with the internet.") parser.add_argument("--enable-eval", action="store_true", help="Enable nodes that can evaluate Python code in workflows.") parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index ee377936f..1f7054e3a 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -73,6 +73,7 @@ class Configuration(dict): temp_directory (Optional[str]): Temporary directory for processing. input_directory (Optional[str]): Directory for input files. When this is a relative path, it will be looked up relative to the cwd (current working directory) and all of the base_paths. auto_launch (bool): Auto-launch UI in the default browser. Defaults to False. + disable_auto_launch (bool): Disable auto launching the browser. cuda_device (Optional[int]): CUDA device ID. None means default device. cuda_malloc (bool): Enable cudaMallocAsync. Defaults to True in applicable setups. disable_cuda_malloc (bool): Disable cudaMallocAsync. @@ -100,6 +101,7 @@ class Configuration(dict): disable_ipex_optimize (bool): Disable IPEX optimization for Intel GPUs. preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "auto". cache_lru (int): Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM. + cache_ram (float): Use RAM pressure caching with the specified headroom threshold. use_split_cross_attention (bool): Use split cross-attention optimization. use_quad_cross_attention (bool): Use sub-quadratic cross-attention optimization. use_pytorch_cross_attention (bool): Use PyTorch's cross-attention function. @@ -147,14 +149,19 @@ class Configuration(dict): user_directory (Optional[str]): Set the ComfyUI user directory with an absolute path. log_stdout (bool): Send normal process output to stdout instead of stderr (default) panic_when (list[str]): List of fully qualified exception class names to panic (sys.exit(1)) when a workflow raises it. + enable_manager (bool): Enable the ComfyUI-Manager feature. + disable_manager_ui (bool): Disables only the ComfyUI-Manager UI. + enable_manager_legacy_ui (bool): Enables the legacy UI of ComfyUI-Manager. enable_compress_response_body (bool): Enable compressing response body. workflows (list[str]): Execute the API workflow(s) specified in the provided files. For each workflow, its outputs will be printed to a line to standard out. Application logging will be redirected to standard error. Use `-` to signify standard in. + disable_pinned_memory (bool): Disable pinned memory use. fp8_e8m0fnu_unet (bool): Store unet weights in fp8_e8m0fnu. bf16_text_enc (bool): Store text encoder weights in bf16. supports_fp8_compute (bool): ComfyUI will act like if the device supports fp8 compute. cache_classic (bool): WARNING: Unused. Use the old style (aggressive) caching. cache_none (bool): Reduced RAM/VRAM usage at the expense of executing every node for each run. - async_offload (bool): Use async weight offloading. + async_offload (Optional[int]): Use async weight offloading. An optional argument controls the amount of offload streams. + disable_async_offload (bool): Disable async weight offloading. force_non_blocking (bool): Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows. default_hashing_function (str): Allows you to choose the hash function to use for duplicate filename / contents comparison. Default is sha256. mmap_torch_files (bool): Use mmap when loading ckpt/pt files. @@ -189,6 +196,7 @@ class Configuration(dict): self.temp_directory: Optional[str] = None self.input_directory: Optional[str] = None self.auto_launch: bool = False + self.disable_auto_launch: bool = False self.cuda_device: Optional[int] = None self.cuda_malloc: bool = True self.disable_cuda_malloc: bool = True @@ -272,13 +280,19 @@ class Configuration(dict): self.user_directory: Optional[str] = None self.panic_when: list[str] = [] self.workflows: list[str] = [] + self.enable_manager: bool = False + self.disable_manager_ui: bool = False + self.enable_manager_legacy_ui: bool = False + self.disable_pinned_memory: bool = False self.fp8_e8m0fnu_unet: bool = False self.bf16_text_enc: bool = False self.supports_fp8_compute: bool = False self.cache_classic: bool = False self.cache_none: bool = False - self.async_offload: bool = False + self.cache_ram: float = 0.0 + self.async_offload: Optional[int] = None + self.disable_async_offload: bool = False self.force_non_blocking: bool = False self.default_hashing_function: str = 'sha256' self.mmap_torch_files: bool = False @@ -289,7 +303,7 @@ class Configuration(dict): self.comfy_api_base: str = "https://api.comfy.org" self.database_url: str = db_config() self.default_device: Optional[int] = None - self.block_runtime_package_installation = None + self.block_runtime_package_installation: bool = False self.enable_eval: Optional[bool] = False self.enable_video_to_image_fallback: bool = False diff --git a/comfy/cmd/cuda_malloc.py b/comfy/cmd/cuda_malloc.py index 78e570a37..b491d31e3 100644 --- a/comfy/cmd/cuda_malloc.py +++ b/comfy/cmd/cuda_malloc.py @@ -71,18 +71,23 @@ def cuda_malloc_supported(): return True +# todo: is this really how we want to get the torch version? +version = "" + +try: + torch_spec = importlib.util.find_spec("torch") + for folder in torch_spec.submodule_search_locations: + ver_file = os.path.join(folder, "version.py") + if os.path.isfile(ver_file): + spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + version = module.__version__ +except: + pass + if not args.cuda_malloc: try: - version = "" - torch_spec = importlib.util.find_spec("torch") - for folder in torch_spec.submodule_search_locations: - ver_file = os.path.join(folder, "version.py") - if os.path.isfile(ver_file): - spec = importlib.util.spec_from_file_location("torch_version_import", ver_file) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - version = module.__version__ - if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc args.cuda_malloc = cuda_malloc_supported() @@ -97,3 +102,6 @@ if args.cuda_malloc and not args.disable_cuda_malloc: env_var += ",backend:cudaMallocAsync" os.environ['PYTORCH_CUDA_ALLOC_CONF'] = env_var + +def get_torch_version_noimport(): + return str(version) diff --git a/comfy/cmd/execution.py b/comfy/cmd/execution.py index 8d3ba84ae..e75a95be7 100644 --- a/comfy/cmd/execution.py +++ b/comfy/cmd/execution.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing_extensions import NotRequired, TypedDict, NamedTuple + from .main_pre import tracer import asyncio @@ -21,14 +23,24 @@ from typing import List, Optional, Tuple, Literal import torch from opentelemetry.trace import get_current_span, StatusCode, Status -from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, \ - make_locked_method_func -from comfy_api.latest import io +from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func +from comfy_api.latest import io, _io from comfy_compatibility.vanilla import vanilla_environment_node_execution_hooks -from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID, \ - DependencyAwareCache, \ - BasicCache -from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker +from comfy_execution.caching import ( + BasicCache, + CacheKeySetID, + CacheKeySetInputSignature, + NullCache, + HierarchicalCache, + LRUCache, + RAMPressureCache, +) +from comfy_execution.graph import ( + DynamicPrompt, + ExecutionBlocker, + ExecutionList, + get_input_info, +) from comfy_execution.graph_types import FrozenTopologicalSort from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, \ @@ -94,7 +106,7 @@ class IsChangedCache: return self.is_changed[node_id] # Intentionally do not use cached outputs here. We only want constants in IS_CHANGED - input_data_all, _, hidden_inputs = get_input_data(node["inputs"], class_def, node_id, None) + input_data_all, _, v3_data = get_input_data(node["inputs"], class_def, node_id, None) try: is_changed = await _async_map_node_over_list(self.prompt_id, node_id, class_def, input_data_all, is_changed_name) is_changed = await resolve_map_node_over_list_results(is_changed) @@ -106,48 +118,58 @@ class IsChangedCache: return self.is_changed[node_id] +class CacheEntry(NamedTuple): + ui: dict + outputs: list + + class CacheType(Enum): CLASSIC = 0 LRU = 1 - DEPENDENCY_AWARE = 2 + NONE = 2 + RAM_PRESSURE = 3 class CacheSet: - def __init__(self, cache_type=None, cache_size=None): - if cache_type == CacheType.DEPENDENCY_AWARE: - self.init_dependency_aware_cache() + def __init__(self, cache_type=None, cache_args: Optional[CacheArgs] = None): + if cache_args is None: + cache_args = {} + if cache_type == CacheType.NONE: + self.init_null_cache() logger.info("Disabling intermediate node cache.") + elif cache_type == CacheType.RAM_PRESSURE: + cache_ram = cache_args.get("ram", 16.0) + self.init_ram_cache(cache_ram) + logging.info("Using RAM pressure cache.") elif cache_type == CacheType.LRU: - if cache_size is None: - cache_size = 0 + cache_size = cache_args.get("lru", 0) self.init_lru_cache(cache_size) logger.info("Using LRU cache") else: self.init_classic_cache() - self.all = [self.outputs, self.ui, self.objects] + self.all = [self.outputs, self.objects] # Performs like the old cache -- dump data ASAP def init_classic_cache(self): self.outputs = HierarchicalCache(CacheKeySetInputSignature) - self.ui = HierarchicalCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def init_lru_cache(self, cache_size): self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) - self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) - # only hold cached items while the decendents have not executed - def init_dependency_aware_cache(self): - self.outputs = DependencyAwareCache(CacheKeySetInputSignature) - self.ui = DependencyAwareCache(CacheKeySetInputSignature) - self.objects = DependencyAwareCache(CacheKeySetID) + def init_ram_cache(self, min_headroom): + self.outputs = RAMPressureCache(CacheKeySetInputSignature) + self.objects = HierarchicalCache(CacheKeySetID) + + def init_null_cache(self): + self.outputs = NullCache() + self.objects = NullCache() def recursive_debug_dump(self): result = { "outputs": self.outputs.recursive_debug_dump(), - "ui": self.ui.recursive_debug_dump(), } return result @@ -155,12 +177,14 @@ class CacheSet: SENSITIVE_EXTRA_DATA_KEYS = ("auth_token_comfy_org", "api_key_comfy_org") -def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, extra_data=None): +def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=None, extra_data=None): if extra_data is None: extra_data = {} is_v3 = issubclass(class_def, _ComfyNodeInternal) + v3_data: io.V3Data = {} + schema = None if is_v3: - valid_inputs, schema = class_def.INPUT_TYPES(include_hidden=False, return_schema=True) + valid_inputs, schema, v3_data = class_def.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) else: valid_inputs = class_def.INPUT_TYPES() input_data_all = {} @@ -177,17 +201,17 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e if is_link(input_data) and (not input_info or not input_info.get("rawLink", False)): input_unique_id = input_data[0] output_index = input_data[1] - if outputs is None: + if execution_list is None: mark_missing() continue # This might be a lazily-evaluated input - cached_output = outputs.get(input_unique_id) - if cached_output is None: + cached = execution_list.get_cache(input_unique_id, unique_id) + if cached is None or cached.outputs is None: mark_missing() continue - if output_index >= len(cached_output): + if output_index >= len(cached.outputs): mark_missing() continue - obj = cached_output[output_index] + obj = cached.outputs[output_index] input_data_all[x] = obj elif input_category is not None: input_data_all[x] = [input_data] @@ -223,7 +247,8 @@ def get_input_data(inputs, class_def, unique_id, outputs=None, dynprompt=None, e input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)] if h[x] == "API_KEY_COMFY_ORG": input_data_all[x] = [extra_data.get("api_key_comfy_org", None)] - return input_data_all, missing_keys, hidden_inputs_v3 + v3_data["hidden_inputs"] = hidden_inputs_v3 + return input_data_all, missing_keys, v3_data def map_node_over_list(obj, input_data_all: typing.Dict[str, typing.Any], func: str, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None): @@ -244,12 +269,12 @@ async def resolve_map_node_over_list_results(results): @tracer.start_as_current_span("Execute Node") -async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None, execution_list=None, executed=None): +async def _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None, execution_list=None, executed=None): with context_set_execution_list_and_inputs(FrozenTopologicalSort.from_topological_sort(execution_list) if execution_list is not None else None, frozenset(executed) if executed is not None else None): - return await __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb, hidden_inputs) + return await __async_map_node_over_list(prompt_id=prompt_id, unique_id=unique_id, obj=obj, input_data_all=input_data_all, func=func, allow_interrupt=allow_interrupt, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data) -async def __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None): +async def __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, func, allow_interrupt=False, execution_block_cb=None, pre_execute_cb=None, v3_data=None): span = get_current_span() class_type = obj.__class__.__name__ span.set_attribute("class_type", class_type) @@ -312,13 +337,16 @@ async def __async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, if is_class(obj): type_obj = obj obj.VALIDATE_CLASS() - class_clone = obj.PREPARE_CLASS_CLONE(hidden_inputs) + class_clone = obj.PREPARE_CLASS_CLONE(v3_data) # otherwise, use class instance to populate/reuse some fields else: type_obj = type(obj) type_obj.VALIDATE_CLASS() - class_clone = type_obj.PREPARE_CLASS_CLONE(hidden_inputs) + class_clone = type_obj.PREPARE_CLASS_CLONE(v3_data) f = make_locked_method_func(type_obj, func, class_clone) + # in case of dynamic inputs, restructure inputs to expected nested dict + if v3_data is not None: + inputs = _io.build_nested_inputs(inputs, v3_data) # V1 else: f = getattr(obj, func) @@ -376,8 +404,8 @@ def merge_result_data(results, obj): return output -async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, hidden_inputs=None, inputs=None, execution_list=None, executed=None): - return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs, execution_list=execution_list, executed=executed) +async def get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=None, pre_execute_cb=None, v3_data=None, execution_list=None, executed=None): + return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data, execution_list=execution_list, executed=executed) has_pending_task = any(isinstance(r, asyncio.Task) and not r.done() for r in return_values) if has_pending_task: return return_values, {}, False, has_pending_task @@ -484,7 +512,7 @@ async def execute(server: ExecutorToClientProgress, dynprompt: DynamicPrompt, ca return await _execute(server, dynprompt, caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) -async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list: ExecutionList, pending_subgraph_results, pending_async_nodes) -> RecursiveExecutionTuple: +async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_item: str, extra_data, executed, prompt_id, execution_list: ExecutionList, pending_subgraph_results, pending_async_nodes, ui_outputs) -> RecursiveExecutionTuple: unique_id = current_item real_node_id = dynprompt.get_real_node_id(unique_id) display_node_id = dynprompt.get_display_node_id(unique_id) @@ -492,11 +520,15 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i inputs = dynprompt.get_node(unique_id)['inputs'] class_type = dynprompt.get_node(unique_id)['class_type'] class_def = get_nodes().NODE_CLASS_MAPPINGS[class_type] - if caches.outputs.get(unique_id) is not None: + cached = caches.outputs.get(unique_id) + if cached is not None: if server.client_id is not None: - cached_output = caches.ui.get(unique_id) or {} - server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": cached_output.get("output", None), "prompt_id": prompt_id}, server.client_id) + cached_ui = cached.ui or {} + server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output", None), "prompt_id": prompt_id}, server.client_id) + if cached.ui is not None: + ui_outputs[unique_id] = cached.ui get_progress_state().finish_progress(unique_id) + execution_list.cache_update(unique_id, cached) return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None) input_data_all = None @@ -526,8 +558,8 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i for r in result: if is_link(r): source_node, source_output = r[0], r[1] - node_output = caches.outputs.get(source_node)[source_output] - for o in node_output: + node_cached = execution_list.get_cache(source_node, unique_id) + for o in node_cached.outputs[source_output]: resolved_output.append(o) else: @@ -535,10 +567,11 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i resolved_outputs.append(tuple(resolved_output)) output_data = merge_result_data(resolved_outputs, class_def) output_ui = [] + del pending_subgraph_results[unique_id] has_subgraph = False else: get_progress_state().start_progress(unique_id) - input_data_all, missing_keys, hidden_inputs = get_input_data(inputs, class_def, unique_id, caches.outputs, dynprompt, extra_data) + input_data_all, missing_keys, v3_data = get_input_data(inputs, class_def, unique_id, execution_list, dynprompt, extra_data) if server.client_id is not None: server.last_node_id = display_node_id server.send_sync("executing", {"node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id}, server.client_id) @@ -553,7 +586,7 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i else: lazy_status_present = getattr(obj, "check_lazy_status", None) is not None if lazy_status_present: - required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, hidden_inputs=hidden_inputs, execution_list=execution_list, executed=executed) + required_inputs = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, "check_lazy_status", allow_interrupt=True, execution_list=execution_list, executed=executed, v3_data=v3_data) required_inputs = await resolve_map_node_over_list_results(required_inputs) required_inputs = set(sum([r for r in required_inputs if isinstance(r, list)], [])) required_inputs = [x for x in required_inputs if isinstance(x, str) and ( @@ -587,7 +620,7 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i # TODO - How to handle this with async functions without contextvars (which requires Python 3.12)? GraphBuilder.set_default_prefix(unique_id, call_index, 0) - output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, hidden_inputs=hidden_inputs, inputs=inputs, execution_list=execution_list, executed=executed) + output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, execution_list=execution_list, executed=executed, v3_data=v3_data) if has_pending_tasks: pending_async_nodes[unique_id] = output_data unblock = execution_list.add_external_block(unique_id) @@ -600,7 +633,7 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i asyncio.create_task(await_completion()) return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None) if len(output_ui) > 0: - caches.ui.set(unique_id, { + ui_outputs[unique_id] = { "meta": { "node_id": unique_id, "display_node": display_node_id, @@ -608,7 +641,7 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i "real_node_id": real_node_id, }, "output": output_ui - }) + } if server.client_id is not None: server.send_sync("executed", {"node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id}, server.client_id) @@ -622,10 +655,6 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i if new_graph is None: cached_outputs.append((False, node_outputs)) else: - # Check for conflicts - for node_id in new_graph.keys(): - if dynprompt.has_node(node_id): - raise DuplicateNodeError(f"Attempt to add duplicate node {node_id}. Ensure node ids are unique and deterministic or use graph_utils.GraphBuilder.") for node_id, node_info in new_graph.items(): new_node_ids.append(node_id) display_id = node_info.get("override_display_id", unique_id) @@ -646,11 +675,16 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i subcache.clean_unused() for node_id in new_output_ids: execution_list.add_node(node_id) + execution_list.cache_link(node_id, unique_id) for link in new_output_links: execution_list.add_strong_link(link[0], link[1], unique_id) pending_subgraph_results[unique_id] = cached_outputs return RecursiveExecutionTuple(ExecutionResult.PENDING, None, None) - caches.outputs.set(unique_id, output_data) + + cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) + execution_list.cache_update(unique_id, cache_entry) + caches.outputs.set(unique_id, cache_entry) + except interruption.InterruptProcessingException as iex: logger.info("Processing interrupted") @@ -702,10 +736,17 @@ async def _execute(server, dynprompt: DynamicPrompt, caches: CacheSet, current_i return RecursiveExecutionTuple(ExecutionResult.SUCCESS, None, None) +class CacheArgs(TypedDict): + ram: NotRequired[int] + lru: NotRequired[float] + + class PromptExecutor: - def __init__(self, server: ExecutorToClientProgress, cache_type: CacheType | Literal[False] = False, cache_size: int | None = None): + def __init__(self, server: ExecutorToClientProgress, cache_type: CacheType | Literal[False] = False, cache_args: Optional[CacheArgs] = None): + self.status_messages = [] + self.caches: Optional[CacheSet] = None self.success = None - self.cache_size = cache_size + self.cache_args = cache_args self.cache_type = cache_type self.server = server self.raise_exceptions = False @@ -714,7 +755,7 @@ class PromptExecutor: def reset(self): self.success = True - self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size) + self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args) self.status_messages = [] def add_message(self, event, data: dict, broadcast: bool): @@ -819,6 +860,7 @@ class PromptExecutor: broadcast=False) pending_subgraph_results = {} pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results + ui_node_outputs = {} executed = set() execution_list = ExecutionList(dynamic_prompt, self.caches.outputs) current_outputs = self.caches.outputs.all_node_ids() @@ -833,7 +875,7 @@ class PromptExecutor: break assert node_id is not None, "Node ID should not be None at this point" - result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes) + result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) self.success = result != ExecutionResult.FAILURE if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) @@ -842,18 +884,16 @@ class PromptExecutor: execution_list.unstage_node_execution() else: # result == ExecutionResult.SUCCESS: execution_list.complete_node_execution() + self.caches.outputs.poll(ram_headroom=self.cache_args["ram"]) else: # Only execute when the while-loop ends without break self.add_message("execution_success", {"prompt_id": prompt_id}, broadcast=False) ui_outputs = {} meta_outputs = {} - all_node_ids = self.caches.ui.all_node_ids() - for node_id in all_node_ids: - ui_info = self.caches.ui.get(node_id) - if ui_info is not None: - ui_outputs[node_id] = ui_info["output"] - meta_outputs[node_id] = ui_info["meta"] + for node_id, ui_info in ui_node_outputs.items(): + ui_outputs[node_id] = ui_info["output"] + meta_outputs[node_id] = ui_info["meta"] self.history_result = { "outputs": ui_outputs, "meta": meta_outputs, @@ -883,9 +923,6 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing class_type = prompt[unique_id]['class_type'] obj_class = get_nodes().NODE_CLASS_MAPPINGS[class_type] - class_inputs = obj_class.INPUT_TYPES() - valid_inputs = set(class_inputs.get('required', {})).union(set(class_inputs.get('optional', {}))) - error: ValidationErrorDict errors = [] valid = True @@ -893,9 +930,11 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing validate_function_inputs = [] validate_has_kwargs = False if issubclass(obj_class, _ComfyNodeInternal): + class_inputs, _, _ = obj_class.INPUT_TYPES(include_hidden=False, return_schema=True, live_inputs=inputs) validate_function_name = "validate_inputs" validate_function = first_real_override(obj_class, validate_function_name) else: + class_inputs = obj_class.INPUT_TYPES() validate_function_name = "VALIDATE_INPUTS" validate_function = getattr(obj_class, validate_function_name, None) if validate_function is not None: @@ -904,6 +943,8 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing validate_has_kwargs = argspec.varkw is not None received_types = {} + valid_inputs = set(class_inputs.get('required', {})).union(set(class_inputs.get('optional', {}))) + for x in valid_inputs: input_type, input_category, extra_info = get_input_info(obj_class, x, class_inputs) assert extra_info is not None @@ -1085,7 +1126,7 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing continue if len(validate_function_inputs) > 0 or validate_has_kwargs: - input_data_all, _, hidden_inputs = get_input_data(inputs, obj_class, unique_id) + input_data_all, _, v3_data = get_input_data(inputs, obj_class, unique_id) input_filtered = {} for x in input_data_all: if x in validate_function_inputs or validate_has_kwargs: @@ -1093,7 +1134,7 @@ async def validate_inputs(prompt_id: typing.Any, prompt, item, validated: typing if 'input_types' in validate_function_inputs: input_filtered['input_types'] = [received_types] - ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, hidden_inputs=hidden_inputs) + ret = await _async_map_node_over_list(prompt_id, unique_id, obj_class, input_filtered, validate_function_name, v3_data=v3_data) ret = await resolve_map_node_over_list_results(ret) for x in input_filtered: for i, r in enumerate(ret): @@ -1320,8 +1361,7 @@ class PromptQueue(AbstractPromptQueue): self.server.queue_updated() return copy.deepcopy(item_with_future.queue_tuple), task_id - def task_done(self, item_id: str, outputs: HistoryResultDict, - status: Optional[ExecutionStatus], error_details: Optional[ExecutionErrorMessage] = None): + def task_done(self, item_id: str, outputs: HistoryResultDict, status: Optional[ExecutionStatus], error_details: typing.Optional[ExecutionErrorMessage] = None, process_item: typing.Optional[typing.Callable[[QueueTuple], QueueItem]] = None): history_result = outputs with self.mutex: queue_item = self.currently_running.pop(item_id) @@ -1331,16 +1371,14 @@ class PromptQueue(AbstractPromptQueue): status_dict = None if status is not None: - status_dict: Optional[ExecutionStatusAsDict] = status.as_dict(error_details=error_details) + status_dict: Optional[ExecutionStatusAsDict] = status.as_dict(error_details=None) outputs_ = history_result["outputs"] - # Remove sensitive data from extra_data before storing in history - for sensitive_val in SENSITIVE_EXTRA_DATA_KEYS: - if sensitive_val in prompt[3]: - prompt[3].pop(sensitive_val) + if process_item is not None: + prompt = process_item(prompt) history_entry: HistoryEntry = { - "prompt": prompt, + "prompt": prompt.queue_tuple if isinstance(prompt, QueueItem) else prompt, "outputs": copy.deepcopy(outputs_), } if status_dict is not None: diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index 3544ac202..9a6be8550 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -24,6 +24,12 @@ _module_properties = create_module_properties() logger = logging.getLogger(__name__) +# todo: investigate what this is actually trying to do +# System User Protection - Protects system directories from HTTP endpoint access +# System Users are internal-only users that cannot be accessed via HTTP endpoints. +# They use the '__' prefix convention (similar to Python's private member convention). +_SYSTEM_USER_PREFIX = "__" + @_module_properties.getter def _supported_pt_extensions() -> set[str]: @@ -58,6 +64,65 @@ def _resolve_path_with_compatibility(path: Path | str) -> PurePosixPath | Path: return Path(path).resolve() +def get_system_user_directory(name: str = "system") -> str: + """ + Get the path to a System User directory. + + System User directories (prefixed with '__') are only accessible via internal API, + not through HTTP endpoints. Use this for storing system-internal data that + should not be exposed to users. + + Args: + name: System user name (e.g., "system", "cache"). Must be alphanumeric + with underscores allowed, but cannot start with underscore. + + Returns: + Absolute path to the system user directory. + + Raises: + ValueError: If name is empty, invalid, or starts with underscore. + + Example: + >>> get_system_user_directory("cache") + '/path/to/user/__cache' + """ + if not name or not isinstance(name, str): + raise ValueError("System user name cannot be empty") + if not name.replace("_", "").isalnum(): + raise ValueError(f"Invalid system user name: '{name}'") + if name.startswith("_"): + raise ValueError("System user name should not start with underscore") + return os.path.join(get_user_directory(), f"{_SYSTEM_USER_PREFIX}{name}") + + +def get_public_user_directory(user_id: str) -> str | None: + """ + Get the path to a Public User directory for HTTP endpoint access. + + This function provides structural security by returning None for any + System User (prefixed with '__'). All HTTP endpoints should use this + function instead of directly constructing user paths. + + Args: + user_id: User identifier from HTTP request. + + Returns: + Absolute path to the user directory, or None if user_id is invalid + or refers to a System User. + + Example: + >>> get_public_user_directory("default") + '/path/to/user/default' + >>> get_public_user_directory("__system") + None + """ + if not user_id or not isinstance(user_id, str): + return None + if user_id.startswith(_SYSTEM_USER_PREFIX): + return None + return os.path.join(get_user_directory(), user_id) + + def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optional[Configuration] = None, create_all_directories=False, replace_existing=True, base_paths_from_configuration=True): """ Populates the folder names and paths object with the default, upstream model directories and custom_nodes directory. @@ -111,6 +176,7 @@ def init_default_paths(folder_names_and_paths: FolderNames, configuration: Optio ModelPaths(["huggingface"], supported_extensions=set()), ModelPaths(["model_patches"], supported_extensions=set(supported_pt_extensions)), ModelPaths(["audio_encoders"], supported_extensions=set(supported_pt_extensions)), + ModelPaths(["latent_upscale_models"], supported_extensions=set(supported_pt_extensions)), hf_cache_paths, hf_xet, ] diff --git a/comfy/cmd/folder_paths.pyi b/comfy/cmd/folder_paths.pyi index c220e6df3..50dac2607 100644 --- a/comfy/cmd/folder_paths.pyi +++ b/comfy/cmd/folder_paths.pyi @@ -15,6 +15,7 @@ output_directory: str temp_directory: str input_directory: str supported_pt_extensions: set[str] +extension_mimetypes_cache: dict[str, str] # Functions @@ -39,7 +40,7 @@ def get_output_directory() -> str: ... def get_temp_directory() -> str: ... -def get_input_directory() -> str: ... +def get_input_directory(mkdirs: bool = ...) -> str: ... def get_user_directory() -> str: ... @@ -108,3 +109,9 @@ def filter_files_content_types(files: List[str], content_types: List[Literal["im def get_input_subfolders() -> list[str]: ... + + +def get_system_user_directory(name: str = ...) -> str: ... + + +def get_public_user_directory(user_id: str) -> Optional[str]: ... diff --git a/comfy/cmd/latent_preview.py b/comfy/cmd/latent_preview.py index 1a2c01162..79b42200d 100644 --- a/comfy/cmd/latent_preview.py +++ b/comfy/cmd/latent_preview.py @@ -15,14 +15,24 @@ from ..component_model.executor_types import UnencodedPreviewImageMessage from ..execution_context import current_execution_context from ..model_downloader import get_or_download, KNOWN_APPROX_VAES from ..taesd.taesd import TAESD +from ..sd import VAE +from ..utils import load_torch_file MAX_PREVIEW_RESOLUTION = args.preview_size +VIDEO_TAES = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] + logger = logging.getLogger(__name__) -def preview_to_image(latent_image) -> Image: - latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 - .mul(0xFF) # to 0..255 - ) + +def preview_to_image(latent_image, do_scale=True) -> Image.Image: + if do_scale: + latents_ubyte = (((latent_image + 1.0) / 2.0).clamp(0, 1) # change scale from -1..1 to 0..1 + .mul(0xFF) # to 0..255 + ) + else: + latents_ubyte = (latent_image.clamp(0, 1) + .mul(0xFF) # to 0..255 + ) if model_management.directml_device is not None: latents_ubyte = latents_ubyte.to(dtype=torch.uint8) latents_ubyte = latents_ubyte.to(device="cpu", dtype=torch.uint8, non_blocking=model_management.device_supports_non_blocking(latent_image.device)) @@ -31,7 +41,7 @@ def preview_to_image(latent_image) -> Image: class LatentPreviewer: - def decode_latent_to_preview(self, x0) -> Image: + def decode_latent_to_preview(self, x0) -> Image.Image: raise NotImplementedError def decode_latent_to_preview_image(self, preview_format, x0) -> UnencodedPreviewImageMessage: @@ -49,14 +59,23 @@ class TAESDPreviewerImpl(LatentPreviewer): return preview_to_image(x_sample) +class TAEHVPreviewerImpl(TAESDPreviewerImpl): + def decode_latent_to_preview(self, x0): + x_sample = self.taesd.decode(x0[:1, :, :1])[0][0] + return preview_to_image(x_sample, do_scale=False) + + class Latent2RGBPreviewer(LatentPreviewer): - def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None): + def __init__(self, latent_rgb_factors, latent_rgb_factors_bias=None, latent_rgb_factors_reshape=None): self.latent_rgb_factors = torch.tensor(latent_rgb_factors, device="cpu").transpose(0, 1) self.latent_rgb_factors_bias = None if latent_rgb_factors_bias is not None: self.latent_rgb_factors_bias = torch.tensor(latent_rgb_factors_bias, device="cpu") + self.latent_rgb_factors_reshape = latent_rgb_factors_reshape def decode_latent_to_preview(self, x0): + if self.latent_rgb_factors_reshape is not None: + x0 = self.latent_rgb_factors_reshape(x0) self.latent_rgb_factors = self.latent_rgb_factors.to(dtype=x0.dtype, device=x0.device) if self.latent_rgb_factors_bias is not None: self.latent_rgb_factors_bias = self.latent_rgb_factors_bias.to(dtype=x0.dtype, device=x0.device) @@ -91,14 +110,19 @@ def get_previewer(device, latent_format): if method == LatentPreviewMethod.TAESD: if taesd_decoder_path: - taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device) - previewer = TAESDPreviewerImpl(taesd) + if latent_format.taesd_decoder_name in VIDEO_TAES: + taesd = VAE(load_torch_file(taesd_decoder_path)) + taesd.first_stage_model.show_progress_bar = False + previewer = TAEHVPreviewerImpl(taesd) + else: + taesd = TAESD(None, taesd_decoder_path, latent_channels=latent_format.latent_channels).to(device) + previewer = TAESDPreviewerImpl(taesd) else: logger.warning("Warning: TAESD previews enabled, but could not find models/vae_approx/{}".format(latent_format.taesd_decoder_name)) if previewer is None: if latent_format.latent_rgb_factors is not None: - previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias) + previewer = Latent2RGBPreviewer(latent_format.latent_rgb_factors, latent_format.latent_rgb_factors_bias, latent_format.latent_rgb_factors_reshape) return previewer diff --git a/comfy/cmd/main.py b/comfy/cmd/main.py index cade11350..06f6e6a24 100644 --- a/comfy/cmd/main.py +++ b/comfy/cmd/main.py @@ -1,30 +1,29 @@ -from .main_pre import tracer - import asyncio import contextvars import gc - import logging import os import shutil +import sys import threading import time from pathlib import Path from typing import Optional -from ..cli_args_types import Configuration -from ..component_model.file_counter import cleanup_temp as fc_cleanup_temp -from ..execution_context import current_execution_context + +from comfy.component_model.abstract_prompt_queue import AbstractPromptQueue from . import hook_breaker_ac10a0 from .extra_model_paths import load_extra_path_config from .. import model_management from ..analytics.analytics import initialize_event_tracking +from ..cli_args_types import Configuration from ..cmd import cuda_malloc from ..cmd import folder_paths from ..cmd import server as server_module -from ..component_model.abstract_prompt_queue import AbstractPromptQueue from ..component_model.entrypoints_common import configure_application_paths, executor_from_args +from ..component_model.file_counter import cleanup_temp as fc_cleanup_temp from ..distributed.distributed_prompt_queue import DistributedPromptQueue from ..distributed.server_stub import ServerStub +from ..execution_context import current_execution_context from ..nodes.package import import_all_nodes_in_workspace from ..nodes_context import get_nodes @@ -44,22 +43,27 @@ def cuda_malloc_warning(): "\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") -def prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer): - asyncio.run(_prompt_worker(q, server_instance)) +def handle_comfyui_manager_unavailable(args: Configuration): + if not args.windows_standalone_build: + logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n") + args.enable_manager = False async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module.PromptServer): from ..cmd import execution from ..component_model import queue_types from .. import model_management + args = current_execution_context().configuration cache_type = execution.CacheType.CLASSIC if args.cache_lru > 0: cache_type = execution.CacheType.LRU + elif args.cache_ram > 0: + cache_type = execution.CacheType.RAM_PRESSURE elif args.cache_none: - cache_type = execution.CacheType.DEPENDENCY_AWARE + cache_type = execution.CacheType.NONE - e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru) + e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={"lru": args.cache_lru, "ram": args.cache_ram}) last_gc_collect = 0 need_gc = False gc_collect_interval = 10.0 @@ -76,6 +80,15 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module. prompt_id = item[1] server_instance.last_prompt_id = prompt_id + sensitive = item[5] + extra_data = item[3].copy() + for k in sensitive: + extra_data[k] = sensitive[k] + + e.execute(item[2], prompt_id, extra_data, item[4]) + + remove_sensitive = lambda prompt: prompt[:5] + prompt[6:] + await e.execute_async(item[2], prompt_id, item[3], item[4]) need_gc = True @@ -96,13 +109,16 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module. status_str='success' if e.success else 'error', completed=e.success, messages=messages), - error_details=error_details) + error_details=error_details, + process_item=remove_sensitive, + ) + if server_instance.client_id is not None: - server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, - server_instance.client_id) + server_instance.send_sync("executing", {"node": None, "prompt_id": prompt_id}, server_instance.client_id) current_time = time.perf_counter() execution_time = current_time - execution_start_time + # Log Time in a more readable way after 10 minutes if execution_time > 600: execution_time = time.strftime("%H:%M:%S", time.gmtime(execution_time)) @@ -133,7 +149,7 @@ async def _prompt_worker(q: AbstractPromptQueue, server_instance: server_module. hook_breaker_ac10a0.restore_functions() -async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): +async def run(server_instance, address='', port=8188, call_on_start=None): addresses = [] for addr in address.split(","): addresses.append((addr, port)) @@ -194,6 +210,23 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None): logger.info(f"Setting user directory to: {user_dir}") folder_paths.set_user_directory(user_dir) + # todo: the manager code has to live inside vanilla_node_importing, it has to deal with a git repo already being in custom_nodes + # if args.enable_manager: + # if importlib.util.find_spec("comfyui_manager"): + # import comfyui_manager + # + # if not comfyui_manager.__file__ or not comfyui_manager.__file__.endswith('__init__.py'): + # handle_comfyui_manager_unavailable(args) + # else: + # handle_comfyui_manager_unavailable(args) + # + # if args.enable_manager: + # try: + # import comfyui_manager + # comfyui_manager.prestartup() + # except: + # pass + # configure extra model paths earlier try: extra_model_paths_config_path = os.path.join(os_getcwd, "extra_model_paths.yaml") @@ -224,6 +257,15 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None): loop = asyncio.get_event_loop() server = server_module.PromptServer(loop) + + # todo: the manager code has to live inside vanilla_node_importing, it has to deal with a git repo already being in custom_nodes + # if args.enable_manager and not args.disable_manager_ui: + # try: + # import comfyui_manager + # comfyui_manager.start() + # except: + # pass + if args.external_address is not None: server.external_address = args.external_address @@ -317,8 +359,7 @@ async def __start_comfyui(from_script_dir: Optional[Path] = None): try: await server.setup() - await run(server, address=first_listen_addr, port=args.port, verbose=not args.dont_print_server, - call_on_start=call_on_start) + await run(server, address=first_listen_addr, port=args.port, call_on_start=call_on_start) except (asyncio.CancelledError, KeyboardInterrupt): logger.debug("Stopped server") finally: diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index a98b900b0..08f74e0c2 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -12,6 +12,8 @@ import socket import struct import sys import traceback +import time + import typing import urllib import uuid @@ -39,7 +41,7 @@ from .. import node_helpers from .. import utils from ..api_server.routes.internal.internal_routes import InternalRoutes from ..app.custom_node_manager import CustomNodeManager -from ..app.frontend_management import FrontendManager +from ..app.frontend_management import FrontendManager, parse_version from ..app.model_manager import ModelFileManager from ..app.user_manager import UserManager from ..cli_args import args @@ -70,6 +72,13 @@ class HeuristicPath(NamedTuple): # Import cache control middleware from ..middleware.cache_middleware import cache_control +# todo: what is this really trying to do? +LOADED_MODULE_DIRS = {} + +# todo: is this really how we want to enable the manager? +if args.enable_manager: + import comfyui_manager + async def send_socket_catch_exception(function, message): try: await function(message) @@ -144,7 +153,7 @@ def create_cors_middleware(allowed_origin: str): response = await handler(request) response.headers['Access-Control-Allow-Origin'] = allowed_origin - response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS' + response.headers['Access-Control-Allow-Methods'] = 'POST, GET, DELETE, PUT, OPTIONS, PATCH' response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, traceparent, tracestate' response.headers['Access-Control-Allow-Credentials'] = 'true' return response @@ -215,9 +224,23 @@ def create_origin_only_middleware(): return origin_only_middleware +def create_block_external_middleware(): + @web.middleware + async def block_external_middleware(request: web.Request, handler): + if request.method == "OPTIONS": + # Pre-flight request. Reply successfully: + response = web.Response() + else: + response = await handler(request) + + response.headers['Content-Security-Policy'] = "default-src 'self'; script-src 'self' 'unsafe-inline' 'unsafe-eval' blob:; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; font-src 'self'; connect-src 'self'; frame-src 'self'; object-src 'self';" + return response + + return block_external_middleware + + class PromptServer(ExecutorToClientProgress): instance: Optional['PromptServer'] = None - def __init__(self, loop): # todo: this really needs to be set up differently, because sometimes the prompt server will not be initialized PromptServer.instance = self @@ -230,6 +253,7 @@ class PromptServer(ExecutorToClientProgress): self.user_manager = UserManager() self.model_file_manager = ModelFileManager() self.custom_node_manager = CustomNodeManager() + self.subgraph_manager = SubgraphManager() self.internal_routes = InternalRoutes(self) # todo: this is probably read by custom nodes elsewhere self.supports: List[str] = ["custom_nodes_from_web"] @@ -251,6 +275,12 @@ class PromptServer(ExecutorToClientProgress): else: middlewares.append(create_origin_only_middleware()) + if args.disable_api_nodes: + middlewares.append(create_block_external_middleware()) + + if args.enable_manager: + middlewares.append(comfyui_manager.create_middleware()) + max_upload_size = round(args.max_upload_size * 1024 * 1024) self.app: web.Application = web.Application(client_max_size=max_upload_size, handler_args={'max_field_size': 16380}, @@ -634,7 +664,7 @@ class PromptServer(ExecutorToClientProgress): system_stats = { "system": { - "os": os.name, + "os": sys.platform, "ram_total": ram_total, "ram_free": ram_free, "comfyui_version": __version__, @@ -746,8 +776,9 @@ class PromptServer(ExecutorToClientProgress): async def get_queue(request): queue_info = {} current_queue = self.prompt_queue.get_current_queue_volatile() - queue_info['queue_running'] = current_queue[0] - queue_info['queue_pending'] = current_queue[1] + remove_sensitive = lambda queue: [x[:5] for x in queue] + queue_info['queue_running'] = remove_sensitive(current_queue[0]) + queue_info['queue_pending'] = remove_sensitive(current_queue[1]) return web.json_response(queue_info) @routes.post("/prompt") @@ -782,8 +813,13 @@ class PromptServer(ExecutorToClientProgress): extra_data["client_id"] = json_data["client_id"] if valid[0]: outputs_to_execute = valid[2] + sensitive = {} + for sensitive_val in execution.SENSITIVE_EXTRA_DATA_KEYS: + if sensitive_val in extra_data: + sensitive[sensitive_val] = extra_data.pop(sensitive_val) + extra_data["create_time"] = int(time.time() * 1000) # timestamp in milliseconds self.prompt_queue.put( - QueueItem(queue_tuple=(number, prompt_id, prompt, extra_data, outputs_to_execute), + QueueItem(queue_tuple=(number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive), completed=None)) response = {"prompt_id": prompt_id, "number": number, "node_errors": valid[3]} return web.json_response(response) @@ -1112,6 +1148,7 @@ class PromptServer(ExecutorToClientProgress): self.model_file_manager.add_routes(self.routes) # todo: needs to use module directories self.custom_node_manager.add_routes(self.routes, self.app, {}) + self.subgraph_manager.add_routes(self.routes, LOADED_MODULE_DIRS.items()) self.app.add_subapp('/internal', self.internal_routes.get_app()) # Prefix every route with /api for easier matching for delegation. @@ -1132,11 +1169,31 @@ class PromptServer(ExecutorToClientProgress): for name, dir in self.nodes.EXTENSION_WEB_DIRS.items(): self.app.add_routes([web.static('/extensions/' + name, dir, follow_symlinks=True)]) - workflow_templates_path = FrontendManager.templates_path() - if workflow_templates_path: - self.app.add_routes([ - web.static('/templates', workflow_templates_path) - ]) + installed_templates_version = FrontendManager.get_installed_templates_version() + use_legacy_templates = True + if installed_templates_version: + try: + use_legacy_templates = ( + parse_version(installed_templates_version) + < parse_version("0.3.0") + ) + except Exception as exc: + logging.warning( + "Unable to parse templates version '%s': %s", + installed_templates_version, + exc, + ) + + if use_legacy_templates: + workflow_templates_path = FrontendManager.legacy_templates_path() + if workflow_templates_path: + self.app.add_routes([ + web.static('/templates', workflow_templates_path) + ]) + else: + handler = FrontendManager.template_asset_handler() + if handler: + self.app.router.add_get("/templates/{path:.*}", handler) # Serve embedded documentation from the package embedded_docs_path = FrontendManager.embedded_docs_path() diff --git a/comfy/component_model/abstract_prompt_queue.py b/comfy/component_model/abstract_prompt_queue.py index 74cde5fcf..14ddc469b 100644 --- a/comfy/component_model/abstract_prompt_queue.py +++ b/comfy/component_model/abstract_prompt_queue.py @@ -3,7 +3,7 @@ from __future__ import annotations import typing from abc import ABCMeta, abstractmethod -from .executor_types import HistoryResultDict +from .executor_types import HistoryResultDict, ExecutionErrorMessage from .queue_types import QueueTuple, HistoryEntry, QueueItem, Flags, ExecutionStatus, TaskInvocation, AbstractPromptQueueGetCurrentQueueItems @@ -43,10 +43,11 @@ class AbstractPromptQueue(metaclass=ABCMeta): pass @abstractmethod - def task_done(self, item_id: str, outputs: HistoryResultDict, - status: typing.Optional[ExecutionStatus]): + def task_done(self, item_id: str, outputs: HistoryResultDict, status: typing.Optional[ExecutionStatus], error_details: typing.Optional[ExecutionErrorMessage] = None, process_item: typing.Optional[typing.Callable[[QueueTuple], QueueItem]] = None): """ Signals to the user interface that the task with the specified id is completed + :param error_details: + :param process_item: :param item_id: the ID of the task that should be marked as completed :param outputs: an opaque dictionary of outputs :param status: diff --git a/comfy/component_model/queue_types.py b/comfy/component_model/queue_types.py index f5189cc64..498c57e31 100644 --- a/comfy/component_model/queue_types.py +++ b/comfy/component_model/queue_types.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import copy +import typing from enum import Enum from typing import NamedTuple, Optional, List, Literal, Sequence from typing import Tuple @@ -10,7 +11,12 @@ from typing_extensions import NotRequired, TypedDict from .outputs_types import OutputsDict -QueueTuple = Tuple[float, str, dict, dict, list] +if typing.TYPE_CHECKING: + from .executor_types import ExecutionErrorMessage +# todo: migrate this and the tree of objects here to a NamedTuple +# number, prompt_id, prompt, extra_data, outputs_to_execute, sensitive +# todo: sensitive dictionary data is actually a JSON value +QueueTuple = Tuple[float, str, dict, dict, list, Optional[dict[str, str]]] MAXIMUM_HISTORY_SIZE = 10000 @@ -63,6 +69,7 @@ class ExecutionStatusAsDict(TypedDict): status_str: Literal['success', 'error'] completed: bool messages: List[str] + error_details: NotRequired[ExecutionErrorMessage] class Flags(TypedDict, total=False): @@ -98,7 +105,8 @@ class NamedQueueTuple(dict): prompt_id=queue_tuple[1], prompt=queue_tuple[2], extra_data=queue_tuple[3] if len(queue_tuple) > 3 else None, - good_outputs=queue_tuple[4] if len(queue_tuple) > 4 else None + good_outputs=queue_tuple[4] if len(queue_tuple) > 4 else None, + sensitive=queue_tuple[5] if len(queue_tuple) > 5 else None, ) # Store the original tuple in a slot, making it invisible to json.dumps. self.queue_tuple = queue_tuple @@ -127,6 +135,12 @@ class NamedQueueTuple(dict): return self.queue_tuple[4] return None + @property + def sensitive(self) -> Optional[dict]: + if len(self.queue_tuple) > 5: + return self.queue_tuple[5] + return None + class QueueItem(NamedQueueTuple): """ diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 36174af05..97fbe2ccb 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -54,26 +54,36 @@ class ContextHandlerABC(ABC): class IndexListContextWindow(ContextWindowABC): - def __init__(self, index_list: list[int], dim: int = 0): + def __init__(self, index_list: list[int], dim: int = 0, total_frames: int=0): self.index_list = index_list self.context_length = len(index_list) self.dim = dim + self.total_frames = total_frames + self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) - def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor: + def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: if dim is None: dim = self.dim if dim == 0 and full.shape[dim] == 1: return full - idx = [slice(None)] * dim + [self.index_list] - return full[idx].to(device) + idx = tuple([slice(None)] * dim + [self.index_list]) + window = full[idx] + if retain_index_list: + idx = tuple([slice(None)] * dim + [retain_index_list]) + window[idx] = full[idx] + return window.to(device) def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor: if dim is None: dim = self.dim - idx = [slice(None)] * dim + [self.index_list] + idx = tuple([slice(None)] * dim + [self.index_list]) full[idx] += to_add return full + def get_region_index(self, num_regions: int) -> int: + region_idx = int(self.center_ratio * num_regions) + return min(max(region_idx, 0), num_regions - 1) + class IndexListCallbacks: EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" @@ -101,7 +111,8 @@ ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_co class IndexListContextHandler(ContextHandlerABC): - def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int = 1, context_overlap: int = 0, context_stride: int = 1, closed_loop=False, dim=0): + def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int = 1, context_overlap: int = 0, context_stride: int = 1, + closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False): self.context_schedule = context_schedule self.fuse_method = fuse_method self.context_length = context_length @@ -110,13 +121,18 @@ class IndexListContextHandler(ContextHandlerABC): self.closed_loop = closed_loop self.dim = dim self._step = 0 + self.freenoise = freenoise + self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] + self.split_conds_to_windows = split_conds_to_windows self.callbacks = {} def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: # for now, assume first dim is batch - should have stored on BaseModel in actual implementation if x_in.size(self.dim) > self.context_length: - logger.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.") + logger.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") + if self.cond_retain_index_list: + logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") return True return False @@ -130,6 +146,11 @@ class IndexListContextHandler(ContextHandlerABC): return None # reuse or resize cond items to match context requirements resized_cond = [] + # if multiple conds, split based on primary region + if self.split_conds_to_windows and len(cond_in) > 1: + region = window.get_region_index(len(cond_in)) + logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}") + cond_in = [cond_in[region]] # cond object is a list containing a dict - outer list is irrelevant, so just loop through it for actual_cond in cond_in: resized_actual_cond = actual_cond.copy() @@ -153,12 +174,19 @@ class IndexListContextHandler(ContextHandlerABC): # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) for cond_key, cond_value in new_cond_item.items(): if isinstance(cond_value, torch.Tensor): - if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim): + if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ + (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): new_cond_item[cond_key] = window.get_tensor(cond_value, device) + # Handle audio_embed (temporal dim is 1) + elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + audio_cond = cond_value.cond + if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim): + new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1)) # if has cond that is a Tensor, check if needs to be subset elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): - if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim): - new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device)) + if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \ + (cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)): + new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list)) elif cond_key == "num_video_frames": # for SVD new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond) new_cond_item[cond_key].cond = window.context_length @@ -171,7 +199,7 @@ class IndexListContextHandler(ContextHandlerABC): return resized_cond def set_step(self, timestep: torch.Tensor, model_options: dict[str]): - mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001) + mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: raise Exception("No sample_sigmas matched current timestep; something went wrong.") @@ -180,7 +208,7 @@ class IndexListContextHandler(ContextHandlerABC): def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: full_length = x_in.size(self.dim) # TODO: choose dim based on model context_windows = self.context_schedule.func(full_length, self, model_options) - context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows] + context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows] return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): @@ -256,8 +284,8 @@ class IndexListContextHandler(ContextHandlerABC): prev_weight = (bias_total / (bias_total + bias)) new_weight = (bias / (bias_total + bias)) # account for dims of tensors - idx_window = [slice(None)] * self.dim + [idx] - pos_window = [slice(None)] * self.dim + [pos] + idx_window = tuple([slice(None)] * self.dim + [idx]) + pos_window = tuple([slice(None)] * self.dim + [pos]) # apply new values conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight biases_final[i][idx] = bias_total + bias @@ -293,6 +321,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher): ) +def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs): + model_options = extra_args.get("model_options", None) + if model_options is None: + raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.") + handler: IndexListContextHandler = model_options.get("context_handler", None) + if handler is None: + raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") + if not handler.freenoise: + return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) + noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + + return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) + + +def create_sampler_sample_wrapper(model: ModelPatcher): + model.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, + "ContextWindows_sampler_sample", + _sampler_sample_wrapper + ) + + def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: total_dims = len(x_in.shape) weights_tensor = torch.Tensor(weights).to(device=device) @@ -552,3 +602,29 @@ def shift_window_to_end(window: list[int], num_frames: int): for i in range(len(window)): # 2) add end_delta to each val to slide windows to end window[i] = window[i] + end_delta + + +# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465 +def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int): + logging.info("Context windows: Applying FreeNoise") + generator = torch.Generator(device='cpu').manual_seed(seed) + latent_video_length = noise.shape[dim] + delta = context_length - context_overlap + + for start_idx in range(0, latent_video_length - context_length, delta): + place_idx = start_idx + context_length + + actual_delta = min(delta, latent_video_length - place_idx) + if actual_delta <= 0: + break + + list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx + + source_slice = [slice(None)] * noise.ndim + source_slice[dim] = list_idx + target_slice = [slice(None)] * noise.ndim + target_slice[dim] = slice(place_idx, place_idx + actual_delta) + + noise[tuple(target_slice)] = noise[tuple(source_slice)] + + return noise diff --git a/comfy/controlnet.py b/comfy/controlnet.py index cbb598fc8..41d7223aa 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -316,11 +316,13 @@ class ControlLoraOps: self.bias = None def forward(self, input): - weight, bias = ops.cast_bias_weight(self, input) + weight, bias, offload_stream = ops.cast_bias_weight(self, input, offloadable=True) if self.up is not None: - return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) + x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias) else: - return torch.nn.functional.linear(input, weight, bias) + x = torch.nn.functional.linear(input, weight, bias) + comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) + return x class Conv2d(torch.nn.Module, ops.CastWeightBiasOp): def __init__( @@ -355,12 +357,13 @@ class ControlLoraOps: self.down = None def forward(self, input): - weight, bias = ops.cast_bias_weight(self, input) + weight, bias, offload_stream = ops.cast_bias_weight(self, input, offloadable=True) if self.up is not None: - return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) + x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups) else: - return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) - + x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups) + comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream) + return x class ControlLora(ControlNet): def __init__(self, control_weights, global_average_pooling=False, model_options={}): # TODO? model_options diff --git a/comfy/distributed/distributed_prompt_queue.py b/comfy/distributed/distributed_prompt_queue.py index e3af6ef5c..df2ef990d 100644 --- a/comfy/distributed/distributed_prompt_queue.py +++ b/comfy/distributed/distributed_prompt_queue.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing + from ..cmd.main_pre import tracer import asyncio @@ -22,7 +24,7 @@ from .server_stub import ServerStub from ..auth.permissions import jwt_decode from ..cmd.server import PromptServer from ..component_model.abstract_prompt_queue import AsyncAbstractPromptQueue, AbstractPromptQueue -from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData, HistoryResultDict +from ..component_model.executor_types import ExecutorToClientProgress, SendSyncEvent, SendSyncData, HistoryResultDict, ExecutionErrorMessage from ..component_model.queue_types import Flags, HistoryEntry, QueueTuple, QueueItem, ExecutionStatus, TaskInvocation, \ ExecutionError @@ -163,7 +165,8 @@ class DistributedPromptQueue(AbstractPromptQueue, AsyncAbstractPromptQueue): return item, item[1] - def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus], error_details: Optional['ExecutionErrorMessage'] = None): + def task_done(self, item_id: int, outputs: dict, status: Optional[ExecutionStatus], error_details: Optional[ExecutionErrorMessage] = None, process_item: typing.Optional[typing.Callable[[QueueTuple], QueueItem]] = None): + # todo: should we do the process_item? it's clearing sensitive data. but what is the idea? why do things this way, it's crazy # callee: executed on the worker thread if "outputs" in outputs: outputs: HistoryResultDict diff --git a/comfy/latent_formats.py b/comfy/latent_formats.py index 4c94384d6..ac40d8350 100644 --- a/comfy/latent_formats.py +++ b/comfy/latent_formats.py @@ -7,6 +7,7 @@ class LatentFormat: latent_dimensions = 2 latent_rgb_factors = None latent_rgb_factors_bias = None + latent_rgb_factors_reshape = None taesd_decoder_name = None def process_in(self, latent): @@ -191,6 +192,55 @@ class Flux(SD3): return (latent / self.scale_factor) + self.shift_factor +class Flux2(LatentFormat): + latent_channels = 128 + + def __init__(self): + self.latent_rgb_factors = [ + [0.0058, 0.0113, 0.0073], + [0.0495, 0.0443, 0.0836], + [-0.0099, 0.0096, 0.0644], + [0.2144, 0.3009, 0.3652], + [0.0166, -0.0039, -0.0054], + [0.0157, 0.0103, -0.0160], + [-0.0398, 0.0902, -0.0235], + [-0.0052, 0.0095, 0.0109], + [-0.3527, -0.2712, -0.1666], + [-0.0301, -0.0356, -0.0180], + [-0.0107, 0.0078, 0.0013], + [0.0746, 0.0090, -0.0941], + [0.0156, 0.0169, 0.0070], + [-0.0034, -0.0040, -0.0114], + [0.0032, 0.0181, 0.0080], + [-0.0939, -0.0008, 0.0186], + [0.0018, 0.0043, 0.0104], + [0.0284, 0.0056, -0.0127], + [-0.0024, -0.0022, -0.0030], + [0.1207, -0.0026, 0.0065], + [0.0128, 0.0101, 0.0142], + [0.0137, -0.0072, -0.0007], + [0.0095, 0.0092, -0.0059], + [0.0000, -0.0077, -0.0049], + [-0.0465, -0.0204, -0.0312], + [0.0095, 0.0012, -0.0066], + [0.0290, -0.0034, 0.0025], + [0.0220, 0.0169, -0.0048], + [-0.0332, -0.0457, -0.0468], + [-0.0085, 0.0389, 0.0609], + [-0.0076, 0.0003, -0.0043], + [-0.0111, -0.0460, -0.0614], + ] + + self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851] + self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2) + + def process_in(self, latent): + return latent + + def process_out(self, latent): + return latent + + class Mochi(LatentFormat): latent_channels = 12 latent_dimensions = 3 @@ -240,210 +290,214 @@ class LTXV(LatentFormat): def __init__(self): self.latent_rgb_factors = [ - [ 1.1202e-02, -6.3815e-04, -1.0021e-02], - [ 8.6031e-02, 6.5813e-02, 9.5409e-04], + [1.1202e-02, -6.3815e-04, -1.0021e-02], + [8.6031e-02, 6.5813e-02, 9.5409e-04], [-1.2576e-02, -7.5734e-03, -4.0528e-03], - [ 9.4063e-03, -2.1688e-03, 2.6093e-03], - [ 3.7636e-03, 1.2765e-02, 9.1548e-03], - [ 2.1024e-02, -5.2973e-03, 3.4373e-03], + [9.4063e-03, -2.1688e-03, 2.6093e-03], + [3.7636e-03, 1.2765e-02, 9.1548e-03], + [2.1024e-02, -5.2973e-03, 3.4373e-03], [-8.8896e-03, -1.9703e-02, -1.8761e-02], - [-1.3160e-02, -1.0523e-02, 1.9709e-03], + [-1.3160e-02, -1.0523e-02, 1.9709e-03], [-1.5152e-03, -6.9891e-03, -7.5810e-03], - [-1.7247e-03, 4.6560e-04, -3.3839e-03], - [ 1.3617e-02, 4.7077e-03, -2.0045e-03], - [ 1.0256e-02, 7.7318e-03, 1.3948e-02], - [-1.6108e-02, -6.2151e-03, 1.1561e-03], - [ 7.3407e-03, 1.5628e-02, 4.4865e-04], - [ 9.5357e-04, -2.9518e-03, -1.4760e-02], - [ 1.9143e-02, 1.0868e-02, 1.2264e-02], - [ 4.4575e-03, 3.6682e-05, -6.8508e-03], - [-4.5681e-04, 3.2570e-03, 7.7929e-03], - [ 3.3902e-02, 3.3405e-02, 3.7454e-02], + [-1.7247e-03, 4.6560e-04, -3.3839e-03], + [1.3617e-02, 4.7077e-03, -2.0045e-03], + [1.0256e-02, 7.7318e-03, 1.3948e-02], + [-1.6108e-02, -6.2151e-03, 1.1561e-03], + [7.3407e-03, 1.5628e-02, 4.4865e-04], + [9.5357e-04, -2.9518e-03, -1.4760e-02], + [1.9143e-02, 1.0868e-02, 1.2264e-02], + [4.4575e-03, 3.6682e-05, -6.8508e-03], + [-4.5681e-04, 3.2570e-03, 7.7929e-03], + [3.3902e-02, 3.3405e-02, 3.7454e-02], [-2.3001e-02, -2.4877e-03, -3.1033e-03], - [ 5.0265e-02, 3.8841e-02, 3.3539e-02], - [-4.1018e-03, -1.1095e-03, 1.5859e-03], + [5.0265e-02, 3.8841e-02, 3.3539e-02], + [-4.1018e-03, -1.1095e-03, 1.5859e-03], [-1.2689e-01, -1.3107e-01, -2.1005e-01], - [ 2.6276e-02, 1.4189e-02, -3.5963e-03], - [-4.8679e-03, 8.8486e-03, 7.8029e-03], + [2.6276e-02, 1.4189e-02, -3.5963e-03], + [-4.8679e-03, 8.8486e-03, 7.8029e-03], [-1.6610e-03, -4.8597e-03, -5.2060e-03], - [-2.1010e-03, 2.3610e-03, 9.3796e-03], + [-2.1010e-03, 2.3610e-03, 9.3796e-03], [-2.2482e-02, -2.1305e-02, -1.5087e-02], [-1.5753e-02, -1.0646e-02, -6.5083e-03], - [-4.6975e-03, 5.0288e-03, -6.7390e-03], - [ 1.1951e-02, 2.0712e-02, 1.6191e-02], + [-4.6975e-03, 5.0288e-03, -6.7390e-03], + [1.1951e-02, 2.0712e-02, 1.6191e-02], [-6.3704e-03, -8.4827e-03, -9.5483e-03], - [ 7.2610e-03, -9.9326e-03, -2.2978e-02], - [-9.1904e-04, 6.2882e-03, 9.5720e-03], + [7.2610e-03, -9.9326e-03, -2.2978e-02], + [-9.1904e-04, 6.2882e-03, 9.5720e-03], [-3.7178e-02, -3.7123e-02, -5.6713e-02], [-1.3373e-01, -1.0720e-01, -5.3801e-02], - [-5.3702e-03, 8.1256e-03, 8.8397e-03], + [-5.3702e-03, 8.1256e-03, 8.8397e-03], [-1.5247e-01, -2.1437e-01, -2.1843e-01], - [ 3.1441e-02, 7.0335e-03, -9.7541e-03], - [ 2.1528e-03, -8.9817e-03, -2.1023e-02], - [ 3.8461e-03, -5.8957e-03, -1.5014e-02], + [3.1441e-02, 7.0335e-03, -9.7541e-03], + [2.1528e-03, -8.9817e-03, -2.1023e-02], + [3.8461e-03, -5.8957e-03, -1.5014e-02], [-4.3470e-03, -1.2940e-02, -1.5972e-02], [-5.4781e-03, -1.0842e-02, -3.0204e-03], - [-6.5347e-03, 3.0806e-03, -1.0163e-02], + [-6.5347e-03, 3.0806e-03, -1.0163e-02], [-5.0414e-03, -7.1503e-03, -8.9686e-04], - [-8.5851e-03, -2.4351e-03, 1.0674e-03], - [-9.0016e-03, -9.6493e-03, 1.5692e-03], - [ 5.0914e-03, 1.2099e-02, 1.9968e-02], - [ 1.3758e-02, 1.1669e-02, 8.1958e-03], + [-8.5851e-03, -2.4351e-03, 1.0674e-03], + [-9.0016e-03, -9.6493e-03, 1.5692e-03], + [5.0914e-03, 1.2099e-02, 1.9968e-02], + [1.3758e-02, 1.1669e-02, 8.1958e-03], [-1.0518e-02, -1.1575e-02, -4.1307e-03], [-2.8410e-02, -3.1266e-02, -2.2149e-02], - [ 2.9336e-03, 3.6511e-02, 1.8717e-02], + [2.9336e-03, 3.6511e-02, 1.8717e-02], [-1.6703e-02, -1.6696e-02, -4.4529e-03], - [ 4.8818e-02, 4.0063e-02, 8.7410e-03], - [-1.5066e-02, -5.7328e-04, 2.9785e-03], - [-1.7613e-02, -8.1034e-03, 1.3086e-02], - [-9.2633e-03, 1.0803e-02, -6.3489e-03], - [ 3.0851e-03, 4.7750e-04, 1.2347e-02], + [4.8818e-02, 4.0063e-02, 8.7410e-03], + [-1.5066e-02, -5.7328e-04, 2.9785e-03], + [-1.7613e-02, -8.1034e-03, 1.3086e-02], + [-9.2633e-03, 1.0803e-02, -6.3489e-03], + [3.0851e-03, 4.7750e-04, 1.2347e-02], [-2.2785e-02, -2.3043e-02, -2.6005e-02], [-2.4787e-02, -1.5389e-02, -2.2104e-02], - [-2.3572e-02, 1.0544e-03, 1.2361e-02], + [-2.3572e-02, 1.0544e-03, 1.2361e-02], [-7.8915e-03, -1.2271e-03, -6.0968e-03], - [-1.1478e-02, -1.2543e-03, 6.2679e-03], - [-5.4229e-02, 2.6644e-02, 6.3394e-03], - [ 4.4216e-03, -7.3338e-03, -1.0464e-02], - [-4.5013e-03, 1.6082e-03, 1.4420e-02], - [ 1.3673e-02, 8.8877e-03, 4.1253e-03], - [-1.0145e-02, 9.0072e-03, 1.5695e-02], - [-5.6234e-03, 1.1847e-03, 8.1261e-03], - [-3.7171e-03, -5.3538e-03, 1.2590e-03], - [ 2.9476e-02, 2.1424e-02, 3.0424e-02], + [-1.1478e-02, -1.2543e-03, 6.2679e-03], + [-5.4229e-02, 2.6644e-02, 6.3394e-03], + [4.4216e-03, -7.3338e-03, -1.0464e-02], + [-4.5013e-03, 1.6082e-03, 1.4420e-02], + [1.3673e-02, 8.8877e-03, 4.1253e-03], + [-1.0145e-02, 9.0072e-03, 1.5695e-02], + [-5.6234e-03, 1.1847e-03, 8.1261e-03], + [-3.7171e-03, -5.3538e-03, 1.2590e-03], + [2.9476e-02, 2.1424e-02, 3.0424e-02], [-3.4925e-02, -2.4340e-02, -2.5316e-02], [-3.4127e-02, -2.2406e-02, -1.0589e-02], [-1.7342e-02, -1.3249e-02, -1.0719e-02], [-2.1478e-03, -8.6051e-03, -2.9878e-03], - [ 1.2089e-03, -4.2391e-03, -6.8569e-03], - [ 9.0411e-04, -6.6886e-03, -6.7547e-05], - [ 1.6048e-02, -1.0057e-02, -2.8929e-02], - [ 1.2290e-03, 1.0163e-02, 1.8861e-02], - [ 1.7264e-02, 2.7257e-04, 1.3785e-02], - [-1.3482e-02, -3.6427e-03, 6.7481e-04], - [ 4.6782e-03, -5.2423e-03, 2.4467e-03], + [1.2089e-03, -4.2391e-03, -6.8569e-03], + [9.0411e-04, -6.6886e-03, -6.7547e-05], + [1.6048e-02, -1.0057e-02, -2.8929e-02], + [1.2290e-03, 1.0163e-02, 1.8861e-02], + [1.7264e-02, 2.7257e-04, 1.3785e-02], + [-1.3482e-02, -3.6427e-03, 6.7481e-04], + [4.6782e-03, -5.2423e-03, 2.4467e-03], [-5.9113e-03, -6.2244e-03, -1.8162e-03], - [ 1.5496e-02, 1.4582e-02, 1.9514e-03], - [ 7.4958e-03, 1.5886e-03, -8.2305e-03], - [ 1.9086e-02, 1.6360e-03, -3.9674e-03], + [1.5496e-02, 1.4582e-02, 1.9514e-03], + [7.4958e-03, 1.5886e-03, -8.2305e-03], + [1.9086e-02, 1.6360e-03, -3.9674e-03], [-5.7021e-03, -2.7307e-03, -4.1066e-03], - [ 1.7450e-03, 1.4602e-02, 2.5794e-02], - [-8.2788e-04, 2.2902e-03, 4.5161e-03], - [ 1.1632e-02, 8.9193e-03, -7.2813e-03], - [ 7.5721e-03, 2.6784e-03, 1.1393e-02], - [ 5.1939e-03, 3.6903e-03, 1.4049e-02], + [1.7450e-03, 1.4602e-02, 2.5794e-02], + [-8.2788e-04, 2.2902e-03, 4.5161e-03], + [1.1632e-02, 8.9193e-03, -7.2813e-03], + [7.5721e-03, 2.6784e-03, 1.1393e-02], + [5.1939e-03, 3.6903e-03, 1.4049e-02], [-1.8383e-02, -2.2529e-02, -2.4477e-02], - [ 5.8842e-04, -5.7874e-03, -1.4770e-02], + [5.8842e-04, -5.7874e-03, -1.4770e-02], [-1.6125e-02, -8.6101e-03, -1.4533e-02], - [ 2.0540e-02, 2.0729e-02, 6.4338e-03], - [ 3.3587e-03, -1.1226e-02, -1.6444e-02], - [-1.4742e-03, -1.0489e-02, 1.7097e-03], - [ 2.8130e-02, 2.3546e-02, 3.2791e-02], + [2.0540e-02, 2.0729e-02, 6.4338e-03], + [3.3587e-03, -1.1226e-02, -1.6444e-02], + [-1.4742e-03, -1.0489e-02, 1.7097e-03], + [2.8130e-02, 2.3546e-02, 3.2791e-02], [-1.8532e-02, -1.2842e-02, -8.7756e-03], [-8.0533e-03, -1.0771e-02, -1.7536e-02], - [-3.9009e-03, 1.6150e-02, 3.3359e-02], + [-3.9009e-03, 1.6150e-02, 3.3359e-02], [-7.4554e-03, -1.4154e-02, -6.1910e-03], - [ 3.4734e-03, -1.1370e-02, -1.0581e-02], - [ 1.1476e-02, 3.9281e-03, 2.8231e-03], - [ 7.1639e-03, -1.4741e-03, -3.8066e-03], - [ 2.2250e-03, -8.7552e-03, -9.5719e-03], - [ 2.4146e-02, 2.1696e-02, 2.8056e-02], + [3.4734e-03, -1.1370e-02, -1.0581e-02], + [1.1476e-02, 3.9281e-03, 2.8231e-03], + [7.1639e-03, -1.4741e-03, -3.8066e-03], + [2.2250e-03, -8.7552e-03, -9.5719e-03], + [2.4146e-02, 2.1696e-02, 2.8056e-02], [-5.4365e-03, -2.4291e-02, -1.7802e-02], - [ 7.4263e-03, 1.0510e-02, 1.2705e-02], - [ 6.2669e-03, 6.2658e-03, 1.9211e-02], - [ 1.6378e-02, 9.4933e-03, 6.6971e-03], - [ 1.7173e-02, 2.3601e-02, 2.3296e-02], + [7.4263e-03, 1.0510e-02, 1.2705e-02], + [6.2669e-03, 6.2658e-03, 1.9211e-02], + [1.6378e-02, 9.4933e-03, 6.6971e-03], + [1.7173e-02, 2.3601e-02, 2.3296e-02], [-1.4568e-02, -9.8279e-03, -1.1556e-02], - [ 1.4431e-02, 1.4430e-02, 6.6362e-03], - [-6.8230e-03, 1.8863e-02, 1.4555e-02], - [ 6.1156e-03, 3.4700e-03, -2.6662e-03], + [1.4431e-02, 1.4430e-02, 6.6362e-03], + [-6.8230e-03, 1.8863e-02, 1.4555e-02], + [6.1156e-03, 3.4700e-03, -2.6662e-03], [-2.6983e-03, -5.9402e-03, -9.2276e-03], - [ 1.0235e-02, 7.4173e-03, -7.6243e-03], - [-1.3255e-02, 1.9322e-02, -9.2153e-04], - [ 2.4222e-03, -4.8039e-03, -1.5759e-02], - [ 2.6244e-02, 2.5951e-02, 2.0249e-02], - [ 1.5711e-02, 1.8498e-02, 2.7407e-03], - [-2.1714e-03, 4.7214e-03, -2.2443e-02], - [-7.4747e-03, 7.4166e-03, 1.4430e-02], - [-8.3906e-03, -7.9776e-03, 9.7927e-03], - [ 3.8321e-02, 9.6622e-03, -1.9268e-02], - [-1.4605e-02, -6.7032e-03, 3.9675e-03] + [1.0235e-02, 7.4173e-03, -7.6243e-03], + [-1.3255e-02, 1.9322e-02, -9.2153e-04], + [2.4222e-03, -4.8039e-03, -1.5759e-02], + [2.6244e-02, 2.5951e-02, 2.0249e-02], + [1.5711e-02, 1.8498e-02, 2.7407e-03], + [-2.1714e-03, 4.7214e-03, -2.2443e-02], + [-7.4747e-03, 7.4166e-03, 1.4430e-02], + [-8.3906e-03, -7.9776e-03, 9.7927e-03], + [3.8321e-02, 9.6622e-03, -1.9268e-02], + [-1.4605e-02, -6.7032e-03, 3.9675e-03] ] self.latent_rgb_factors_bias = [-0.0571, -0.1657, -0.2512] + class HunyuanVideo(LatentFormat): latent_channels = 16 latent_dimensions = 3 scale_factor = 0.476986 latent_rgb_factors = [ - [-0.0395, -0.0331, 0.0445], - [ 0.0696, 0.0795, 0.0518], - [ 0.0135, -0.0945, -0.0282], - [ 0.0108, -0.0250, -0.0765], - [-0.0209, 0.0032, 0.0224], + [-0.0395, -0.0331, 0.0445], + [0.0696, 0.0795, 0.0518], + [0.0135, -0.0945, -0.0282], + [0.0108, -0.0250, -0.0765], + [-0.0209, 0.0032, 0.0224], [-0.0804, -0.0254, -0.0639], - [-0.0991, 0.0271, -0.0669], + [-0.0991, 0.0271, -0.0669], [-0.0646, -0.0422, -0.0400], [-0.0696, -0.0595, -0.0894], [-0.0799, -0.0208, -0.0375], - [ 0.1166, 0.1627, 0.0962], - [ 0.1165, 0.0432, 0.0407], + [0.1166, 0.1627, 0.0962], + [0.1165, 0.0432, 0.0407], [-0.2315, -0.1920, -0.1355], - [-0.0270, 0.0401, -0.0821], + [-0.0270, 0.0401, -0.0821], [-0.0616, -0.0997, -0.0727], - [ 0.0249, -0.0469, -0.1703] + [0.0249, -0.0469, -0.1703] ] - latent_rgb_factors_bias = [ 0.0259, -0.0192, -0.0761] + latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761] + taesd_decoder_name = "taehv" + class Cosmos1CV8x8x8(LatentFormat): latent_channels = 16 latent_dimensions = 3 latent_rgb_factors = [ - [ 0.1817, 0.2284, 0.2423], + [0.1817, 0.2284, 0.2423], [-0.0586, -0.0862, -0.3108], [-0.4703, -0.4255, -0.3995], - [ 0.0803, 0.1963, 0.1001], - [-0.0820, -0.1050, 0.0400], - [ 0.2511, 0.3098, 0.2787], + [0.0803, 0.1963, 0.1001], + [-0.0820, -0.1050, 0.0400], + [0.2511, 0.3098, 0.2787], [-0.1830, -0.2117, -0.0040], [-0.0621, -0.2187, -0.0939], - [ 0.3619, 0.1082, 0.1455], - [ 0.3164, 0.3922, 0.2575], - [ 0.1152, 0.0231, -0.0462], + [0.3619, 0.1082, 0.1455], + [0.3164, 0.3922, 0.2575], + [0.1152, 0.0231, -0.0462], [-0.1434, -0.3609, -0.3665], - [ 0.0635, 0.1471, 0.1680], + [0.0635, 0.1471, 0.1680], [-0.3635, -0.1963, -0.3248], - [-0.1865, 0.0365, 0.2346], - [ 0.0447, 0.0994, 0.0881] + [-0.1865, 0.0365, 0.2346], + [0.0447, 0.0994, 0.0881] ] latent_rgb_factors_bias = [-0.1223, -0.1889, -0.1976] + class Wan21(LatentFormat): latent_channels = 16 latent_dimensions = 3 latent_rgb_factors = [ - [-0.1299, -0.1692, 0.2932], - [ 0.0671, 0.0406, 0.0442], - [ 0.3568, 0.2548, 0.1747], - [ 0.0372, 0.2344, 0.1420], - [ 0.0313, 0.0189, -0.0328], - [ 0.0296, -0.0956, -0.0665], - [-0.3477, -0.4059, -0.2925], - [ 0.0166, 0.1902, 0.1975], - [-0.0412, 0.0267, -0.1364], - [-0.1293, 0.0740, 0.1636], - [ 0.0680, 0.3019, 0.1128], - [ 0.0032, 0.0581, 0.0639], - [-0.1251, 0.0927, 0.1699], - [ 0.0060, -0.0633, 0.0005], - [ 0.3477, 0.2275, 0.2950], - [ 0.1984, 0.0913, 0.1861] - ] + [-0.1299, -0.1692, 0.2932], + [0.0671, 0.0406, 0.0442], + [0.3568, 0.2548, 0.1747], + [0.0372, 0.2344, 0.1420], + [0.0313, 0.0189, -0.0328], + [0.0296, -0.0956, -0.0665], + [-0.3477, -0.4059, -0.2925], + [0.0166, 0.1902, 0.1975], + [-0.0412, 0.0267, -0.1364], + [-0.1293, 0.0740, 0.1636], + [0.0680, 0.3019, 0.1128], + [0.0032, 0.0581, 0.0639], + [-0.1251, 0.0927, 0.1699], + [0.0060, -0.0633, 0.0005], + [0.3477, 0.2275, 0.2950], + [0.1984, 0.0913, 0.1861] + ] latent_rgb_factors_bias = [-0.1835, -0.0868, -0.3360] @@ -458,8 +512,7 @@ class Wan21(LatentFormat): 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160 ]).view(1, self.latent_channels, 1, 1, 1) - - self.taesd_decoder_name = None #TODO + self.taesd_decoder_name = "lighttaew2_1" def process_in(self, latent): latents_mean = self.latents_mean.to(latent.device, latent.dtype) @@ -471,81 +524,84 @@ class Wan21(LatentFormat): latents_std = self.latents_std.to(latent.device, latent.dtype) return latent * latents_std / self.scale_factor + latents_mean + class Wan22(Wan21): latent_channels = 48 latent_dimensions = 3 latent_rgb_factors = [ - [ 0.0119, 0.0103, 0.0046], - [-0.1062, -0.0504, 0.0165], - [ 0.0140, 0.0409, 0.0491], - [-0.0813, -0.0677, 0.0607], - [ 0.0656, 0.0851, 0.0808], - [ 0.0264, 0.0463, 0.0912], - [ 0.0295, 0.0326, 0.0590], - [-0.0244, -0.0270, 0.0025], - [ 0.0443, -0.0102, 0.0288], - [-0.0465, -0.0090, -0.0205], - [ 0.0359, 0.0236, 0.0082], - [-0.0776, 0.0854, 0.1048], - [ 0.0564, 0.0264, 0.0561], - [ 0.0006, 0.0594, 0.0418], - [-0.0319, -0.0542, -0.0637], - [-0.0268, 0.0024, 0.0260], - [ 0.0539, 0.0265, 0.0358], - [-0.0359, -0.0312, -0.0287], - [-0.0285, -0.1032, -0.1237], - [ 0.1041, 0.0537, 0.0622], - [-0.0086, -0.0374, -0.0051], - [ 0.0390, 0.0670, 0.2863], - [ 0.0069, 0.0144, 0.0082], - [ 0.0006, -0.0167, 0.0079], - [ 0.0313, -0.0574, -0.0232], - [-0.1454, -0.0902, -0.0481], - [ 0.0714, 0.0827, 0.0447], - [-0.0304, -0.0574, -0.0196], - [ 0.0401, 0.0384, 0.0204], - [-0.0758, -0.0297, -0.0014], - [ 0.0568, 0.1307, 0.1372], - [-0.0055, -0.0310, -0.0380], - [ 0.0239, -0.0305, 0.0325], - [-0.0663, -0.0673, -0.0140], - [-0.0416, -0.0047, -0.0023], - [ 0.0166, 0.0112, -0.0093], - [-0.0211, 0.0011, 0.0331], - [ 0.1833, 0.1466, 0.2250], - [-0.0368, 0.0370, 0.0295], - [-0.3441, -0.3543, -0.2008], - [-0.0479, -0.0489, -0.0420], - [-0.0660, -0.0153, 0.0800], - [-0.0101, 0.0068, 0.0156], - [-0.0690, -0.0452, -0.0927], - [-0.0145, 0.0041, 0.0015], - [ 0.0421, 0.0451, 0.0373], - [ 0.0504, -0.0483, -0.0356], - [-0.0837, 0.0168, 0.0055] - ] + [0.0119, 0.0103, 0.0046], + [-0.1062, -0.0504, 0.0165], + [0.0140, 0.0409, 0.0491], + [-0.0813, -0.0677, 0.0607], + [0.0656, 0.0851, 0.0808], + [0.0264, 0.0463, 0.0912], + [0.0295, 0.0326, 0.0590], + [-0.0244, -0.0270, 0.0025], + [0.0443, -0.0102, 0.0288], + [-0.0465, -0.0090, -0.0205], + [0.0359, 0.0236, 0.0082], + [-0.0776, 0.0854, 0.1048], + [0.0564, 0.0264, 0.0561], + [0.0006, 0.0594, 0.0418], + [-0.0319, -0.0542, -0.0637], + [-0.0268, 0.0024, 0.0260], + [0.0539, 0.0265, 0.0358], + [-0.0359, -0.0312, -0.0287], + [-0.0285, -0.1032, -0.1237], + [0.1041, 0.0537, 0.0622], + [-0.0086, -0.0374, -0.0051], + [0.0390, 0.0670, 0.2863], + [0.0069, 0.0144, 0.0082], + [0.0006, -0.0167, 0.0079], + [0.0313, -0.0574, -0.0232], + [-0.1454, -0.0902, -0.0481], + [0.0714, 0.0827, 0.0447], + [-0.0304, -0.0574, -0.0196], + [0.0401, 0.0384, 0.0204], + [-0.0758, -0.0297, -0.0014], + [0.0568, 0.1307, 0.1372], + [-0.0055, -0.0310, -0.0380], + [0.0239, -0.0305, 0.0325], + [-0.0663, -0.0673, -0.0140], + [-0.0416, -0.0047, -0.0023], + [0.0166, 0.0112, -0.0093], + [-0.0211, 0.0011, 0.0331], + [0.1833, 0.1466, 0.2250], + [-0.0368, 0.0370, 0.0295], + [-0.3441, -0.3543, -0.2008], + [-0.0479, -0.0489, -0.0420], + [-0.0660, -0.0153, 0.0800], + [-0.0101, 0.0068, 0.0156], + [-0.0690, -0.0452, -0.0927], + [-0.0145, 0.0041, 0.0015], + [0.0421, 0.0451, 0.0373], + [0.0504, -0.0483, -0.0356], + [-0.0837, 0.0168, 0.0055] + ] latent_rgb_factors_bias = [0.0317, -0.0878, -0.1388] def __init__(self): self.scale_factor = 1.0 + self.taesd_decoder_name = "lighttaew2_2" self.latents_mean = torch.tensor([ - -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, - -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, - -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, - -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230, - -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748, - 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667, - ]).view(1, self.latent_channels, 1, 1, 1) + -0.2289, -0.0052, -0.1323, -0.2339, -0.2799, 0.0174, 0.1838, 0.1557, + -0.1382, 0.0542, 0.2813, 0.0891, 0.1570, -0.0098, 0.0375, -0.1825, + -0.2246, -0.1207, -0.0698, 0.5109, 0.2665, -0.2108, -0.2158, 0.2502, + -0.2055, -0.0322, 0.1109, 0.1567, -0.0729, 0.0899, -0.2799, -0.1230, + -0.0313, -0.1649, 0.0117, 0.0723, -0.2839, -0.2083, -0.0520, 0.3748, + 0.0152, 0.1957, 0.1433, -0.2944, 0.3573, -0.0548, -0.1681, -0.0667, + ]).view(1, self.latent_channels, 1, 1, 1) self.latents_std = torch.tensor([ - 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013, - 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, - 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, - 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, - 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, - 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744 - ]).view(1, self.latent_channels, 1, 1, 1) + 0.4765, 1.0364, 0.4514, 1.1677, 0.5313, 0.4990, 0.4818, 0.5013, + 0.8158, 1.0344, 0.5894, 1.0901, 0.6885, 0.6165, 0.8454, 0.4978, + 0.5759, 0.3523, 0.7135, 0.6804, 0.5833, 1.4146, 0.8986, 0.5659, + 0.7069, 0.5338, 0.4889, 0.4917, 0.4069, 0.4999, 0.6866, 0.4093, + 0.5709, 0.6065, 0.6415, 0.4944, 0.5726, 1.2042, 0.5458, 1.6887, + 0.3971, 1.0600, 0.3943, 0.5537, 0.5444, 0.4089, 0.7468, 0.7744 + ]).view(1, self.latent_channels, 1, 1, 1) + class HunyuanImage21(LatentFormat): latent_channels = 64 @@ -554,105 +610,173 @@ class HunyuanImage21(LatentFormat): latent_rgb_factors = [ [-0.0154, -0.0397, -0.0521], - [ 0.0005, 0.0093, 0.0006], + [0.0005, 0.0093, 0.0006], [-0.0805, -0.0773, -0.0586], [-0.0494, -0.0487, -0.0498], [-0.0212, -0.0076, -0.0261], [-0.0179, -0.0417, -0.0505], - [ 0.0158, 0.0310, 0.0239], - [ 0.0409, 0.0516, 0.0201], - [ 0.0350, 0.0553, 0.0036], + [0.0158, 0.0310, 0.0239], + [0.0409, 0.0516, 0.0201], + [0.0350, 0.0553, 0.0036], [-0.0447, -0.0327, -0.0479], [-0.0038, -0.0221, -0.0365], [-0.0423, -0.0718, -0.0654], - [ 0.0039, 0.0368, 0.0104], - [ 0.0655, 0.0217, 0.0122], - [ 0.0490, 0.1638, 0.2053], - [ 0.0932, 0.0829, 0.0650], + [0.0039, 0.0368, 0.0104], + [0.0655, 0.0217, 0.0122], + [0.0490, 0.1638, 0.2053], + [0.0932, 0.0829, 0.0650], [-0.0186, -0.0209, -0.0135], [-0.0080, -0.0076, -0.0148], - [-0.0284, -0.0201, 0.0011], + [-0.0284, -0.0201, 0.0011], [-0.0642, -0.0294, -0.0777], - [-0.0035, 0.0076, -0.0140], - [ 0.0519, 0.0731, 0.0887], - [-0.0102, 0.0095, 0.0704], - [ 0.0068, 0.0218, -0.0023], + [-0.0035, 0.0076, -0.0140], + [0.0519, 0.0731, 0.0887], + [-0.0102, 0.0095, 0.0704], + [0.0068, 0.0218, -0.0023], [-0.0726, -0.0486, -0.0519], - [ 0.0260, 0.0295, 0.0263], - [ 0.0250, 0.0333, 0.0341], - [ 0.0168, -0.0120, -0.0174], - [ 0.0226, 0.1037, 0.0114], - [ 0.2577, 0.1906, 0.1604], + [0.0260, 0.0295, 0.0263], + [0.0250, 0.0333, 0.0341], + [0.0168, -0.0120, -0.0174], + [0.0226, 0.1037, 0.0114], + [0.2577, 0.1906, 0.1604], [-0.0646, -0.0137, -0.0018], - [-0.0112, 0.0309, 0.0358], - [-0.0347, 0.0146, -0.0481], - [ 0.0234, 0.0179, 0.0201], - [ 0.0157, 0.0313, 0.0225], - [ 0.0423, 0.0675, 0.0524], - [-0.0031, 0.0027, -0.0255], - [ 0.0447, 0.0555, 0.0330], - [-0.0152, 0.0103, 0.0299], + [-0.0112, 0.0309, 0.0358], + [-0.0347, 0.0146, -0.0481], + [0.0234, 0.0179, 0.0201], + [0.0157, 0.0313, 0.0225], + [0.0423, 0.0675, 0.0524], + [-0.0031, 0.0027, -0.0255], + [0.0447, 0.0555, 0.0330], + [-0.0152, 0.0103, 0.0299], [-0.0755, -0.0489, -0.0635], - [ 0.0853, 0.0788, 0.1017], + [0.0853, 0.0788, 0.1017], [-0.0272, -0.0294, -0.0471], - [ 0.0440, 0.0400, -0.0137], - [ 0.0335, 0.0317, -0.0036], + [0.0440, 0.0400, -0.0137], + [0.0335, 0.0317, -0.0036], [-0.0344, -0.0621, -0.0984], [-0.0127, -0.0630, -0.0620], - [-0.0648, 0.0360, 0.0924], + [-0.0648, 0.0360, 0.0924], [-0.0781, -0.0801, -0.0409], - [ 0.0363, 0.0613, 0.0499], - [ 0.0238, 0.0034, 0.0041], - [-0.0135, 0.0258, 0.0310], - [ 0.0614, 0.1086, 0.0589], - [ 0.0428, 0.0350, 0.0205], - [ 0.0153, 0.0173, -0.0018], + [0.0363, 0.0613, 0.0499], + [0.0238, 0.0034, 0.0041], + [-0.0135, 0.0258, 0.0310], + [0.0614, 0.1086, 0.0589], + [0.0428, 0.0350, 0.0205], + [0.0153, 0.0173, -0.0018], [-0.0288, -0.0455, -0.0091], - [ 0.0344, 0.0109, -0.0157], + [0.0344, 0.0109, -0.0157], [-0.0205, -0.0247, -0.0187], - [ 0.0487, 0.0126, 0.0064], - [-0.0220, -0.0013, 0.0074], + [0.0487, 0.0126, 0.0064], + [-0.0220, -0.0013, 0.0074], [-0.0203, -0.0094, -0.0048], - [-0.0719, 0.0429, -0.0442], - [ 0.1042, 0.0497, 0.0356], + [-0.0719, 0.0429, -0.0442], + [0.1042, 0.0497, 0.0356], [-0.0659, -0.0578, -0.0280], [-0.0060, -0.0322, -0.0234]] latent_rgb_factors_bias = [0.0007, -0.0256, -0.0206] + class HunyuanImage21Refiner(LatentFormat): latent_channels = 64 latent_dimensions = 3 scale_factor = 1.03682 + def process_in(self, latent): + out = latent * self.scale_factor + out = torch.cat((out[:, :, :1], out), dim=2) + out = out.permute(0, 2, 1, 3, 4) + b, f_times_2, c, h, w = out.shape + out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) + out = out.permute(0, 2, 1, 3, 4).contiguous() + return out + + def process_out(self, latent): + z = latent / self.scale_factor + z = z.permute(0, 2, 1, 3, 4) + b, f, c, h, w = z.shape + z = z.reshape(b, f, 2, c // 2, h, w) + z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) + z = z.permute(0, 2, 1, 3, 4) + z = z[:, :, 1:] + return z + + +class HunyuanVideo15(LatentFormat): + latent_rgb_factors = [ + [0.0568, -0.0521, -0.0131], + [0.0014, 0.0735, 0.0326], + [0.0186, 0.0531, -0.0138], + [-0.0031, 0.0051, 0.0288], + [0.0110, 0.0556, 0.0432], + [-0.0041, -0.0023, -0.0485], + [0.0530, 0.0413, 0.0253], + [0.0283, 0.0251, 0.0339], + [0.0277, -0.0372, -0.0093], + [0.0393, 0.0944, 0.1131], + [0.0020, 0.0251, 0.0037], + [-0.0017, 0.0012, 0.0234], + [0.0468, 0.0436, 0.0203], + [0.0354, 0.0439, -0.0233], + [0.0090, 0.0123, 0.0346], + [0.0382, 0.0029, 0.0217], + [0.0261, -0.0300, 0.0030], + [-0.0088, -0.0220, -0.0283], + [-0.0272, -0.0121, -0.0363], + [-0.0664, -0.0622, 0.0144], + [0.0414, 0.0479, 0.0529], + [0.0355, 0.0612, -0.0247], + [0.0147, 0.0264, 0.0174], + [0.0438, 0.0038, 0.0542], + [0.0431, -0.0573, -0.0033], + [-0.0162, -0.0211, -0.0406], + [-0.0487, -0.0295, -0.0393], + [0.0005, -0.0109, 0.0253], + [0.0296, 0.0591, 0.0353], + [0.0119, 0.0181, -0.0306], + [-0.0085, -0.0362, 0.0229], + [0.0005, -0.0106, 0.0242] + ] + + latent_rgb_factors_bias = [0.0456, -0.0202, -0.0644] + latent_channels = 32 + latent_dimensions = 3 + scale_factor = 1.03682 + taesd_decoder_name = "lighttaehy1_5" + + class Hunyuan3Dv2(LatentFormat): latent_channels = 64 latent_dimensions = 1 scale_factor = 0.9990943042622529 + class Hunyuan3Dv2_1(LatentFormat): scale_factor = 1.0039506158752403 latent_channels = 64 latent_dimensions = 1 + class Hunyuan3Dv2mini(LatentFormat): latent_channels = 64 latent_dimensions = 1 scale_factor = 1.0188137142395404 + class ACEAudio(LatentFormat): latent_channels = 8 latent_dimensions = 2 + class ChromaRadiance(LatentFormat): latent_channels = 3 def __init__(self): self.latent_rgb_factors = [ # R G B - [ 1.0, 0.0, 0.0 ], - [ 0.0, 1.0, 0.0 ], - [ 0.0, 0.0, 1.0 ] + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [0.0, 0.0, 1.0] ] def process_in(self, latent): diff --git a/comfy/ldm/chroma/layers.py b/comfy/ldm/chroma/layers.py index 7b8a810d7..6a1ba6b7d 100644 --- a/comfy/ldm/chroma/layers.py +++ b/comfy/ldm/chroma/layers.py @@ -1,29 +1,33 @@ import torch from torch import Tensor, nn -from ..flux.math import attention -from ..flux.layers import MLPEmbedder, RMSNorm, QKNorm, SelfAttention, ModulationOut +from comfy.ldm.flux.layers import ( + MLPEmbedder, + RMSNorm, + ModulationOut, +) +# TODO: remove this in a few months +SingleStreamBlock = None +DoubleStreamBlock = None class ChromaModulationOut(ModulationOut): @classmethod def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut: return cls( - shift=tensor[:, offset : offset + 1, :], - scale=tensor[:, offset + 1 : offset + 2, :], - gate=tensor[:, offset + 2 : offset + 3, :], + shift=tensor[:, offset: offset + 1, :], + scale=tensor[:, offset + 1: offset + 2, :], + gate=tensor[:, offset + 2: offset + 3, :], ) - - class Approximator(nn.Module): - def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None): + def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers=5, dtype=None, device=None, operations=None): super().__init__() self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) - self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) - self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) + self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range(n_layers)]) + self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range(n_layers)]) self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) @property @@ -42,124 +46,6 @@ class Approximator(nn.Module): return x -class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): - super().__init__() - - mlp_hidden_dim = int(hidden_size * mlp_ratio) - self.num_heads = num_heads - self.hidden_size = hidden_size - self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) - - self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) - - self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) - - self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) - self.flipped_img_txt = flipped_img_txt - - def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}): - (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec - - # prepare image for attention - img_modulated = torch.addcmul(img_mod1.shift, 1 + img_mod1.scale, self.img_norm1(img)) - img_qkv = self.img_attn.qkv(img_modulated) - img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) - - # prepare txt for attention - txt_modulated = torch.addcmul(txt_mod1.shift, 1 + txt_mod1.scale, self.txt_norm1(txt)) - txt_qkv = self.txt_attn.qkv(txt_modulated) - txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) - - # run actual attention - attn = attention(torch.cat((txt_q, img_q), dim=2), - torch.cat((txt_k, img_k), dim=2), - torch.cat((txt_v, img_v), dim=2), - pe=pe, mask=attn_mask, transformer_options=transformer_options) - - txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] - - # calculate the img bloks - img.addcmul_(img_mod1.gate, self.img_attn.proj(img_attn)) - img.addcmul_(img_mod2.gate, self.img_mlp(torch.addcmul(img_mod2.shift, 1 + img_mod2.scale, self.img_norm2(img)))) - - # calculate the txt bloks - txt.addcmul_(txt_mod1.gate, self.txt_attn.proj(txt_attn)) - txt.addcmul_(txt_mod2.gate, self.txt_mlp(torch.addcmul(txt_mod2.shift, 1 + txt_mod2.scale, self.txt_norm2(txt)))) - - if txt.dtype == torch.float16: - txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) - - return img, txt - - -class SingleStreamBlock(nn.Module): - """ - A DiT block with parallel linear layers as described in - https://arxiv.org/abs/2302.05442 and adapted modulation interface. - """ - - def __init__( - self, - hidden_size: int, - num_heads: int, - mlp_ratio: float = 4.0, - qk_scale: float = None, - dtype=None, - device=None, - operations=None - ): - super().__init__() - self.hidden_dim = hidden_size - self.num_heads = num_heads - head_dim = hidden_size // num_heads - self.scale = qk_scale or head_dim**-0.5 - - self.mlp_hidden_dim = int(hidden_size * mlp_ratio) - # qkv and mlp_in - self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) - # proj and mlp_out - self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) - - self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) - - self.hidden_size = hidden_size - self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - - self.mlp_act = nn.GELU(approximate="tanh") - - def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None, transformer_options={}) -> Tensor: - mod = vec - x_mod = torch.addcmul(mod.shift, 1 + mod.scale, self.pre_norm(x)) - qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) - - q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k = self.norm(q, k, v) - - # compute attention - attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) - # compute activation in mlp stream, cat again and run second linear layer - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) - x.addcmul_(mod.gate, output) - if x.dtype == torch.float16: - x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) - return x - - class LastLayer(nn.Module): def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): super().__init__() diff --git a/comfy/ldm/chroma/model.py b/comfy/ldm/chroma/model.py index e835ad912..fbfb52f34 100644 --- a/comfy/ldm/chroma/model.py +++ b/comfy/ldm/chroma/model.py @@ -8,12 +8,15 @@ from einops import rearrange, repeat from ..common_dit import pad_to_patch_size from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP -from ..flux.layers import EmbedND, timestep_embedding +from ..flux.layers import ( + EmbedND, + timestep_embedding, + DoubleStreamBlock, + SingleStreamBlock, +) from .layers import ( - DoubleStreamBlock, LastLayer, - SingleStreamBlock, Approximator, ChromaModulationOut, ) @@ -37,6 +40,8 @@ class ChromaParams: out_dim: int hidden_dim: int n_layers: int + txt_ids_dims: list + vec_in_dim: int class Chroma(nn.Module): @@ -84,6 +89,7 @@ class Chroma(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + modulation=False, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -92,7 +98,7 @@ class Chroma(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=False, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) @@ -173,7 +179,10 @@ class Chroma(nn.Module): pe = self.pe_embedder(ids) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if i not in self.skip_mmdit: double_mod = ( self.get_modulations(mod_vectors, "double_img", idx=i), @@ -216,7 +225,10 @@ class Chroma(nn.Module): img = torch.cat((txt, img), 1) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if i not in self.skip_dit: single_mod = self.get_modulations(mod_vectors, "single", idx=i) if ("single_block", i) in blocks_replace: diff --git a/comfy/ldm/chroma_radiance/model.py b/comfy/ldm/chroma_radiance/model.py index 015aa7629..6553a160c 100644 --- a/comfy/ldm/chroma_radiance/model.py +++ b/comfy/ldm/chroma_radiance/model.py @@ -10,10 +10,10 @@ from torch import Tensor, nn from einops import repeat from ..common_dit import pad_to_patch_size -from ..flux.layers import EmbedND +from ..flux.layers import EmbedND, DoubleStreamBlock, SingleStreamBlock from ..chroma.model import Chroma, ChromaParams -from ..chroma.layers import DoubleStreamBlock, SingleStreamBlock, Approximator +from ..chroma.layers import Approximator from .layers import ( NerfEmbedder, NerfGLUBlock, @@ -94,6 +94,7 @@ class ChromaRadiance(Chroma): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + modulation=False, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -106,6 +107,7 @@ class ChromaRadiance(Chroma): self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, + modulation=False, dtype=dtype, device=device, operations=operations, ) for _ in range(params.depth_single_blocks) diff --git a/comfy/ldm/flux/layers.py b/comfy/ldm/flux/layers.py index 14c6affe1..691dee419 100644 --- a/comfy/ldm/flux/layers.py +++ b/comfy/ldm/flux/layers.py @@ -47,15 +47,44 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10 return embedding class MLPEmbedder(nn.Module): - def __init__(self, in_dim: int, hidden_dim: int, dtype=None, device=None, operations=None): + def __init__(self, in_dim: int, hidden_dim: int, bias=True, dtype=None, device=None, operations=None): super().__init__() - self.in_layer = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.in_layer = operations.Linear(in_dim, hidden_dim, bias=bias, dtype=dtype, device=device) self.silu = nn.SiLU() - self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=True, dtype=dtype, device=device) + self.out_layer = operations.Linear(hidden_dim, hidden_dim, bias=bias, dtype=dtype, device=device) def forward(self, x: Tensor) -> Tensor: return self.out_layer(self.silu(self.in_layer(x))) +class YakMLP(nn.Module): + def __init__(self, hidden_size: int, intermediate_size: int, dtype=None, device=None, operations=None): + super().__init__() + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.gate_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device) + self.up_proj = operations.Linear(self.hidden_size, self.intermediate_size, bias=True, dtype=dtype, device=device) + self.down_proj = operations.Linear(self.intermediate_size, self.hidden_size, bias=True, dtype=dtype, device=device) + self.act_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + +def build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=False, yak_mlp=False, dtype=None, device=None, operations=None): + if yak_mlp: + return YakMLP(hidden_size, mlp_hidden_dim, dtype=dtype, device=device, operations=operations) + if mlp_silu_act: + return nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim * 2, bias=False, dtype=dtype, device=device), + SiLUActivation(), + operations.Linear(mlp_hidden_dim, hidden_size, bias=False, dtype=dtype, device=device), + ) + else: + return nn.Sequential( + operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), + nn.GELU(approximate="tanh"), + operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), + ) class RMSNorm(torch.nn.Module): def __init__(self, dim: int, dtype=None, device=None, operations=None): @@ -79,14 +108,14 @@ class QKNorm(torch.nn.Module): class SelfAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, dtype=None, device=None, operations=None): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, dtype=None, device=None, operations=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) - self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) + self.proj = operations.Linear(dim, dim, bias=proj_bias, dtype=dtype, device=device) @dataclass @@ -97,11 +126,11 @@ class ModulationOut: class Modulation(nn.Module): - def __init__(self, dim: int, double: bool, dtype=None, device=None, operations=None): + def __init__(self, dim: int, double: bool, bias=True, dtype=None, device=None, operations=None): super().__init__() self.is_double = double self.multiplier = 6 if double else 3 - self.lin = operations.Linear(dim, self.multiplier * dim, bias=True, dtype=dtype, device=device) + self.lin = operations.Linear(dim, self.multiplier * dim, bias=bias, dtype=dtype, device=device) def forward(self, vec: Tensor) -> tuple: if vec.ndim == 2: @@ -128,80 +157,110 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None): return tensor +class SiLUActivation(nn.Module): + def __init__(self): + super().__init__() + self.gate_fn = nn.SiLU() + + def forward(self, x: Tensor) -> Tensor: + x1, x2 = x.chunk(2, dim=-1) + return self.gate_fn(x1) * x2 + + class DoubleStreamBlock(nn.Module): - def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, modulation=True, mlp_silu_act=False, proj_bias=True, yak_mlp=False, dtype=None, device=None, operations=None): super().__init__() mlp_hidden_dim = int(hidden_size * mlp_ratio) self.num_heads = num_heads self.hidden_size = hidden_size - self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.modulation = modulation + + if self.modulation: + self.img_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations) self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.img_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) - self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.img_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) + + if self.modulation: + self.txt_mod = Modulation(hidden_size, double=True, dtype=dtype, device=device, operations=operations) + self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) + self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias, dtype=dtype, device=device, operations=operations) self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.txt_mlp = nn.Sequential( - operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), - nn.GELU(approximate="tanh"), - operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), - ) + + self.txt_mlp = build_mlp(hidden_size, mlp_hidden_dim, mlp_silu_act=mlp_silu_act, yak_mlp=yak_mlp, dtype=dtype, device=device, operations=operations) + self.flipped_img_txt = flipped_img_txt def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims_img=None, modulation_dims_txt=None, transformer_options={}): - img_mod1, img_mod2 = self.img_mod(vec) - txt_mod1, txt_mod2 = self.txt_mod(vec) + if self.modulation: + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + else: + (img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec # prepare image for attention img_modulated = self.img_norm1(img) img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img) img_qkv = self.img_attn.qkv(img_modulated) - img_qkv = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - img_q, img_k, img_v = torch.unbind(img_qkv, dim=0) + del img_modulated + img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + del img_qkv img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) # prepare txt for attention txt_modulated = self.txt_norm1(txt) txt_modulated = apply_mod(txt_modulated, (1 + txt_mod1.scale), txt_mod1.shift, modulation_dims_txt) txt_qkv = self.txt_attn.qkv(txt_modulated) - txt_qkv = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - txt_q, txt_k, txt_v = torch.unbind(txt_qkv, dim=0) + del txt_modulated + txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + del txt_qkv txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) if self.flipped_img_txt: + q = torch.cat((img_q, txt_q), dim=2) + del img_q, txt_q + k = torch.cat((img_k, txt_k), dim=2) + del img_k, txt_k + v = torch.cat((img_v, txt_v), dim=2) + del img_v, txt_v # run actual attention - attn = attention(torch.cat((img_q, txt_q), dim=2), - torch.cat((img_k, txt_k), dim=2), - torch.cat((img_v, txt_v), dim=2), + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1]:] else: + q = torch.cat((txt_q, img_q), dim=2) + del txt_q, img_q + k = torch.cat((txt_k, img_k), dim=2) + del txt_k, img_k + v = torch.cat((txt_v, img_v), dim=2) + del txt_v, img_v # run actual attention - attn = attention(torch.cat((txt_q, img_q), dim=2), - torch.cat((txt_k, img_k), dim=2), - torch.cat((txt_v, img_v), dim=2), + attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:] - # calculate the img bloks - img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) - img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) + # calculate the img blocks + # todo: do we have to re-investigate this += versus img = img + ... op? + img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img) + del img_attn + img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img) - # calculate the txt bloks - txt = txt + apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt) - txt = txt + apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt) + # calculate the txt blocks + # todo: do we have to re-investigate this += versus txt = txt + ... op? + txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt) + del txt_attn + txt += apply_mod(self.txt_mlp(apply_mod(self.txt_norm2(txt), (1 + txt_mod2.scale), txt_mod2.shift, modulation_dims_txt)), txt_mod2.gate, None, modulation_dims_txt) if txt.dtype == torch.float16: txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) @@ -221,6 +280,10 @@ class SingleStreamBlock(nn.Module): num_heads: int, mlp_ratio: float = 4.0, qk_scale: float = None, + modulation=True, + mlp_silu_act=False, + bias=True, + yak_mlp=False, dtype=None, device=None, operations=None @@ -232,31 +295,57 @@ class SingleStreamBlock(nn.Module): self.scale = qk_scale or head_dim**-0.5 self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.mlp_hidden_dim_first = self.mlp_hidden_dim + self.yak_mlp = yak_mlp + if mlp_silu_act: + self.mlp_hidden_dim_first = int(hidden_size * mlp_ratio * 2) + self.mlp_act = SiLUActivation() + else: + self.mlp_act = nn.GELU(approximate="tanh") + + if self.yak_mlp: + self.mlp_hidden_dim_first *= 2 + self.mlp_act = nn.SiLU() + # qkv and mlp_in - self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) + self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim_first, bias=bias, dtype=dtype, device=device) # proj and mlp_out - self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) + self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, bias=bias, dtype=dtype, device=device) self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) self.hidden_size = hidden_size self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.mlp_act = nn.GELU(approximate="tanh") - self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) + if modulation: + self.modulation = Modulation(hidden_size, double=False, dtype=dtype, device=device, operations=operations) + else: + self.modulation = None def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation_dims=None, transformer_options={}) -> Tensor: - mod, _ = self.modulation(vec) - qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) + if self.modulation: + mod, _ = self.modulation(vec) + else: + mod = vec + + qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1) qkv = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = torch.unbind(qkv, dim=0) + del qkv + q, k = self.norm(q, k, v) # compute attention attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) + del q, k, v # compute activation in mlp stream, cat again and run second linear layer - output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) + if self.yak_mlp: + mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2] + else: + mlp = self.mlp_act(mlp) + output = self.linear2(torch.cat((attn, mlp), 2)) x = x + apply_mod(output, mod.gate, None, modulation_dims) if x.dtype == torch.float16: x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) @@ -264,11 +353,11 @@ class SingleStreamBlock(nn.Module): class LastLayer(nn.Module): - def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int, bias=True, dtype=None, device=None, operations=None): super().__init__() self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) - self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) - self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device)) + self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=bias, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential(nn.SiLU(), operations.Linear(hidden_size, 2 * hidden_size, bias=bias, dtype=dtype, device=device)) def forward(self, x: Tensor, vec: Tensor, modulation_dims=None) -> Tensor: if vec.ndim == 2: diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index 562c2c113..9a9784fd7 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -7,15 +7,8 @@ from ... import model_management def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor: - q_shape = q.shape - k_shape = k.shape - if pe is not None: - q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2) - k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2) - q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) - k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) - + q, k = apply_rope(q, k, pe) heads = q.shape[1] x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options) return x diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index eb7dff4dd..876f22818 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -13,6 +13,8 @@ from .layers import ( MLPEmbedder, SingleStreamBlock, timestep_embedding, + Modulation, + RMSNorm ) from .. import common_dit @@ -33,6 +35,14 @@ class FluxParams: patch_size: int qkv_bias: bool guidance_embed: bool + txt_ids_dims: list + global_modulation: bool = False + mlp_silu_act: bool = False + ops_bias: bool = True + default_ref_method: str = "offset" + ref_index_scale: float = 1.0 + yak_mlp: bool = False + txt_norm: bool = False class Flux(nn.Module): @@ -60,13 +70,22 @@ class Flux(nn.Module): self.hidden_size = params.hidden_size self.num_heads = params.num_heads self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim) - self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device) - self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) - self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + self.img_in = operations.Linear(self.in_channels, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + if params.vec_in_dim is not None: + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size, dtype=dtype, device=device, operations=operations) + else: + self.vector_in = None + self.guidance_in = ( - MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) if params.guidance_embed else nn.Identity() ) - self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, dtype=dtype, device=device) + self.txt_in = operations.Linear(params.context_in_dim, self.hidden_size, bias=params.ops_bias, dtype=dtype, device=device) + + if params.txt_norm: + self.txt_norm = RMSNorm(params.context_in_dim, dtype=dtype, device=device, operations=operations) + else: + self.txt_norm = None self.double_blocks = nn.ModuleList( [ @@ -75,6 +94,10 @@ class Flux(nn.Module): self.num_heads, mlp_ratio=params.mlp_ratio, qkv_bias=params.qkv_bias, + modulation=params.global_modulation is False, + mlp_silu_act=params.mlp_silu_act, + proj_bias=params.ops_bias, + yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations ) for _ in range(params.depth) @@ -83,13 +106,30 @@ class Flux(nn.Module): self.single_blocks = nn.ModuleList( [ - SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, dtype=dtype, device=device, operations=operations) + SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio, modulation=params.global_modulation is False, mlp_silu_act=params.mlp_silu_act, bias=params.ops_bias, yak_mlp=params.yak_mlp, dtype=dtype, device=device, operations=operations) for _ in range(params.depth_single_blocks) ] ) if final_layer: - self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, dtype=dtype, device=device, operations=operations) + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels, bias=params.ops_bias, dtype=dtype, device=device, operations=operations) + + if params.global_modulation: + self.double_stream_modulation_img = Modulation( + self.hidden_size, + double=True, + bias=False, + dtype=dtype, device=device, operations=operations + ) + self.double_stream_modulation_txt = Modulation( + self.hidden_size, + double=True, + bias=False, + dtype=dtype, device=device, operations=operations + ) + self.single_stream_modulation = Modulation( + self.hidden_size, double=False, bias=False, dtype=dtype, device=device, operations=operations + ) def forward_orig( self, @@ -122,9 +162,19 @@ class Flux(nn.Module): if guidance is not None: vec = vec + self.guidance_in(timestep_embedding(guidance, 256).to(img.dtype)) - vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + if self.vector_in is not None: + if y is None: + y = torch.zeros((img.shape[0], self.params.vec_in_dim), device=img.device, dtype=img.dtype) + vec = vec + self.vector_in(y[:, :self.params.vec_in_dim]) + + if self.txt_norm is not None: + txt = self.txt_norm(txt) txt = self.txt_in(txt) + vec_orig = vec + if self.params.global_modulation: + vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig)) + if "post_input" in patches: for p in patches["post_input"]: out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) @@ -140,7 +190,10 @@ class Flux(nn.Module): pe = None blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap_1(args): out = {} @@ -181,7 +234,13 @@ class Flux(nn.Module): img = torch.cat((txt, img), 1) + if self.params.global_modulation: + vec, _ = self.single_stream_modulation(vec_orig) + + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: def block_wrap_2(args): out = {} @@ -211,10 +270,10 @@ class Flux(nn.Module): img = img[:, txt.shape[1]:, ...] - img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels) + img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) return img - def process_img(self, x, index=0, h_offset=0, w_offset=0): + def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): bs, c, h, w = x.shape patch_size = self.patch_size x = common_dit.pad_to_patch_size(x, (patch_size, patch_size)) @@ -226,10 +285,22 @@ class Flux(nn.Module): h_offset = ((h_offset + (patch_size // 2)) // patch_size) w_offset = ((w_offset + (patch_size // 2)) // patch_size) - img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype) + steps_h = h_len + steps_w = w_len + + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + index += rope_options.get("shift_t", 0.0) + h_offset += rope_options.get("shift_y", 0.0) + w_offset += rope_options.get("shift_x", 0.0) + + img_ids = torch.zeros((steps_h, steps_w, len(self.params.axes_dim)), device=x.device, dtype=torch.float32) img_ids[:, :, 0] = img_ids[:, :, 1] + index - img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) + img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=steps_h, device=x.device, dtype=torch.float32).unsqueeze(1) + img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=steps_w, device=x.device, dtype=torch.float32).unsqueeze(0) return img, repeat(img_ids, "h w c -> b (h w) c", b=bs) def forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None, control=None, transformer_options=None, **kwargs): @@ -249,16 +320,16 @@ class Flux(nn.Module): h_len = ((h_orig + (patch_size // 2)) // patch_size) w_len = ((w_orig + (patch_size // 2)) // patch_size) - img, img_ids = self.process_img(x) + img, img_ids = self.process_img(x, transformer_options=transformer_options) img_tokens = img.shape[1] if ref_latents is not None: h = 0 w = 0 index = 0 - ref_latents_method = kwargs.get("ref_latents_method", "offset") + ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method) for ref in ref_latents: if ref_latents_method == "index": - index += 1 + index += self.params.ref_index_scale h_offset = 0 w_offset = 0 elif ref_latents_method == "uxo": @@ -282,7 +353,12 @@ class Flux(nn.Module): img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) + txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) + + if len(self.params.txt_ids_dims) > 0: + for i in self.params.txt_ids_dims: + txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) + out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = out[:, :img_tokens] - return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2)[:, :, :h_orig, :w_orig] + return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:, :, :h_orig, :w_orig] diff --git a/comfy/ldm/hunyuan_video/model.py b/comfy/ldm/hunyuan_video/model.py index b24ce6987..17f62db66 100644 --- a/comfy/ldm/hunyuan_video/model.py +++ b/comfy/ldm/hunyuan_video/model.py @@ -33,6 +33,8 @@ class HunyuanVideoParams: guidance_embed: bool byt5: bool meanflow: bool + use_cond_type_embedding: bool + vision_in_dim: int class SelfAttentionRef(nn.Module): @@ -153,7 +155,10 @@ class TokenRefiner(nn.Module): t = self.t_embedder(timestep_embedding(timesteps, 256, time_factor=1.0).to(x.dtype)) # m = mask.float().unsqueeze(-1) # c = (x.float() * m).sum(dim=1) / m.sum(dim=1) #TODO: the following works when the x.shape is the same length as the tokens but might break otherwise - c = x.sum(dim=1) / x.shape[1] + if x.dtype == torch.float16: + c = x.float().sum(dim=1) / x.shape[1] + else: + c = x.sum(dim=1) / x.shape[1] c = t + self.c_embedder(c.to(x.dtype)) x = self.input_embedder(x) @@ -193,11 +198,15 @@ class HunyuanVideo(nn.Module): def __init__(self, image_model=None, final_layer=True, dtype=None, device=None, operations=None, **kwargs): super().__init__() self.dtype = dtype + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + params = HunyuanVideoParams(**kwargs) self.params = params self.patch_size = params.patch_size self.in_channels = params.in_channels self.out_channels = params.out_channels + self.use_cond_type_embedding = params.use_cond_type_embedding + self.vision_in_dim = params.vision_in_dim if params.hidden_size % params.num_heads != 0: raise ValueError( f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" @@ -263,6 +272,18 @@ class HunyuanVideo(nn.Module): if final_layer: self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations) + # HunyuanVideo 1.5 specific modules + if self.vision_in_dim is not None: + from ..wan.model import MLPProj + self.vision_in = MLPProj(in_dim=self.vision_in_dim, out_dim=self.hidden_size, operation_settings=operation_settings) + else: + self.vision_in = None + if self.use_cond_type_embedding: + # 0: text_encoder feature 1: byt5 feature 2: vision_encoder feature + self.cond_type_embedding = nn.Embedding(3, self.hidden_size) + else: + self.cond_type_embedding = None + def forward_orig( self, img: Tensor, @@ -273,7 +294,7 @@ class HunyuanVideo(nn.Module): timesteps: Tensor, y: Tensor = None, txt_byt5=None, - guidance: Tensor = None, + clip_fea=None,guidance: Tensor = None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, @@ -330,12 +351,31 @@ class HunyuanVideo(nn.Module): txt = self.txt_in(txt, timesteps, txt_mask, transformer_options=transformer_options) + if self.cond_type_embedding is not None: + self.cond_type_embedding.to(txt.device) + cond_emb = self.cond_type_embedding(torch.zeros_like(txt[:, :, 0], device=txt.device, dtype=torch.long)) + txt = txt + cond_emb.to(txt.dtype) + if self.byt5_in is not None and txt_byt5 is not None: txt_byt5 = self.byt5_in(txt_byt5) + if self.cond_type_embedding is not None: + cond_emb = self.cond_type_embedding(torch.ones_like(txt_byt5[:, :, 0], device=txt_byt5.device, dtype=torch.long)) + txt_byt5 = txt_byt5 + cond_emb.to(txt_byt5.dtype) + txt = torch.cat((txt_byt5, txt), dim=1) # byt5 first for HunyuanVideo1.5 + else: + txt = torch.cat((txt, txt_byt5), dim=1) txt_byt5_ids = torch.zeros((txt_ids.shape[0], txt_byt5.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) - txt = torch.cat((txt, txt_byt5), dim=1) txt_ids = torch.cat((txt_ids, txt_byt5_ids), dim=1) + if clip_fea is not None: + txt_vision_states = self.vision_in(clip_fea) + if self.cond_type_embedding is not None: + cond_emb = self.cond_type_embedding(2 * torch.ones_like(txt_vision_states[:, :, 0], dtype=torch.long, device=txt_vision_states.device)) + txt_vision_states = txt_vision_states + cond_emb + txt = torch.cat((txt_vision_states.to(txt.dtype), txt), dim=1) + extra_txt_ids = torch.zeros((txt_ids.shape[0], txt_vision_states.shape[1], txt_ids.shape[-1]), device=txt_ids.device, dtype=txt_ids.dtype) + txt_ids = torch.cat((txt_ids, extra_txt_ids), dim=1) + ids = torch.cat((img_ids, txt_ids), dim=1) pe = self.pe_embedder(ids) @@ -348,7 +388,10 @@ class HunyuanVideo(nn.Module): attn_mask = None blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.double_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.double_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap_2(args): out = {} @@ -370,7 +413,10 @@ class HunyuanVideo(nn.Module): img = torch.cat((img, txt), 1) + transformer_options["total_blocks"] = len(self.single_blocks) + transformer_options["block_type"] = "single" for i, block in enumerate(self.single_blocks): + transformer_options["block_index"] = i if ("single_block", i) in blocks_replace: def block_wrap(args): out = {} @@ -429,16 +475,16 @@ class HunyuanVideo(nn.Module): img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) return repeat(img_ids, "h w c -> b (h w) c", b=bs) - def forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options=None, **kwargs): + def forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options=None, **kwargs): if transformer_options is None: transformer_options = {} return WrapperExecutor.new_class_executor( self._forward, self, get_all_wrappers(WrappersMP.DIFFUSION_MODEL, transformer_options) - ).execute(x, timestep, context, y, txt_byt5, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) + ).execute(x, timestep, context, y, txt_byt5, clip_fea, guidance, attention_mask, guiding_frame_index, ref_latent, disable_time_r, control, transformer_options, **kwargs) - def _forward(self, x, timestep, context, y=None, txt_byt5=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options=None, **kwargs): + def _forward(self, x, timestep, context, y=None, txt_byt5=None, clip_fea=None, guidance=None, attention_mask=None, guiding_frame_index=None, ref_latent=None, disable_time_r=False, control=None, transformer_options=None, **kwargs): if transformer_options is None: transformer_options = {} bs = x.shape[0] @@ -448,5 +494,5 @@ class HunyuanVideo(nn.Module): else: img_ids = self.img_ids_2d(x) txt_ids = torch.zeros((bs, context.shape[1], 2), device=x.device, dtype=x.dtype) - out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) + out = self.forward_orig(x, img_ids, context, txt_ids, attention_mask, timestep, y, txt_byt5, clip_fea, guidance, guiding_frame_index, ref_latent, disable_time_r=disable_time_r, control=control, transformer_options=transformer_options) return out diff --git a/comfy/ldm/hunyuan_video/upsampler.py b/comfy/ldm/hunyuan_video/upsampler.py new file mode 100644 index 000000000..85f515f67 --- /dev/null +++ b/comfy/ldm/hunyuan_video/upsampler.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d +from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm +import model_management, model_patcher + +class SRResidualCausalBlock3D(nn.Module): + def __init__(self, channels: int): + super().__init__() + self.block = nn.Sequential( + VideoConv3d(channels, channels, kernel_size=3), + nn.SiLU(inplace=True), + VideoConv3d(channels, channels, kernel_size=3), + nn.SiLU(inplace=True), + VideoConv3d(channels, channels, kernel_size=3), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.block(x) + +class SRModel3DV2(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + hidden_channels: int = 64, + num_blocks: int = 6, + global_residual: bool = False, + ): + super().__init__() + self.in_conv = VideoConv3d(in_channels, hidden_channels, kernel_size=3) + self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)]) + self.out_conv = VideoConv3d(hidden_channels, out_channels, kernel_size=3) + self.global_residual = bool(global_residual) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + y = self.in_conv(x) + for blk in self.blocks: + y = blk(y) + y = self.out_conv(y) + if self.global_residual and (y.shape == residual.shape): + y = y + residual + return y + + +class Upsampler(nn.Module): + def __init__( + self, + z_channels: int, + out_channels: int, + block_out_channels: tuple[int, ...], + num_res_blocks: int = 2, + ): + super().__init__() + self.num_res_blocks = num_res_blocks + self.block_out_channels = block_out_channels + self.z_channels = z_channels + + ch = block_out_channels[0] + self.conv_in = VideoConv3d(z_channels, ch, kernel_size=3) + + self.up = nn.ModuleList() + + for i, tgt in enumerate(block_out_channels): + stage = nn.Module() + stage.block = nn.ModuleList([ResnetBlock(in_channels=ch if j == 0 else tgt, + out_channels=tgt, + temb_channels=0, + conv_shortcut=False, + conv_op=VideoConv3d, norm_op=RMS_norm) + for j in range(num_res_blocks + 1)]) + ch = tgt + self.up.append(stage) + + self.norm_out = RMS_norm(ch) + self.conv_out = VideoConv3d(ch, out_channels, kernel_size=3) + + def forward(self, z): + """ + Args: + z: (B, C, T, H, W) + target_shape: (H, W) + """ + # z to block_in + repeats = self.block_out_channels[0] // (self.z_channels) + x = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1) + + # upsampling + for stage in self.up: + for blk in stage.block: + x = blk(x) + + out = self.conv_out(F.silu(self.norm_out(x))) + return out + +UPSAMPLERS = { + "720p": SRModel3DV2, + "1080p": Upsampler, +} + +class HunyuanVideo15SRModel(): + def __init__(self, model_type, config): + self.load_device = model_management.vae_device() + offload_device = model_management.vae_offload_device() + self.dtype = model_management.vae_dtype(self.load_device) + self.model_class = UPSAMPLERS.get(model_type) + self.model = self.model_class(**config).eval() + + self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) + + def load_sd(self, sd): + return self.model.load_state_dict(sd, strict=True) + + def get_sd(self): + return self.model.state_dict() + + def resample_latent(self, latent): + model_management.load_model_gpu(self.patcher) + return self.model(latent.to(self.load_device)) diff --git a/comfy/ldm/hunyuan_video/vae_refiner.py b/comfy/ldm/hunyuan_video/vae_refiner.py index 1fda9f2b9..c22cd15c9 100644 --- a/comfy/ldm/hunyuan_video/vae_refiner.py +++ b/comfy/ldm/hunyuan_video/vae_refiner.py @@ -1,8 +1,11 @@ import torch import torch.nn as nn import torch.nn.functional as F -from ..modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d, Normalize + from ..models.autoencoder import DiagonalGaussianRegularizer +from ..modules.diffusionmodules.model import ResnetBlock, AttnBlock, CarriedConv3d, Normalize, conv_carry_causal_3d, \ + torch_cat_if_needed +from ...model_management import cast_to from ...ops import disable_weight_init as ops @@ -14,11 +17,11 @@ class RMS_norm(nn.Module): self.gamma = nn.Parameter(torch.empty(shape)) def forward(self, x): - return F.normalize(x, dim=1) * self.scale * self.gamma + return F.normalize(x, dim=1) * self.scale * cast_to(self.gamma, dtype=x.dtype, device=x.device) class DnSmpl(nn.Module): - def __init__(self, ic, oc, tds=True, refiner_vae=True, op=VideoConv3d): + def __init__(self, ic, oc, tds, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tds else 1 * 2 * 2 assert oc % fct == 0 @@ -28,11 +31,11 @@ class DnSmpl(nn.Module): self.tds = tds self.gs = fct * ic // oc - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): r1 = 2 if self.tds else 1 - h = self.conv(x) + h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) - if self.tds and self.refiner_vae: + if self.tds and self.refiner_vae and conv_carry_in is None: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape hf = hf.reshape(b, c, f, ht // 2, 2, wd // 2, 2) @@ -40,14 +43,7 @@ class DnSmpl(nn.Module): hf = hf.reshape(b, 2 * 2 * c, f, ht // 2, wd // 2) hf = torch.cat([hf, hf], dim=1) - hn = h[:, :, 1:, :, :] - b, c, frms, ht, wd = hn.shape - nf = frms // r1 - hn = hn.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) - hn = hn.permute(0, 3, 5, 7, 1, 2, 4, 6) - hn = hn.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) - - h = torch.cat([hf, hn], dim=2) + h = h[:, :, 1:, :, :] xf = x[:, :, :1, :, :] b, ci, f, ht, wd = xf.shape @@ -55,38 +51,36 @@ class DnSmpl(nn.Module): xf = xf.permute(0, 4, 6, 1, 2, 3, 5) xf = xf.reshape(b, 2 * 2 * ci, f, ht // 2, wd // 2) B, C, T, H, W = xf.shape - xf = xf.view(B, h.shape[1], self.gs // 2, T, H, W).mean(dim=2) + xf = xf.view(B, hf.shape[1], self.gs // 2, T, H, W).mean(dim=2) - xn = x[:, :, 1:, :, :] - b, ci, frms, ht, wd = xn.shape - nf = frms // r1 - xn = xn.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) - xn = xn.permute(0, 3, 5, 7, 1, 2, 4, 6) - xn = xn.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) - B, C, T, H, W = xn.shape - xn = xn.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) - sc = torch.cat([xf, xn], dim=2) - else: - b, c, frms, ht, wd = h.shape + x = x[:, :, 1:, :, :] - nf = frms // r1 - h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) - h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) - h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) + if h.shape[2] == 0: + return hf + xf - b, ci, frms, ht, wd = x.shape - nf = frms // r1 - sc = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) - sc = sc.permute(0, 3, 5, 7, 1, 2, 4, 6) - sc = sc.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) - B, C, T, H, W = sc.shape - sc = sc.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) + b, c, frms, ht, wd = h.shape + nf = frms // r1 + h = h.reshape(b, c, nf, r1, ht // 2, 2, wd // 2, 2) + h = h.permute(0, 3, 5, 7, 1, 2, 4, 6) + h = h.reshape(b, r1 * 2 * 2 * c, nf, ht // 2, wd // 2) - return h + sc + b, ci, frms, ht, wd = x.shape + nf = frms // r1 + x = x.reshape(b, ci, nf, r1, ht // 2, 2, wd // 2, 2) + x = x.permute(0, 3, 5, 7, 1, 2, 4, 6) + x = x.reshape(b, r1 * 2 * 2 * ci, nf, ht // 2, wd // 2) + B, C, T, H, W = x.shape + x = x.view(B, h.shape[1], self.gs, T, H, W).mean(dim=2) + + if self.tds and self.refiner_vae and conv_carry_in is None: + h = torch.cat([hf, h], dim=2) + x = torch.cat([xf, x], dim=2) + + return h + x class UpSmpl(nn.Module): - def __init__(self, ic, oc, tus=True, refiner_vae=True, op=VideoConv3d): + def __init__(self, ic, oc, tus, refiner_vae, op): super().__init__() fct = 2 * 2 * 2 if tus else 1 * 2 * 2 self.conv = op(ic, oc * fct, kernel_size=3, stride=1, padding=1) @@ -95,11 +89,11 @@ class UpSmpl(nn.Module): self.tus = tus self.rp = fct * oc // ic - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): r1 = 2 if self.tus else 1 - h = self.conv(x) + h = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) - if self.tus and self.refiner_vae: + if self.tus and self.refiner_vae and conv_carry_in is None: hf = h[:, :, :1, :, :] b, c, f, ht, wd = hf.shape nc = c // (2 * 2) @@ -108,14 +102,7 @@ class UpSmpl(nn.Module): hf = hf.reshape(b, nc, f, ht * 2, wd * 2) hf = hf[:, : hf.shape[1] // 2] - hn = h[:, :, 1:, :, :] - b, c, frms, ht, wd = hn.shape - nc = c // (r1 * 2 * 2) - hn = hn.reshape(b, r1, 2, 2, nc, frms, ht, wd) - hn = hn.permute(0, 4, 5, 1, 6, 2, 7, 3) - hn = hn.reshape(b, nc, frms * r1, ht * 2, wd * 2) - - h = torch.cat([hf, hn], dim=2) + h = h[:, :, 1:, :, :] xf = x[:, :, :1, :, :] b, ci, f, ht, wd = xf.shape @@ -126,29 +113,26 @@ class UpSmpl(nn.Module): xf = xf.permute(0, 3, 4, 5, 1, 6, 2) xf = xf.reshape(b, nc, f, ht * 2, wd * 2) - xn = x[:, :, 1:, :, :] - xn = xn.repeat_interleave(repeats=self.rp, dim=1) - b, c, frms, ht, wd = xn.shape - nc = c // (r1 * 2 * 2) - xn = xn.reshape(b, r1, 2, 2, nc, frms, ht, wd) - xn = xn.permute(0, 4, 5, 1, 6, 2, 7, 3) - xn = xn.reshape(b, nc, frms * r1, ht * 2, wd * 2) - sc = torch.cat([xf, xn], dim=2) - else: - b, c, frms, ht, wd = h.shape - nc = c // (r1 * 2 * 2) - h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) - h = h.permute(0, 4, 5, 1, 6, 2, 7, 3) - h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2) + x = x[:, :, 1:, :, :] - sc = x.repeat_interleave(repeats=self.rp, dim=1) - b, c, frms, ht, wd = sc.shape - nc = c // (r1 * 2 * 2) - sc = sc.reshape(b, r1, 2, 2, nc, frms, ht, wd) - sc = sc.permute(0, 4, 5, 1, 6, 2, 7, 3) - sc = sc.reshape(b, nc, frms * r1, ht * 2, wd * 2) + b, c, frms, ht, wd = h.shape + nc = c // (r1 * 2 * 2) + h = h.reshape(b, r1, 2, 2, nc, frms, ht, wd) + h = h.permute(0, 4, 5, 1, 6, 2, 7, 3) + h = h.reshape(b, nc, frms * r1, ht * 2, wd * 2) - return h + sc + x = x.repeat_interleave(repeats=self.rp, dim=1) + b, c, frms, ht, wd = x.shape + nc = c // (r1 * 2 * 2) + x = x.reshape(b, r1, 2, 2, nc, frms, ht, wd) + x = x.permute(0, 4, 5, 1, 6, 2, 7, 3) + x = x.reshape(b, nc, frms * r1, ht * 2, wd * 2) + + if self.tus and self.refiner_vae and conv_carry_in is None: + h = torch.cat([hf, h], dim=2) + x = torch.cat([xf, x], dim=2) + + return h + x class Encoder(nn.Module): @@ -162,7 +146,7 @@ class Encoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = VideoConv3d + conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -190,9 +174,9 @@ class Encoder(nn.Module): self.down.append(stage) self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.norm_out = norm_op(ch) self.conv_out = conv_op(ch, z_channels << 1, 3, 1, 1) @@ -203,31 +187,48 @@ class Encoder(nn.Module): if not self.refiner_vae and x.shape[2] == 1: x = x.expand(-1, -1, self.ffactor_temporal, -1, -1) - x = self.conv_in(x) + if self.refiner_vae: + xl = [x[:, :, :1, :, :]] + if x.shape[2] > self.ffactor_temporal: + xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // self.ffactor_temporal) * self.ffactor_temporal, :, :], self.ffactor_temporal * 2, dim=2) + x = xl + else: + x = [x] + out = [] - for stage in self.down: - for blk in stage.block: - x = blk(x) - if hasattr(stage, 'downsample'): - x = stage.downsample(x) + conv_carry_in = None - x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + + x1 = [x1] + x1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) + + for stage in self.down: + for blk in stage.block: + x1 = blk(x1, None, conv_carry_in, conv_carry_out) + if hasattr(stage, 'downsample'): + x1 = stage.downsample(x1, conv_carry_in, conv_carry_out) + + out.append(x1) + conv_carry_in = conv_carry_out + + out = torch_cat_if_needed(out, dim=2) + + x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(out))) + del out b, c, t, h, w = x.shape grp = c // (self.z_channels << 1) skip = x.view(b, c // grp, grp, t, h, w).mean(2) - out = self.conv_out(F.silu(self.norm_out(x))) + skip + out = conv_carry_causal_3d([F.silu(self.norm_out(x))], self.conv_out) + skip if self.refiner_vae: out = self.regul(out)[0] - out = torch.cat((out[:, :, :1], out), dim=2) - out = out.permute(0, 2, 1, 3, 4) - b, f_times_2, c, h, w = out.shape - out = out.reshape(b, f_times_2 // 2, 2 * c, h, w) - out = out.permute(0, 2, 1, 3, 4).contiguous() - return out @@ -242,7 +243,7 @@ class Decoder(nn.Module): self.refiner_vae = refiner_vae if self.refiner_vae: - conv_op = VideoConv3d + conv_op = CarriedConv3d norm_op = RMS_norm else: conv_op = ops.Conv3d @@ -252,9 +253,9 @@ class Decoder(nn.Module): self.conv_in = conv_op(z_channels, ch, kernel_size=3, stride=1, padding=1) self.mid = nn.Module() - self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_1 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.mid.attn_1 = AttnBlock(ch, conv_op=ops.Conv3d, norm_op=norm_op) - self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, temb_channels=0, conv_op=conv_op, norm_op=norm_op) + self.mid.block_2 = ResnetBlock(in_channels=ch, out_channels=ch, conv_op=conv_op, norm_op=norm_op) self.up = nn.ModuleList() depth = (ffactor_spatial >> 1).bit_length() @@ -278,24 +279,34 @@ class Decoder(nn.Module): self.conv_out = conv_op(ch, out_channels, 3, stride=1, padding=1) def forward(self, z): - if self.refiner_vae: - z = z.permute(0, 2, 1, 3, 4) - b, f, c, h, w = z.shape - z = z.reshape(b, f, 2, c // 2, h, w) - z = z.permute(0, 1, 2, 3, 4, 5).reshape(b, f * 2, c // 2, h, w) - z = z.permute(0, 2, 1, 3, 4) - z = z[:, :, 1:] - - x = self.conv_in(z) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) + x = conv_carry_causal_3d([z], self.conv_in) + z.repeat_interleave(self.block_out_channels[0] // self.z_channels, 1) x = self.mid.block_2(self.mid.attn_1(self.mid.block_1(x))) - for stage in self.up: - for blk in stage.block: - x = blk(x) - if hasattr(stage, 'upsample'): - x = stage.upsample(x) + if self.refiner_vae: + x = torch.split(x, 2, dim=2) + else: + x = [x] + out = [] - out = self.conv_out(F.silu(self.norm_out(x))) + conv_carry_in = None + + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + for stage in self.up: + for blk in stage.block: + x1 = blk(x1, None, conv_carry_in, conv_carry_out) + if hasattr(stage, 'upsample'): + x1 = stage.upsample(x1, conv_carry_in, conv_carry_out) + + x1 = [F.silu(self.norm_out(x1))] + x1 = conv_carry_causal_3d(x1, self.conv_out, conv_carry_in, conv_carry_out) + out.append(x1) + conv_carry_in = conv_carry_out + del x + + out = torch_cat_if_needed(out, dim=2) if not self.refiner_vae: if z.shape[-3] == 1: diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py new file mode 100644 index 000000000..1509de2f8 --- /dev/null +++ b/comfy/ldm/kandinsky5/model.py @@ -0,0 +1,413 @@ +import torch +from torch import nn +import math + +import comfy.ldm.common_dit +from comfy.ldm.modules.attention import optimized_attention +from comfy.ldm.flux.math import apply_rope1 +from comfy.ldm.flux.layers import EmbedND + +def attention(q, k, v, heads, transformer_options={}): + return optimized_attention( + q.transpose(1, 2), + k.transpose(1, 2), + v.transpose(1, 2), + heads=heads, + skip_reshape=True, + transformer_options=transformer_options + ) + +def apply_scale_shift_norm(norm, x, scale, shift): + return torch.addcmul(shift, norm(x), scale + 1.0) + +def apply_gate_sum(x, out, gate): + return torch.addcmul(x, gate, out) + +def get_shift_scale_gate(params): + shift, scale, gate = torch.chunk(params, 3, dim=-1) + return tuple(x.unsqueeze(1) for x in (shift, scale, gate)) + +def get_freqs(dim, max_period=10000.0): + return torch.exp(-math.log(max_period) * torch.arange(start=0, end=dim, dtype=torch.float32) / dim) + + +class TimeEmbeddings(nn.Module): + def __init__(self, model_dim, time_dim, max_period=10000.0, operation_settings=None): + super().__init__() + assert model_dim % 2 == 0 + self.model_dim = model_dim + self.max_period = max_period + self.register_buffer("freqs", get_freqs(model_dim // 2, max_period), persistent=False) + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(model_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.activation = nn.SiLU() + self.out_layer = operations.Linear(time_dim, time_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, timestep, dtype): + args = torch.outer(timestep, self.freqs.to(device=timestep.device)) + time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1).to(dtype) + time_embed = self.out_layer(self.activation(self.in_layer(time_embed))) + return time_embed + + +class TextEmbeddings(nn.Module): + def __init__(self, text_dim, model_dim, operation_settings=None): + super().__init__() + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(text_dim, model_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.norm = operations.LayerNorm(model_dim, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, text_embed): + text_embed = self.in_layer(text_embed) + return self.norm(text_embed).type_as(text_embed) + + +class VisualEmbeddings(nn.Module): + def __init__(self, visual_dim, model_dim, patch_size, operation_settings=None): + super().__init__() + self.patch_size = patch_size + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(visual_dim, model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x): + x = x.movedim(1, -1) # B C T H W -> B T H W C + B, T, H, W, dim = x.shape + pt, ph, pw = self.patch_size + + x = x.view( + B, + T // pt, pt, + H // ph, ph, + W // pw, pw, + dim, + ).permute(0, 1, 3, 5, 2, 4, 6, 7).flatten(4, 7) + + return self.in_layer(x) + + +class Modulation(nn.Module): + def __init__(self, time_dim, model_dim, num_params, operation_settings=None): + super().__init__() + self.activation = nn.SiLU() + self.out_layer = operation_settings.get("operations").Linear(time_dim, num_params * model_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, x): + return self.out_layer(self.activation(x)) + + +class SelfAttention(nn.Module): + def __init__(self, num_channels, head_dim, operation_settings=None): + super().__init__() + assert num_channels % head_dim == 0 + self.num_heads = num_channels // head_dim + self.head_dim = head_dim + + operations = operation_settings.get("operations") + self.to_query = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.to_key = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.to_value = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.query_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.key_norm = operations.RMSNorm(head_dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + self.out_layer = operations.Linear(num_channels, num_channels, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.num_chunks = 2 + + def _compute_qk(self, x, freqs, proj_fn, norm_fn): + result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1) + return apply_rope1(norm_fn(result), freqs) + + def _forward(self, x, freqs, transformer_options={}): + q = self._compute_qk(x, freqs, self.to_query, self.query_norm) + k = self._compute_qk(x, freqs, self.to_key, self.key_norm) + v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + def _forward_chunked(self, x, freqs, transformer_options={}): + def process_chunks(proj_fn, norm_fn): + x_chunks = torch.chunk(x, self.num_chunks, dim=1) + freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1) + chunks = [] + for x_chunk, freqs_chunk in zip(x_chunks, freqs_chunks): + chunks.append(self._compute_qk(x_chunk, freqs_chunk, proj_fn, norm_fn)) + return torch.cat(chunks, dim=1) + + q = process_chunks(self.to_query, self.query_norm) + k = process_chunks(self.to_key, self.key_norm) + v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + def forward(self, x, freqs, transformer_options={}): + if x.shape[1] > 8192: + return self._forward_chunked(x, freqs, transformer_options=transformer_options) + else: + return self._forward(x, freqs, transformer_options=transformer_options) + + +class CrossAttention(SelfAttention): + def get_qkv(self, x, context): + q = self.to_query(x).view(*x.shape[:-1], self.num_heads, -1) + k = self.to_key(context).view(*context.shape[:-1], self.num_heads, -1) + v = self.to_value(context).view(*context.shape[:-1], self.num_heads, -1) + return q, k, v + + def forward(self, x, context, transformer_options={}): + q, k, v = self.get_qkv(x, context) + out = attention(self.query_norm(q), self.key_norm(k), v, self.num_heads, transformer_options=transformer_options) + return self.out_layer(out) + + +class FeedForward(nn.Module): + def __init__(self, dim, ff_dim, operation_settings=None): + super().__init__() + operations = operation_settings.get("operations") + self.in_layer = operations.Linear(dim, ff_dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.activation = nn.GELU() + self.out_layer = operations.Linear(ff_dim, dim, bias=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.num_chunks = 4 + + def _forward(self, x): + return self.out_layer(self.activation(self.in_layer(x))) + + def _forward_chunked(self, x): + chunks = torch.chunk(x, self.num_chunks, dim=1) + output_chunks = [] + for chunk in chunks: + output_chunks.append(self._forward(chunk)) + return torch.cat(output_chunks, dim=1) + + def forward(self, x): + if x.shape[1] > 8192: + return self._forward_chunked(x) + else: + return self._forward(x) + + +class OutLayer(nn.Module): + def __init__(self, model_dim, time_dim, visual_dim, patch_size, operation_settings=None): + super().__init__() + self.patch_size = patch_size + self.modulation = Modulation(time_dim, model_dim, 2, operation_settings=operation_settings) + operations = operation_settings.get("operations") + self.norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.out_layer = operations.Linear(model_dim, math.prod(patch_size) * visual_dim, bias=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, visual_embed, time_embed): + B, T, H, W, _ = visual_embed.shape + shift, scale = torch.chunk(self.modulation(time_embed), 2, dim=-1) + scale = scale[:, None, None, None, :] + shift = shift[:, None, None, None, :] + visual_embed = apply_scale_shift_norm(self.norm, visual_embed, scale, shift) + x = self.out_layer(visual_embed) + + out_dim = x.shape[-1] // (self.patch_size[0] * self.patch_size[1] * self.patch_size[2]) + x = x.view( + B, T, H, W, + out_dim, + self.patch_size[0], self.patch_size[1], self.patch_size[2] + ) + return x.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(2, 3).flatten(3, 4).flatten(4, 5) + + +class TransformerEncoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None): + super().__init__() + self.text_modulation = Modulation(time_dim, model_dim, 6, operation_settings=operation_settings) + operations = operation_settings.get("operations") + + self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) + + def forward(self, x, time_embed, freqs, transformer_options={}): + self_attn_params, ff_params = torch.chunk(self.text_modulation(time_embed), 2, dim=-1) + shift, scale, gate = get_shift_scale_gate(self_attn_params) + out = apply_scale_shift_norm(self.self_attention_norm, x, scale, shift) + out = self.self_attention(out, freqs, transformer_options=transformer_options) + x = apply_gate_sum(x, out, gate) + + shift, scale, gate = get_shift_scale_gate(ff_params) + out = apply_scale_shift_norm(self.feed_forward_norm, x, scale, shift) + out = self.feed_forward(out) + x = apply_gate_sum(x, out, gate) + return x + + +class TransformerDecoderBlock(nn.Module): + def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=None): + super().__init__() + self.visual_modulation = Modulation(time_dim, model_dim, 9, operation_settings=operation_settings) + + operations = operation_settings.get("operations") + self.self_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.self_attention = SelfAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.cross_attention_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.cross_attention = CrossAttention(model_dim, head_dim, operation_settings=operation_settings) + + self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) + + def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}): + self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1) + # self attention + shift, scale, gate = get_shift_scale_gate(self_attn_params) + visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) + visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + # cross attention + shift, scale, gate = get_shift_scale_gate(cross_attn_params) + visual_out = apply_scale_shift_norm(self.cross_attention_norm, visual_embed, scale, shift) + visual_out = self.cross_attention(visual_out, text_embed, transformer_options=transformer_options) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + # feed forward + shift, scale, gate = get_shift_scale_gate(ff_params) + visual_out = apply_scale_shift_norm(self.feed_forward_norm, visual_embed, scale, shift) + visual_out = self.feed_forward(visual_out) + visual_embed = apply_gate_sum(visual_embed, visual_out, gate) + return visual_embed + + +class Kandinsky5(nn.Module): + def __init__( + self, + in_visual_dim=16, out_visual_dim=16, in_text_dim=3584, in_text_dim2=768, time_dim=512, + model_dim=1792, ff_dim=7168, visual_embed_dim=132, patch_size=(1, 2, 2), num_text_blocks=2, num_visual_blocks=32, + axes_dims=(16, 24, 24), rope_scale_factor=(1.0, 2.0, 2.0), + dtype=None, device=None, operations=None, **kwargs + ): + super().__init__() + head_dim = sum(axes_dims) + self.rope_scale_factor = rope_scale_factor + self.in_visual_dim = in_visual_dim + self.model_dim = model_dim + self.patch_size = patch_size + self.visual_embed_dim = visual_embed_dim + self.dtype = dtype + self.device = device + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.time_embeddings = TimeEmbeddings(model_dim, time_dim, operation_settings=operation_settings) + self.text_embeddings = TextEmbeddings(in_text_dim, model_dim, operation_settings=operation_settings) + self.pooled_text_embeddings = TextEmbeddings(in_text_dim2, time_dim, operation_settings=operation_settings) + self.visual_embeddings = VisualEmbeddings(visual_embed_dim, model_dim, patch_size, operation_settings=operation_settings) + + self.text_transformer_blocks = nn.ModuleList( + [TransformerEncoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_text_blocks)] + ) + + self.visual_transformer_blocks = nn.ModuleList( + [TransformerDecoderBlock(model_dim, time_dim, ff_dim, head_dim, operation_settings=operation_settings) for _ in range(num_visual_blocks)] + ) + + self.out_layer = OutLayer(model_dim, time_dim, out_visual_dim, patch_size, operation_settings=operation_settings) + + self.rope_embedder_3d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=axes_dims) + self.rope_embedder_1d = EmbedND(dim=head_dim, theta=10000.0, axes_dim=[head_dim]) + + def rope_encode_1d(self, seq_len, seq_start=0, steps=None, device=None, dtype=None, transformer_options={}): + steps = seq_len if steps is None else steps + seq_ids = torch.linspace(seq_start, seq_start + (seq_len - 1), steps=steps, device=device, dtype=dtype) + seq_ids = seq_ids.reshape(-1, 1).unsqueeze(0) # Shape: (1, steps, 1) + freqs = self.rope_embedder_1d(seq_ids).movedim(1, 2) + return freqs + + def rope_encode_3d(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}): + + patch_size = self.patch_size + t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) + h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) + w_len = ((w + (patch_size[2] // 2)) // patch_size[2]) + + if steps_t is None: + steps_t = t_len + if steps_h is None: + steps_h = h_len + if steps_w is None: + steps_w = w_len + + h_start = 0 + w_start = 0 + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0 + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + t_start += rope_options.get("shift_t", 0.0) + h_start += rope_options.get("shift_y", 0.0) + w_start += rope_options.get("shift_x", 0.0) + else: + rope_scale_factor = self.rope_scale_factor + if self.model_dim == 4096: # pro video model uses different rope scaling at higher resolutions + if h * w >= 14080: + rope_scale_factor = (1.0, 3.16, 3.16) + + t_len = (t_len - 1.0) / rope_scale_factor[0] + 1.0 + h_len = (h_len - 1.0) / rope_scale_factor[1] + 1.0 + w_len = (w_len - 1.0) / rope_scale_factor[2] + 1.0 + + img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) + img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) + img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) + + freqs = self.rope_embedder_3d(img_ids).movedim(1, 2) + return freqs + + def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_options={}, **kwargs): + patches_replace = transformer_options.get("patches_replace", {}) + context = self.text_embeddings(context) + time_embed = self.time_embeddings(timestep, x.dtype) + self.pooled_text_embeddings(y) + + for block in self.text_transformer_blocks: + context = block(context, time_embed, freqs_text, transformer_options=transformer_options) + + visual_embed = self.visual_embeddings(x) + visual_shape = visual_embed.shape[:-1] + visual_embed = visual_embed.flatten(1, -2) + + blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.visual_transformer_blocks) + transformer_options["block_type"] = "double" + for i, block in enumerate(self.visual_transformer_blocks): + transformer_options["block_index"] = i + if ("double_block", i) in blocks_replace: + def block_wrap(args): + return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) + visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"] + else: + visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options) + + visual_embed = visual_embed.reshape(*visual_shape, -1) + return self.out_layer(visual_embed, time_embed) + + def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): + original_dims = x.ndim + if original_dims == 4: + x = x.unsqueeze(2) + bs, c, t_len, h, w = x.shape + x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size) + + if time_dim_replace is not None: + time_dim_replace = comfy.ldm.common_dit.pad_to_patch_size(time_dim_replace, self.patch_size) + x[:, :time_dim_replace.shape[1], :time_dim_replace.shape[2]] = time_dim_replace + + freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) + freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options) + + out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs) + if original_dims == 4: + out = out.squeeze(2) + return out + + def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): + return comfy.patcher_extension.WrapperExecutor.new_class_executor( + self._forward, + self, + comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) + ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs) diff --git a/comfy/ldm/lightricks/model.py b/comfy/ldm/lightricks/model.py index 6df410d35..6142b5cba 100644 --- a/comfy/ldm/lightricks/model.py +++ b/comfy/ldm/lightricks/model.py @@ -1,12 +1,12 @@ -import torch -from torch import nn - -from ..common_dit import rms_norm -from einops import rearrange import math from typing import Dict, Optional, Tuple +import torch +from torch import nn + from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords +from ..common_dit import rms_norm +from ..flux.math import apply_rope1 from ..modules.attention import optimized_attention, optimized_attention_masked from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP @@ -181,10 +181,11 @@ class AdaLayerNormSingle(nn.Module): added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, batch_size: Optional[int] = None, hidden_dtype: Optional[torch.dtype] = None, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: #, torch.Tensor, torch.Tensor, torch.Tensor]: # No modulation happening here. added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None} embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype) + # todo: whats going on with the signature? return self.linear(self.silu(embedded_timestep)), embedded_timestep @@ -240,20 +241,6 @@ class FeedForward(nn.Module): return self.net(x) -def apply_rotary_emb(input_tensor, freqs_cis): # TODO: remove duplicate funcs and pick the best/fastest one - cos_freqs = freqs_cis[0] - sin_freqs = freqs_cis[1] - - t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) - t1, t2 = t_dup.unbind(dim=-1) - t_dup = torch.stack((-t2, t1), dim=-1) - input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") - - out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs - - return out - - class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None): super().__init__() @@ -285,8 +272,8 @@ class CrossAttention(nn.Module): k = self.k_norm(k) if pe is not None: - q = apply_rotary_emb(q, pe) - k = apply_rotary_emb(k, pe) + q = apply_rope1(q.unsqueeze(1), pe).squeeze(1) + k = apply_rope1(k.unsqueeze(1), pe).squeeze(1) if mask is None: out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options) @@ -312,12 +299,17 @@ class BasicTransformerBlock(nn.Module): transformer_options = {} shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2) - x += self.attn1(rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa + attn1_input = rms_norm(x) + attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa) + attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options) + x.addcmul_(attn1_input, gate_msa) + del attn1_input x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options) - y = rms_norm(x) * (1 + scale_mlp) + shift_mlp - x += self.ff(y) * gate_mlp + y = rms_norm(x) + y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp) + x.addcmul_(self.ff(y), gate_mlp) return x @@ -336,41 +328,35 @@ def get_fractional_positions(indices_grid, max_pos): def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=None): if max_pos is None: max_pos = [20, 2048, 2048] - dtype = torch.float32 # self.dtype + dtype = torch.float32 + device = indices_grid.device + # Get fractional positions and compute frequency indices fractional_positions = get_fractional_positions(indices_grid, max_pos) + indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2 - start = 1 - end = theta - device = fractional_positions.device + # Compute frequencies and apply cos/sin + freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2) + cos_vals = freqs.cos().repeat_interleave(2, dim=-1) + sin_vals = freqs.sin().repeat_interleave(2, dim=-1) - indices = theta ** ( - torch.linspace( - math.log(start, theta), - math.log(end, theta), - dim // 6, - device=device, - dtype=dtype, - ) - ) - indices = indices.to(dtype=dtype) - - indices = indices * math.pi / 2 - - freqs = ( - (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) - .transpose(-1, -2) - .flatten(2) - ) - - cos_freq = freqs.cos().repeat_interleave(2, dim=-1) - sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + # Pad if dim is not divisible by 6 if dim % 6 != 0: - cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) - sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) - cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) - sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) - return cos_freq.to(out_dtype), sin_freq.to(out_dtype) + padding_size = dim % 6 + cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1) + sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1) + + # Reshape and extract one value per pair (since repeat_interleave duplicates each value) + cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] + sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2] + + # Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension + freqs_cis = torch.stack([ + torch.stack([cos_vals, -sin_vals], dim=-1), + torch.stack([sin_vals, cos_vals], dim=-1) + ], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2] + + return freqs_cis class LTXVModel(torch.nn.Module): @@ -515,7 +501,7 @@ class LTXVModel(torch.nn.Module): shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] x = self.norm_out(x) # Modulation - x = x * (1 + scale) + shift + x = torch.addcmul(x, x, scale).add_(shift) x = self.proj_out(x) x = self.patchifier.unpatchify( diff --git a/comfy/ldm/lumina/controlnet.py b/comfy/ldm/lumina/controlnet.py new file mode 100644 index 000000000..fd7ce3b5c --- /dev/null +++ b/comfy/ldm/lumina/controlnet.py @@ -0,0 +1,113 @@ +import torch +from torch import nn + +from .model import JointTransformerBlock + +class ZImageControlTransformerBlock(JointTransformerBlock): + def __init__( + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + modulation=True, + block_id=0, + operation_settings=None, + ): + super().__init__(layer_id, dim, n_heads, n_kv_heads, multiple_of, ffn_dim_multiplier, norm_eps, qk_norm, modulation, z_image_modulation=True, operation_settings=operation_settings) + self.block_id = block_id + if block_id == 0: + self.before_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + self.after_proj = operation_settings.get("operations").Linear(self.dim, self.dim, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) + + def forward(self, c, x, **kwargs): + if self.block_id == 0: + c = self.before_proj(c) + x + c = super().forward(c, **kwargs) + c_skip = self.after_proj(c) + return c_skip, c + +class ZImage_Control(torch.nn.Module): + def __init__( + self, + dim: int = 3840, + n_heads: int = 30, + n_kv_heads: int = 30, + multiple_of: int = 256, + ffn_dim_multiplier: float = (8.0 / 3.0), + norm_eps: float = 1e-5, + qk_norm: bool = True, + dtype=None, + device=None, + operations=None, + **kwargs + ): + super().__init__() + operation_settings = {"operations": operations, "device": device, "dtype": dtype} + + self.additional_in_dim = 0 + self.control_in_dim = 16 + n_refiner_layers = 2 + self.n_control_layers = 6 + self.control_layers = nn.ModuleList( + [ + ZImageControlTransformerBlock( + i, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + block_id=i, + operation_settings=operation_settings, + ) + for i in range(self.n_control_layers) + ] + ) + + all_x_embedder = {} + patch_size = 2 + f_patch_size = 1 + x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype) + all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder + + self.control_all_x_embedder = nn.ModuleDict(all_x_embedder) + self.control_noise_refiner = nn.ModuleList( + [ + JointTransformerBlock( + layer_id, + dim, + n_heads, + n_kv_heads, + multiple_of, + ffn_dim_multiplier, + norm_eps, + qk_norm, + modulation=True, + z_image_modulation=True, + operation_settings=operation_settings, + ) + for layer_id in range(n_refiner_layers) + ] + ) + + def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input): + patch_size = 2 + f_patch_size = 1 + pH = pW = patch_size + B, C, H, W = control_context.shape + control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) + + x_attn_mask = None + for layer in self.control_noise_refiner: + control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input) + return control_context + + def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input): + return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input) diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index b8b255478..8716f0688 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -12,26 +12,34 @@ from ..modules.diffusionmodules.mmdit import TimestepEmbedder from ..modules.attention import optimized_attention_masked from ..flux.layers import EmbedND from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP +from ..flux.math import apply_rope def modulate(x, scale): return x * (1 + scale.unsqueeze(1)) + ############################################################################# # Core NextDiT Model # ############################################################################# +def clamp_fp16(x): + if x.dtype == torch.float16: + return torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) + return x + class JointAttention(nn.Module): """Multi-head attention module.""" def __init__( - self, - dim: int, - n_heads: int, - n_kv_heads: Optional[int], - qk_norm: bool, - operation_settings={}, + self, + dim: int, + n_heads: int, + n_kv_heads: Optional[int], + qk_norm: bool, + out_bias: bool = False, + operation_settings=None, ): """ Initialize the Attention module. @@ -43,6 +51,8 @@ class JointAttention(nn.Module): """ super().__init__() + if operation_settings is None: + operation_settings = {} self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads self.n_local_heads = n_heads self.n_local_kv_heads = self.n_kv_heads @@ -59,7 +69,7 @@ class JointAttention(nn.Module): self.out = operation_settings.get("operations").Linear( n_heads * self.head_dim, dim, - bias=False, + bias=out_bias, device=operation_settings.get("device"), dtype=operation_settings.get("dtype"), ) @@ -70,41 +80,12 @@ class JointAttention(nn.Module): else: self.q_norm = self.k_norm = nn.Identity() - @staticmethod - def apply_rotary_emb( - x_in: torch.Tensor, - freqs_cis: torch.Tensor, - ) -> torch.Tensor: - """ - Apply rotary embeddings to input tensors using the given frequency - tensor. - - This function applies rotary embeddings to the given query 'xq' and - key 'xk' tensors using the provided frequency tensor 'freqs_cis'. The - input tensors are reshaped as complex numbers, and the frequency tensor - is reshaped for broadcasting compatibility. The resulting tensors - contain rotary embeddings and are returned as real tensors. - - Args: - x_in (torch.Tensor): Query or Key tensor to apply rotary embeddings. - freqs_cis (torch.Tensor): Precomputed frequency tensor for complex - exponentials. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor - and key tensor with rotary embeddings. - """ - - t_ = x_in.reshape(*x_in.shape[:-1], -1, 1, 2) - t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1] - return t_out.reshape(*x_in.shape) - def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - freqs_cis: torch.Tensor, - transformer_options={}, + self, + x: torch.Tensor, + x_mask: torch.Tensor, + freqs_cis: torch.Tensor, + transformer_options=None, ) -> torch.Tensor: """ @@ -116,6 +97,8 @@ class JointAttention(nn.Module): Returns: """ + if transformer_options is None: + transformer_options = {} bsz, seqlen, _ = x.shape xq, xk, xv = torch.split( @@ -134,8 +117,7 @@ class JointAttention(nn.Module): xq = self.q_norm(xq) xk = self.k_norm(xk) - xq = JointAttention.apply_rotary_emb(xq, freqs_cis=freqs_cis) - xk = JointAttention.apply_rotary_emb(xk, freqs_cis=freqs_cis) + xq, xk = apply_rope(xq, xk, freqs_cis) n_rep = self.n_local_heads // self.n_local_kv_heads if n_rep >= 1: @@ -148,12 +130,12 @@ class JointAttention(nn.Module): class FeedForward(nn.Module): def __init__( - self, - dim: int, - hidden_dim: int, - multiple_of: int, - ffn_dim_multiplier: Optional[float], - operation_settings={}, + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: Optional[float], + operation_settings=None, ): """ Initialize the FeedForward module. @@ -169,6 +151,8 @@ class FeedForward(nn.Module): """ super().__init__() # custom dim factor multiplier + if operation_settings is None: + operation_settings = {} if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) @@ -197,7 +181,7 @@ class FeedForward(nn.Module): # @torch.compile def _forward_silu_gating(self, x1, x3): - return F.silu(x1) * x3 + return clamp_fp16(F.silu(x1) * x3) def forward(self, x): return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) @@ -205,17 +189,19 @@ class FeedForward(nn.Module): class JointTransformerBlock(nn.Module): def __init__( - self, - layer_id: int, - dim: int, - n_heads: int, - n_kv_heads: int, - multiple_of: int, - ffn_dim_multiplier: float, - norm_eps: float, - qk_norm: bool, - modulation=True, - operation_settings={}, + self, + layer_id: int, + dim: int, + n_heads: int, + n_kv_heads: int, + multiple_of: int, + ffn_dim_multiplier: float, + norm_eps: float, + qk_norm: bool, + modulation=True, + z_image_modulation=False, + attn_out_bias=False, + operation_settings=None, ) -> None: """ Initialize a TransformerBlock. @@ -233,12 +219,14 @@ class JointTransformerBlock(nn.Module): """ super().__init__() + if operation_settings is None: + operation_settings = {} self.dim = dim self.head_dim = dim // n_heads - self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, operation_settings=operation_settings) + self.attention = JointAttention(dim, n_heads, n_kv_heads, qk_norm, out_bias=attn_out_bias, operation_settings=operation_settings) self.feed_forward = FeedForward( dim=dim, - hidden_dim=4 * dim, + hidden_dim=dim, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier, operation_settings=operation_settings, @@ -252,24 +240,35 @@ class JointTransformerBlock(nn.Module): self.modulation = modulation if modulation: - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - operation_settings.get("operations").Linear( - min(dim, 1024), - 4 * dim, - bias=True, - device=operation_settings.get("device"), - dtype=operation_settings.get("dtype"), - ), - ) + if z_image_modulation: + self.adaLN_modulation = nn.Sequential( + operation_settings.get("operations").Linear( + min(dim, 256), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + else: + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(dim, 1024), + 4 * dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) def forward( - self, - x: torch.Tensor, - x_mask: torch.Tensor, - freqs_cis: torch.Tensor, - adaln_input: Optional[torch.Tensor]=None, - transformer_options={}, + self, + x: torch.Tensor, + x_mask: torch.Tensor, + freqs_cis: torch.Tensor, + adaln_input: Optional[torch.Tensor] = None, + transformer_options=None, ): """ Perform a forward pass through the TransformerBlock. @@ -283,32 +282,34 @@ class JointTransformerBlock(nn.Module): feedforward layers. """ + if transformer_options is None: + transformer_options = {} if self.modulation: assert adaln_input is not None scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) x = x + gate_msa.unsqueeze(1).tanh() * self.attention_norm2( - self.attention( + clamp_fp16(self.attention( modulate(self.attention_norm1(x), scale_msa), x_mask, freqs_cis, transformer_options=transformer_options, - ) + )) ) x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2( - self.feed_forward( + clamp_fp16(self.feed_forward( modulate(self.ffn_norm1(x), scale_mlp), - ) + )) ) else: assert adaln_input is None x = x + self.attention_norm2( - self.attention( + clamp_fp16(self.attention( self.attention_norm1(x), x_mask, freqs_cis, transformer_options=transformer_options, - ) + )) ) x = x + self.ffn_norm2( self.feed_forward( @@ -323,8 +324,10 @@ class FinalLayer(nn.Module): The final layer of NextDiT. """ - def __init__(self, hidden_size, patch_size, out_channels, operation_settings={}): + def __init__(self, hidden_size, patch_size, out_channels, z_image_modulation=False, operation_settings=None): super().__init__() + if operation_settings is None: + operation_settings = {} self.norm_final = operation_settings.get("operations").LayerNorm( hidden_size, elementwise_affine=False, @@ -340,10 +343,15 @@ class FinalLayer(nn.Module): dtype=operation_settings.get("dtype"), ) + if z_image_modulation: + min_mod = 256 + else: + min_mod = 1024 + self.adaLN_modulation = nn.Sequential( nn.SiLU(), operation_settings.get("operations").Linear( - min(hidden_size, 1024), + min(hidden_size, min_mod), hidden_size, bias=True, device=operation_settings.get("device"), @@ -364,25 +372,30 @@ class NextDiT(nn.Module): """ def __init__( - self, - patch_size: int = 2, - in_channels: int = 4, - dim: int = 4096, - n_layers: int = 32, - n_refiner_layers: int = 2, - n_heads: int = 32, - n_kv_heads: Optional[int] = None, - multiple_of: int = 256, - ffn_dim_multiplier: Optional[float] = None, - norm_eps: float = 1e-5, - qk_norm: bool = False, - cap_feat_dim: int = 5120, - axes_dims: List[int] = (16, 56, 56), - axes_lens: List[int] = (1, 512, 512), - image_model=None, - device=None, - dtype=None, - operations=None, + self, + patch_size: int = 2, + in_channels: int = 4, + dim: int = 4096, + n_layers: int = 32, + n_refiner_layers: int = 2, + n_heads: int = 32, + n_kv_heads: Optional[int] = None, + multiple_of: int = 256, + ffn_dim_multiplier: float = 4.0, + norm_eps: float = 1e-5, + qk_norm: bool = False, + cap_feat_dim: int = 5120, + axes_dims: List[int] = (16, 56, 56), + axes_lens: List[int] = (1, 512, 512), + rope_theta=10000.0, + z_image_modulation=False, + time_scale=1.0, + pad_tokens_multiple=None, + clip_text_dim=None, + image_model=None, + device=None, + dtype=None, + operations=None, ) -> None: super().__init__() self.dtype = dtype @@ -390,6 +403,8 @@ class NextDiT(nn.Module): self.in_channels = in_channels self.out_channels = in_channels self.patch_size = patch_size + self.time_scale = time_scale + self.pad_tokens_multiple = pad_tokens_multiple self.x_embedder = operation_settings.get("operations").Linear( in_features=patch_size * patch_size * in_channels, @@ -411,6 +426,7 @@ class NextDiT(nn.Module): norm_eps, qk_norm, modulation=True, + z_image_modulation=z_image_modulation, operation_settings=operation_settings, ) for layer_id in range(n_refiner_layers) @@ -434,7 +450,7 @@ class NextDiT(nn.Module): ] ) - self.t_embedder = TimestepEmbedder(min(dim, 1024), **operation_settings) + self.t_embedder = TimestepEmbedder(min(dim, 1024), output_size=256 if z_image_modulation else None, **operation_settings) self.cap_embedder = nn.Sequential( operation_settings.get("operations").RMSNorm(cap_feat_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), operation_settings.get("operations").Linear( @@ -446,6 +462,31 @@ class NextDiT(nn.Module): ), ) + self.clip_text_pooled_proj = None + + if clip_text_dim is not None: + self.clip_text_dim = clip_text_dim + self.clip_text_pooled_proj = nn.Sequential( + operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")), + operation_settings.get("operations").Linear( + clip_text_dim, + clip_text_dim, + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.time_text_embed = nn.Sequential( + nn.SiLU(), + operation_settings.get("operations").Linear( + min(dim, 1024) + clip_text_dim, + min(dim, 1024), + bias=True, + device=operation_settings.get("device"), + dtype=operation_settings.get("dtype"), + ), + ) + self.layers = nn.ModuleList( [ JointTransformerBlock( @@ -457,23 +498,29 @@ class NextDiT(nn.Module): ffn_dim_multiplier, norm_eps, qk_norm, + z_image_modulation=z_image_modulation, + attn_out_bias=False, operation_settings=operation_settings, ) for layer_id in range(n_layers) ] ) self.norm_final = operation_settings.get("operations").RMSNorm(dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) - self.final_layer = FinalLayer(dim, patch_size, self.out_channels, operation_settings=operation_settings) + self.final_layer = FinalLayer(dim, patch_size, self.out_channels, z_image_modulation=z_image_modulation, operation_settings=operation_settings) + + if self.pad_tokens_multiple is not None: + self.x_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) + self.cap_pad_token = nn.Parameter(torch.empty((1, dim), device=device, dtype=dtype)) assert (dim // n_heads) == sum(axes_dims) self.axes_dims = axes_dims self.axes_lens = axes_lens - self.rope_embedder = EmbedND(dim=dim // n_heads, theta=10000.0, axes_dim=axes_dims) + self.rope_embedder = EmbedND(dim=dim // n_heads, theta=rope_theta, axes_dim=axes_dims) self.dim = dim self.n_heads = n_heads def unpatchify( - self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False + self, x: torch.Tensor, img_size: List[Tuple[int, int]], cap_size: List[int], return_tensor=False ) -> List[torch.Tensor]: """ x: (N, T, patch_size**2 * C) @@ -498,101 +545,61 @@ class NextDiT(nn.Module): return imgs def patchify_and_embed( - self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options={} + self, x: List[torch.Tensor] | torch.Tensor, cap_feats: torch.Tensor, cap_mask: torch.Tensor, t: torch.Tensor, num_tokens, transformer_options=None ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], List[int], torch.Tensor]: + if transformer_options is None: + transformer_options = {} bsz = len(x) pH = pW = self.patch_size device = x[0].device - dtype = x[0].dtype - if cap_mask is not None: - l_effective_cap_len = cap_mask.sum(dim=1).tolist() - else: - l_effective_cap_len = [num_tokens] * bsz + if self.pad_tokens_multiple is not None: + pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple + cap_feats = torch.cat((cap_feats, self.cap_pad_token.to(device=cap_feats.device, dtype=cap_feats.dtype, copy=True).unsqueeze(0).repeat(cap_feats.shape[0], pad_extra, 1)), dim=1) - if cap_mask is not None and not torch.is_floating_point(cap_mask): - cap_mask = (cap_mask - 1).to(dtype) * torch.finfo(dtype).max + cap_pos_ids = torch.zeros(bsz, cap_feats.shape[1], 3, dtype=torch.float32, device=device) + cap_pos_ids[:, :, 0] = torch.arange(cap_feats.shape[1], dtype=torch.float32, device=device) + 1.0 - img_sizes = [(img.size(1), img.size(2)) for img in x] - l_effective_img_len = [(H // pH) * (W // pW) for (H, W) in img_sizes] + B, C, H, W = x.shape + x = self.x_embedder(x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2)) - max_seq_len = max( - (cap_len+img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)) - ) - max_cap_len = max(l_effective_cap_len) - max_img_len = max(l_effective_img_len) + rope_options = transformer_options.get("rope_options", None) + h_scale = 1.0 + w_scale = 1.0 + h_start = 0 + w_start = 0 + if rope_options is not None: + h_scale = rope_options.get("scale_y", 1.0) + w_scale = rope_options.get("scale_x", 1.0) - position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device) + h_start = rope_options.get("shift_y", 0.0) + w_start = rope_options.get("shift_x", 0.0) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - H, W = img_sizes[i] - H_tokens, W_tokens = H // pH, W // pW - assert H_tokens * W_tokens == img_len + H_tokens, W_tokens = H // pH, W // pW + x_pos_ids = torch.zeros((bsz, x.shape[1], 3), dtype=torch.float32, device=device) + x_pos_ids[:, :, 0] = cap_feats.shape[1] + 1 + x_pos_ids[:, :, 1] = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten() + x_pos_ids[:, :, 2] = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device) - position_ids[i, cap_len:cap_len+img_len, 0] = cap_len - row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten() - col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten() - position_ids[i, cap_len:cap_len+img_len, 1] = row_ids - position_ids[i, cap_len:cap_len+img_len, 2] = col_ids + if self.pad_tokens_multiple is not None: + pad_extra = (-x.shape[1]) % self.pad_tokens_multiple + x = torch.cat((x, self.x_pad_token.to(device=x.device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(x.shape[0], pad_extra, 1)), dim=1) + x_pos_ids = torch.nn.functional.pad(x_pos_ids, (0, 0, 0, pad_extra)) - freqs_cis = self.rope_embedder(position_ids).movedim(1, 2).to(dtype) - - # build freqs_cis for cap and image individually - cap_freqs_cis_shape = list(freqs_cis.shape) - # cap_freqs_cis_shape[1] = max_cap_len - cap_freqs_cis_shape[1] = cap_feats.shape[1] - cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - img_freqs_cis_shape = list(freqs_cis.shape) - img_freqs_cis_shape[1] = max_img_len - img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype) - - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len] - img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len:cap_len+img_len] + freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2) # refine context for layer in self.context_refiner: - cap_feats = layer(cap_feats, cap_mask, cap_freqs_cis, transformer_options=transformer_options) + cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options) - # refine image - flat_x = [] - for i in range(bsz): - img = x[i] - C, H, W = img.size() - img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1) - flat_x.append(img) - x = flat_x - padded_img_embed = torch.zeros(bsz, max_img_len, x[0].shape[-1], device=device, dtype=x[0].dtype) - padded_img_mask = torch.zeros(bsz, max_img_len, dtype=dtype, device=device) - for i in range(bsz): - padded_img_embed[i, :l_effective_img_len[i]] = x[i] - padded_img_mask[i, l_effective_img_len[i]:] = -torch.finfo(dtype).max - - padded_img_embed = self.x_embedder(padded_img_embed) - padded_img_mask = padded_img_mask.unsqueeze(1) + padded_img_mask = None for layer in self.noise_refiner: - padded_img_embed = layer(padded_img_embed, padded_img_mask, img_freqs_cis, t, transformer_options=transformer_options) - - if cap_mask is not None: - mask = torch.zeros(bsz, max_seq_len, dtype=dtype, device=device) - mask[:, :max_cap_len] = cap_mask[:, :max_cap_len] - else: - mask = None - - padded_full_embed = torch.zeros(bsz, max_seq_len, self.dim, device=device, dtype=x[0].dtype) - for i in range(bsz): - cap_len = l_effective_cap_len[i] - img_len = l_effective_img_len[i] - - padded_full_embed[i, :cap_len] = cap_feats[i, :cap_len] - padded_full_embed[i, cap_len:cap_len+img_len] = padded_img_embed[i, :img_len] + x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options) + padded_full_embed = torch.cat((cap_feats, x), dim=1) + mask = None + img_sizes = [(H, W)] * bsz + l_effective_cap_len = [cap_feats.shape[1]] * bsz return padded_full_embed, mask, img_sizes, l_effective_cap_len, freqs_cis def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): @@ -603,7 +610,9 @@ class NextDiT(nn.Module): ).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs) # def forward(self, x, t, cap_feats, cap_mask): - def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs): + def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, transformer_options=None, **kwargs): + if transformer_options is None: + transformer_options = {} t = 1.0 - timesteps cap_feats = context cap_mask = attention_mask @@ -615,21 +624,36 @@ class NextDiT(nn.Module): y: (N,) tensor of text tokens/features """ - t = self.t_embedder(t, dtype=x.dtype) # (N, D) + t = self.t_embedder(t * self.time_scale, dtype=x.dtype) # (N, D) adaln_input = t cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute - transformer_options = kwargs.get("transformer_options", {}) + if self.clip_text_pooled_proj is not None: + pooled = kwargs.get("clip_text_pooled", None) + if pooled is not None: + pooled = self.clip_text_pooled_proj(pooled) + else: + pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype) + + adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1)) + + patches = transformer_options.get("patches", {}) x_is_tensor = isinstance(x, torch.Tensor) - x, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) - freqs_cis = freqs_cis.to(x.device) + img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) + freqs_cis = freqs_cis.to(img.device) - for layer in self.layers: - x = layer(x, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + for i, layer in enumerate(self.layers): + img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options) + if "double_block" in patches: + for p in patches["double_block"]: + out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options}) + if "img" in out: + img[:, cap_size[0]:] = out["img"] + if "txt" in out: + img[:, :cap_size[0]] = out["txt"] - x = self.final_layer(x, adaln_input) - x = self.unpatchify(x, img_size, cap_size, return_tensor=x_is_tensor)[:,:,:h,:w] - - return -x + img = self.final_layer(img, adaln_input) + img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w] + return -img diff --git a/comfy/ldm/models/autoencoder.py b/comfy/ldm/models/autoencoder.py index 0eb470827..0f4c7a7c4 100644 --- a/comfy/ldm/models/autoencoder.py +++ b/comfy/ldm/models/autoencoder.py @@ -1,10 +1,12 @@ import logging import math from contextlib import contextmanager -from typing import Any, Dict, Tuple, Union, Callable +from typing import Any, Dict, Tuple, Union, Callable, Optional import torch +from einops import rearrange +from ...model_management import cast_to from ..modules.distributions.distributions import DiagonalGaussianDistribution from ..modules.ema import LitEma from ..util import instantiate_from_config, get_obj_from_str @@ -12,6 +14,7 @@ from ... import ops logger = logging.getLogger(__name__) + class DiagonalGaussianRegularizer(torch.nn.Module): def __init__(self, sample: bool = False): super().__init__() @@ -20,7 +23,7 @@ class DiagonalGaussianRegularizer(torch.nn.Module): def get_trainable_parameters(self) -> Any: yield from () - def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Optional[dict]]: posterior = DiagonalGaussianDistribution(z) if self.sample: z = posterior.sample() @@ -28,13 +31,15 @@ class DiagonalGaussianRegularizer(torch.nn.Module): z = posterior.mode() return z, None + class EmptyRegularizer(torch.nn.Module): def __init__(self): super().__init__() - def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, dict]: + def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Optional[dict]]: return z, None + class AbstractAutoencoder(torch.nn.Module): """ This is the base class for all autoencoders, including image autoencoders, image autoencoders with discriminators, @@ -181,8 +186,27 @@ class AutoencodingEngineLegacy(AutoencodingEngine): self.post_quant_conv = conv_op(embed_dim, ddconfig["z_channels"], 1) self.embed_dim = embed_dim + if ddconfig.get("batch_norm_latent", False): + self.bn_eps = 1e-4 + self.bn_momentum = 0.1 + self.ps = [2, 2] + self.bn = torch.nn.BatchNorm2d(math.prod(self.ps) * ddconfig["z_channels"], + eps=self.bn_eps, + momentum=self.bn_momentum, + affine=False, + track_running_stats=True, + ) + self.bn.eval() + else: + self.bn = None + + def get_autoencoder_params(self) -> list: + params = super().get_autoencoder_params() + return params + def encode( - self, x: torch.Tensor, return_reg_log: bool = False + self, x: torch.Tensor, return_reg_log: bool = False, + unregularized: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, dict]]: if self.max_batch_size is None: z = self.encoder(x) @@ -199,11 +223,36 @@ class AutoencodingEngineLegacy(AutoencodingEngine): z = torch.cat(z, 0) z, reg_log = self.regularization(z) + + if self.bn is not None: + z = rearrange(z, + "... c (i pi) (j pj) -> ... (c pi pj) i j", + pi=self.ps[0], + pj=self.ps[1], + ) + + z = torch.nn.functional.batch_norm(z, + cast_to(self.bn.running_mean, dtype=z.dtype, device=z.device), + cast_to(self.bn.running_var, dtype=z.dtype, device=z.device), + momentum=self.bn_momentum, + eps=self.bn_eps) + if return_reg_log: return z, reg_log return z def decode(self, z: torch.Tensor, **decoder_kwargs) -> torch.Tensor: + if self.bn is not None: + s = torch.sqrt(cast_to(self.bn.running_var.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + self.bn_eps) + m = cast_to(self.bn.running_mean.view(1, -1, 1, 1), dtype=z.dtype, device=z.device) + z = z * s + m + z = rearrange( + z, + "... (c pi pj) i j -> ... c (i pi) (j pj)", + pi=self.ps[0], + pj=self.ps[1], + ) + if self.max_batch_size is None: dec = self.post_quant_conv(z) dec = self.decoder(dec, **decoder_kwargs) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index e1c729f71..9d792711e 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -557,6 +557,7 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha @wrap_attn def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + exception_fallback = False if skip_reshape: b, _, _, dim_head = q.shape tensor_layout = "HND" @@ -581,6 +582,8 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = sageattn(q, k, v, attn_mask=mask, is_causal=False, tensor_layout=tensor_layout) except Exception as e: logger.error("Error running sage attention: {}, using pytorch attention instead.".format(e)) + exception_fallback = True + if exception_fallback: if tensor_layout == "NHD": q, k, v = map( lambda t: t.transpose(1, 2), diff --git a/comfy/ldm/modules/diffusionmodules/mmdit.py b/comfy/ldm/modules/diffusionmodules/mmdit.py index 46e9ee4d6..e6e9733fe 100644 --- a/comfy/ldm/modules/diffusionmodules/mmdit.py +++ b/comfy/ldm/modules/diffusionmodules/mmdit.py @@ -213,12 +213,14 @@ class TimestepEmbedder(nn.Module): Embeds scalar timesteps into vector representations. """ - def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + def __init__(self, hidden_size, frequency_embedding_size=256, output_size=None, dtype=None, device=None, operations=None): super().__init__() + if output_size is None: + output_size = hidden_size self.mlp = nn.Sequential( operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), nn.SiLU(), - operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + operations.Linear(hidden_size, output_size, bias=True, dtype=dtype, device=device), ) self.frequency_embedding_size = frequency_embedding_size diff --git a/comfy/ldm/modules/diffusionmodules/model.py b/comfy/ldm/modules/diffusionmodules/model.py index d87af7cc9..d4667825b 100644 --- a/comfy/ldm/modules/diffusionmodules/model.py +++ b/comfy/ldm/modules/diffusionmodules/model.py @@ -18,6 +18,13 @@ if model_management.xformers_enabled_vae(): import xformers.ops # pylint: disable=import-error +def torch_cat_if_needed(xl, dim): + if len(xl) > 1: + return torch.cat(xl, dim) + else: + return xl[0] + + def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: @@ -41,13 +48,43 @@ def get_timestep_embedding(timesteps, embedding_dim): def nonlinearity(x): # swish - return torch.nn.functional.silu(x) + return torch.nn.functional.silu(x) def Normalize(in_channels, num_groups=32): return ops.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) +class CarriedConv3d(nn.Module): + def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding=0, **kwargs): + super().__init__() + self.conv = ops.Conv3d(n_channels, out_channels, kernel_size, stride=stride, dilation=dilation, **kwargs) + + def forward(self, x): + return self.conv(x) + + +def conv_carry_causal_3d(xl, op, conv_carry_in=None, conv_carry_out=None): + x = xl[0] + xl.clear() + + if isinstance(op, CarriedConv3d): + if conv_carry_in is None: + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2, 0), mode='replicate') + else: + carry_len = conv_carry_in[0].shape[2] + x = torch.nn.functional.pad(x, (1, 1, 1, 1, 2 - carry_len, 0), mode='replicate') + x = torch.cat([conv_carry_in.pop(0), x], dim=2) + + if conv_carry_out is not None: + to_push = x[:, :, -2:, :, :].clone() + conv_carry_out.append(to_push) + + out = op(x) + + return out + + class VideoConv3d(nn.Module): def __init__(self, n_channels, out_channels, kernel_size, stride=1, dilation=1, padding_mode='replicate', padding=1, **kwargs): super().__init__() @@ -96,29 +133,24 @@ class Upsample(nn.Module): stride=1, padding=1) - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): scale_factor = self.scale_factor if isinstance(scale_factor, (int, float)): scale_factor = (scale_factor,) * (x.ndim - 2) if x.ndim == 5 and scale_factor[0] > 1.0: - t = x.shape[2] - if t > 1: - a, b = x.split((1, t - 1), dim=2) - del x - b = interpolate_up(b, scale_factor) - else: - a = x - - a = interpolate_up(a.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2) - if t > 1: - x = torch.cat((a, b), dim=2) - else: - x = a + results = [] + if conv_carry_in is None: + first = x[:, :, :1, :, :] + results.append(interpolate_up(first.squeeze(2), scale_factor=scale_factor[1:]).unsqueeze(2)) + x = x[:, :, 1:, :, :] + if x.shape[2] > 0: + results.append(interpolate_up(x, scale_factor)) + x = torch_cat_if_needed(results, dim=2) else: x = interpolate_up(x, scale_factor) if self.with_conv: - x = self.conv(x) + x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) return x @@ -134,17 +166,20 @@ class Downsample(nn.Module): stride=stride, padding=0) - def forward(self, x): + def forward(self, x, conv_carry_in=None, conv_carry_out=None): if self.with_conv: - if x.ndim == 4: + if isinstance(self.conv, CarriedConv3d): + x = conv_carry_causal_3d([x], self.conv, conv_carry_in, conv_carry_out) + elif x.ndim == 4: pad = (0, 1, 0, 1) mode = "constant" x = torch.nn.functional.pad(x, pad, mode=mode, value=0) + x = self.conv(x) elif x.ndim == 5: pad = (1, 1, 1, 1, 2, 0) mode = "replicate" x = torch.nn.functional.pad(x, pad, mode=mode) - x = self.conv(x) + x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x @@ -190,23 +225,23 @@ class ResnetBlock(nn.Module): stride=1, padding=0) - def forward(self, x, temb=None): + def forward(self, x, temb=None, conv_carry_in=None, conv_carry_out=None): h = x h = self.norm1(h) - h = self.swish(h) - h = self.conv1(h) + h = [self.swish(h)] + h = conv_carry_causal_3d(h, self.conv1, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if temb is not None: h = h + self.temb_proj(self.swish(temb))[:, :, None, None] h = self.norm2(h) h = self.swish(h) - h = self.dropout(h) - h = self.conv2(h) + h = [self.dropout(h)] + h = conv_carry_causal_3d(h, self.conv2, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) if self.in_channels != self.out_channels: if self.use_conv_shortcut: - x = self.conv_shortcut(x) + x = conv_carry_causal_3d([x], self.conv_shortcut, conv_carry_in=conv_carry_in, conv_carry_out=conv_carry_out) else: x = self.nin_shortcut(x) @@ -290,6 +325,7 @@ def pytorch_attention(q, k, v): orig_shape = q.shape B = orig_shape[0] C = orig_shape[1] + oom_fallback = False q, k, v = map( lambda t: t.view(B, 1, C, -1).transpose(2, 3).contiguous(), (q, k, v), @@ -300,6 +336,8 @@ def pytorch_attention(q, k, v): out = out.transpose(2, 3).reshape(orig_shape) except model_management.OOM_EXCEPTION: logger.warning("scaled_dot_product_attention OOMed: switched to slice attention") + oom_fallback = True + if oom_fallback: out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(orig_shape) return out @@ -529,9 +567,14 @@ class Encoder(nn.Module): self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels + self.carried = False if conv3d: - conv_op = VideoConv3d + if not attn_resolutions: + conv_op = CarriedConv3d + self.carried = True + else: + conv_op = VideoConv3d mid_attn_conv_op = ops.Conv3d else: conv_op = ops.Conv2d @@ -544,6 +587,7 @@ class Encoder(nn.Module): stride=1, padding=1) + self.time_compress = 1 curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult @@ -570,10 +614,15 @@ class Encoder(nn.Module): if time_compress is not None: if (self.num_resolutions - 1 - i_level) > math.log2(time_compress): stride = (1, 2, 2) + else: + self.time_compress *= 2 down.downsample = Downsample(block_in, resamp_with_conv, stride=stride, conv_op=conv_op) curr_res = curr_res // 2 self.down.append(down) + if time_compress is not None: + self.time_compress = time_compress + # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock(in_channels=block_in, @@ -599,15 +648,42 @@ class Encoder(nn.Module): def forward(self, x): # timestep embedding temb = None - # downsampling - h = self.conv_in(x) - for i_level in range(self.num_resolutions): - for i_block in range(self.num_res_blocks): - h = self.down[i_level].block[i_block](h, temb) - if len(self.down[i_level].attn) > 0: - h = self.down[i_level].attn[i_block](h) - if i_level != self.num_resolutions - 1: - h = self.down[i_level].downsample(h) + + if self.carried: + xl = [x[:, :, :1, :, :]] + if x.shape[2] > self.time_compress: + tc = self.time_compress + xl += torch.split(x[:, :, 1: 1 + ((x.shape[2] - 1) // tc) * tc, :, :], tc * 2, dim=2) + x = xl + else: + x = [x] + out = [] + + conv_carry_in = None + + for i, x1 in enumerate(x): + conv_carry_out = [] + if i == len(x) - 1: + conv_carry_out = None + + # downsampling + x1 = [x1] + h1 = conv_carry_causal_3d(x1, self.conv_in, conv_carry_in, conv_carry_out) + + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h1 = self.down[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out) + if len(self.down[i_level].attn) > 0: + assert i == 0 # carried should not happen if attn exists + h1 = self.down[i_level].attn[i_block](h1) + if i_level != self.num_resolutions - 1: + h1 = self.down[i_level].downsample(h1, conv_carry_in, conv_carry_out) + + out.append(h1) + conv_carry_in = conv_carry_out + + h = torch_cat_if_needed(out, dim=2) + del out # middle h = self.mid.block_1(h, temb) @@ -616,15 +692,15 @@ class Encoder(nn.Module): # end h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h = [nonlinearity(h)] + h = conv_carry_causal_3d(h, self.conv_out) return h class Decoder(nn.Module): def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, - resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + resolution, z_channels, tanh_out=False, use_linear_attn=False, conv_out_op=ops.Conv2d, resnet_op=ResnetBlock, attn_op=AttnBlock, @@ -638,12 +714,18 @@ class Decoder(nn.Module): self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels - self.give_pre_end = give_pre_end self.tanh_out = tanh_out + self.carried = False if conv3d: - conv_op = VideoConv3d - conv_out_op = VideoConv3d + if not attn_resolutions and resnet_op == ResnetBlock: + conv_op = CarriedConv3d + conv_out_op = CarriedConv3d + self.carried = True + else: + conv_op = VideoConv3d + conv_out_op = VideoConv3d + mid_attn_conv_op = ops.Conv3d else: conv_op = ops.Conv2d @@ -718,29 +800,43 @@ class Decoder(nn.Module): temb = None # z to block_in - h = self.conv_in(z) + h = conv_carry_causal_3d([z], self.conv_in) # middle h = self.mid.block_1(h, temb, **kwargs) h = self.mid.attn_1(h, **kwargs) h = self.mid.block_2(h, temb, **kwargs) + if self.carried: + h = torch.split(h, 2, dim=2) + else: + h = [h] + out = [] + + conv_carry_in = None + # upsampling - for i_level in reversed(range(self.num_resolutions)): - for i_block in range(self.num_res_blocks + 1): - h = self.up[i_level].block[i_block](h, temb, **kwargs) - if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h, **kwargs) - if i_level != 0: - h = self.up[i_level].upsample(h) + for i, h1 in enumerate(h): + conv_carry_out = [] + if i == len(h) - 1: + conv_carry_out = None + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h1 = self.up[i_level].block[i_block](h1, temb, conv_carry_in, conv_carry_out, **kwargs) + if len(self.up[i_level].attn) > 0: + assert i == 0 # carried should not happen if attn exists + h1 = self.up[i_level].attn[i_block](h1, **kwargs) + if i_level != 0: + h1 = self.up[i_level].upsample(h1, conv_carry_in, conv_carry_out) - # end - if self.give_pre_end: - return h + h1 = self.norm_out(h1) + h1 = [nonlinearity(h1)] + h1 = conv_carry_causal_3d(h1, self.conv_out, conv_carry_in, conv_carry_out) + if self.tanh_out: + h1 = torch.tanh(h1) + out.append(h1) + conv_carry_in = conv_carry_out - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h, **kwargs) - if self.tanh_out: - h = torch.tanh(h) - return h + out = torch_cat_if_needed(out, dim=2) + + return out diff --git a/comfy/ldm/qwen_image/controlnet.py b/comfy/ldm/qwen_image/controlnet.py index 92ac3cf0a..a6d408104 100644 --- a/comfy/ldm/qwen_image/controlnet.py +++ b/comfy/ldm/qwen_image/controlnet.py @@ -44,7 +44,7 @@ class QwenImageControlNetModel(QwenImageTransformer2DModel): txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) + self.controlnet_x_embedder(hint) diff --git a/comfy/ldm/qwen_image/model.py b/comfy/ldm/qwen_image/model.py index 76f66f7c0..f19972683 100644 --- a/comfy/ldm/qwen_image/model.py +++ b/comfy/ldm/qwen_image/model.py @@ -10,6 +10,7 @@ from ..flux.layers import EmbedND from ..lightricks.model import TimestepEmbedding, Timesteps from ..modules.attention import optimized_attention_masked from ...patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP +from ..flux.math import apply_rope1 class GELU(nn.Module): @@ -137,33 +138,34 @@ class Attention(nn.Module): ) -> Tuple[torch.Tensor, torch.Tensor]: if transformer_options is None: transformer_options = {} + batch_size = hidden_states.shape[0] + seq_img = hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1] - img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1)) - img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1)) - img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1)) + # Project and reshape to BHND format (batch, heads, seq, dim) + img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() + img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2) - txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) - txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) - txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1)) + txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous() + txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2) img_query = self.norm_q(img_query) img_key = self.norm_k(img_key) txt_query = self.norm_added_q(txt_query) txt_key = self.norm_added_k(txt_key) - joint_query = torch.cat([txt_query, img_query], dim=1) - joint_key = torch.cat([txt_key, img_key], dim=1) - joint_value = torch.cat([txt_value, img_value], dim=1) + joint_query = torch.cat([txt_query, img_query], dim=2) + joint_key = torch.cat([txt_key, img_key], dim=2) + joint_value = torch.cat([txt_value, img_value], dim=2) - joint_query = apply_rotary_emb(joint_query, image_rotary_emb) - joint_key = apply_rotary_emb(joint_key, image_rotary_emb) + joint_query = apply_rope1(joint_query, image_rotary_emb) + joint_key = apply_rope1(joint_key, image_rotary_emb) - joint_query = joint_query.flatten(start_dim=2) - joint_key = joint_key.flatten(start_dim=2) - joint_value = joint_value.flatten(start_dim=2) - - joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options) + joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, + attention_mask, transformer_options=transformer_options, + skip_reshape=True) txt_attn_output = joint_hidden_states[:, :seq_txt, :] img_attn_output = joint_hidden_states[:, seq_txt:, :] @@ -239,10 +241,10 @@ class QwenImageTransformerBlock(nn.Module): img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) - img_normed = self.img_norm1(hidden_states) - img_modulated, img_gate1 = self._modulate(img_normed, img_mod1) - txt_normed = self.txt_norm1(encoder_hidden_states) - txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) + img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1) + del img_mod1 + txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1) + del txt_mod1 img_attn_output, txt_attn_output = self.attn( hidden_states=img_modulated, @@ -251,16 +253,20 @@ class QwenImageTransformerBlock(nn.Module): image_rotary_emb=image_rotary_emb, transformer_options=transformer_options, ) + del img_modulated + del txt_modulated hidden_states = hidden_states + img_gate1 * img_attn_output encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output + del img_attn_output + del txt_attn_output + del img_gate1 + del txt_gate1 - img_normed2 = self.img_norm2(hidden_states) - img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2) + img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2) hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2)) - txt_normed2 = self.txt_norm2(encoder_hidden_states) - txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2) + txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2) encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2)) return encoder_hidden_states, hidden_states @@ -421,7 +427,7 @@ class QwenImageTransformer2DModel(nn.Module): txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) ids = torch.cat((txt_ids, img_ids), dim=1) - image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype) + image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous() del ids, txt_ids, img_ids hidden_states = self.img_in(hidden_states) @@ -441,7 +447,10 @@ class QwenImageTransformer2DModel(nn.Module): patches = transformer_options.get("patches", {}) blocks_replace = patches_replace.get("dit", {}) + transformer_options["total_blocks"] = len(self.transformer_blocks) + transformer_options["block_type"] = "double" for i, block in enumerate(self.transformer_blocks): + transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): out = {} diff --git a/comfy/ldm/wan/model.py b/comfy/ldm/wan/model.py index 5f1d4463e..975e4af2d 100644 --- a/comfy/ldm/wan/model.py +++ b/comfy/ldm/wan/model.py @@ -246,6 +246,7 @@ class WanAttentionBlock(nn.Module): # assert e[0].dtype == torch.float32 # self-attention + x = x.contiguous() # otherwise implicit in LayerNorm y = self.self_attn( torch.addcmul(repeat_e(e[0], x), self.norm1(x), 1 + repeat_e(e[1], x)), freqs, transformer_options=transformer_options) @@ -615,7 +616,7 @@ class WanModel(torch.nn.Module): x = self.unpatchify(x, grid_sizes) return x - def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None): + def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}): patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) h_len = ((h + (patch_size[1] // 2)) // patch_size[1]) @@ -628,10 +629,22 @@ class WanModel(torch.nn.Module): if steps_w is None: steps_w = w_len + h_start = 0 + w_start = 0 + rope_options = transformer_options.get("rope_options", None) + if rope_options is not None: + t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0 + h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0 + w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0 + + t_start += rope_options.get("shift_t", 0.0) + h_start += rope_options.get("shift_y", 0.0) + w_start += rope_options.get("shift_x", 0.0) + img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype) img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1) - img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) - img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) + img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1) + img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1) img_ids = img_ids.reshape(1, -1, img_ids.shape[-1]) freqs = self.rope_embedder(img_ids).movedim(1, 2) @@ -661,7 +674,7 @@ class WanModel(torch.nn.Module): if self.ref_conv is not None and "reference_latent" in kwargs: t_len += 1 - freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype) + freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options) return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w] def unpatchify(self, x, grid_sizes): diff --git a/comfy/lora.py b/comfy/lora.py index ff87dd3c0..24592d227 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -326,6 +326,22 @@ def model_lora_keys_unet(model, key_map=None): key_map["transformer.{}".format(key_lora)] = k key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = k # SimpleTuner lycoris format + if isinstance(model, comfy.model_base.Lumina2): + diffusers_keys = comfy.utils.z_image_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = k[:-len(".weight")] + key_map["diffusion_model.{}".format(key_lora)] = to + key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to + + if isinstance(model, comfy.model_base.Kandinsky5): + for k in sdk: + if k.startswith("diffusion_model.") and k.endswith(".weight"): + key_lora = k[len("diffusion_model."):-len(".weight")] + key_map["{}".format(key_lora)] = k + key_map["transformer.{}".format(key_lora)] = k + return key_map diff --git a/comfy/model_base.py b/comfy/model_base.py index 0de79ae78..92e9fbc19 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -53,6 +53,7 @@ from .ldm.modules.encoders.noise_aug_modules import CLIPEmbeddingNoiseAugmentati from .ldm.chroma_radiance import model as chroma_radiance from .ldm.omnigen.omnigen2 import OmniGen2Transformer2DModel from .ldm.pixart.pixartms import PixArtMS +from .ldm.kandinsky5.model import Kandinsky5 from .ldm.qwen_image.model import QwenImageTransformer2DModel from .ldm.wan.model import WanModel, VaceWanModel, CameraWanModel, WanModel_S2V, HumoWanModel from .ldm.wan.model_animate import AnimateWanModel @@ -149,7 +150,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8) + operations = ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config) else: operations = model_config.custom_operations self.operations = operations @@ -216,8 +217,14 @@ class BaseModel(torch.nn.Module): extra_conds[o] = extra t = self.process_timestep(t, x=x, **extra_conds) - model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float() - return self.model_sampling.calculate_denoised(sigma, model_output, x) + if "latent_shapes" in extra_conds: + xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes")) + + model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds) + if len(model_output) > 1 and not torch.is_tensor(model_output): + model_output, _ = utils.pack_latents(model_output) + + return self.model_sampling.calculate_denoised(sigma, model_output.float(), x) def process_timestep(self, timestep, **kwargs): return timestep @@ -343,10 +350,6 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) unet_state_dict = self.diffusion_model.state_dict() - - if self.model_config.scaled_fp8 is not None: - unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) - unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) if self.model_type == ModelType.V_PREDICTION: @@ -921,12 +924,13 @@ class Flux(BaseModel): attention_mask = kwargs.get("attention_mask", None) if attention_mask is not None: shape = kwargs["noise"].shape - mask_ref_size = kwargs["attention_mask_img_shape"] - # the model will pad to the patch size, and then divide - # essentially dividing and rounding up - (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) - attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) - out['attention_mask'] = conds.CONDRegular(attention_mask) + mask_ref_size = kwargs.get("attention_mask_img_shape", None) + if mask_ref_size is not None: + # the model will pad to the patch size, and then divide + # essentially dividing and rounding up + (h_tok, w_tok) = (math.ceil(shape[2] / self.diffusion_model.patch_size), math.ceil(shape[3] / self.diffusion_model.patch_size)) + attention_mask = utils.upscale_dit_mask(attention_mask, mask_ref_size, (h_tok, w_tok)) + out['attention_mask'] = conds.CONDRegular(attention_mask) guidance = kwargs.get("guidance", 3.5) if guidance is not None: @@ -948,7 +952,19 @@ class Flux(BaseModel): out = {} ref_latents = kwargs.get("reference_latents", None) if ref_latents is not None: - out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()), ref_latents)) // 16]) + out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))]) + return out + + +class Flux2(Flux): + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + target_text_len = 512 + if cross_attn.shape[1] < target_text_len: + cross_attn = torch.nn.functional.pad(cross_attn, (0, 0, target_text_len - cross_attn.shape[1], 0)) + out['c_crossattn'] = conds.CONDRegular(cross_attn) return out @@ -1135,6 +1151,12 @@ class Lumina2(BaseModel): cross_attn = kwargs.get("cross_attn", None) if cross_attn is not None: out['c_crossattn'] = conds.CONDRegular(cross_attn) + if 'num_tokens' not in out: + out['num_tokens'] = conds.CONDConstant(cross_attn.shape[1]) + + clip_text_pooled = kwargs["pooled_output"] # Newbie + if clip_text_pooled is not None: + out['clip_text_pooled'] = conds.CONDRegular(clip_text_pooled) return out @@ -1580,3 +1602,144 @@ class HunyuanImage21Refiner(HunyuanImage21): out = super().extra_conds(**kwargs) out['disable_time_r'] = conds.CONDConstant(True) return out + + +class HunyuanVideo15(HunyuanVideo): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + extra_channels = self.diffusion_model.img_in.proj.weight.shape[1] - noise.shape[1] - 1 # noise 32 img cond 32 + mask 1 + if extra_channels == 0: + return None + + image = kwargs.get("concat_latent_image", None) + device = kwargs["device"] + + if image is None: + shape_image = list(noise.shape) + shape_image[1] = extra_channels + image = torch.zeros(shape_image, dtype=noise.dtype, layout=noise.layout, device=noise.device) + else: + latent_dim = self.latent_format.latent_channels + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + for i in range(0, image.shape[1], latent_dim): + image[:, i: i + latent_dim] = self.process_latent_in(image[:, i: i + latent_dim]) + image = utils.resize_to_batch_size(image, noise.shape[0]) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :1] + else: + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + + return torch.cat((image, mask), dim=1) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + if torch.numel(attention_mask) != attention_mask.sum(): + out['attention_mask'] = conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = conds.CONDRegular(cross_attn) + + conditioning_byt5small = kwargs.get("conditioning_byt5small", None) + if conditioning_byt5small is not None: + out['txt_byt5'] = conds.CONDRegular(conditioning_byt5small) + + guidance = kwargs.get("guidance", 6.0) + if guidance is not None: + out['guidance'] = conds.CONDRegular(torch.FloatTensor([guidance])) + + clip_vision_output = kwargs.get("clip_vision_output", None) + if clip_vision_output is not None: + out['clip_fea'] = conds.CONDRegular(clip_vision_output.last_hidden_state) + + return out + + +class HunyuanVideo15_SR_Distilled(HunyuanVideo15): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + image = kwargs.get("concat_latent_image", None) + noise_augmentation = kwargs.get("noise_augmentation", 0.0) + device = kwargs["device"] + + if image is None: + image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=model_management.intermediate_device()) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + # image = self.process_latent_in(image) # scaling wasn't applied in reference code + image = utils.resize_to_batch_size(image, noise.shape[0]) + lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1) + if noise_augmentation > 0: + generator = torch.Generator(device="cpu") + generator.manual_seed(kwargs.get("seed", 0) - 10) + noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device) + image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice] + else: + image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice] + return image + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + out['disable_time_r'] = conds.CONDConstant(False) + return out + + +class Kandinsky5(BaseModel): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device, unet_model=Kandinsky5) + + def encode_adm(self, **kwargs): + return kwargs["pooled_output"] + + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + device = kwargs["device"] + image = torch.zeros_like(noise) + + mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None)) + if mask is None: + mask = torch.zeros_like(noise)[:, :1] + else: + mask = 1.0 - mask + mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + if mask.shape[-3] < noise.shape[-3]: + mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0) + mask = utils.resize_to_batch_size(mask, noise.shape[0]) + + return torch.cat((image, mask), dim=1) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + attention_mask = kwargs.get("attention_mask", None) + if attention_mask is not None: + out['attention_mask'] = conds.CONDRegular(attention_mask) + cross_attn = kwargs.get("cross_attn", None) + if cross_attn is not None: + out['c_crossattn'] = conds.CONDRegular(cross_attn) + + time_dim_replace = kwargs.get("time_dim_replace", None) + if time_dim_replace is not None: + out['time_dim_replace'] = conds.CONDRegular(self.process_latent_in(time_dim_replace)) + + return out + + +class Kandinsky5Image(Kandinsky5): + def __init__(self, model_config, model_type=ModelType.FLOW, device=None): + super().__init__(model_config, model_type, device=device) + + def concat_cond(self, **kwargs): + return None diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 8e481b4db..88443ac1a 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -8,6 +8,7 @@ import torch from . import supported_models, utils from . import supported_models_base from .gguf import GGMLOps +from .utils import detect_layer_quantization logger = logging.getLogger(__name__) @@ -180,30 +181,71 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): guidance_keys = list(filter(lambda a: a.startswith("{}guidance_in.".format(key_prefix)), state_dict_keys)) dit_config["guidance_embed"] = len(guidance_keys) > 0 + + # HunyuanVideo 1.5 + if '{}cond_type_embedding.weight'.format(key_prefix) in state_dict_keys: + dit_config["use_cond_type_embedding"] = True + else: + dit_config["use_cond_type_embedding"] = False + if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys: + dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0] + else: + dit_config["vision_in_dim"] = None return dit_config if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): # Flux, Chroma or Chroma Radiance (has no img_in.weight) dit_config = {} - dit_config["image_model"] = "flux" + if '{}double_stream_modulation_img.lin.weight'.format(key_prefix) in state_dict_keys: + dit_config["image_model"] = "flux2" + dit_config["axes_dim"] = [32, 32, 32, 32] + dit_config["num_heads"] = 48 + dit_config["mlp_ratio"] = 3.0 + dit_config["theta"] = 2000 + dit_config["out_channels"] = 128 + dit_config["global_modulation"] = True + dit_config["mlp_silu_act"] = True + dit_config["qkv_bias"] = False + dit_config["ops_bias"] = False + dit_config["default_ref_method"] = "index" + dit_config["ref_index_scale"] = 10.0 + dit_config["txt_ids_dims"] = [3] + patch_size = 1 + else: + dit_config["image_model"] = "flux" + dit_config["axes_dim"] = [16, 56, 56] + dit_config["num_heads"] = 24 + dit_config["mlp_ratio"] = 4.0 + dit_config["theta"] = 10000 + dit_config["out_channels"] = 16 + dit_config["qkv_bias"] = True + dit_config["txt_ids_dims"] = [] + patch_size = 2 + dit_config["in_channels"] = 16 - patch_size = 2 + dit_config["hidden_size"] = 3072 + dit_config["context_in_dim"] = 4096 + dit_config["patch_size"] = patch_size in_key = "{}img_in.weight".format(key_prefix) if in_key in state_dict_keys: - dit_config["in_channels"] = state_dict[in_key].shape[1] // (patch_size * patch_size) - dit_config["out_channels"] = 16 + w = state_dict[in_key] + dit_config["in_channels"] = w.shape[1] // (patch_size * patch_size) + dit_config["hidden_size"] = w.shape[0] + + txt_in_key = "{}txt_in.weight".format(key_prefix) + if txt_in_key in state_dict_keys: + w = state_dict[txt_in_key] + dit_config["context_in_dim"] = w.shape[1] + dit_config["hidden_size"] = w.shape[0] + vec_in_key = '{}vector_in.in_layer.weight'.format(key_prefix) if vec_in_key in state_dict_keys: dit_config["vec_in_dim"] = state_dict[vec_in_key].shape[1] - dit_config["context_in_dim"] = 4096 - dit_config["hidden_size"] = 3072 - dit_config["mlp_ratio"] = 4.0 - dit_config["num_heads"] = 24 + else: + dit_config["vec_in_dim"] = None + dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.') dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.') - dit_config["axes_dim"] = [16, 56, 56] - dit_config["theta"] = 10000 - dit_config["qkv_bias"] = True if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: # Chroma dit_config["image_model"] = "chroma" dit_config["in_channels"] = 64 @@ -226,6 +268,11 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["nerf_embedder_dtype"] = torch.float32 else: dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys + dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys + dit_config["txt_norm"] = "{}txt_norm.scale".format(key_prefix) in state_dict_keys + if dit_config["yak_mlp"] and dit_config["txt_norm"]: # Ovis model + dit_config["txt_ids_dims"] = [1, 2] + return dit_config if '{}t5_yproj.weight'.format(key_prefix) in state_dict_keys: # Genmo mochi preview @@ -372,14 +419,34 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["image_model"] = "lumina2" dit_config["patch_size"] = 2 dit_config["in_channels"] = 16 - dit_config["dim"] = 2304 - dit_config["cap_feat_dim"] = state_dict['{}cap_embedder.1.weight'.format(key_prefix)].shape[1] + w = state_dict['{}cap_embedder.1.weight'.format(key_prefix)] + dit_config["dim"] = w.shape[0] + dit_config["cap_feat_dim"] = w.shape[1] dit_config["n_layers"] = count_blocks(state_dict_keys, '{}layers.'.format(key_prefix) + '{}.') - dit_config["n_heads"] = 24 - dit_config["n_kv_heads"] = 8 dit_config["qk_norm"] = True - dit_config["axes_dims"] = [32, 32, 32] - dit_config["axes_lens"] = [300, 512, 512] + + if dit_config["dim"] == 2304: # Original Lumina 2 + dit_config["n_heads"] = 24 + dit_config["n_kv_heads"] = 8 + dit_config["axes_dims"] = [32, 32, 32] + dit_config["axes_lens"] = [300, 512, 512] + dit_config["rope_theta"] = 10000.0 + dit_config["ffn_dim_multiplier"] = 4.0 + ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None) + if ctd_weight is not None: + dit_config["clip_text_dim"] = ctd_weight.shape[0] + elif dit_config["dim"] == 3840: # Z image + dit_config["n_heads"] = 30 + dit_config["n_kv_heads"] = 30 + dit_config["axes_dims"] = [32, 48, 48] + dit_config["axes_lens"] = [1536, 512, 512] + dit_config["rope_theta"] = 256.0 + dit_config["ffn_dim_multiplier"] = (8.0 / 3.0) + dit_config["z_image_modulation"] = True + dit_config["time_scale"] = 1000.0 + if '{}cap_pad_token'.format(key_prefix) in state_dict_keys: + dit_config["pad_tokens_multiple"] = 32 + return dit_config if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1 @@ -556,6 +623,24 @@ def detect_unet_config(state_dict, key_prefix, metadata=None): dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.') return dit_config + if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5 + dit_config = {} + model_dim = state_dict['{}visual_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + dit_config["model_dim"] = model_dim + if model_dim in [4096, 2560]: # pro video and lite image + dit_config["axes_dims"] = (32, 48, 48) + if model_dim == 2560: # lite image + dit_config["rope_scale_factor"] = (1.0, 1.0, 1.0) + elif model_dim == 1792: # lite video + dit_config["axes_dims"] = (16, 24, 24) + dit_config["time_dim"] = state_dict['{}time_embeddings.in_layer.bias'.format(key_prefix)].shape[0] + dit_config["image_model"] = "kandinsky5" + dit_config["ff_dim"] = state_dict['{}visual_transformer_blocks.0.feed_forward.in_layer.weight'.format(key_prefix)].shape[0] + dit_config["visual_embed_dim"] = state_dict['{}visual_embeddings.in_layer.weight'.format(key_prefix)].shape[1] + dit_config["num_text_blocks"] = count_blocks(state_dict_keys, '{}text_transformer_blocks.'.format(key_prefix) + '{}.') + dit_config["num_visual_blocks"] = count_blocks(state_dict_keys, '{}visual_transformer_blocks.'.format(key_prefix) + '{}.') + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -699,16 +784,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal if model_config is None and use_base_if_no_match: model_config = supported_models_base.BASE(unet_config) - scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix) - if scaled_fp8_key in state_dict: - scaled_fp8_weight = state_dict.pop(scaled_fp8_key) - model_config.scaled_fp8 = scaled_fp8_weight.dtype - if model_config.scaled_fp8 == torch.float32: - model_config.scaled_fp8 = torch.float8_e4m3fn - if scaled_fp8_weight.nelement() == 2: - model_config.optimizations["fp8"] = False - else: - model_config.optimizations["fp8"] = True + # Detect per-layer quantization (mixed precision) + quant_config = detect_layer_quantization(state_dict, unet_key_prefix) + if quant_config: + model_config.quant_config = quant_config + logging.info("Detected mixed precision quantization") if metadata is not None and "format" in metadata and metadata["format"] == "gguf": model_config.custom_operations = GGMLOps diff --git a/comfy/model_management.py b/comfy/model_management.py index bf3aea14e..b47190c01 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -114,6 +114,7 @@ if args.deterministic: directml_device = None if args.directml is not None: + logging.warning("WARNING: torch-directml barely works, is very slow, has not been updated in over 1 year and might be removed soon, please don't use it, there are better options.") import torch_directml # pylint: disable=import-error device_index = args.directml @@ -380,15 +381,20 @@ except: pass SUPPORT_FP8_OPS = args.supports_fp8_compute + +AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"] + try: if is_amd(): - torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD - logger.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") + arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)): + torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD + logger.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.") try: rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2])) except: rocm_version = (6, -1) - arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName + logger.debug("AMD arch: {}".format(arch)) logger.debug("ROCm version: {}".format(rocm_version)) if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: @@ -557,6 +563,7 @@ class LoadedModel: if use_more_vram == 0: use_more_vram = 1e32 self.model_use_more_vram(use_more_vram, force_patch_weights=force_patch_weights) + real_model = self.model.model if is_intel_xpu() and not args.disable_ipex_optimize and 'ipex' in globals() and real_model is not None: @@ -811,8 +818,11 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0 loaded_memory = loaded_model.model_loaded_memory() current_free_mem = get_free_memory(torch_dev) + loaded_memory - lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) - lowvram_model_memory = max(0.1, lowvram_model_memory - loaded_memory) + lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory())) + lowvram_model_memory = lowvram_model_memory - loaded_memory + + if lowvram_model_memory == 0: + lowvram_model_memory = 0.1 if vram_set_state == VRAMState.NO_VRAM: lowvram_model_memory = 0.1 @@ -1149,13 +1159,6 @@ def device_supports_non_blocking(device): return True -def device_should_use_non_blocking(device): - if not device_supports_non_blocking(device): - return False - return False - # return True #TODO: figure out why this causes memory issues on Nvidia and possibly others - - def force_channels_last(): if args.force_channels_last: return True @@ -1165,57 +1168,77 @@ def force_channels_last(): STREAMS = {} -NUM_STREAMS = 1 -if args.async_offload: - NUM_STREAMS = 2 +NUM_STREAMS = 0 +if args.async_offload is not None: + NUM_STREAMS = args.async_offload +else: + # Enable by default on Nvidia + if is_nvidia(): + NUM_STREAMS = 2 + +if args.disable_async_offload: + NUM_STREAMS = 0 + +if NUM_STREAMS > 0: logger.debug("Using async weight offloading with {} streams".format(NUM_STREAMS)) + +def current_stream(device): + if device is None: + return None + if is_device_cuda(device): + return torch.cuda.current_stream() + elif is_device_xpu(device): + return torch.xpu.current_stream() + else: + return None + + stream_counters = {} def get_offload_stream(device): stream_counter = stream_counters.get(device, 0) - if NUM_STREAMS <= 1: + if NUM_STREAMS == 0: + return None + + if torch.compiler.is_compiling(): return None if device in STREAMS: ss = STREAMS[device] - s = ss[stream_counter] + # Sync the oldest stream in the queue with the current + ss[stream_counter].wait_stream(current_stream(device)) stream_counter = (stream_counter + 1) % len(ss) - if is_device_cuda(device): - ss[stream_counter].wait_stream(torch.cuda.current_stream()) - elif is_device_xpu(device): - ss[stream_counter].wait_stream(torch.xpu.current_stream()) stream_counters[device] = stream_counter - return s + return ss[stream_counter] elif is_device_cuda(device): ss = [] for k in range(NUM_STREAMS): - ss.append(torch.cuda.Stream(device=device, priority=0)) + s1 = torch.cuda.Stream(device=device, priority=0) + s1.as_context = torch.cuda.stream + ss.append(s1) STREAMS[device] = ss s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter return s elif is_device_xpu(device): ss = [] for k in range(NUM_STREAMS): - ss.append(torch.xpu.Stream(device=device, priority=0)) + s1 = torch.xpu.Stream(device=device, priority=0) + s1.as_context = torch.xpu.stream + ss.append(s1) STREAMS[device] = ss s = ss[stream_counter] - stream_counter = (stream_counter + 1) % len(ss) stream_counters[device] = stream_counter return s return None def sync_stream(device, stream): - if stream is None: + if stream is None or current_stream(device) is None: return - if is_device_cuda(device): - torch.cuda.current_stream().wait_stream(stream) - elif is_device_xpu(device): - torch.xpu.current_stream().wait_stream(stream) + current_stream(device).wait_stream(stream) def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None): @@ -1224,12 +1247,18 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str if dtype is None or weight.dtype == dtype: return weight if stream is not None: - with stream: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + with wf_context: return weight.to(dtype=dtype, copy=copy) return weight.to(dtype=dtype, copy=copy) if stream is not None: - with stream: + wf_context = stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(stream) + with wf_context: r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) else: @@ -1243,6 +1272,85 @@ def cast_to_device(tensor, device, dtype, copy=False): return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) +PINNED_MEMORY = {} +TOTAL_PINNED_MEMORY = 0 +MAX_PINNED_MEMORY = -1 +if not args.disable_pinned_memory: + if is_nvidia() or is_amd(): + if WINDOWS: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.45 # Windows limit is apparently 50% + else: + MAX_PINNED_MEMORY = get_total_memory(torch.device("cpu")) * 0.95 + logging.info("Enabled pinned memory {}".format(MAX_PINNED_MEMORY // (1024 * 1024))) + +PINNING_ALLOWED_TYPES = set(["Parameter", "QuantizedTensor"]) + + +def pin_memory(tensor): + global TOTAL_PINNED_MEMORY + if MAX_PINNED_MEMORY <= 0: + return False + + if type(tensor).__name__ not in PINNING_ALLOWED_TYPES: + return False + + if not is_device_cpu(tensor.device): + return False + + if tensor.is_pinned(): + # NOTE: Cuda does detect when a tensor is already pinned and would + # error below, but there are proven cases where this also queues an error + # on the GPU async. So dont trust the CUDA API and guard here + return False + + if not tensor.is_contiguous(): + return False + + size = tensor.numel() * tensor.element_size() + if (TOTAL_PINNED_MEMORY + size) > MAX_PINNED_MEMORY: + return False + + ptr = tensor.data_ptr() + if ptr == 0: + return False + + if torch.cuda.cudart().cudaHostRegister(ptr, size, 1) == 0: + PINNED_MEMORY[ptr] = size + TOTAL_PINNED_MEMORY += size + return True + + return False + + +def unpin_memory(tensor): + global TOTAL_PINNED_MEMORY + if MAX_PINNED_MEMORY <= 0: + return False + + if not is_device_cpu(tensor.device): + return False + + ptr = tensor.data_ptr() + size = tensor.numel() * tensor.element_size() + + size_stored = PINNED_MEMORY.get(ptr, None) + if size_stored is None: + logging.warning("Tried to unpin tensor not pinned by ComfyUI") + return False + + if size != size_stored: + logging.warning("Size of pinned tensor changed") + return False + + if torch.cuda.cudart().cudaHostUnregister(ptr) == 0: + TOTAL_PINNED_MEMORY -= PINNED_MEMORY.pop(ptr) + if len(PINNED_MEMORY) == 0: + TOTAL_PINNED_MEMORY = 0 + return True + + return False + + def sage_attention_enabled(): return args.use_sage_attention @@ -1531,7 +1639,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_amd(): arch = torch.cuda.get_device_properties(device).gcnArchName - if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16 + if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16 if manual_cast: return True return False @@ -1607,6 +1715,23 @@ def extended_fp16_support(): return True +LORA_COMPUTE_DTYPES = {} + + +def lora_compute_dtype(device): + dtype = LORA_COMPUTE_DTYPES.get(device, None) + if dtype is not None: + return dtype + + if should_use_fp16(device): + dtype = torch.float16 + else: + dtype = torch.float32 + + LORA_COMPUTE_DTYPES[device] = dtype + return dtype + + def soft_empty_cache(force=False): with model_management_lock: _soft_empty_cache(force=force) diff --git a/comfy/model_management_types.py b/comfy/model_management_types.py index 07731e92c..00ec4079b 100644 --- a/comfy/model_management_types.py +++ b/comfy/model_management_types.py @@ -353,6 +353,9 @@ class MemoryMeasurements: current_weight_patches_uuid: Any = None _device: torch.device | None = None + def __init__(self): + self.model_offload_buffer_memory = None + @property def device(self) -> torch.device: if isinstance(self.model, DeviceSettable): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 453d812e6..f79bbe880 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -40,8 +40,10 @@ from .component_model.deprecation import _deprecate_method from .float import stochastic_rounding from .gguf import move_patch_to_device, is_torch_compatible, is_quantized, GGMLOps from .hooks import EnumHookMode, _HookRef, HookGroup, EnumHookType, WeightHook, create_transformer_options_from_hooks +from .lora import calculate_weight from .lora_types import PatchDict, PatchDictKey, PatchTuple, PatchWeightTuple, ModelPatchesDictValue, PatchSupport from .model_base import BaseModel +from .model_management import lora_compute_dtype from .model_management_types import ModelManageable, MemoryMeasurements, ModelOptions, LatentFormatT, LoadingListItem, TrainingSupport, HooksSupport from .patcher_extension import CallbacksMP, WrappersMP, PatcherInjection @@ -144,27 +146,23 @@ class LowVramPatch: def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches - self.convert_func = convert_func + self.convert_func = convert_func # TODO: remove self.set_func = set_func def __call__(self, weight): - intermediate_dtype = weight.dtype - if self.convert_func is not None: - weight = self.convert_func(weight.to(dtype=torch.float32, copy=True), inplace=True) + return calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) - if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: # intermediate_dtype has to be one that is supported in math ops - intermediate_dtype = torch.float32 - out = lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype) - if self.set_func is None: - return stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key)) - else: - return self.set_func(out, seed=string_to_seed(self.key), return_weight=True) - out = lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) - if self.set_func is not None: - return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype) - else: - return out +# The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 +LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 + + +def low_vram_patch_estimate_vram(model, key): + weight, set_func, convert_func = get_key_weight(model, key) + if weight is None: + return 0 + return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR + def get_key_weight(model, key): set_func = None @@ -253,7 +251,6 @@ class ModelPatcher(ModelManageable, PatchSupport): self.object_patches_backup = {} self.weight_wrapper_patches = {} self._model_options: ModelOptions = {"transformer_options": {}} - self.model_size() self.load_device = load_device self.offload_device = offload_device self.weight_inplace_update = weight_inplace_update @@ -262,6 +259,7 @@ class ModelPatcher(ModelManageable, PatchSupport): self.patches_uuid: uuid.UUID = uuid.uuid4() self.ckpt_name = ckpt_name self._memory_measurements = MemoryMeasurements(self.model) + self.pinned = set() self.attachments: dict[str] = {} self.additional_models: dict[str, list[ModelPatcher]] = {} self.callbacks: dict[str, dict[str, list[Callable]]] = CallbacksMP.init_callbacks() @@ -322,17 +320,28 @@ class ModelPatcher(ModelManageable, PatchSupport): def lowvram_patch_counter(self): return self._memory_measurements.lowvram_patch_counter + @property + def model_offload_buffer_memory(self) -> int: + return self._memory_measurements.model_offload_buffer_memory + + @model_offload_buffer_memory.setter + def model_offload_buffer_memory(self, value): + self._memory_measurements.model_offload_buffer_memory = value + def model_size(self): if self.size > 0: return self.size self.size = model_management.module_size(self.model) return self.size + def get_ram_usage(self): + return self.model_size() + def loaded_size(self): return self._memory_measurements.model_loaded_weight_memory def clone(self): - n = self.__class__(self.model, self.load_device, self.offload_device, self.size, weight_inplace_update=self.weight_inplace_update) + n = self.__class__(self.model, self.load_device, self.offload_device, self.model_size(), weight_inplace_update=self.weight_inplace_update) n._memory_measurements = self._memory_measurements n.ckpt_name = self.ckpt_name n.patches = {} @@ -346,6 +355,7 @@ class ModelPatcher(ModelManageable, PatchSupport): n.backup = self.backup n.object_patches_backup = self.object_patches_backup n._parent = self + n.pinned = self.pinned n.force_cast_weights = self.force_cast_weights @@ -430,8 +440,11 @@ class ModelPatcher(ModelManageable, PatchSupport): return True def memory_required(self, input_shape) -> int: - assert isinstance(self.model, BaseModel) - return self.model.memory_required(input_shape=input_shape) + if isinstance(self.model, BaseModel): + return self.model.memory_required(input_shape=input_shape) + else: + # todo: some other heuristic to determine memory required + raise ValueError("unexpected call to memory required on object that doesn't have a BaseModel but is using ModelPatcher") def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False): if len(inspect.signature(sampler_cfg_function).parameters) == 3: @@ -504,6 +517,18 @@ class ModelPatcher(ModelManageable, PatchSupport): def set_model_post_input_patch(self, patch): self.set_model_patch(patch, "post_input") + def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs): + rope_options = self.model_options["transformer_options"].get("rope_options", {}) + rope_options["scale_x"] = scale_x + rope_options["scale_y"] = scale_y + rope_options["scale_t"] = scale_t + + rope_options["shift_x"] = shift_x + rope_options["shift_y"] = shift_y + rope_options["shift_t"] = shift_t + + self.model_options["transformer_options"]["rope_options"] = rope_options + def add_object_patch(self, name, obj): self.object_patches[name] = obj @@ -686,10 +711,11 @@ class ModelPatcher(ModelManageable, PatchSupport): if key not in self.backup: self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update) + temp_dtype = lora_compute_dtype(device_to) if device_to is not None: - temp_weight = model_management.cast_to_device(weight, device_to, torch.float32, copy=True) + temp_weight = model_management.cast_to_device(weight, device_to, temp_dtype, copy=True) else: - temp_weight = weight.to(torch.float32, copy=True) + temp_weight = weight.to(temp_dtype, copy=True) if convert_func is not None: temp_weight = convert_func(temp_weight, inplace=True) @@ -703,6 +729,21 @@ class ModelPatcher(ModelManageable, PatchSupport): else: set_func(out_weight, inplace_update=inplace_update, seed=string_to_seed(key)) + def pin_weight_to_device(self, key): + weight, set_func, convert_func = get_key_weight(self.model, key) + if model_management.pin_memory(weight): + self.pinned.add(key) + + def unpin_weight(self, key): + if key in self.pinned: + weight, set_func, convert_func = get_key_weight(self.model, key) + model_management.unpin_memory(weight) + self.pinned.remove(key) + + def unpin_all_weights(self): + for key in list(self.pinned): + self.unpin_weight(key) + def _load_list(self) -> list[LoadingListItem]: loading = [] for n, m in self.model.named_modules(): @@ -715,7 +756,16 @@ class ModelPatcher(ModelManageable, PatchSupport): skip = True # skip random weights in non leaf modules break if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0): - loading.append(LoadingListItem(model_management.module_size(m), n, m, params)) + module_mem = model_management.module_size(m) + module_offload_mem = module_mem + if hasattr(m, "comfy_cast_weights"): + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if weight_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key) + if bias_key in self.patches: + module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key) + loading.append(LoadingListItem(module_offload_mem, module_mem, n, m, params)) return loading def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False): @@ -727,25 +777,30 @@ class ModelPatcher(ModelManageable, PatchSupport): mem_counter = 0 patch_counter = 0 lowvram_counter = 0 + lowvram_mem_counter = 0 loading = self._load_list() load_completely: list[LoadingListItem] = [] + offloaded = [] + offload_buffer = 0 loading.sort(reverse=True) - for x in loading: - n = x.name - m = x.module - params = x.params - module_mem = x.module_size + for i, x in enumerate(loading): + module_offload_mem, module_mem, n, m, params = x lowvram_weight = False + potential_offload = max(offload_buffer, module_offload_mem + sum([x1[1] for x1 in loading[i + 1:i + 1 + model_management.NUM_STREAMS]])) + lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory + weight_key = "{}.weight".format(n) bias_key = "{}.bias".format(n) if not full_load and hasattr(m, "comfy_cast_weights"): - if mem_counter + module_mem >= lowvram_model_memory: + if not lowvram_fits: + offload_buffer = potential_offload lowvram_weight = True lowvram_counter += 1 + lowvram_mem_counter += module_mem if hasattr(m, "prev_comfy_cast_weights"): # Already lowvramed continue @@ -771,13 +826,16 @@ class ModelPatcher(ModelManageable, PatchSupport): patch_counter += 1 cast_weight = True + offloaded.append((module_mem, n, m, params)) else: if hasattr(m, "comfy_cast_weights"): wipe_lowvram_weight(m) - if full_load or mem_counter + module_mem < lowvram_model_memory: + if full_load or lowvram_fits: mem_counter += module_mem load_completely.append(LoadingListItem(module_mem, n, m, params)) + else: + offload_buffer = potential_offload if cast_weight and hasattr(m, "comfy_cast_weights"): m.prev_comfy_cast_weights = m.comfy_cast_weights @@ -802,7 +860,11 @@ class ModelPatcher(ModelManageable, PatchSupport): continue for param in params: - self.patch_weight_to_device("{}.{}".format(n, param), device_to=device_to) + key = "{}.{}".format(n, param) + self.unpin_weight(key) + self.patch_weight_to_device(key, device_to=device_to) + if model_management.is_device_cuda(device_to): + torch.cuda.synchronize() models_loaded_regularly.append("name={} module={}".format(n, m)) m.comfy_patched_weights = True @@ -810,11 +872,21 @@ class ModelPatcher(ModelManageable, PatchSupport): for x in load_completely: x.module.to(device_to) + for x in offloaded: + n = x[1] + params = x[3] + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) + if lowvram_counter > 0: logger.debug(f"loaded partially lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB patch_counter={patch_counter}") + if hasattr(self.model, "model_lovram"): + self.model.model_lowvram = True self._memory_measurements.model_lowvram = True else: logger.debug(f"loaded completely lowvram_model_memory={lowvram_model_memory / (1024 * 1024):.1f}MB mem_counter={mem_counter / (1024 * 1024):.1f}MB full_load={full_load}") + if hasattr(self.model, "model_lovram"): + self.model.model_lowvram = False self._memory_measurements.model_lowvram = False if full_load: self.model.to(device_to) @@ -846,6 +918,7 @@ class ModelPatcher(ModelManageable, PatchSupport): self.model_device = device_to self._memory_measurements.model_loaded_weight_memory = mem_counter + self._memory_measurements.model_offload_buffer_memory = offload_buffer self._memory_measurements.current_weight_patches_uuid = self.patches_uuid for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD): @@ -881,6 +954,7 @@ class ModelPatcher(ModelManageable, PatchSupport): p.patches = [] if unpatch_weights: self.unpatch_hooks() + self.unpin_all_weights() if self._memory_measurements.model_lowvram: for m in self.model.modules(): move_weight_functions(m, device_to) @@ -907,6 +981,7 @@ class ModelPatcher(ModelManageable, PatchSupport): self.model.to(device_to) self.model_device = device_to self._memory_measurements.model_loaded_weight_memory = 0 + self._memory_measurements.model_offload_buffer_memory = 0 for m in self.model.modules(): if hasattr(m, "comfy_patched_weights"): @@ -918,7 +993,7 @@ class ModelPatcher(ModelManageable, PatchSupport): self.object_patches_backup.clear() - def partially_unload(self, device_to, memory_to_free=0): + def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=False): freed_layers: list[str] = [] with self.use_ejected(): hooks_unpatched = False @@ -926,13 +1001,18 @@ class ModelPatcher(ModelManageable, PatchSupport): patch_counter = 0 unload_list = self._load_list() unload_list.sort() + + offload_buffer = self._memory_measurements.model_offload_buffer_memory + if len(unload_list) > 0: + NS = model_management.NUM_STREAMS + offload_weight_factor = [min(offload_buffer / (NS + 1), unload_list[0][1])] * NS + for unload in unload_list: - if memory_to_free < memory_freed: + if memory_to_free + offload_buffer - self._memory_measurements.model_offload_buffer_memory < memory_freed: break - module_mem = unload[0] - n = unload[1] - m = unload[2] - params = unload[3] + module_offload_mem, module_mem, n, m, params = unload + + potential_offload = module_offload_mem + sum(offload_weight_factor) lowvram_possible = hasattr(m, "comfy_cast_weights") if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True: @@ -963,13 +1043,19 @@ class ModelPatcher(ModelManageable, PatchSupport): module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: - _, set_func, convert_func = get_key_weight(self.model, weight_key) - m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) - patch_counter += 1 + if force_patch_weights: + self.patch_weight_to_device(weight_key) + else: + _, set_func, convert_func = get_key_weight(self.model, weight_key) + m.weight_function.append(LowVramPatch(weight_key, self.patches, convert_func, set_func)) + patch_counter += 1 if bias_key in self.patches: - _, set_func, convert_func = get_key_weight(self.model, bias_key) - m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) - patch_counter += 1 + if force_patch_weights: + self.patch_weight_to_device(bias_key) + else: + _, set_func, convert_func = get_key_weight(self.model, bias_key) + m.bias_function.append(LowVramPatch(bias_key, self.patches, convert_func, set_func)) + patch_counter += 1 cast_weight = True if cast_weight: @@ -978,12 +1064,19 @@ class ModelPatcher(ModelManageable, PatchSupport): m.comfy_patched_weights = False memory_freed += module_mem freed_layers.append(n) + offload_buffer = max(offload_buffer, potential_offload) + offload_weight_factor.append(module_mem) + offload_weight_factor.pop(0) - logger.debug("freed {}".format(natsorted(freed_layers))) + for param in params: + self.pin_weight_to_device("{}.{}".format(n, param)) + + logger.debug(f"Freed {natsorted(freed_layers)}") self._memory_measurements.model_lowvram = True self._memory_measurements.lowvram_patch_counter += patch_counter self._memory_measurements.model_loaded_weight_memory -= memory_freed + self._memory_measurements.model_offload_buffer_memory = offload_buffer return memory_freed def partially_load(self, device_to, extra_memory=0, force_patch_weights=False) -> int: @@ -996,6 +1089,9 @@ class ModelPatcher(ModelManageable, PatchSupport): extra_memory += (used - self._memory_measurements.model_loaded_weight_memory) self.patch_model(load_weights=False) + if extra_memory < 0 and not unpatch_weights: + self.partially_unload(self.offload_device, -extra_memory, force_patch_weights=force_patch_weights) + return 0 full_load = False if not self._memory_measurements.model_lowvram and self._memory_measurements.model_loaded_weight_memory > 0: return 0 @@ -1399,4 +1495,5 @@ class ModelPatcher(ModelManageable, PatchSupport): self.clear_cached_hook_weights() def __del__(self): + self.unpin_all_weights() self.detach(unpatch_all=False) diff --git a/comfy/nested_tensor.py b/comfy/nested_tensor.py new file mode 100644 index 000000000..b700816fa --- /dev/null +++ b/comfy/nested_tensor.py @@ -0,0 +1,91 @@ +import torch + +class NestedTensor: + def __init__(self, tensors): + self.tensors = list(tensors) + self.is_nested = True + + def _copy(self): + return NestedTensor(self.tensors) + + def apply_operation(self, other, operation): + o = self._copy() + if isinstance(other, NestedTensor): + for i, t in enumerate(o.tensors): + o.tensors[i] = operation(t, other.tensors[i]) + else: + for i, t in enumerate(o.tensors): + o.tensors[i] = operation(t, other) + return o + + def __add__(self, b): + return self.apply_operation(b, lambda x, y: x + y) + + def __sub__(self, b): + return self.apply_operation(b, lambda x, y: x - y) + + def __mul__(self, b): + return self.apply_operation(b, lambda x, y: x * y) + + # def __itruediv__(self, b): + # return self.apply_operation(b, lambda x, y: x / y) + + def __truediv__(self, b): + return self.apply_operation(b, lambda x, y: x / y) + + def __getitem__(self, *args, **kwargs): + return self.apply_operation(None, lambda x, y: x.__getitem__(*args, **kwargs)) + + def unbind(self): + return self.tensors + + def to(self, *args, **kwargs): + o = self._copy() + for i, t in enumerate(o.tensors): + o.tensors[i] = t.to(*args, **kwargs) + return o + + def new_ones(self, *args, **kwargs): + return self.tensors[0].new_ones(*args, **kwargs) + + def float(self): + return self.to(dtype=torch.float) + + def chunk(self, *args, **kwargs): + return self.apply_operation(None, lambda x, y: x.chunk(*args, **kwargs)) + + def size(self): + return self.tensors[0].size() + + @property + def shape(self): + return self.tensors[0].shape + + @property + def ndim(self): + dims = 0 + for t in self.tensors: + dims = max(t.ndim, dims) + return dims + + @property + def device(self): + return self.tensors[0].device + + @property + def dtype(self): + return self.tensors[0].dtype + + @property + def layout(self): + return self.tensors[0].layout + + +def cat_nested(tensors, *args, **kwargs): + cated_tensors = [] + for i in range(len(tensors[0].tensors)): + tens = [] + for j in range(len(tensors)): + tens.append(tensors[j].tensors[i]) + cated_tensors.append(torch.cat(tens, *args, **kwargs)) + return NestedTensor(cated_tensors) diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 50f82f590..a29820970 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -745,8 +745,10 @@ class LoraLoaderModelOnly(LoraLoader): class VAELoader: + video_taes = ["taehv", "lighttaew2_2", "lighttaew2_1", "lighttaehy1_5"] + image_taes = ["taesd", "taesdxl", "taesd3", "taef1"] @staticmethod - def vae_list(): + def vae_list(s): vaes = get_filename_list_with_downloadable("vae", KNOWN_VAES) approx_vaes = get_filename_list_with_downloadable("vae_approx", KNOWN_APPROX_VAES) sdxl_taesd_enc = False @@ -775,6 +777,11 @@ class VAELoader: f1_taesd_dec = True elif v.startswith("taef1_decoder."): f1_taesd_enc = True + else: + for tae in s.video_taes: + if v.startswith(tae): + vaes.append(v) + if sd1_taesd_dec and sd1_taesd_enc: vaes.append("taesd") if sdxl_taesd_dec and sdxl_taesd_enc: @@ -818,8 +825,7 @@ class VAELoader: @classmethod def INPUT_TYPES(s): - return {"required": {"vae_name": (s.vae_list(),)}} - + return {"required": {"vae_name": (s.vae_list(s),)}} RETURN_TYPES = ("VAE",) FUNCTION = "load_vae" @@ -831,10 +837,13 @@ class VAELoader: if vae_name == "pixel_space": sd_ = {} sd_["pixel_space_vae"] = torch.tensor(1.0) - elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]: + elif vae_name in self.image_taes: sd_ = self.load_taesd(vae_name) else: - vae_path = get_full_path_or_raise("vae", vae_name, KNOWN_VAES) + if os.path.splitext(vae_name)[0] in self.video_taes: + vae_path = folder_paths.get_full_path_or_raise("vae_approx", vae_name) + else: + vae_path = get_full_path_or_raise("vae", vae_name, KNOWN_VAES) sd_, metadata = utils.load_torch_file(vae_path, return_metadata=True) vae = sd.VAE(sd=sd_, metadata=metadata, ckpt_name=vae_name) vae.throw_exception_if_invalid() @@ -1016,7 +1025,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": {"clip_name": (get_filename_list_with_downloadable("text_encoders", KNOWN_CLIP_MODELS),), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"],), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image", "flux2", "ovis"],), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -1046,7 +1055,7 @@ class DualCLIPLoader: def INPUT_TYPES(s): return {"required": {"clip_name1": (get_filename_list_with_downloadable("text_encoders"),), "clip_name2": ( get_filename_list_with_downloadable("text_encoders"),), - "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image"],), + "type": (["sdxl", "sd3", "flux", "hunyuan_video", "hidream", "hunyuan_image", "hunyuan_video_15", "kandinsky5", "kandinsky5_image"],), }, "optional": { "device": (["default", "cpu"], {"advanced": True}), @@ -2003,6 +2012,11 @@ class ImageBatch: CATEGORY = "image" def batch(self, image1, image2): + if image1.shape[-1] != image2.shape[-1]: + if image1.shape[-1] > image2.shape[-1]: + image2 = torch.nn.functional.pad(image2, (0,1), mode='constant', value=1.0) + else: + image1 = torch.nn.functional.pad(image1, (0,1), mode='constant', value=1.0) if image1.shape[1:] != image2.shape[1:]: image2 = utils.common_upscale(image2.movedim(-1, 1), image1.shape[2], image1.shape[1], "bilinear", "center").movedim(1, -1) s = torch.cat((image1, image2), dim=0) diff --git a/comfy/ops.py b/comfy/ops.py index c705460e1..6b2ef2a65 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -17,20 +17,22 @@ """ import contextlib import logging -import torch -from torch import Tensor from typing import Optional, Type, Union +import torch +from torch import Tensor + from . import model_management, rmsnorm -from .interruption import throw_exception_if_processing_interrupted from .cli_args import args, PerformanceFeature from .execution_context import current_execution_context -from .float import stochastic_rounding +from .interruption import throw_exception_if_processing_interrupted logger = logging.getLogger(__name__) _RUN_EVERY_OP_ENABLED = model_management.torch_version_numeric >= (2, 5) +import json + def run_every_op(): global _RUN_EVERY_OP_ENABLED @@ -49,7 +51,7 @@ def _scaled_dot_product_attention(q, k, v, *args, **kwargs): try: - if torch.cuda.is_available(): + if torch.cuda.is_available() and model_management.WINDOWS: from torch.nn.attention import SDPBackend, sdpa_kernel # pylint: disable=import-error import inspect @@ -82,7 +84,8 @@ except Exception as exc_info: NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False try: if model_management.is_nvidia(): - if torch.backends.cudnn.version() >= 91002 and model_management.torch_version_numeric >= (2, 9) and model_management.torch_version_numeric <= (2, 10): + cudnn_version = torch.backends.cudnn.version() + if (cudnn_version >= 91002 and cudnn_version < 91500) and model_management.torch_version_numeric >= (2, 9) and model_management.torch_version_numeric <= (2, 10): # TODO: change upper bound version once it's fixed' NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True logger.debug("working around nvidia conv3d memory bug.") @@ -96,41 +99,81 @@ def cast_to_input(weight, input, non_blocking=False, copy=True): return model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -@torch.compiler.disable() -def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False): + # NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass + # offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This + # will add async-offload support to your cast and improve performance. if input is not None: if dtype is None: - dtype = input.dtype + if isinstance(input, QuantizedTensor): + dtype = input._layout_params["orig_dtype"] + else: + dtype = input.dtype if bias_dtype is None: bias_dtype = dtype if device is None: device = input.device - offload_stream = model_management.get_offload_stream(device) + if offloadable and (device != s.weight.device or + (s.bias is not None and device != s.bias.device)): + offload_stream = model_management.get_offload_stream(device) + else: + offload_stream = None if offload_stream is not None: wf_context = offload_stream + if hasattr(wf_context, "as_context"): + wf_context = wf_context.as_context(offload_stream) else: wf_context = contextlib.nullcontext() - bias = None - non_blocking = True if torch.jit.is_tracing() or torch.jit.is_scripting() else model_management.device_supports_non_blocking(device) - if s.bias is not None: - has_function = len(s.bias_function) > 0 - bias = model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) - if has_function: - with wf_context: - for f in s.bias_function: - bias = f(bias) + # todo: how is wf_context used? + non_blocking = model_management.device_supports_non_blocking(device) - has_function = len(s.weight_function) > 0 - weight = model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) - if has_function: - with wf_context: - for f in s.weight_function: - weight = f(weight) + weight_has_function = len(s.weight_function) > 0 + bias_has_function = len(s.bias_function) > 0 + + weight = model_management.cast_to(s.weight, None, device, non_blocking=non_blocking, copy=weight_has_function, stream=offload_stream) + + bias = None + if s.bias is not None: + bias = model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=bias_has_function, stream=offload_stream) model_management.sync_stream(device, offload_stream) - return weight, bias + + bias_a = bias + weight_a = weight + + if s.bias is not None: + for f in s.bias_function: + bias = f(bias) + + if weight_has_function or weight.dtype != dtype: + weight = weight.to(dtype=dtype) + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + for f in s.weight_function: + weight = f(weight) + + if offloadable: + return weight, bias, (offload_stream, weight_a, bias_a) + else: + # Legacy function signature + return weight, bias + + +def uncast_bias_weight(s, weight, bias, offload_stream): + if offload_stream is None: + return + os, weight_a, bias_a = offload_stream + if os is None: + return + if weight_a is not None: + device = weight_a.device + else: + if bias_a is None: + return + device = bias_a.device + os.wait_stream(model_management.current_stream(device)) class SkipInit: @@ -191,8 +234,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.linear(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -206,8 +251,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -221,8 +268,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -245,8 +294,10 @@ class disable_weight_init: return super()._conv_forward(input, weight, bias, *args, **kwargs) def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return self._conv_forward(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._conv_forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -260,8 +311,10 @@ class disable_weight_init: return None def forward_comfy_cast_weights(self, input): - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -276,11 +329,14 @@ class disable_weight_init: def forward_comfy_cast_weights(self, input): if self.weight is not None: - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) else: weight = None bias = None - return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + offload_stream = None + x = torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -296,11 +352,15 @@ class disable_weight_init: def forward_comfy_cast_weights(self, input): if self.weight is not None: - weight, bias = cast_bias_weight(self, input) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) else: weight = None - return rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated - # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + bias = None + offload_stream = None + x = rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated + # x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -319,10 +379,12 @@ class disable_weight_init: input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.conv_transpose2d( + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.conv_transpose2d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -341,10 +403,12 @@ class disable_weight_init: input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.conv_transpose1d( + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.conv_transpose1d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -362,8 +426,10 @@ class disable_weight_init: output_dtype = out_dtype if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: out_dtype = None - weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) - return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) + weight, bias, offload_stream = cast_bias_weight(self, device=input.device, dtype=out_dtype, offloadable=True) + x = torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) + uncast_bias_weight(self, weight, bias, offload_stream) + return x def forward(self, *args, **kwargs): run_every_op() @@ -417,48 +483,33 @@ class manual_cast(disable_weight_init): def fp8_linear(self, input): + """ + Legacy FP8 linear function for backward compatibility. + Uses QuantizedTensor subclass for dispatch. + """ dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: return None - tensor_2d = False - if len(input.shape) == 2: - tensor_2d = True - input = input.unsqueeze(1) - - input_shape = input.shape input_dtype = input.dtype - if len(input.shape) == 3: - w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) - w = w.t() - scale_weight = self.scale_weight - scale_input = self.scale_input - if scale_weight is None: - scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - else: - scale_weight = scale_weight.to(input.device) + if input.ndim == 3 or input.ndim == 2: + w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - if scale_input is None: - scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() - else: - scale_input = scale_input.to(input.device) - input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} + quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) - if bias is not None: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) - else: - o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) + # Wrap weight in QuantizedTensor - this enables unified dispatch + # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! + layout_params_weight = {'scale': scale_weight, 'orig_dtype': input_dtype} + quantized_weight = QuantizedTensor(w, "TensorCoreFP8Layout", layout_params_weight) + o = torch.nn.functional.linear(quantized_input, quantized_weight, bias) - if isinstance(o, tuple): - o = o[0] - - if tensor_2d: - return o.reshape(input_shape[0], -1) - - return o.reshape((-1, input_shape[1], self.weight.shape[0])) + uncast_bias_weight(self, w, bias, offload_stream) + return o return None @@ -471,7 +522,7 @@ class fp8_ops(manual_cast): return None def forward_comfy_cast_weights(self, input): - if not self.training: + if len(self.weight_function) == 0 and len(self.bias_function) == 0: try: out = fp8_linear(self, input) if out is not None: @@ -479,67 +530,16 @@ class fp8_ops(manual_cast): except Exception as e: logger.info("Exception during fp8 op: {}".format(e)) - weight, bias = cast_bias_weight(self, input) - return torch.nn.functional.linear(input, weight, bias) + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = torch.nn.functional.linear(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x class scaled_fp8_op_base(manual_cast): pass -def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): - logger.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) - - class scaled_fp8_op(scaled_fp8_op_base): - class Linear(manual_cast.Linear): - def __init__(self, *args, **kwargs): - if override_dtype is not None: - kwargs['dtype'] = override_dtype - super().__init__(*args, **kwargs) - - def reset_parameters(self): - if not hasattr(self, 'scale_weight'): - self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - - if not scale_input: - self.scale_input = None - - if not hasattr(self, 'scale_input'): - self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - return None - - def forward_comfy_cast_weights(self, input): - if fp8_matrix_mult: - out = fp8_linear(self, input) - if out is not None: - return out - - weight, bias = cast_bias_weight(self, input) - - if weight.numel() < input.numel(): # TODO: optimize - return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) - else: - return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) - - def convert_weight(self, weight, inplace=False, **kwargs): - if inplace: - weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) - return weight - else: - return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) - - def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): - weight = stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) - if return_weight: - return weight - if inplace_update: - self.weight.data.copy_(weight) - else: - self.weight = torch.nn.Parameter(weight, requires_grad=False) - - return scaled_fp8_op - - CUBLAS_IS_AVAILABLE = False try: from cublas_ops import CublasLinear @@ -565,14 +565,178 @@ else: Operations = Type[Union[manual_cast, fp8_ops, disable_weight_init, skip_init, scaled_fp8_op_base]] +# ============================================================================== +# Mixed Precision Operations +# ============================================================================== +from .quant_ops import QuantizedTensor, QUANT_ALGOS -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8: Optional[torch.dtype] = None, inference_mode: Optional[bool] = None) -> Operations: + +def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): + class MixedPrecisionOps(manual_cast): + _quant_config = quant_config + _compute_dtype = compute_dtype + _full_precision_mm = full_precision_mm + + class Linear(torch.nn.Module, CastWeightBiasOp): + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + device=None, + dtype=None, + ) -> None: + super().__init__() + + self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype} + # self.factory_kwargs = {"device": device, "dtype": dtype} + + self.in_features = in_features + self.out_features = out_features + if bias: + self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs)) + else: + self.register_parameter("bias", None) + + self.tensor_class = None + self._full_precision_mm = MixedPrecisionOps._full_precision_mm + + def reset_parameters(self): + return None + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + device = self.factory_kwargs["device"] + layer_name = prefix.rstrip('.') + weight_key = f"{prefix}weight" + weight = state_dict.pop(weight_key, None) + if weight is None: + raise ValueError(f"Missing weight for layer {layer_name}") + + manually_loaded_keys = [weight_key] + + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + if layer_conf is None: + self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) + else: + self.quant_format = layer_conf.get("format", None) + if not self._full_precision_mm: + self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + + if self.quant_format is None: + raise ValueError(f"Unknown quantization format for layer {layer_name}") + + qconfig = QUANT_ALGOS[self.quant_format] + self.layout_type = qconfig["comfy_tensor_layout"] + + weight_scale_key = f"{prefix}weight_scale" + scale = state_dict.pop(weight_scale_key, None) + if scale is not None: + scale = scale.to(device) + layout_params = { + 'scale': scale, + 'orig_dtype': MixedPrecisionOps._compute_dtype, + 'block_size': qconfig.get("group_size", None), + } + + if scale is not None: + manually_loaded_keys.append(weight_scale_key) + + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params), + requires_grad=False + ) + + for param_name in qconfig["parameters"]: + param_key = f"{prefix}{param_name}" + _v = state_dict.pop(param_key, None) + if _v is None: + continue + self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + manually_loaded_keys.append(param_key) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + for key in manually_loaded_keys: + if key in missing_keys: + missing_keys.remove(key) + + def state_dict(self, *args, destination=None, prefix="", **kwargs): + sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) + if isinstance(self.weight, QuantizedTensor): + sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] + quant_conf = {"format": self.quant_format} + if self._full_precision_mm: + quant_conf["full_precision_matrix_mult"] = True + sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8) + return sd + + def _forward(self, input, weight, bias): + return torch.nn.functional.linear(input, weight, bias) + + def forward_comfy_cast_weights(self, input): + weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) + x = self._forward(input, weight, bias) + uncast_bias_weight(self, weight, bias, offload_stream) + return x + + def forward(self, input, *args, **kwargs): + run_every_op() + + if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: + return self.forward_comfy_cast_weights(input, *args, **kwargs) + if (getattr(self, 'layout_type', None) is not None and + not isinstance(input, QuantizedTensor)): + input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype) + return self._forward(input, self.weight, self.bias) + + def convert_weight(self, weight, inplace=False, **kwargs): + if isinstance(weight, QuantizedTensor): + return weight.dequantize() + else: + return weight + + def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): + if getattr(self, 'layout_type', None) is not None: + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + else: + weight = weight.to(self.weight.dtype) + if return_weight: + return weight + + assert inplace_update is False # TODO: eventually remove the inplace_update stuff + self.weight = torch.nn.Parameter(weight, requires_grad=False) + + def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working + if recurse: + for module in self.children(): + module._apply(fn) + + for key, param in self._parameters.items(): + if param is None: + continue + self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + return self + + return MixedPrecisionOps + + +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None, inference_mode: Optional[bool] = None): if inference_mode is None: # todo: check a context here, since this isn't being used by any callers yet inference_mode = current_execution_context().inference_mode - fp8_compute = model_management.supports_fp8_compute(load_device) - if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) + fp8_compute = model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular + + if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: + logging.info("Using mixed precision operations") + return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) if ( fp8_compute and diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py new file mode 100644 index 000000000..571d3f760 --- /dev/null +++ b/comfy/quant_ops.py @@ -0,0 +1,577 @@ +import torch +import logging +from typing import Tuple, Dict +import comfy.float + +_LAYOUT_REGISTRY = {} +_GENERIC_UTILS = {} + + +def register_layout_op(torch_op, layout_type): + """ + Decorator to register a layout-specific operation handler. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.linear.default) + layout_type: Layout class (e.g., TensorCoreFP8Layout) + Example: + @register_layout_op(torch.ops.aten.linear.default, TensorCoreFP8Layout) + def fp8_linear(func, args, kwargs): + # FP8-specific linear implementation + ... + """ + def decorator(handler_func): + if torch_op not in _LAYOUT_REGISTRY: + _LAYOUT_REGISTRY[torch_op] = {} + _LAYOUT_REGISTRY[torch_op][layout_type] = handler_func + return handler_func + return decorator + + +def register_generic_util(torch_op): + """ + Decorator to register a generic utility that works for all layouts. + Args: + torch_op: PyTorch operation (e.g., torch.ops.aten.detach.default) + + Example: + @register_generic_util(torch.ops.aten.detach.default) + def generic_detach(func, args, kwargs): + # Works for any layout + ... + """ + def decorator(handler_func): + _GENERIC_UTILS[torch_op] = handler_func + return handler_func + return decorator + + +def _get_layout_from_args(args): + for arg in args: + if isinstance(arg, QuantizedTensor): + return arg._layout_type + elif isinstance(arg, (list, tuple)): + for item in arg: + if isinstance(item, QuantizedTensor): + return item._layout_type + return None + + +def _move_layout_params_to_device(params, device): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.to(device=device) + else: + new_params[k] = v + return new_params + + +def _copy_layout_params(params): + new_params = {} + for k, v in params.items(): + if isinstance(v, torch.Tensor): + new_params[k] = v.clone() + else: + new_params[k] = v + return new_params + +def _copy_layout_params_inplace(src, dst, non_blocking=False): + for k, v in src.items(): + if isinstance(v, torch.Tensor): + dst[k].copy_(v, non_blocking=non_blocking) + else: + dst[k] = v + +class QuantizedLayout: + """ + Base class for quantization layouts. + + A layout encapsulates the format-specific logic for quantization/dequantization + and provides a uniform interface for extracting raw tensors needed for computation. + + New quantization formats should subclass this and implement the required methods. + """ + @classmethod + def quantize(cls, tensor, **kwargs) -> Tuple[torch.Tensor, Dict]: + raise NotImplementedError(f"{cls.__name__} must implement quantize()") + + @staticmethod + def dequantize(qdata, **layout_params) -> torch.Tensor: + raise NotImplementedError("TensorLayout must implement dequantize()") + + @classmethod + def get_plain_tensors(cls, qtensor) -> torch.Tensor: + raise NotImplementedError(f"{cls.__name__} must implement get_plain_tensors()") + + +class QuantizedTensor(torch.Tensor): + """ + Universal quantized tensor that works with any layout. + + This tensor subclass uses a pluggable layout system to support multiple + quantization formats (FP8, INT4, INT8, etc.) without code duplication. + + The layout_type determines format-specific behavior, while common operations + (detach, clone, to) are handled generically. + + Attributes: + _qdata: The quantized tensor data + _layout_type: Layout class (e.g., TensorCoreFP8Layout) + _layout_params: Dict with layout-specific params (scale, zero_point, etc.) + """ + + @staticmethod + def __new__(cls, qdata, layout_type, layout_params): + """ + Create a quantized tensor. + + Args: + qdata: The quantized data tensor + layout_type: Layout class (subclass of QuantizedLayout) + layout_params: Dict with layout-specific parameters + """ + return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, device=qdata.device, dtype=qdata.dtype, requires_grad=False) + + def __init__(self, qdata, layout_type, layout_params): + self._qdata = qdata + self._layout_type = layout_type + self._layout_params = layout_params + + def __repr__(self): + layout_name = self._layout_type + param_str = ", ".join(f"{k}={v}" for k, v in list(self._layout_params.items())[:2]) + return f"QuantizedTensor(shape={self.shape}, layout={layout_name}, {param_str})" + + @property + def layout_type(self): + return self._layout_type + + def __tensor_flatten__(self): + """ + Tensor flattening protocol for proper device movement. + """ + inner_tensors = ["_qdata"] + ctx = { + "layout_type": self._layout_type, + } + + tensor_params = {} + non_tensor_params = {} + for k, v in self._layout_params.items(): + if isinstance(v, torch.Tensor): + tensor_params[k] = v + else: + non_tensor_params[k] = v + + ctx["tensor_param_keys"] = list(tensor_params.keys()) + ctx["non_tensor_params"] = non_tensor_params + + for k, v in tensor_params.items(): + attr_name = f"_layout_param_{k}" + object.__setattr__(self, attr_name, v) + inner_tensors.append(attr_name) + + return inner_tensors, ctx + + @staticmethod + def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): + """ + Tensor unflattening protocol for proper device movement. + Reconstructs the QuantizedTensor after device movement. + """ + layout_type = ctx["layout_type"] + layout_params = dict(ctx["non_tensor_params"]) + + for key in ctx["tensor_param_keys"]: + attr_name = f"_layout_param_{key}" + layout_params[key] = inner_tensors[attr_name] + + return QuantizedTensor(inner_tensors["_qdata"], layout_type, layout_params) + + @classmethod + def from_float(cls, tensor, layout_type, **quantize_kwargs) -> 'QuantizedTensor': + qdata, layout_params = LAYOUTS[layout_type].quantize(tensor, **quantize_kwargs) + return cls(qdata, layout_type, layout_params) + + def dequantize(self) -> torch.Tensor: + return LAYOUTS[self._layout_type].dequantize(self._qdata, **self._layout_params) + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + + # Step 1: Check generic utilities first (detach, clone, to, etc.) + if func in _GENERIC_UTILS: + return _GENERIC_UTILS[func](func, args, kwargs) + + # Step 2: Check layout-specific handlers (linear, matmul, etc.) + layout_type = _get_layout_from_args(args) + if layout_type and func in _LAYOUT_REGISTRY: + handler = _LAYOUT_REGISTRY[func].get(layout_type) + if handler: + return handler(func, args, kwargs) + + # Step 3: Fallback to dequantization + if isinstance(args[0] if args else None, QuantizedTensor): + logging.info(f"QuantizedTensor: Unhandled operation {func}, falling back to dequantization. kwargs={kwargs}") + return cls._dequant_and_fallback(func, args, kwargs) + + @classmethod + def _dequant_and_fallback(cls, func, args, kwargs): + def dequant_arg(arg): + if isinstance(arg, QuantizedTensor): + return arg.dequantize() + elif isinstance(arg, (list, tuple)): + return type(arg)(dequant_arg(a) for a in arg) + return arg + + new_args = dequant_arg(args) + new_kwargs = dequant_arg(kwargs) + return func(*new_args, **new_kwargs) + + def data_ptr(self): + return self._qdata.data_ptr() + + def is_pinned(self): + return self._qdata.is_pinned() + + def is_contiguous(self, *arg, **kwargs): + return self._qdata.is_contiguous(*arg, **kwargs) + + def storage(self): + return self._qdata.storage() + +# ============================================================================== +# Generic Utilities (Layout-Agnostic Operations) +# ============================================================================== + +def _create_transformed_qtensor(qt, transform_fn): + new_data = transform_fn(qt._qdata) + new_params = _copy_layout_params(qt._layout_params) + return QuantizedTensor(new_data, qt._layout_type, new_params) + + +def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): + if target_layout is not None and target_layout != torch.strided: + logging.warning( + f"QuantizedTensor: layout change requested to {target_layout}, " + f"but not supported. Ignoring layout." + ) + + # Handle device transfer + current_device = qt._qdata.device + if target_device is not None: + # Normalize device for comparison + if isinstance(target_device, str): + target_device = torch.device(target_device) + if isinstance(current_device, str): + current_device = torch.device(current_device) + + if target_device != current_device: + logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") + new_q_data = qt._qdata.to(device=target_device) + new_params = _move_layout_params_to_device(qt._layout_params, target_device) + if target_dtype is not None: + new_params["orig_dtype"] = target_dtype + new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) + logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") + return new_qt + + logging.debug(f"QuantizedTensor.{op_name}: No device change needed, returning original") + return qt + + +@register_generic_util(torch.ops.aten.detach.default) +def generic_detach(func, args, kwargs): + """Detach operation - creates a detached copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.detach()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.clone.default) +def generic_clone(func, args, kwargs): + """Clone operation - creates a deep copy of the quantized tensor.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _create_transformed_qtensor(qt, lambda x: x.clone()) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._to_copy.default) +def generic_to_copy(func, args, kwargs): + """Device/dtype transfer operation - handles .to(device) calls.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + op_name="_to_copy" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.to.dtype_layout) +def generic_to_dtype_layout(func, args, kwargs): + """Handle .to(device) calls using the dtype_layout variant.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + return _handle_device_transfer( + qt, + target_device=kwargs.get('device', None), + target_dtype=kwargs.get('dtype', None), + target_layout=kwargs.get('layout', None), + op_name="to" + ) + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.copy_.default) +def generic_copy_(func, args, kwargs): + qt_dest = args[0] + src = args[1] + non_blocking = args[2] if len(args) > 2 else False + if isinstance(qt_dest, QuantizedTensor): + if isinstance(src, QuantizedTensor): + # Copy from another quantized tensor + qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking) + qt_dest._layout_type = src._layout_type + orig_dtype = qt_dest._layout_params["orig_dtype"] + _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking) + qt_dest._layout_params["orig_dtype"] = orig_dtype + else: + # Copy from regular tensor - just copy raw data + qt_dest._qdata.copy_(src) + return qt_dest + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten.to.dtype) +def generic_to_dtype(func, args, kwargs): + """Handle .to(dtype) calls - dtype conversion only.""" + src = args[0] + if isinstance(src, QuantizedTensor): + # For dtype-only conversion, just change the orig_dtype, no real cast is needed + target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype') + src._layout_params["orig_dtype"] = target_dtype + return src + return func(*args, **kwargs) + + +@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default) +def generic_has_compatible_shallow_copy_type(func, args, kwargs): + return True + + +@register_generic_util(torch.ops.aten.empty_like.default) +def generic_empty_like(func, args, kwargs): + """Empty_like operation - creates an empty tensor with the same quantized structure.""" + qt = args[0] + if isinstance(qt, QuantizedTensor): + # Create empty tensor with same shape and dtype as the quantized data + hp_dtype = kwargs.pop('dtype', qt._layout_params["orig_dtype"]) + new_qdata = torch.empty_like(qt._qdata, **kwargs) + + # Handle device transfer for layout params + target_device = kwargs.get('device', new_qdata.device) + new_params = _move_layout_params_to_device(qt._layout_params, target_device) + + # Update orig_dtype if dtype is specified + new_params['orig_dtype'] = hp_dtype + + return QuantizedTensor(new_qdata, qt._layout_type, new_params) + return func(*args, **kwargs) + +# ============================================================================== +# FP8 Layout + Operation Handlers +# ============================================================================== +class TensorCoreFP8Layout(QuantizedLayout): + """ + Storage format: + - qdata: FP8 tensor (torch.float8_e4m3fn or torch.float8_e5m2) + - scale: Scalar tensor (float32) for dequantization + - orig_dtype: Original dtype before quantization (for casting back) + """ + @classmethod + def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): + orig_dtype = tensor.dtype + + if isinstance(scale, str) and scale == "recalculate": + scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max + + if scale is not None: + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) + + if inplace_ops: + tensor *= (1.0 / scale).to(tensor.dtype) + else: + tensor = tensor * (1.0 / scale).to(tensor.dtype) + else: + scale = torch.ones((), device=tensor.device, dtype=torch.float32) + + if stochastic_rounding > 0: + tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) + else: + lp_amax = torch.finfo(dtype).max + torch.clamp(tensor, min=-lp_amax, max=lp_amax, out=tensor) + tensor = tensor.to(dtype, memory_format=torch.contiguous_format) + + layout_params = { + 'scale': scale, + 'orig_dtype': orig_dtype + } + return tensor, layout_params + + @staticmethod + def dequantize(qdata, scale, orig_dtype, **kwargs): + plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) + plain_tensor.mul_(scale) + return plain_tensor + + @classmethod + def get_plain_tensors(cls, qtensor): + return qtensor._qdata, qtensor._layout_params['scale'] + +QUANT_ALGOS = { + "float8_e4m3fn": { + "storage_t": torch.float8_e4m3fn, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreFP8Layout", + }, +} + +LAYOUTS = { + "TensorCoreFP8Layout": TensorCoreFP8Layout, +} + + +@register_layout_op(torch.ops.aten.linear.default, "TensorCoreFP8Layout") +def fp8_linear(func, args, kwargs): + input_tensor = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + + out_dtype = kwargs.get("out_dtype") + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + + weight_t = plain_weight.t() + + tensor_2d = False + if len(plain_input.shape) == 2: + tensor_2d = True + plain_input = plain_input.unsqueeze(1) + + input_shape = plain_input.shape + if len(input_shape) != 3: + return None + + try: + output = torch._scaled_mm( + plain_input.reshape(-1, input_shape[2]).contiguous(), + weight_t, + bias=bias, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + ) + + if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 + output = output[0] + + if not tensor_2d: + output = output.reshape((-1, input_shape[1], weight.shape[0])) + + if output.dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: + output_scale = scale_a * scale_b + output_params = { + 'scale': output_scale, + 'orig_dtype': input_tensor._layout_params['orig_dtype'] + } + return QuantizedTensor(output, "TensorCoreFP8Layout", output_params) + else: + return output + + except Exception as e: + raise RuntimeError(f"FP8 _scaled_mm failed, falling back to dequantization: {e}") + + # Case 2: DQ Fallback + if isinstance(weight, QuantizedTensor): + weight = weight.dequantize() + if isinstance(input_tensor, QuantizedTensor): + input_tensor = input_tensor.dequantize() + + return torch.nn.functional.linear(input_tensor, weight, bias) + +def fp8_mm_(input_tensor, weight, bias=None, out_dtype=None): + if out_dtype is None: + out_dtype = input_tensor._layout_params['orig_dtype'] + + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + plain_weight, scale_b = TensorCoreFP8Layout.get_plain_tensors(weight) + + output = torch._scaled_mm( + plain_input.contiguous(), + plain_weight, + bias=bias, + scale_a=scale_a, + scale_b=scale_b, + out_dtype=out_dtype, + ) + + if isinstance(output, tuple): # TODO: remove when we drop support for torch 2.4 + output = output[0] + return output + +@register_layout_op(torch.ops.aten.addmm.default, "TensorCoreFP8Layout") +def fp8_addmm(func, args, kwargs): + input_tensor = args[1] + weight = args[2] + bias = args[0] + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + return fp8_mm_(input_tensor, weight, bias=bias, out_dtype=kwargs.get("out_dtype", None)) + + a = list(args) + if isinstance(args[0], QuantizedTensor): + a[0] = args[0].dequantize() + if isinstance(args[1], QuantizedTensor): + a[1] = args[1].dequantize() + if isinstance(args[2], QuantizedTensor): + a[2] = args[2].dequantize() + + return func(*a, **kwargs) + +@register_layout_op(torch.ops.aten.mm.default, "TensorCoreFP8Layout") +def fp8_mm(func, args, kwargs): + input_tensor = args[0] + weight = args[1] + + if isinstance(input_tensor, QuantizedTensor) and isinstance(weight, QuantizedTensor): + return fp8_mm_(input_tensor, weight, bias=None, out_dtype=kwargs.get("out_dtype", None)) + + a = list(args) + if isinstance(args[0], QuantizedTensor): + a[0] = args[0].dequantize() + if isinstance(args[1], QuantizedTensor): + a[1] = args[1].dequantize() + return func(*a, **kwargs) + +@register_layout_op(torch.ops.aten.view.default, "TensorCoreFP8Layout") +@register_layout_op(torch.ops.aten.t.default, "TensorCoreFP8Layout") +def fp8_func(func, args, kwargs): + input_tensor = args[0] + if isinstance(input_tensor, QuantizedTensor): + plain_input, scale_a = TensorCoreFP8Layout.get_plain_tensors(input_tensor) + ar = list(args) + ar[0] = plain_input + return QuantizedTensor(func(*ar, **kwargs), "TensorCoreFP8Layout", input_tensor._layout_params) + return func(*args, **kwargs) diff --git a/comfy/sample.py b/comfy/sample.py index 0173c1ca9..6ec6da18b 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -5,6 +5,21 @@ from . import model_management from . import samplers from . import utils from .component_model.deprecation import _deprecate_method +from .nested_tensor import NestedTensor + + +def prepare_noise_inner(latent_image, generator, noise_inds=None): + if noise_inds is None: + return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + + unique_inds, inverse = np.unique(noise_inds, return_inverse=True) + noises = [] + for i in range(unique_inds[-1] + 1): + noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") + if i in unique_inds: + noises.append(noise) + noises = [noises[i] for i in inverse] + return torch.cat(noises, axis=0) def prepare_noise(latent_image, seed, noise_inds=None): @@ -13,36 +28,41 @@ def prepare_noise(latent_image, seed, noise_inds=None): optional arg skip can be used to skip and discard x number of noise generations for a given seed """ generator = torch.manual_seed(seed) - if noise_inds is None: - return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - unique_inds, inverse = np.unique(noise_inds, return_inverse=True) - noises = [] - for i in range(unique_inds[-1]+1): - noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - if i in unique_inds: - noises.append(noise) - noises = [noises[i] for i in inverse] - noises = torch.cat(noises, axis=0) + if latent_image.is_nested: + tensors = latent_image.unbind() + noises = [] + for t in tensors: + noises.append(prepare_noise_inner(t, generator, noise_inds)) + noises = NestedTensor(noises) + else: + noises = prepare_noise_inner(latent_image, generator, noise_inds) + return noises + def fix_empty_latent_channels(model, latent_image): - latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels + if latent_image.is_nested: + return latent_image + latent_format = model.get_model_object("latent_format") # Resize the empty latent image so it has the right number of channels if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0: latent_image = utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1) if latent_format.latent_dimensions == 3 and latent_image.ndim == 4: latent_image = latent_image.unsqueeze(2) return latent_image + @_deprecate_method(version="0.3.2", message="Warning: comfy.sample.prepare_sampling isn't used anymore and can be removed") def prepare_sampling(model, noise_shape, positive, negative, noise_mask): pass return model, positive, negative, noise_mask, [] + @_deprecate_method(version="0.3.2", message="Warning: comfy.sample.cleanup_additional_models isn't used anymore and can be removed") def cleanup_additional_models(models): pass + def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False, noise_mask=None, sigmas=None, callback=None, disable_pbar=False, seed=None): sampler = samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) @@ -50,6 +70,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative samples = samples.to(model_management.intermediate_device()) return samples + def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): samples = samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = samples.to(model_management.intermediate_device()) diff --git a/comfy/samplers.py b/comfy/samplers.py index 147ba2895..ec426b5e6 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -24,7 +24,7 @@ from .model_management_types import ModelOptions from .model_patcher import ModelPatcher from .sampler_names import SCHEDULER_NAMES, SAMPLER_NAMES, KSAMPLER_NAMES from .context_windows import ContextHandlerABC -from .utils import common_upscale +from .utils import common_upscale, pack_latents, unpack_latents from .patcher_extension import WrapperExecutor, get_all_wrappers, WrappersMP from .component_model import module_property @@ -827,7 +827,7 @@ def ksampler(sampler_name, extra_options={}, inpaint_options={}): return KSAMPLER(sampler_function, extra_options, inpaint_options) -def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None): +def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=None, seed=None, latent_shapes=None): for k in conds: conds[k] = conds[k][:] resolve_areas_and_cond_masks_multidim(conds[k], noise.shape[2:], device) @@ -837,7 +837,7 @@ def process_conds(model, noise, conds, device, latent_image=None, denoise_mask=N if hasattr(model, 'extra_conds'): for k in conds: - conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed) + conds[k] = encode_model_conds(model.extra_conds, conds[k], noise, device, k, latent_image=latent_image, denoise_mask=denoise_mask, seed=seed, latent_shapes=latent_shapes) # make sure each cond area has an opposite one with the same area for k in conds: @@ -1008,11 +1008,11 @@ class CFGGuider: def predict_noise(self, x, timestep, model_options={}, seed=None): return sampling_function(self.inner_model, x, timestep, self.conds.get("negative", None), self.conds.get("positive", None), self.cfg, model_options=model_options, seed=seed) - def inner_sample(self, noise, latent_image, device, sampler: KSAMPLER, sigmas, denoise_mask, callback, disable_pbar, seed): + def inner_sample(self, noise, latent_image, device, sampler: KSAMPLER, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=None): if latent_image is not None and torch.count_nonzero(latent_image) > 0: # Don't shift the empty latent image. latent_image = self.inner_model.process_latent_in(latent_image) - self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed) + self.conds = process_conds(self.inner_model, noise, self.conds, device, latent_image, denoise_mask, seed, latent_shapes=latent_shapes) extra_model_options = model_patcher.create_model_options_clone(self.model_options) extra_model_options.setdefault("transformer_options", {})["sample_sigmas"] = sigmas @@ -1026,7 +1026,7 @@ class CFGGuider: samples = executor.execute(self, sigmas, extra_args, callback, noise, latent_image, denoise_mask, disable_pbar) return self.inner_model.process_latent_out(samples.to(torch.float32)) - def outer_sample(self, noise, latent_image, sampler: KSAMPLER, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None): + def outer_sample(self, noise, latent_image, sampler: KSAMPLER, sigmas, denoise_mask=None, callback=None, disable_pbar=False, seed=None, latent_shapes=None): self.inner_model, self.conds, self.loaded_models = sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) device = self.model_patcher.load_device @@ -1040,7 +1040,7 @@ class CFGGuider: try: self.model_patcher.pre_run() - output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + output = self.inner_sample(noise, latent_image, device, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) finally: self.model_patcher.cleanup() @@ -1053,6 +1053,12 @@ class CFGGuider: if sigmas.shape[-1] == 0: return latent_image + if latent_image.is_nested: + latent_image, latent_shapes = pack_latents(latent_image.unbind()) + noise, _ = pack_latents(noise.unbind()) + else: + latent_shapes = [latent_image.shape] + self.conds = {} for k in self.original_conds: self.conds[k] = list(map(lambda a: a.copy(), self.original_conds[k])) @@ -1072,7 +1078,7 @@ class CFGGuider: self, patcher_extension.get_all_wrappers(patcher_extension.WrappersMP.OUTER_SAMPLE, self.model_options, is_model_options=True) ) - output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed) + output = executor.execute(noise, latent_image, sampler, sigmas, denoise_mask, callback, disable_pbar, seed, latent_shapes=latent_shapes) except ValueError as exc_info: if "fp8e4nv" in str(exc_info): logger.error(f"Load the weights for model {self.model_patcher} as fp8_e5m2 to use floating point 8-bit inference with torch.compile and triton on Ampere architecture") @@ -1084,6 +1090,9 @@ class CFGGuider: self.model_patcher.restore_hook_patches() del self.conds + + if len(latent_shapes) > 1: + output = NestedTensor(unpack_latents(output, latent_shapes)) return output diff --git a/comfy/sd.py b/comfy/sd.py index b258563e0..1a639bc21 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -6,12 +6,11 @@ import logging import math import os import os.path -import torch -import yaml from enum import Enum from typing import Any, Optional -from humanize import naturalsize +import torch +import yaml from . import clip_vision from . import diffusers_convert @@ -34,14 +33,15 @@ from .ldm.flux.redux import ReduxImageEncoder from .ldm.genmo.vae import model as genmo_model from .ldm.hunyuan3d.vae import ShapeVAE from .ldm.lightricks.vae import causal_video_autoencoder as lightricks -from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.mmaudio.vae.autoencoder import AudioAutoencoder +from .ldm.models.autoencoder import AutoencoderKL, AutoencodingEngine from .ldm.wan import vae as wan_vae from .ldm.wan import vae2_2 as wan_vae2_2 from .lora import load_lora, model_lora_keys_unet, model_lora_keys_clip from .lora_convert import convert_lora -from .model_management import load_models_gpu +from .model_management import load_models_gpu, module_size from .model_patcher import ModelPatcher +from .pixel_space_convert import PixelspaceConversionVAE from .t2i_adapter import adapter from .taesd import taesd from .text_encoders import ace @@ -50,21 +50,25 @@ from .text_encoders import cosmos from .text_encoders import flux from .text_encoders import genmo from .text_encoders import hidream -from .text_encoders import hunyuan_video from .text_encoders import hunyuan_image +from .text_encoders import hunyuan_video from .text_encoders import hydit +from .text_encoders import kandinsky5 from .text_encoders import long_clipl from .text_encoders import lt from .text_encoders import lumina2 from .text_encoders import omnigen2 +from .text_encoders import ovis from .text_encoders import pixart_t5 from .text_encoders import qwen_image from .text_encoders import sa_t5 from .text_encoders import sd2_clip from .text_encoders import sd3_clip from .text_encoders import wan -from .utils import ProgressBar, FileMetadata -from .pixel_space_convert import PixelspaceConversionVAE +from .text_encoders import z_image +from .utils import ProgressBar, FileMetadata, state_dict_prefix_replace +from .taesd.taehv import TAEHV +from .latent_formats import HunyuanVideo15, HunyuanVideo logger = logging.getLogger(__name__) @@ -101,7 +105,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_ class CLIP: - def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None = None, parameters=0, model_options={}): + def __init__(self, target: CLIPTarget = None, embedding_directory=None, no_init=False, textmodel_json_config=None, tokenizer_data: dict | None = None, parameters=0, state_dict=[], model_options={}): if tokenizer_data is None: tokenizer_data = dict() if no_init: @@ -137,6 +141,27 @@ class CLIP: self.patcher.hook_mode = EnumHookMode.MinVram self.patcher.is_clip = True self.apply_hooks_to_conds = None + if len(state_dict) > 0: + if isinstance(state_dict, list): + for c in state_dict: + m, u = self.load_sd(c) + if len(m) > 0: + logging.warning("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected: {}".format(u)) + else: + m, u = self.load_sd(state_dict, full_model=True) + if len(m) > 0: + m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) + if len(m_filter) > 0: + logging.warning("clip missing: {}".format(m)) + else: + logging.debug("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected {}:".format(u)) + if params['device'] == load_device: model_management.load_models_gpu([self.patcher], force_full_load=True) self.layer_idx = None @@ -156,6 +181,9 @@ class CLIP: n.apply_hooks_to_conds = self.apply_hooks_to_conds return n + def get_ram_usage(self): + return self.patcher.get_ram_usage() + def add_patches(self, patches, strength_patch=1.0, strength_model=1.0): return self.patcher.add_patches(patches, strength_patch, strength_model) @@ -199,6 +227,7 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) all_hooks.reset() self.patcher.patch_hooks(None) if show_pbar: @@ -246,6 +275,7 @@ class CLIP: self.cond_stage_model.set_clip_options({"projected_pooled": False}) self.load_model() + self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device}) o = self.cond_stage_model.encode_token_weights(tokens) cond, pooled = o[:2] if return_dict: @@ -310,6 +340,7 @@ class VAE: self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = False self.not_video = False + self.size = None self.downscale_index_formula = None self.upscale_index_formula = None @@ -369,7 +400,7 @@ class VAE: self.memory_used_encode = lambda shape, dtype: (700 * shape[2] * shape[3]) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (700 * shape[2] * shape[3] * 32 * 32) * model_management.dtype_size(dtype) - elif sd['decoder.conv_in.weight'].shape[1] == 32: + elif sd['decoder.conv_in.weight'].shape[1] == 32 and sd['decoder.conv_in.weight'].ndim == 5: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True, "refiner_vae": False} self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] @@ -395,6 +426,17 @@ class VAE: self.upscale_ratio = 4 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.weight"].shape[1] + if 'decoder.post_quant_conv.weight' in sd: + sd = state_dict_prefix_replace(sd, {"decoder.post_quant_conv.": "post_quant_conv.", "encoder.quant_conv.": "quant_conv."}) + + if 'bn.running_mean' in sd: + ddconfig["batch_norm_latent"] = True + self.downscale_ratio *= 2 + self.upscale_ratio *= 2 + self.latent_channels *= 4 + old_memory_used_decode = self.memory_used_decode + self.memory_used_decode = lambda shape, dtype: old_memory_used_decode(shape, dtype) * 4.0 + if 'post_quant_conv.weight' in sd: self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) else: @@ -454,20 +496,20 @@ class VAE: elif "decoder.conv_in.conv.weight" in sd and sd['decoder.conv_in.conv.weight'].shape[1] == 32: ddconfig = {"block_out_channels": [128, 256, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 16, "ffactor_temporal": 4, "downsample_match_channel": True, "upsample_match_channel": True} ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] - self.latent_channels = 64 + self.latent_channels = 32 self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) self.upscale_index_formula = (4, 16, 16) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) self.downscale_index_formula = (4, 16, 16) self.latent_dim = 3 - self.not_video = True + self.not_video = False self.working_dtypes = [torch.float16, torch.bfloat16, torch.float32] self.first_stage_model = AutoencodingEngine(regularizer_config={'target': "comfy.ldm.models.autoencoder.EmptyRegularizer"}, encoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Encoder", 'params': ddconfig}, decoder_config={'target': "comfy.ldm.hunyuan_video.vae_refiner.Decoder", 'params': ddconfig}) - self.memory_used_encode = lambda shape, dtype: (1400 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) - self.memory_used_decode = lambda shape, dtype: (1400 * shape[-3] * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype) + self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype) elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True @@ -479,8 +521,10 @@ class VAE: self.latent_dim = 3 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) - self.memory_used_decode = lambda shape, dtype: (1500 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype) - self.memory_used_encode = lambda shape, dtype: (900 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype) + # This is likely to significantly over-estimate with single image or low frame counts as the + # implementation is able to completely skip caching. Rework if used as an image only VAE + self.memory_used_decode = lambda shape, dtype: (2800 * min(8, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype) + self.memory_used_encode = lambda shape, dtype: (1400 * min(9, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] elif "decoder.unpatcher3d.wavelets" in sd: self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 8, 8) @@ -509,13 +553,14 @@ class VAE: self.memory_used_encode = lambda shape, dtype: 3300 * shape[3] * shape[4] * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: 8000 * shape[3] * shape[4] * (16 * 16) * model_management.dtype_size(dtype) else: # Wan 2.1 VAE + dim = sd["decoder.head.0.gamma"].shape[0] self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) self.upscale_index_formula = (4, 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 self.latent_channels = 16 - ddconfig = {"dim": 96, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} + ddconfig = {"dim": dim, "z_dim": self.latent_channels, "dim_mult": [1, 2, 4, 4], "num_res_blocks": 2, "attn_scales": [], "temperal_downsample": [False, True, True], "dropout": 0.0} self.first_stage_model = wan_vae.WanVAE(**ddconfig) self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32] @@ -593,6 +638,35 @@ class VAE: self.process_input = lambda audio: audio self.working_dtypes = [torch.float32] self.crop_input = False + elif "decoder.22.bias" in sd: # taehv, taew and lighttae + self.latent_channels = sd["decoder.1.weight"].shape[1] + self.latent_dim = 3 + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 16, 16) + self.upscale_index_formula = (4, 16, 16) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 16, 16) + self.downscale_index_formula = (4, 16, 16) + if self.latent_channels == 48: # Wan 2.2 + self.first_stage_model = TAEHV(latent_channels=self.latent_channels, latent_format=None) # taehv doesn't need scaling + self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.process_output = lambda image: image + self.memory_used_decode = lambda shape, dtype: (1800 * (max(1, (shape[-3] ** 0.7 * 0.1)) * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)) + elif self.latent_channels == 32 and sd["decoder.22.bias"].shape[0] == 12: # lighttae_hv15 + self.first_stage_model = TAEHV(latent_channels=self.latent_channels, latent_format=HunyuanVideo15) + self.process_input = lambda image: (_ for _ in ()).throw(NotImplementedError("This light tae doesn't support encoding currently")) + self.memory_used_decode = lambda shape, dtype: (1200 * (max(1, (shape[-3] ** 0.7 * 0.05)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) + else: + if sd["decoder.1.weight"].dtype == torch.float16: # taehv currently only available in float16, so assume it's not lighttaew2_1 as otherwise state dicts are identical + latent_format = HunyuanVideo + else: + latent_format = None # lighttaew2_1 doesn't need scaling + self.first_stage_model = TAEHV(latent_channels=self.latent_channels, latent_format=latent_format) + self.process_input = self.process_output = lambda image: image + self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) + self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) + self.memory_used_encode = lambda shape, dtype: (700 * (max(1, (shape[-3] ** 0.66 * 0.11)) * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)) + self.memory_used_decode = lambda shape, dtype: (50 * (max(1, (shape[-3] ** 0.65 * 0.26)) * shape[-2] * shape[-1] * 32 * 32) * model_management.dtype_size(dtype)) else: logger.warning("WARNING: No VAE weights detected, VAE not initalized.") self.first_stage_model = None @@ -620,6 +694,8 @@ class VAE: self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) logger.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + # todo: why is this being called here? for what side effects exactly? + self.model_size() def clone(self): n = VAE(no_init=True) @@ -644,6 +720,15 @@ class VAE: n.patcher = self.patcher.clone() return n + def model_size(self): + if self.size is not None: + return self.size + self.size = module_size(self.first_stage_model) + return self.size + + def get_ram_usage(self): + return self.model_size() + def throw_exception_if_invalid(self): if self.first_stage_model is None: raise RuntimeError("ERROR: VAE is invalid: None\n\nIf the VAE is from a checkpoint loader node your checkpoint does not contain a valid VAE.") @@ -704,6 +789,7 @@ class VAE: return samples def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): + extra_channel_size = 0 if self.latent_dim == 1: encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() out_channels = self.latent_channels @@ -730,6 +816,8 @@ class VAE: self.throw_exception_if_invalid() pixel_samples = None do_tile = False + if self.latent_dim == 2 and samples_in.ndim == 5: + samples_in = samples_in[:, :, 0] try: memory_used = self.memory_used_decode(samples_in.shape, self.vae_dtype) model_management.load_models_gpu([self.patcher], memory_required=memory_used, force_full_load=self.disable_offload) @@ -960,6 +1048,10 @@ class CLIPType(Enum): OMNIGEN2 = 17 QWEN_IMAGE = 18 HUNYUAN_IMAGE = 19 + HUNYUAN_VIDEO_15 = 20 + OVIS = 21 + KANDINSKY5 = 22 + KANDINSKY5_IMAGE = 23 @dataclasses.dataclass @@ -975,8 +1067,11 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI model_options = dict() clip_data = [] for p in ckpt_paths: - clip_data.append(utils.load_torch_file(p, safe_load=True)) - return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options, textmodel_json_config=textmodel_json_config) + sd, metadata = utils.load_torch_file(p, safe_load=True, return_metadata=True) + if model_options.get("custom_operations", None) is None: + sd, metadata = utils.convert_old_quants(sd, model_prefix="", metadata=metadata) + clip_data.append(sd) + return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) class TEModel(Enum): @@ -993,6 +1088,10 @@ class TEModel(Enum): QWEN25_7B = 11 BYT5_SMALL_GLYPH = 12 GEMMA_3_4B = 13 + MISTRAL3_24B = 14 + MISTRAL3_24B_PRUNED_FLUX2 = 15 + QWEN3_4B = 16 + QWEN3_2B = 17 def detect_te_model(sd): @@ -1026,6 +1125,18 @@ def detect_te_model(sd): if weight.shape[0] == 512: return TEModel.QWEN25_7B if "model.layers.0.post_attention_layernorm.weight" in sd: + weight = sd['model.layers.0.post_attention_layernorm.weight'] + if 'model.layers.0.self_attn.q_norm.weight' in sd: + if weight.shape[0] == 2560: + return TEModel.QWEN3_4B + elif weight.shape[0] == 2048: + return TEModel.QWEN3_2B + if weight.shape[0] == 5120: + if "model.layers.39.post_attention_layernorm.weight" in sd: + return TEModel.MISTRAL3_24B + else: + return TEModel.MISTRAL3_24B_PRUNED_FLUX2 + return TEModel.LLAMA3_8 return None @@ -1077,7 +1188,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False) clip_target.tokenizer = sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLRefinerClipModel @@ -1101,7 +1212,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.HIDREAM: clip_target.clip = hidream.hidream_clip(**t5xxl_detect(clip_data), - clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) + clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None) clip_target.tokenizer = hidream.HiDreamTokenizer else: # CLIPType.MOCHI clip_target.clip = genmo.mochi_te(**t5xxl_detect(clip_data)) @@ -1130,7 +1241,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.LLAMA3_8: clip_target.clip = hidream.hidream_clip(**llama_detect(clip_data), - clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) + clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None) clip_target.tokenizer = hidream.HiDreamTokenizer elif te_model == TEModel.QWEN25_3B: clip_target.clip = omnigen2.te(**llama_detect(clip_data)) @@ -1142,13 +1253,23 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip else: clip_target.clip = qwen_image.te(**llama_detect(clip_data)) clip_target.tokenizer = qwen_image.QwenImageTokenizer + elif te_model == TEModel.MISTRAL3_24B or te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2: + clip_target.clip = flux.flux2_te(**llama_detect(clip_data), pruned=te_model == TEModel.MISTRAL3_24B_PRUNED_FLUX2) + clip_target.tokenizer = flux.Flux2Tokenizer + tokenizer_data["tekken_model"] = clip_data[0].get("tekken_model", None) + elif te_model == TEModel.QWEN3_4B: + clip_target.clip = z_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = z_image.ZImageTokenizer + elif te_model == TEModel.QWEN3_2B: + clip_target.clip = ovis.te(**llama_detect(clip_data)) + clip_target.tokenizer = ovis.OvisTokenizer else: # clip_l if clip_type == CLIPType.SD3: clip_target.clip = sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) clip_target.tokenizer = sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = hidream.HiDreamTokenizer else: clip_target.clip = sd1_clip.SD1ClipModel @@ -1188,6 +1309,15 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.HUNYUAN_IMAGE: clip_target.clip = hunyuan_image.te(**llama_detect(clip_data)) clip_target.tokenizer = hunyuan_image.HunyuanImageTokenizer + elif clip_type == CLIPType.HUNYUAN_VIDEO_15: + clip_target.clip = hunyuan_image.te(**llama_detect(clip_data)) + clip_target.tokenizer = hunyuan_video.HunyuanVideo15Tokenizer + elif clip_type == CLIPType.KANDINSKY5: + clip_target.clip = kandinsky5.te(**llama_detect(clip_data)) + clip_target.tokenizer = kandinsky5.Kandinsky5Tokenizer + elif clip_type == CLIPType.KANDINSKY5_IMAGE: + clip_target.clip = kandinsky5.te(**llama_detect(clip_data)) + clip_target.tokenizer = kandinsky5.Kandinsky5TokenizerImage else: clip_target.clip = sdxl_clip.SDXLClipModel clip_target.tokenizer = sdxl_clip.SDXLTokenizer @@ -1203,14 +1333,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters += utils.calculate_parameters(c) tokenizer_data, model_options = long_clipl.model_options_long_clip(c, tokenizer_data, model_options) - clip = CLIP(clip_target, embedding_directory=embedding_directory, textmodel_json_config=textmodel_json_config, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options) - for c in clip_data: - m, u = clip.load_sd(c) - if len(m) > 0: - logger.warning("clip missing: {}".format(m)) - - if len(u) > 0: - logger.debug("clip unexpected: {}".format(u)) + clip = CLIP(clip_target, textmodel_json_config=textmodel_json_config, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options) return clip @@ -1285,6 +1408,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c weight_dtype = utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata) + model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: logger.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") @@ -1294,16 +1421,21 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.quant_config is not None: weight_dtype = None - model_config.custom_operations = model_options.get("custom_operations", None) + if custom_operations is not None: + model_config.custom_operations = custom_operations + unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)) if unet_dtype is None: unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype) - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if model_config.quant_config is not None: + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config.clip_vision_prefix is not None: @@ -1321,22 +1453,33 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c vae = VAE(sd=vae_sd, metadata=metadata) if output_clip: + if te_model_options.get("custom_operations", None) is None: + scaled_fp8_list = [] + for k in list(sd.keys()): # Convert scaled fp8 to mixed ops + if k.endswith(".scaled_fp8"): + scaled_fp8_list.append(k[:-len("scaled_fp8")]) + + if len(scaled_fp8_list) > 0: + out_sd = {} + for k in sd: + skip = False + for pref in scaled_fp8_list: + skip = skip or k.startswith(pref) + if not skip: + out_sd[k] = sd[k] + + for pref in scaled_fp8_list: + quant_sd, qmetadata = utils.convert_old_quants(sd, pref, metadata={}) + for k in quant_sd: + out_sd[k] = quant_sd[k] + sd = out_sd + clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: parameters = utils.calculate_parameters(clip_sd) - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options) - m, u = clip.load_sd(clip_sd, full_model=True) - if len(m) > 0: - m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) - if len(m_filter) > 0: - logger.warning("clip missing: {}".format(m)) - else: - logger.debug("clip missing: {}".format(m)) - - if len(u) > 0: - logger.debug("clip unexpected {}:".format(u)) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options) else: logger.warning(f"no CLIP/text encoder weights in checkpoint {ckpt_path}, the text encoder model will not be loaded.") @@ -1385,6 +1528,9 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O if len(temp_sd) > 0: sd = temp_sd + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = utils.convert_old_quants(sd, "", metadata=metadata) parameters = utils.calculate_parameters(sd) weight_dtype = utils.weight_dtype(sd) load_device = model_management.get_torch_device() @@ -1414,7 +1560,7 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O offload_device = model_management.unet_offload_device() unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.quant_config is not None: weight_dtype = None if dtype is None: @@ -1422,9 +1568,15 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O else: unet_dtype = dtype - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if model_config.quant_config is not None: + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) - model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) + + if custom_operations is not None: + model_config.custom_operations = custom_operations + if model_options.get("fp8_optimizations", False): model_config.optimizations["fp8"] = True @@ -1437,7 +1589,7 @@ def load_diffusion_model_state_dict(sd, model_options: dict = None, ckpt_path: O return model_patcher.ModelPatcher(model, load_device=load_device, offload_device=offload_device, ckpt_name=os.path.basename(ckpt_path)) -def load_diffusion_model(unet_path, model_options: dict = None): +def load_diffusion_model(unet_path, model_options=None): if model_options is None: model_options = {} sd, metadata = utils.load_torch_file(unet_path, return_metadata=True) @@ -1468,6 +1620,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if vae is not None: vae_sd = vae.get_sd() + if metadata is None: + metadata = {} + model_management.load_models_gpu(load_models, force_patch_weights=True) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 24fa61b32..556885fad 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -28,6 +28,7 @@ except ImportError: logger = logging.getLogger(__name__) + def gen_empty_tokens(special_tokens, length): start_token = special_tokens.get("start", None) end_token = special_tokens.get("end", None) @@ -132,19 +133,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): config[k] = v operations = model_options.get("custom_operations", None) - scaled_fp8 = None + quant_config = model_options.get("quantization_metadata", None) if operations is None: - scaled_fp8 = model_options.get("scaled_fp8", None) - if scaled_fp8 is not None: - operations = ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) + if quant_config is not None: + operations = ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True) + logging.info("Using MixedPrecisionOps for text encoder") else: operations = ops.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations) - if scaled_fp8 is not None: - self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8)) self.num_layers = self.transformer.num_layers @@ -162,6 +161,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer_norm_hidden_state = layer_norm_hidden_state self.return_projected_pooled = return_projected_pooled self.return_attention_masks = return_attention_masks + self.execution_device = None if layer == "hidden": assert layer_idx is not None @@ -178,7 +178,8 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): def set_clip_options(self, options): layer_idx = options.get("layer", self.layer_idx) self.return_projected_pooled = options.get("projected_pooled", self.return_projected_pooled) - if self.layer == "all": + self.execution_device = options.get("execution_device", self.execution_device) + if isinstance(self.layer, list) or self.layer == "all": pass elif layer_idx is None or abs(layer_idx) > self.num_layers: self.layer = "last" @@ -190,6 +191,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): self.layer = self.options_default[0] self.layer_idx = self.options_default[1] self.return_projected_pooled = self.options_default[2] + self.execution_device = None def process_tokens(self, tokens, device): end_token = self.special_tokens.get("end", None) @@ -273,14 +275,20 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): return torch.cat(embeds_out), torch.tensor(attention_masks, device=device, dtype=torch.long), num_tokens, embeds_info def forward(self, tokens): - device = self.transformer.get_input_embeddings().weight.device + if self.execution_device is None: + device = self.transformer.get_input_embeddings().weight.device + else: + device = self.execution_device + embeds, attention_mask, num_tokens, embeds_info = self.process_tokens(tokens, device) attention_mask_model = None if self.enable_attention_masks: attention_mask_model = attention_mask - if self.layer == "all": + if isinstance(self.layer, list): + intermediate_output = self.layer + elif self.layer == "all": intermediate_output = "all" else: intermediate_output = self.layer_idx @@ -478,6 +486,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No valid_file = None for embed_dir in embedding_directory: + # todo: improve this, so that it is more compatible between linux and windows embed_path = os.path.abspath(os.path.join(embed_dir, embedding_name)) embed_dir = os.path.abspath(embed_dir) try: @@ -546,7 +555,7 @@ SDTokenizerT = TypeVar('SDTokenizerT', bound='SDTokenizer') class SDTokenizer: - def __init__(self, tokenizer_path: Optional[Union[torch.Tensor, bytes, bytearray, memoryview, str, Path, Traversable]] = None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, tokenizer_data=None, tokenizer_args=None): + def __init__(self, tokenizer_path: Optional[Union[torch.Tensor, bytes, bytearray, memoryview, str, Path, Traversable]] = None, max_length=77, pad_with_end=True, embedding_directory=None, embedding_size=768, embedding_key='clip_l', tokenizer_class=CLIPTokenizer, has_start_token=True, has_end_token=True, pad_to_max_length=True, min_length=None, pad_token=None, end_token=None, min_padding=None, pad_left=False, tokenizer_data=None, tokenizer_args=None): if tokenizer_data is None: tokenizer_data = dict() if tokenizer_args is None: @@ -568,6 +577,7 @@ class SDTokenizer: self.min_length = tokenizer_data.get("{}_min_length".format(embedding_key), min_length) self.end_token = None self.min_padding = min_padding + self.pad_left = pad_left empty = self.tokenizer('')["input_ids"] self.tokenizer_adds_end_token = has_end_token @@ -635,6 +645,13 @@ class SDTokenizer: return (embed, "{} {}".format(embedding_name[len(stripped):], leftover)) return (embed, leftover) + def pad_tokens(self, tokens, amount): + if self.pad_left: + for i in range(amount): + tokens.insert(0, (self.pad_token, 1.0, 0)) + else: + tokens.extend([(self.pad_token, 1.0, 0)] * amount) + def tokenize_with_weights(self, text: str, return_word_ids=False, tokenizer_options={}, **kwargs): ''' Takes a prompt and converts it to a list of (token, weight, word id) elements. @@ -720,7 +737,7 @@ class SDTokenizer: if self.end_token is not None: batch.append((self.end_token, 1.0, 0)) if self.pad_to_max_length: - batch.extend([(self.pad_token, 1.0, 0)] * (remaining_length)) + self.pad_tokens(batch, remaining_length) # start new batch batch = [] if self.start_token is not None: @@ -734,11 +751,11 @@ class SDTokenizer: if self.end_token is not None: batch.append((self.end_token, 1.0, 0)) if min_padding is not None: - batch.extend([(self.pad_token, 1.0, 0)] * min_padding) + self.pad_tokens(batch, min_padding) if self.pad_to_max_length and len(batch) < self.max_length: - batch.extend([(self.pad_token, 1.0, 0)] * (self.max_length - len(batch))) + self.pad_tokens(batch, self.max_length - len(batch)) if min_length is not None and len(batch) < min_length: - batch.extend([(self.pad_token, 1.0, 0)] * (min_length - len(batch))) + self.pad_tokens(batch, min_length - len(batch)) if not return_word_ids: batched_tokens = [[(t, w) for t, w, _ in x] for x in batched_tokens] @@ -756,7 +773,7 @@ SD1TokenizerT = TypeVar("SD1TokenizerT", bound="SD1Tokenizer") class SD1Tokenizer: - def __init__(self, embedding_directory=None, tokenizer_data: dict=None, clip_name="l", tokenizer=SDTokenizer, name=None): + def __init__(self, embedding_directory=None, tokenizer_data: dict = None, clip_name="l", tokenizer=SDTokenizer, name=None): if tokenizer_data is None: tokenizer_data = {} if name is not None: @@ -792,11 +809,12 @@ class SD1Tokenizer: def state_dict(self): return getattr(self, self.clip).state_dict() + class SD1CheckpointClipModel(SDClipModel): def __init__(self, device="cpu", dtype=None, model_options=None, textmodel_json_config=None): - super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config) if model_options is None: model_options = {} + super().__init__(device=device, return_projected_pooled=False, dtype=dtype, model_options=model_options, textmodel_json_config=textmodel_json_config) class SD1ClipModel(torch.nn.Module): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 2f9fdad27..78df0943c 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -27,6 +27,8 @@ from .text_encoders import sd3_clip from .text_encoders import wan from .text_encoders import qwen_image from .text_encoders import hunyuan_image +from .text_encoders import kandinsky5 +from .text_encoders import z_image class SD15(supported_models_base.BASE): @@ -798,6 +800,40 @@ class FluxSchnell(Flux): return out +class Flux2(Flux): + unet_config = { + "image_model": "flux2", + } + + sampling_settings = { + "shift": 2.02, + } + + unet_extra_config = {} + latent_format = latent_formats.Flux2 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def __init__(self, unet_config): + super().__init__(unet_config) + self.memory_usage_factor = self.memory_usage_factor * (2.0 * 2.0) * 2.36 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Flux2(self, device=device) + return out + + def clip_target(self, state_dict=None): + if state_dict is None: + state_dict = {} + return None # TODO + # pref = self.text_encoder_key_prefix[0] + # t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref)) + # return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect)) + + class GenmoMochi(supported_models_base.BASE): unet_config = { "image_model": "mochi_preview", @@ -1039,7 +1075,7 @@ class Lumina2(supported_models_base.BASE): "shift": 6.0, } - memory_usage_factor = 1.2 + memory_usage_factor = 1.4 unet_extra_config = {} latent_format = latent_formats.Flux @@ -1061,6 +1097,27 @@ class Lumina2(supported_models_base.BASE): return supported_models_base.ClipTarget(lumina2.LuminaTokenizer, lumina2.te(**hunyuan_detect)) +class ZImage(Lumina2): + unet_config = { + "image_model": "lumina2", + "dim": 3840, + } + + sampling_settings = { + "multiplier": 1.0, + "shift": 3.0, + } + + memory_usage_factor = 1.7 + + supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref)) + return supported_models_base.ClipTarget(z_image.ZImageTokenizer, z_image.te(**hunyuan_detect)) + + class WAN21_T2V(supported_models_base.BASE): unet_config = { "image_model": "wan2.1", @@ -1483,6 +1540,108 @@ class HunyuanImage21Refiner(HunyuanVideo): return out -models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage] +class HunyuanVideo15(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "vision_in_dim": 1152, + } + + sampling_settings = { + "shift": 7.0, + } + memory_usage_factor = 4.0 # TODO + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + latent_format = latent_formats.HunyuanVideo15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanVideo15(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(hunyuan_video.HunyuanVideo15Tokenizer, hunyuan_image.te(**hunyuan_detect)) + + +class HunyuanVideo15_SR_Distilled(HunyuanVideo): + unet_config = { + "image_model": "hunyuan_video", + "vision_in_dim": 1152, + "in_channels": 98, + } + + sampling_settings = { + "shift": 2.0, + } + memory_usage_factor = 4.0 # TODO + supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32] + + latent_format = latent_formats.HunyuanVideo15 + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.HunyuanVideo15_SR_Distilled(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(hunyuan_video.HunyuanVideo15Tokenizer, hunyuan_image.te(**hunyuan_detect)) + + +class Kandinsky5(supported_models_base.BASE): + unet_config = { + "image_model": "kandinsky5", + } + + sampling_settings = { + "shift": 10.0, + } + + unet_extra_config = {} + latent_format = latent_formats.HunyuanVideo + + memory_usage_factor = 1.1 # TODO + + supported_inference_dtypes = [torch.bfloat16, torch.float32] + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Kandinsky5(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(kandinsky5.Kandinsky5Tokenizer, kandinsky5.te(**hunyuan_detect)) + + +class Kandinsky5Image(Kandinsky5): + unet_config = { + "image_model": "kandinsky5", + "model_dim": 2560, + "visual_embed_dim": 64, + } + + sampling_settings = { + "shift": 3.0, + } + + latent_format = latent_formats.Flux + memory_usage_factor = 1.1 # TODO + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.Kandinsky5Image(self, device=device) + return out + + def clip_target(self, state_dict={}): + pref = self.text_encoder_key_prefix[0] + hunyuan_detect = hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref)) + return supported_models_base.ClipTarget(kandinsky5.Kandinsky5TokenizerImage, kandinsky5.te(**hunyuan_detect)) + + +models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5] models += [SVD_img2vid] diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index e8aa64ad2..9f739bccf 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -18,10 +18,10 @@ from typing import Optional import torch +import logging from . import model_base from . import utils from . import latent_formats -from .ops import Operations class ClipTarget: @@ -30,6 +30,7 @@ class ClipTarget: self.tokenizer = tokenizer self.params = {} + class BASE: unet_config = {} unet_extra_config = { @@ -51,8 +52,8 @@ class BASE: memory_usage_factor = 2.0 manual_cast_dtype: Optional[torch.dtype] = None - custom_operations: Optional[Operations] = None - scaled_fp8: Optional[torch.dtype] = None + custom_operations: Optional[torch.dtype] = None + quant_config = None # quantization configuration for mixed precision optimizations = {"fp8": False} @classmethod @@ -120,3 +121,7 @@ class BASE: def set_inference_dtype(self, dtype, manual_cast_dtype): self.unet_config['dtype'] = dtype self.manual_cast_dtype = manual_cast_dtype + + def __getattr__(self, name): + logging.warning("\nWARNING, you accessed {} from the model config object which doesn't exist. Please fix your code.\n".format(name)) + return None diff --git a/comfy/taesd/taehv.py b/comfy/taesd/taehv.py new file mode 100644 index 000000000..3dfe1e4d4 --- /dev/null +++ b/comfy/taesd/taehv.py @@ -0,0 +1,171 @@ +# Tiny AutoEncoder for HunyuanVideo and WanVideo https://github.com/madebyollin/taehv + +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm.auto import tqdm +from collections import namedtuple, deque + +import comfy.ops +operations=comfy.ops.disable_weight_init + +DecoderResult = namedtuple("DecoderResult", ("frame", "memory")) +TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index")) + +def conv(n_in, n_out, **kwargs): + return operations.Conv2d(n_in, n_out, 3, padding=1, **kwargs) + +class Clamp(nn.Module): + def forward(self, x): + return torch.tanh(x / 3) * 3 + +class MemBlock(nn.Module): + def __init__(self, n_in, n_out, act_func): + super().__init__() + self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out)) + self.skip = operations.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity() + self.act = act_func + def forward(self, x, past): + return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x)) + +class TPool(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = operations.Conv2d(n_f*stride,n_f, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + return self.conv(x.reshape(-1, self.stride * C, H, W)) + +class TGrow(nn.Module): + def __init__(self, n_f, stride): + super().__init__() + self.stride = stride + self.conv = operations.Conv2d(n_f, n_f*stride, 1, bias=False) + def forward(self, x): + _NT, C, H, W = x.shape + x = self.conv(x) + return x.reshape(-1, C, H, W) + +def apply_model_with_memblocks(model, x, parallel, show_progress_bar): + + B, T, C, H, W = x.shape + if parallel: + x = x.reshape(B*T, C, H, W) + # parallel over input timesteps, iterate over blocks + for b in tqdm(model, disable=not show_progress_bar): + if isinstance(b, MemBlock): + BT, C, H, W = x.shape + T = BT // B + _x = x.reshape(B, T, C, H, W) + mem = F.pad(_x, (0,0,0,0,0,0,1,0), value=0)[:,:T].reshape(x.shape) + x = b(x, mem) + else: + x = b(x) + BT, C, H, W = x.shape + T = BT // B + x = x.view(B, T, C, H, W) + else: + out = [] + work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))]) + progress_bar = tqdm(range(T), disable=not show_progress_bar) + mem = [None] * len(model) + while work_queue: + xt, i = work_queue.popleft() + if i == 0: + progress_bar.update(1) + if i == len(model): + out.append(xt) + del xt + else: + b = model[i] + if isinstance(b, MemBlock): + if mem[i] is None: + xt_new = b(xt, xt * 0) + mem[i] = xt.detach().clone() + else: + xt_new = b(xt, mem[i]) + mem[i] = xt.detach().clone() + del xt + work_queue.appendleft(TWorkItem(xt_new, i+1)) + elif isinstance(b, TPool): + if mem[i] is None: + mem[i] = [] + mem[i].append(xt.detach().clone()) + if len(mem[i]) == b.stride: + B, C, H, W = xt.shape + xt = b(torch.cat(mem[i], 1).view(B*b.stride, C, H, W)) + mem[i] = [] + work_queue.appendleft(TWorkItem(xt, i+1)) + elif isinstance(b, TGrow): + xt = b(xt) + NT, C, H, W = xt.shape + for xt_next in reversed(xt.view(B, b.stride*C, H, W).chunk(b.stride, 1)): + work_queue.appendleft(TWorkItem(xt_next, i+1)) + del xt + else: + xt = b(xt) + work_queue.appendleft(TWorkItem(xt, i+1)) + progress_bar.close() + x = torch.stack(out, 1) + return x + + +class TAEHV(nn.Module): + def __init__(self, latent_channels, parallel=False, decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True), latent_format=None, show_progress_bar=True): + super().__init__() + self.image_channels = 3 + self.patch_size = 1 + self.latent_channels = latent_channels + self.parallel = parallel + self.latent_format = latent_format + self.show_progress_bar = show_progress_bar + self.process_in = latent_format().process_in if latent_format is not None else (lambda x: x) + self.process_out = latent_format().process_out if latent_format is not None else (lambda x: x) + if self.latent_channels in [48, 32]: # Wan 2.2 and HunyuanVideo1.5 + self.patch_size = 2 + if self.latent_channels == 32: # HunyuanVideo1.5 + act_func = nn.LeakyReLU(0.2, inplace=True) + else: # HunyuanVideo, Wan 2.1 + act_func = nn.ReLU(inplace=True) + + self.encoder = nn.Sequential( + conv(self.image_channels*self.patch_size**2, 64), act_func, + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), MemBlock(64, 64, act_func), + conv(64, self.latent_channels), + ) + n_f = [256, 128, 64, 64] + self.frames_to_trim = 2**sum(decoder_time_upscale) - 1 + self.decoder = nn.Sequential( + Clamp(), conv(self.latent_channels, n_f[0]), act_func, + MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), MemBlock(n_f[0], n_f[0], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False), + MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), MemBlock(n_f[1], n_f[1], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False), + MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), MemBlock(n_f[2], n_f[2], act_func), nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False), + act_func, conv(n_f[3], self.image_channels*self.patch_size**2), + ) + @property + def show_progress_bar(self): + return self._show_progress_bar + + @show_progress_bar.setter + def show_progress_bar(self, value): + self._show_progress_bar = value + + def encode(self, x, **kwargs): + if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size) + x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] + if x.shape[1] % 4 != 0: + # pad at end to multiple of 4 + n_pad = 4 - x.shape[1] % 4 + padding = x[:, -1:].repeat_interleave(n_pad, dim=1) + x = torch.cat([x, padding], 1) + x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1) + return self.process_out(x) + + def decode(self, x, **kwargs): + x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W] + x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar) + if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size) + return x[:, self.frames_to_trim:].movedim(2, 1) diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index 0876d5b54..98a02a72b 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -11,10 +11,10 @@ class T5XXLModel(sd1_clip.SDClipModel): if model_options is None: model_options = {} textmodel_json_config = get_path_as_dict(textmodel_json_config, "t5_old_config_xxl.json", package=__package__) - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options) @@ -43,14 +43,14 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer): -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class CosmosTEModel_(CosmosT5XXL): def __init__(self, device="cpu", dtype=None, model_options=None): if model_options is None: model_options = {} - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index d202c1396..ca2ff55bf 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -1,12 +1,15 @@ import copy import torch -from transformers import T5TokenizerFast +from transformers import T5TokenizerFast, LlamaTokenizerFast from .sd3_clip import T5XXLModel from .. import sd1_clip, model_management from ..component_model import files +import json +import base64 + class T5XXLTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data=None): @@ -73,14 +76,127 @@ class FluxClipModel(torch.nn.Module): return self.t5xxl.load_sd(sd) -def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None): +def flux_clip(dtype_t5=None, t5_quantization_metadata=None): class FluxClipModel_(FluxClipModel): def __init__(self, device="cpu", dtype=None, model_options=None): if model_options is None: model_options = {} - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) return FluxClipModel_ + + +def load_mistral_tokenizer(data): + if torch.is_tensor(data): + data = data.numpy().tobytes() + + # we just have to use the latest transformers + from transformers.integrations.mistral import MistralConverter + + mistral_vocab = json.loads(data) + + special_tokens = {} + vocab = {} + + max_vocab = mistral_vocab["config"]["default_vocab_size"] + max_vocab -= len(mistral_vocab["special_tokens"]) + + for w in mistral_vocab["vocab"]: + r = w["rank"] + if r >= max_vocab: + continue + + vocab[base64.b64decode(w["token_bytes"])] = r + + for w in mistral_vocab["special_tokens"]: + if "token_bytes" in w: + special_tokens[base64.b64decode(w["token_bytes"])] = w["rank"] + else: + special_tokens[w["token_str"]] = w["rank"] + + all_special = [] + for v in special_tokens: + all_special.append(v) + + special_tokens.update(vocab) + vocab = special_tokens + return {"tokenizer_object": MistralConverter(vocab=vocab, additional_special_tokens=all_special).converted(), "legacy": False} + + +class MistralTokenizerClass: + @staticmethod + def from_pretrained(path, **kwargs): + return LlamaTokenizerFast(**kwargs) + + +class Mistral3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = {} + self.tekken_data = tokenizer_data.get("tekken_model", None) + super().__init__("", pad_with_end=False, embedding_size=5120, embedding_key='mistral3_24b', tokenizer_class=MistralTokenizerClass, has_end_token=False, pad_to_max_length=False, pad_token=11, max_length=99999999, min_length=1, pad_left=True, tokenizer_args=load_mistral_tokenizer(self.tekken_data), tokenizer_data=tokenizer_data) + + def state_dict(self): + return {"tekken_model": self.tekken_data} + + +class Flux2Tokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = {} + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="mistral3_24b", tokenizer=Mistral3Tokenizer) + self.llama_template = '[SYSTEM_PROMPT]You are an AI that reasons about image descriptions. You give structured responses focusing on object relationships, object\nattribution and actions without speculation.[/SYSTEM_PROMPT][INST]{}[/INST]' + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + + +class Mistral3_24BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer=None, layer_idx=None, dtype=None, attention_mask=True, model_options={}): + if layer is None: + layer = [10, 20, 30] + textmodel_json_config = {} + num_layers = model_options.get("num_layers", None) + if num_layers is not None: + textmodel_json_config["num_hidden_layers"] = num_layers + if num_layers < 40: + textmodel_json_config["final_norm"] = False + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 1, "pad": 0}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Mistral3Small24B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class Flux2TEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}, name="mistral3_24b", clip_model=Mistral3_24BModel): + super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + + out = torch.stack((out[:, 0], out[:, 1], out[:, 2]), dim=1) + out = out.movedim(1, 2) + out = out.reshape(out.shape[0], out.shape[1], -1) + return out, pooled, extra + + +def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): + class Flux2TEModel_(Flux2TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + if pruned: + model_options = model_options.copy() + model_options["num_layers"] = 30 + super().__init__(device=device, dtype=dtype, model_options=model_options) + + return Flux2TEModel_ diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 529a94830..019f1f512 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -33,14 +33,14 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): tokenizer_data = {} -def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None): +def mochi_te(dtype_t5=None, t5_quantization_metadata=None): class MochiTEModel_(MochiT5XXL): def __init__(self, device="cpu", dtype=None, model_options=None): if model_options is None: model_options = {} - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index a617ddd9a..68e64fb17 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -9,6 +9,7 @@ from ..model_management import intermediate_device, pick_weight_dtype logger = logging.getLogger(__name__) + class HiDreamTokenizer: def __init__(self, embedding_directory=None, tokenizer_data=None): if tokenizer_data is None: @@ -148,17 +149,17 @@ class HiDreamTEModel(torch.nn.Module): return self.llama.load_sd(sd) -def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None): +def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None): class HiDreamTEModel_(HiDreamTEModel): def __init__(self, device="cpu", dtype=None, model_options=None): if model_options is None: model_options = {} - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HiDreamTEModel_ diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index 6e07d5a88..e1587b839 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -14,7 +14,7 @@ class ByT5SmallTokenizer(sd1_clip.SDTokenizer): if tokenizer_data is None: tokenizer_data = {} tokenizer_path = files.get_package_as_path("byt5_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, embedding_directory=None, pad_with_end=False, embedding_size=1472, embedding_key='byt5_small', tokenizer_class=ByT5Tokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_data=tokenizer_data) class HunyuanImageTokenizer(QwenImageTokenizer): @@ -52,10 +52,10 @@ class Qwen25_7BVLIModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options=None): if model_options is None: model_options = {} - llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -108,14 +108,14 @@ class HunyuanImageTEModel(QwenImageTEModel): return super().load_sd(sd) -def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None): +def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None): class QwenImageTEModel_(HunyuanImageTEModel): def __init__(self, device="cpu", dtype=None, model_options=None): if model_options is None: model_options = {} - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index d263deb67..fc2c60afb 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -2,10 +2,12 @@ import torch import numbers from transformers import LlamaTokenizerFast +from .hunyuan_image import HunyuanImageTokenizer from .llama import Llama2 from .. import sd1_clip from ..component_model import files from ..model_management import pick_weight_dtype +from ..utils import detect_layer_quantization def llama_detect(state_dict, prefix=""): @@ -14,9 +16,9 @@ def llama_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_llama"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + quant = detect_layer_quantization(state_dict, prefix) + if quant is not None: + out["llama_quantization_metadata"] = quant return out @@ -35,10 +37,10 @@ class LLAMAModel(sd1_clip.SDClipModel): special_tokens = {"start": 128000, "pad": 128258} if model_options is None: model_options = {} - llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata textmodel_json_config = textmodel_json_config or {} vocab_size = model_options.get("vocab_size", None) @@ -83,6 +85,15 @@ class HunyuanVideoTokenizer: return {} +class HunyuanVideo15Tokenizer(HunyuanImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a helpful assistant. Describe the video by detailing the following aspects:\n1. The main content and theme of the video.\n2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects.\n3. Actions, events, behaviors temporal relationships, physical movement changes of the objects.\n4. background environment, light, style and atmosphere.\n5. camera angles, movements, and transitions used in the video.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs): + return super().tokenize_with_weights(text, return_word_ids, prevent_empty_text=True, **kwargs) + + class HunyuanVideoClipModel(torch.nn.Module): def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options=None): super().__init__() @@ -161,14 +172,14 @@ class HunyuanVideoClipModel(torch.nn.Module): return self.llama.load_sd(sd) -def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None): +def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None): class HunyuanVideoClipModel_(HunyuanVideoClipModel): def __init__(self, device="cpu", dtype=None, model_options=None): if model_options is None: model_options = {} - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HunyuanVideoClipModel_ diff --git a/comfy/text_encoders/kandinsky5.py b/comfy/text_encoders/kandinsky5.py new file mode 100644 index 000000000..be086458c --- /dev/null +++ b/comfy/text_encoders/kandinsky5.py @@ -0,0 +1,68 @@ +from comfy import sd1_clip +from .qwen_image import QwenImageTokenizer, QwenImageTEModel +from .llama import Qwen25_7BVLI + + +class Kandinsky5Tokenizer(QwenImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a prompt engineer. Describe the video in detail.\nDescribe how the camera moves or shakes, describe the zoom and view angle, whether it follows the objects.\nDescribe the location of the video, main characters or objects and their action.\nDescribe the dynamism of the video and presented actions.\nName the visual style of the video: whether it is a professional footage, user generated content, some kind of animation, video game or screen content.\nDescribe the visual effects, postprocessing and transitions if they are presented in the video.\nPay attention to the order of key actions shown in the scene.<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + + def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): + out = super().tokenize_with_weights(text, return_word_ids, **kwargs) + out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs) + + return out + + +class Kandinsky5TokenizerImage(Kandinsky5Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" + + +class Qwen25_7BVLIModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}): + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class Kandinsky5TEModel(QwenImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=False, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + cond, p, extra = super().encode_token_weights(token_weight_pairs, template_end=-1) + l_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs["l"]) + + return cond, l_pooled, extra + + def set_clip_options(self, options): + super().set_clip_options(options) + self.clip_l.set_clip_options(options) + + def reset_clip_options(self): + super().reset_clip_options() + self.clip_l.reset_clip_options() + + def load_sd(self, sd): + if "text_model.encoder.layers.1.mlp.fc1.weight" in sd: + return self.clip_l.load_sd(sd) + else: + return super().load_sd(sd) + +def te(dtype_llama=None, llama_quantization_metadata=None): + class Kandinsky5TEModel_(Kandinsky5TEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_quantization_metadata is not None: + model_options = model_options.copy() + model_options["llama_quantization_metadata"] = llama_quantization_metadata + if dtype_llama is not None: + dtype = dtype_llama + super().__init__(device=device, dtype=dtype, model_options=model_options) + return Kandinsky5TEModel_ diff --git a/comfy/text_encoders/llama.py b/comfy/text_encoders/llama.py index ee9c2ef2e..bf8bda48f 100644 --- a/comfy/text_encoders/llama.py +++ b/comfy/text_encoders/llama.py @@ -33,6 +33,30 @@ class Llama2Config: q_norm = None k_norm = None rope_scale = None + final_norm: bool = True + + +@dataclass +class Mistral3Small24BConfig: + vocab_size: int = 131072 + hidden_size: int = 5120 + intermediate_size: int = 32768 + num_hidden_layers: int = 40 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 8192 + rms_norm_eps: float = 1e-5 + rope_theta: float = 1000000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = None + k_norm = None + rope_scale = None + final_norm: bool = True @dataclass @@ -55,6 +79,53 @@ class Qwen25_3BConfig: q_norm = None k_norm = None rope_scale = None + final_norm: bool = True + + +@dataclass +class Qwen3_4BConfig: + vocab_size: int = 151936 + hidden_size: int = 2560 + intermediate_size: int = 9728 + num_hidden_layers: int = 36 + num_attention_heads: int = 32 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True + + +@dataclass +class Ovis25_2BConfig: + vocab_size: int = 151936 + hidden_size: int = 2048 + intermediate_size: int = 6144 + num_hidden_layers: int = 28 + num_attention_heads: int = 16 + num_key_value_heads: int = 8 + max_position_embeddings: int = 40960 + rms_norm_eps: float = 1e-6 + rope_theta: float = 1000000.0 + transformer_type: str = "llama" + head_dim = 128 + rms_norm_add = False + mlp_activation = "silu" + qkv_bias = False + rope_dims = None + q_norm = "gemma3" + k_norm = "gemma3" + rope_scale = None + final_norm: bool = True @dataclass @@ -77,6 +148,7 @@ class Qwen25_7BVLI_Config: q_norm = None k_norm = None rope_scale = None + final_norm: bool = True @dataclass @@ -100,6 +172,7 @@ class Gemma2_2B_Config: k_norm = None sliding_attention = None rope_scale = None + final_norm: bool = True @dataclass @@ -123,6 +196,7 @@ class Gemma3_4B_Config: k_norm = "gemma3" sliding_attention = [False, False, False, False, False, 1024] rope_scale = [1.0, 8.0] + final_norm: bool = True class RMSNorm(nn.Module): @@ -375,7 +449,12 @@ class Llama2_(nn.Module): transformer(config, index=i, device=device, dtype=dtype, ops=ops) for i in range(config.num_hidden_layers) ]) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + + if config.final_norm: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, add=config.rms_norm_add, device=device, dtype=dtype) + else: + self.norm = None + # self.lm_head = ops.Linear(config.hidden_size, config.vocab_size, bias=False, device=device, dtype=dtype) def forward(self, x, attention_mask=None, embeds=None, num_tokens=None, intermediate_output=None, final_layer_norm_intermediate=True, dtype=None, position_ids=None, embeds_info=[]): @@ -411,8 +490,12 @@ class Llama2_(nn.Module): intermediate = None all_intermediate = None + only_layers = None if intermediate_output is not None: - if intermediate_output == "all": + if isinstance(intermediate_output, list): + all_intermediate = [] + only_layers = set(intermediate_output) + elif intermediate_output == "all": all_intermediate = [] intermediate_output = None elif intermediate_output < 0: @@ -420,7 +503,8 @@ class Llama2_(nn.Module): for i, layer in enumerate(self.layers): if all_intermediate is not None: - all_intermediate.append(x.unsqueeze(1).clone()) + if only_layers is None or (i in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) x = layer( x=x, attention_mask=mask, @@ -430,14 +514,17 @@ class Llama2_(nn.Module): if i == intermediate_output: intermediate = x.clone() - x = self.norm(x) + if self.norm is not None: + x = self.norm(x) + if all_intermediate is not None: - all_intermediate.append(x.unsqueeze(1).clone()) + if only_layers is None or ((i + 1) in only_layers): + all_intermediate.append(x.unsqueeze(1).clone()) if all_intermediate is not None: intermediate = torch.cat(all_intermediate, dim=1) - if intermediate is not None and final_layer_norm_intermediate: + if intermediate is not None and final_layer_norm_intermediate and self.norm is not None: intermediate = self.norm(intermediate) return x, intermediate @@ -466,6 +553,16 @@ class Llama2(BaseLlama, torch.nn.Module): self.dtype = dtype +class Mistral3Small24B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Mistral3Small24BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + + class Qwen25_3B(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() @@ -476,6 +573,26 @@ class Qwen25_3B(BaseLlama, torch.nn.Module): self.dtype = dtype +class Qwen3_4B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Qwen3_4BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + + +class Ovis25_2B(BaseLlama, torch.nn.Module): + def __init__(self, config_dict, dtype, device, operations): + super().__init__() + config = Ovis25_2BConfig(**config_dict) + self.num_layers = config.num_hidden_layers + + self.model = Llama2_(config, device=device, dtype=dtype, ops=operations) + self.dtype = dtype + + class Qwen25_7BVLI(BaseLlama, torch.nn.Module): def __init__(self, config_dict, dtype, device, operations): super().__init__() diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index f4834ce4d..89a58d872 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -8,29 +8,35 @@ class Gemma2BTokenizer(sd1_clip.SDTokenizer): if tokenizer_data is None: tokenizer_data = {} tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2304, embedding_directory=None, embedding_key='gemma2_2b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} + class Gemma3_4BTokenizer(sd1_clip.SDTokenizer): - def __init__(self, embedding_directory=None, tokenizer_data={}): + def __init__(self, embedding_directory=None, tokenizer_data=None): + if tokenizer_data is None: + tokenizer_data = {} tokenizer = tokenizer_data.get("spiece_model", None) - super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) + super().__init__(tokenizer, pad_with_end=False, embedding_size=2560, embedding_directory=None, embedding_key='gemma3_4b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data) def state_dict(self): return {"spiece_model": self.tokenizer.serialize_model()} + class LuminaTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data=None): if tokenizer_data is None: tokenizer_data = {} super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma2_2b", tokenizer=Gemma2BTokenizer) + class NTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="gemma3_4b", tokenizer=Gemma3_4BTokenizer) + class Gemma2_2BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options=None, textmodel_json_config=None): if model_options is None: @@ -38,10 +44,12 @@ class Gemma2_2BModel(sd1_clip.SDClipModel): textmodel_json_config = textmodel_json_config or {} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=Gemma2_2B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + class Gemma3_4BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"start": 2, "pad": 0}, layer_norm_hidden_state=False, model_class=Gemma3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + class LuminaModel(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options=None, name="gemma2_2b", clip_model=Gemma2_2BModel): if model_options is None: @@ -49,21 +57,22 @@ class LuminaModel(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) - -def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"): +def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"): model = None if model_type == "gemma2_2b": model = Gemma2_2BModel elif model_type == "gemma3_4b": model = Gemma3_4BModel + class LuminaTEModel_(LuminaModel): def __init__(self, device="cpu", dtype=None, model_options=None): if model_options is None: model_options = {} - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model) + return LuminaTEModel_ diff --git a/comfy/text_encoders/omnigen2.py b/comfy/text_encoders/omnigen2.py index 1fa9669d5..dc9941488 100644 --- a/comfy/text_encoders/omnigen2.py +++ b/comfy/text_encoders/omnigen2.py @@ -1,8 +1,7 @@ from transformers import Qwen2Tokenizer -from .. import sd1_clip -from .llama import Qwen25_3B -import os +from .llama import Qwen25_3B +from .. import sd1_clip from ..component_model import files @@ -10,8 +9,8 @@ class Qwen25_3BTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data=None): if tokenizer_data is None: tokenizer_data = {} - tokenizer_path = files.get_package_as_path("comfy.text_encoders.qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + tokenizer_path = files.get_package_as_path("comfy.text_encoders.qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_directory=embedding_directory, embedding_key='qwen25_3b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer): @@ -21,20 +20,20 @@ class Omnigen2Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen25_3b", tokenizer=Qwen25_3BTokenizer) self.llama_template = '<|im_start|>system\nYou are a helpful assistant that generates high-quality images based on user instructions.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n' - def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None,**kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): if llama_template is None: llama_text = self.llama_template.format(text) else: llama_text = llama_template.format(text) return super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, **kwargs) + class Qwen25_3BModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options=None, textmodel_json_config=None): if model_options is None: model_options = {} textmodel_json_config = textmodel_json_config or {} - super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) - + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_3B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) class Omnigen2Model(sd1_clip.SD1ClipModel): @@ -44,15 +43,16 @@ class Omnigen2Model(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class Omnigen2TEModel_(Omnigen2Model): def __init__(self, device="cpu", dtype=None, model_options=None): if model_options is None: model_options = {} - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) + return Omnigen2TEModel_ diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py new file mode 100644 index 000000000..5754424d2 --- /dev/null +++ b/comfy/text_encoders/ovis.py @@ -0,0 +1,66 @@ +from transformers import Qwen2Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import os +import torch +import numbers + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2048, embedding_key='qwen3_2b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=284, pad_token=151643, tokenizer_data=tokenizer_data) + + +class OvisTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_2b", tokenizer=Qwen3Tokenizer) + self.llama_template = "<|im_start|>user\nDescribe the image by detailing the color, quantity, text, shape, size, texture, spatial relationships of the objects and background: {}<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + +class Ovis25_2BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Ovis25_2B, enable_attention_masks=attention_mask, return_attention_masks=False, zero_out_masked=True, model_options=model_options) + + +class OvisTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_2b", clip_model=Ovis25_2BModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs, template_end=-1): + out, pooled = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen3_2b"][0] + count_im_start = 0 + if template_end == -1: + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 4004 and count_im_start < 1: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 1): + if tok_pairs[template_end + 1][0] == 25: + template_end += 1 + + out = out[:, template_end:] + return out, pooled, {} + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class OvisTEModel_(OvisTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return OvisTEModel_ diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index 95358e072..44adde026 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -35,19 +35,20 @@ class T5XXLTokenizer(sd1_clip.SDTokenizer): class PixArtTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data=None): - super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) if tokenizer_data is None: tokenizer_data = {} + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None): + +def pixart_te(dtype_t5=None, t5_quantization_metadata=None): class PixArtTEModel_(PixArtT5XXL): def __init__(self, device="cpu", dtype=None, model_options=None): if model_options is None: model_options = {} - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json index 67688e82c..df5b5d7fe 100644 --- a/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json +++ b/comfy/text_encoders/qwen25_tokenizer/tokenizer_config.json @@ -179,36 +179,36 @@ "special": false }, "151665": { - "content": "<|img|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151666": { - "content": "<|endofimg|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151667": { - "content": "<|meta|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false }, "151668": { - "content": "<|endofmeta|>", + "content": "", "lstrip": false, "normalized": false, "rstrip": false, "single_word": false, - "special": true + "special": false } }, "additional_special_tokens": [ diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index a56f18291..ce4773d27 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -12,7 +12,7 @@ class Qwen25_7BVLITokenizer(sd1_clip.SDTokenizer): if tokenizer_data is None: tokenizer_data = {} tokenizer_path = files.get_package_as_path("comfy.text_encoders.qwen25_tokenizer") - super().__init__(tokenizer_path, pad_with_end=False, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + super().__init__(tokenizer_path, pad_with_end=False, embedding_directory=embedding_directory, embedding_size=3584, embedding_key='qwen25_7b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) class QwenImageTokenizer(sd1_clip.SD1Tokenizer): @@ -23,12 +23,14 @@ class QwenImageTokenizer(sd1_clip.SD1Tokenizer): self.llama_template = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" self.llama_template_images = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n" - def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], **kwargs): + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, images=[], prevent_empty_text=False, **kwargs): skip_template = False if text.startswith('<|im_start|>'): skip_template = True if text.startswith('<|start_header_id|>'): skip_template = True + if prevent_empty_text and text == '': + text = ' ' if skip_template: llama_text = text @@ -94,14 +96,14 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel): return out, pooled, extra -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class QwenImageTEModel_(QwenImageTEModel): def __init__(self, device="cpu", dtype=None, model_options=None): if model_options is None: model_options = {} - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index d6898d7ef..2511969e3 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -1,5 +1,6 @@ import copy import logging +import comfy.utils import torch from transformers import T5TokenizerFast @@ -17,10 +18,10 @@ class T5XXLModel(sd1_clip.SDClipModel): if model_options is None: model_options = {} textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "t5_config_xxl.json", package=__package__) - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata model_options = {**model_options, "model_name": "t5xxl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -32,9 +33,9 @@ def t5_xxl_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_t5"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + quant = comfy.utils.detect_layer_quantization(state_dict, prefix) + if quant is not None: + out["t5_quantization_metadata"] = quant return out @@ -175,14 +176,14 @@ class SD3ClipModel(torch.nn.Module): return self.t5xxl.load_sd(sd) -def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False): +def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False): class SD3ClipModel_(SD3ClipModel): def __init__(self, device="cpu", dtype=None, model_options=None): if model_options is None: model_options = {} - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options) return SD3ClipModel_ diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index 4567ed55b..6e1ec8c66 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -37,14 +37,14 @@ class WanT5Model(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs) -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class WanTEModel(WanT5Model): def __init__(self, device="cpu", dtype=None, model_options=None): if model_options is None: model_options = {} - if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5_quantization_metadata if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py new file mode 100644 index 000000000..19adde0b7 --- /dev/null +++ b/comfy/text_encoders/z_image.py @@ -0,0 +1,45 @@ +from transformers import Qwen2Tokenizer +import comfy.text_encoders.llama +from comfy import sd1_clip +import os + +class Qwen3Tokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "qwen25_tokenizer") + super().__init__(tokenizer_path, pad_with_end=False, embedding_size=2560, embedding_key='qwen3_4b', tokenizer_class=Qwen2Tokenizer, has_start_token=False, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=151643, tokenizer_data=tokenizer_data) + + +class ZImageTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, name="qwen3_4b", tokenizer=Qwen3Tokenizer) + self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n" + + def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs): + if llama_template is None: + llama_text = self.llama_template.format(text) + else: + llama_text = llama_template.format(text) + + tokens = super().tokenize_with_weights(llama_text, return_word_ids=return_word_ids, disable_weights=True, **kwargs) + return tokens + + +class Qwen3_4BModel(sd1_clip.SDClipModel): + def __init__(self, device="cpu", layer="hidden", layer_idx=-2, dtype=None, attention_mask=True, model_options={}): + super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen3_4B, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) + + +class ZImageTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options) + + +def te(dtype_llama=None, llama_quantization_metadata=None): + class ZImageTEModel_(ZImageTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if dtype_llama is not None: + dtype = dtype_llama + if llama_quantization_metadata is not None: + model_options["quantization_metadata"] = llama_quantization_metadata + super().__init__(device=device, dtype=dtype, model_options=model_options) + return ZImageTEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index 67c02f2c2..170f58436 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -79,6 +79,7 @@ if hasattr(torch.serialization, "add_safe_globals"): # TODO: this was added in except (ImportError, ModuleNotFoundError): from numpy import generic as scalar from numpy import dtype + try: from numpy.dtypes import Float64DType # pylint: disable=no-name-in-module,import-error except (ImportError, ModuleNotFoundError): @@ -779,6 +780,72 @@ def flux_to_diffusers(mmdit_config, output_prefix=""): return key_map +def z_image_to_diffusers(mmdit_config, output_prefix=""): + n_layers = mmdit_config.get("n_layers", 0) + hidden_size = mmdit_config.get("dim", 0) + n_context_refiner = mmdit_config.get("n_refiner_layers", 2) + n_noise_refiner = mmdit_config.get("n_refiner_layers", 2) + key_map = {} + + def add_block_keys(prefix_from, prefix_to, has_adaln=True): + for end in ("weight", "bias"): + k = "{}.attention.".format(prefix_from) + qkv = "{}.attention.qkv.{}".format(prefix_to, end) + key_map["{}to_q.{}".format(k, end)] = (qkv, (0, 0, hidden_size)) + key_map["{}to_k.{}".format(k, end)] = (qkv, (0, hidden_size, hidden_size)) + key_map["{}to_v.{}".format(k, end)] = (qkv, (0, hidden_size * 2, hidden_size)) + + block_map = { + "attention.norm_q.weight": "attention.q_norm.weight", + "attention.norm_k.weight": "attention.k_norm.weight", + "attention.to_out.0.weight": "attention.out.weight", + "attention.to_out.0.bias": "attention.out.bias", + "attention_norm1.weight": "attention_norm1.weight", + "attention_norm2.weight": "attention_norm2.weight", + "feed_forward.w1.weight": "feed_forward.w1.weight", + "feed_forward.w2.weight": "feed_forward.w2.weight", + "feed_forward.w3.weight": "feed_forward.w3.weight", + "ffn_norm1.weight": "ffn_norm1.weight", + "ffn_norm2.weight": "ffn_norm2.weight", + } + if has_adaln: + block_map["adaLN_modulation.0.weight"] = "adaLN_modulation.0.weight" + block_map["adaLN_modulation.0.bias"] = "adaLN_modulation.0.bias" + for k, v in block_map.items(): + key_map["{}.{}".format(prefix_from, k)] = "{}.{}".format(prefix_to, v) + + for i in range(n_layers): + add_block_keys("layers.{}".format(i), "{}layers.{}".format(output_prefix, i)) + + for i in range(n_context_refiner): + add_block_keys("context_refiner.{}".format(i), "{}context_refiner.{}".format(output_prefix, i)) + + for i in range(n_noise_refiner): + add_block_keys("noise_refiner.{}".format(i), "{}noise_refiner.{}".format(output_prefix, i)) + + MAP_BASIC = [ + ("final_layer.linear.weight", "all_final_layer.2-1.linear.weight"), + ("final_layer.linear.bias", "all_final_layer.2-1.linear.bias"), + ("final_layer.adaLN_modulation.1.weight", "all_final_layer.2-1.adaLN_modulation.1.weight"), + ("final_layer.adaLN_modulation.1.bias", "all_final_layer.2-1.adaLN_modulation.1.bias"), + ("x_embedder.weight", "all_x_embedder.2-1.weight"), + ("x_embedder.bias", "all_x_embedder.2-1.bias"), + ("x_pad_token", "x_pad_token"), + ("cap_embedder.0.weight", "cap_embedder.0.weight"), + ("cap_embedder.1.weight", "cap_embedder.1.weight"), + ("cap_embedder.1.bias", "cap_embedder.1.bias"), + ("cap_pad_token", "cap_pad_token"), + ("t_embedder.mlp.0.weight", "t_embedder.mlp.0.weight"), + ("t_embedder.mlp.0.bias", "t_embedder.mlp.0.bias"), + ("t_embedder.mlp.2.weight", "t_embedder.mlp.2.weight"), + ("t_embedder.mlp.2.bias", "t_embedder.mlp.2.bias"), + ] + + for c, diffusers in MAP_BASIC: + key_map[diffusers] = "{}{}".format(output_prefix, c) + + return key_map + def repeat_to_batch_size(tensor, batch_size, dim=0): if tensor.shape[dim] > batch_size: @@ -1379,3 +1446,94 @@ def upscale_dit_mask(mask: torch.Tensor, img_size_in, img_size_out): dim=1 ) return out + + +def pack_latents(latents): + latent_shapes = [] + tensors = [] + for tensor in latents: + latent_shapes.append(tensor.shape) + tensors.append(tensor.reshape(tensor.shape[0], 1, -1)) + + latent = torch.cat(tensors, dim=-1) + return latent, latent_shapes + + +def unpack_latents(combined_latent, latent_shapes): + if len(latent_shapes) > 1: + output_tensors = [] + for shape in latent_shapes: + cut = math.prod(shape[1:]) + tens = combined_latent[:, :, :cut] + combined_latent = combined_latent[:, :, cut:] + output_tensors.append(tens.reshape([tens.shape[0]] + list(shape)[1:])) + else: + output_tensors = combined_latent + return output_tensors + + +def detect_layer_quantization(state_dict, prefix): + for k in state_dict: + if k.startswith(prefix) and k.endswith(".comfy_quant"): + logging.info("Found quantization metadata version 1") + return {"mixed_ops": True} + return None + + +def convert_old_quants(state_dict, model_prefix="", metadata={}): + if metadata is None: + metadata = {} + + quant_metadata = None + if "_quantization_metadata" not in metadata: + scaled_fp8_key = "{}scaled_fp8".format(model_prefix) + + if scaled_fp8_key in state_dict: + scaled_fp8_weight = state_dict[scaled_fp8_key] + scaled_fp8_dtype = scaled_fp8_weight.dtype + if scaled_fp8_dtype == torch.float32: + scaled_fp8_dtype = torch.float8_e4m3fn + + if scaled_fp8_weight.nelement() == 2: + full_precision_matrix_mult = True + else: + full_precision_matrix_mult = False + + out_sd = {} + layers = {} + for k in list(state_dict.keys()): + if not k.startswith(model_prefix): + out_sd[k] = state_dict[k] + continue + k_out = k + w = state_dict.pop(k) + layer = None + if k_out.endswith(".scale_weight"): + layer = k_out[:-len(".scale_weight")] + k_out = "{}.weight_scale".format(layer) + + if layer is not None: + layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints + if full_precision_matrix_mult: + layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult + layers[layer] = layer_conf + + if k_out.endswith(".scale_input"): + layer = k_out[:-len(".scale_input")] + k_out = "{}.input_scale".format(layer) + if w.item() == 1.0: + continue + + out_sd[k_out] = w + + state_dict = out_sd + quant_metadata = {"layers": layers} + else: + quant_metadata = json.loads(metadata["_quantization_metadata"]) + + if quant_metadata is not None: + layers = quant_metadata["layers"] + for k, v in layers.items(): + state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8) + + return state_dict, metadata diff --git a/comfy/weight_adapter/lora.py b/comfy/weight_adapter/lora.py index 58e3a6be9..b81ca14b2 100644 --- a/comfy/weight_adapter/lora.py +++ b/comfy/weight_adapter/lora.py @@ -197,6 +197,7 @@ class LoRAAdapter(WeightAdapterBase): lora_diff = torch.mm( mat1.flatten(start_dim=1), mat2.flatten(start_dim=1) ).reshape(weight.shape) + del mat1, mat2 if dora_scale is not None: weight = weight_decompose( dora_scale, diff --git a/comfy_api/feature_flags.py b/comfy_api/feature_flags.py index c6ccc4ec2..838cea3c1 100644 --- a/comfy_api/feature_flags.py +++ b/comfy_api/feature_flags.py @@ -13,6 +13,7 @@ from comfy.cli_args import args SERVER_FEATURE_FLAGS: Dict[str, Any] = { "supports_preview_metadata": True, "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes + "extension": {"manager": {"supports_v4": True}}, } diff --git a/comfy_api/internal/async_to_sync.py b/comfy_api/internal/async_to_sync.py index 11b105d64..0674127f2 100644 --- a/comfy_api/internal/async_to_sync.py +++ b/comfy_api/internal/async_to_sync.py @@ -8,7 +8,7 @@ import os import textwrap import threading from enum import Enum -from typing import Optional, Type, get_origin, get_args +from typing import Optional, Type, get_origin, get_args, get_type_hints class TypeTracker: @@ -225,11 +225,18 @@ class AsyncToSyncConverter: self._async_instance = async_class(*args, **kwargs) # Handle annotated class attributes (like execution: Execution) - # Get all annotations from the class hierarchy - all_annotations = {} - for base_class in reversed(inspect.getmro(async_class)): - if hasattr(base_class, "__annotations__"): - all_annotations.update(base_class.__annotations__) + # Get all annotations from the class hierarchy and resolve string annotations + try: + # get_type_hints resolves string annotations to actual type objects + # This handles classes using 'from __future__ import annotations' + all_annotations = get_type_hints(async_class) + except Exception: + # Fallback to raw annotations if get_type_hints fails + # (e.g., for undefined forward references) + all_annotations = {} + for base_class in reversed(inspect.getmro(async_class)): + if hasattr(base_class, "__annotations__"): + all_annotations.update(base_class.__annotations__) # For each annotated attribute, check if it needs to be created or wrapped for attr_name, attr_type in all_annotations.items(): @@ -630,15 +637,19 @@ class AsyncToSyncConverter: """Extract class attributes that are classes themselves.""" class_attributes = [] + # Get resolved type hints to handle string annotations + try: + type_hints = get_type_hints(async_class) + except Exception: + type_hints = {} + # Look for class attributes that are classes for name, attr in sorted(inspect.getmembers(async_class)): if isinstance(attr, type) and not name.startswith("_"): class_attributes.append((name, attr)) - elif ( - hasattr(async_class, "__annotations__") - and name in async_class.__annotations__ - ): - annotation = async_class.__annotations__[name] + elif name in type_hints: + # Use resolved type hint instead of raw annotation + annotation = type_hints[name] if isinstance(annotation, type): class_attributes.append((name, annotation)) @@ -913,11 +924,15 @@ class AsyncToSyncConverter: attribute_mappings = {} # First check annotations for typed attributes (including from parent classes) - # Collect all annotations from the class hierarchy - all_annotations = {} - for base_class in reversed(inspect.getmro(async_class)): - if hasattr(base_class, "__annotations__"): - all_annotations.update(base_class.__annotations__) + # Resolve string annotations to actual types + try: + all_annotations = get_type_hints(async_class) + except Exception: + # Fallback to raw annotations + all_annotations = {} + for base_class in reversed(inspect.getmro(async_class)): + if hasattr(base_class, "__annotations__"): + all_annotations.update(base_class.__annotations__) for attr_name, attr_type in sorted(all_annotations.items()): for class_name, class_type in class_attributes: diff --git a/comfy_api/latest/__init__.py b/comfy_api/latest/__init__.py index b7a3fa9c1..35e1ac853 100644 --- a/comfy_api/latest/__init__.py +++ b/comfy_api/latest/__init__.py @@ -5,11 +5,11 @@ from typing import Type, TYPE_CHECKING from comfy_api.internal import ComfyAPIBase from comfy_api.internal.singleton import ProxiedSingleton from comfy_api.internal.async_to_sync import create_sync_class -from comfy_api.latest._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput -from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents -from comfy_api.latest._util import VideoCodec, VideoContainer, VideoComponents -from . import _io as io -from . import _ui as ui +from ._input import ImageInput, AudioInput, MaskInput, LatentInput, VideoInput +from ._input_impl import VideoFromFile, VideoFromComponents +from ._util import VideoCodec, VideoContainer, VideoComponents, MESH, VOXEL +from . import _io_public as io +from . import _ui_public as ui # from comfy_api.latest._resources import _RESOURCES as resources #noqa: F401 from comfy_execution.utils import get_executing_context from comfy_execution.progress import get_progress_state, PreviewImageTuple @@ -80,7 +80,7 @@ class ComfyExtension(ABC): async def on_load(self) -> None: """ Called when an extension is loaded. - This should be used to initialize any global resources neeeded by the extension. + This should be used to initialize any global resources needed by the extension. """ @abstractmethod @@ -104,6 +104,8 @@ class Types: VideoCodec = VideoCodec VideoContainer = VideoContainer VideoComponents = VideoComponents + MESH = MESH + VOXEL = VOXEL ComfyAPI = ComfyAPI_latest diff --git a/comfy_api/latest/_input/video_types.py b/comfy_api/latest/_input/video_types.py index a335df4d0..e634a0311 100644 --- a/comfy_api/latest/_input/video_types.py +++ b/comfy_api/latest/_input/video_types.py @@ -1,9 +1,10 @@ from __future__ import annotations from abc import ABC, abstractmethod +from fractions import Fraction from typing import Optional, Union, IO import io import av -from comfy_api.util import VideoContainer, VideoCodec, VideoComponents +from .._util import VideoContainer, VideoCodec, VideoComponents class VideoInput(ABC): """ @@ -72,6 +73,33 @@ class VideoInput(ABC): frame_count = components.images.shape[0] return float(frame_count / components.frame_rate) + def get_frame_count(self) -> int: + """ + Returns the number of frames in the video. + + Default implementation uses :meth:`get_components`, which may require + loading all frames into memory. File-based implementations should + override this method and use container/stream metadata instead. + + Returns: + Total number of frames as an integer. + """ + return int(self.get_components().images.shape[0]) + + def get_frame_rate(self) -> Fraction: + """ + Returns the frame rate of the video. + + Default implementation materializes the video into memory via + `get_components()`. Subclasses that can inspect the underlying + container (e.g. `VideoFromFile`) should override this with a more + efficient implementation. + + Returns: + Frame rate as a Fraction. + """ + return self.get_components().frame_rate + def get_container_format(self) -> str: """ Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). diff --git a/comfy_api/latest/_input_impl/video_types.py b/comfy_api/latest/_input_impl/video_types.py index b7291bef1..1c733f571 100644 --- a/comfy_api/latest/_input_impl/video_types.py +++ b/comfy_api/latest/_input_impl/video_types.py @@ -3,14 +3,14 @@ from av.container import InputContainer from av.subtitles.stream import SubtitleStream # pylint: disable=no-name-in-module from fractions import Fraction from typing import Optional -from comfy_api.latest._input import AudioInput, VideoInput +from .._input import AudioInput, VideoInput import av import io import json import numpy as np import math import torch -from comfy_api.latest._util import VideoContainer, VideoCodec, VideoComponents +from .._util import VideoContainer, VideoCodec, VideoComponents def container_to_output_format(container_format: str | None) -> str | None: @@ -121,6 +121,71 @@ class VideoFromFile(VideoInput): raise ValueError(f"Could not determine duration for file '{self.__file}'") + def get_frame_count(self) -> int: + """ + Returns the number of frames in the video without materializing them as + torch tensors. + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + + with av.open(self.__file, mode="r") as container: + video_stream = self._get_first_video_stream(container) + # 1. Prefer the frames field if available + if video_stream.frames and video_stream.frames > 0: + return int(video_stream.frames) + + # 2. Try to estimate from duration and average_rate using only metadata + if container.duration is not None and video_stream.average_rate: + duration_seconds = float(container.duration / av.time_base) + estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) + if estimated_frames > 0: + return estimated_frames + + if ( + getattr(video_stream, "duration", None) is not None + and getattr(video_stream, "time_base", None) is not None + and video_stream.average_rate + ): + duration_seconds = float(video_stream.duration * video_stream.time_base) + estimated_frames = int(round(duration_seconds * float(video_stream.average_rate))) + if estimated_frames > 0: + return estimated_frames + + # 3. Last resort: decode frames and count them (streaming) + frame_count = 0 + container.seek(0) + for packet in container.demux(video_stream): + for _ in packet.decode(): + frame_count += 1 + + if frame_count == 0: + raise ValueError(f"Could not determine frame count for file '{self.__file}'") + return frame_count + + def get_frame_rate(self) -> Fraction: + """ + Returns the average frame rate of the video using container metadata + without decoding all frames. + """ + if isinstance(self.__file, io.BytesIO): + self.__file.seek(0) + + with av.open(self.__file, mode="r") as container: + video_stream = self._get_first_video_stream(container) + # Preferred: use PyAV's average_rate (usually already a Fraction-like) + if video_stream.average_rate: + return Fraction(video_stream.average_rate) + + # Fallback: estimate from frames + duration if available + if video_stream.frames and container.duration: + duration_seconds = float(container.duration / av.time_base) + if duration_seconds > 0: + return Fraction(video_stream.frames / duration_seconds).limit_denominator() + + # Last resort: match get_components_internal default + return Fraction(1) + def get_container_format(self) -> str: """ Returns the container format of the video (e.g., 'mp4', 'mov', 'avi'). @@ -238,6 +303,13 @@ class VideoFromFile(VideoInput): packet.stream = stream_map[packet.stream] output_container.mux(packet) + def _get_first_video_stream(self, container: InputContainer): + video_stream = next((s for s in container.streams if s.type == "video"), None) + if video_stream is None: + raise ValueError(f"No video stream found in file '{self.__file}'") + return video_stream + + class VideoFromComponents(VideoInput): """ Class representing video input from tensors. @@ -264,7 +336,10 @@ class VideoFromComponents(VideoInput): raise ValueError("Only MP4 format is supported for now") if codec != VideoCodec.AUTO and codec != VideoCodec.H264: raise ValueError("Only H264 codec is supported for now") - with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}) as output: + extra_kwargs = {} + if isinstance(format, VideoContainer) and format != VideoContainer.AUTO: + extra_kwargs["format"] = format.value + with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output: # Add metadata before writing any streams if metadata is not None: for key, value in metadata.items(): diff --git a/comfy_api/latest/_io.py b/comfy_api/latest/_io.py index eaa6cc181..11f85444e 100644 --- a/comfy_api/latest/_io.py +++ b/comfy_api/latest/_io.py @@ -4,7 +4,8 @@ import copy import inspect from abc import ABC, abstractmethod from collections import Counter -from dataclasses import asdict, dataclass +from collections.abc import Iterable +from dataclasses import asdict, dataclass, field from enum import Enum from typing import Any, Callable, Literal, TypedDict, TypeVar, TYPE_CHECKING from typing_extensions import NotRequired, final @@ -25,8 +26,9 @@ if TYPE_CHECKING: from comfy_api.input import VideoInput from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class, prune_dict, shallow_clone_class) -from comfy_api.latest._resources import Resources, ResourcesLocal +from ._resources import Resources, ResourcesLocal from comfy_execution.graph_utils import ExecutionBlocker +from ._util import MESH, VOXEL # from comfy_extras.nodes_images import SVG as SVG_ # NOTE: needs to be moved before can be imported due to circular reference @@ -149,6 +151,9 @@ class _IO_V3: def __init__(self): pass + def validate(self): + pass + @property def io_type(self): return self.Parent.io_type @@ -181,6 +186,9 @@ class Input(_IO_V3): def get_io_type(self): return _StringIOType(self.io_type) + def get_all(self) -> list[Input]: + return [self] + class WidgetInput(Input): ''' Base class for a V3 Input with widget. @@ -560,6 +568,8 @@ class Conditioning(ComfyTypeIO): '''Used by WAN Camera.''' time_dim_concat: NotRequired[torch.Tensor] '''Used by WAN Phantom Subject.''' + time_dim_replace: NotRequired[torch.Tensor] + '''Used by Kandinsky5 I2V.''' CondList = list[tuple[torch.Tensor, PooledDict]] Type = CondList @@ -628,6 +638,10 @@ class UpscaleModel(ComfyTypeIO): if TYPE_CHECKING: Type = ImageModelDescriptor +@comfytype(io_type="LATENT_UPSCALE_MODEL") +class LatentUpscaleModel(ComfyTypeIO): + Type = Any + @comfytype(io_type="AUDIO") class Audio(ComfyTypeIO): class AudioDict(TypedDict): @@ -656,11 +670,11 @@ class LossMap(ComfyTypeIO): @comfytype(io_type="VOXEL") class Voxel(ComfyTypeIO): - Type = Any # TODO: VOXEL class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3 + Type = VOXEL @comfytype(io_type="MESH") class Mesh(ComfyTypeIO): - Type = Any # TODO: MESH class is defined in comfy_extras/nodes_hunyuan3d.py; should be moved to somewhere else before referenced directly in v3 + Type = MESH @comfytype(io_type="HOOKS") class Hooks(ComfyTypeIO): @@ -809,13 +823,61 @@ class MultiType: else: return super().as_dict() +@comfytype(io_type="COMFY_MATCHTYPE_V3") +class MatchType(ComfyTypeIO): + class Template: + def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType] = AnyType): + self.template_id = template_id + # account for syntactic sugar + if not isinstance(allowed_types, Iterable): + allowed_types = [allowed_types] + for t in allowed_types: + if not isinstance(t, type): + if not isinstance(t, _ComfyType): + raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__class__.__name__}") + else: + if not issubclass(t, _ComfyType): + raise ValueError(f"Allowed types must be a ComfyType or a list of ComfyTypes, got {t.__name__}") + self.allowed_types = allowed_types + + def as_dict(self): + return { + "template_id": self.template_id, + "allowed_types": ",".join([t.io_type for t in self.allowed_types]), + } + + class Input(Input): + def __init__(self, id: str, template: MatchType.Template, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + + class Output(Output): + def __init__(self, template: MatchType.Template, id: str=None, display_name: str=None, tooltip: str=None, + is_output_list=False): + super().__init__(id, display_name, tooltip, is_output_list) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) + class DynamicInput(Input, ABC): ''' Abstract class for dynamic input registration. ''' - @abstractmethod def get_dynamic(self) -> list[Input]: - ... + return [] + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + pass + class DynamicOutput(Output, ABC): ''' @@ -825,99 +887,223 @@ class DynamicOutput(Output, ABC): is_output_list=False): super().__init__(id, display_name, tooltip, is_output_list) - @abstractmethod def get_dynamic(self) -> list[Output]: - ... + return [] @comfytype(io_type="COMFY_AUTOGROW_V3") -class AutogrowDynamic(ComfyTypeI): - Type = list[Any] - class Input(DynamicInput): - def __init__(self, id: str, template_input: Input, min: int=1, max: int=None, - display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): - super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) - self.template_input = template_input - if min is not None: - assert(min >= 1) - if max is not None: - assert(max >= 1) +class Autogrow(ComfyTypeI): + Type = dict[str, Any] + _MaxNames = 100 # NOTE: max 100 names for sanity + + class _AutogrowTemplate: + def __init__(self, input: Input): + # dynamic inputs are not allowed as the template input + assert(not isinstance(input, DynamicInput)) + self.input = copy.copy(input) + if isinstance(self.input, WidgetInput): + self.input.force_input = True + self.names: list[str] = [] + self.cached_inputs = {} + + def _create_input(self, input: Input, name: str): + new_input = copy.copy(self.input) + new_input.id = name + return new_input + + def _create_cached_inputs(self): + for name in self.names: + self.cached_inputs[name] = self._create_input(self.input, name) + + def get_all(self) -> list[Input]: + return list(self.cached_inputs.values()) + + def as_dict(self): + return prune_dict({ + "input": create_input_dict_v1([self.input]), + }) + + def validate(self): + self.input.validate() + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + real_inputs = [] + for name, input in self.cached_inputs.items(): + if name in live_inputs: + real_inputs.append(input) + add_to_input_dict_v1(d, real_inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, real_inputs, curr_prefix) + + class TemplatePrefix(_AutogrowTemplate): + def __init__(self, input: Input, prefix: str, min: int=1, max: int=10): + super().__init__(input) + self.prefix = prefix + assert(min >= 0) + assert(max >= 1) + assert(max <= Autogrow._MaxNames) self.min = min self.max = max + self.names = [f"{self.prefix}{i}" for i in range(self.max)] + self._create_cached_inputs() + + def as_dict(self): + return super().as_dict() | prune_dict({ + "prefix": self.prefix, + "min": self.min, + "max": self.max, + }) + + class TemplateNames(_AutogrowTemplate): + def __init__(self, input: Input, names: list[str], min: int=1): + super().__init__(input) + self.names = names[:Autogrow._MaxNames] + assert(min >= 0) + self.min = min + self._create_cached_inputs() + + def as_dict(self): + return super().as_dict() | prune_dict({ + "names": self.names, + "min": self.min, + }) + + class Input(DynamicInput): + def __init__(self, id: str, template: Autogrow.TemplatePrefix | Autogrow.TemplateNames, + display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): + super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) + self.template = template + + def as_dict(self): + return super().as_dict() | prune_dict({ + "template": self.template.as_dict(), + }) def get_dynamic(self) -> list[Input]: - curr_count = 1 - new_inputs = [] - for i in range(self.min): - new_input = copy.copy(self.template_input) - new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$" - if new_input.display_name is not None: - new_input.display_name = f"{new_input.display_name}{curr_count}" - new_input.optional = self.optional or new_input.optional - if isinstance(self.template_input, WidgetInput): - new_input.force_input = True - new_inputs.append(new_input) - curr_count += 1 - # pretend to expand up to max - for i in range(curr_count-1, self.max): - new_input = copy.copy(self.template_input) - new_input.id = f"{new_input.id}{curr_count}_${self.id}_ag$" - if new_input.display_name is not None: - new_input.display_name = f"{new_input.display_name}{curr_count}" - new_input.optional = True - if isinstance(self.template_input, WidgetInput): - new_input.force_input = True - new_inputs.append(new_input) - curr_count += 1 - return new_inputs + return self.template.get_all() -@comfytype(io_type="COMFY_COMBODYNAMIC_V3") -class ComboDynamic(ComfyTypeI): - class Input(DynamicInput): - def __init__(self, id: str): - pass + def get_all(self) -> list[Input]: + return [self] + self.template.get_all() -@comfytype(io_type="COMFY_MATCHTYPE_V3") -class MatchType(ComfyTypeIO): - class Template: - def __init__(self, template_id: str, allowed_types: _ComfyType | list[_ComfyType]): - self.template_id = template_id - self.allowed_types = [allowed_types] if isinstance(allowed_types, _ComfyType) else allowed_types + def validate(self): + self.template.validate() + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + curr_prefix = f"{curr_prefix}{self.id}." + # need to remove self from expected inputs dictionary; replaced by template inputs in frontend + for inner_dict in d.values(): + if self.id in inner_dict: + del inner_dict[self.id] + self.template.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + +@comfytype(io_type="COMFY_DYNAMICCOMBO_V3") +class DynamicCombo(ComfyTypeI): + Type = dict[str, Any] + + class Option: + def __init__(self, key: str, inputs: list[Input]): + self.key = key + self.inputs = inputs def as_dict(self): return { - "template_id": self.template_id, - "allowed_types": "".join(t.io_type for t in self.allowed_types), + "key": self.key, + "inputs": create_input_dict_v1(self.inputs), } class Input(DynamicInput): - def __init__(self, id: str, template: MatchType.Template, + def __init__(self, id: str, options: list[DynamicCombo.Option], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, extra_dict=None): super().__init__(id, display_name, optional, tooltip, lazy, extra_dict) - self.template = template + self.options = options + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + # check if dynamic input's id is in live_inputs + if self.id in live_inputs: + curr_prefix = f"{curr_prefix}{self.id}." + key = live_inputs[self.id] + selected_option = None + for option in self.options: + if option.key == key: + selected_option = option + break + if selected_option is not None: + add_to_input_dict_v1(d, selected_option.inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, selected_option.inputs, curr_prefix, self) def get_dynamic(self) -> list[Input]: - return [self] + return [input for option in self.options for input in option.inputs] + + def get_all(self) -> list[Input]: + return [self] + [input for option in self.options for input in option.inputs] def as_dict(self): return super().as_dict() | prune_dict({ - "template": self.template.as_dict(), + "options": [o.as_dict() for o in self.options], }) - class Output(DynamicOutput): - def __init__(self, id: str, template: MatchType.Template, display_name: str=None, tooltip: str=None, - is_output_list=False): - super().__init__(id, display_name, tooltip, is_output_list) - self.template = template + def validate(self): + # make sure all nested inputs are validated + for option in self.options: + for input in option.inputs: + input.validate() - def get_dynamic(self) -> list[Output]: - return [self] +@comfytype(io_type="COMFY_DYNAMICSLOT_V3") +class DynamicSlot(ComfyTypeI): + Type = dict[str, Any] + + class Input(DynamicInput): + def __init__(self, slot: Input, inputs: list[Input], + display_name: str=None, tooltip: str=None, lazy: bool=None, extra_dict=None): + assert(not isinstance(slot, DynamicInput)) + self.slot = copy.copy(slot) + self.slot.display_name = slot.display_name if slot.display_name is not None else display_name + optional = True + self.slot.tooltip = slot.tooltip if slot.tooltip is not None else tooltip + self.slot.lazy = slot.lazy if slot.lazy is not None else lazy + self.slot.extra_dict = slot.extra_dict if slot.extra_dict is not None else extra_dict + super().__init__(slot.id, self.slot.display_name, optional, self.slot.tooltip, self.slot.lazy, self.slot.extra_dict) + self.inputs = inputs + self.force_input = None + # force widget inputs to have no widgets, otherwise this would be awkward + if isinstance(self.slot, WidgetInput): + self.force_input = True + self.slot.force_input = True + + def expand_schema_for_dynamic(self, d: dict[str, Any], live_inputs: dict[str, Any], curr_prefix=''): + if self.id in live_inputs: + curr_prefix = f"{curr_prefix}{self.id}." + add_to_input_dict_v1(d, self.inputs, live_inputs, curr_prefix) + add_dynamic_id_mapping(d, [self.slot] + self.inputs, curr_prefix) + + def get_dynamic(self) -> list[Input]: + return [self.slot] + self.inputs + + def get_all(self) -> list[Input]: + return [self] + [self.slot] + self.inputs def as_dict(self): return super().as_dict() | prune_dict({ - "template": self.template.as_dict(), + "slotType": str(self.slot.get_io_type()), + "inputs": create_input_dict_v1(self.inputs), + "forceInput": self.force_input, }) + def validate(self): + self.slot.validate() + for input in self.inputs: + input.validate() + +def add_dynamic_id_mapping(d: dict[str, Any], inputs: list[Input], curr_prefix: str, self: DynamicInput=None): + dynamic = d.setdefault("dynamic_paths", {}) + if self is not None: + dynamic[self.id] = f"{curr_prefix}{self.id}" + for i in inputs: + if not isinstance(i, DynamicInput): + dynamic[f"{i.id}"] = f"{curr_prefix}{i.id}" + +class V3Data(TypedDict): + hidden_inputs: dict[str, Any] + dynamic_paths: dict[str, Any] class HiddenHolder: def __init__(self, unique_id: str, prompt: Any, @@ -979,6 +1165,7 @@ class NodeInfoV1: output_is_list: list[bool]=None output_name: list[str]=None output_tooltips: list[str]=None + output_matchtypes: list[str]=None name: str=None display_name: str=None description: str=None @@ -1015,9 +1202,9 @@ class Schema: """Display name of node.""" category: str = "sd" """The category of the node, as per the "Add Node" menu.""" - inputs: list[Input]=None - outputs: list[Output]=None - hidden: list[Hidden]=None + inputs: list[Input] = field(default_factory=list) + outputs: list[Output] = field(default_factory=list) + hidden: list[Hidden] = field(default_factory=list) description: str="" """Node description, shown as a tooltip when hovering over the node.""" is_input_list: bool = False @@ -1057,7 +1244,11 @@ class Schema: '''Validate the schema: - verify ids on inputs and outputs are unique - both internally and in relation to each other ''' - input_ids = [i.id for i in self.inputs] if self.inputs is not None else [] + nested_inputs: list[Input] = [] + if self.inputs is not None: + for input in self.inputs: + nested_inputs.extend(input.get_all()) + input_ids = [i.id for i in nested_inputs] if nested_inputs is not None else [] output_ids = [o.id for o in self.outputs] if self.outputs is not None else [] input_set = set(input_ids) output_set = set(output_ids) @@ -1073,6 +1264,13 @@ class Schema: issues.append(f"Ids must be unique between inputs and outputs, but {intersection} are not.") if len(issues) > 0: raise ValueError("\n".join(issues)) + # validate inputs and outputs + if self.inputs is not None: + for input in self.inputs: + input.validate() + if self.outputs is not None: + for output in self.outputs: + output.validate() def finalize(self): """Add hidden based on selected schema options, and give outputs without ids default ids.""" @@ -1098,19 +1296,10 @@ class Schema: if output.id is None: output.id = f"_{i}_{output.io_type}_" - def get_v1_info(self, cls) -> NodeInfoV1: + def get_v1_info(self, cls, live_inputs: dict[str, Any]=None) -> NodeInfoV1: + # NOTE: live_inputs will not be used anymore very soon and this will be done another way # get V1 inputs - input = { - "required": {} - } - if self.inputs: - for i in self.inputs: - if isinstance(i, DynamicInput): - dynamic_inputs = i.get_dynamic() - for d in dynamic_inputs: - add_to_dict_v1(d, input) - else: - add_to_dict_v1(i, input) + input = create_input_dict_v1(self.inputs, live_inputs) if self.hidden: for hidden in self.hidden: input.setdefault("hidden", {})[hidden.name] = (hidden.value,) @@ -1119,12 +1308,24 @@ class Schema: output_is_list = [] output_name = [] output_tooltips = [] + output_matchtypes = [] + any_matchtypes = False if self.outputs: for o in self.outputs: output.append(o.io_type) output_is_list.append(o.is_output_list) output_name.append(o.display_name if o.display_name else o.io_type) output_tooltips.append(o.tooltip if o.tooltip else None) + # special handling for MatchType + if isinstance(o, MatchType.Output): + output_matchtypes.append(o.template.template_id) + any_matchtypes = True + else: + output_matchtypes.append(None) + + # clear out lists that are all None + if not any_matchtypes: + output_matchtypes = None info = NodeInfoV1( input=input, @@ -1133,6 +1334,7 @@ class Schema: output_is_list=output_is_list, output_name=output_name, output_tooltips=output_tooltips, + output_matchtypes=output_matchtypes, name=self.node_id, display_name=self.display_name, category=self.category, @@ -1178,16 +1380,57 @@ class Schema: return info -def add_to_dict_v1(i: Input, input: dict): +def create_input_dict_v1(inputs: list[Input], live_inputs: dict[str, Any]=None) -> dict: + input = { + "required": {} + } + add_to_input_dict_v1(input, inputs, live_inputs) + return input + +def add_to_input_dict_v1(d: dict[str, Any], inputs: list[Input], live_inputs: dict[str, Any]=None, curr_prefix=''): + for i in inputs: + if isinstance(i, DynamicInput): + add_to_dict_v1(i, d) + if live_inputs is not None: + i.expand_schema_for_dynamic(d, live_inputs, curr_prefix) + else: + add_to_dict_v1(i, d) + +def add_to_dict_v1(i: Input, d: dict, dynamic_dict: dict=None): key = "optional" if i.optional else "required" as_dict = i.as_dict() # for v1, we don't want to include the optional key as_dict.pop("optional", None) - input.setdefault(key, {})[i.id] = (i.get_io_type(), as_dict) + if dynamic_dict is None: + value = (i.get_io_type(), as_dict) + else: + value = (i.get_io_type(), as_dict, dynamic_dict) + d.setdefault(key, {})[i.id] = value def add_to_dict_v3(io: Input | Output, d: dict): d[io.id] = (io.get_io_type(), io.as_dict()) +def build_nested_inputs(values: dict[str, Any], v3_data: V3Data): + paths = v3_data.get("dynamic_paths", None) + if paths is None: + return values + values = values.copy() + result = {} + + for key, path in paths.items(): + parts = path.split(".") + current = result + + for i, p in enumerate(parts): + is_last = (i == len(parts) - 1) + + if is_last: + current[p] = values.pop(key, None) + else: + current = current.setdefault(p, {}) + + values.update(result) + return values class _ComfyNodeBaseInternal(_ComfyNodeInternal): @@ -1307,12 +1550,12 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def PREPARE_CLASS_CLONE(cls, hidden_inputs: dict) -> type[ComfyNode]: + def PREPARE_CLASS_CLONE(cls, v3_data: V3Data) -> type[ComfyNode]: """Creates clone of real node class to prevent monkey-patching.""" c_type: type[ComfyNode] = cls if is_class(cls) else type(cls) type_clone: type[ComfyNode] = shallow_clone_class(c_type) # set hidden - type_clone.hidden = HiddenHolder.from_dict(hidden_inputs) + type_clone.hidden = HiddenHolder.from_dict(v3_data["hidden_inputs"]) return type_clone @final @@ -1429,14 +1672,18 @@ class _ComfyNodeBaseInternal(_ComfyNodeInternal): @final @classmethod - def INPUT_TYPES(cls, include_hidden=True, return_schema=False) -> dict[str, dict] | tuple[dict[str, dict], Schema]: + def INPUT_TYPES(cls, include_hidden=True, return_schema=False, live_inputs=None) -> dict[str, dict] | tuple[dict[str, dict], Schema, V3Data]: schema = cls.FINALIZE_SCHEMA() - info = schema.get_v1_info(cls) + info = schema.get_v1_info(cls, live_inputs) input = info.input if not include_hidden: input.pop("hidden", None) if return_schema: - return input, schema + v3_data: V3Data = {} + dynamic = input.pop("dynamic_paths", None) + if dynamic is not None: + v3_data["dynamic_paths"] = dynamic + return input, schema, v3_data return input @final @@ -1509,7 +1756,7 @@ class ComfyNode(_ComfyNodeBaseInternal): raise NotImplementedError @classmethod - def validate_inputs(cls, **kwargs) -> bool: + def validate_inputs(cls, **kwargs) -> bool | str: """Optionally, define this function to validate inputs; equivalent to V1's VALIDATE_INPUTS.""" raise NotImplementedError @@ -1625,6 +1872,7 @@ __all__ = [ "StyleModel", "Gligen", "UpscaleModel", + "LatentUpscaleModel", "Audio", "Video", "SVG", @@ -1648,6 +1896,10 @@ __all__ = [ "SEGS", "AnyType", "MultiType", + # Dynamic Types + "MatchType", + # "DynamicCombo", + # "Autogrow", # Other classes "HiddenHolder", "Hidden", @@ -1658,4 +1910,5 @@ __all__ = [ "NodeOutput", "add_to_dict_v1", "add_to_dict_v3", + "V3Data", ] diff --git a/comfy_api/latest/_io_public.py b/comfy_api/latest/_io_public.py new file mode 100644 index 000000000..43c7680f3 --- /dev/null +++ b/comfy_api/latest/_io_public.py @@ -0,0 +1 @@ +from ._io import * # noqa: F403 diff --git a/comfy_api/latest/_ui.py b/comfy_api/latest/_ui.py index f9f56cfd8..9e9606a4b 100644 --- a/comfy_api/latest/_ui.py +++ b/comfy_api/latest/_ui.py @@ -4,6 +4,7 @@ import json import logging import os import random +import uuid from io import BytesIO from typing import Type @@ -21,7 +22,7 @@ from PIL.PngImagePlugin import PngInfo # used for image preview from comfy.cli_args import args from comfy.cmd import folder_paths -from comfy_api.latest._io import ComfyNode, FolderType, Image, _UIOutput +from ._io import ComfyNode, FolderType, Image, _UIOutput logger = logging.getLogger(__name__) @@ -324,9 +325,10 @@ class AudioSaveHelper: for key, value in metadata.items(): output_container.metadata[key] = value + layout = "mono" if waveform.shape[0] == 1 else "stereo" # Set up the output stream with appropriate properties if format == "opus": - out_stream = output_container.add_stream("libopus", rate=sample_rate) + out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout) if quality == "64k": out_stream.bit_rate = 64000 elif quality == "96k": @@ -338,7 +340,7 @@ class AudioSaveHelper: elif quality == "320k": out_stream.bit_rate = 320000 elif format == "mp3": - out_stream = output_container.add_stream("libmp3lame", rate=sample_rate) + out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout) if quality == "V0": # TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool out_stream.codec_context.qscale = 1 @@ -347,12 +349,12 @@ class AudioSaveHelper: elif quality == "320k": out_stream.bit_rate = 320000 else: # format == "flac": - out_stream = output_container.add_stream("flac", rate=sample_rate) + out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout) frame = av.AudioFrame.from_ndarray( waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format="flt", - layout="mono" if waveform.shape[0] == 1 else "stereo", + layout=layout, ) frame.sample_rate = sample_rate frame.pts = 0 @@ -442,9 +444,19 @@ class PreviewUI3D(_UIOutput): def __init__(self, model_file, camera_info, **kwargs): self.model_file = model_file self.camera_info = camera_info + self.bg_image_path = None + bg_image = kwargs.get("bg_image", None) + if bg_image is not None: + img_array = (bg_image[0].cpu().numpy() * 255).astype(np.uint8) + img = PILImage.fromarray(img_array) + temp_dir = folder_paths.get_temp_directory() + filename = f"bg_{uuid.uuid4().hex}.png" + bg_image_path = os.path.join(temp_dir, filename) + img.save(bg_image_path, compress_level=1) + self.bg_image_path = f"temp/{filename}" def as_dict(self): - return {"result": [self.model_file, self.camera_info]} + return {"result": [self.model_file, self.camera_info, self.bg_image_path]} class PreviewText(_UIOutput): diff --git a/comfy_api/latest/_ui_public.py b/comfy_api/latest/_ui_public.py new file mode 100644 index 000000000..85b11d78b --- /dev/null +++ b/comfy_api/latest/_ui_public.py @@ -0,0 +1 @@ +from ._ui import * # noqa: F403 diff --git a/comfy_api/latest/_util/__init__.py b/comfy_api/latest/_util/__init__.py index 9019c46db..fc5431dda 100644 --- a/comfy_api/latest/_util/__init__.py +++ b/comfy_api/latest/_util/__init__.py @@ -1,8 +1,11 @@ from .video_types import VideoContainer, VideoCodec, VideoComponents +from .geometry_types import VOXEL, MESH __all__ = [ # Utility Types "VideoContainer", "VideoCodec", "VideoComponents", + "VOXEL", + "MESH", ] diff --git a/comfy_api/latest/_util/geometry_types.py b/comfy_api/latest/_util/geometry_types.py new file mode 100644 index 000000000..385122778 --- /dev/null +++ b/comfy_api/latest/_util/geometry_types.py @@ -0,0 +1,12 @@ +import torch + + +class VOXEL: + def __init__(self, data: torch.Tensor): + self.data = data + + +class MESH: + def __init__(self, vertices: torch.Tensor, faces: torch.Tensor): + self.vertices = vertices + self.faces = faces diff --git a/comfy_api/latest/_util/video_types.py b/comfy_api/latest/_util/video_types.py index c3e3d8e3a..fd3b5a510 100644 --- a/comfy_api/latest/_util/video_types.py +++ b/comfy_api/latest/_util/video_types.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum from fractions import Fraction from typing import Optional -from comfy_api.latest._input import ImageInput, AudioInput +from .._input import ImageInput, AudioInput class VideoCodec(str, Enum): AUTO = "auto" diff --git a/comfy_api/v0_0_2/__init__.py b/comfy_api/v0_0_2/__init__.py index de0f95001..c4fa1d971 100644 --- a/comfy_api/v0_0_2/__init__.py +++ b/comfy_api/v0_0_2/__init__.py @@ -6,7 +6,7 @@ from comfy_api.latest import ( ) from typing import Type, TYPE_CHECKING from comfy_api.internal.async_to_sync import create_sync_class -from comfy_api.latest import io, ui, ComfyExtension #noqa: F401 +from comfy_api.latest import io, ui, IO, UI, ComfyExtension #noqa: F401 class ComfyAPIAdapter_v0_0_2(ComfyAPI_latest): @@ -42,4 +42,8 @@ __all__ = [ "InputImpl", "Types", "ComfyExtension", + "io", + "IO", + "ui", + "UI", ] diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py deleted file mode 100644 index e37ab486c..000000000 --- a/comfy_api_nodes/apinode_utils.py +++ /dev/null @@ -1,718 +0,0 @@ -from __future__ import annotations -import aiohttp -import io -import logging -import mimetypes -import os -from typing import Optional, Union -from comfy.utils import common_upscale -from comfy_api.input_impl import VideoFromFile -from comfy_api.util import VideoContainer, VideoCodec -from comfy_api.input.video_types import VideoInput -from comfy_api.input.basic_types import AudioInput -from comfy_api_nodes.apis.client import ( - ApiClient, - ApiEndpoint, - HttpMethod, - SynchronousOperation, - UploadRequest, - UploadResponse, -) -from comfy.cmd.server import PromptServer -from comfy.cli_args import args - -import numpy as np -from PIL import Image -import torch -import math -import base64 -import uuid -from io import BytesIO -import av - - -async def download_url_to_video_output( - video_url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None -) -> VideoFromFile: - """Downloads a video from a URL and returns a `VIDEO` output. - - Args: - video_url: The URL of the video to download. - - Returns: - A Comfy node `VIDEO` output. - """ - video_io = await download_url_to_bytesio(video_url, timeout, auth_kwargs=auth_kwargs) - if video_io is None: - error_msg = f"Failed to download video from {video_url}" - logging.error(error_msg) - raise ValueError(error_msg) - return VideoFromFile(video_io) - - -def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: - """Downscale input image tensor to roughly the specified total pixels.""" - samples = image.movedim(-1, 1) - total = int(total_pixels) - scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) - if scale_by >= 1: - return image - width = round(samples.shape[3] * scale_by) - height = round(samples.shape[2] * scale_by) - - s = common_upscale(samples, width, height, "lanczos", "disabled") - s = s.movedim(1, -1) - return s - - -async def validate_and_cast_response( - response, timeout: int = None, node_id: Union[str, None] = None -) -> torch.Tensor: - """Validates and casts a response to a torch.Tensor. - - Args: - response: The response to validate and cast. - timeout: Request timeout in seconds. Defaults to None (no timeout). - - Returns: - A torch.Tensor representing the image (1, H, W, C). - - Raises: - ValueError: If the response is not valid. - """ - # validate raw JSON response - data = response.data - if not data or len(data) == 0: - raise ValueError("No images returned from API endpoint") - - # Initialize list to store image tensors - image_tensors: list[torch.Tensor] = [] - - # Process each image in the data array - async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: - for img_data in data: - img_bytes: bytes - if img_data.b64_json: - img_bytes = base64.b64decode(img_data.b64_json) - elif img_data.url: - if node_id: - PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id) - async with session.get(img_data.url) as resp: - if resp.status != 200: - raise ValueError("Failed to download generated image") - img_bytes = await resp.read() - else: - raise ValueError("Invalid image payload – neither URL nor base64 data present.") - - pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA") - arr = np.asarray(pil_img).astype(np.float32) / 255.0 - image_tensors.append(torch.from_numpy(arr)) - - return torch.stack(image_tensors, dim=0) - - -def validate_aspect_ratio( - aspect_ratio: str, - minimum_ratio: float, - maximum_ratio: float, - minimum_ratio_str: str, - maximum_ratio_str: str, -) -> float: - """Validates and casts an aspect ratio string to a float. - - Args: - aspect_ratio: The aspect ratio string to validate. - minimum_ratio: The minimum aspect ratio. - maximum_ratio: The maximum aspect ratio. - minimum_ratio_str: The minimum aspect ratio string. - maximum_ratio_str: The maximum aspect ratio string. - - Returns: - The validated and cast aspect ratio. - - Raises: - Exception: If the aspect ratio is not valid. - """ - # get ratio values - numbers = aspect_ratio.split(":") - if len(numbers) != 2: - raise TypeError( - f"Aspect ratio must be in the format X:Y, such as 16:9, but was {aspect_ratio}." - ) - try: - numerator = int(numbers[0]) - denominator = int(numbers[1]) - except ValueError as exc: - raise TypeError( - f"Aspect ratio must contain numbers separated by ':', such as 16:9, but was {aspect_ratio}." - ) from exc - calculated_ratio = numerator / denominator - # if not close to minimum and maximum, check bounds - if not math.isclose(calculated_ratio, minimum_ratio) or not math.isclose( - calculated_ratio, maximum_ratio - ): - if calculated_ratio < minimum_ratio: - raise TypeError( - f"Aspect ratio cannot reduce to any less than {minimum_ratio_str} ({minimum_ratio}), but was {aspect_ratio} ({calculated_ratio})." - ) - if calculated_ratio > maximum_ratio: - raise TypeError( - f"Aspect ratio cannot reduce to any greater than {maximum_ratio_str} ({maximum_ratio}), but was {aspect_ratio} ({calculated_ratio})." - ) - return aspect_ratio - - -def mimetype_to_extension(mime_type: str) -> str: - """Converts a MIME type to a file extension.""" - return mime_type.split("/")[-1].lower() - - -async def download_url_to_bytesio( - url: str, timeout: int = None, auth_kwargs: Optional[dict[str, str]] = None -) -> BytesIO: - """Downloads content from a URL using requests and returns it as BytesIO. - - Args: - url: The URL to download. - timeout: Request timeout in seconds. Defaults to None (no timeout). - - Returns: - BytesIO object containing the downloaded content. - """ - headers = {} - if url.startswith("/proxy/"): - url = str(args.comfy_api_base).rstrip("/") + url - auth_token = auth_kwargs.get("auth_token") - comfy_api_key = auth_kwargs.get("comfy_api_key") - if auth_token: - headers["Authorization"] = f"Bearer {auth_token}" - elif comfy_api_key: - headers["X-API-KEY"] = comfy_api_key - timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None - async with aiohttp.ClientSession(timeout=timeout_cfg) as session: - async with session.get(url, headers=headers) as resp: - resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) - return BytesIO(await resp.read()) - - -def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: - """Converts image data from BytesIO to a torch.Tensor. - - Args: - image_bytesio: BytesIO object containing the image data. - mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA"). - - Returns: - A torch.Tensor representing the image (1, H, W, C). - - Raises: - PIL.UnidentifiedImageError: If the image data cannot be identified. - ValueError: If the specified mode is invalid. - """ - image = Image.open(image_bytesio) - image = image.convert(mode) - image_array = np.array(image).astype(np.float32) / 255.0 - return torch.from_numpy(image_array).unsqueeze(0) - - -async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: - """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" - image_bytesio = await download_url_to_bytesio(url, timeout) - return bytesio_to_image_tensor(image_bytesio) - - -def process_image_response(response_content: bytes | str) -> torch.Tensor: - """Uses content from a Response object and converts it to a torch.Tensor""" - return bytesio_to_image_tensor(BytesIO(response_content)) - - -def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: - """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" - if len(image.shape) > 3: - image = image[0] - # TODO: remove alpha if not allowed and present - input_tensor = image.cpu() - input_tensor = downscale_image_tensor( - input_tensor.unsqueeze(0), total_pixels=total_pixels - ).squeeze() - image_np = (input_tensor.numpy() * 255).astype(np.uint8) - img = Image.fromarray(image_np) - return img - - -def _pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: - """Converts a PIL Image to a BytesIO object.""" - if not mime_type: - mime_type = "image/png" - - img_byte_arr = io.BytesIO() - # Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG') - pil_format = mime_type.split("/")[-1].upper() - if pil_format == "JPG": - pil_format = "JPEG" - img.save(img_byte_arr, format=pil_format) - img_byte_arr.seek(0) - return img_byte_arr - - -def tensor_to_bytesio( - image: torch.Tensor, - name: Optional[str] = None, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> BytesIO: - """Converts a torch.Tensor image to a named BytesIO object. - - Args: - image: Input torch.Tensor image. - name: Optional filename for the BytesIO object. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). - - Returns: - Named BytesIO object containing the image data, with pointer set to the start of buffer. - """ - if not mime_type: - mime_type = "image/png" - - pil_image = _tensor_to_pil(image, total_pixels=total_pixels) - img_binary = _pil_to_bytesio(pil_image, mime_type=mime_type) - img_binary.name = ( - f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" - ) - return img_binary - - -def tensor_to_base64_string( - image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> str: - """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. - - Args: - image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). - - Returns: - Base64 encoded string of the image. - """ - pil_image = _tensor_to_pil(image_tensor, total_pixels=total_pixels) - img_byte_arr = _pil_to_bytesio(pil_image, mime_type=mime_type) - img_bytes = img_byte_arr.getvalue() - # Encode bytes to base64 string - base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8") - return base64_encoded_string - - -def tensor_to_data_uri( - image_tensor: torch.Tensor, - total_pixels: int = 2048 * 2048, - mime_type: str = "image/png", -) -> str: - """Converts a tensor image to a Data URI string. - - Args: - image_tensor: Input torch.Tensor image. - total_pixels: Maximum total pixels for potential downscaling. - mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). - - Returns: - Data URI string (e.g., 'data:image/png;base64,...'). - """ - base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type) - return f"data:{mime_type};base64,{base64_string}" - - -def text_filepath_to_base64_string(filepath: str) -> str: - """Converts a text file to a base64 string.""" - with open(filepath, "rb") as f: - file_content = f.read() - return base64.b64encode(file_content).decode("utf-8") - - -def text_filepath_to_data_uri(filepath: str) -> str: - """Converts a text file to a data URI.""" - base64_string = text_filepath_to_base64_string(filepath) - mime_type, _ = mimetypes.guess_type(filepath) - if mime_type is None: - mime_type = "application/octet-stream" - return f"data:{mime_type};base64,{base64_string}" - - -async def upload_file_to_comfyapi( - file_bytes_io: BytesIO, - filename: str, - upload_mime_type: Optional[str], - auth_kwargs: Optional[dict[str, str]] = None, -) -> str: - """ - Uploads a single file to ComfyUI API and returns its download URL. - - Args: - file_bytes_io: BytesIO object containing the file data. - filename: The filename of the file. - upload_mime_type: MIME type of the file. - auth_kwargs: Optional authentication token(s). - - Returns: - The download URL for the uploaded file. - """ - if upload_mime_type is None: - request_object = UploadRequest(file_name=filename) - else: - request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/customers/storage", - method=HttpMethod.POST, - request_model=UploadRequest, - response_model=UploadResponse, - ), - request=request_object, - auth_kwargs=auth_kwargs, - ) - - response: UploadResponse = await operation.execute() - await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type) - return response.download_url - - -def video_to_base64_string( - video: VideoInput, - container_format: VideoContainer = None, - codec: VideoCodec = None -) -> str: - """ - Converts a video input to a base64 string. - - Args: - video: The video input to convert - container_format: Optional container format to use (defaults to video.container if available) - codec: Optional codec to use (defaults to video.codec if available) - """ - video_bytes_io = io.BytesIO() - - # Use provided format/codec if specified, otherwise use video's own if available - format_to_use = container_format if container_format is not None else getattr(video, 'container', VideoContainer.MP4) - codec_to_use = codec if codec is not None else getattr(video, 'codec', VideoCodec.H264) - - video.save_to(video_bytes_io, format=format_to_use, codec=codec_to_use) - video_bytes_io.seek(0) - return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") - - -async def upload_video_to_comfyapi( - video: VideoInput, - auth_kwargs: Optional[dict[str, str]] = None, - container: VideoContainer = VideoContainer.MP4, - codec: VideoCodec = VideoCodec.H264, - max_duration: Optional[int] = None, -) -> str: - """ - Uploads a single video to ComfyUI API and returns its download URL. - Uses the specified container and codec for saving the video before upload. - - Args: - video: VideoInput object (Comfy VIDEO type). - auth_kwargs: Optional authentication token(s). - container: The video container format to use (default: MP4). - codec: The video codec to use (default: H264). - max_duration: Optional maximum duration of the video in seconds. If the video is longer than this, an error will be raised. - - Returns: - The download URL for the uploaded video file. - """ - if max_duration is not None: - try: - actual_duration = video.duration_seconds - if actual_duration is not None and actual_duration > max_duration: - raise ValueError( - f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." - ) - except Exception as e: - logging.error("Error getting video duration: %s", str(e)) - raise ValueError(f"Could not verify video duration from source: {e}") from e - - upload_mime_type = f"video/{container.value.lower()}" - filename = f"uploaded_video.{container.value.lower()}" - - # Convert VideoInput to BytesIO using specified container/codec - video_bytes_io = io.BytesIO() - video.save_to(video_bytes_io, format=container, codec=codec) - video_bytes_io.seek(0) - - return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs) - - -def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: - """ - Prepares audio waveform for av library by converting to a contiguous numpy array. - - Args: - waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type. - - Returns: - Contiguous numpy array of the audio waveform. If the audio was batched, - the first item is taken. - """ - if waveform.ndim != 3 or waveform.shape[0] != 1: - raise ValueError("Expected waveform tensor shape (1, channels, samples)") - - # If batch is > 1, take first item - if waveform.shape[0] > 1: - waveform = waveform[0] - - # Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array - audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy() - if audio_data_np.dtype != np.float32: - audio_data_np = audio_data_np.astype(np.float32) - - return audio_data_np - - -def audio_ndarray_to_bytesio( - audio_data_np: np.ndarray, - sample_rate: int, - container_format: str = "mp4", - codec_name: str = "aac", -) -> BytesIO: - """ - Encodes a numpy array of audio data into a BytesIO object. - """ - audio_bytes_io = io.BytesIO() - with av.open(audio_bytes_io, mode="w", format=container_format) as output_container: - audio_stream = output_container.add_stream(codec_name, rate=sample_rate) - frame = av.AudioFrame.from_ndarray( - audio_data_np, - format="fltp", - layout="stereo" if audio_data_np.shape[0] > 1 else "mono", - ) - frame.sample_rate = sample_rate - frame.pts = 0 - - for packet in audio_stream.encode(frame): - output_container.mux(packet) - - # Flush stream - for packet in audio_stream.encode(None): - output_container.mux(packet) - - audio_bytes_io.seek(0) - return audio_bytes_io - - -async def upload_audio_to_comfyapi( - audio: AudioInput, - auth_kwargs: Optional[dict[str, str]] = None, - container_format: str = "mp4", - codec_name: str = "aac", - mime_type: str = "audio/mp4", - filename: str = "uploaded_audio.mp4", -) -> str: - """ - Uploads a single audio input to ComfyUI API and returns its download URL. - Encodes the raw waveform into the specified format before uploading. - - Args: - audio: a Comfy `AUDIO` type (contains waveform tensor and sample_rate) - auth_kwargs: Optional authentication token(s). - - Returns: - The download URL for the uploaded audio file. - """ - sample_rate: int = audio["sample_rate"] - waveform: torch.Tensor = audio["waveform"] - audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) - audio_bytes_io = audio_ndarray_to_bytesio( - audio_data_np, sample_rate, container_format, codec_name - ) - - return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) - - -def f32_pcm(wav: torch.Tensor) -> torch.Tensor: - """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" - if wav.dtype.is_floating_point: - return wav - elif wav.dtype == torch.int16: - return wav.float() / (2 ** 15) - elif wav.dtype == torch.int32: - return wav.float() / (2 ** 31) - raise ValueError(f"Unsupported wav dtype: {wav.dtype}") - - -def audio_bytes_to_audio_input(audio_bytes: bytes,) -> dict: - """ - Decode any common audio container from bytes using PyAV and return - a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}. - """ - with av.open(io.BytesIO(audio_bytes)) as af: - if not af.streams.audio: - raise ValueError("No audio stream found in response.") - stream = af.streams.audio[0] - - in_sr = int(stream.codec_context.sample_rate) - out_sr = in_sr - - frames: list[torch.Tensor] = [] - n_channels = stream.channels or 1 - - for frame in af.decode(streams=stream.index): - arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T] - buf = torch.from_numpy(arr) - if buf.ndim == 1: - buf = buf.unsqueeze(0) # [T] -> [1, T] - elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels: - buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T] - elif buf.shape[0] != n_channels: - buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T] - frames.append(buf) - - if not frames: - raise ValueError("Decoded zero audio frames.") - - wav = torch.cat(frames, dim=1) # [C, T] - wav = f32_pcm(wav) - return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} - - -def audio_input_to_mp3(audio: AudioInput) -> io.BytesIO: - waveform = audio["waveform"].cpu() - - output_buffer = io.BytesIO() - output_container = av.open(output_buffer, mode='w', format="mp3") - - out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"]) - out_stream.bit_rate = 320000 - - frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout='mono' if waveform.shape[0] == 1 else 'stereo') - frame.sample_rate = audio["sample_rate"] - frame.pts = 0 - output_container.mux(out_stream.encode(frame)) - output_container.mux(out_stream.encode(None)) - output_container.close() - output_buffer.seek(0) - return output_buffer - - -def audio_to_base64_string( - audio: AudioInput, container_format: str = "mp4", codec_name: str = "aac" -) -> str: - """Converts an audio input to a base64 string.""" - sample_rate: int = audio["sample_rate"] - waveform: torch.Tensor = audio["waveform"] - audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) - audio_bytes_io = audio_ndarray_to_bytesio( - audio_data_np, sample_rate, container_format, codec_name - ) - audio_bytes = audio_bytes_io.getvalue() - return base64.b64encode(audio_bytes).decode("utf-8") - - -async def upload_images_to_comfyapi( - image: torch.Tensor, - max_images=8, - auth_kwargs: Optional[dict[str, str]] = None, - mime_type: Optional[str] = None, -) -> list[str]: - """ - Uploads images to ComfyUI API and returns download URLs. - To upload multiple images, stack them in the batch dimension first. - - Args: - image: Input torch.Tensor image. - max_images: Maximum number of images to upload. - auth_kwargs: Optional authentication token(s). - mime_type: Optional MIME type for the image. - """ - # if batch, try to upload each file if max_images is greater than 0 - download_urls: list[str] = [] - is_batch = len(image.shape) > 3 - batch_len = image.shape[0] if is_batch else 1 - - for idx in range(min(batch_len, max_images)): - tensor = image[idx] if is_batch else image - img_io = tensor_to_bytesio(tensor, mime_type=mime_type) - url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs) - download_urls.append(url) - return download_urls - - -def resize_mask_to_image( - mask: torch.Tensor, - image: torch.Tensor, - upscale_method="nearest-exact", - crop="disabled", - allow_gradient=True, - add_channel_dim=False, -): - """ - Resize mask to be the same dimensions as an image, while maintaining proper format for API calls. - """ - _, H, W, _ = image.shape - mask = mask.unsqueeze(-1) - mask = mask.movedim(-1, 1) - mask = common_upscale( - mask, width=W, height=H, upscale_method=upscale_method, crop=crop - ) - mask = mask.movedim(1, -1) - if not add_channel_dim: - mask = mask.squeeze(-1) - if not allow_gradient: - mask = (mask > 0.5).float() - return mask - - -def validate_string( - string: str, - strip_whitespace=True, - field_name="prompt", - min_length=None, - max_length=None, -): - if string is None: - raise Exception(f"Field '{field_name}' cannot be empty.") - if strip_whitespace: - string = string.strip() - if min_length and len(string) < min_length: - raise Exception( - f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long." - ) - if max_length and len(string) > max_length: - raise Exception( - f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long." - ) - - -def image_tensor_pair_to_batch( - image1: torch.Tensor, image2: torch.Tensor -) -> torch.Tensor: - """ - Converts a pair of image tensors to a batch tensor. - If the images are not the same size, the smaller image is resized to - match the larger image. - """ - if image1.shape[1:] != image2.shape[1:]: - image2 = common_upscale( - image2.movedim(-1, 1), - image1.shape[2], - image1.shape[1], - "bilinear", - "center", - ).movedim(1, -1) - return torch.cat((image1, image2), dim=0) - - -def get_size(path_or_object: Union[str, io.BytesIO]) -> int: - if isinstance(path_or_object, str): - return os.path.getsize(path_or_object) - return len(path_or_object.getvalue()) - - -def validate_container_format_is_mp4(video: VideoInput) -> None: - """Validates video container format is MP4.""" - container_format = video.get_container_format() - if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: - raise ValueError(f"Only MP4 container format supported. Got: {container_format}") diff --git a/comfy_api_nodes/apis/PixverseController.py b/comfy_api_nodes/apis/PixverseController.py deleted file mode 100644 index 310c0f546..000000000 --- a/comfy_api_nodes/apis/PixverseController.py +++ /dev/null @@ -1,17 +0,0 @@ -# generated by datamodel-codegen: -# filename: filtered-openapi.yaml -# timestamp: 2025-04-29T23:44:54+00:00 - -from __future__ import annotations - -from typing import Optional - -from pydantic import BaseModel - -from . import PixverseDto - - -class ResponseData(BaseModel): - ErrCode: Optional[int] = None - ErrMsg: Optional[str] = None - Resp: Optional[PixverseDto.V2OpenAPII2VResp] = None diff --git a/comfy_api_nodes/apis/PixverseDto.py b/comfy_api_nodes/apis/PixverseDto.py deleted file mode 100644 index 323c38e96..000000000 --- a/comfy_api_nodes/apis/PixverseDto.py +++ /dev/null @@ -1,57 +0,0 @@ -# generated by datamodel-codegen: -# filename: filtered-openapi.yaml -# timestamp: 2025-04-29T23:44:54+00:00 - -from __future__ import annotations - -from typing import Optional - -from pydantic import BaseModel, Field - - -class V2OpenAPII2VResp(BaseModel): - video_id: Optional[int] = Field(None, description='Video_id') - - -class V2OpenAPIT2VReq(BaseModel): - aspect_ratio: str = Field( - ..., description='Aspect ratio (16:9, 4:3, 1:1, 3:4, 9:16)', examples=['16:9'] - ) - duration: int = Field( - ..., - description='Video duration (5, 8 seconds, --model=v3.5 only allows 5,8; --quality=1080p does not support 8s)', - examples=[5], - ) - model: str = Field( - ..., description='Model version (only supports v3.5)', examples=['v3.5'] - ) - motion_mode: Optional[str] = Field( - 'normal', - description='Motion mode (normal, fast, --fast only available when duration=5; --quality=1080p does not support fast)', - examples=['normal'], - ) - negative_prompt: Optional[str] = Field( - None, description='Negative prompt\n', max_length=2048 - ) - prompt: str = Field(..., description='Prompt', max_length=2048) - quality: str = Field( - ..., - description='Video quality ("360p"(Turbo model), "540p", "720p", "1080p")', - examples=['540p'], - ) - seed: Optional[int] = Field(None, description='Random seed, range: 0 - 2147483647') - style: Optional[str] = Field( - None, - description='Style (effective when model=v3.5, "anime", "3d_animation", "clay", "comic", "cyberpunk") Do not include style parameter unless needed', - examples=['anime'], - ) - template_id: Optional[int] = Field( - None, - description='Template ID (template_id must be activated before use)', - examples=[302325299692608], - ) - water_mark: Optional[bool] = Field( - False, - description='Watermark (true: add watermark, false: no watermark)', - examples=[False], - ) diff --git a/comfy_api_nodes/apis/bfl_api.py b/comfy_api_nodes/apis/bfl_api.py index 0e90aef7c..d8d3557b3 100644 --- a/comfy_api_nodes/apis/bfl_api.py +++ b/comfy_api_nodes/apis/bfl_api.py @@ -50,44 +50,6 @@ class BFLFluxFillImageRequest(BaseModel): mask: str = Field(None, description='A Base64-encoded string representing the mask of the areas you with to modify.') -class BFLFluxCannyImageRequest(BaseModel): - prompt: str = Field(..., description='Text prompt for image generation') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - canny_low_threshold: Optional[int] = Field(None, description='Low threshold for Canny edge detection') - canny_high_threshold: Optional[int] = Field(None, description='High threshold for Canny edge detection') - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided') - preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step') - - -class BFLFluxDepthImageRequest(BaseModel): - prompt: str = Field(..., description='Text prompt for image generation') - prompt_upsampling: Optional[bool] = Field( - None, description='Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation.' - ) - seed: Optional[int] = Field(None, description='The seed value for reproducibility.') - steps: conint(ge=15, le=50) = Field(..., description='Number of steps for the image generation process') - guidance: confloat(ge=1, le=100) = Field(..., description='Guidance strength for the image generation process') - safety_tolerance: Optional[conint(ge=0, le=6)] = Field( - 6, description='Tolerance level for input and output moderation. Between 0 and 6, 0 being most strict, 6 being least strict. Defaults to 2.' - ) - output_format: Optional[BFLOutputFormat] = Field( - BFLOutputFormat.png, description="Output format for the generated image. Can be 'jpeg' or 'png'.", examples=['png'] - ) - control_image: Optional[str] = Field(None, description='Base64 encoded image to use as control input if no preprocessed image is provided') - preprocessed_image: Optional[str] = Field(None, description='Optional pre-processed image that will bypass the control preprocessing step') - - class BFLFluxProGenerateRequest(BaseModel): prompt: str = Field(..., description='The text prompt for image generation.') prompt_upsampling: Optional[bool] = Field( @@ -108,6 +70,29 @@ class BFLFluxProGenerateRequest(BaseModel): # ) +class Flux2ProGenerateRequest(BaseModel): + prompt: str = Field(...) + width: int = Field(1024, description="Must be a multiple of 32.") + height: int = Field(768, description="Must be a multiple of 32.") + seed: int | None = Field(None) + prompt_upsampling: bool | None = Field(None) + input_image: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_2: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_3: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_4: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_5: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_6: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_7: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_8: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + input_image_9: str | None = Field(None, description="Base64 encoded image for image-to-image generation") + safety_tolerance: int | None = Field( + 5, description="Tolerance level for input and output moderation. Value 0 being most strict.", ge=0, le=5 + ) + output_format: str | None = Field( + "png", description="Output format for the generated image. Can be 'jpeg' or 'png'." + ) + + class BFLFluxKontextProGenerateRequest(BaseModel): prompt: str = Field(..., description='The text prompt for what you wannt to edit.') input_image: Optional[str] = Field(None, description='Image to edit in base64 format') @@ -147,8 +132,9 @@ class BFLFluxProUltraGenerateRequest(BaseModel): class BFLFluxProGenerateResponse(BaseModel): - id: str = Field(..., description='The unique identifier for the generation task.') - polling_url: str = Field(..., description='URL to poll for the generation result.') + id: str = Field(..., description="The unique identifier for the generation task.") + polling_url: str = Field(..., description="URL to poll for the generation result.") + cost: float | None = Field(None, description="Price in cents") class BFLStatus(str, Enum): @@ -160,15 +146,8 @@ class BFLStatus(str, Enum): error = "Error" -class BFLFluxProStatusResponse(BaseModel): +class BFLFluxStatusResponse(BaseModel): id: str = Field(..., description="The unique identifier for the generation task.") status: BFLStatus = Field(..., description="The status of the task.") - result: Optional[Dict[str, Any]] = Field( - None, description="The result of the task (null if not completed)." - ) - progress: confloat(ge=0.0, le=1.0) = Field( - ..., description="The progress of the task (0.0 to 1.0)." - ) - details: Optional[Dict[str, Any]] = Field( - None, description="Additional details about the task (null if not available)." - ) + result: Optional[Dict[str, Any]] = Field(None, description="The result of the task (null if not completed).") + progress: Optional[float] = Field(None, description="The progress of the task (0.0 to 1.0).", ge=0.0, le=1.0) diff --git a/comfy_api_nodes/apis/bytedance_api.py b/comfy_api_nodes/apis/bytedance_api.py new file mode 100644 index 000000000..77cd76f9b --- /dev/null +++ b/comfy_api_nodes/apis/bytedance_api.py @@ -0,0 +1,144 @@ +from typing import Literal + +from pydantic import BaseModel, Field + + +class Text2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str | None = Field("url") + size: str | None = Field(None) + seed: int | None = Field(0, ge=0, le=2147483647) + guidance_scale: float | None = Field(..., ge=1.0, le=10.0) + watermark: bool | None = Field(True) + + +class Image2ImageTaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str | None = Field("url") + image: str = Field(..., description="Base64 encoded string or image URL") + size: str | None = Field("adaptive") + seed: int | None = Field(..., ge=0, le=2147483647) + guidance_scale: float | None = Field(..., ge=1.0, le=10.0) + watermark: bool | None = Field(True) + + +class Seedream4Options(BaseModel): + max_images: int = Field(15) + + +class Seedream4TaskCreationRequest(BaseModel): + model: str = Field(...) + prompt: str = Field(...) + response_format: str = Field("url") + image: list[str] | None = Field(None, description="Image URLs") + size: str = Field(...) + seed: int = Field(..., ge=0, le=2147483647) + sequential_image_generation: str = Field("disabled") + sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) + watermark: bool = Field(True) + + +class ImageTaskCreationResponse(BaseModel): + model: str = Field(...) + created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.") + data: list = Field([], description="Contains information about the generated image(s).") + error: dict = Field({}, description="Contains `code` and `message` fields in case of error.") + + +class TaskTextContent(BaseModel): + type: str = Field("text") + text: str = Field(...) + + +class TaskImageContentUrl(BaseModel): + url: str = Field(...) + + +class TaskImageContent(BaseModel): + type: str = Field("image_url") + image_url: TaskImageContentUrl = Field(...) + role: Literal["first_frame", "last_frame", "reference_image"] | None = Field(None) + + +class Text2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + content: list[TaskTextContent] = Field(..., min_length=1) + + +class Image2VideoTaskCreationRequest(BaseModel): + model: str = Field(...) + content: list[TaskTextContent | TaskImageContent] = Field(..., min_length=2) + + +class TaskCreationResponse(BaseModel): + id: str = Field(...) + + +class TaskStatusError(BaseModel): + code: str = Field(...) + message: str = Field(...) + + +class TaskStatusResult(BaseModel): + video_url: str = Field(...) + + +class TaskStatusResponse(BaseModel): + id: str = Field(...) + model: str = Field(...) + status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) + error: TaskStatusError | None = Field(None) + content: TaskStatusResult | None = Field(None) + + +RECOMMENDED_PRESETS = [ + ("1024x1024 (1:1)", 1024, 1024), + ("864x1152 (3:4)", 864, 1152), + ("1152x864 (4:3)", 1152, 864), + ("1280x720 (16:9)", 1280, 720), + ("720x1280 (9:16)", 720, 1280), + ("832x1248 (2:3)", 832, 1248), + ("1248x832 (3:2)", 1248, 832), + ("1512x648 (21:9)", 1512, 648), + ("2048x2048 (1:1)", 2048, 2048), + ("Custom", None, None), +] + +RECOMMENDED_PRESETS_SEEDREAM_4 = [ + ("2048x2048 (1:1)", 2048, 2048), + ("2304x1728 (4:3)", 2304, 1728), + ("1728x2304 (3:4)", 1728, 2304), + ("2560x1440 (16:9)", 2560, 1440), + ("1440x2560 (9:16)", 1440, 2560), + ("2496x1664 (3:2)", 2496, 1664), + ("1664x2496 (2:3)", 1664, 2496), + ("3024x1296 (21:9)", 3024, 1296), + ("4096x4096 (1:1)", 4096, 4096), + ("Custom", None, None), +] + +# The time in this dictionary are given for 10 seconds duration. +VIDEO_TASKS_EXECUTION_TIME = { + "seedance-1-0-lite-t2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-lite-i2v-250428": { + "480p": 40, + "720p": 60, + "1080p": 90, + }, + "seedance-1-0-pro-250528": { + "480p": 70, + "720p": 85, + "1080p": 115, + }, + "seedance-1-0-pro-fast-251015": { + "480p": 50, + "720p": 65, + "1080p": 100, + }, +} diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py deleted file mode 100644 index 9a78640e6..000000000 --- a/comfy_api_nodes/apis/client.py +++ /dev/null @@ -1,981 +0,0 @@ -""" -API Client Framework for api.comfy.org. - -This module provides a flexible framework for making API requests from ComfyUI nodes. -It supports both synchronous and asynchronous API operations with proper type validation. - -Key Components: --------------- -1. ApiClient - Handles HTTP requests with authentication and error handling -2. ApiEndpoint - Defines a single HTTP endpoint with its request/response models -3. ApiOperation - Executes a single synchronous API operation - -Usage Examples: --------------- - -# Example 1: Synchronous API Operation -# ------------------------------------ -# For a simple API call that returns the result immediately: - -# 1. Create the API client -api_client = ApiClient( - base_url="https://api.example.com", - auth_token="your_auth_token_here", - comfy_api_key="your_comfy_api_key_here", - timeout=30.0, - verify_ssl=True -) - -# 2. Define the endpoint -user_info_endpoint = ApiEndpoint( - path="/v1/users/me", - method=HttpMethod.GET, - request_model=EmptyRequest, # No request body needed - response_model=UserProfile, # Pydantic model for the response - query_params=None -) - -# 3. Create the request object -request = EmptyRequest() - -# 4. Create and execute the operation -operation = ApiOperation( - endpoint=user_info_endpoint, - request=request -) -user_profile = await operation.execute(client=api_client) # Returns immediately with the result - - -# Example 2: Asynchronous API Operation with Polling -# ------------------------------------------------- -# For an API that starts a task and requires polling for completion: - -# 1. Define the endpoints (initial request and polling) -generate_image_endpoint = ApiEndpoint( - path="/v1/images/generate", - method=HttpMethod.POST, - request_model=ImageGenerationRequest, - response_model=TaskCreatedResponse, - query_params=None -) - -check_task_endpoint = ApiEndpoint( - path="/v1/tasks/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=ImageGenerationResult, - query_params=None -) - -# 2. Create the request object -request = ImageGenerationRequest( - prompt="a beautiful sunset over mountains", - width=1024, - height=1024, - num_images=1 -) - -# 3. Create and execute the polling operation -operation = PollingOperation( - initial_endpoint=generate_image_endpoint, - initial_request=request, - poll_endpoint=check_task_endpoint, - task_id_field="task_id", - status_field="status", - completed_statuses=["completed"], - failed_statuses=["failed", "error"] -) - -# This will make the initial request and then poll until completion -result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done -""" - -from __future__ import annotations -import aiohttp -import asyncio -import logging -import io -import os -import socket -from aiohttp.client_exceptions import ClientError, ClientResponseError -from typing import Type, Optional, Any, TypeVar, Generic, Callable -from enum import Enum -import json -from urllib.parse import urljoin, urlparse -from pydantic import BaseModel, Field -import uuid # For generating unique operation IDs - -from comfy.cmd.server import PromptServer -from comfy.cli_args import args -from comfy import utils -from . import request_logger - -T = TypeVar("T", bound=BaseModel) -R = TypeVar("R", bound=BaseModel) -P = TypeVar("P", bound=BaseModel) # For poll response - -PROGRESS_BAR_MAX = 100 - - -class NetworkError(Exception): - """Base exception for network-related errors with diagnostic information.""" - pass - - -class LocalNetworkError(NetworkError): - """Exception raised when local network connectivity issues are detected.""" - pass - - -class ApiServerError(NetworkError): - """Exception raised when the API server is unreachable but internet is working.""" - pass - - -class EmptyRequest(BaseModel): - """Base class for empty request bodies. - For GET requests, fields will be sent as query parameters.""" - - pass - - -class UploadRequest(BaseModel): - file_name: str = Field(..., description="Filename to upload") - content_type: Optional[str] = Field( - None, - description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", - ) - - -class UploadResponse(BaseModel): - download_url: str = Field(..., description="URL to GET uploaded file") - upload_url: str = Field(..., description="URL to PUT file to upload") - - -class HttpMethod(str, Enum): - GET = "GET" - POST = "POST" - PUT = "PUT" - DELETE = "DELETE" - PATCH = "PATCH" - - -class ApiClient: - """ - Client for making HTTP requests to an API with authentication, error handling, and retry logic. - """ - - def __init__( - self, - base_url: str, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - timeout: float = 3600.0, - verify_ssl: bool = True, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - retry_status_codes: Optional[tuple[int, ...]] = None, - session: Optional[aiohttp.ClientSession] = None, - ): - self.base_url = base_url - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - self.timeout = timeout - self.verify_ssl = verify_ssl - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - # Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests), - # 500, 502, 503, 504 (Server Errors) - self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504) - self._session: Optional[aiohttp.ClientSession] = session - self._owns_session = session is None # Track if we have to close it - - @staticmethod - def _generate_operation_id(path: str) -> str: - """Generates a unique operation ID for logging.""" - return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}" - - @staticmethod - def _create_json_payload_args( - data: Optional[dict[str, Any]] = None, - headers: Optional[dict[str, str]] = None, - ) -> dict[str, Any]: - return { - "json": data, - "headers": headers, - } - - def _create_form_data_args( - self, - data: dict[str, Any] | None, - files: dict[str, Any] | None, - headers: Optional[dict[str, str]] = None, - multipart_parser: Callable | None = None, - ) -> dict[str, Any]: - if headers and "Content-Type" in headers: - del headers["Content-Type"] - - if multipart_parser and data: - data = multipart_parser(data) - - if isinstance(data, aiohttp.FormData): - form = data # If the parser already returned a FormData, pass it through - else: - form = aiohttp.FormData(default_to_multipart=True) - if data: # regular text fields - for k, v in data.items(): - if v is None: - continue # aiohttp fails to serialize "None" values - # aiohttp expects strings or bytes; convert enums etc. - form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) - - if files: - file_iter = files if isinstance(files, list) else files.items() - for field_name, file_obj in file_iter: - if file_obj is None: - continue # aiohttp fails to serialize "None" values - # file_obj can be (filename, bytes/io.BytesIO, content_type) tuple - if isinstance(file_obj, tuple): - filename, file_value, content_type = self._unpack_tuple(file_obj) - else: - file_value = file_obj - filename = getattr(file_obj, "name", field_name) - content_type = "application/octet-stream" - - form.add_field( - name=field_name, - value=file_value, - filename=filename, - content_type=content_type, - ) - return {"data": form, "headers": headers or {}} - - @staticmethod - def _create_urlencoded_form_data_args( - data: dict[str, Any], - headers: Optional[dict[str, str]] = None, - ) -> dict[str, Any]: - headers = headers or {} - headers["Content-Type"] = "application/x-www-form-urlencoded" - return { - "data": data, - "headers": headers, - } - - def get_headers(self) -> dict[str, str]: - """Get headers for API requests, including authentication if available""" - headers = {"Content-Type": "application/json", "Accept": "application/json"} - - if self.auth_token: - headers["Authorization"] = f"Bearer {self.auth_token}" - elif self.comfy_api_key: - headers["X-API-KEY"] = self.comfy_api_key - - return headers - - async def _check_connectivity(self, target_url: str) -> dict[str, bool]: - """ - Check connectivity to determine if network issues are local or server-related. - - Args: - target_url: URL to check connectivity to - - Returns: - Dictionary with connectivity status details - """ - results = { - "internet_accessible": False, - "api_accessible": False, - "is_local_issue": False, - "is_api_issue": False, - } - timeout = aiohttp.ClientTimeout(total=5.0) - async with aiohttp.ClientSession(timeout=timeout) as session: - try: - async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp: - results["internet_accessible"] = resp.status < 500 - except (ClientError, asyncio.TimeoutError, socket.gaierror): - results["is_local_issue"] = True - return results # cannot reach the internet – early exit - - # Now check API health endpoint - parsed = urlparse(target_url) - health_url = f"{parsed.scheme}://{parsed.netloc}/health" - try: - async with session.get(health_url, ssl=self.verify_ssl) as resp: - results["api_accessible"] = resp.status < 500 - except ClientError: - pass # leave as False - - results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] - return results - - async def request( - self, - method: str, - path: str, - params: Optional[dict[str, Any]] = None, - data: Optional[dict[str, Any]] = None, - files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None, - headers: Optional[dict[str, str]] = None, - content_type: str = "application/json", - multipart_parser: Callable | None = None, - retry_count: int = 0, # Used internally for tracking retries - ) -> dict[str, Any]: - """ - Make an HTTP request to the API with automatic retries for transient errors. - - Args: - method: HTTP method (GET, POST, etc.) - path: API endpoint path (will be joined with base_url) - params: Query parameters - data: body data - files: Files to upload - headers: Additional headers - content_type: Content type of the request. Defaults to application/json. - retry_count: Internal parameter for tracking retries, do not set manually - - Returns: - Parsed JSON response - - Raises: - LocalNetworkError: If local network connectivity issues are detected - ApiServerError: If the API server is unreachable but internet is working - Exception: For other request failures - """ - - # Build full URL and merge headers - relative_path = path.lstrip("/") - url = urljoin(self.base_url, relative_path) - self._check_auth(self.auth_token, self.comfy_api_key) - - request_headers = self.get_headers() - if headers: - request_headers.update(headers) - if files: - request_headers.pop("Content-Type", None) - if params: - params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values - - logging.debug("[DEBUG] Request Headers: %s", request_headers) - logging.debug("[DEBUG] Files: %s", files) - logging.debug("[DEBUG] Params: %s", params) - logging.debug("[DEBUG] Data: %s", data) - - if content_type == "application/x-www-form-urlencoded": - payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers) - elif content_type == "multipart/form-data": - payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser) - else: - payload_args = self._create_json_payload_args(data, request_headers) - - operation_id = self._generate_operation_id(path) - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - request_headers=request_headers, - request_params=params, - request_data=data if content_type == "application/json" else "[form-data or other]", - ) - - session = await self._get_session() - try: - async with session.request( - method, - url, - params=params, - ssl=self.verify_ssl, - **payload_args, - ) as resp: - if resp.status >= 400: - try: - error_data = await resp.json() - except (aiohttp.ContentTypeError, json.JSONDecodeError): - error_data = await resp.text() - - return await self._handle_http_error( - ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data), - operation_id, - method, - url, - params, - data, - files, - headers, - content_type, - multipart_parser, - retry_count=retry_count, - response_content=error_data, - ) - - # Success – parse JSON (safely) and log - try: - payload = await resp.json() - response_content_to_log = payload - except (aiohttp.ContentTypeError, json.JSONDecodeError): - payload = {} - response_content_to_log = await resp.text() - - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content=response_content_to_log, - ) - return payload - - except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: - # Treat as *connection* problem – optionally retry, else escalate - if retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1, - self.max_retries, str(e)) - await asyncio.sleep(delay) - return await self.request( - method, - path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - # One final connectivity check for diagnostics - connectivity = await self._check_connectivity(self.base_url) - if connectivity["is_local_issue"]: - raise LocalNetworkError( - "Unable to connect to the API server due to local network issues. " - "Please check your internet connection and try again." - ) from e - raise ApiServerError( - f"The API server at {self.base_url} is currently unreachable. " - f"The service may be experiencing issues. Please try again later." - ) from e - - @staticmethod - def _check_auth(auth_token, comfy_api_key): - """Verify that an auth token is present or comfy_api_key is present""" - if auth_token is None and comfy_api_key is None: - raise Exception("Unauthorized: Please login first to use this node.") - return auth_token or comfy_api_key - - @staticmethod - async def upload_file( - upload_url: str, - file: io.BytesIO | str, - content_type: str | None = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - ) -> aiohttp.ClientResponse: - """Upload a file to the API with retry logic. - - Args: - upload_url: The URL to upload to - file: Either a file path string, BytesIO object, or tuple of (file_path, filename) - content_type: Optional mime type to set for the upload - max_retries: Maximum number of retry attempts - retry_delay: Initial delay between retries in seconds - retry_backoff_factor: Multiplier for the delay after each retry - """ - headers: dict[str, str] = {} - skip_auto_headers: set[str] = set() - if content_type: - headers["Content-Type"] = content_type - else: - # tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status. - skip_auto_headers.add("Content-Type") - - # Extract file bytes - if isinstance(file, io.BytesIO): - file.seek(0) - data = file.read() - elif isinstance(file, str): - with open(file, "rb") as f: - data = f.read() - else: - raise ValueError("File must be BytesIO or str path") - - parsed = urlparse(upload_url) - basename = os.path.basename(parsed.path) or parsed.netloc or "upload" - operation_id = f"upload_{basename}_{uuid.uuid4().hex[:8]}" - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - request_headers=headers, - request_data=f"[File data {len(data)} bytes]", - ) - - delay = retry_delay - for attempt in range(max_retries + 1): - try: - timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.put( - upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers, - ) as resp: - resp.raise_for_status() - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=resp.status, - response_headers=dict(resp.headers), - response_content="File uploaded successfully.", - ) - return resp - except (ClientError, asyncio.TimeoutError) as e: - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", - request_url=upload_url, - response_status_code=e.status if hasattr(e, "status") else None, - response_headers=dict(e.headers) if hasattr(e, "headers") else None, # pylint: disable=no-member - response_content=None, - error_message=f"{type(e).__name__}: {str(e)}", - ) - if attempt < max_retries: - logging.warning( - "Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e) - ) - await asyncio.sleep(delay) - delay *= retry_backoff_factor - else: - raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e - - async def _handle_http_error( - self, - exc: ClientResponseError, - operation_id: str, - *req_meta, - retry_count: int, - response_content: dict | str = "", - ) -> dict[str, Any]: - status_code = exc.status - if status_code == 401: - user_friendly = "Unauthorized: Please login first to use this node." - elif status_code == 402: - user_friendly = "Payment Required: Please add credits to your account to use this node." - elif status_code == 409: - user_friendly = "There is a problem with your account. Please contact support@comfy.org." - elif status_code == 429: - user_friendly = "Rate Limit Exceeded: Please try again later." - else: - if isinstance(response_content, dict): - if "error" in response_content and "message" in response_content["error"]: - user_friendly = f"API Error: {response_content['error']['message']}" - if "type" in response_content["error"]: - user_friendly += f" (Type: {response_content['error']['type']})" - else: # Handle cases where error is just a JSON dict with unknown format - user_friendly = f"API Error: {json.dumps(response_content)}" - else: - if len(response_content) < 200: # Arbitrary limit for display - user_friendly = f"API Error (raw): {response_content}" - else: - user_friendly = f"API Error (raw, status {response_content})" - - request_logger.log_request_response( - operation_id=operation_id, - request_method=req_meta[0], - request_url=req_meta[1], - response_status_code=exc.status, - response_headers=dict(req_meta[5]) if req_meta[5] else None, - response_content=response_content, - error_message=f"HTTP Error {exc.status}", - ) - - logging.debug("[DEBUG] API Error: %s (Status: %s)", user_friendly, status_code) - if response_content: - logging.debug("[DEBUG] Response content: %s", response_content) - - # Retry if eligible - if status_code in self.retry_status_codes and retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - "HTTP error %s. Retrying in %.2fs (%s/%s)", - status_code, - delay, - retry_count + 1, - self.max_retries, - ) - await asyncio.sleep(delay) - return await self.request( - req_meta[0], # method - req_meta[1].replace(self.base_url, ""), # path - params=req_meta[2], - data=req_meta[3], - files=req_meta[4], - headers=req_meta[5], - content_type=req_meta[6], - multipart_parser=req_meta[7], - retry_count=retry_count + 1, - ) - - raise Exception(user_friendly) from exc - - @staticmethod - def _unpack_tuple(t): - """Helper to normalise (filename, file, content_type) tuples.""" - if len(t) == 3: - return t - elif len(t) == 2: - return t[0], t[1], "application/octet-stream" - else: - raise ValueError("files tuple must be (filename, file[, content_type])") - - async def _get_session(self) -> aiohttp.ClientSession: - if self._session is None or self._session.closed: - timeout = aiohttp.ClientTimeout(total=self.timeout) - self._session = aiohttp.ClientSession(timeout=timeout) - self._owns_session = True - return self._session - - async def close(self) -> None: - if self._owns_session and self._session and not self._session.closed: - await self._session.close() - - async def __aenter__(self) -> "ApiClient": - """Allow usage as async‑context‑manager – ensures clean teardown""" - return self - - async def __aexit__(self, exc_type, exc, tb): - await self.close() - - -class ApiEndpoint(Generic[T, R]): - """Defines an API endpoint with its request and response types""" - - def __init__( - self, - path: str, - method: HttpMethod, - request_model: Type[T], - response_model: Type[R], - query_params: Optional[dict[str, Any]] = None, - ): - """Initialize an API endpoint definition. - - Args: - path: The URL path for this endpoint, can include placeholders like {id} - method: The HTTP method to use (GET, POST, etc.) - request_model: Pydantic model class that defines the structure and validation rules for API requests to this endpoint - response_model: Pydantic model class that defines the structure and validation rules for API responses from this endpoint - query_params: Optional dictionary of query parameters to include in the request - """ - self.path = path - self.method = method - self.request_model = request_model - self.response_model = response_model - self.query_params = query_params or {} - - -class SynchronousOperation(Generic[T, R]): - """Represents a single synchronous API operation.""" - - def __init__( - self, - endpoint: ApiEndpoint[T, R], - request: T, - files: Optional[dict[str, Any] | list[tuple[str, Any]]] = None, - api_base: str | None = None, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[dict[str, str]] = None, - timeout: float = 7200.0, - verify_ssl: bool = True, - content_type: str = "application/json", - multipart_parser: Callable | None = None, - max_retries: int = 3, - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - ) -> None: - self.endpoint = endpoint - self.request = request - self.files = files - self.api_base: str = api_base or args.comfy_api_base - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - if auth_kwargs is not None: - self.auth_token = auth_kwargs.get("auth_token", self.auth_token) - self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) - self.timeout = timeout - self.verify_ssl = verify_ssl - self.content_type = content_type - self.multipart_parser = multipart_parser - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - - async def execute(self, client: Optional[ApiClient] = None) -> R: - owns_client = client is None - if owns_client: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - timeout=self.timeout, - verify_ssl=self.verify_ssl, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - - try: - request_dict: Optional[dict[str, Any]] - if isinstance(self.request, EmptyRequest): - request_dict = None - else: - request_dict = self.request.model_dump(exclude_none=True) - for k, v in list(request_dict.items()): - if isinstance(v, Enum): - request_dict[k] = v.value - - logging.debug("[DEBUG] API Request: %s %s", self.endpoint.method.value, self.endpoint.path) - logging.debug("[DEBUG] Request Data: %s", json.dumps(request_dict, indent=2)) - logging.debug("[DEBUG] Query Params: %s", self.endpoint.query_params) - - response_json = await client.request( - self.endpoint.method.value, - self.endpoint.path, - params=self.endpoint.query_params, - data=request_dict, - files=self.files, - content_type=self.content_type, - multipart_parser=self.multipart_parser, - ) - - logging.debug("=" * 50) - logging.debug("[DEBUG] RESPONSE DETAILS:") - logging.debug("[DEBUG] Status Code: 200 (Success)") - logging.debug("[DEBUG] Response Body: %s", json.dumps(response_json, indent=2)) - logging.debug("=" * 50) - - parsed_response = self.endpoint.response_model.model_validate(response_json) - logging.debug("[DEBUG] Parsed Response: %s", parsed_response) - return parsed_response - finally: - if owns_client: - await client.close() - - -class TaskStatus(str, Enum): - """Enum for task status values""" - - COMPLETED = "completed" - FAILED = "failed" - PENDING = "pending" - - -class PollingOperation(Generic[T, R]): - """Represents an asynchronous API operation that requires polling for completion.""" - - def __init__( - self, - poll_endpoint: ApiEndpoint[EmptyRequest, R], - completed_statuses: list[str], - failed_statuses: list[str], - *, - status_extractor: Callable[[R], Optional[str]], - progress_extractor: Callable[[R], Optional[float]] | None = None, - result_url_extractor: Callable[[R], Optional[str]] | None = None, - price_extractor: Callable[[R], Optional[float]] | None = None, - request: Optional[T] = None, - api_base: str | None = None, - auth_token: Optional[str] = None, - comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[dict[str, str]] = None, - poll_interval: float = 5.0, - max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) - max_retries: int = 3, # Max retries per individual API call - retry_delay: float = 1.0, - retry_backoff_factor: float = 2.0, - estimated_duration: Optional[float] = None, - node_id: Optional[str] = None, - ) -> None: - self.poll_endpoint = poll_endpoint - self.request = request - self.api_base: str = api_base or args.comfy_api_base - self.auth_token = auth_token - self.comfy_api_key = comfy_api_key - if auth_kwargs is not None: - self.auth_token = auth_kwargs.get("auth_token", self.auth_token) - self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) - self.poll_interval = poll_interval - self.max_poll_attempts = max_poll_attempts - self.max_retries = max_retries - self.retry_delay = retry_delay - self.retry_backoff_factor = retry_backoff_factor - self.estimated_duration = estimated_duration - self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None)) - self.progress_extractor = progress_extractor - self.result_url_extractor = result_url_extractor - self.price_extractor = price_extractor - self.node_id = node_id - self.completed_statuses = completed_statuses - self.failed_statuses = failed_statuses - self.final_response: Optional[R] = None - self.extracted_price: Optional[float] = None - - async def execute(self, client: Optional[ApiClient] = None) -> R: - owns_client = client is None - if owns_client: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - try: - return await self._poll_until_complete(client) - finally: - if owns_client: - await client.close() - - def _display_text_on_node(self, text: str): - if not self.node_id: - return - if self.extracted_price is not None: - text = f"Price: ${self.extracted_price}\n{text}" - PromptServer.instance.send_progress_text(text, self.node_id) - - def _display_time_progress_on_node(self, time_completed: int | float): - if not self.node_id: - return - if self.estimated_duration is not None: - remaining = max(0, int(self.estimated_duration) - time_completed) - message = f"Task in progress: {time_completed}s (~{remaining}s remaining)" - else: - message = f"Task in progress: {time_completed}s" - self._display_text_on_node(message) - - def _check_task_status(self, response: R) -> TaskStatus: - try: - status = self.status_extractor(response) - if status in self.completed_statuses: - return TaskStatus.COMPLETED - if status in self.failed_statuses: - return TaskStatus.FAILED - return TaskStatus.PENDING - except Exception as e: - logging.error("Error extracting status: %s", e) - return TaskStatus.PENDING - - async def _poll_until_complete(self, client: ApiClient) -> R: - """Poll until the task is complete""" - consecutive_errors = 0 - max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors - - if self.progress_extractor: - progress = utils.ProgressBar(PROGRESS_BAR_MAX) - - status = TaskStatus.PENDING - for poll_count in range(1, self.max_poll_attempts + 1): - try: - logging.debug("[DEBUG] Polling attempt #%s", poll_count) - - request_dict = None if self.request is None else self.request.model_dump(exclude_none=True) - - if poll_count == 1: - logging.debug( - "[DEBUG] Poll Request: %s %s", - self.poll_endpoint.method.value, - self.poll_endpoint.path, - ) - logging.debug( - "[DEBUG] Poll Request Data: %s", - json.dumps(request_dict, indent=2) if request_dict else "None", - ) - - # Query task status - resp = await client.request( - self.poll_endpoint.method.value, - self.poll_endpoint.path, - params=self.poll_endpoint.query_params, - data=request_dict, - ) - consecutive_errors = 0 # reset on success - response_obj: R = self.poll_endpoint.response_model.model_validate(resp) - - # Check if task is complete - status = self._check_task_status(response_obj) - logging.debug("[DEBUG] Task Status: %s", status) - - # If progress extractor is provided, extract progress - if self.progress_extractor: - new_progress = self.progress_extractor(response_obj) - if new_progress is not None: - progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX) - - if self.price_extractor: - price = self.price_extractor(response_obj) - if price is not None: - self.extracted_price = price - - if status == TaskStatus.COMPLETED: - message = "Task completed successfully" - if self.result_url_extractor: - result_url = self.result_url_extractor(response_obj) - if result_url: - message = f"Result URL: {result_url}" - logging.debug("[DEBUG] %s", message) - self._display_text_on_node(message) - self.final_response = response_obj - if self.progress_extractor: - progress.update(100) - return self.final_response - if status == TaskStatus.FAILED: - message = f"Task failed: {json.dumps(resp)}" - logging.error("[DEBUG] %s", message) - raise Exception(message) - logging.debug("[DEBUG] Task still pending, continuing to poll...") - # Task pending – wait - for i in range(int(self.poll_interval)): - self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i) - await asyncio.sleep(1) - - except (LocalNetworkError, ApiServerError, NetworkError) as e: - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors: - raise Exception( - f"Polling aborted after {consecutive_errors} network errors: {str(e)}" - ) from e - logging.warning( - "Network error (%s/%s): %s", - consecutive_errors, - max_consecutive_errors, - str(e), - ) - await asyncio.sleep(self.poll_interval) - except Exception as e: - # For other errors, increment count and potentially abort - consecutive_errors += 1 - if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED: - raise Exception( - f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}" - ) from e - - logging.error("[DEBUG] Polling error: %s", str(e)) - logging.warning( - "Error during polling (attempt %s/%s): %s. Will retry in %s seconds.", - poll_count, - self.max_poll_attempts, - str(e), - self.poll_interval, - ) - await asyncio.sleep(self.poll_interval) - - # If we've exhausted all polling attempts - raise Exception( - f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). " - "The operation may still be running on the server but is taking longer than expected." - ) diff --git a/comfy_api_nodes/apis/gemini_api.py b/comfy_api_nodes/apis/gemini_api.py index 2bf28bf93..f8edc38c9 100644 --- a/comfy_api_nodes/apis/gemini_api.py +++ b/comfy_api_nodes/apis/gemini_api.py @@ -1,22 +1,228 @@ -from typing import Optional +from datetime import date +from enum import Enum +from typing import Any -from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata -from pydantic import BaseModel +from pydantic import BaseModel, Field + + +class GeminiSafetyCategory(str, Enum): + HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT" + HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH" + HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT" + HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT" + + +class GeminiSafetyThreshold(str, Enum): + OFF = "OFF" + BLOCK_NONE = "BLOCK_NONE" + BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE" + BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE" + BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" + + +class GeminiSafetySetting(BaseModel): + category: GeminiSafetyCategory + threshold: GeminiSafetyThreshold + + +class GeminiRole(str, Enum): + user = "user" + model = "model" + + +class GeminiMimeType(str, Enum): + application_pdf = "application/pdf" + audio_mpeg = "audio/mpeg" + audio_mp3 = "audio/mp3" + audio_wav = "audio/wav" + image_png = "image/png" + image_jpeg = "image/jpeg" + image_webp = "image/webp" + text_plain = "text/plain" + video_mov = "video/mov" + video_mpeg = "video/mpeg" + video_mp4 = "video/mp4" + video_mpg = "video/mpg" + video_avi = "video/avi" + video_wmv = "video/wmv" + video_mpegps = "video/mpegps" + video_flv = "video/flv" + + +class GeminiInlineData(BaseModel): + data: str | None = Field( + None, + description="The base64 encoding of the image, PDF, or video to include inline in the prompt. " + "When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB", + ) + mimeType: GeminiMimeType | None = Field(None) + + +class GeminiFileData(BaseModel): + fileUri: str | None = Field(None) + mimeType: GeminiMimeType | None = Field(None) + + +class GeminiPart(BaseModel): + inlineData: GeminiInlineData | None = Field(None) + fileData: GeminiFileData | None = Field(None) + text: str | None = Field(None) + + +class GeminiTextPart(BaseModel): + text: str | None = Field(None) + + +class GeminiContent(BaseModel): + parts: list[GeminiPart] = Field([]) + role: GeminiRole = Field(..., examples=["user"]) + + +class GeminiSystemInstructionContent(BaseModel): + parts: list[GeminiTextPart] = Field( + ..., + description="A list of ordered parts that make up a single message. " + "Different parts may have different IANA MIME types.", + ) + role: GeminiRole | None = Field(..., description="The role field of systemInstruction may be ignored.") + + +class GeminiFunctionDeclaration(BaseModel): + description: str | None = Field(None) + name: str = Field(...) + parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters") + + +class GeminiTool(BaseModel): + functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None) + + +class GeminiOffset(BaseModel): + nanos: int | None = Field(None, ge=0, le=999999999) + seconds: int | None = Field(None, ge=-315576000000, le=315576000000) + + +class GeminiVideoMetadata(BaseModel): + endOffset: GeminiOffset | None = Field(None) + startOffset: GeminiOffset | None = Field(None) + + +class GeminiGenerationConfig(BaseModel): + maxOutputTokens: int | None = Field(None, ge=16, le=8192) + seed: int | None = Field(None) + stopSequences: list[str] | None = Field(None) + temperature: float | None = Field(None, ge=0.0, le=2.0) + topK: int | None = Field(None, ge=1) + topP: float | None = Field(None, ge=0.0, le=1.0) class GeminiImageConfig(BaseModel): - aspectRatio: Optional[str] = None + aspectRatio: str | None = Field(None) + imageSize: str | None = Field(None) class GeminiImageGenerationConfig(GeminiGenerationConfig): - responseModalities: Optional[list[str]] = None - imageConfig: Optional[GeminiImageConfig] = None + responseModalities: list[str] | None = Field(None) + imageConfig: GeminiImageConfig | None = Field(None) class GeminiImageGenerateContentRequest(BaseModel): - contents: list[GeminiContent] - generationConfig: Optional[GeminiImageGenerationConfig] = None - safetySettings: Optional[list[GeminiSafetySetting]] = None - systemInstruction: Optional[GeminiSystemInstructionContent] = None - tools: Optional[list[GeminiTool]] = None - videoMetadata: Optional[GeminiVideoMetadata] = None + contents: list[GeminiContent] = Field(...) + generationConfig: GeminiImageGenerationConfig | None = Field(None) + safetySettings: list[GeminiSafetySetting] | None = Field(None) + systemInstruction: GeminiSystemInstructionContent | None = Field(None) + tools: list[GeminiTool] | None = Field(None) + videoMetadata: GeminiVideoMetadata | None = Field(None) + + +class GeminiGenerateContentRequest(BaseModel): + contents: list[GeminiContent] = Field(...) + generationConfig: GeminiGenerationConfig | None = Field(None) + safetySettings: list[GeminiSafetySetting] | None = Field(None) + systemInstruction: GeminiSystemInstructionContent | None = Field(None) + tools: list[GeminiTool] | None = Field(None) + videoMetadata: GeminiVideoMetadata | None = Field(None) + + +class Modality(str, Enum): + MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED" + TEXT = "TEXT" + IMAGE = "IMAGE" + VIDEO = "VIDEO" + AUDIO = "AUDIO" + DOCUMENT = "DOCUMENT" + + +class ModalityTokenCount(BaseModel): + modality: Modality | None = None + tokenCount: int | None = Field(None, description="Number of tokens for the given modality.") + + +class Probability(str, Enum): + NEGLIGIBLE = "NEGLIGIBLE" + LOW = "LOW" + MEDIUM = "MEDIUM" + HIGH = "HIGH" + UNKNOWN = "UNKNOWN" + + +class GeminiSafetyRating(BaseModel): + category: GeminiSafetyCategory | None = None + probability: Probability | None = Field( + None, + description="The probability that the content violates the specified safety category", + ) + + +class GeminiCitation(BaseModel): + authors: list[str] | None = None + endIndex: int | None = None + license: str | None = None + publicationDate: date | None = None + startIndex: int | None = None + title: str | None = None + uri: str | None = None + + +class GeminiCitationMetadata(BaseModel): + citations: list[GeminiCitation] | None = None + + +class GeminiCandidate(BaseModel): + citationMetadata: GeminiCitationMetadata | None = None + content: GeminiContent | None = None + finishReason: str | None = None + safetyRatings: list[GeminiSafetyRating] | None = None + + +class GeminiPromptFeedback(BaseModel): + blockReason: str | None = None + blockReasonMessage: str | None = None + safetyRatings: list[GeminiSafetyRating] | None = None + + +class GeminiUsageMetadata(BaseModel): + cachedContentTokenCount: int | None = Field( + None, + description="Output only. Number of tokens in the cached part in the input (the cached content).", + ) + candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).") + candidatesTokensDetails: list[ModalityTokenCount] | None = Field( + None, description="Breakdown of candidate tokens by modality." + ) + promptTokenCount: int | None = Field( + None, + description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.", + ) + promptTokensDetails: list[ModalityTokenCount] | None = Field( + None, description="Breakdown of prompt tokens by modality." + ) + thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.") + toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).") + + +class GeminiGenerateContentResponse(BaseModel): + candidates: list[GeminiCandidate] | None = Field(None) + promptFeedback: GeminiPromptFeedback | None = Field(None) + usageMetadata: GeminiUsageMetadata | None = Field(None) + modelVersion: str | None = Field(None) diff --git a/comfy_api_nodes/apis/kling_api.py b/comfy_api_nodes/apis/kling_api.py new file mode 100644 index 000000000..d8949f8ac --- /dev/null +++ b/comfy_api_nodes/apis/kling_api.py @@ -0,0 +1,86 @@ +from pydantic import BaseModel, Field + + +class OmniProText2VideoRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + aspect_ratio: str = Field(..., description="'16:9', '9:16' or '1:1'") + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + + +class OmniParamImage(BaseModel): + image_url: str = Field(...) + type: str | None = Field(None, description="Can be 'first_frame' or 'end_frame'") + + +class OmniParamVideo(BaseModel): + video_url: str = Field(...) + refer_type: str | None = Field(..., description="Can be 'base' or 'feature'") + keep_original_sound: str = Field(..., description="'yes' or 'no'") + + +class OmniProFirstLastFrameRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + image_list: list[OmniParamImage] = Field(..., min_length=1, max_length=7) + duration: str = Field(..., description="'5' or '10'") + prompt: str = Field(...) + mode: str = Field("pro") + + +class OmniProReferences2VideoRequest(BaseModel): + model_name: str = Field(..., description="kling-video-o1") + aspect_ratio: str | None = Field(..., description="'16:9', '9:16' or '1:1'") + image_list: list[OmniParamImage] | None = Field( + None, max_length=7, description="Max length 4 when video is present." + ) + video_list: list[OmniParamVideo] | None = Field(None, max_length=1) + duration: str | None = Field(..., description="From 3 to 10.") + prompt: str = Field(...) + mode: str = Field("pro") + + +class TaskStatusVideoResult(BaseModel): + duration: str | None = Field(None, description="Total video duration") + id: str | None = Field(None, description="Generated video ID") + url: str | None = Field(None, description="URL for generated video") + + +class TaskStatusImageResult(BaseModel): + index: int = Field(..., description="Image Number,0-9") + url: str = Field(..., description="URL for generated image") + + +class OmniTaskStatusResults(BaseModel): + videos: list[TaskStatusVideoResult] | None = Field(None) + images: list[TaskStatusImageResult] | None = Field(None) + + +class OmniTaskStatusResponseData(BaseModel): + created_at: int | None = Field(None, description="Task creation time") + updated_at: int | None = Field(None, description="Task update time") + task_status: str | None = None + task_status_msg: str | None = Field(None, description="Additional failure reason. Only for polling endpoint.") + task_id: str | None = Field(None, description="Task ID") + task_result: OmniTaskStatusResults | None = Field(None) + + +class OmniTaskStatusResponse(BaseModel): + code: int | None = Field(None, description="Error code") + message: str | None = Field(None, description="Error message") + request_id: str | None = Field(None, description="Request ID") + data: OmniTaskStatusResponseData | None = Field(None) + + +class OmniImageParamImage(BaseModel): + image: str = Field(...) + + +class OmniProImageRequest(BaseModel): + model_name: str = Field(..., description="kling-image-o1") + resolution: str = Field(..., description="'1k' or '2k'") + aspect_ratio: str | None = Field(...) + prompt: str = Field(...) + mode: str = Field("pro") + n: int | None = Field(1, le=9) + image_list: list[OmniImageParamImage] | None = Field(..., max_length=10) diff --git a/comfy_api_nodes/apis/minimax_api.py b/comfy_api_nodes/apis/minimax_api.py new file mode 100644 index 000000000..d747e177a --- /dev/null +++ b/comfy_api_nodes/apis/minimax_api.py @@ -0,0 +1,120 @@ +from enum import Enum +from typing import Optional + +from pydantic import BaseModel, Field + + +class MinimaxBaseResponse(BaseModel): + status_code: int = Field( + ..., + description='Status code. 0 indicates success, other values indicate errors.', + ) + status_msg: str = Field( + ..., description='Specific error details or success message.' + ) + + +class File(BaseModel): + bytes: Optional[int] = Field(None, description='File size in bytes') + created_at: Optional[int] = Field( + None, description='Unix timestamp when the file was created, in seconds' + ) + download_url: Optional[str] = Field( + None, description='The URL to download the video' + ) + backup_download_url: Optional[str] = Field( + None, description='The backup URL to download the video' + ) + + file_id: Optional[int] = Field(None, description='Unique identifier for the file') + filename: Optional[str] = Field(None, description='The name of the file') + purpose: Optional[str] = Field(None, description='The purpose of using the file') + + +class MinimaxFileRetrieveResponse(BaseModel): + base_resp: MinimaxBaseResponse + file: File + + +class MiniMaxModel(str, Enum): + T2V_01_Director = 'T2V-01-Director' + I2V_01_Director = 'I2V-01-Director' + S2V_01 = 'S2V-01' + I2V_01 = 'I2V-01' + I2V_01_live = 'I2V-01-live' + T2V_01 = 'T2V-01' + Hailuo_02 = 'MiniMax-Hailuo-02' + + +class Status6(str, Enum): + Queueing = 'Queueing' + Preparing = 'Preparing' + Processing = 'Processing' + Success = 'Success' + Fail = 'Fail' + + +class MinimaxTaskResultResponse(BaseModel): + base_resp: MinimaxBaseResponse + file_id: Optional[str] = Field( + None, + description='After the task status changes to Success, this field returns the file ID corresponding to the generated video.', + ) + status: Status6 = Field( + ..., + description="Task status: 'Queueing' (in queue), 'Preparing' (task is preparing), 'Processing' (generating), 'Success' (task completed successfully), or 'Fail' (task failed).", + ) + task_id: str = Field(..., description='The task ID being queried.') + + +class SubjectReferenceItem(BaseModel): + image: Optional[str] = Field( + None, description='URL or base64 encoding of the subject reference image.' + ) + mask: Optional[str] = Field( + None, + description='URL or base64 encoding of the mask for the subject reference image.', + ) + + +class MinimaxVideoGenerationRequest(BaseModel): + callback_url: Optional[str] = Field( + None, + description='Optional. URL to receive real-time status updates about the video generation task.', + ) + first_frame_image: Optional[str] = Field( + None, + description='URL or base64 encoding of the first frame image. Required when model is I2V-01, I2V-01-Director, or I2V-01-live.', + ) + model: MiniMaxModel = Field( + ..., + description='Required. ID of model. Options: T2V-01-Director, I2V-01-Director, S2V-01, I2V-01, I2V-01-live, T2V-01', + ) + prompt: Optional[str] = Field( + None, + description='Description of the video. Should be less than 2000 characters. Supports camera movement instructions in [brackets].', + max_length=2000, + ) + prompt_optimizer: Optional[bool] = Field( + True, + description='If true (default), the model will automatically optimize the prompt. Set to false for more precise control.', + ) + subject_reference: Optional[list[SubjectReferenceItem]] = Field( + None, + description='Only available when model is S2V-01. The model will generate a video based on the subject uploaded through this parameter.', + ) + duration: Optional[int] = Field( + None, + description="The length of the output video in seconds." + ) + resolution: Optional[str] = Field( + None, + description="The dimensions of the video display. 1080p corresponds to 1920 x 1080 pixels, 768p corresponds to 1366 x 768 pixels." + ) + + +class MinimaxVideoGenerationResponse(BaseModel): + base_resp: MinimaxBaseResponse + task_id: str = Field( + ..., description='The task ID for the asynchronous video generation task.' + ) diff --git a/comfy_api_nodes/apis/pika_defs.py b/comfy_api_nodes/apis/pika_api.py similarity index 100% rename from comfy_api_nodes/apis/pika_defs.py rename to comfy_api_nodes/apis/pika_api.py diff --git a/comfy_api_nodes/apis/topaz_api.py b/comfy_api_nodes/apis/topaz_api.py new file mode 100644 index 000000000..4d9e62e72 --- /dev/null +++ b/comfy_api_nodes/apis/topaz_api.py @@ -0,0 +1,133 @@ +from typing import Optional, Union + +from pydantic import BaseModel, Field + + +class ImageEnhanceRequest(BaseModel): + model: str = Field("Reimagine") + output_format: str = Field("jpeg") + subject_detection: str = Field("All") + face_enhancement: bool = Field(True) + face_enhancement_creativity: float = Field(0, description="Is ignored if face_enhancement is false") + face_enhancement_strength: float = Field(0.8, description="Is ignored if face_enhancement is false") + source_url: str = Field(...) + output_width: Optional[int] = Field(None) + output_height: Optional[int] = Field(None) + crop_to_fill: bool = Field(False) + prompt: Optional[str] = Field(None, description="Text prompt for creative upscaling guidance") + creativity: int = Field(3, description="Creativity settings range from 1 to 9") + face_preservation: str = Field("true", description="To preserve the identity of characters") + color_preservation: str = Field("true", description="To preserve the original color") + + +class ImageAsyncTaskResponse(BaseModel): + process_id: str = Field(...) + + +class ImageStatusResponse(BaseModel): + process_id: str = Field(...) + status: str = Field(...) + progress: Optional[int] = Field(None) + credits: int = Field(...) + + +class ImageDownloadResponse(BaseModel): + download_url: str = Field(...) + expiry: int = Field(...) + + +class Resolution(BaseModel): + width: int = Field(...) + height: int = Field(...) + + +class CreateCreateVideoRequestSource(BaseModel): + container: str = Field(...) + size: int = Field(..., description="Size of the video file in bytes") + duration: int = Field(..., description="Duration of the video file in seconds") + frameCount: int = Field(..., description="Total number of frames in the video") + frameRate: int = Field(...) + resolution: Resolution = Field(...) + + +class VideoFrameInterpolationFilter(BaseModel): + model: str = Field(...) + slowmo: Optional[int] = Field(None) + fps: int = Field(...) + duplicate: bool = Field(...) + duplicate_threshold: float = Field(...) + + +class VideoEnhancementFilter(BaseModel): + model: str = Field(...) + auto: Optional[str] = Field(None, description="Auto, Manual, Relative") + focusFixLevel: Optional[str] = Field(None, description="Downscales video input for correction of blurred subjects") + compression: Optional[float] = Field(None, description="Strength of compression recovery") + details: Optional[float] = Field(None, description="Amount of detail reconstruction") + prenoise: Optional[float] = Field(None, description="Amount of noise to add to input to reduce over-smoothing") + noise: Optional[float] = Field(None, description="Amount of noise reduction") + halo: Optional[float] = Field(None, description="Amount of halo reduction") + preblur: Optional[float] = Field(None, description="Anti-aliasing and deblurring strength") + blur: Optional[float] = Field(None, description="Amount of sharpness applied") + grain: Optional[float] = Field(None, description="Grain after AI model processing") + grainSize: Optional[float] = Field(None, description="Size of generated grain") + recoverOriginalDetailValue: Optional[float] = Field(None, description="Source details into the output video") + creativity: Optional[str] = Field(None, description="Creativity level(high, low) for slc-1 only") + isOptimizedMode: Optional[bool] = Field(None, description="Set to true for Starlight Creative (slc-1) only") + + +class OutputInformationVideo(BaseModel): + resolution: Resolution = Field(...) + frameRate: int = Field(...) + audioCodec: Optional[str] = Field(..., description="Required if audioTransfer is Copy or Convert") + audioTransfer: str = Field(..., description="Copy, Convert, None") + dynamicCompressionLevel: str = Field(..., description="Low, Mid, High") + + +class Overrides(BaseModel): + isPaidDiffusion: bool = Field(True) + + +class CreateVideoRequest(BaseModel): + source: CreateCreateVideoRequestSource = Field(...) + filters: list[Union[VideoFrameInterpolationFilter, VideoEnhancementFilter]] = Field(...) + output: OutputInformationVideo = Field(...) + overrides: Overrides = Field(Overrides(isPaidDiffusion=True)) + + +class CreateVideoResponse(BaseModel): + requestId: str = Field(...) + + +class VideoAcceptResponse(BaseModel): + uploadId: str = Field(...) + urls: list[str] = Field(...) + + +class VideoCompleteUploadRequestPart(BaseModel): + partNum: int = Field(...) + eTag: str = Field(...) + + +class VideoCompleteUploadRequest(BaseModel): + uploadResults: list[VideoCompleteUploadRequestPart] = Field(...) + + +class VideoCompleteUploadResponse(BaseModel): + message: str = Field(..., description="Confirmation message") + + +class VideoStatusResponseEstimates(BaseModel): + cost: list[int] = Field(...) + + +class VideoStatusResponseDownloadUrl(BaseModel): + url: str = Field(...) + + +class VideoStatusResponse(BaseModel): + status: str = Field(...) + estimates: Optional[VideoStatusResponseEstimates] = Field(None) + progress: Optional[float] = Field(None) + message: Optional[str] = Field("") + download: Optional[VideoStatusResponseDownloadUrl] = Field(None) diff --git a/comfy_api_nodes/apis/tripo_api.py b/comfy_api_nodes/apis/tripo_api.py index 9f43d4d09..713260e2a 100644 --- a/comfy_api_nodes/apis/tripo_api.py +++ b/comfy_api_nodes/apis/tripo_api.py @@ -1,13 +1,20 @@ from __future__ import annotations -from comfy_api_nodes.apis import ( - TripoModelVersion, - TripoTextureQuality, -) from enum import Enum from typing import Optional, List, Dict, Any, Union from pydantic import BaseModel, Field, RootModel +class TripoModelVersion(str, Enum): + v2_5_20250123 = 'v2.5-20250123' + v2_0_20240919 = 'v2.0-20240919' + v1_4_20240625 = 'v1.4-20240625' + + +class TripoTextureQuality(str, Enum): + standard = 'standard' + detailed = 'detailed' + + class TripoStyle(str, Enum): PERSON_TO_CARTOON = "person:person2cartoon" ANIMAL_VENOM = "animal:venom" diff --git a/comfy_api_nodes/apis/veo_api.py b/comfy_api_nodes/apis/veo_api.py new file mode 100644 index 000000000..23ca725b7 --- /dev/null +++ b/comfy_api_nodes/apis/veo_api.py @@ -0,0 +1,99 @@ +from typing import Optional + +from pydantic import BaseModel, Field + + +class VeoRequestInstanceImage(BaseModel): + bytesBase64Encoded: str | None = Field(None) + gcsUri: str | None = Field(None) + mimeType: str | None = Field(None) + + +class VeoRequestInstance(BaseModel): + image: VeoRequestInstanceImage | None = Field(None) + lastFrame: VeoRequestInstanceImage | None = Field(None) + prompt: str = Field(..., description='Text description of the video') + + +class VeoRequestParameters(BaseModel): + aspectRatio: Optional[str] = Field(None, examples=['16:9']) + durationSeconds: Optional[int] = None + enhancePrompt: Optional[bool] = None + generateAudio: Optional[bool] = Field( + None, + description='Generate audio for the video. Only supported by veo 3 models.', + ) + negativePrompt: Optional[str] = None + personGeneration: str | None = Field(None, description="ALLOW or BLOCK") + sampleCount: Optional[int] = None + seed: Optional[int] = None + storageUri: Optional[str] = Field( + None, description='Optional Cloud Storage URI to upload the video' + ) + resolution: str | None = Field(None) + + +class VeoGenVidRequest(BaseModel): + instances: list[VeoRequestInstance] | None = Field(None) + parameters: VeoRequestParameters | None = Field(None) + + +class VeoGenVidResponse(BaseModel): + name: str = Field( + ..., + description='Operation resource name', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8' + ], + ) + + +class VeoGenVidPollRequest(BaseModel): + operationName: str = Field( + ..., + description='Full operation name (from predict response)', + examples=[ + 'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID' + ], + ) + + +class Video(BaseModel): + bytesBase64Encoded: Optional[str] = Field( + None, description='Base64-encoded video content' + ) + gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video') + mimeType: Optional[str] = Field(None, description='Video MIME type') + + +class Error1(BaseModel): + code: Optional[int] = Field(None, description='Error code') + message: Optional[str] = Field(None, description='Error message') + + +class Response1(BaseModel): + field_type: Optional[str] = Field( + None, + alias='@type', + examples=[ + 'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse' + ], + ) + raiMediaFilteredCount: Optional[int] = Field( + None, description='Count of media filtered by responsible AI policies' + ) + raiMediaFilteredReasons: Optional[list[str]] = Field( + None, description='Reasons why media was filtered by responsible AI policies' + ) + videos: Optional[list[Video]] = Field(None) + + +class VeoGenVidPollResponse(BaseModel): + done: Optional[bool] = None + error: Optional[Error1] = Field( + None, description='Error details if operation failed' + ) + name: Optional[str] = None + response: Optional[Response1] = Field( + None, description='The actual prediction response if done is true' + ) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index 10755a9b2..8826dea0c 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,146 +1,47 @@ -import asyncio -import io from inspect import cleandoc -from typing import Union, Optional + +import torch +from pydantic import BaseModel from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO + +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.bfl_api import ( - BFLStatus, BFLFluxExpandImageRequest, BFLFluxFillImageRequest, - BFLFluxCannyImageRequest, - BFLFluxDepthImageRequest, - BFLFluxProGenerateRequest, BFLFluxKontextProGenerateRequest, - BFLFluxProUltraGenerateRequest, BFLFluxProGenerateResponse, + BFLFluxProUltraGenerateRequest, + BFLFluxStatusResponse, + BFLStatus, + Flux2ProGenerateRequest, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, -) -from comfy_api_nodes.apinode_utils import ( - downscale_image_tensor, - validate_aspect_ratio, - process_image_response, + download_url_to_image_tensor, + get_number_of_images, + poll_op, resize_mask_to_image, + sync_op, + tensor_to_base64_string, + validate_aspect_ratio_string, validate_string, ) -import numpy as np -from PIL import Image -import aiohttp -import torch -import base64 -import time -from comfy.cmd.server import PromptServer - def convert_mask_to_image(mask: torch.Tensor): """ Make mask have the expected amount of dims (4) and channels (3) to be recognized as an image. """ mask = mask.unsqueeze(-1) - mask = torch.cat([mask]*3, dim=-1) + mask = torch.cat([mask] * 3, dim=-1) return mask -async def handle_bfl_synchronous_operation( - operation: SynchronousOperation, - timeout_bfl_calls=360, - node_id: Union[str, None] = None, -): - response_api: BFLFluxProGenerateResponse = await operation.execute() - return await _poll_until_generated( - response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id - ) - - -async def _poll_until_generated( - polling_url: str, timeout=360, node_id: Union[str, None] = None -): - # used bfl-comfy-nodes to verify code implementation: - # https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main - start_time = time.time() - retries_404 = 0 - max_retries_404 = 5 - retry_404_seconds = 2 - retry_202_seconds = 2 - retry_pending_seconds = 1 - - async with aiohttp.ClientSession() as session: - # NOTE: should True loop be replaced with checking if workflow has been interrupted? - while True: - if node_id: - time_elapsed = time.time() - start_time - PromptServer.instance.send_progress_text( - f"Generating ({time_elapsed:.0f}s)", node_id - ) - - async with session.get(polling_url) as response: - if response.status == 200: - result = await response.json() - if result["status"] == BFLStatus.ready: - img_url = result["result"]["sample"] - if node_id: - PromptServer.instance.send_progress_text( - f"Result URL: {img_url}", node_id - ) - async with session.get(img_url) as img_resp: - return process_image_response(await img_resp.content.read()) - elif result["status"] in [ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - ]: - status = result["status"] - raise Exception( - f"BFL API did not return an image due to: {status}." - ) - elif result["status"] == BFLStatus.error: - raise Exception(f"BFL API encountered an error: {result}.") - elif result["status"] == BFLStatus.pending: - await asyncio.sleep(retry_pending_seconds) - continue - elif response.status == 404: - if retries_404 < max_retries_404: - retries_404 += 1 - await asyncio.sleep(retry_404_seconds) - continue - raise Exception( - f"BFL API could not find task after {max_retries_404} tries." - ) - elif response.status == 202: - await asyncio.sleep(retry_202_seconds) - elif time.time() - start_time > timeout: - raise Exception( - f"BFL API experienced a timeout; could not return request under {timeout} seconds." - ) - else: - raise Exception(f"BFL API encountered an error: {response.json()}") - -def convert_image_to_base64(image: torch.Tensor): - scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048) - # remove batch dimension if present - if len(scaled_image.shape) > 3: - scaled_image = scaled_image[0] - image_np = (scaled_image.numpy() * 255).astype(np.uint8) - img = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() - img.save(img_byte_arr, format="PNG") - return base64.b64encode(img_byte_arr.getvalue()).decode() - - class FluxProUltraImageNode(IO.ComfyNode): """ Generates images using Flux Pro 1.1 Ultra via api based on prompt and resolution. """ - MINIMUM_RATIO = 1 / 4 - MAXIMUM_RATIO = 4 / 1 - MINIMUM_RATIO_STR = "1:4" - MAXIMUM_RATIO_STR = "4:1" - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( @@ -158,7 +59,9 @@ class FluxProUltraImageNode(IO.ComfyNode): IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Int.Input( "seed", @@ -203,16 +106,7 @@ class FluxProUltraImageNode(IO.ComfyNode): @classmethod def validate_inputs(cls, aspect_ratio: str): - try: - validate_aspect_ratio( - aspect_ratio, - minimum_ratio=cls.MINIMUM_RATIO, - maximum_ratio=cls.MAXIMUM_RATIO, - minimum_ratio_str=cls.MINIMUM_RATIO_STR, - maximum_ratio_str=cls.MAXIMUM_RATIO_STR, - ) - except Exception as e: - return str(e) + validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1)) return True @classmethod @@ -220,49 +114,44 @@ class FluxProUltraImageNode(IO.ComfyNode): cls, prompt: str, aspect_ratio: str, - prompt_upsampling=False, - raw=False, - seed=0, - image_prompt=None, - image_prompt_strength=0.1, + prompt_upsampling: bool = False, + raw: bool = False, + seed: int = 0, + image_prompt: torch.Tensor | None = None, + image_prompt_strength: float = 0.1, ) -> IO.NodeOutput: if image_prompt is None: validate_string(prompt, strip_whitespace=False) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.1-ultra/generate", - method=HttpMethod.POST, - request_model=BFLFluxProUltraGenerateRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxProUltraGenerateRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.1-ultra/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxProUltraGenerateRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, seed=seed, - aspect_ratio=validate_aspect_ratio( - aspect_ratio, - minimum_ratio=cls.MINIMUM_RATIO, - maximum_ratio=cls.MAXIMUM_RATIO, - minimum_ratio_str=cls.MINIMUM_RATIO_STR, - maximum_ratio_str=cls.MAXIMUM_RATIO_STR, - ), + aspect_ratio=aspect_ratio, raw=raw, - image_prompt=( - image_prompt - if image_prompt is None - else convert_image_to_base64(image_prompt) - ), - image_prompt_strength=( - None if image_prompt is None else round(image_prompt_strength, 2) - ), + image_prompt=(image_prompt if image_prompt is None else tensor_to_base64_string(image_prompt)), + image_prompt_strength=(None if image_prompt is None else round(image_prompt_strength, 2)), ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxKontextProImageNode(IO.ComfyNode): @@ -270,11 +159,6 @@ class FluxKontextProImageNode(IO.ComfyNode): Edits images using Flux.1 Kontext [pro] via api based on prompt and aspect ratio. """ - MINIMUM_RATIO = 1 / 4 - MAXIMUM_RATIO = 4 / 1 - MINIMUM_RATIO_STR = "1:4" - MAXIMUM_RATIO_STR = "4:1" - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( @@ -347,46 +231,43 @@ class FluxKontextProImageNode(IO.ComfyNode): aspect_ratio: str, guidance: float, steps: int, - input_image: Optional[torch.Tensor]=None, + input_image: torch.Tensor | None = None, seed=0, prompt_upsampling=False, ) -> IO.NodeOutput: - aspect_ratio = validate_aspect_ratio( - aspect_ratio, - minimum_ratio=cls.MINIMUM_RATIO, - maximum_ratio=cls.MAXIMUM_RATIO, - minimum_ratio_str=cls.MINIMUM_RATIO_STR, - maximum_ratio_str=cls.MAXIMUM_RATIO_STR, - ) + validate_aspect_ratio_string(aspect_ratio, (1, 4), (4, 1)) if input_image is None: validate_string(prompt, strip_whitespace=False) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=cls.BFL_PATH, - method=HttpMethod.POST, - request_model=BFLFluxKontextProGenerateRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxKontextProGenerateRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path=cls.BFL_PATH, method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxKontextProGenerateRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, guidance=round(guidance, 1), steps=steps, seed=seed, aspect_ratio=aspect_ratio, - input_image=( - input_image - if input_image is None - else convert_image_to_base64(input_image) - ) + input_image=(input_image if input_image is None else tensor_to_base64_string(input_image)), ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxKontextMaxImageNode(FluxKontextProImageNode): @@ -400,117 +281,6 @@ class FluxKontextMaxImageNode(FluxKontextProImageNode): DISPLAY_NAME = "Flux.1 Kontext [max] Image" -class FluxProImageNode(IO.ComfyNode): - """ - Generates images synchronously based on prompt and resolution. - """ - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="FluxProImageNode", - display_name="Flux 1.1 [pro] Image", - category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), - inputs=[ - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Prompt for the image generation", - ), - IO.Boolean.Input( - "prompt_upsampling", - default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - ), - IO.Int.Input( - "width", - default=1024, - min=256, - max=1440, - step=32, - ), - IO.Int.Input( - "height", - default=768, - min=256, - max=1440, - step=32, - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=0xFFFFFFFFFFFFFFFF, - control_after_generate=True, - tooltip="The random seed used for creating the noise.", - ), - IO.Image.Input( - "image_prompt", - optional=True, - ), - # "image_prompt_strength": ( - # IO.FLOAT, - # { - # "default": 0.1, - # "min": 0.0, - # "max": 1.0, - # "step": 0.01, - # "tooltip": "Blend between the prompt and the image prompt.", - # }, - # ), - ], - outputs=[IO.Image.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - ) - - @classmethod - async def execute( - cls, - prompt: str, - prompt_upsampling, - width: int, - height: int, - seed=0, - image_prompt=None, - # image_prompt_strength=0.1, - ) -> IO.NodeOutput: - image_prompt = ( - image_prompt - if image_prompt is None - else convert_image_to_base64(image_prompt) - ) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.1/generate", - method=HttpMethod.POST, - request_model=BFLFluxProGenerateRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxProGenerateRequest( - prompt=prompt, - prompt_upsampling=prompt_upsampling, - width=width, - height=height, - seed=seed, - image_prompt=image_prompt, - ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) - - class FluxProExpandNode(IO.ComfyNode): """ Outpaints image based on prompt. @@ -534,7 +304,9 @@ class FluxProExpandNode(IO.ComfyNode): IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Int.Input( "top", @@ -610,16 +382,11 @@ class FluxProExpandNode(IO.ComfyNode): guidance: float, seed=0, ) -> IO.NodeOutput: - image = convert_image_to_base64(image) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-expand/generate", - method=HttpMethod.POST, - request_model=BFLFluxExpandImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxExpandImageRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-expand/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxExpandImageRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, top=top, @@ -629,16 +396,25 @@ class FluxProExpandNode(IO.ComfyNode): steps=steps, guidance=guidance, seed=seed, - image=image, + image=tensor_to_base64_string(image), ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) - + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class FluxProFillNode(IO.ComfyNode): @@ -665,7 +441,9 @@ class FluxProFillNode(IO.ComfyNode): IO.Boolean.Input( "prompt_upsampling", default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", ), IO.Float.Input( "guidance", @@ -712,94 +490,68 @@ class FluxProFillNode(IO.ComfyNode): ) -> IO.NodeOutput: # prepare mask mask = resize_mask_to_image(mask, image) - mask = convert_image_to_base64(convert_mask_to_image(mask)) - # make sure image will have alpha channel removed - image = convert_image_to_base64(image[:, :, :, :3]) - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-fill/generate", - method=HttpMethod.POST, - request_model=BFLFluxFillImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxFillImageRequest( + mask = tensor_to_base64_string(convert_mask_to_image(mask)) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-pro-1.0-fill/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=BFLFluxFillImageRequest( prompt=prompt, prompt_upsampling=prompt_upsampling, steps=steps, guidance=guidance, seed=seed, - image=image, + image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed mask=mask, ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) -class FluxProCannyNode(IO.ComfyNode): - """ - Generate image using a control image (canny). - """ +class Flux2ProImageNode(IO.ComfyNode): @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( - node_id="FluxProCannyNode", - display_name="Flux.1 Canny Control Image", + node_id="Flux2ProImageNode", + display_name="Flux.2 [pro] Image", category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), + description="Generates images synchronously based on prompt and resolution.", inputs=[ - IO.Image.Input("control_image"), IO.String.Input( "prompt", multiline=True, default="", - tooltip="Prompt for the image generation", - ), - IO.Boolean.Input( - "prompt_upsampling", - default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - ), - IO.Float.Input( - "canny_low_threshold", - default=0.1, - min=0.01, - max=0.99, - step=0.01, - tooltip="Low threshold for Canny edge detection; ignored if skip_processing is True", - ), - IO.Float.Input( - "canny_high_threshold", - default=0.4, - min=0.01, - max=0.99, - step=0.01, - tooltip="High threshold for Canny edge detection; ignored if skip_processing is True", - ), - IO.Boolean.Input( - "skip_preprocessing", - default=False, - tooltip="Whether to skip preprocessing; set to True if control_image already is canny-fied, False if it is a raw image.", - ), - IO.Float.Input( - "guidance", - default=30, - min=1, - max=100, - tooltip="Guidance strength for the image generation process", + tooltip="Prompt for the image generation or edit", ), IO.Int.Input( - "steps", - default=50, - min=15, - max=50, - tooltip="Number of steps for the image generation process", + "width", + default=1024, + min=256, + max=2048, + step=32, + ), + IO.Int.Input( + "height", + default=768, + min=256, + max=2048, + step=32, ), IO.Int.Input( "seed", @@ -809,6 +561,14 @@ class FluxProCannyNode(IO.ComfyNode): control_after_generate=True, tooltip="The random seed used for creating the noise.", ), + IO.Boolean.Input( + "prompt_upsampling", + default=False, + tooltip="Whether to perform upsampling on the prompt. " + "If active, automatically modifies the prompt for more creative generation, " + "but results are nondeterministic (same seed will not produce exactly the same result).", + ), + IO.Image.Input("images", optional=True, tooltip="Up to 4 images to be used as references."), ], outputs=[IO.Image.Output()], hidden=[ @@ -822,162 +582,54 @@ class FluxProCannyNode(IO.ComfyNode): @classmethod async def execute( cls, - control_image: torch.Tensor, prompt: str, + width: int, + height: int, + seed: int, prompt_upsampling: bool, - canny_low_threshold: float, - canny_high_threshold: float, - skip_preprocessing: bool, - steps: int, - guidance: float, - seed=0, + images: torch.Tensor | None = None, ) -> IO.NodeOutput: - control_image = convert_image_to_base64(control_image[:, :, :, :3]) - preprocessed_image = None - - # scale canny threshold between 0-500, to match BFL's API - def scale_value(value: float, min_val=0, max_val=500): - return min_val + value * (max_val - min_val) - canny_low_threshold = int(round(scale_value(canny_low_threshold))) - canny_high_threshold = int(round(scale_value(canny_high_threshold))) - - - if skip_preprocessing: - preprocessed_image = control_image - control_image = None - canny_low_threshold = None - canny_high_threshold = None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-canny/generate", - method=HttpMethod.POST, - request_model=BFLFluxCannyImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxCannyImageRequest( + reference_images = {} + if images is not None: + if get_number_of_images(images) > 9: + raise ValueError("The current maximum number of supported images is 9.") + for image_index in range(images.shape[0]): + key_name = f"input_image_{image_index + 1}" if image_index else "input_image" + reference_images[key_name] = tensor_to_base64_string(images[image_index], total_pixels=2048 * 2048) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/bfl/flux-2-pro/generate", method="POST"), + response_model=BFLFluxProGenerateResponse, + data=Flux2ProGenerateRequest( prompt=prompt, - prompt_upsampling=prompt_upsampling, - steps=steps, - guidance=guidance, + width=width, + height=height, seed=seed, - control_image=control_image, - canny_low_threshold=canny_low_threshold, - canny_high_threshold=canny_high_threshold, - preprocessed_image=preprocessed_image, - ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) - - -class FluxProDepthNode(IO.ComfyNode): - """ - Generate image using a control image (depth). - """ - - @classmethod - def define_schema(cls) -> IO.Schema: - return IO.Schema( - node_id="FluxProDepthNode", - display_name="Flux.1 Depth Control Image", - category="api node/image/BFL", - description=cleandoc(cls.__doc__ or ""), - inputs=[ - IO.Image.Input("control_image"), - IO.String.Input( - "prompt", - multiline=True, - default="", - tooltip="Prompt for the image generation", - ), - IO.Boolean.Input( - "prompt_upsampling", - default=False, - tooltip="Whether to perform upsampling on the prompt. If active, automatically modifies the prompt for more creative generation, but results are nondeterministic (same seed will not produce exactly the same result).", - ), - IO.Boolean.Input( - "skip_preprocessing", - default=False, - tooltip="Whether to skip preprocessing; set to True if control_image already is depth-ified, False if it is a raw image.", - ), - IO.Float.Input( - "guidance", - default=15, - min=1, - max=100, - tooltip="Guidance strength for the image generation process", - ), - IO.Int.Input( - "steps", - default=50, - min=15, - max=50, - tooltip="Number of steps for the image generation process", - ), - IO.Int.Input( - "seed", - default=0, - min=0, - max=0xFFFFFFFFFFFFFFFF, - control_after_generate=True, - tooltip="The random seed used for creating the noise.", - ), - ], - outputs=[IO.Image.Output()], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], - is_api_node=True, - ) - - @classmethod - async def execute( - cls, - control_image: torch.Tensor, - prompt: str, - prompt_upsampling: bool, - skip_preprocessing: bool, - steps: int, - guidance: float, - seed=0, - ) -> IO.NodeOutput: - control_image = convert_image_to_base64(control_image[:,:,:,:3]) - preprocessed_image = None - - if skip_preprocessing: - preprocessed_image = control_image - control_image = None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/bfl/flux-pro-1.0-depth/generate", - method=HttpMethod.POST, - request_model=BFLFluxDepthImageRequest, - response_model=BFLFluxProGenerateResponse, - ), - request=BFLFluxDepthImageRequest( - prompt=prompt, prompt_upsampling=prompt_upsampling, - steps=steps, - guidance=guidance, - seed=seed, - control_image=control_image, - preprocessed_image=preprocessed_image, + **reference_images, ), - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - output_image = await handle_bfl_synchronous_operation(operation, node_id=cls.hidden.unique_id) - return IO.NodeOutput(output_image) + + def price_extractor(_r: BaseModel) -> float | None: + return None if initial_response.cost is None else initial_response.cost / 100 + + response = await poll_op( + cls, + ApiEndpoint(initial_response.polling_url), + response_model=BFLFluxStatusResponse, + status_extractor=lambda r: r.status, + progress_extractor=lambda r: r.progress, + price_extractor=price_extractor, + completed_statuses=[BFLStatus.ready], + failed_statuses=[ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + BFLStatus.error, + BFLStatus.task_not_found, + ], + queued_statuses=[], + ) + return IO.NodeOutput(await download_url_to_image_tensor(response.result["sample"])) class BFLExtension(ComfyExtension): @@ -985,13 +637,11 @@ class BFLExtension(ComfyExtension): async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ FluxProUltraImageNode, - # FluxProImageNode, FluxKontextProImageNode, FluxKontextMaxImageNode, FluxProExpandNode, FluxProFillNode, - FluxProCannyNode, - FluxProDepthNode, + Flux2ProImageNode, ] diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index f3d3f8d3e..57c0218d0 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -1,35 +1,41 @@ import logging import math -from enum import Enum -from typing import Literal, Optional, Type, Union -from typing_extensions import override import torch -from pydantic import BaseModel, Field +from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.util.validation_utils import ( - validate_image_aspect_ratio_range, - get_number_of_images, - validate_image_dimensions, +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis.bytedance_api import ( + RECOMMENDED_PRESETS, + RECOMMENDED_PRESETS_SEEDREAM_4, + VIDEO_TASKS_EXECUTION_TIME, + Image2ImageTaskCreationRequest, + Image2VideoTaskCreationRequest, + ImageTaskCreationResponse, + Seedream4Options, + Seedream4TaskCreationRequest, + TaskCreationResponse, + TaskImageContent, + TaskImageContentUrl, + TaskStatusResponse, + TaskTextContent, + Text2ImageTaskCreationRequest, + Text2VideoTaskCreationRequest, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - EmptyRequest, - HttpMethod, - SynchronousOperation, - PollingOperation, - T, -) -from comfy_api_nodes.apinode_utils import ( download_url_to_image_tensor, download_url_to_video_output, - upload_images_to_comfyapi, - validate_string, + get_number_of_images, image_tensor_pair_to_batch, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_image_aspect_ratio, + validate_image_dimensions, + validate_string, ) - BYTEPLUS_IMAGE_ENDPOINT = "/proxy/byteplus/api/v3/images/generations" # Long-running tasks endpoints(e.g., video) @@ -37,161 +43,6 @@ BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} -class Text2ImageModelName(str, Enum): - seedream_3 = "seedream-3-0-t2i-250415" - - -class Image2ImageModelName(str, Enum): - seededit_3 = "seededit-3-0-i2i-250628" - - -class Text2VideoModelName(str, Enum): - seedance_1_pro = "seedance-1-0-pro-250528" - seedance_1_lite = "seedance-1-0-lite-t2v-250428" - - -class Image2VideoModelName(str, Enum): - """note(August 31): Pro model only supports FirstFrame: https://docs.byteplus.com/en/docs/ModelArk/1520757""" - seedance_1_pro = "seedance-1-0-pro-250528" - seedance_1_lite = "seedance-1-0-lite-i2v-250428" - - -class Text2ImageTaskCreationRequest(BaseModel): - model: Text2ImageModelName = Text2ImageModelName.seedream_3 - prompt: str = Field(...) - response_format: Optional[str] = Field("url") - size: Optional[str] = Field(None) - seed: Optional[int] = Field(0, ge=0, le=2147483647) - guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0) - watermark: Optional[bool] = Field(True) - - -class Image2ImageTaskCreationRequest(BaseModel): - model: Image2ImageModelName = Image2ImageModelName.seededit_3 - prompt: str = Field(...) - response_format: Optional[str] = Field("url") - image: str = Field(..., description="Base64 encoded string or image URL") - size: Optional[str] = Field("adaptive") - seed: Optional[int] = Field(..., ge=0, le=2147483647) - guidance_scale: Optional[float] = Field(..., ge=1.0, le=10.0) - watermark: Optional[bool] = Field(True) - - -class Seedream4Options(BaseModel): - max_images: int = Field(15) - - -class Seedream4TaskCreationRequest(BaseModel): - model: str = Field("seedream-4-0-250828") - prompt: str = Field(...) - response_format: str = Field("url") - image: Optional[list[str]] = Field(None, description="Image URLs") - size: str = Field(...) - seed: int = Field(..., ge=0, le=2147483647) - sequential_image_generation: str = Field("disabled") - sequential_image_generation_options: Seedream4Options = Field(Seedream4Options(max_images=15)) - watermark: bool = Field(True) - - -class ImageTaskCreationResponse(BaseModel): - model: str = Field(...) - created: int = Field(..., description="Unix timestamp (in seconds) indicating time when the request was created.") - data: list = Field([], description="Contains information about the generated image(s).") - error: dict = Field({}, description="Contains `code` and `message` fields in case of error.") - - -class TaskTextContent(BaseModel): - type: str = Field("text") - text: str = Field(...) - - -class TaskImageContentUrl(BaseModel): - url: str = Field(...) - - -class TaskImageContent(BaseModel): - type: str = Field("image_url") - image_url: TaskImageContentUrl = Field(...) - role: Optional[Literal["first_frame", "last_frame", "reference_image"]] = Field(None) - - -class Text2VideoTaskCreationRequest(BaseModel): - model: Text2VideoModelName = Text2VideoModelName.seedance_1_pro - content: list[TaskTextContent] = Field(..., min_length=1) - - -class Image2VideoTaskCreationRequest(BaseModel): - model: Image2VideoModelName = Image2VideoModelName.seedance_1_pro - content: list[Union[TaskTextContent, TaskImageContent]] = Field(..., min_length=2) - - -class TaskCreationResponse(BaseModel): - id: str = Field(...) - - -class TaskStatusError(BaseModel): - code: str = Field(...) - message: str = Field(...) - - -class TaskStatusResult(BaseModel): - video_url: str = Field(...) - - -class TaskStatusResponse(BaseModel): - id: str = Field(...) - model: str = Field(...) - status: Literal["queued", "running", "cancelled", "succeeded", "failed"] = Field(...) - error: Optional[TaskStatusError] = Field(None) - content: Optional[TaskStatusResult] = Field(None) - - -RECOMMENDED_PRESETS = [ - ("1024x1024 (1:1)", 1024, 1024), - ("864x1152 (3:4)", 864, 1152), - ("1152x864 (4:3)", 1152, 864), - ("1280x720 (16:9)", 1280, 720), - ("720x1280 (9:16)", 720, 1280), - ("832x1248 (2:3)", 832, 1248), - ("1248x832 (3:2)", 1248, 832), - ("1512x648 (21:9)", 1512, 648), - ("2048x2048 (1:1)", 2048, 2048), - ("Custom", None, None), -] - -RECOMMENDED_PRESETS_SEEDREAM_4 = [ - ("2048x2048 (1:1)", 2048, 2048), - ("2304x1728 (4:3)", 2304, 1728), - ("1728x2304 (3:4)", 1728, 2304), - ("2560x1440 (16:9)", 2560, 1440), - ("1440x2560 (9:16)", 1440, 2560), - ("2496x1664 (3:2)", 2496, 1664), - ("1664x2496 (2:3)", 1664, 2496), - ("3024x1296 (21:9)", 3024, 1296), - ("4096x4096 (1:1)", 4096, 4096), - ("Custom", None, None), -] - -# The time in this dictionary are given for 10 seconds duration. -VIDEO_TASKS_EXECUTION_TIME = { - "seedance-1-0-lite-t2v-250428": { - "480p": 40, - "720p": 60, - "1080p": 90, - }, - "seedance-1-0-lite-i2v-250428": { - "480p": 40, - "720p": 60, - "1080p": 90, - }, - "seedance-1-0-pro-250528": { - "480p": 70, - "720p": 85, - "1080p": 115, - }, -} - - def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: if response.error: error_msg = f"ByteDance request failed. Code: {response.error['code']}, message: {response.error['message']}" @@ -201,42 +52,6 @@ def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: return response.data[0]["url"] -def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: - """Returns the video URL from the task status response if it exists.""" - if hasattr(response, "content") and response.content: - return response.content.video_url - return None - - -async def poll_until_finished( - auth_kwargs: dict[str, str], - task_id: str, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> TaskStatusResponse: - """Polls the ByteDance API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - completed_statuses=[ - "succeeded", - ], - failed_statuses=[ - "cancelled", - "failed", - ], - status_extractor=lambda response: response.status, - auth_kwargs=auth_kwargs, - result_url_extractor=get_video_url_from_task_status, - estimated_duration=estimated_duration, - node_id=node_id, - ).execute() - - class ByteDanceImageNode(IO.ComfyNode): @classmethod @@ -247,12 +62,7 @@ class ByteDanceImageNode(IO.ComfyNode): category="api node/image/ByteDance", description="Generate images using ByteDance models via api based on prompt", inputs=[ - IO.Combo.Input( - "model", - options=Text2ImageModelName, - default=Text2ImageModelName.seedream_3, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["seedream-3-0-t2i-250415"]), IO.String.Input( "prompt", multiline=True, @@ -303,7 +113,7 @@ class ByteDanceImageNode(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the image", + tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), ], @@ -341,8 +151,7 @@ class ByteDanceImageNode(IO.ComfyNode): w, h = width, height if not (512 <= w <= 2048) or not (512 <= h <= 2048): raise ValueError( - f"Custom size out of range: {w}x{h}. " - "Both width and height must be between 512 and 2048 pixels." + f"Custom size out of range: {w}x{h}. " "Both width and height must be between 512 and 2048 pixels." ) payload = Text2ImageTaskCreationRequest( @@ -353,20 +162,12 @@ class ByteDanceImageNode(IO.ComfyNode): guidance_scale=guidance_scale, watermark=watermark, ) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Text2ImageTaskCreationRequest, - response_model=ImageTaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() + response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + data=payload, + response_model=ImageTaskCreationResponse, + ) return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) @@ -380,12 +181,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): category="api node/image/ByteDance", description="Edit images using ByteDance models via api based on prompt", inputs=[ - IO.Combo.Input( - "model", - options=Image2ImageModelName, - default=Image2ImageModelName.seededit_3, - tooltip="Model name", - ), + IO.Combo.Input("model", options=["seededit-3-0-i2i-250628"]), IO.Image.Input( "image", tooltip="The base image to edit", @@ -420,7 +216,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the image", + tooltip='Whether to add an "AI generated" watermark to the image', optional=True, ), ], @@ -439,7 +235,7 @@ class ByteDanceImageEditNode(IO.ComfyNode): async def execute( cls, model: str, - image: torch.Tensor, + image: Input.Image, prompt: str, seed: int, guidance_scale: float, @@ -448,17 +244,8 @@ class ByteDanceImageEditNode(IO.ComfyNode): validate_string(prompt, strip_whitespace=True, min_length=1) if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") - validate_image_aspect_ratio_range(image, (1, 3), (3, 1)) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - source_url = (await upload_images_to_comfyapi( - image, - max_images=1, - mime_type="image/png", - auth_kwargs=auth_kwargs, - ))[0] + validate_image_aspect_ratio(image, (1, 3), (3, 1)) + source_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0] payload = Image2ImageTaskCreationRequest( model=model, prompt=prompt, @@ -467,16 +254,12 @@ class ByteDanceImageEditNode(IO.ComfyNode): guidance_scale=guidance_scale, watermark=watermark, ) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Image2ImageTaskCreationRequest, - response_model=ImageTaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() + response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + data=payload, + response_model=ImageTaskCreationResponse, + ) return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) @@ -492,7 +275,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=["seedream-4-0-250828"], + options=["seedream-4-5-251128", "seedream-4-0-250828"], tooltip="Model name", ), IO.String.Input( @@ -504,7 +287,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): IO.Image.Input( "image", tooltip="Input image(s) for image-to-image generation. " - "List of 1-10 images for single or multi-reference generation.", + "List of 1-10 images for single or multi-reference generation.", optional=True, ), IO.Combo.Input( @@ -517,7 +300,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): default=2048, min=1024, max=4096, - step=64, + step=8, tooltip="Custom width for image. Value is working only if `size_preset` is set to `Custom`", optional=True, ), @@ -526,7 +309,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): default=2048, min=1024, max=4096, - step=64, + step=8, tooltip="Custom height for image. Value is working only if `size_preset` is set to `Custom`", optional=True, ), @@ -534,9 +317,9 @@ class ByteDanceSeedreamNode(IO.ComfyNode): "sequential_image_generation", options=["disabled", "auto"], tooltip="Group image generation mode. " - "'disabled' generates a single image. " - "'auto' lets the model decide whether to generate multiple related images " - "(e.g., story scenes, character variations).", + "'disabled' generates a single image. " + "'auto' lets the model decide whether to generate multiple related images " + "(e.g., story scenes, character variations).", optional=True, ), IO.Int.Input( @@ -547,7 +330,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): step=1, display_mode=IO.NumberDisplay.number, tooltip="Maximum number of images to generate when sequential_image_generation='auto'. " - "Total images (input + generated) cannot exceed 15.", + "Total images (input + generated) cannot exceed 15.", optional=True, ), IO.Int.Input( @@ -564,7 +347,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the image.", + tooltip='Whether to add an "AI generated" watermark to the image.', optional=True, ), IO.Boolean.Input( @@ -590,7 +373,7 @@ class ByteDanceSeedreamNode(IO.ComfyNode): cls, model: str, prompt: str, - image: torch.Tensor = None, + image: Input.Image | None = None, size_preset: str = RECOMMENDED_PRESETS_SEEDREAM_4[0][0], width: int = 2048, height: int = 2048, @@ -611,9 +394,20 @@ class ByteDanceSeedreamNode(IO.ComfyNode): w, h = width, height if not (1024 <= w <= 4096) or not (1024 <= h <= 4096): raise ValueError( - f"Custom size out of range: {w}x{h}. " - "Both width and height must be between 1024 and 4096 pixels." + f"Custom size out of range: {w}x{h}. " "Both width and height must be between 1024 and 4096 pixels." ) + out_num_pixels = w * h + mp_provided = out_num_pixels / 1_000_000.0 + if "seedream-4-5" in model and out_num_pixels < 3686400: + raise ValueError( + f"Minimum image resolution that Seedream 4.5 can generate is 3.68MP, " + f"but {mp_provided:.2f}MP provided." + ) + if "seedream-4-0" in model and out_num_pixels < 921600: + raise ValueError( + f"Minimum image resolution that the selected model can generate is 0.92MP, " + f"but {mp_provided:.2f}MP provided." + ) n_input_images = get_number_of_images(image) if image is not None else 0 if n_input_images > 10: raise ValueError(f"Maximum of 10 reference images are supported, but {n_input_images} received.") @@ -621,41 +415,31 @@ class ByteDanceSeedreamNode(IO.ComfyNode): raise ValueError( "The maximum number of generated images plus the number of reference images cannot exceed 15." ) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } reference_images_urls = [] if n_input_images: for i in image: - validate_image_aspect_ratio_range(i, (1, 3), (3, 1)) - reference_images_urls = (await upload_images_to_comfyapi( + validate_image_aspect_ratio(i, (1, 3), (3, 1)) + reference_images_urls = await upload_images_to_comfyapi( + cls, image, max_images=n_input_images, mime_type="image/png", - auth_kwargs=auth_kwargs, - )) - payload = Seedream4TaskCreationRequest( - model=model, - prompt=prompt, - image=reference_images_urls, - size=f"{w}x{h}", - seed=seed, - sequential_image_generation=sequential_image_generation, - sequential_image_generation_options=Seedream4Options(max_images=max_images), - watermark=watermark, - ) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_IMAGE_ENDPOINT, - method=HttpMethod.POST, - request_model=Seedream4TaskCreationRequest, - response_model=ImageTaskCreationResponse, + ) + response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_IMAGE_ENDPOINT, method="POST"), + response_model=ImageTaskCreationResponse, + data=Seedream4TaskCreationRequest( + model=model, + prompt=prompt, + image=reference_images_urls, + size=f"{w}x{h}", + seed=seed, + sequential_image_generation=sequential_image_generation, + sequential_image_generation_options=Seedream4Options(max_images=max_images), + watermark=watermark, ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - + ) if len(response.data) == 1: return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_response(response))) urls = [str(d["url"]) for d in response.data if isinstance(d, dict) and "url" in d] @@ -676,9 +460,8 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=Text2VideoModelName, - default=Text2VideoModelName.seedance_1_pro, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + default="seedance-1-0-pro-fast-251015", ), IO.String.Input( "prompt", @@ -719,13 +502,13 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): "camera_fixed", default=False, tooltip="Specifies whether to fix the camera. The platform appends an instruction " - "to fix the camera to your prompt, but does not guarantee the actual effect.", + "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, ), IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -764,19 +547,9 @@ class ByteDanceTextToVideoNode(IO.ComfyNode): f"--camerafixed {str(camera_fixed).lower()} " f"--watermark {str(watermark).lower()}" ) - - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } return await process_video_task( - request_model=Text2VideoTaskCreationRequest, - payload=Text2VideoTaskCreationRequest( - model=model, - content=[TaskTextContent(text=prompt)], - ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, + cls, + payload=Text2VideoTaskCreationRequest(model=model, content=[TaskTextContent(text=prompt)]), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -793,9 +566,8 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=Image2VideoModelName, - default=Image2VideoModelName.seedance_1_pro, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-t2v-250428", "seedance-1-0-pro-fast-251015"], + default="seedance-1-0-pro-fast-251015", ), IO.String.Input( "prompt", @@ -840,13 +612,13 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): "camera_fixed", default=False, tooltip="Specifies whether to fix the camera. The platform appends an instruction " - "to fix the camera to your prompt, but does not guarantee the actual effect.", + "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, ), IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -866,7 +638,7 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): cls, model: str, prompt: str, - image: torch.Tensor, + image: Input.Image, resolution: str, aspect_ratio: str, duration: int, @@ -877,15 +649,9 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): validate_string(prompt, strip_whitespace=True, min_length=1) raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) - validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth_kwargs))[0] + validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] prompt = ( f"{prompt} " f"--resolution {resolution} " @@ -897,13 +663,11 @@ class ByteDanceImageToVideoNode(IO.ComfyNode): ) return await process_video_task( - request_model=Image2VideoTaskCreationRequest, + cls, payload=Image2VideoTaskCreationRequest( model=model, content=[TaskTextContent(text=prompt), TaskImageContent(image_url=TaskImageContentUrl(url=image_url))], ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -920,9 +684,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=[model.value for model in Image2VideoModelName], - default=Image2VideoModelName.seedance_1_lite.value, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + default="seedance-1-0-lite-i2v-250428", ), IO.String.Input( "prompt", @@ -971,13 +734,13 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): "camera_fixed", default=False, tooltip="Specifies whether to fix the camera. The platform appends an instruction " - "to fix the camera to your prompt, but does not guarantee the actual effect.", + "to fix the camera to your prompt, but does not guarantee the actual effect.", optional=True, ), IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -997,8 +760,8 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): cls, model: str, prompt: str, - first_frame: torch.Tensor, - last_frame: torch.Tensor, + first_frame: Input.Image, + last_frame: Input.Image, resolution: str, aspect_ratio: str, duration: int, @@ -1010,18 +773,13 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "camerafixed", "watermark"]) for i in (first_frame, last_frame): validate_image_dimensions(i, min_width=300, min_height=300, max_width=6000, max_height=6000) - validate_image_aspect_ratio_range(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } + validate_image_aspect_ratio(i, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 download_urls = await upload_images_to_comfyapi( + cls, image_tensor_pair_to_batch(first_frame, last_frame), max_images=2, mime_type="image/png", - auth_kwargs=auth_kwargs, ) prompt = ( @@ -1035,7 +793,7 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): ) return await process_video_task( - request_model=Image2VideoTaskCreationRequest, + cls, payload=Image2VideoTaskCreationRequest( model=model, content=[ @@ -1044,8 +802,6 @@ class ByteDanceFirstLastFrameNode(IO.ComfyNode): TaskImageContent(image_url=TaskImageContentUrl(url=str(download_urls[1])), role="last_frame"), ], ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -1062,9 +818,8 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): inputs=[ IO.Combo.Input( "model", - options=[Image2VideoModelName.seedance_1_lite.value], - default=Image2VideoModelName.seedance_1_lite.value, - tooltip="Model name", + options=["seedance-1-0-pro-250528", "seedance-1-0-lite-i2v-250428"], + default="seedance-1-0-lite-i2v-250428", ), IO.String.Input( "prompt", @@ -1108,7 +863,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the video.", + tooltip='Whether to add an "AI generated" watermark to the video.', optional=True, ), ], @@ -1128,7 +883,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): cls, model: str, prompt: str, - images: torch.Tensor, + images: Input.Image, resolution: str, aspect_ratio: str, duration: int, @@ -1139,17 +894,9 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): raise_if_text_params(prompt, ["resolution", "ratio", "duration", "seed", "watermark"]) for image in images: validate_image_dimensions(image, min_width=300, min_height=300, max_width=6000, max_height=6000) - validate_image_aspect_ratio_range(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 - - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - image_urls = await upload_images_to_comfyapi( - images, max_images=4, mime_type="image/png", auth_kwargs=auth_kwargs - ) + validate_image_aspect_ratio(image, (2, 5), (5, 2), strict=False) # 0.4 to 2.5 + image_urls = await upload_images_to_comfyapi(cls, images, max_images=4, mime_type="image/png") prompt = ( f"{prompt} " f"--resolution {resolution} " @@ -1160,44 +907,34 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): ) x = [ TaskTextContent(text=prompt), - *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls] + *[TaskImageContent(image_url=TaskImageContentUrl(url=str(i)), role="reference_image") for i in image_urls], ] return await process_video_task( - request_model=Image2VideoTaskCreationRequest, - payload=Image2VideoTaskCreationRequest( - model=model, - content=x, - ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, + cls, + payload=Image2VideoTaskCreationRequest(model=model, content=x), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) async def process_video_task( - request_model: Type[T], - payload: Union[Text2VideoTaskCreationRequest, Image2VideoTaskCreationRequest], - auth_kwargs: dict, - node_id: str, - estimated_duration: Optional[int], + cls: type[IO.ComfyNode], + payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest, + estimated_duration: int | None, ) -> IO.NodeOutput: - initial_response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=BYTEPLUS_TASK_ENDPOINT, - method=HttpMethod.POST, - request_model=request_model, - response_model=TaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - response = await poll_until_finished( - auth_kwargs, - initial_response.id, - estimated_duration=estimated_duration, - node_id=node_id, + initial_response = await sync_op( + cls, + ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), + data=payload, + response_model=TaskCreationResponse, ) - return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(response))) + response = await poll_op( + cls, + ApiEndpoint(path=f"{BYTEPLUS_TASK_STATUS_ENDPOINT}/{initial_response.id}"), + status_extractor=lambda r: r.status, + estimated_duration=estimated_duration, + response_model=TaskStatusResponse, + ) + return IO.NodeOutput(await download_url_to_video_output(response.content.video_url)) def raise_if_text_params(prompt: str, text_params: list[str]) -> None: @@ -1221,5 +958,6 @@ class ByteDanceExtension(ComfyExtension): ByteDanceImageReferenceNode, ] + async def comfy_entrypoint() -> ByteDanceExtension: return ByteDanceExtension() diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 937984da8..3b02600fd 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -2,48 +2,56 @@ API Nodes for Gemini Multimodal LLM Usage via Remote API See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference """ -from __future__ import annotations -import json -import time -import os -import uuid import base64 -from io import BytesIO +import os from enum import Enum -from typing import Optional, Literal +from io import BytesIO +from typing import Literal import torch +from typing_extensions import override from comfy.cmd import folder_paths -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict -from comfy.cmd.server import PromptServer -from comfy_api_nodes.apis import ( +from comfy_api.latest import IO, ComfyExtension, Input, Types +from comfy_api_nodes.apis.gemini_api import ( GeminiContent, + GeminiFileData, GeminiGenerateContentRequest, GeminiGenerateContentResponse, + GeminiImageConfig, + GeminiImageGenerateContentRequest, + GeminiImageGenerationConfig, GeminiInlineData, - GeminiPart, GeminiMimeType, + GeminiPart, + GeminiRole, + GeminiSystemInstructionContent, + GeminiTextPart, + Modality, ) -from comfy_api_nodes.apis.gemini_api import GeminiImageGenerationConfig, GeminiImageGenerateContentRequest, GeminiImageConfig -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, -) -from comfy_api_nodes.apinode_utils import ( - validate_string, audio_to_base64_string, - video_to_base64_string, - tensor_to_base64_string, bytesio_to_image_tensor, + get_number_of_images, + sync_op, + tensor_to_base64_string, + upload_images_to_comfyapi, + validate_string, + video_to_base64_string, ) -from comfy_api.util import VideoContainer, VideoCodec - GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini" GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB +GEMINI_IMAGE_SYS_PROMPT = ( + "You are an expert image-generation engine. You must ALWAYS produce an image.\n" + "Interpret all user input—regardless of " + "format, intent, or abstraction—as literal visual directives for image composition.\n" + "If a prompt is conversational or lacks specific visual details, " + "you must creatively invent a concrete visual scenario that depicts the concept.\n" + "Prioritize generating the visual representation above any text, formatting, or conversational requests." +) class GeminiModel(str, Enum): @@ -55,6 +63,7 @@ class GeminiModel(str, Enum): gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17" gemini_2_5_pro = "gemini-2.5-pro" gemini_2_5_flash = "gemini-2.5-flash" + gemini_3_0_pro = "gemini-3-pro-preview" class GeminiImageModel(str, Enum): @@ -66,107 +75,50 @@ class GeminiImageModel(str, Enum): gemini_2_5_flash_image = "gemini-2.5-flash-image" -def get_gemini_endpoint( - model: GeminiModel, -) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]: - """ - Get the API endpoint for a given Gemini model. - - Args: - model: The Gemini model to use, either as enum or string value. - - Returns: - ApiEndpoint configured for the specific Gemini model. - """ - if isinstance(model, str): - model = GeminiModel(model) - return ApiEndpoint( - path=f"{GEMINI_BASE_ENDPOINT}/{model.value}", - method=HttpMethod.POST, - request_model=GeminiGenerateContentRequest, - response_model=GeminiGenerateContentResponse, - ) - - -def get_gemini_image_endpoint( - model: GeminiImageModel, -) -> ApiEndpoint[GeminiGenerateContentRequest, GeminiGenerateContentResponse]: - """ - Get the API endpoint for a given Gemini model. - - Args: - model: The Gemini model to use, either as enum or string value. - - Returns: - ApiEndpoint configured for the specific Gemini model. - """ - if isinstance(model, str): - model = GeminiImageModel(model) - return ApiEndpoint( - path=f"{GEMINI_BASE_ENDPOINT}/{model.value}", - method=HttpMethod.POST, - request_model=GeminiImageGenerateContentRequest, - response_model=GeminiGenerateContentResponse, - ) - - -def create_image_parts(image_input: torch.Tensor) -> list[GeminiPart]: - """ - Convert image tensor input to Gemini API compatible parts. - - Args: - image_input: Batch of image tensors from ComfyUI. - - Returns: - List of GeminiPart objects containing the encoded images. - """ +async def create_image_parts( + cls: type[IO.ComfyNode], + images: Input.Image, + image_limit: int = 0, +) -> list[GeminiPart]: image_parts: list[GeminiPart] = [] - for image_index in range(image_input.shape[0]): - image_as_b64 = tensor_to_base64_string( - image_input[image_index].unsqueeze(0) + if image_limit < 0: + raise ValueError("image_limit must be greater than or equal to 0 when creating Gemini image parts.") + total_images = get_number_of_images(images) + if total_images <= 0: + raise ValueError("No images provided to create_image_parts; at least one image is required.") + + # If image_limit == 0 --> use all images; otherwise clamp to image_limit. + effective_max = total_images if image_limit == 0 else min(total_images, image_limit) + + # Number of images we'll send as URLs (fileData) + num_url_images = min(effective_max, 10) # Vertex API max number of image links + reference_images_urls = await upload_images_to_comfyapi( + cls, + images, + max_images=num_url_images, + ) + for reference_image_url in reference_images_urls: + image_parts.append( + GeminiPart( + fileData=GeminiFileData( + mimeType=GeminiMimeType.image_png, + fileUri=reference_image_url, + ) + ) ) + for idx in range(num_url_images, effective_max): image_parts.append( GeminiPart( inlineData=GeminiInlineData( mimeType=GeminiMimeType.image_png, - data=image_as_b64, + data=tensor_to_base64_string(images[idx]), ) ) ) return image_parts -def create_text_part(text: str) -> GeminiPart: - """ - Create a text part for the Gemini API request. - - Args: - text: The text content to include in the request. - - Returns: - A GeminiPart object with the text content. - """ - return GeminiPart(text=text) - - -def get_parts_from_response( - response: GeminiGenerateContentResponse -) -> list[GeminiPart]: - """ - Extract all parts from the Gemini API response. - - Args: - response: The API response from Gemini. - - Returns: - List of response parts from the first candidate. - """ - return response.candidates[0].content.parts - - -def get_parts_by_type( - response: GeminiGenerateContentResponse, part_type: Literal["text"] | str -) -> list[GeminiPart]: +def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Literal["text"] | str) -> list[GeminiPart]: """ Filter response parts by their type. @@ -177,15 +129,21 @@ def get_parts_by_type( Returns: List of response parts matching the requested type. """ + if response.candidates is None: + if response.promptFeedback and response.promptFeedback.blockReason: + feedback = response.promptFeedback + raise ValueError( + f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})" + ) + raise ValueError( + "Gemini API returned no response candidates. If you are using the `IMAGE` modality, " + "try changing it to `IMAGE+TEXT` to view the model's reasoning and understand why image generation failed." + ) parts = [] - for part in get_parts_from_response(response): + for part in response.candidates[0].content.parts: if part_type == "text" and hasattr(part, "text") and part.text: parts.append(part) - elif ( - hasattr(part, "inlineData") - and part.inlineData - and part.inlineData.mimeType == part_type - ): + elif hasattr(part, "inlineData") and part.inlineData and part.inlineData.mimeType == part_type: parts.append(part) # Skip parts that don't match the requested type return parts @@ -205,19 +163,63 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor: - image_tensors: list[torch.Tensor] = [] +def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: + image_tensors: list[Input.Image] = [] parts = get_parts_by_type(response, "image/png") for part in parts: image_data = base64.b64decode(part.inlineData.data) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) image_tensors.append(returned_image) if len(image_tensors) == 0: - return torch.zeros((1,1024,1024,4)) + return torch.zeros((1, 1024, 1024, 4)) return torch.cat(image_tensors, dim=0) -class GeminiNode(ComfyNodeABC): +def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | None: + if not response.modelVersion: + return None + # Define prices (Cost per 1,000,000 tokens), see https://cloud.google.com/vertex-ai/generative-ai/pricing + if response.modelVersion in ("gemini-2.5-pro-preview-05-06", "gemini-2.5-pro"): + input_tokens_price = 1.25 + output_text_tokens_price = 10.0 + output_image_tokens_price = 0.0 + elif response.modelVersion in ( + "gemini-2.5-flash-preview-04-17", + "gemini-2.5-flash", + ): + input_tokens_price = 0.30 + output_text_tokens_price = 2.50 + output_image_tokens_price = 0.0 + elif response.modelVersion in ( + "gemini-2.5-flash-image-preview", + "gemini-2.5-flash-image", + ): + input_tokens_price = 0.30 + output_text_tokens_price = 2.50 + output_image_tokens_price = 30.0 + elif response.modelVersion == "gemini-3-pro-preview": + input_tokens_price = 2 + output_text_tokens_price = 12.0 + output_image_tokens_price = 0.0 + elif response.modelVersion == "gemini-3-pro-image-preview": + input_tokens_price = 2 + output_text_tokens_price = 12.0 + output_image_tokens_price = 120.0 + else: + return None + final_price = response.usageMetadata.promptTokenCount * input_tokens_price + if response.usageMetadata.candidatesTokensDetails: + for i in response.usageMetadata.candidatesTokensDetails: + if i.modality == Modality.IMAGE: + final_price += output_image_tokens_price * i.tokenCount # for Nano Banana models + else: + final_price += output_text_tokens_price * i.tokenCount + if response.usageMetadata.thoughtsTokenCount: + final_price += output_text_tokens_price * response.usageMetadata.thoughtsTokenCount + return final_price / 1_000_000.0 + + +class GeminiNode(IO.ComfyNode): """ Node to generate text responses from a Gemini model. @@ -228,95 +230,87 @@ class GeminiNode(ComfyNodeABC): """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text inputs to the model, used to generate a response. You can include detailed instructions, questions, or context for the model.", - }, + def define_schema(cls): + return IO.Schema( + node_id="GeminiNode", + display_name="Google Gemini", + category="api node/text/Gemini", + description="Generate text responses with Google's Gemini AI model. " + "You can provide multiple types of inputs (text, images, audio, video) " + "as context for generating more relevant and meaningful responses.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text inputs to the model, used to generate a response. " + "You can include detailed instructions, questions, or context for the model.", ), - "model": ( - IO.COMBO, - { - "tooltip": "The Gemini model to use for generating responses.", - "options": [model.value for model in GeminiModel], - "default": GeminiModel.gemini_2_5_pro.value, - }, + IO.Combo.Input( + "model", + options=GeminiModel, + default=GeminiModel.gemini_2_5_pro, + tooltip="The Gemini model to use for generating responses.", ), - "seed": ( - IO.INT, - { - "default": 42, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.", - }, + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", ), - }, - "optional": { - "images": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", - }, + IO.Image.Input( + "images", + optional=True, + tooltip="Optional image(s) to use as context for the model. " + "To include multiple images, you can use the Batch Images node.", ), - "audio": ( - IO.AUDIO, - { - "tooltip": "Optional audio to use as context for the model.", - "default": None, - }, + IO.Audio.Input( + "audio", + optional=True, + tooltip="Optional audio to use as context for the model.", ), - "video": ( - IO.VIDEO, - { - "tooltip": "Optional video to use as context for the model.", - "default": None, - }, + IO.Video.Input( + "video", + optional=True, + tooltip="Optional video to use as context for the model.", ), - "files": ( - "GEMINI_INPUT_FILES", - { - "default": None, - "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.", - }, + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + IO.String.Input( + "system_prompt", + multiline=True, + default="", + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), + ], + outputs=[ + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - DESCRIPTION = "Generate text responses with Google's Gemini AI model. You can provide multiple types of inputs (text, images, audio, video) as context for generating more relevant and meaningful responses." - RETURN_TYPES = ("STRING",) - FUNCTION = "api_call" - CATEGORY = "api node/text/Gemini" - API_NODE = True - - def create_video_parts(self, video_input: IO.VIDEO, **kwargs) -> list[GeminiPart]: - """ - Convert video input to Gemini API compatible parts. - - Args: - video_input: Video tensor from ComfyUI. - **kwargs: Additional arguments to pass to the conversion function. - - Returns: - List of GeminiPart objects containing the encoded video. - """ + @classmethod + def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]: + """Convert video input to Gemini API compatible parts.""" base_64_string = video_to_base64_string( - video_input, - container_format=VideoContainer.MP4, - codec=VideoCodec.H264 + video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264 ) return [ GeminiPart( @@ -327,7 +321,8 @@ class GeminiNode(ComfyNodeABC): ) ] - def create_audio_parts(self, audio_input: IO.AUDIO) -> list[GeminiPart]: + @classmethod + def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]: """ Convert audio input to Gemini API compatible parts. @@ -340,10 +335,10 @@ class GeminiNode(ComfyNodeABC): audio_parts: list[GeminiPart] = [] for batch_index in range(audio_input["waveform"].shape[0]): # Recreate an IO.AUDIO object for the given batch dimension index - audio_at_index = { - "waveform": audio_input["waveform"][batch_index].unsqueeze(0), - "sample_rate": audio_input["sample_rate"], - } + audio_at_index = Input.Audio( + waveform=audio_input["waveform"][batch_index].unsqueeze(0), + sample_rate=audio_input["sample_rate"], + ) # Convert to MP3 format for compatibility with Gemini API audio_bytes = audio_to_base64_string( audio_at_index, @@ -360,77 +355,58 @@ class GeminiNode(ComfyNodeABC): ) return audio_parts - async def api_call( - self, - prompt: str, - model: GeminiModel, - images: Optional[IO.IMAGE] = None, - audio: Optional[IO.AUDIO] = None, - video: Optional[IO.VIDEO] = None, - files: Optional[list[GeminiPart]] = None, - unique_id: Optional[str] = None, - **kwargs, - ) -> tuple[str]: - # Validate inputs + @classmethod + async def execute( + cls, + prompt: str, + model: str, + seed: int, + images: Input.Image | None = None, + audio: Input.Audio | None = None, + video: Input.Video | None = None, + files: list[GeminiPart] | None = None, + system_prompt: str = "", + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) # Create parts list with text prompt as the first part - parts: list[GeminiPart] = [create_text_part(prompt)] + parts: list[GeminiPart] = [GeminiPart(text=prompt)] # Add other modal parts if images is not None: - image_parts = create_image_parts(images) - parts.extend(image_parts) + parts.extend(await create_image_parts(cls, images)) if audio is not None: - parts.extend(self.create_audio_parts(audio)) + parts.extend(cls.create_audio_parts(audio)) if video is not None: - parts.extend(self.create_video_parts(video)) + parts.extend(cls.create_video_parts(video)) if files is not None: parts.extend(files) - # Create response - response = await SynchronousOperation( - endpoint=get_gemini_endpoint(model), - request=GeminiGenerateContentRequest( + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + data=GeminiGenerateContentRequest( contents=[ GeminiContent( - role="user", + role=GeminiRole.user, parts=parts, ) - ] + ], + systemInstruction=gemini_system_prompt, ), - auth_kwargs=kwargs, - ).execute() + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) - # Get result output output_text = get_text_from_response(response) - if unique_id and output_text: - # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. - render_spec = { - "node_id": unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - - return (output_text or "Empty response from Gemini model...",) + return IO.NodeOutput(output_text or "Empty response from Gemini model...") -class GeminiInputFiles(ComfyNodeABC): +class GeminiInputFiles(IO.ComfyNode): """ Loads and formats input files for use with the Gemini API. @@ -441,7 +417,7 @@ class GeminiInputFiles(ComfyNodeABC): """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: + def define_schema(cls): """ For details about the supported file input types, see: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference @@ -451,44 +427,42 @@ class GeminiInputFiles(ComfyNodeABC): f for f in os.scandir(input_dir) if f.is_file() - and (f.name.endswith(".txt") or f.name.endswith(".pdf")) - and f.stat().st_size < GEMINI_MAX_INPUT_FILE_SIZE + and (f.name.endswith(".txt") or f.name.endswith(".pdf")) + and f.stat().st_size < GEMINI_MAX_INPUT_FILE_SIZE ] input_files = sorted(input_files, key=lambda x: x.name) input_files = [f.name for f in input_files] - return { - "required": { - "file": ( - IO.COMBO, - { - "tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.", - "options": input_files, - "default": input_files[0] if input_files else None, - }, + return IO.Schema( + node_id="GeminiInputFiles", + display_name="Gemini Input Files", + category="api node/text/Gemini", + description="Loads and prepares input files to include as inputs for Gemini LLM nodes. " + "The files will be read by the Gemini model when generating a response. " + "The contents of the text file count toward the token limit. " + "🛈 TIP: Can be chained together with other Gemini Input File nodes.", + inputs=[ + IO.Combo.Input( + "file", + options=input_files, + default=input_files[0] if input_files else None, + tooltip="Input files to include as context for the model. " + "Only accepts text (.txt) and PDF (.pdf) files for now.", ), - }, - "optional": { - "GEMINI_INPUT_FILES": ( + IO.Custom("GEMINI_INPUT_FILES").Input( "GEMINI_INPUT_FILES", - { - "tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.", - "default": None, - }, + optional=True, + tooltip="An optional additional file(s) to batch together with the file loaded from this node. " + "Allows chaining of input files so that a single message can include multiple input files.", ), - }, - } - - DESCRIPTION = "Loads and prepares input files to include as inputs for Gemini LLM nodes. The files will be read by the Gemini model when generating a response. The contents of the text file count toward the token limit. 🛈 TIP: Can be chained together with other Gemini Input File nodes." - RETURN_TYPES = ("GEMINI_INPUT_FILES",) - FUNCTION = "prepare_files" - CATEGORY = "api node/text/Gemini" - - def create_file_part(self, file_path: str) -> GeminiPart: - mime_type = ( - GeminiMimeType.application_pdf - if file_path.endswith(".pdf") - else GeminiMimeType.text_plain + ], + outputs=[ + IO.Custom("GEMINI_INPUT_FILES").Output(), + ], ) + + @classmethod + def create_file_part(cls, file_path: str) -> GeminiPart: + mime_type = GeminiMimeType.application_pdf if file_path.endswith(".pdf") else GeminiMimeType.text_plain # Use base64 string directly, not the data URI with open(file_path, "rb") as f: file_content = f.read() @@ -501,185 +475,287 @@ class GeminiInputFiles(ComfyNodeABC): ) ) - def prepare_files( - self, file: str, GEMINI_INPUT_FILES: list[GeminiPart] = [] - ) -> tuple[list[GeminiPart]]: - """ - Loads and formats input files for Gemini API. - """ - file_path = folder_paths.get_annotated_filepath(file) - input_file_content = self.create_file_part(file_path) - files = [input_file_content] + GEMINI_INPUT_FILES - return (files,) - - -class GeminiImage(ComfyNodeABC): - """ - Node to generate text and image responses from a Gemini model. - - This node allows users to interact with Google's Gemini AI models, providing - multimodal inputs (text, images, files) to generate coherent - text and image responses. The node works with the latest Gemini models, handling the - API communication and response parsing. - """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text prompt for generation", - }, - ), - "model": ( - IO.COMBO, - { - "tooltip": "The Gemini model to use for generating responses.", - "options": [model.value for model in GeminiImageModel], - "default": GeminiImageModel.gemini_2_5_flash_image.value, - }, - ), - "seed": ( - IO.INT, - { - "default": 42, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "When seed is fixed to a specific value, the model makes a best effort to provide the same response for repeated requests. Deterministic output isn't guaranteed. Also, changing the model or parameter settings, such as the temperature, can cause variations in the response even when you use the same seed value. By default, a random seed value is used.", - }, - ), - }, - "optional": { - "images": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", - }, - ), - "files": ( - "GEMINI_INPUT_FILES", - { - "default": None, - "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the Gemini Generate Content Input Files node.", - }, - ), - # TODO: later we can add this parameter later - # "n": ( - # IO.INT, - # { - # "default": 1, - # "min": 1, - # "max": 8, - # "step": 1, - # "display": "number", - # "tooltip": "How many images to generate", - # }, - # ), - "aspect_ratio": ( - IO.COMBO, - { - "tooltip": "Defaults to matching the output image size to that of your input image, or otherwise generates 1:1 squares.", - "options": ["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], - "default": "auto", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def execute(cls, file: str, GEMINI_INPUT_FILES: list[GeminiPart] | None = None) -> IO.NodeOutput: + """Loads and formats input files for Gemini API.""" + if GEMINI_INPUT_FILES is None: + GEMINI_INPUT_FILES = [] + file_path = folder_paths.get_annotated_filepath(file) + input_file_content = cls.create_file_part(file_path) + return IO.NodeOutput([input_file_content] + GEMINI_INPUT_FILES) - RETURN_TYPES = (IO.IMAGE, IO.STRING) - FUNCTION = "api_call" - CATEGORY = "api node/image/Gemini" - DESCRIPTION = "Edit images synchronously via Google API." - API_NODE = True - async def api_call( - self, - prompt: str, - model: GeminiImageModel, - images: Optional[IO.IMAGE] = None, - files: Optional[list[GeminiPart]] = None, - n=1, - aspect_ratio: str = "auto", - unique_id: Optional[str] = None, - **kwargs, - ): +class GeminiImage(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiImageNode", + display_name="Nano Banana (Google Gemini Image)", + category="api node/image/Gemini", + description="Edit images synchronously via Google API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt for generation", + default="", + ), + IO.Combo.Input( + "model", + options=GeminiImageModel, + default=GeminiImageModel.gemini_2_5_flash_image, + tooltip="The Gemini model to use for generating responses.", + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", + ), + IO.Image.Input( + "images", + optional=True, + tooltip="Optional image(s) to use as context for the model. " + "To include multiple images, you can use the Batch Images node.", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], + default="auto", + tooltip="Defaults to matching the output image size to that of your input image, " + "or otherwise generates 1:1 squares.", + optional=True, + ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE+TEXT", "IMAGE"], + tooltip="Choose 'IMAGE' for image-only output, or " + "'IMAGE+TEXT' to return both the generated image and a text response.", + optional=True, + ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: str, + seed: int, + images: Input.Image | None = None, + files: list[GeminiPart] | None = None, + aspect_ratio: str = "auto", + response_modalities: str = "IMAGE+TEXT", + system_prompt: str = "", + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=1) - parts: list[GeminiPart] = [create_text_part(prompt)] + parts: list[GeminiPart] = [GeminiPart(text=prompt)] if not aspect_ratio: aspect_ratio = "auto" # for backward compatability with old workflows; to-do remove this in December image_config = GeminiImageConfig(aspectRatio=aspect_ratio) if images is not None: - image_parts = create_image_parts(images) - parts.extend(image_parts) + parts.extend(await create_image_parts(cls, images)) if files is not None: parts.extend(files) - response = await SynchronousOperation( - endpoint=get_gemini_image_endpoint(model), - request=GeminiImageGenerateContentRequest( + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + data=GeminiImageGenerateContentRequest( contents=[ - GeminiContent( - role="user", - parts=parts, - ), + GeminiContent(role=GeminiRole.user, parts=parts), ], generationConfig=GeminiImageGenerationConfig( - responseModalities=["TEXT","IMAGE"], + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), imageConfig=None if aspect_ratio == "auto" else image_config, - ) + ), + systemInstruction=gemini_system_prompt, ), - auth_kwargs=kwargs, - ).execute() - - output_image = get_image_from_response(response) - output_text = get_text_from_response(response) - if unique_id and output_text: - # Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button. - render_spec = { - "node_id": unique_id, - "component": "ChatHistoryWidget", - "props": { - "history": json.dumps( - [ - { - "prompt": prompt, - "response": output_text, - "response_id": str(uuid.uuid4()), - "timestamp": time.time(), - } - ] - ), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - - output_text = output_text or "Empty response from Gemini model..." - return (output_image, output_text,) + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) -NODE_CLASS_MAPPINGS = { - "GeminiNode": GeminiNode, - "GeminiImageNode": GeminiImage, - "GeminiInputFiles": GeminiInputFiles, -} +class GeminiImage2(IO.ComfyNode): -NODE_DISPLAY_NAME_MAPPINGS = { - "GeminiNode": "Google Gemini", - "GeminiImageNode": "Google Gemini Image", - "GeminiInputFiles": "Gemini Input Files", -} + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="GeminiImage2Node", + display_name="Nano Banana Pro (Google Gemini Image)", + category="api node/image/Gemini", + description="Generate or edit images synchronously via Google Vertex API.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + tooltip="Text prompt describing the image to generate or the edits to apply. " + "Include any constraints, styles, or details the model should follow.", + default="", + ), + IO.Combo.Input( + "model", + options=["gemini-3-pro-image-preview"], + ), + IO.Int.Input( + "seed", + default=42, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="When the seed is fixed to a specific value, the model makes a best effort to provide " + "the same response for repeated requests. Deterministic output isn't guaranteed. " + "Also, changing the model or parameter settings, such as the temperature, " + "can cause variations in the response even when you use the same seed value. " + "By default, a random seed value is used.", + ), + IO.Combo.Input( + "aspect_ratio", + options=["auto", "1:1", "2:3", "3:2", "3:4", "4:3", "4:5", "5:4", "9:16", "16:9", "21:9"], + default="auto", + tooltip="If set to 'auto', matches your input image's aspect ratio; " + "if no image is provided, a 16:9 square is usually generated.", + ), + IO.Combo.Input( + "resolution", + options=["1K", "2K", "4K"], + tooltip="Target output resolution. For 2K/4K the native Gemini upscaler is used.", + ), + IO.Combo.Input( + "response_modalities", + options=["IMAGE+TEXT", "IMAGE"], + tooltip="Choose 'IMAGE' for image-only output, or " + "'IMAGE+TEXT' to return both the generated image and a text response.", + ), + IO.Image.Input( + "images", + optional=True, + tooltip="Optional reference image(s). " + "To include multiple images, use the Batch Images node (up to 14).", + ), + IO.Custom("GEMINI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. " + "Accepts inputs from the Gemini Generate Content Input Files node.", + ), + IO.String.Input( + "system_prompt", + multiline=True, + default=GEMINI_IMAGE_SYS_PROMPT, + optional=True, + tooltip="Foundational instructions that dictate an AI's behavior.", + ), + ], + outputs=[ + IO.Image.Output(), + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: str, + seed: int, + aspect_ratio: str, + resolution: str, + response_modalities: str, + images: Input.Image | None = None, + files: list[GeminiPart] | None = None, + system_prompt: str = "", + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=True, min_length=1) + + parts: list[GeminiPart] = [GeminiPart(text=prompt)] + if images is not None: + if get_number_of_images(images) > 14: + raise ValueError("The current maximum number of supported images is 14.") + parts.extend(await create_image_parts(cls, images)) + if files is not None: + parts.extend(files) + + image_config = GeminiImageConfig(imageSize=resolution) + if aspect_ratio != "auto": + image_config.aspectRatio = aspect_ratio + + gemini_system_prompt = None + if system_prompt: + gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None) + + response = await sync_op( + cls, + ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"), + data=GeminiImageGenerateContentRequest( + contents=[ + GeminiContent(role=GeminiRole.user, parts=parts), + ], + generationConfig=GeminiImageGenerationConfig( + responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]), + imageConfig=image_config, + ), + systemInstruction=gemini_system_prompt, + ), + response_model=GeminiGenerateContentResponse, + price_extractor=calculate_tokens_price, + ) + return IO.NodeOutput(get_image_from_response(response), get_text_from_response(response)) + + +class GeminiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + GeminiNode, + GeminiImage, + GeminiImage2, + GeminiInputFiles, + ] + + +async def comfy_entrypoint() -> GeminiExtension: + return GeminiExtension() diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index 4eb225166..bf250ff8d 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -1,6 +1,6 @@ from io import BytesIO from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO +from comfy_api.latest import IO, ComfyExtension from PIL import Image import numpy as np import torch @@ -11,135 +11,129 @@ from comfy_api_nodes.apis import ( IdeogramV3Request, IdeogramV3EditRequest, ) - -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, -) - -from comfy_api_nodes.apinode_utils import ( - download_url_to_bytesio, bytesio_to_image_tensor, + download_url_as_bytesio, resize_mask_to_image, + sync_op, ) -from comfy.cmd.server import PromptServer V1_V1_RES_MAP = { - "Auto":"AUTO", - "512 x 1536":"RESOLUTION_512_1536", - "576 x 1408":"RESOLUTION_576_1408", - "576 x 1472":"RESOLUTION_576_1472", - "576 x 1536":"RESOLUTION_576_1536", - "640 x 1024":"RESOLUTION_640_1024", - "640 x 1344":"RESOLUTION_640_1344", - "640 x 1408":"RESOLUTION_640_1408", - "640 x 1472":"RESOLUTION_640_1472", - "640 x 1536":"RESOLUTION_640_1536", - "704 x 1152":"RESOLUTION_704_1152", - "704 x 1216":"RESOLUTION_704_1216", - "704 x 1280":"RESOLUTION_704_1280", - "704 x 1344":"RESOLUTION_704_1344", - "704 x 1408":"RESOLUTION_704_1408", - "704 x 1472":"RESOLUTION_704_1472", - "720 x 1280":"RESOLUTION_720_1280", - "736 x 1312":"RESOLUTION_736_1312", - "768 x 1024":"RESOLUTION_768_1024", - "768 x 1088":"RESOLUTION_768_1088", - "768 x 1152":"RESOLUTION_768_1152", - "768 x 1216":"RESOLUTION_768_1216", - "768 x 1232":"RESOLUTION_768_1232", - "768 x 1280":"RESOLUTION_768_1280", - "768 x 1344":"RESOLUTION_768_1344", - "832 x 960":"RESOLUTION_832_960", - "832 x 1024":"RESOLUTION_832_1024", - "832 x 1088":"RESOLUTION_832_1088", - "832 x 1152":"RESOLUTION_832_1152", - "832 x 1216":"RESOLUTION_832_1216", - "832 x 1248":"RESOLUTION_832_1248", - "864 x 1152":"RESOLUTION_864_1152", - "896 x 960":"RESOLUTION_896_960", - "896 x 1024":"RESOLUTION_896_1024", - "896 x 1088":"RESOLUTION_896_1088", - "896 x 1120":"RESOLUTION_896_1120", - "896 x 1152":"RESOLUTION_896_1152", - "960 x 832":"RESOLUTION_960_832", - "960 x 896":"RESOLUTION_960_896", - "960 x 1024":"RESOLUTION_960_1024", - "960 x 1088":"RESOLUTION_960_1088", - "1024 x 640":"RESOLUTION_1024_640", - "1024 x 768":"RESOLUTION_1024_768", - "1024 x 832":"RESOLUTION_1024_832", - "1024 x 896":"RESOLUTION_1024_896", - "1024 x 960":"RESOLUTION_1024_960", - "1024 x 1024":"RESOLUTION_1024_1024", - "1088 x 768":"RESOLUTION_1088_768", - "1088 x 832":"RESOLUTION_1088_832", - "1088 x 896":"RESOLUTION_1088_896", - "1088 x 960":"RESOLUTION_1088_960", - "1120 x 896":"RESOLUTION_1120_896", - "1152 x 704":"RESOLUTION_1152_704", - "1152 x 768":"RESOLUTION_1152_768", - "1152 x 832":"RESOLUTION_1152_832", - "1152 x 864":"RESOLUTION_1152_864", - "1152 x 896":"RESOLUTION_1152_896", - "1216 x 704":"RESOLUTION_1216_704", - "1216 x 768":"RESOLUTION_1216_768", - "1216 x 832":"RESOLUTION_1216_832", - "1232 x 768":"RESOLUTION_1232_768", - "1248 x 832":"RESOLUTION_1248_832", - "1280 x 704":"RESOLUTION_1280_704", - "1280 x 720":"RESOLUTION_1280_720", - "1280 x 768":"RESOLUTION_1280_768", - "1280 x 800":"RESOLUTION_1280_800", - "1312 x 736":"RESOLUTION_1312_736", - "1344 x 640":"RESOLUTION_1344_640", - "1344 x 704":"RESOLUTION_1344_704", - "1344 x 768":"RESOLUTION_1344_768", - "1408 x 576":"RESOLUTION_1408_576", - "1408 x 640":"RESOLUTION_1408_640", - "1408 x 704":"RESOLUTION_1408_704", - "1472 x 576":"RESOLUTION_1472_576", - "1472 x 640":"RESOLUTION_1472_640", - "1472 x 704":"RESOLUTION_1472_704", - "1536 x 512":"RESOLUTION_1536_512", - "1536 x 576":"RESOLUTION_1536_576", - "1536 x 640":"RESOLUTION_1536_640", + "Auto": "AUTO", + "512 x 1536": "RESOLUTION_512_1536", + "576 x 1408": "RESOLUTION_576_1408", + "576 x 1472": "RESOLUTION_576_1472", + "576 x 1536": "RESOLUTION_576_1536", + "640 x 1024": "RESOLUTION_640_1024", + "640 x 1344": "RESOLUTION_640_1344", + "640 x 1408": "RESOLUTION_640_1408", + "640 x 1472": "RESOLUTION_640_1472", + "640 x 1536": "RESOLUTION_640_1536", + "704 x 1152": "RESOLUTION_704_1152", + "704 x 1216": "RESOLUTION_704_1216", + "704 x 1280": "RESOLUTION_704_1280", + "704 x 1344": "RESOLUTION_704_1344", + "704 x 1408": "RESOLUTION_704_1408", + "704 x 1472": "RESOLUTION_704_1472", + "720 x 1280": "RESOLUTION_720_1280", + "736 x 1312": "RESOLUTION_736_1312", + "768 x 1024": "RESOLUTION_768_1024", + "768 x 1088": "RESOLUTION_768_1088", + "768 x 1152": "RESOLUTION_768_1152", + "768 x 1216": "RESOLUTION_768_1216", + "768 x 1232": "RESOLUTION_768_1232", + "768 x 1280": "RESOLUTION_768_1280", + "768 x 1344": "RESOLUTION_768_1344", + "832 x 960": "RESOLUTION_832_960", + "832 x 1024": "RESOLUTION_832_1024", + "832 x 1088": "RESOLUTION_832_1088", + "832 x 1152": "RESOLUTION_832_1152", + "832 x 1216": "RESOLUTION_832_1216", + "832 x 1248": "RESOLUTION_832_1248", + "864 x 1152": "RESOLUTION_864_1152", + "896 x 960": "RESOLUTION_896_960", + "896 x 1024": "RESOLUTION_896_1024", + "896 x 1088": "RESOLUTION_896_1088", + "896 x 1120": "RESOLUTION_896_1120", + "896 x 1152": "RESOLUTION_896_1152", + "960 x 832": "RESOLUTION_960_832", + "960 x 896": "RESOLUTION_960_896", + "960 x 1024": "RESOLUTION_960_1024", + "960 x 1088": "RESOLUTION_960_1088", + "1024 x 640": "RESOLUTION_1024_640", + "1024 x 768": "RESOLUTION_1024_768", + "1024 x 832": "RESOLUTION_1024_832", + "1024 x 896": "RESOLUTION_1024_896", + "1024 x 960": "RESOLUTION_1024_960", + "1024 x 1024": "RESOLUTION_1024_1024", + "1088 x 768": "RESOLUTION_1088_768", + "1088 x 832": "RESOLUTION_1088_832", + "1088 x 896": "RESOLUTION_1088_896", + "1088 x 960": "RESOLUTION_1088_960", + "1120 x 896": "RESOLUTION_1120_896", + "1152 x 704": "RESOLUTION_1152_704", + "1152 x 768": "RESOLUTION_1152_768", + "1152 x 832": "RESOLUTION_1152_832", + "1152 x 864": "RESOLUTION_1152_864", + "1152 x 896": "RESOLUTION_1152_896", + "1216 x 704": "RESOLUTION_1216_704", + "1216 x 768": "RESOLUTION_1216_768", + "1216 x 832": "RESOLUTION_1216_832", + "1232 x 768": "RESOLUTION_1232_768", + "1248 x 832": "RESOLUTION_1248_832", + "1280 x 704": "RESOLUTION_1280_704", + "1280 x 720": "RESOLUTION_1280_720", + "1280 x 768": "RESOLUTION_1280_768", + "1280 x 800": "RESOLUTION_1280_800", + "1312 x 736": "RESOLUTION_1312_736", + "1344 x 640": "RESOLUTION_1344_640", + "1344 x 704": "RESOLUTION_1344_704", + "1344 x 768": "RESOLUTION_1344_768", + "1408 x 576": "RESOLUTION_1408_576", + "1408 x 640": "RESOLUTION_1408_640", + "1408 x 704": "RESOLUTION_1408_704", + "1472 x 576": "RESOLUTION_1472_576", + "1472 x 640": "RESOLUTION_1472_640", + "1472 x 704": "RESOLUTION_1472_704", + "1536 x 512": "RESOLUTION_1536_512", + "1536 x 576": "RESOLUTION_1536_576", + "1536 x 640": "RESOLUTION_1536_640", } V1_V2_RATIO_MAP = { - "1:1":"ASPECT_1_1", - "4:3":"ASPECT_4_3", - "3:4":"ASPECT_3_4", - "16:9":"ASPECT_16_9", - "9:16":"ASPECT_9_16", - "2:1":"ASPECT_2_1", - "1:2":"ASPECT_1_2", - "3:2":"ASPECT_3_2", - "2:3":"ASPECT_2_3", - "4:5":"ASPECT_4_5", - "5:4":"ASPECT_5_4", + "1:1": "ASPECT_1_1", + "4:3": "ASPECT_4_3", + "3:4": "ASPECT_3_4", + "16:9": "ASPECT_16_9", + "9:16": "ASPECT_9_16", + "2:1": "ASPECT_2_1", + "1:2": "ASPECT_1_2", + "3:2": "ASPECT_3_2", + "2:3": "ASPECT_2_3", + "4:5": "ASPECT_4_5", + "5:4": "ASPECT_5_4", } V3_RATIO_MAP = { - "1:3":"1x3", - "3:1":"3x1", - "1:2":"1x2", - "2:1":"2x1", - "9:16":"9x16", - "16:9":"16x9", - "10:16":"10x16", - "16:10":"16x10", - "2:3":"2x3", - "3:2":"3x2", - "3:4":"3x4", - "4:3":"4x3", - "4:5":"4x5", - "5:4":"5x4", - "1:1":"1x1", + "1:3": "1x3", + "3:1": "3x1", + "1:2": "1x2", + "2:1": "2x1", + "9:16": "9x16", + "16:9": "16x9", + "10:16": "10x16", + "16:10": "16x10", + "2:3": "2x3", + "3:2": "3x2", + "3:4": "3x4", + "4:3": "4x3", + "4:5": "4x5", + "5:4": "5x4", + "1:1": "1x1", } -V3_RESOLUTIONS= [ +V3_RESOLUTIONS = [ "Auto", "512x1536", "576x1408", @@ -212,6 +206,7 @@ V3_RESOLUTIONS= [ "1536x640" ] + async def download_and_process_images(image_urls): """Helper function to download and process multiple images from URLs""" @@ -220,7 +215,7 @@ async def download_and_process_images(image_urls): for image_url in image_urls: # Using functions from apinode_utils.py to handle downloading and processing - image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO + image_bytesio = await download_url_as_bytesio(image_url) # Download image content to BytesIO img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode image_tensors.append(img_tensor) @@ -233,19 +228,6 @@ async def download_and_process_images(image_urls): return stacked_tensors -def display_image_urls_on_node(image_urls, node_id): - if node_id and image_urls: - if len(image_urls) == 1: - PromptServer.instance.send_progress_text( - f"Generated Image URL:\n{image_urls[0]}", node_id - ) - else: - urls_text = "Generated Image URLs:\n" + "\n".join( - f"{i+1}. {url}" for i, url in enumerate(image_urls) - ) - PromptServer.instance.send_progress_text(urls_text, node_id) - - class IdeogramV1(IO.ComfyNode): @classmethod @@ -321,57 +303,43 @@ class IdeogramV1(IO.ComfyNode): @classmethod async def execute( - cls, - prompt, - turbo=False, - aspect_ratio="1:1", - magic_prompt_option="AUTO", - seed=0, - negative_prompt="", - num_images=1, + cls, + prompt, + turbo=False, + aspect_ratio="1:1", + magic_prompt_option="AUTO", + seed=0, + negative_prompt="", + num_images=1, ): # Determine the model based on turbo setting aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) model = "V_1_TURBO" if turbo else "V_1" - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/ideogram/generate", - method=HttpMethod.POST, - request_model=IdeogramGenerateRequest, - response_model=IdeogramGenerateResponse, - ), - request=IdeogramGenerateRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/ideogram/generate", method="POST"), + response_model=IdeogramGenerateResponse, + data=IdeogramGenerateRequest( image_request=ImageRequest( prompt=prompt, model=model, num_images=num_images, seed=seed, aspect_ratio=aspect_ratio if aspect_ratio != "ASPECT_1_1" else None, - magic_prompt_option=( - magic_prompt_option if magic_prompt_option != "AUTO" else None - ), + magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None), negative_prompt=negative_prompt if negative_prompt else None, ) ), - auth_kwargs=auth, + max_retries=1, ) - response = await operation.execute() - if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") image_urls = [image_data.url for image_data in response.data if image_data.url] - if not image_urls: raise Exception("No image URLs were generated in the response") - - display_image_urls_on_node(image_urls, cls.hidden.unique_id) return IO.NodeOutput(await download_and_process_images(image_urls)) @@ -452,14 +420,14 @@ class IdeogramV2(IO.ComfyNode): display_mode=IO.NumberDisplay.number, optional=True, ), - #"color_palette": ( + # "color_palette": ( # IO.STRING, # { # "multiline": False, # "default": "", # "tooltip": "Color palette preset name or hex colors with weights", # }, - #), + # ), ], outputs=[ IO.Image.Output(), @@ -473,17 +441,17 @@ class IdeogramV2(IO.ComfyNode): @classmethod async def execute( - cls, - prompt, - turbo=False, - aspect_ratio="1:1", - resolution="Auto", - magic_prompt_option="AUTO", - seed=0, - style_type="NONE", - negative_prompt="", - num_images=1, - color_palette="", + cls, + prompt, + turbo=False, + aspect_ratio="1:1", + resolution="Auto", + magic_prompt_option="AUTO", + seed=0, + style_type="NONE", + negative_prompt="", + num_images=1, + color_palette="", ): aspect_ratio = V1_V2_RATIO_MAP.get(aspect_ratio, None) resolution = V1_V1_RES_MAP.get(resolution, None) @@ -500,18 +468,11 @@ class IdeogramV2(IO.ComfyNode): else: final_aspect_ratio = aspect_ratio if aspect_ratio != "ASPECT_1_1" else None - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/ideogram/generate", - method=HttpMethod.POST, - request_model=IdeogramGenerateRequest, - response_model=IdeogramGenerateResponse, - ), - request=IdeogramGenerateRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/ideogram/generate", method="POST"), + response_model=IdeogramGenerateResponse, + data=IdeogramGenerateRequest( image_request=ImageRequest( prompt=prompt, model=model, @@ -519,28 +480,20 @@ class IdeogramV2(IO.ComfyNode): seed=seed, aspect_ratio=final_aspect_ratio, resolution=final_resolution, - magic_prompt_option=( - magic_prompt_option if magic_prompt_option != "AUTO" else None - ), + magic_prompt_option=(magic_prompt_option if magic_prompt_option != "AUTO" else None), style_type=style_type if style_type != "NONE" else None, negative_prompt=negative_prompt if negative_prompt else None, color_palette=color_palette if color_palette else None, ) ), - auth_kwargs=auth, + max_retries=1, ) - - response = await operation.execute() - if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") image_urls = [image_data.url for image_data in response.data if image_data.url] - if not image_urls: raise Exception("No image URLs were generated in the response") - - display_image_urls_on_node(image_urls, cls.hidden.unique_id) return IO.NodeOutput(await download_and_process_images(image_urls)) @@ -643,23 +596,19 @@ class IdeogramV3(IO.ComfyNode): @classmethod async def execute( - cls, - prompt, - image=None, - mask=None, - resolution="Auto", - aspect_ratio="1:1", - magic_prompt_option="AUTO", - seed=0, - num_images=1, - rendering_speed="DEFAULT", - character_image=None, - character_mask=None, + cls, + prompt, + image=None, + mask=None, + resolution="Auto", + aspect_ratio="1:1", + magic_prompt_option="AUTO", + seed=0, + num_images=1, + rendering_speed="DEFAULT", + character_image=None, + character_mask=None, ): - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } if rendering_speed == "BALANCED": # for backward compatibility rendering_speed = "DEFAULT" @@ -694,9 +643,6 @@ class IdeogramV3(IO.ComfyNode): # Check if both image and mask are provided for editing mode if image is not None and mask is not None: - # Edit mode - path = "/proxy/ideogram/ideogram-v3/edit" - # Process image and mask input_tensor = image.squeeze().cpu() # Resize mask to match image dimension @@ -749,27 +695,20 @@ class IdeogramV3(IO.ComfyNode): if character_mask_binary: files["character_mask_binary"] = character_mask_binary - # Execute the operation for edit mode - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=IdeogramV3EditRequest, - response_model=IdeogramGenerateResponse, - ), - request=edit_request, + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/ideogram/ideogram-v3/edit", method="POST"), + response_model=IdeogramGenerateResponse, + data=edit_request, files=files, content_type="multipart/form-data", - auth_kwargs=auth, + max_retries=1, ) elif image is not None or mask is not None: # If only one of image or mask is provided, raise an error raise Exception("Ideogram V3 image editing requires both an image AND a mask") else: - # Generation mode - path = "/proxy/ideogram/ideogram-v3/generate" - # Create generation request gen_request = IdeogramV3Request( prompt=prompt, @@ -800,32 +739,22 @@ class IdeogramV3(IO.ComfyNode): if files: gen_request.style_type = "AUTO" - # Execute the operation for generation mode - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=IdeogramV3Request, - response_model=IdeogramGenerateResponse, - ), - request=gen_request, + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/ideogram/ideogram-v3/generate", method="POST"), + response_model=IdeogramGenerateResponse, + data=gen_request, files=files if files else None, content_type="multipart/form-data", - auth_kwargs=auth, + max_retries=1, ) - # Execute the operation and process response - response = await operation.execute() - if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") image_urls = [image_data.url for image_data in response.data if image_data.url] - if not image_urls: raise Exception("No image URLs were generated in the response") - - display_image_urls_on_node(image_urls, cls.hidden.unique_id) return IO.NodeOutput(await download_and_process_images(image_urls)) @@ -838,5 +767,6 @@ class IdeogramExtension(ComfyExtension): IdeogramV3, ] + async def comfy_entrypoint() -> IdeogramExtension: return IdeogramExtension() diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 67c8307c5..6c840dc47 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -4,18 +4,15 @@ For source of truth on the allowed permutations of request fields, please refere - [Compatibility Table](https://app.klingai.com/global/dev/document-api/apiReference/model/skillsMap) """ -from __future__ import annotations -from typing import Optional, TypeVar, Any -from collections.abc import Callable -import math import logging - -from typing_extensions import override +import math +import re import torch +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.apis import ( - KlingTaskStatus, KlingCameraControl, KlingCameraConfig, KlingCameraControlType, @@ -52,31 +49,33 @@ from comfy_api_nodes.apis import ( KlingCharacterEffectModelName, KlingSingleImageEffectModelName, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.apis.kling_api import ( + OmniImageParamImage, + OmniParamImage, + OmniParamVideo, + OmniProFirstLastFrameRequest, + OmniProImageRequest, + OmniProReferences2VideoRequest, + OmniProText2VideoRequest, + OmniTaskStatusResponse, +) +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - tensor_to_base64_string, - download_url_to_video_output, - upload_video_to_comfyapi, - upload_audio_to_comfyapi, download_url_to_image_tensor, - validate_string, -) -from comfy_api_nodes.util.validation_utils import ( - validate_image_dimensions, + download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, + tensor_to_base64_string, + upload_audio_to_comfyapi, + upload_images_to_comfyapi, + upload_video_to_comfyapi, validate_image_aspect_ratio, + validate_image_dimensions, + validate_string, validate_video_dimensions, validate_video_duration, ) -from comfy_api.input_impl import VideoFromFile -from comfy_api.input.basic_types import AudioInput -from comfy_api.input.video_types import VideoInput -from comfy_api.latest import ComfyExtension, IO KLING_API_VERSION = "v1" PATH_TEXT_TO_VIDEO = f"/proxy/kling/{KLING_API_VERSION}/videos/text2video" @@ -102,8 +101,6 @@ AVERAGE_DURATION_IMAGE_GEN = 32 AVERAGE_DURATION_VIDEO_EFFECTS = 320 AVERAGE_DURATION_VIDEO_EXTEND = 320 -R = TypeVar("R") - MODE_TEXT2VIDEO = { "standard mode / 5s duration / kling-v1": ("std", "5", "kling-v1"), @@ -138,6 +135,8 @@ MODE_START_END_FRAME = { "pro mode / 10s duration / kling-v1-6": ("pro", "10", "kling-v1-6"), "pro mode / 5s duration / kling-v2-1": ("pro", "5", "kling-v2-1"), "pro mode / 10s duration / kling-v2-1": ("pro", "10", "kling-v2-1"), + "pro mode / 5s duration / kling-v2-5-turbo": ("pro", "5", "kling-v2-5-turbo"), + "pro mode / 10s duration / kling-v2-5-turbo": ("pro", "10", "kling-v2-5-turbo"), } """ Returns a mapping of mode strings to their corresponding (mode, duration, model_name) tuples. @@ -214,32 +213,48 @@ VOICES_CONFIG = { } -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, R], - result_url_extractor: Optional[Callable[[R], str]] = None, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> R: - """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - KlingTaskStatus.succeed.value, - ], - failed_statuses=[KlingTaskStatus.failed.value], - status_extractor=lambda response: ( - response.data.task_status.value - if response.data and response.data.task_status - else None - ), - auth_kwargs=auth_kwargs, - result_url_extractor=result_url_extractor, - estimated_duration=estimated_duration, - node_id=node_id, - poll_interval=16.0, - max_poll_attempts=256, - ).execute() +def normalize_omni_prompt_references(prompt: str) -> str: + """ + Rewrites Kling Omni-style placeholders used in the app, like: + + @image, @image1, @image2, ... @imageN + @video, @video1, @video2, ... @videoN + + into the API-compatible form: + + <<>>, <<>>, ... + <<>>, <<>>, ... + + This is a UX shim for ComfyUI so users can type the same syntax as in the Kling app. + """ + if not prompt: + return prompt + + def _image_repl(match): + return f"<<>>" + + def _video_repl(match): + return f"<<>>" + + # (? and not @imageFoo + prompt = re.sub(r"(?\d*)(?!\w)", _image_repl, prompt) + return re.sub(r"(?\d*)(?!\w)", _video_repl, prompt) + + +async def finish_omni_video_task(cls: type[IO.ComfyNode], response: OmniTaskStatusResponse) -> IO.NodeOutput: + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/videos/omni-video/{response.data.task_id}"), + response_model=OmniTaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + max_poll_attempts=160, + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.data.task_result.videos[0].url)) def is_valid_camera_control_configs(configs: list[float]) -> bool: @@ -318,7 +333,7 @@ def validate_input_image(image: torch.Tensor) -> None: See: https://app.klingai.com/global/dev/document-api/apiReference/model/imageToVideo """ validate_image_dimensions(image, min_width=300, min_height=300) - validate_image_aspect_ratio(image, min_aspect_ratio=1 / 2.5, max_aspect_ratio=2.5) + validate_image_aspect_ratio(image, (1, 2.5), (2.5, 1)) def get_video_from_response(response) -> KlingVideoResult: @@ -332,7 +347,7 @@ def get_video_from_response(response) -> KlingVideoResult: return video -def get_video_url_from_response(response) -> Optional[str]: +def get_video_url_from_response(response) -> str | None: """Returns the first video url from the Kling video generation task result. Will not raise an error if the response is not valid. """ @@ -351,7 +366,7 @@ def get_images_from_response(response) -> list[KlingImageResult]: return images -def get_images_urls_from_response(response) -> Optional[str]: +def get_images_urls_from_response(response) -> str | None: """Returns the list of image urls from the Kling image generation task result. Will not raise an error if the response is not valid. If there is only one image, returns the url as a string. If there are multiple images, returns a list of urls. """ @@ -377,8 +392,7 @@ async def image_result_to_node_output( async def execute_text2video( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], prompt: str, negative_prompt: str, cfg_scale: float, @@ -386,17 +400,14 @@ async def execute_text2video( model_mode: str, duration: str, aspect_ratio: str, - camera_control: Optional[KlingCameraControl] = None, + camera_control: KlingCameraControl | None = None, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_VIDEO, - method=HttpMethod.POST, - request_model=KlingText2VideoRequest, - response_model=KlingText2VideoResponse, - ), - request=KlingText2VideoRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"), + response_model=KlingText2VideoResponse, + data=KlingText2VideoRequest( prompt=prompt if prompt else None, negative_prompt=negative_prompt if negative_prompt else None, duration=KlingVideoGenDuration(duration), @@ -406,24 +417,17 @@ async def execute_text2video( aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio), camera_control=camera_control, ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingText2VideoResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_TEXT_TO_VIDEO}/{task_id}"), + response_model=KlingText2VideoResponse, estimated_duration=AVERAGE_DURATION_T2V, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -432,8 +436,7 @@ async def execute_text2video( async def execute_image2video( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], start_frame: torch.Tensor, prompt: str, negative_prompt: str, @@ -442,8 +445,8 @@ async def execute_image2video( model_mode: str, aspect_ratio: str, duration: str, - camera_control: Optional[KlingCameraControl] = None, - end_frame: Optional[torch.Tensor] = None, + camera_control: KlingCameraControl | None = None, + end_frame: torch.Tensor | None = None, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_I2V) validate_input_image(start_frame) @@ -455,14 +458,11 @@ async def execute_image2video( if model_mode == "std" and model_name == KlingVideoGenModelName.kling_v2_5_turbo.value: model_mode = "pro" # October 5: currently "std" mode is not supported for this model - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=KlingImage2VideoRequest, - response_model=KlingImage2VideoResponse, - ), - request=KlingImage2VideoRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), + response_model=KlingImage2VideoResponse, + data=KlingImage2VideoRequest( model_name=KlingVideoGenModelName(model_name), image=tensor_to_base64_string(start_frame), image_tail=( @@ -477,24 +477,17 @@ async def execute_image2video( duration=KlingVideoGenDuration(duration), camera_control=camera_control, ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", - method=HttpMethod.GET, - request_model=KlingImage2VideoRequest, - response_model=KlingImage2VideoResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}"), + response_model=KlingImage2VideoResponse, estimated_duration=AVERAGE_DURATION_I2V, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -503,16 +496,15 @@ async def execute_image2video( async def execute_video_effect( - auth_kwargs: dict[str, str], - node_id: str, + cls: type[IO.ComfyNode], dual_character: bool, effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, model_name: str, duration: KlingVideoGenDuration, image_1: torch.Tensor, - image_2: Optional[torch.Tensor] = None, - model_mode: Optional[KlingVideoGenMode] = None, -) -> tuple[VideoFromFile, str, str]: + image_2: torch.Tensor | None = None, + model_mode: KlingVideoGenMode | None = None, +) -> tuple[InputImpl.VideoFromFile, str, str]: if dual_character: request_input_field = KlingDualCharacterEffectInput( model_name=model_name, @@ -530,35 +522,25 @@ async def execute_video_effect( duration=duration, ) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIDEO_EFFECTS, - method=HttpMethod.POST, - request_model=KlingVideoEffectsRequest, - response_model=KlingVideoEffectsResponse, - ), - request=KlingVideoEffectsRequest( + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_VIDEO_EFFECTS, method="POST"), + response_model=KlingVideoEffectsResponse, + data=KlingVideoEffectsRequest( effect_scene=effect_scene, input=request_input_field, ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_VIDEO_EFFECTS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVideoEffectsResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIDEO_EFFECTS}/{task_id}"), + response_model=KlingVideoEffectsResponse, estimated_duration=AVERAGE_DURATION_VIDEO_EFFECTS, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -567,15 +549,14 @@ async def execute_video_effect( async def execute_lipsync( - auth_kwargs: dict[str, str], - node_id: str, - video: VideoInput, - audio: Optional[AudioInput] = None, - voice_language: Optional[str] = None, - model_mode: Optional[str] = None, - text: Optional[str] = None, - voice_speed: Optional[float] = None, - voice_id: Optional[str] = None, + cls: type[IO.ComfyNode], + video: Input.Video, + audio: Input.Audio | None = None, + voice_language: str | None = None, + model_mode: str | None = None, + text: str | None = None, + voice_speed: float | None = None, + voice_id: str | None = None, ) -> IO.NodeOutput: if text: validate_string(text, field_name="Text", max_length=MAX_PROMPT_LENGTH_LIP_SYNC) @@ -583,24 +564,23 @@ async def execute_lipsync( validate_video_duration(video, 2, 10) # Upload video to Comfy API and get download URL - video_url = await upload_video_to_comfyapi(video, auth_kwargs=auth_kwargs) + video_url = await upload_video_to_comfyapi(cls, video) logging.info("Uploaded video to Comfy API. URL: %s", video_url) # Upload the audio file to Comfy API and get download URL if audio: - audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=auth_kwargs) + audio_url = await upload_audio_to_comfyapi( + cls, audio, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mpeg", filename="output.mp3" + ) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) else: audio_url = None - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_LIP_SYNC, - method=HttpMethod.POST, - request_model=KlingLipSyncRequest, - response_model=KlingLipSyncResponse, - ), - request=KlingLipSyncRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(PATH_LIP_SYNC, "POST"), + response_model=KlingLipSyncResponse, + data=KlingLipSyncRequest( input=KlingLipSyncInputObject( video_url=video_url, mode=model_mode, @@ -612,24 +592,17 @@ async def execute_lipsync( voice_id=voice_id, ), ), - auth_kwargs=auth_kwargs, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_LIP_SYNC}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingLipSyncResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_LIP_SYNC}/{task_id}"), + response_model=KlingLipSyncResponse, estimated_duration=AVERAGE_DURATION_LIP_SYNC, - node_id=node_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -807,11 +780,7 @@ class KlingTextToVideoNode(IO.ComfyNode): ) -> IO.NodeOutput: model_mode, duration, model_name = MODE_TEXT2VIDEO[mode] return await execute_text2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt=prompt, negative_prompt=negative_prompt, cfg_scale=cfg_scale, @@ -822,6 +791,474 @@ class KlingTextToVideoNode(IO.ComfyNode): ) +class OmniProTextToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProTextToVideoNode", + display_name="Kling Omni Text to Video (Pro)", + category="api node/video/Kling", + description="Use text prompts to generate videos with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Combo.Input("duration", options=[5, 10]), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=2500) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=OmniTaskStatusResponse, + data=OmniProText2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProFirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProFirstLastFrameNode", + display_name="Kling Omni First-Last-Frame to Video (Pro)", + category="api node/video/Kling", + description="Use a start frame, an optional end frame, or reference images with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("duration", options=["5", "10"]), + IO.Image.Input("first_frame"), + IO.Image.Input( + "end_frame", + optional=True, + tooltip="An optional end frame for the video. " + "This cannot be used simultaneously with 'reference_images'.", + ), + IO.Image.Input( + "reference_images", + optional=True, + tooltip="Up to 6 additional reference images.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + duration: int, + first_frame: Input.Image, + end_frame: Input.Image | None = None, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) + validate_string(prompt, min_length=1, max_length=2500) + if end_frame is not None and reference_images is not None: + raise ValueError("The 'end_frame' input cannot be used simultaneously with 'reference_images'.") + validate_image_dimensions(first_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(first_frame, (1, 2.5), (2.5, 1)) + image_list: list[OmniParamImage] = [ + OmniParamImage( + image_url=(await upload_images_to_comfyapi(cls, first_frame, wait_label="Uploading first frame"))[0], + type="first_frame", + ) + ] + if end_frame is not None: + validate_image_dimensions(end_frame, min_width=300, min_height=300) + validate_image_aspect_ratio(end_frame, (1, 2.5), (2.5, 1)) + image_list.append( + OmniParamImage( + image_url=(await upload_images_to_comfyapi(cls, end_frame, wait_label="Uploading end frame"))[0], + type="end_frame", + ) + ) + if reference_images is not None: + if get_number_of_images(reference_images) > 6: + raise ValueError("The maximum number of reference images allowed is 6.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference frame(s)"): + image_list.append(OmniParamImage(image_url=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=OmniTaskStatusResponse, + data=OmniProFirstLastFrameRequest( + model_name=model_name, + prompt=prompt, + duration=str(duration), + image_list=image_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProImageToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProImageToVideoNode", + display_name="Kling Omni Image to Video (Pro)", + category="api node/video/Kling", + description="Use up to 7 reference images to generate a video with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider), + IO.Image.Input( + "reference_images", + tooltip="Up to 7 reference images.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + reference_images: Input.Image, + ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) + validate_string(prompt, min_length=1, max_length=2500) + if get_number_of_images(reference_images) > 7: + raise ValueError("The maximum number of reference images is 7.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + image_list: list[OmniParamImage] = [] + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=OmniTaskStatusResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + image_list=image_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProVideoToVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProVideoToVideoNode", + display_name="Kling Omni Video to Video (Pro)", + category="api node/video/Kling", + description="Use a video and up to 4 reference images to generate a video with the latest Kling model.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), + IO.Int.Input("duration", default=3, min=3, max=10, display_mode=IO.NumberDisplay.slider), + IO.Video.Input("reference_video", tooltip="Video to use as a reference."), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Image.Input( + "reference_images", + tooltip="Up to 4 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + aspect_ratio: str, + duration: int, + reference_video: Input.Video, + keep_original_sound: bool, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) + validate_string(prompt, min_length=1, max_length=2500) + validate_video_duration(reference_video, min_duration=3.0, max_duration=10.05) + validate_video_dimensions(reference_video, min_width=720, min_height=720, max_width=2160, max_height=2160) + image_list: list[OmniParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 4: + raise ValueError("The maximum number of reference images allowed with a video input is 4.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + video_list = [ + OmniParamVideo( + video_url=await upload_video_to_comfyapi(cls, reference_video, wait_label="Uploading reference video"), + refer_type="feature", + keep_original_sound="yes" if keep_original_sound else "no", + ) + ] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=OmniTaskStatusResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=aspect_ratio, + duration=str(duration), + image_list=image_list if image_list else None, + video_list=video_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProEditVideoNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProEditVideoNode", + display_name="Kling Omni Edit Video (Pro)", + category="api node/video/Kling", + description="Edit an existing video with the latest model from Kling.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-video-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the video content. " + "This can include both positive and negative descriptions.", + ), + IO.Video.Input("video", tooltip="Video for editing. The output video length will be the same."), + IO.Boolean.Input("keep_original_sound", default=True), + IO.Image.Input( + "reference_images", + tooltip="Up to 4 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + video: Input.Video, + keep_original_sound: bool, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) + validate_string(prompt, min_length=1, max_length=2500) + validate_video_duration(video, min_duration=3.0, max_duration=10.05) + validate_video_dimensions(video, min_width=720, min_height=720, max_width=2160, max_height=2160) + image_list: list[OmniParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 4: + raise ValueError("The maximum number of reference images allowed with a video input is 4.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniParamImage(image_url=i)) + video_list = [ + OmniParamVideo( + video_url=await upload_video_to_comfyapi(cls, video, wait_label="Uploading base video"), + refer_type="base", + keep_original_sound="yes" if keep_original_sound else "no", + ) + ] + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/videos/omni-video", method="POST"), + response_model=OmniTaskStatusResponse, + data=OmniProReferences2VideoRequest( + model_name=model_name, + prompt=prompt, + aspect_ratio=None, + duration=None, + image_list=image_list if image_list else None, + video_list=video_list, + ), + ) + return await finish_omni_video_task(cls, response) + + +class OmniProImageNode(IO.ComfyNode): + + @classmethod + def define_schema(cls) -> IO.Schema: + return IO.Schema( + node_id="KlingOmniProImageNode", + display_name="Kling Omni Image (Pro)", + category="api node/image/Kling", + description="Create or edit images with the latest model from Kling.", + inputs=[ + IO.Combo.Input("model_name", options=["kling-image-o1"]), + IO.String.Input( + "prompt", + multiline=True, + tooltip="A text prompt describing the image content. " + "This can include both positive and negative descriptions.", + ), + IO.Combo.Input("resolution", options=["1K", "2K"]), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16", "1:1", "4:3", "3:4", "3:2", "2:3", "21:9"], + ), + IO.Image.Input( + "reference_images", + tooltip="Up to 10 additional reference images.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model_name: str, + prompt: str, + resolution: str, + aspect_ratio: str, + reference_images: Input.Image | None = None, + ) -> IO.NodeOutput: + prompt = normalize_omni_prompt_references(prompt) + validate_string(prompt, min_length=1, max_length=2500) + image_list: list[OmniImageParamImage] = [] + if reference_images is not None: + if get_number_of_images(reference_images) > 10: + raise ValueError("The maximum number of reference images is 10.") + for i in reference_images: + validate_image_dimensions(i, min_width=300, min_height=300) + validate_image_aspect_ratio(i, (1, 2.5), (2.5, 1)) + for i in await upload_images_to_comfyapi(cls, reference_images, wait_label="Uploading reference image"): + image_list.append(OmniImageParamImage(image=i)) + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/kling/v1/images/omni-image", method="POST"), + response_model=OmniTaskStatusResponse, + data=OmniProImageRequest( + model_name=model_name, + prompt=prompt, + resolution=resolution.lower(), + aspect_ratio=aspect_ratio, + image_list=image_list if image_list else None, + ), + ) + if response.code: + raise RuntimeError( + f"Kling request failed. Code: {response.code}, Message: {response.message}, Data: {response.data}" + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/kling/v1/images/omni-image/{response.data.task_id}"), + response_model=OmniTaskStatusResponse, + status_extractor=lambda r: (r.data.task_status if r.data else None), + ) + return IO.NodeOutput(await download_url_to_image_tensor(final_response.data.task_result.images[0].url)) + + class KlingCameraControlT2VNode(IO.ComfyNode): """ Kling Text to Video Camera Control Node. This node is a text to video node, but it supports controlling the camera. @@ -869,14 +1306,10 @@ class KlingCameraControlT2VNode(IO.ComfyNode): negative_prompt: str, cfg_scale: float, aspect_ratio: str, - camera_control: Optional[KlingCameraControl] = None, + camera_control: KlingCameraControl | None = None, ) -> IO.NodeOutput: return await execute_text2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, model_name=KlingVideoGenModelName.kling_v1, cfg_scale=cfg_scale, model_mode=KlingVideoGenMode.std, @@ -940,15 +1373,11 @@ class KlingImage2VideoNode(IO.ComfyNode): mode: str, aspect_ratio: str, duration: str, - camera_control: Optional[KlingCameraControl] = None, - end_frame: Optional[torch.Tensor] = None, + camera_control: KlingCameraControl | None = None, + end_frame: torch.Tensor | None = None, ) -> IO.NodeOutput: return await execute_image2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, start_frame=start_frame, prompt=prompt, negative_prompt=negative_prompt, @@ -1017,11 +1446,7 @@ class KlingCameraControlI2VNode(IO.ComfyNode): camera_control: KlingCameraControl, ) -> IO.NodeOutput: return await execute_image2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, model_name=KlingVideoGenModelName.kling_v1_5, start_frame=start_frame, cfg_scale=cfg_scale, @@ -1059,15 +1484,11 @@ class KlingStartEndFrameNode(IO.ComfyNode): IO.String.Input("prompt", multiline=True, tooltip="Positive text prompt"), IO.String.Input("negative_prompt", multiline=True, tooltip="Negative text prompt"), IO.Float.Input("cfg_scale", default=0.5, min=0.0, max=1.0), - IO.Combo.Input( - "aspect_ratio", - options=[i.value for i in KlingVideoGenAspectRatio], - default="16:9", - ), + IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]), IO.Combo.Input( "mode", options=modes, - default=modes[2], + default=modes[8], tooltip="The configuration to use for the video generation following the format: mode / duration / model_name.", ), ], @@ -1097,11 +1518,7 @@ class KlingStartEndFrameNode(IO.ComfyNode): ) -> IO.NodeOutput: mode, duration, model_name = MODE_START_END_FRAME[mode] return await execute_image2video( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt=prompt, negative_prompt=negative_prompt, model_name=model_name, @@ -1162,41 +1579,27 @@ class KlingVideoExtendNode(IO.ComfyNode): video_id: str, ) -> IO.NodeOutput: validate_prompts(prompt, negative_prompt, MAX_PROMPT_LENGTH_T2V) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIDEO_EXTEND, - method=HttpMethod.POST, - request_model=KlingVideoExtendRequest, - response_model=KlingVideoExtendResponse, - ), - request=KlingVideoExtendRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_VIDEO_EXTEND, method="POST"), + response_model=KlingVideoExtendResponse, + data=KlingVideoExtendRequest( prompt=prompt if prompt else None, negative_prompt=negative_prompt if negative_prompt else None, cfg_scale=cfg_scale, video_id=video_id, ), - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth, - ApiEndpoint( - path=f"{PATH_VIDEO_EXTEND}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVideoExtendResponse, - ), - result_url_extractor=get_video_url_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIDEO_EXTEND}/{task_id}"), + response_model=KlingVideoExtendResponse, estimated_duration=AVERAGE_DURATION_VIDEO_EXTEND, - node_id=cls.hidden.unique_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_video_result_response(final_response) @@ -1259,11 +1662,7 @@ class KlingDualCharacterVideoEffectNode(IO.ComfyNode): duration: KlingVideoGenDuration, ) -> IO.NodeOutput: video, _, duration = await execute_video_effect( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, dual_character=True, effect_scene=effect_scene, model_name=model_name, @@ -1286,7 +1685,10 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): category="api node/video/Kling", description="Achieve different special effects when generating a video based on the effect_scene.", inputs=[ - IO.Image.Input("image", tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1"), + IO.Image.Input( + "image", + tooltip=" Reference Image. URL or Base64 encoded string (without data:image prefix). File size cannot exceed 10MB, resolution not less than 300*300px, aspect ratio between 1:2.5 ~ 2.5:1", + ), IO.Combo.Input( "effect_scene", options=[i.value for i in KlingSingleImageEffectsScene], @@ -1324,11 +1726,7 @@ class KlingSingleImageVideoEffectNode(IO.ComfyNode): return IO.NodeOutput( *( await execute_video_effect( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, dual_character=False, effect_scene=effect_scene, model_name=model_name, @@ -1374,16 +1772,12 @@ class KlingLipSyncAudioToVideoNode(IO.ComfyNode): @classmethod async def execute( cls, - video: VideoInput, - audio: AudioInput, + video: Input.Video, + audio: Input.Audio, voice_language: str, ) -> IO.NodeOutput: return await execute_lipsync( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, video=video, audio=audio, voice_language=voice_language, @@ -1438,18 +1832,14 @@ class KlingLipSyncTextToVideoNode(IO.ComfyNode): @classmethod async def execute( cls, - video: VideoInput, + video: Input.Video, text: str, voice: str, voice_speed: float, ) -> IO.NodeOutput: voice_id, voice_language = VOICES_CONFIG[voice] return await execute_lipsync( - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, video=video, text=text, voice_language=voice_language, @@ -1496,40 +1886,26 @@ class KlingVirtualTryOnNode(IO.ComfyNode): cloth_image: torch.Tensor, model_name: KlingVirtualTryOnModelName, ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_VIRTUAL_TRY_ON, - method=HttpMethod.POST, - request_model=KlingVirtualTryOnRequest, - response_model=KlingVirtualTryOnResponse, - ), - request=KlingVirtualTryOnRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_VIRTUAL_TRY_ON, method="POST"), + response_model=KlingVirtualTryOnResponse, + data=KlingVirtualTryOnRequest( human_image=tensor_to_base64_string(human_image), cloth_image=tensor_to_base64_string(cloth_image), model_name=model_name, ), - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth, - ApiEndpoint( - path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingVirtualTryOnResponse, - ), - result_url_extractor=get_images_urls_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}"), + response_model=KlingVirtualTryOnResponse, estimated_duration=AVERAGE_DURATION_VIRTUAL_TRY_ON, - node_id=cls.hidden.unique_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_image_result_response(final_response) @@ -1613,7 +1989,7 @@ class KlingImageGenerationNode(IO.ComfyNode): human_fidelity: float, n: int, aspect_ratio: KlingImageGenAspectRatio, - image: Optional[torch.Tensor] = None, + image: torch.Tensor | None = None, ) -> IO.NodeOutput: validate_string(prompt, field_name="prompt", min_length=1, max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) validate_string(negative_prompt, field_name="negative_prompt", max_length=MAX_PROMPT_LENGTH_IMAGE_GEN) @@ -1625,18 +2001,11 @@ class KlingImageGenerationNode(IO.ComfyNode): else: image = tensor_to_base64_string(image) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_GENERATIONS, - method=HttpMethod.POST, - request_model=KlingImageGenerationsRequest, - response_model=KlingImageGenerationsResponse, - ), - request=KlingImageGenerationsRequest( + task_creation_response = await sync_op( + cls, + ApiEndpoint(path=PATH_IMAGE_GENERATIONS, method="POST"), + response_model=KlingImageGenerationsResponse, + data=KlingImageGenerationsRequest( model_name=model_name, prompt=prompt, negative_prompt=negative_prompt, @@ -1647,24 +2016,17 @@ class KlingImageGenerationNode(IO.ComfyNode): n=n, aspect_ratio=aspect_ratio, ), - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = await poll_until_finished( - auth, - ApiEndpoint( - path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=KlingImageGenerationsResponse, - ), - result_url_extractor=get_images_urls_from_response, + final_response = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_IMAGE_GENERATIONS}/{task_id}"), + response_model=KlingImageGenerationsResponse, estimated_duration=AVERAGE_DURATION_IMAGE_GEN, - node_id=cls.hidden.unique_id, + status_extractor=lambda r: (r.data.task_status.value if r.data and r.data.task_status else None), ) validate_image_result_response(final_response) @@ -1689,6 +2051,12 @@ class KlingExtension(ComfyExtension): KlingImageGenerationNode, KlingSingleImageVideoEffectNode, KlingDualCharacterVideoEffectNode, + OmniProTextToVideoNode, + OmniProFirstLastFrameNode, + OmniProImageToVideoNode, + OmniProVideoToVideoNode, + OmniProEditVideoNode, + # OmniProImageNode, # need support from backend ] diff --git a/comfy_api_nodes/nodes_ltxv.py b/comfy_api_nodes/nodes_ltxv.py new file mode 100644 index 000000000..7e61560dc --- /dev/null +++ b/comfy_api_nodes/nodes_ltxv.py @@ -0,0 +1,196 @@ +from io import BytesIO + +from pydantic import BaseModel, Field +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl +from comfy_api_nodes.util import ( + ApiEndpoint, + get_number_of_images, + sync_op_raw, + upload_images_to_comfyapi, + validate_string, +) + +MODELS_MAP = { + "LTX-2 (Pro)": "ltx-2-pro", + "LTX-2 (Fast)": "ltx-2-fast", +} + + +class ExecuteTaskRequest(BaseModel): + prompt: str = Field(...) + model: str = Field(...) + duration: int = Field(...) + resolution: str = Field(...) + fps: int | None = Field(25) + generate_audio: bool | None = Field(True) + image_uri: str | None = Field(None) + + +class TextToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="LtxvApiTextToVideo", + display_name="LTXV Text To Video", + category="api node/video/LTXV", + description="Professional-quality videos with customizable duration and resolution.", + inputs=[ + IO.Combo.Input("model", options=list(MODELS_MAP.keys())), + IO.String.Input( + "prompt", + multiline=True, + default="", + ), + IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8), + IO.Combo.Input( + "resolution", + options=[ + "1920x1080", + "2560x1440", + "3840x2160", + ], + ), + IO.Combo.Input("fps", options=[25, 50], default=25), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="When true, the generated video will include AI-generated audio matching the scene.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + prompt: str, + duration: int, + resolution: str, + fps: int = 25, + generate_audio: bool = False, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=10000) + if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25): + raise ValueError( + "Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS." + ) + response = await sync_op_raw( + cls, + ApiEndpoint("/proxy/ltx/v1/text-to-video", "POST"), + data=ExecuteTaskRequest( + prompt=prompt, + model=MODELS_MAP[model], + duration=duration, + resolution=resolution, + fps=fps, + generate_audio=generate_audio, + ), + as_binary=True, + max_retries=1, + ) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response))) + + +class ImageToVideoNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="LtxvApiImageToVideo", + display_name="LTXV Image To Video", + category="api node/video/LTXV", + description="Professional-quality videos with customizable duration and resolution based on start image.", + inputs=[ + IO.Image.Input("image", tooltip="First frame to be used for the video."), + IO.Combo.Input("model", options=list(MODELS_MAP.keys())), + IO.String.Input( + "prompt", + multiline=True, + default="", + ), + IO.Combo.Input("duration", options=[6, 8, 10, 12, 14, 16, 18, 20], default=8), + IO.Combo.Input( + "resolution", + options=[ + "1920x1080", + "2560x1440", + "3840x2160", + ], + ), + IO.Combo.Input("fps", options=[25, 50], default=25), + IO.Boolean.Input( + "generate_audio", + default=False, + optional=True, + tooltip="When true, the generated video will include AI-generated audio matching the scene.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + image: Input.Image, + model: str, + prompt: str, + duration: int, + resolution: str, + fps: int = 25, + generate_audio: bool = False, + ) -> IO.NodeOutput: + validate_string(prompt, min_length=1, max_length=10000) + if duration > 10 and (model != "LTX-2 (Fast)" or resolution != "1920x1080" or fps != 25): + raise ValueError( + "Durations over 10s are only available for the Fast model at 1920x1080 resolution and 25 FPS." + ) + if get_number_of_images(image) != 1: + raise ValueError("Currently only one input image is supported.") + response = await sync_op_raw( + cls, + ApiEndpoint("/proxy/ltx/v1/image-to-video", "POST"), + data=ExecuteTaskRequest( + image_uri=(await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png"))[0], + prompt=prompt, + model=MODELS_MAP[model], + duration=duration, + resolution=resolution, + fps=fps, + generate_audio=generate_audio, + ), + as_binary=True, + max_retries=1, + ) + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response))) + + +class LtxvApiExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TextToVideoNode, + ImageToVideoNode, + ] + + +async def comfy_entrypoint() -> LtxvApiExtension: + return LtxvApiExtension() diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 06c9845cf..b7c8c2eac 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -1,69 +1,51 @@ -from __future__ import annotations -from inspect import cleandoc from typing import Optional + +import torch from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api.input_impl.video_types import VideoFromFile + +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.luma_api import ( - LumaImageModel, - LumaVideoModel, - LumaVideoOutputResolution, - LumaVideoModelOutputDuration, LumaAspectRatio, - LumaState, - LumaImageGenerationRequest, - LumaGenerationRequest, - LumaGeneration, LumaCharacterRef, - LumaModifyImageRef, + LumaConceptChain, + LumaGeneration, + LumaGenerationRequest, + LumaImageGenerationRequest, LumaImageIdentity, + LumaImageModel, + LumaImageReference, + LumaIO, + LumaKeyframes, + LumaModifyImageRef, LumaReference, LumaReferenceChain, - LumaImageReference, - LumaKeyframes, - LumaConceptChain, - LumaIO, + LumaVideoModel, + LumaVideoModelOutputDuration, + LumaVideoOutputResolution, get_luma_concepts, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( + download_url_to_image_tensor, + download_url_to_video_output, + poll_op, + sync_op, upload_images_to_comfyapi, - process_image_response, validate_string, ) -from comfy.cmd.server import PromptServer - -import aiohttp -import torch -from io import BytesIO LUMA_T2V_AVERAGE_DURATION = 105 LUMA_I2V_AVERAGE_DURATION = 100 -def image_result_url_extractor(response: LumaGeneration): - return response.assets.image if hasattr(response, "assets") and hasattr(response.assets, "image") else None - -def video_result_url_extractor(response: LumaGeneration): - return response.assets.video if hasattr(response, "assets") and hasattr(response.assets, "video") else None class LumaReferenceNode(IO.ComfyNode): - """ - Holds an image and weight for use with Luma Generate Image node. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaReferenceNode", display_name="Luma Reference", category="api node/image/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Holds an image and weight for use with Luma Generate Image node.", inputs=[ IO.Image.Input( "image", @@ -83,17 +65,10 @@ class LumaReferenceNode(IO.ComfyNode): ), ], outputs=[IO.Custom(LumaIO.LUMA_REF).Output(display_name="luma_ref")], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], ) @classmethod - def execute( - cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None - ) -> IO.NodeOutput: + def execute(cls, image: torch.Tensor, weight: float, luma_ref: LumaReferenceChain = None) -> IO.NodeOutput: if luma_ref is not None: luma_ref = luma_ref.clone() else: @@ -103,17 +78,13 @@ class LumaReferenceNode(IO.ComfyNode): class LumaConceptsNode(IO.ComfyNode): - """ - Holds one or more Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaConceptsNode", display_name="Luma Concepts", category="api node/video/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Camera Concepts for use with Luma Text to Video and Luma Image to Video nodes.", inputs=[ IO.Combo.Input( "concept1", @@ -138,21 +109,16 @@ class LumaConceptsNode(IO.ComfyNode): ), ], outputs=[IO.Custom(LumaIO.LUMA_CONCEPTS).Output(display_name="luma_concepts")], - hidden=[ - IO.Hidden.auth_token_comfy_org, - IO.Hidden.api_key_comfy_org, - IO.Hidden.unique_id, - ], ) @classmethod def execute( - cls, - concept1: str, - concept2: str, - concept3: str, - concept4: str, - luma_concepts: LumaConceptChain = None, + cls, + concept1: str, + concept2: str, + concept3: str, + concept4: str, + luma_concepts: LumaConceptChain = None, ) -> IO.NodeOutput: chain = LumaConceptChain(str_list=[concept1, concept2, concept3, concept4]) if luma_concepts is not None: @@ -161,17 +127,13 @@ class LumaConceptsNode(IO.ComfyNode): class LumaImageGenerationNode(IO.ComfyNode): - """ - Generates images synchronously based on prompt and aspect ratio. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaImageNode", display_name="Luma Text to Image", category="api node/image/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Generates images synchronously based on prompt and aspect ratio.", inputs=[ IO.String.Input( "prompt", @@ -231,51 +193,36 @@ class LumaImageGenerationNode(IO.ComfyNode): @classmethod async def execute( - cls, - prompt: str, - model: str, - aspect_ratio: str, - seed, - style_image_weight: float, - image_luma_ref: LumaReferenceChain = None, - style_image: torch.Tensor = None, - character_image: torch.Tensor = None, + cls, + prompt: str, + model: str, + aspect_ratio: str, + seed, + style_image_weight: float, + image_luma_ref: Optional[LumaReferenceChain] = None, + style_image: Optional[torch.Tensor] = None, + character_image: Optional[torch.Tensor] = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=True, min_length=3) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } # handle image_luma_ref api_image_ref = None if image_luma_ref is not None: - api_image_ref = await cls._convert_luma_refs( - image_luma_ref, max_refs=4, auth_kwargs=auth_kwargs, - ) + api_image_ref = await cls._convert_luma_refs(image_luma_ref, max_refs=4) # handle style_luma_ref api_style_ref = None if style_image is not None: - api_style_ref = await cls._convert_style_image( - style_image, weight=style_image_weight, auth_kwargs=auth_kwargs, - ) + api_style_ref = await cls._convert_style_image(style_image, weight=style_image_weight) # handle character_ref images character_ref = None if character_image is not None: - download_urls = await upload_images_to_comfyapi( - character_image, max_images=4, auth_kwargs=auth_kwargs, - ) - character_ref = LumaCharacterRef( - identity0=LumaImageIdentity(images=download_urls) - ) + download_urls = await upload_images_to_comfyapi(cls, character_image, max_images=4) + character_ref = LumaCharacterRef(identity0=LumaImageIdentity(images=download_urls)) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations/image", - method=HttpMethod.POST, - request_model=LumaImageGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaImageGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations/image", method="POST"), + response_model=LumaGeneration, + data=LumaImageGenerationRequest( prompt=prompt, model=model, aspect_ratio=aspect_ratio, @@ -283,41 +230,21 @@ class LumaImageGenerationNode(IO.ComfyNode): style_ref=api_style_ref, character_ref=character_ref, ), - auth_kwargs=auth_kwargs, ) - response_api: LumaGeneration = await operation.execute() - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=image_result_url_extractor, - node_id=cls.hidden.unique_id, - auth_kwargs=auth_kwargs, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.assets.image) as img_response: - img = process_image_response(await img_response.content.read()) - return IO.NodeOutput(img) + return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image)) @classmethod - async def _convert_luma_refs( - cls, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None - ): + async def _convert_luma_refs(cls, luma_ref: LumaReferenceChain, max_refs: int): luma_urls = [] ref_count = 0 for ref in luma_ref.refs: - download_urls = await upload_images_to_comfyapi( - ref.image, max_images=1, auth_kwargs=auth_kwargs - ) + download_urls = await upload_images_to_comfyapi(cls, ref.image, max_images=1) luma_urls.append(download_urls[0]) ref_count += 1 if ref_count >= max_refs: @@ -325,27 +252,19 @@ class LumaImageGenerationNode(IO.ComfyNode): return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) @classmethod - async def _convert_style_image( - cls, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None - ): - chain = LumaReferenceChain( - first_ref=LumaReference(image=style_image, weight=weight) - ) - return await cls._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) + async def _convert_style_image(cls, style_image: torch.Tensor, weight: float): + chain = LumaReferenceChain(first_ref=LumaReference(image=style_image, weight=weight)) + return await cls._convert_luma_refs(chain, max_refs=1) class LumaImageModifyNode(IO.ComfyNode): - """ - Modifies images synchronously based on prompt and aspect ratio. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaImageModifyNode", display_name="Luma Image to Image", category="api node/image/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Modifies images synchronously based on prompt and aspect ratio.", inputs=[ IO.Image.Input( "image", @@ -388,75 +307,44 @@ class LumaImageModifyNode(IO.ComfyNode): @classmethod async def execute( - cls, - prompt: str, - model: str, - image: torch.Tensor, - image_weight: float, - seed, + cls, + prompt: str, + model: str, + image: torch.Tensor, + image_weight: float, + seed, ) -> IO.NodeOutput: - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - # first, upload image - download_urls = await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=auth_kwargs, - ) + download_urls = await upload_images_to_comfyapi(cls, image, max_images=1) image_url = download_urls[0] - # next, make Luma call with download url provided - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations/image", - method=HttpMethod.POST, - request_model=LumaImageGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaImageGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations/image", method="POST"), + response_model=LumaGeneration, + data=LumaImageGenerationRequest( prompt=prompt, model=model, modify_image_ref=LumaModifyImageRef( - url=image_url, weight=round(max(min(1.0-image_weight, 0.98), 0.0), 2) + url=image_url, weight=round(max(min(1.0 - image_weight, 0.98), 0.0), 2) ), ), - auth_kwargs=auth_kwargs, ) - response_api: LumaGeneration = await operation.execute() - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=image_result_url_extractor, - node_id=cls.hidden.unique_id, - auth_kwargs=auth_kwargs, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.assets.image) as img_response: - img = process_image_response(await img_response.content.read()) - return IO.NodeOutput(img) + return IO.NodeOutput(await download_url_to_image_tensor(response_poll.assets.image)) class LumaTextToVideoGenerationNode(IO.ComfyNode): - """ - Generates videos synchronously based on prompt and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaVideoNode", display_name="Luma Text to Video", category="api node/video/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on prompt and output_size.", inputs=[ IO.String.Input( "prompt", @@ -498,7 +386,7 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): "luma_concepts", tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", optional=True, - ) + ), ], outputs=[IO.Video.Output()], hidden=[ @@ -511,32 +399,25 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): @classmethod async def execute( - cls, - prompt: str, - model: str, - aspect_ratio: str, - resolution: str, - duration: str, - loop: bool, - seed, - luma_concepts: LumaConceptChain = None, + cls, + prompt: str, + model: str, + aspect_ratio: str, + resolution: str, + duration: str, + loop: bool, + seed, + luma_concepts: Optional[LumaConceptChain] = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, min_length=3) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations", - method=HttpMethod.POST, - request_model=LumaGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations", method="POST"), + response_model=LumaGeneration, + data=LumaGenerationRequest( prompt=prompt, model=model, resolution=resolution, @@ -545,47 +426,25 @@ class LumaTextToVideoGenerationNode(IO.ComfyNode): loop=loop, concepts=luma_concepts.create_api_model() if luma_concepts else None, ), - auth_kwargs=auth_kwargs, ) - response_api: LumaGeneration = await operation.execute() - - if cls.hidden.unique_id: - PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=video_result_url_extractor, - node_id=cls.hidden.unique_id, estimated_duration=LUMA_T2V_AVERAGE_DURATION, - auth_kwargs=auth_kwargs, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.assets.video) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video)) class LumaImageToVideoGenerationNode(IO.ComfyNode): - """ - Generates videos synchronously based on prompt, input images, and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="LumaImageToVideoNode", display_name="Luma Image to Video", category="api node/video/Luma", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on prompt, input images, and output_size.", inputs=[ IO.String.Input( "prompt", @@ -637,7 +496,7 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): "luma_concepts", tooltip="Optional Camera Concepts to dictate camera motion via the Luma Concepts node.", optional=True, - ) + ), ], outputs=[IO.Video.Output()], hidden=[ @@ -650,37 +509,27 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): @classmethod async def execute( - cls, - prompt: str, - model: str, - resolution: str, - duration: str, - loop: bool, - seed, - first_image: torch.Tensor = None, - last_image: torch.Tensor = None, - luma_concepts: LumaConceptChain = None, + cls, + prompt: str, + model: str, + resolution: str, + duration: str, + loop: bool, + seed, + first_image: torch.Tensor = None, + last_image: torch.Tensor = None, + luma_concepts: LumaConceptChain = None, ) -> IO.NodeOutput: if first_image is None and last_image is None: - raise Exception( - "At least one of first_image and last_image requires an input." - ) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - keyframes = await cls._convert_to_keyframes(first_image, last_image, auth_kwargs=auth_kwargs) + raise Exception("At least one of first_image and last_image requires an input.") + keyframes = await cls._convert_to_keyframes(first_image, last_image) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/luma/generations", - method=HttpMethod.POST, - request_model=LumaGenerationRequest, - response_model=LumaGeneration, - ), - request=LumaGenerationRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/luma/generations", method="POST"), + response_model=LumaGeneration, + data=LumaGenerationRequest( prompt=prompt, model=model, aspect_ratio=LumaAspectRatio.ratio_16_9, # ignored, but still needed by the API for some reason @@ -690,54 +539,31 @@ class LumaImageToVideoGenerationNode(IO.ComfyNode): keyframes=keyframes, concepts=luma_concepts.create_api_model() if luma_concepts else None, ), - auth_kwargs=auth_kwargs, ) - response_api: LumaGeneration = await operation.execute() - - if cls.hidden.unique_id: - PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", cls.hidden.unique_id) - - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/luma/generations/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=LumaGeneration, - ), - completed_statuses=[LumaState.completed], - failed_statuses=[LumaState.failed], + response_poll = await poll_op( + cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/luma/generations/{response_api.id}"), + response_model=LumaGeneration, status_extractor=lambda x: x.state, - result_url_extractor=video_result_url_extractor, - node_id=cls.hidden.unique_id, estimated_duration=LUMA_I2V_AVERAGE_DURATION, - auth_kwargs=auth_kwargs, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.assets.video) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.assets.video)) @classmethod async def _convert_to_keyframes( - cls, - first_image: torch.Tensor = None, - last_image: torch.Tensor = None, - auth_kwargs: Optional[dict[str,str]] = None, + cls, + first_image: torch.Tensor = None, + last_image: torch.Tensor = None, ): if first_image is None and last_image is None: return None frame0 = None frame1 = None if first_image is not None: - download_urls = await upload_images_to_comfyapi( - first_image, max_images=1, auth_kwargs=auth_kwargs, - ) + download_urls = await upload_images_to_comfyapi(cls, first_image, max_images=1) frame0 = LumaImageReference(type="image", url=download_urls[0]) if last_image is not None: - download_urls = await upload_images_to_comfyapi( - last_image, max_images=1, auth_kwargs=auth_kwargs, - ) + download_urls = await upload_images_to_comfyapi(cls, last_image, max_images=1) frame1 = LumaImageReference(type="image", url=download_urls[0]) return LumaKeyframes(frame0=frame0, frame1=frame1) diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index d605300d7..2517060bc 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -1,71 +1,57 @@ -from inspect import cleandoc from typing import Optional -import logging -import torch +import torch from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api.input_impl.video_types import VideoFromFile -from comfy_api_nodes.apis import ( + +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apis.minimax_api import ( + MinimaxFileRetrieveResponse, + MiniMaxModel, + MinimaxTaskResultResponse, MinimaxVideoGenerationRequest, MinimaxVideoGenerationResponse, - MinimaxFileRetrieveResponse, - MinimaxTaskResultResponse, SubjectReferenceItem, - MiniMaxModel, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - download_url_to_bytesio, + download_url_to_video_output, + poll_op, + sync_op, upload_images_to_comfyapi, validate_string, ) -from comfy.cmd.server import PromptServer - I2V_AVERAGE_DURATION = 114 T2V_AVERAGE_DURATION = 234 async def _generate_mm_video( - *, - auth: dict[str, str], - node_id: str, - prompt_text: str, - seed: int, - model: str, - image: Optional[torch.Tensor] = None, # used for ImageToVideo - subject: Optional[torch.Tensor] = None, # used for SubjectToVideo - average_duration: Optional[int] = None, + cls: type[IO.ComfyNode], + *, + prompt_text: str, + seed: int, + model: str, + image: Optional[torch.Tensor] = None, # used for ImageToVideo + subject: Optional[torch.Tensor] = None, # used for SubjectToVideo + average_duration: Optional[int] = None, ) -> IO.NodeOutput: if image is None: validate_string(prompt_text, field_name="prompt_text") - # upload image, if passed in image_url = None if image is not None: - image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=auth))[0] + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1))[0] # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model subject_reference = None if subject is not None: - subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=auth))[0] + subject_url = (await upload_images_to_comfyapi(cls, subject, max_images=1))[0] subject_reference = [SubjectReferenceItem(image=subject_url)] - - video_generate_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/video_generation", - method=HttpMethod.POST, - request_model=MinimaxVideoGenerationRequest, - response_model=MinimaxVideoGenerationResponse, - ), - request=MinimaxVideoGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"), + response_model=MinimaxVideoGenerationResponse, + data=MinimaxVideoGenerationRequest( model=MiniMaxModel(model), prompt=prompt_text, callback_url=None, @@ -73,81 +59,50 @@ async def _generate_mm_video( subject_reference=subject_reference, prompt_optimizer=None, ), - auth_kwargs=auth, ) - response = await video_generate_operation.execute() task_id = response.task_id if not task_id: raise Exception(f"MiniMax generation failed: {response.base_resp}") - video_generate_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/minimax/query/video_generation", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxTaskResultResponse, - query_params={"task_id": task_id}, - ), - completed_statuses=["Success"], - failed_statuses=["Fail"], + task_result = await poll_op( + cls, + ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}), + response_model=MinimaxTaskResultResponse, status_extractor=lambda x: x.status.value, estimated_duration=average_duration, - node_id=node_id, - auth_kwargs=auth, ) - task_result = await video_generate_operation.execute() file_id = task_result.file_id if file_id is None: raise Exception("Request was not successful. Missing file ID.") - file_retrieve_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/files/retrieve", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxFileRetrieveResponse, - query_params={"file_id": int(file_id)}, - ), - request=EmptyRequest(), - auth_kwargs=auth, + file_result = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}), + response_model=MinimaxFileRetrieveResponse, ) - file_result = await file_retrieve_operation.execute() file_url = file_result.file.download_url if file_url is None: - raise Exception( - f"No video was found in the response. Full response: {file_result.model_dump()}" - ) - logging.info("Generated video URL: %s", file_url) - if node_id: - if hasattr(file_result.file, "backup_download_url"): - message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" - else: - message = f"Result URL: {file_url}" - PromptServer.instance.send_progress_text(message, node_id) - - # Download and return as VideoFromFile - video_io = await download_url_to_bytesio(file_url) - if video_io is None: - error_msg = f"Failed to download video from {file_url}" - logging.error(error_msg) - raise Exception(error_msg) - return IO.NodeOutput(VideoFromFile(video_io)) + raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}") + if file_result.file.backup_download_url: + try: + return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2)) + except Exception: # if we have a second URL to retrieve the result, try again using that one + return IO.NodeOutput( + await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3) + ) + return IO.NodeOutput(await download_url_to_video_output(file_url)) class MinimaxTextToVideoNode(IO.ComfyNode): - """ - Generates videos synchronously based on a prompt, and optional parameters using MiniMax's API. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="MinimaxTextToVideoNode", display_name="MiniMax Text to Video", category="api node/video/MiniMax", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on a prompt, and optional parameters.", inputs=[ IO.String.Input( "prompt_text", @@ -183,17 +138,13 @@ class MinimaxTextToVideoNode(IO.ComfyNode): @classmethod async def execute( - cls, - prompt_text: str, - model: str = "T2V-01", - seed: int = 0, + cls, + prompt_text: str, + model: str = "T2V-01", + seed: int = 0, ) -> IO.NodeOutput: return await _generate_mm_video( - auth={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt_text=prompt_text, seed=seed, model=model, @@ -204,17 +155,13 @@ class MinimaxTextToVideoNode(IO.ComfyNode): class MinimaxImageToVideoNode(IO.ComfyNode): - """ - Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="MinimaxImageToVideoNode", display_name="MiniMax Image to Video", category="api node/video/MiniMax", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( "image", @@ -254,18 +201,14 @@ class MinimaxImageToVideoNode(IO.ComfyNode): @classmethod async def execute( - cls, - image: torch.Tensor, - prompt_text: str, - model: str = "I2V-01", - seed: int = 0, + cls, + image: torch.Tensor, + prompt_text: str, + model: str = "I2V-01", + seed: int = 0, ) -> IO.NodeOutput: return await _generate_mm_video( - auth={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt_text=prompt_text, seed=seed, model=model, @@ -276,17 +219,13 @@ class MinimaxImageToVideoNode(IO.ComfyNode): class MinimaxSubjectToVideoNode(IO.ComfyNode): - """ - Generates videos synchronously based on an image and prompt, and optional parameters using MiniMax's API. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="MinimaxSubjectToVideoNode", display_name="MiniMax Subject to Video", category="api node/video/MiniMax", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos synchronously based on an image and prompt, and optional parameters.", inputs=[ IO.Image.Input( "subject", @@ -326,18 +265,14 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode): @classmethod async def execute( - cls, - subject: torch.Tensor, - prompt_text: str, - model: str = "S2V-01", - seed: int = 0, + cls, + subject: torch.Tensor, + prompt_text: str, + model: str = "S2V-01", + seed: int = 0, ) -> IO.NodeOutput: return await _generate_mm_video( - auth={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - node_id=cls.hidden.unique_id, + cls, prompt_text=prompt_text, seed=seed, model=model, @@ -348,15 +283,13 @@ class MinimaxSubjectToVideoNode(IO.ComfyNode): class MinimaxHailuoVideoNode(IO.ComfyNode): - """Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.""" - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="MinimaxHailuoVideoNode", display_name="MiniMax Hailuo Video", category="api node/video/MiniMax", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos from prompt, with optional start frame using the new MiniMax Hailuo-02 model.", inputs=[ IO.String.Input( "prompt_text", @@ -411,19 +344,15 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): @classmethod async def execute( - cls, - prompt_text: str, - seed: int = 0, - first_frame_image: Optional[torch.Tensor] = None, # used for ImageToVideo - prompt_optimizer: bool = True, - duration: int = 6, - resolution: str = "768P", - model: str = "MiniMax-Hailuo-02", + cls, + prompt_text: str, + seed: int = 0, + first_frame_image: Optional[torch.Tensor] = None, # used for ImageToVideo + prompt_optimizer: bool = True, + duration: int = 6, + resolution: str = "768P", + model: str = "MiniMax-Hailuo-02", ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } if first_frame_image is None: validate_string(prompt_text, field_name="prompt_text") @@ -435,16 +364,13 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): # upload image, if passed in image_url = None if first_frame_image is not None: - image_url = (await upload_images_to_comfyapi(first_frame_image, max_images=1, auth_kwargs=auth))[0] + image_url = (await upload_images_to_comfyapi(cls, first_frame_image, max_images=1))[0] - video_generate_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/video_generation", - method=HttpMethod.POST, - request_model=MinimaxVideoGenerationRequest, - response_model=MinimaxVideoGenerationResponse, - ), - request=MinimaxVideoGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/video_generation", method="POST"), + response_model=MinimaxVideoGenerationResponse, + data=MinimaxVideoGenerationRequest( model=MiniMaxModel(model), prompt=prompt_text, callback_url=None, @@ -453,67 +379,42 @@ class MinimaxHailuoVideoNode(IO.ComfyNode): duration=duration, resolution=resolution, ), - auth_kwargs=auth, ) - response = await video_generate_operation.execute() task_id = response.task_id if not task_id: raise Exception(f"MiniMax generation failed: {response.base_resp}") average_duration = 120 if resolution == "768P" else 240 - video_generate_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/minimax/query/video_generation", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxTaskResultResponse, - query_params={"task_id": task_id}, - ), - completed_statuses=["Success"], - failed_statuses=["Fail"], + task_result = await poll_op( + cls, + ApiEndpoint(path="/proxy/minimax/query/video_generation", query_params={"task_id": task_id}), + response_model=MinimaxTaskResultResponse, status_extractor=lambda x: x.status.value, estimated_duration=average_duration, - node_id=cls.hidden.unique_id, - auth_kwargs=auth, ) - task_result = await video_generate_operation.execute() file_id = task_result.file_id if file_id is None: raise Exception("Request was not successful. Missing file ID.") - file_retrieve_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/minimax/files/retrieve", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MinimaxFileRetrieveResponse, - query_params={"file_id": int(file_id)}, - ), - request=EmptyRequest(), - auth_kwargs=auth, + file_result = await sync_op( + cls, + ApiEndpoint(path="/proxy/minimax/files/retrieve", query_params={"file_id": int(file_id)}), + response_model=MinimaxFileRetrieveResponse, ) - file_result = await file_retrieve_operation.execute() file_url = file_result.file.download_url if file_url is None: - raise Exception( - f"No video was found in the response. Full response: {file_result.model_dump()}" - ) - logging.info("Generated video URL: %s", file_url) - if cls.hidden.unique_id: - if hasattr(file_result.file, "backup_download_url"): - message = f"Result URL: {file_url}\nBackup URL: {file_result.file.backup_download_url}" - else: - message = f"Result URL: {file_url}" - PromptServer.instance.send_progress_text(message, cls.hidden.unique_id) + raise Exception(f"No video was found in the response. Full response: {file_result.model_dump()}") - video_io = await download_url_to_bytesio(file_url) - if video_io is None: - error_msg = f"Failed to download video from {file_url}" - logging.error(error_msg) - raise Exception(error_msg) - return IO.NodeOutput(VideoFromFile(video_io)) + if file_result.file.backup_download_url: + try: + return IO.NodeOutput(await download_url_to_video_output(file_url, timeout=10, max_retries=2)) + except Exception: # if we have a second URL to retrieve the result, try again using that one + return IO.NodeOutput( + await download_url_to_video_output(file_result.file.backup_download_url, max_retries=3) + ) + return IO.NodeOutput(await download_url_to_video_output(file_url)) class MinimaxExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 7566188dd..2771e4790 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -1,35 +1,28 @@ import logging -from typing import Any, Callable, Optional, TypeVar -import torch -from typing_extensions import override -from comfy_api_nodes.util.validation_utils import validate_image_dimensions +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input from comfy_api_nodes.apis import ( - MoonvalleyTextToVideoRequest, + MoonvalleyPromptResponse, MoonvalleyTextToVideoInferenceParams, + MoonvalleyTextToVideoRequest, MoonvalleyVideoToVideoInferenceParams, MoonvalleyVideoToVideoRequest, - MoonvalleyPromptResponse, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( download_url_to_video_output, + poll_op, + sync_op, + trim_video, upload_images_to_comfyapi, upload_video_to_comfyapi, validate_container_format_is_mp4, + validate_image_dimensions, + validate_string, ) -from comfy_api.input import VideoInput -from comfy_api.latest import ComfyExtension, InputImpl, IO -import av -import io - API_UPLOADS_ENDPOINT = "/proxy/moonvalley/uploads" API_PROMPTS_ENDPOINT = "/proxy/moonvalley/prompts" API_VIDEO2VIDEO_ENDPOINT = "/proxy/moonvalley/prompts/video-to-video" @@ -51,13 +44,6 @@ MAX_VID_HEIGHT = 10000 MAX_VIDEO_SIZE = 1024 * 1024 * 1024 # 1 GB max for in-memory video processing MOONVALLEY_MAREY_MAX_PROMPT_LENGTH = 5000 -R = TypeVar("R") - - -class MoonvalleyApiError(Exception): - """Base exception for Moonvalley API errors.""" - - pass def is_valid_task_creation_response(response: MoonvalleyPromptResponse) -> bool: @@ -69,67 +55,10 @@ def validate_task_creation_response(response) -> None: if not is_valid_task_creation_response(response): error_msg = f"Moonvalley Marey API: Initial request failed. Code: {response.code}, Message: {response.message}, Data: {response}" logging.error(error_msg) - raise MoonvalleyApiError(error_msg) + raise RuntimeError(error_msg) -def get_video_from_response(response): - video = response.output_url - logging.info( - "Moonvalley Marey API: Task %s succeeded. Video URL: %s", response.id, video - ) - return video - - -def get_video_url_from_response(response) -> Optional[str]: - """Returns the first video url from the Moonvalley video generation task result. - Will not raise an error if the response is not valid. - """ - if response: - return str(get_video_from_response(response)) - else: - return None - - -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, R], - result_url_extractor: Optional[Callable[[R], str]] = None, - node_id: Optional[str] = None, -) -> R: - """Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - "completed", - ], - max_poll_attempts=240, # 64 minutes with 16s interval - poll_interval=16.0, - failed_statuses=["error"], - status_extractor=lambda response: ( - response.status if response and response.status else None - ), - auth_kwargs=auth_kwargs, - result_url_extractor=result_url_extractor, - node_id=node_id, - ).execute() - - -def validate_prompts( - prompt: str, negative_prompt: str, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH -): - """Verifies that the prompt isn't empty and that neither prompt is too long.""" - if not prompt: - raise ValueError("Positive prompt is empty") - if len(prompt) > max_length: - raise ValueError(f"Positive prompt is too long: {len(prompt)} characters") - if negative_prompt and len(negative_prompt) > max_length: - raise ValueError( - f"Negative prompt is too long: {len(negative_prompt)} characters" - ) - return True - - -def validate_video_to_video_input(video: VideoInput) -> VideoInput: +def validate_video_to_video_input(video: Input.Video) -> Input.Video: """ Validates and processes video input for Moonvalley Video-to-Video generation. @@ -150,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput: return _validate_and_trim_duration(video) -def _get_video_dimensions(video: VideoInput) -> tuple[int, int]: +def _get_video_dimensions(video: Input.Video) -> tuple[int, int]: """Extracts video dimensions with error handling.""" try: return video.get_dimensions() @@ -170,15 +99,11 @@ def _validate_video_dimensions(width: int, height: int) -> None: } if (width, height) not in supported_resolutions: - supported_list = ", ".join( - [f"{w}x{h}" for w, h in sorted(supported_resolutions)] - ) - raise ValueError( - f"Resolution {width}x{height} not supported. Supported: {supported_list}" - ) + supported_list = ", ".join([f"{w}x{h}" for w, h in sorted(supported_resolutions)]) + raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}") -def _validate_and_trim_duration(video: VideoInput) -> VideoInput: +def _validate_and_trim_duration(video: Input.Video) -> Input.Video: """Validates video duration and trims to 5 seconds if needed.""" duration = video.get_duration() _validate_minimum_duration(duration) @@ -188,133 +113,16 @@ def _validate_and_trim_duration(video: VideoInput) -> VideoInput: def _validate_minimum_duration(duration: float) -> None: """Ensures video is at least 5 seconds long.""" if duration < 5: - raise MoonvalleyApiError("Input video must be at least 5 seconds long.") + raise ValueError("Input video must be at least 5 seconds long.") -def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput: +def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video: """Trims video to 5 seconds if longer.""" if duration > 5: return trim_video(video, 5) return video -def trim_video(video: VideoInput, duration_sec: float) -> VideoInput: - """ - Returns a new VideoInput object trimmed from the beginning to the specified duration, - using av to avoid loading entire video into memory. - - Args: - video: Input video to trim - duration_sec: Duration in seconds to keep from the beginning - - Returns: - VideoFromFile object that owns the output buffer - """ - output_buffer = io.BytesIO() - - input_container = None - output_container = None - - try: - # Get the stream source - this avoids loading entire video into memory - # when the source is already a file path - input_source = video.get_stream_source() - - # Open containers - input_container = av.open(input_source, mode="r") - output_container = av.open(output_buffer, mode="w", format="mp4") - - # Set up output streams for re-encoding - video_stream = None - audio_stream = None - - for stream in input_container.streams: - logging.info("Found stream: type=%s, class=%s", stream.type, type(stream)) - if isinstance(stream, av.VideoStream): - # Create output video stream with same parameters - video_stream = output_container.add_stream( - "h264", rate=stream.average_rate - ) - video_stream.width = stream.width - video_stream.height = stream.height - video_stream.pix_fmt = "yuv420p" - logging.info( - "Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate - ) - elif isinstance(stream, av.AudioStream): - # Create output audio stream with same parameters - audio_stream = output_container.add_stream( - "aac", rate=stream.sample_rate - ) - audio_stream.sample_rate = stream.sample_rate - audio_stream.layout = stream.layout - logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels) - - # Calculate target frame count that's divisible by 16 - fps = input_container.streams.video[0].average_rate - estimated_frames = int(duration_sec * fps) - target_frames = ( - estimated_frames // 16 - ) * 16 # Round down to nearest multiple of 16 - - if target_frames == 0: - raise ValueError("Video too short: need at least 16 frames for Moonvalley") - - frame_count = 0 - audio_frame_count = 0 - - # Decode and re-encode video frames - if video_stream: - for frame in input_container.decode(video=0): - if frame_count >= target_frames: - break - - # Re-encode frame - for packet in video_stream.encode(frame): - output_container.mux(packet) - frame_count += 1 - - # Flush encoder - for packet in video_stream.encode(): - output_container.mux(packet) - - logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames) - - # Decode and re-encode audio frames - if audio_stream: - input_container.seek(0) # Reset to beginning for audio - for frame in input_container.decode(audio=0): - if frame.time >= duration_sec: - break - - # Re-encode frame - for packet in audio_stream.encode(frame): - output_container.mux(packet) - audio_frame_count += 1 - - # Flush encoder - for packet in audio_stream.encode(): - output_container.mux(packet) - - logging.info("Encoded %s audio frames", audio_frame_count) - - # Close containers - output_container.close() - input_container.close() - - # Return as VideoFromFile using the buffer - output_buffer.seek(0) - return InputImpl.VideoFromFile(output_buffer) - - except Exception as e: - # Clean up on error - if input_container is not None: - input_container.close() - if output_container is not None: - output_container.close() - raise RuntimeError(f"Failed to trim video: {str(e)}") from e - - def parse_width_height_from_res(resolution: str): # Accepts a string like "16:9 (1920 x 1080)" and returns width, height as a dict res_map = { @@ -338,19 +146,14 @@ def parse_control_parameter(value): return control_map.get(value, control_map["Motion Transfer"]) -async def get_response( - task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None -) -> MoonvalleyPromptResponse: - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{API_PROMPTS_ENDPOINT}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=MoonvalleyPromptResponse, - ), - result_url_extractor=get_video_url_from_response, - node_id=node_id, +async def get_response(cls: type[IO.ComfyNode], task_id: str) -> MoonvalleyPromptResponse: + return await poll_op( + cls, + ApiEndpoint(path=f"{API_PROMPTS_ENDPOINT}/{task_id}"), + response_model=MoonvalleyPromptResponse, + status_extractor=lambda r: (r.status if r and r.status else None), + poll_interval=16.0, + max_poll_attempts=240, ) @@ -435,7 +238,7 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): @classmethod async def execute( cls, - image: torch.Tensor, + image: Input.Image, prompt: str, negative_prompt: str, resolution: str, @@ -444,14 +247,10 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): steps: int, ) -> IO.NodeOutput: validate_image_dimensions(image, min_width=300, min_height=300, max_height=MAX_HEIGHT, max_width=MAX_WIDTH) - validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = parse_width_height_from_res(resolution) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - inference_params = MoonvalleyTextToVideoInferenceParams( negative_prompt=negative_prompt, steps=steps, @@ -464,33 +263,17 @@ class MoonvalleyImg2VideoNode(IO.ComfyNode): # Get MIME type from tensor - assuming PNG format for image tensors mime_type = "image/png" - - image_url = ( - await upload_images_to_comfyapi( - image, max_images=1, auth_kwargs=auth, mime_type=mime_type - ) - )[0] - - request = MoonvalleyTextToVideoRequest( - image_url=image_url, prompt_text=prompt, inference_params=inference_params - ) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=API_IMG2VIDEO_ENDPOINT, - method=HttpMethod.POST, - request_model=MoonvalleyTextToVideoRequest, - response_model=MoonvalleyPromptResponse, + image_url = (await upload_images_to_comfyapi(cls, image, max_images=1, mime_type=mime_type))[0] + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_IMG2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyTextToVideoRequest( + image_url=image_url, prompt_text=prompt, inference_params=inference_params ), - request=request, - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) - task_id = task_creation_response.id - - final_response = await get_response( - task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id - ) + final_response = await get_response(cls, task_creation_response.id) video = await download_url_to_video_output(final_response.output_url) return IO.NodeOutput(video) @@ -576,21 +359,16 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): prompt: str, negative_prompt: str, seed: int, - video: Optional[VideoInput] = None, + video: Input.Video | None = None, control_type: str = "Motion Transfer", - motion_intensity: Optional[int] = 100, + motion_intensity: int | None = 100, steps=33, prompt_adherence=4.5, ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - validated_video = validate_video_to_video_input(video) - video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=auth) - - validate_prompts(prompt, negative_prompt) + video_url = await upload_video_to_comfyapi(cls, validated_video) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) # Only include motion_intensity for Motion Transfer control_params = {} @@ -605,35 +383,20 @@ class MoonvalleyVideo2VideoNode(IO.ComfyNode): guidance_scale=prompt_adherence, ) - control = parse_control_parameter(control_type) - - request = MoonvalleyVideoToVideoRequest( - control_type=control, - video_url=video_url, - prompt_text=prompt, - inference_params=inference_params, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=API_VIDEO2VIDEO_ENDPOINT, - method=HttpMethod.POST, - request_model=MoonvalleyVideoToVideoRequest, - response_model=MoonvalleyPromptResponse, + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_VIDEO2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyVideoToVideoRequest( + control_type=parse_control_parameter(control_type), + video_url=video_url, + prompt_text=prompt, + inference_params=inference_params, ), - request=request, - auth_kwargs=auth, ) - task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) - task_id = task_creation_response.id - - final_response = await get_response( - task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id - ) - - video = await download_url_to_video_output(final_response.output_url) - return IO.NodeOutput(video) + final_response = await get_response(cls, task_creation_response.id) + return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) class MoonvalleyTxt2VideoNode(IO.ComfyNode): @@ -720,14 +483,10 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode): seed: int, steps: int, ) -> IO.NodeOutput: - validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(prompt, min_length=1, max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) + validate_string(negative_prompt, field_name="negative_prompt", max_length=MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) width_height = parse_width_height_from_res(resolution) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - inference_params = MoonvalleyTextToVideoInferenceParams( negative_prompt=negative_prompt, steps=steps, @@ -737,30 +496,16 @@ class MoonvalleyTxt2VideoNode(IO.ComfyNode): width=width_height["width"], height=width_height["height"], ) - request = MoonvalleyTextToVideoRequest( - prompt_text=prompt, inference_params=inference_params - ) - init_op = SynchronousOperation( - endpoint=ApiEndpoint( - path=API_TXT2VIDEO_ENDPOINT, - method=HttpMethod.POST, - request_model=MoonvalleyTextToVideoRequest, - response_model=MoonvalleyPromptResponse, - ), - request=request, - auth_kwargs=auth, + task_creation_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=API_TXT2VIDEO_ENDPOINT, method="POST"), + response_model=MoonvalleyPromptResponse, + data=MoonvalleyTextToVideoRequest(prompt_text=prompt, inference_params=inference_params), ) - task_creation_response = await init_op.execute() validate_task_creation_response(task_creation_response) - task_id = task_creation_response.id - - final_response = await get_response( - task_id, auth_kwargs=auth, node_id=cls.hidden.unique_id - ) - - video = await download_url_to_video_output(final_response.output_url) - return IO.NodeOutput(video) + final_response = await get_response(cls, task_creation_response.id) + return IO.NodeOutput(await download_url_to_video_output(final_response.output_url)) class MoonvalleyExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index e3a6d3c85..8c58ce4ea 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -1,19 +1,14 @@ -import io -from typing import TypedDict, Optional -import json +from io import BytesIO import os -import time -import re -import uuid from enum import Enum from inspect import cleandoc import numpy as np import torch from PIL import Image -from comfy.comfy_types.node_typing import IO, ComfyNodeABC, InputTypeDict -from comfy.cmd.server import PromptServer from comfy.cmd import folder_paths - +import base64 +from comfy_api.latest import IO, ComfyExtension +from typing_extensions import override from comfy_api_nodes.apis import ( OpenAIImageGenerationRequest, @@ -23,7 +18,6 @@ from comfy_api_nodes.apis import ( OpenAIResponse, CreateModelResponseProperties, Item, - Includable, OutputContent, InputImageContent, Detail, @@ -34,43 +28,21 @@ from comfy_api_nodes.apis import ( InputFileContent, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) - -from comfy_api_nodes.apinode_utils import ( +from comfy_api_nodes.util import ( downscale_image_tensor, - validate_and_cast_response, + download_url_to_bytesio, validate_string, tensor_to_base64_string, + ApiEndpoint, + sync_op, + poll_op, text_filepath_to_data_uri, ) -from comfy_api_nodes.mapper_utils import model_field_to_node_input - RESPONSES_ENDPOINT = "/proxy/openai/v1/responses" STARTING_POINT_ID_PATTERN = r"" -class HistoryEntry(TypedDict): - """Type definition for a single history entry in the chat.""" - - prompt: str - response: str - response_id: str - timestamp: float - - -class ChatHistory(TypedDict): - """Type definition for the chat history dictionary.""" - - __annotations__: dict[str, list[HistoryEntry]] - - class SupportedOpenAIModel(str, Enum): o4_mini = "o4-mini" o1 = "o1" @@ -85,98 +57,123 @@ class SupportedOpenAIModel(str, Enum): gpt_5_nano = "gpt-5-nano" -class OpenAIDalle2(ComfyNodeABC): +async def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor: + """Validates and casts a response to a torch.Tensor. + + Args: + response: The response to validate and cast. + timeout: Request timeout in seconds. Defaults to None (no timeout). + + Returns: + A torch.Tensor representing the image (1, H, W, C). + + Raises: + ValueError: If the response is not valid. + """ + # validate raw JSON response + data = response.data + if not data or len(data) == 0: + raise ValueError("No images returned from API endpoint") + + # Initialize list to store image tensors + image_tensors: list[torch.Tensor] = [] + + # Process each image in the data array + for img_data in data: + if img_data.b64_json: + img_io = BytesIO(base64.b64decode(img_data.b64_json)) + elif img_data.url: + img_io = BytesIO() + await download_url_to_bytesio(img_data.url, img_io, timeout=timeout) + else: + raise ValueError("Invalid image payload – neither URL nor base64 data present.") + + pil_img = Image.open(img_io).convert("RGBA") + arr = np.asarray(pil_img).astype(np.float32) / 255.0 + image_tensors.append(torch.from_numpy(arr)) + + return torch.stack(image_tensors, dim=0) + + +class OpenAIDalle2(IO.ComfyNode): """ Generates images synchronously via OpenAI's DALL·E 2 endpoint. """ - def __init__(self): - pass + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenAIDalle2", + display_name="OpenAI DALL·E 2", + category="api node/image/OpenAI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text prompt for DALL·E", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2 ** 31 - 1, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="not implemented yet in backend", + optional=True, + ), + IO.Combo.Input( + "size", + default="1024x1024", + options=["256x256", "512x512", "1024x1024"], + tooltip="Image size", + optional=True, + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=8, + step=1, + tooltip="How many images to generate", + display_mode=IO.NumberDisplay.number, + optional=True, + ), + IO.Image.Input( + "image", + tooltip="Optional reference image for image editing.", + optional=True, + ), + IO.Mask.Input( + "mask", + tooltip="Optional mask for inpainting (white areas will be replaced)", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text prompt for DALL·E", - }, - ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2**31 - 1, - "step": 1, - "display": "number", - "control_after_generate": True, - "tooltip": "not implemented yet in backend", - }, - ), - "size": ( - IO.COMBO, - { - "options": ["256x256", "512x512", "1024x1024"], - "default": "1024x1024", - "tooltip": "Image size", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "How many images to generate", - }, - ), - "image": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional reference image for image editing.", - }, - ), - "mask": ( - IO.MASK, - { - "default": None, - "tooltip": "Optional mask for inpainting (white areas will be replaced)", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/OpenAI" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, - prompt, - seed=0, - image=None, - mask=None, - n=1, - size="1024x1024", - unique_id=None, - **kwargs, - ): + async def execute( + cls, + prompt, + seed=0, + image=None, + mask=None, + n=1, + size="1024x1024", + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) model = "dall-e-2" path = "/proxy/openai/images/generations" @@ -202,7 +199,7 @@ class OpenAIDalle2(ComfyNodeABC): image_np = (rgba_tensor.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() + img_byte_arr = BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) img_binary = img_byte_arr # .getvalue() @@ -210,15 +207,11 @@ class OpenAIDalle2(ComfyNodeABC): elif image is not None or mask is not None: raise Exception("Dall-E 2 image editing requires an image AND a mask") - # Build the operation - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=request_class, - response_model=OpenAIImageGenerationResponse, - ), - request=request_class( + response = await sync_op( + cls, + ApiEndpoint(path=path, method="POST"), + response_model=OpenAIImageGenerationResponse, + data=request_class( model=model, prompt=prompt, n=n, @@ -227,115 +220,98 @@ class OpenAIDalle2(ComfyNodeABC): ), files=( { - "image": img_binary, + "image": ("image.png", img_binary, "image/png"), } if img_binary else None ), content_type=content_type, - auth_kwargs=kwargs, ) - response = await operation.execute() - - img_tensor = await validate_and_cast_response(response, node_id=unique_id) - return (img_tensor,) + return IO.NodeOutput(await validate_and_cast_response(response)) -class OpenAIDalle3(ComfyNodeABC): +class OpenAIDalle3(IO.ComfyNode): """ Generates images synchronously via OpenAI's DALL·E 3 endpoint. """ - def __init__(self): - pass + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenAIDalle3", + display_name="OpenAI DALL·E 3", + category="api node/image/OpenAI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text prompt for DALL·E", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2 ** 31 - 1, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="not implemented yet in backend", + optional=True, + ), + IO.Combo.Input( + "quality", + default="standard", + options=["standard", "hd"], + tooltip="Image quality", + optional=True, + ), + IO.Combo.Input( + "style", + default="natural", + options=["natural", "vivid"], + tooltip="Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.", + optional=True, + ), + IO.Combo.Input( + "size", + default="1024x1024", + options=["1024x1024", "1024x1792", "1792x1024"], + tooltip="Image size", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text prompt for DALL·E", - }, - ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2**31 - 1, - "step": 1, - "display": "number", - "control_after_generate": True, - "tooltip": "not implemented yet in backend", - }, - ), - "quality": ( - IO.COMBO, - { - "options": ["standard", "hd"], - "default": "standard", - "tooltip": "Image quality", - }, - ), - "style": ( - IO.COMBO, - { - "options": ["natural", "vivid"], - "default": "natural", - "tooltip": "Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images.", - }, - ), - "size": ( - IO.COMBO, - { - "options": ["1024x1024", "1024x1792", "1792x1024"], - "default": "1024x1024", - "tooltip": "Image size", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/OpenAI" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, - prompt, - seed=0, - style="natural", - quality="standard", - size="1024x1024", - unique_id=None, - **kwargs, - ): + async def execute( + cls, + prompt, + seed=0, + style="natural", + quality="standard", + size="1024x1024", + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) model = "dall-e-3" # build the operation - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/openai/images/generations", - method=HttpMethod.POST, - request_model=OpenAIImageGenerationRequest, - response_model=OpenAIImageGenerationResponse, - ), - request=OpenAIImageGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/openai/images/generations", method="POST"), + response_model=OpenAIImageGenerationResponse, + data=OpenAIImageGenerationRequest( model=model, prompt=prompt, quality=quality, @@ -343,125 +319,106 @@ class OpenAIDalle3(ComfyNodeABC): style=style, seed=seed, ), - auth_kwargs=kwargs, ) - response = await operation.execute() - - img_tensor = await validate_and_cast_response(response, node_id=unique_id) - return (img_tensor,) + return IO.NodeOutput(await validate_and_cast_response(response)) -class OpenAIGPTImage1(ComfyNodeABC): +class OpenAIGPTImage1(IO.ComfyNode): """ Generates images synchronously via OpenAI's GPT Image 1 endpoint. """ - def __init__(self): - pass + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenAIGPTImage1", + display_name="OpenAI GPT Image 1", + category="api node/image/OpenAI", + description=cleandoc(cls.__doc__ or ""), + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text prompt for GPT Image 1", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2 ** 31 - 1, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="not implemented yet in backend", + optional=True, + ), + IO.Combo.Input( + "quality", + default="low", + options=["low", "medium", "high"], + tooltip="Image quality, affects cost and generation time.", + optional=True, + ), + IO.Combo.Input( + "background", + default="opaque", + options=["opaque", "transparent"], + tooltip="Return image with or without background", + optional=True, + ), + IO.Combo.Input( + "size", + default="auto", + options=["auto", "1024x1024", "1024x1536", "1536x1024"], + tooltip="Image size", + optional=True, + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=8, + step=1, + tooltip="How many images to generate", + display_mode=IO.NumberDisplay.number, + optional=True, + ), + IO.Image.Input( + "image", + tooltip="Optional reference image for image editing.", + optional=True, + ), + IO.Mask.Input( + "mask", + tooltip="Optional mask for inpainting (white areas will be replaced)", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text prompt for GPT Image 1", - }, - ), - }, - "optional": { - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 2**31 - 1, - "step": 1, - "display": "number", - "control_after_generate": True, - "tooltip": "not implemented yet in backend", - }, - ), - "quality": ( - IO.COMBO, - { - "options": ["low", "medium", "high"], - "default": "low", - "tooltip": "Image quality, affects cost and generation time.", - }, - ), - "background": ( - IO.COMBO, - { - "options": ["opaque", "transparent"], - "default": "opaque", - "tooltip": "Return image with or without background", - }, - ), - "size": ( - IO.COMBO, - { - "options": ["auto", "1024x1024", "1024x1536", "1536x1024"], - "default": "auto", - "tooltip": "Image size", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 8, - "step": 1, - "display": "number", - "tooltip": "How many images to generate", - }, - ), - "image": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional reference image for image editing.", - }, - ), - "mask": ( - IO.MASK, - { - "default": None, - "tooltip": "Optional mask for inpainting (white areas will be replaced)", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = (IO.IMAGE,) - FUNCTION = "api_call" - CATEGORY = "api node/image/OpenAI" - DESCRIPTION = cleandoc(__doc__ or "") - API_NODE = True - - async def api_call( - self, - prompt, - seed=0, - quality="low", - background="opaque", - image=None, - mask=None, - n=1, - size="1024x1024", - unique_id=None, - **kwargs, - ): + async def execute( + cls, + prompt, + seed=0, + quality="low", + background="opaque", + image=None, + mask=None, + n=1, + size="1024x1024", + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) model = "gpt-image-1" path = "/proxy/openai/images/generations" @@ -477,12 +434,12 @@ class OpenAIGPTImage1(ComfyNodeABC): batch_size = image.shape[0] for i in range(batch_size): - single_image = image[i : i + 1] + single_image = image[i: i + 1] scaled_image = downscale_image_tensor(single_image).squeeze() image_np = (scaled_image.numpy() * 255).astype(np.uint8) img = Image.fromarray(image_np) - img_byte_arr = io.BytesIO() + img_byte_arr = BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) @@ -506,20 +463,17 @@ class OpenAIGPTImage1(ComfyNodeABC): mask_np = (scaled_mask.numpy() * 255).astype(np.uint8) mask_img = Image.fromarray(mask_np) - mask_img_byte_arr = io.BytesIO() + mask_img_byte_arr = BytesIO() mask_img.save(mask_img_byte_arr, format="PNG") mask_img_byte_arr.seek(0) files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) # Build the operation - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=request_class, - response_model=OpenAIImageGenerationResponse, - ), - request=request_class( + response = await sync_op( + cls, + ApiEndpoint(path=path, method="POST"), + response_model=OpenAIImageGenerationResponse, + data=request_class( model=model, prompt=prompt, quality=quality, @@ -530,127 +484,70 @@ class OpenAIGPTImage1(ComfyNodeABC): ), files=files if files else None, content_type=content_type, - auth_kwargs=kwargs, ) - response = await operation.execute() - - img_tensor = await validate_and_cast_response(response, node_id=unique_id) - return (img_tensor,) + return IO.NodeOutput(await validate_and_cast_response(response)) -class OpenAITextNode(ComfyNodeABC): - """ - Base class for OpenAI text generation nodes. - """ - - RETURN_TYPES = (IO.STRING,) - FUNCTION = "api_call" - CATEGORY = "api node/text/OpenAI" - API_NODE = True - - -class OpenAIChatNode(OpenAITextNode): +class OpenAIChatNode(IO.ComfyNode): """ Node to generate text responses from an OpenAI model. """ - def __init__(self) -> None: - """Initialize the chat node with a new session ID and empty history.""" - self.current_session_id: str = str(uuid.uuid4()) - self.history: dict[str, list[HistoryEntry]] = {} - self.previous_response_id: Optional[str] = None + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="OpenAIChatNode", + display_name="OpenAI ChatGPT", + category="api node/text/OpenAI", + description="Generate text responses from an OpenAI model.", + inputs=[ + IO.String.Input( + "prompt", + default="", + multiline=True, + tooltip="Text inputs to the model, used to generate a response.", + ), + IO.Boolean.Input( + "persist_context", + default=False, + tooltip="This parameter is deprecated and has no effect.", + ), + IO.Combo.Input( + "model", + options=SupportedOpenAIModel, + tooltip="The model used to generate the response", + ), + IO.Image.Input( + "images", + tooltip="Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", + optional=True, + ), + IO.Custom("OPENAI_INPUT_FILES").Input( + "files", + optional=True, + tooltip="Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.", + ), + IO.Custom("OPENAI_CHAT_CONFIG").Input( + "advanced_options", + optional=True, + tooltip="Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.", + ), + ], + outputs=[ + IO.String.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Text inputs to the model, used to generate a response.", - }, - ), - "persist_context": ( - IO.BOOLEAN, - { - "default": True, - "tooltip": "Persist chat context between calls (multi-turn conversation)", - }, - ), - "model": model_field_to_node_input( - IO.COMBO, - OpenAICreateResponse, - "model", - enum_type=SupportedOpenAIModel, - ), - }, - "optional": { - "images": ( - IO.IMAGE, - { - "default": None, - "tooltip": "Optional image(s) to use as context for the model. To include multiple images, you can use the Batch Images node.", - }, - ), - "files": ( - "OPENAI_INPUT_FILES", - { - "default": None, - "tooltip": "Optional file(s) to use as context for the model. Accepts inputs from the OpenAI Chat Input Files node.", - }, - ), - "advanced_options": ( - "OPENAI_CHAT_CONFIG", - { - "default": None, - "tooltip": "Optional configuration for the model. Accepts inputs from the OpenAI Chat Advanced Options node.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - DESCRIPTION = "Generate text responses from an OpenAI model." - - async def get_result_response( - self, - response_id: str, - include: Optional[list[Includable]] = None, - auth_kwargs: Optional[dict[str, str]] = None, - ) -> OpenAIResponse: - """ - Retrieve a model response with the given ID from the OpenAI API. - - Args: - response_id (str): The ID of the response to retrieve. - include (Optional[List[Includable]]): Additional fields to include - in the response. See the `include` parameter for Response - creation above for more information. - - """ - return await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"{RESPONSES_ENDPOINT}/{response_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=OpenAIResponse, - query_params={"include": include}, - ), - completed_statuses=["completed"], - failed_statuses=["failed"], - status_extractor=lambda response: response.status, - auth_kwargs=auth_kwargs, - ).execute() - def get_message_content_from_response( - self, response: OpenAIResponse + cls, response: OpenAIResponse ) -> list[OutputContent]: """Extract message content from the API response.""" for output in response.output: @@ -658,8 +555,9 @@ class OpenAIChatNode(OpenAITextNode): return output.root.content raise TypeError("No output message found in response") + @classmethod def get_text_from_message_content( - self, message_content: list[OutputContent] + cls, message_content: list[OutputContent] ) -> str: """Extract text content from message content.""" for content_item in message_content: @@ -667,58 +565,9 @@ class OpenAIChatNode(OpenAITextNode): return str(content_item.root.text) return "No text output found in response" - def get_history_text(self, session_id: str) -> str: - """Convert the entire history for a given session to JSON string.""" - return json.dumps(self.history[session_id]) - - def display_history_on_node(self, session_id: str, node_id: str) -> None: - """Display formatted chat history on the node UI.""" - render_spec = { - "node_id": node_id, - "component": "ChatHistoryWidget", - "props": { - "history": self.get_history_text(session_id), - }, - } - PromptServer.instance.send_sync( - "display_component", - render_spec, - ) - - def add_to_history( - self, session_id: str, prompt: str, output_text: str, response_id: str - ) -> None: - """Add a new entry to the chat history.""" - if session_id not in self.history: - self.history[session_id] = [] - self.history[session_id].append( - { - "prompt": prompt, - "response": output_text, - "response_id": response_id, - "timestamp": time.time(), - } - ) - - def parse_output_text_from_response(self, response: OpenAIResponse) -> str: - """Extract text output from the API response.""" - message_contents = self.get_message_content_from_response(response) - return self.get_text_from_message_content(message_contents) - - def generate_new_session_id(self) -> str: - """Generate a new unique session ID.""" - return str(uuid.uuid4()) - - def get_session_id(self, persist_context: bool) -> str: - """Get the current or generate a new session ID based on context persistence.""" - return ( - self.current_session_id - if persist_context - else self.generate_new_session_id() - ) - + @classmethod def tensor_to_input_image_content( - self, image: torch.Tensor, detail_level: Detail = "auto" + cls, image: torch.Tensor, detail_level: Detail = "auto" ) -> InputImageContent: """Convert a tensor to an input image content object.""" return InputImageContent( @@ -727,21 +576,27 @@ class OpenAIChatNode(OpenAITextNode): type="input_image", ) + @classmethod def create_input_message_contents( - self, - prompt: str, - image: Optional[torch.Tensor] = None, - files: Optional[list[InputFileContent]] = None, + cls, + prompt: str, + image: torch.Tensor | None = None, + files: list[InputFileContent] | None = None, ) -> InputMessageContentList: """Create a list of input message contents from prompt and optional image.""" - content_list: list[InputContent] = [ + content_list: list[InputContent | InputTextContent | InputImageContent | InputFileContent] = [ InputTextContent(text=prompt, type="input_text"), ] if image is not None: for i in range(image.shape[0]): content_list.append( - self.tensor_to_input_image_content(image[i].unsqueeze(0)) + InputImageContent( + detail="auto", + image_url=f"data:image/png;base64,{tensor_to_base64_string(image[i].unsqueeze(0))}", + type="input_image", + ) ) + if files is not None: content_list.extend(files) @@ -749,80 +604,28 @@ class OpenAIChatNode(OpenAITextNode): root=content_list, ) - def parse_response_id_from_prompt(self, prompt: str) -> Optional[str]: - """Extract response ID from prompt if it exists.""" - parsed_id = re.search(STARTING_POINT_ID_PATTERN, prompt) - return parsed_id.group(1) if parsed_id else None - - def strip_response_tag_from_prompt(self, prompt: str) -> str: - """Remove the response ID tag from the prompt.""" - return re.sub(STARTING_POINT_ID_PATTERN, "", prompt.strip()) - - def delete_history_after_response_id( - self, new_start_id: str, session_id: str - ) -> None: - """Delete history entries after a specific response ID.""" - if session_id not in self.history: - return - - new_history = [] - i = 0 - while ( - i < len(self.history[session_id]) - and self.history[session_id][i]["response_id"] != new_start_id - ): - new_history.append(self.history[session_id][i]) - i += 1 - - # Since it's the new starting point (not the response being edited), we include it as well - if i < len(self.history[session_id]): - new_history.append(self.history[session_id][i]) - - self.history[session_id] = new_history - - async def api_call( - self, - prompt: str, - persist_context: bool, - model: SupportedOpenAIModel, - unique_id: Optional[str] = None, - images: Optional[torch.Tensor] = None, - files: Optional[list[InputFileContent]] = None, - advanced_options: Optional[CreateModelResponseProperties] = None, - **kwargs, - ) -> tuple[str]: - # Validate inputs + @classmethod + async def execute( + cls, + prompt: str, + persist_context: bool = False, + model: SupportedOpenAIModel = SupportedOpenAIModel.gpt_5.value, + images: torch.Tensor | None = None, + files: list[InputFileContent] | None = None, + advanced_options: CreateModelResponseProperties | None = None, + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - session_id = self.get_session_id(persist_context) - response_id_override = self.parse_response_id_from_prompt(prompt) - if response_id_override: - is_starting_from_beginning = response_id_override == "start" - if is_starting_from_beginning: - self.history[session_id] = [] - previous_response_id = None - else: - previous_response_id = response_id_override - self.delete_history_after_response_id(response_id_override, session_id) - prompt = self.strip_response_tag_from_prompt(prompt) - elif persist_context: - previous_response_id = self.previous_response_id - else: - previous_response_id = None - # Create response - create_response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=RESPONSES_ENDPOINT, - method=HttpMethod.POST, - request_model=OpenAICreateResponse, - response_model=OpenAIResponse, - ), - request=OpenAICreateResponse( + create_response = await sync_op( + cls, + ApiEndpoint(path=RESPONSES_ENDPOINT, method="POST"), + response_model=OpenAIResponse, + data=OpenAICreateResponse( input=[ Item( root=InputMessage( - content=self.create_input_message_contents( + content=cls.create_input_message_contents( prompt, images, files ), role="user", @@ -832,36 +635,34 @@ class OpenAIChatNode(OpenAITextNode): store=True, stream=False, model=model, - previous_response_id=previous_response_id, + previous_response_id=None, **( advanced_options.model_dump(exclude_none=True) if advanced_options else {} ), ), - auth_kwargs=kwargs, - ).execute() + ) response_id = create_response.id # Get result output - result_response = await self.get_result_response(response_id, auth_kwargs=kwargs) - output_text = self.parse_output_text_from_response(result_response) - - # Update history - self.add_to_history(session_id, prompt, output_text, response_id) - self.display_history_on_node(session_id, unique_id) - self.previous_response_id = response_id - - return (output_text,) + result_response = await poll_op( + cls, + ApiEndpoint(path=f"{RESPONSES_ENDPOINT}/{response_id}"), + response_model=OpenAIResponse, + status_extractor=lambda response: response.status, + completed_statuses=["incomplete", "completed"] + ) + return IO.NodeOutput(cls.get_text_from_message_content(cls.get_message_content_from_response(result_response))) -class OpenAIInputFiles(ComfyNodeABC): +class OpenAIInputFiles(IO.ComfyNode): """ Loads and formats input files for OpenAI API. """ @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: + def define_schema(cls): """ For details about the supported file input types, see: https://platform.openai.com/docs/guides/pdf-files?api-mode=responses @@ -871,102 +672,97 @@ class OpenAIInputFiles(ComfyNodeABC): f for f in os.scandir(input_dir) if f.is_file() - and (f.name.endswith(".txt") or f.name.endswith(".pdf")) - and f.stat().st_size < 32 * 1024 * 1024 + and (f.name.endswith(".txt") or f.name.endswith(".pdf")) + and f.stat().st_size < 32 * 1024 * 1024 ] input_files = sorted(input_files, key=lambda x: x.name) input_files = [f.name for f in input_files] - return { - "required": { - "file": ( - IO.COMBO, - { - "tooltip": "Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.", - "options": input_files, - "default": input_files[0] if input_files else None, - }, + return IO.Schema( + node_id="OpenAIInputFiles", + display_name="OpenAI ChatGPT Input Files", + category="api node/text/OpenAI", + description="Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes.", + inputs=[ + IO.Combo.Input( + "file", + options=input_files, + default=input_files[0] if input_files else None, + tooltip="Input files to include as context for the model. Only accepts text (.txt) and PDF (.pdf) files for now.", ), - }, - "optional": { - "OPENAI_INPUT_FILES": ( + IO.Custom("OPENAI_INPUT_FILES").Input( "OPENAI_INPUT_FILES", - { - "tooltip": "An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.", - "default": None, - }, + tooltip="An optional additional file(s) to batch together with the file loaded from this node. Allows chaining of input files so that a single message can include multiple input files.", + optional=True, ), - }, - } + ], + outputs=[ + IO.Custom("OPENAI_INPUT_FILES").Output(), + ], + ) - DESCRIPTION = "Loads and prepares input files (text, pdf, etc.) to include as inputs for the OpenAI Chat Node. The files will be read by the OpenAI model when generating a response. 🛈 TIP: Can be chained together with other OpenAI Input File nodes." - RETURN_TYPES = ("OPENAI_INPUT_FILES",) - FUNCTION = "prepare_files" - CATEGORY = "api node/text/OpenAI" - - def create_input_file_content(self, file_path: str) -> InputFileContent: + @classmethod + def create_input_file_content(cls, file_path: str) -> InputFileContent: return InputFileContent( file_data=text_filepath_to_data_uri(file_path), filename=os.path.basename(file_path), type="input_file", ) - def prepare_files( - self, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = [] - ) -> tuple[list[InputFileContent]]: + @classmethod + def execute(cls, file: str, OPENAI_INPUT_FILES: list[InputFileContent] = []) -> IO.NodeOutput: """ Loads and formats input files for OpenAI API. """ file_path = folder_paths.get_annotated_filepath(file) - input_file_content = self.create_input_file_content(file_path) + input_file_content = cls.create_input_file_content(file_path) files = [input_file_content] + OPENAI_INPUT_FILES - return (files,) + return IO.NodeOutput(files) -class OpenAIChatConfig(ComfyNodeABC): +class OpenAIChatConfig(IO.ComfyNode): """Allows setting additional configuration for the OpenAI Chat Node.""" - RETURN_TYPES = ("OPENAI_CHAT_CONFIG",) - FUNCTION = "configure" - DESCRIPTION = ( - "Allows specifying advanced configuration options for the OpenAI Chat Nodes." - ) - CATEGORY = "api node/text/OpenAI" - @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "truncation": ( - IO.COMBO, - { - "options": ["auto", "disabled"], - "default": "auto", - "tooltip": "The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error", - }, + def define_schema(cls): + return IO.Schema( + node_id="OpenAIChatConfig", + display_name="OpenAI ChatGPT Advanced Options", + category="api node/text/OpenAI", + description="Allows specifying advanced configuration options for the OpenAI Chat Nodes.", + inputs=[ + IO.Combo.Input( + "truncation", + options=["auto", "disabled"], + default="auto", + tooltip="The truncation strategy to use for the model response. auto: If the context of this response and previous ones exceeds the model's context window size, the model will truncate the response to fit the context window by dropping input items in the middle of the conversation.disabled: If a model response will exceed the context window size for a model, the request will fail with a 400 error", ), - }, - "optional": { - "max_output_tokens": model_field_to_node_input( - IO.INT, - OpenAICreateResponse, + IO.Int.Input( "max_output_tokens", min=16, default=4096, max=16384, tooltip="An upper bound for the number of tokens that can be generated for a response, including visible output tokens", + optional=True, ), - "instructions": model_field_to_node_input( - IO.STRING, OpenAICreateResponse, "instructions", multiline=True + IO.String.Input( + "instructions", + multiline=True, + optional=True, + tooltip="Instructions for the model on how to generate the response", ), - }, - } + ], + outputs=[ + IO.Custom("OPENAI_CHAT_CONFIG").Output(), + ], + ) - def configure( - self, - truncation: bool, - instructions: Optional[str] = None, - max_output_tokens: Optional[int] = None, - ) -> tuple[CreateModelResponseProperties]: + @classmethod + def execute( + cls, + truncation: bool, + instructions: str | None = None, + max_output_tokens: int | None = None, + ) -> IO.NodeOutput: """ Configure advanced options for the OpenAI Chat Node. @@ -976,29 +772,27 @@ class OpenAIChatConfig(ComfyNodeABC): They are not exposed as inputs at all to avoid having to manually remove depending on model choice. """ - return ( + return IO.NodeOutput( CreateModelResponseProperties( instructions=instructions, truncation=truncation, max_output_tokens=max_output_tokens, - ), + ) ) -NODE_CLASS_MAPPINGS = { - "OpenAIDalle2": OpenAIDalle2, - "OpenAIDalle3": OpenAIDalle3, - "OpenAIGPTImage1": OpenAIGPTImage1, - "OpenAIChatNode": OpenAIChatNode, - "OpenAIInputFiles": OpenAIInputFiles, - "OpenAIChatConfig": OpenAIChatConfig, -} +class OpenAIExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + OpenAIDalle2, + OpenAIDalle3, + OpenAIGPTImage1, + OpenAIChatNode, + OpenAIInputFiles, + OpenAIChatConfig, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "OpenAIDalle2": "OpenAI DALL·E 2", - "OpenAIDalle3": "OpenAI DALL·E 3", - "OpenAIGPTImage1": "OpenAI GPT Image 1", - "OpenAIChatNode": "OpenAI ChatGPT", - "OpenAIInputFiles": "OpenAI ChatGPT Input Files", - "OpenAIChatConfig": "OpenAI ChatGPT Advanced Options", -} + +async def comfy_entrypoint() -> OpenAIExtension: + return OpenAIExtension() diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index 27cb0067b..acd88c391 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -7,28 +7,23 @@ from __future__ import annotations from io import BytesIO import logging -from typing import Optional, TypeVar +from typing import Optional import torch from typing_extensions import override from comfy_api.latest import ComfyExtension, IO from comfy_api.input_impl.video_types import VideoCodec, VideoContainer, VideoInput -from comfy_api_nodes.apinode_utils import ( +from comfy_api_nodes.apis import pika_api as pika_defs +from comfy_api_nodes.util import ( + validate_string, download_url_to_video_output, tensor_to_bytesio, - validate_string, -) -from comfy_api_nodes.apis import pika_defs -from comfy_api_nodes.apis.client import ( ApiEndpoint, - EmptyRequest, - HttpMethod, - PollingOperation, - SynchronousOperation, + sync_op, + poll_op, ) -R = TypeVar("R") PATH_PIKADDITIONS = "/proxy/pika/generate/pikadditions" PATH_PIKASWAPS = "/proxy/pika/generate/pikaswaps" @@ -44,28 +39,18 @@ PATH_VIDEO_GET = "/proxy/pika/videos" async def execute_task( - initial_operation: SynchronousOperation[R, pika_defs.PikaGenerateResponse], - auth_kwargs: Optional[dict[str, str]] = None, - node_id: Optional[str] = None, + task_id: str, + cls: type[IO.ComfyNode], ) -> IO.NodeOutput: - task_id = (await initial_operation.execute()).video_id - final_response: pika_defs.PikaVideoResponse = await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"{PATH_VIDEO_GET}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=pika_defs.PikaVideoResponse, - ), - completed_statuses=["finished"], - failed_statuses=["failed", "cancelled"], + final_response: pika_defs.PikaVideoResponse = await poll_op( + cls, + ApiEndpoint(path=f"{PATH_VIDEO_GET}/{task_id}"), + response_model=pika_defs.PikaVideoResponse, status_extractor=lambda response: (response.status.value if response.status else None), progress_extractor=lambda response: (response.progress if hasattr(response, "progress") else None), - auth_kwargs=auth_kwargs, - result_url_extractor=lambda response: (response.url if hasattr(response, "url") else None), - node_id=node_id, estimated_duration=60, max_poll_attempts=240, - ).execute() + ) if not final_response.url: error_msg = f"Pika task {task_id} succeeded but no video data found in response:\n{final_response}" logging.error(error_msg) @@ -107,6 +92,7 @@ class PikaImageToVideo(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -128,23 +114,15 @@ class PikaImageToVideo(IO.ComfyNode): resolution=resolution, duration=duration, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGenerate22I2vGenerate22I2vPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_request_data, + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_request_data, files=pika_files, content_type="multipart/form-data", - auth_kwargs=auth, ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikaTextToVideoNode(IO.ComfyNode): @@ -175,6 +153,7 @@ class PikaTextToVideoNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -187,18 +166,11 @@ class PikaTextToVideoNode(IO.ComfyNode): duration: int, aspect_ratio: float, ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_VIDEO, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost( + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_TEXT_TO_VIDEO, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_defs.PikaBodyGenerate22T2vGenerate22T2vPost( promptText=prompt_text, negativePrompt=negative_prompt, seed=seed, @@ -206,10 +178,9 @@ class PikaTextToVideoNode(IO.ComfyNode): duration=duration, aspectRatio=aspect_ratio, ), - auth_kwargs=auth, content_type="application/x-www-form-urlencoded", ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikaScenes(IO.ComfyNode): @@ -270,6 +241,7 @@ class PikaScenes(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -313,24 +285,16 @@ class PikaScenes(IO.ComfyNode): duration=duration, aspectRatio=aspect_ratio, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKASCENES, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGenerate22C2vGenerate22PikascenesPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_request_data, + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_PIKASCENES, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_request_data, files=pika_files, content_type="multipart/form-data", - auth_kwargs=auth, ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikAdditionsNode(IO.ComfyNode): @@ -362,6 +326,7 @@ class PikAdditionsNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -387,24 +352,16 @@ class PikAdditionsNode(IO.ComfyNode): negativePrompt=negative_prompt, seed=seed, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKADDITIONS, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGeneratePikadditionsGeneratePikadditionsPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_request_data, + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_PIKADDITIONS, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_request_data, files=pika_files, content_type="multipart/form-data", - auth_kwargs=auth, ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikaSwapsNode(IO.ComfyNode): @@ -446,6 +403,7 @@ class PikaSwapsNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -476,23 +434,15 @@ class PikaSwapsNode(IO.ComfyNode): seed=seed, modifyRegionRoi=region_to_modify if region_to_modify else None, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKASWAPS, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGeneratePikaswapsGeneratePikaswapsPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_request_data, + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_PIKASWAPS, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_request_data, files=pika_files, content_type="multipart/form-data", - auth_kwargs=auth, ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikaffectsNode(IO.ComfyNode): @@ -521,6 +471,7 @@ class PikaffectsNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -532,18 +483,11 @@ class PikaffectsNode(IO.ComfyNode): negative_prompt: str, seed: int, ) -> IO.NodeOutput: - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKAFFECTS, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost( + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_PIKAFFECTS, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_defs.PikaBodyGeneratePikaffectsGeneratePikaffectsPost( pikaffect=pikaffect, promptText=prompt_text, negativePrompt=negative_prompt, @@ -551,9 +495,8 @@ class PikaffectsNode(IO.ComfyNode): ), files={"image": ("image.png", tensor_to_bytesio(image), "image/png")}, content_type="multipart/form-data", - auth_kwargs=auth, ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikaStartEndFrameNode(IO.ComfyNode): @@ -578,6 +521,7 @@ class PikaStartEndFrameNode(IO.ComfyNode): IO.Hidden.unique_id, ], is_api_node=True, + is_deprecated=True, ) @classmethod @@ -596,18 +540,11 @@ class PikaStartEndFrameNode(IO.ComfyNode): ("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")), ("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")), ] - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_PIKAFRAMES, - method=HttpMethod.POST, - request_model=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost, - response_model=pika_defs.PikaGenerateResponse, - ), - request=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost( + initial_operation = await sync_op( + cls, + ApiEndpoint(path=PATH_PIKAFRAMES, method="POST"), + response_model=pika_defs.PikaGenerateResponse, + data=pika_defs.PikaBodyGenerate22KeyframeGenerate22PikaframesPost( promptText=prompt_text, negativePrompt=negative_prompt, seed=seed, @@ -616,9 +553,8 @@ class PikaStartEndFrameNode(IO.ComfyNode): ), files=pika_files, content_type="multipart/form-data", - auth_kwargs=auth, ) - return await execute_task(initial_operation, auth_kwargs=auth, node_id=cls.hidden.unique_id) + return await execute_task(initial_operation.video_id, cls) class PikaApiNodesExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index 438a7f80b..6e1686af0 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -1,7 +1,6 @@ -from inspect import cleandoc -from typing import Optional +import torch from typing_extensions import override -from io import BytesIO +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.pixverse_api import ( PixverseTextVideoRequest, PixverseImageVideoRequest, @@ -17,59 +16,30 @@ from comfy_api_nodes.apis.pixverse_api import ( PixverseIO, pixverse_templates, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( + download_url_to_video_output, + poll_op, + sync_op, tensor_to_bytesio, validate_string, ) -from comfy_api.input_impl import VideoFromFile -from comfy_api.latest import ComfyExtension, IO - -import torch -import aiohttp - AVERAGE_DURATION_T2V = 32 AVERAGE_DURATION_I2V = 30 AVERAGE_DURATION_T2T = 52 -def get_video_url_from_response( - response: PixverseGenerationStatusResponse, -) -> Optional[str]: - if response.Resp is None or response.Resp.url is None: - return None - return str(response.Resp.url) - - -async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): - # first, upload image to Pixverse and get image id to use in actual generation call - files = {"image": tensor_to_bytesio(image)} - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/image/upload", - method=HttpMethod.POST, - request_model=EmptyRequest, - response_model=PixverseImageUploadResponse, - ), - request=EmptyRequest(), - files=files, +async def upload_image_to_pixverse(cls: type[IO.ComfyNode], image: torch.Tensor): + response_upload = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/image/upload", method="POST"), + response_model=PixverseImageUploadResponse, + files={"image": tensor_to_bytesio(image)}, content_type="multipart/form-data", - auth_kwargs=auth_kwargs, ) - response_upload: PixverseImageUploadResponse = await operation.execute() - if response_upload.Resp is None: - raise Exception( - f"PixVerse image upload request failed: '{response_upload.ErrMsg}'" - ) - + raise Exception(f"PixVerse image upload request failed: '{response_upload.ErrMsg}'") return response_upload.Resp.img_id @@ -95,22 +65,17 @@ class PixverseTemplateNode(IO.ComfyNode): template_id = pixverse_templates.get(template, None) if template_id is None: raise Exception(f"Template '{template}' is not recognized.") - # just return the integer return IO.NodeOutput(template_id) class PixverseTextToVideoNode(IO.ComfyNode): - """ - Generates videos based on prompt and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="PixverseTextToVideoNode", display_name="PixVerse Text to Video", category="api node/video/PixVerse", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos based on prompt and output_size.", inputs=[ IO.String.Input( "prompt", @@ -177,7 +142,7 @@ class PixverseTextToVideoNode(IO.ComfyNode): negative_prompt: str = None, pixverse_template: int = None, ) -> IO.NodeOutput: - validate_string(prompt, strip_whitespace=False) + validate_string(prompt, strip_whitespace=False, min_length=1) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration if quality == PixverseQuality.res_1080p: @@ -186,18 +151,11 @@ class PixverseTextToVideoNode(IO.ComfyNode): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/video/text/generate", - method=HttpMethod.POST, - request_model=PixverseTextVideoRequest, - response_model=PixverseVideoResponse, - ), - request=PixverseTextVideoRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/video/text/generate", method="POST"), + response_model=PixverseVideoResponse, + data=PixverseTextVideoRequest( prompt=prompt, aspect_ratio=aspect_ratio, quality=quality, @@ -207,20 +165,14 @@ class PixverseTextToVideoNode(IO.ComfyNode): template_id=pixverse_template, seed=seed, ), - auth_kwargs=auth, ) - response_api = await operation.execute() - if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PixverseGenerationStatusResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), + response_model=PixverseGenerationStatusResponse, completed_statuses=[PixverseStatus.successful], failed_statuses=[ PixverseStatus.contents_moderation, @@ -228,30 +180,19 @@ class PixverseTextToVideoNode(IO.ComfyNode): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=auth, - node_id=cls.hidden.unique_id, - result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.Resp.url) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) class PixverseImageToVideoNode(IO.ComfyNode): - """ - Generates videos based on prompt and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="PixverseImageToVideoNode", display_name="PixVerse Image to Video", category="api node/video/PixVerse", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("image"), IO.String.Input( @@ -316,11 +257,7 @@ class PixverseImageToVideoNode(IO.ComfyNode): pixverse_template: int = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - img_id = await upload_image_to_pixverse(image, auth_kwargs=auth) + img_id = await upload_image_to_pixverse(cls, image) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -330,14 +267,11 @@ class PixverseImageToVideoNode(IO.ComfyNode): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/video/img/generate", - method=HttpMethod.POST, - request_model=PixverseImageVideoRequest, - response_model=PixverseVideoResponse, - ), - request=PixverseImageVideoRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/video/img/generate", method="POST"), + response_model=PixverseVideoResponse, + data=PixverseImageVideoRequest( img_id=img_id, prompt=prompt, quality=quality, @@ -347,20 +281,15 @@ class PixverseImageToVideoNode(IO.ComfyNode): template_id=pixverse_template, seed=seed, ), - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PixverseGenerationStatusResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), + response_model=PixverseGenerationStatusResponse, completed_statuses=[PixverseStatus.successful], failed_statuses=[ PixverseStatus.contents_moderation, @@ -368,30 +297,19 @@ class PixverseImageToVideoNode(IO.ComfyNode): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=auth, - node_id=cls.hidden.unique_id, - result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_I2V, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.Resp.url) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) class PixverseTransitionVideoNode(IO.ComfyNode): - """ - Generates videos based on prompt and output_size. - """ - @classmethod def define_schema(cls) -> IO.Schema: return IO.Schema( node_id="PixverseTransitionVideoNode", display_name="PixVerse Transition Video", category="api node/video/PixVerse", - description=cleandoc(cls.__doc__ or ""), + description="Generates videos based on prompt and output_size.", inputs=[ IO.Image.Input("first_frame"), IO.Image.Input("last_frame"), @@ -452,12 +370,8 @@ class PixverseTransitionVideoNode(IO.ComfyNode): negative_prompt: str = None, ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=auth) - last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=auth) + first_frame_id = await upload_image_to_pixverse(cls, first_frame) + last_frame_id = await upload_image_to_pixverse(cls, last_frame) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -467,14 +381,11 @@ class PixverseTransitionVideoNode(IO.ComfyNode): elif duration_seconds != PixverseDuration.dur_5: motion_mode = PixverseMotionMode.normal - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/pixverse/video/transition/generate", - method=HttpMethod.POST, - request_model=PixverseTransitionVideoRequest, - response_model=PixverseVideoResponse, - ), - request=PixverseTransitionVideoRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/pixverse/video/transition/generate", method="POST"), + response_model=PixverseVideoResponse, + data=PixverseTransitionVideoRequest( first_frame_img=first_frame_id, last_frame_img=last_frame_id, prompt=prompt, @@ -484,20 +395,15 @@ class PixverseTransitionVideoNode(IO.ComfyNode): negative_prompt=negative_prompt if negative_prompt else None, seed=seed, ), - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=PixverseGenerationStatusResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/pixverse/video/result/{response_api.Resp.video_id}"), + response_model=PixverseGenerationStatusResponse, completed_statuses=[PixverseStatus.successful], failed_statuses=[ PixverseStatus.contents_moderation, @@ -505,16 +411,9 @@ class PixverseTransitionVideoNode(IO.ComfyNode): PixverseStatus.deleted, ], status_extractor=lambda x: x.Resp.status, - auth_kwargs=auth, - node_id=cls.hidden.unique_id, - result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = await operation.execute() - - async with aiohttp.ClientSession() as session: - async with session.get(response_poll.Resp.url) as vid_response: - return IO.NodeOutput(VideoFromFile(BytesIO(await vid_response.content.read()))) + return IO.NodeOutput(await download_url_to_video_output(response_poll.Resp.url)) class PixVerseExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index bf60f36ad..3a3984881 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -1,96 +1,81 @@ -from __future__ import annotations -from inspect import cleandoc -from typing import Optional +from io import BytesIO +from typing import Optional, Union + +import aiohttp +import torch +from PIL import UnidentifiedImageError +from typing_extensions import override + from comfy.utils import ProgressBar -from comfy_extras.nodes.nodes_images import SVG # Added -from comfy.comfy_types.node_typing import IO +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.recraft_api import ( - RecraftImageGenerationRequest, - RecraftImageGenerationResponse, - RecraftImageSize, - RecraftModel, - RecraftStyle, - RecraftStyleV3, RecraftColor, RecraftColorChain, RecraftControls, + RecraftImageGenerationRequest, + RecraftImageGenerationResponse, + RecraftImageSize, RecraftIO, + RecraftModel, + RecraftStyle, + RecraftStyleV3, get_v3_substyles, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( bytesio_to_image_tensor, - download_url_to_bytesio, - tensor_to_bytesio, + download_url_as_bytesio, resize_mask_to_image, + sync_op, + tensor_to_bytesio, validate_string, ) -from comfy.cmd.server import PromptServer - -import torch -from io import BytesIO -from PIL import UnidentifiedImageError -import aiohttp +from comfy_extras.nodes.nodes_images import SVG async def handle_recraft_file_request( - image: torch.Tensor, - path: str, - mask: torch.Tensor=None, - total_pixels=4096*4096, - timeout=1024, - request=None, - auth_kwargs: dict[str,str] = None, + cls: type[IO.ComfyNode], + image: torch.Tensor, + path: str, + mask: Optional[torch.Tensor] = None, + total_pixels: int = 4096 * 4096, + timeout: int = 1024, + request=None, ) -> list[BytesIO]: - """ - Handle sending common Recraft file-only request to get back file bytes. - """ - if request is None: - request = EmptyRequest() + """Handle sending common Recraft file-only request to get back file bytes.""" - files = { - 'image': tensor_to_bytesio(image, total_pixels=total_pixels).read() - } + files = {"image": tensor_to_bytesio(image, total_pixels=total_pixels).read()} if mask is not None: - files['mask'] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() + files["mask"] = tensor_to_bytesio(mask, total_pixels=total_pixels).read() - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=type(request), - response_model=RecraftImageGenerationResponse, - ), - request=request, + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=path, method="POST"), + response_model=RecraftImageGenerationResponse, + data=request if request else None, files=files, content_type="multipart/form-data", - auth_kwargs=auth_kwargs, multipart_parser=recraft_multipart_parser, + max_retries=1, ) - response: RecraftImageGenerationResponse = await operation.execute() all_bytesio = [] if response.image is not None: - all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout)) + all_bytesio.append(await download_url_as_bytesio(response.image.url, timeout=timeout)) else: for data in response.data: - all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout)) + all_bytesio.append(await download_url_as_bytesio(data.url, timeout=timeout)) return all_bytesio def recraft_multipart_parser( - data, - parent_key=None, - formatter: callable = None, - converted_to_check: list[list] = None, - is_list: bool = False, - return_mode: str = "formdata" # "dict" | "formdata" -) -> dict | aiohttp.FormData: + data, + parent_key=None, + formatter: Optional[type[callable]] = None, + converted_to_check: Optional[list[list]] = None, + is_list: bool = False, + return_mode: str = "formdata", # "dict" | "formdata" +) -> Union[dict, aiohttp.FormData]: """ Formats data such that multipart/form-data will work with aiohttp library when both files and data are present. @@ -107,11 +92,12 @@ def recraft_multipart_parser( etc. Whoever made this serialization up at OpenAI added the constraint that lists must be of uniform length on objects of same 'type'. """ + # Modification of a function that handled a different type of multipart parsing, big ups: # https://gist.github.com/kazqvaizer/4cebebe5db654a414132809f9f88067b - def handle_converted_lists(item, parent_key, lists_to_check=tuple[list]): - # if list already exists exists, just extend list with data + def handle_converted_lists(item, parent_key, lists_to_check=list[list]): + # if list already exists, just extend list with data for check_list in lists_to_check: for conv_tuple in check_list: if conv_tuple[0] == parent_key and isinstance(conv_tuple[1], list): @@ -127,7 +113,7 @@ def recraft_multipart_parser( formatter = lambda v: v # Multipart representation of value if not isinstance(data, dict): - # if list already exists exists, just extend list with data + # if list already exists, just extend list with data added = handle_converted_lists(data, parent_key, converted_to_check) if added: return {} @@ -148,7 +134,9 @@ def recraft_multipart_parser( elif isinstance(value, list): for ind, list_value in enumerate(value): iter_key = f"{current_key}[]" - converted.extend(recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items()) + converted.extend( + recraft_multipart_parser(list_value, iter_key, formatter, next_check, is_list=True).items() + ) else: converted.append((current_key, formatter(value))) @@ -168,6 +156,7 @@ class handle_recraft_image_output: """ Catch an exception related to receiving SVG data instead of image, when Infinite Style Library style_id is in use. """ + def __init__(self): pass @@ -176,253 +165,233 @@ class handle_recraft_image_output: def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None and exc_type is UnidentifiedImageError: - raise Exception("Received output data was not an image; likely an SVG. If you used style_id, make sure it is not a Vector art style.") + raise Exception( + "Received output data was not an image; likely an SVG. " + "If you used style_id, make sure it is not a Vector art style." + ) -class RecraftColorRGBNode: - """ - Create Recraft Color by choosing specific RGB values. - """ - - RETURN_TYPES = (RecraftIO.COLOR,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - RETURN_NAMES = ("recraft_color",) - FUNCTION = "create_color" - CATEGORY = "api node/image/Recraft" +class RecraftColorRGBNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftColorRGB", + display_name="Recraft Color RGB", + category="api node/image/Recraft", + description="Create Recraft Color by choosing specific RGB values.", + inputs=[ + IO.Int.Input("r", default=0, min=0, max=255, tooltip="Red value of color."), + IO.Int.Input("g", default=0, min=0, max=255, tooltip="Green value of color."), + IO.Int.Input("b", default=0, min=0, max=255, tooltip="Blue value of color."), + IO.Custom(RecraftIO.COLOR).Input("recraft_color", optional=True), + ], + outputs=[ + IO.Custom(RecraftIO.COLOR).Output(display_name="recraft_color"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "r": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Red value of color." - }), - "g": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Green value of color." - }), - "b": (IO.INT, { - "default": 0, - "min": 0, - "max": 255, - "tooltip": "Blue value of color." - }), - }, - "optional": { - "recraft_color": (RecraftIO.COLOR,), - } - } - - def create_color(self, r: int, g: int, b: int, recraft_color: RecraftColorChain=None): + def execute(cls, r: int, g: int, b: int, recraft_color: RecraftColorChain = None) -> IO.NodeOutput: recraft_color = recraft_color.clone() if recraft_color else RecraftColorChain() recraft_color.add(RecraftColor(r, g, b)) - return (recraft_color, ) + return IO.NodeOutput(recraft_color) -class RecraftControlsNode: - """ - Create Recraft Controls for customizing Recraft generation. - """ - - RETURN_TYPES = (RecraftIO.CONTROLS,) - RETURN_NAMES = ("recraft_controls",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_controls" - CATEGORY = "api node/image/Recraft" +class RecraftControlsNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftControls", + display_name="Recraft Controls", + category="api node/image/Recraft", + description="Create Recraft Controls for customizing Recraft generation.", + inputs=[ + IO.Custom(RecraftIO.COLOR).Input("colors", optional=True), + IO.Custom(RecraftIO.COLOR).Input("background_color", optional=True), + ], + outputs=[ + IO.Custom(RecraftIO.CONTROLS).Output(display_name="recraft_controls"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - }, - "optional": { - "colors": (RecraftIO.COLOR,), - "background_color": (RecraftIO.COLOR,), - } - } - - def create_controls(self, colors: RecraftColorChain=None, background_color: RecraftColorChain=None): - return (RecraftControls(colors=colors, background_color=background_color), ) + def execute(cls, colors: RecraftColorChain = None, background_color: RecraftColorChain = None) -> IO.NodeOutput: + return IO.NodeOutput(RecraftControls(colors=colors, background_color=background_color)) -class RecraftStyleV3RealisticImageNode: - """ - Select realistic_image style and optional substyle. - """ - - RETURN_TYPES = (RecraftIO.STYLEV3,) - RETURN_NAMES = ("recraft_style",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_style" - CATEGORY = "api node/image/Recraft" - +class RecraftStyleV3RealisticImageNode(IO.ComfyNode): RECRAFT_STYLE = RecraftStyleV3.realistic_image @classmethod - def INPUT_TYPES(s): - return { - "required": { - "substyle": (get_v3_substyles(s.RECRAFT_STYLE),), - } - } + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3RealisticImage", + display_name="Recraft Style - Realistic Image", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) - def create_style(self, substyle: str): + @classmethod + def execute(cls, substyle: str) -> IO.NodeOutput: if substyle == "None": substyle = None - return (RecraftStyle(self.RECRAFT_STYLE, substyle),) + return IO.NodeOutput(RecraftStyle(cls.RECRAFT_STYLE, substyle)) class RecraftStyleV3DigitalIllustrationNode(RecraftStyleV3RealisticImageNode): - """ - Select digital_illustration style and optional substyle. - """ - RECRAFT_STYLE = RecraftStyleV3.digital_illustration + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3DigitalIllustration", + display_name="Recraft Style - Digital Illustration", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) + class RecraftStyleV3VectorIllustrationNode(RecraftStyleV3RealisticImageNode): - """ - Select vector_illustration style and optional substyle. - """ - RECRAFT_STYLE = RecraftStyleV3.vector_illustration + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3VectorIllustrationNode", + display_name="Recraft Style - Realistic Image", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) + class RecraftStyleV3LogoRasterNode(RecraftStyleV3RealisticImageNode): - """ - Select vector_illustration style and optional substyle. - """ - - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "substyle": (get_v3_substyles(s.RECRAFT_STYLE, include_none=False),), - } - } - RECRAFT_STYLE = RecraftStyleV3.logo_raster + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3LogoRaster", + display_name="Recraft Style - Logo Raster", + category="api node/image/Recraft", + description="Select realistic_image style and optional substyle.", + inputs=[ + IO.Combo.Input("substyle", options=get_v3_substyles(cls.RECRAFT_STYLE, include_none=False)), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) -class RecraftStyleInfiniteStyleLibrary: - """ - Select style based on preexisting UUID from Recraft's Infinite Style Library. - """ - RETURN_TYPES = (RecraftIO.STYLEV3,) - RETURN_NAMES = ("recraft_style",) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "create_style" - CATEGORY = "api node/image/Recraft" +class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftStyleV3InfiniteStyleLibrary", + display_name="Recraft Style - Infinite Style Library", + category="api node/image/Recraft", + description="Select style based on preexisting UUID from Recraft's Infinite Style Library.", + inputs=[ + IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."), + ], + outputs=[ + IO.Custom(RecraftIO.STYLEV3).Output(display_name="recraft_style"), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "style_id": (IO.STRING, { - "default": "", - "tooltip": "UUID of style from Infinite Style Library.", - }) - } - } - - def create_style(self, style_id: str): + def execute(cls, style_id: str) -> IO.NodeOutput: if not style_id: raise Exception("The style_id input cannot be empty.") - return (RecraftStyle(style_id=style_id),) + return IO.NodeOutput(RecraftStyle(style_id=style_id)) -class RecraftTextToImageNode: - """ - Generates images synchronously based on prompt and resolution. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftTextToImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftTextToImageNode", + display_name="Recraft Text to Image", + category="api node/image/Recraft", + description="Generates images synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Combo.Input( + "size", + options=[res.value for res in RecraftImageSize], + default=RecraftImageSize.res_1024x1024, + tooltip="The size of the generated image.", + ), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "size": ( - [res.value for res in RecraftImageSize], - { - "default": RecraftImageSize.res_1024x1024, - "tooltip": "The size of the generated image.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, - prompt: str, - size: str, - n: int, - seed, - recraft_style: RecraftStyle = None, - negative_prompt: str = None, - recraft_controls: RecraftControls = None, - unique_id: Optional[str] = None, - **kwargs, - ): + async def execute( + cls, + prompt: str, + size: str, + n: int, + seed, + recraft_style: RecraftStyle = None, + negative_prompt: str = None, + recraft_controls: RecraftControls = None, + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -435,14 +404,11 @@ class RecraftTextToImageNode: if not negative_prompt: negative_prompt = None - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/recraft/image_generation", - method=HttpMethod.POST, - request_model=RecraftImageGenerationRequest, - response_model=RecraftImageGenerationResponse, - ), - request=RecraftImageGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, model=RecraftModel.recraftv3, @@ -453,119 +419,92 @@ class RecraftTextToImageNode: style_id=recraft_style.style_id, controls=controls_api, ), - auth_kwargs=kwargs, + max_retries=1, ) - response: RecraftImageGenerationResponse = await operation.execute() images = [] - urls = [] for data in response.data: with handle_recraft_image_output(): - if unique_id and data.url: - urls.append(data.url) - urls_string = '\n'.join(urls) - PromptServer.instance.send_progress_text( - f"Result URL: {urls_string}", unique_id - ) - image = bytesio_to_image_tensor( - await download_url_to_bytesio(data.url, timeout=1024) - ) + image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024)) if len(image.shape) < 4: image = image.unsqueeze(0) images.append(image) - output_image = torch.cat(images, dim=0) - return (output_image,) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftImageToImageNode: - """ - Modify image based on prompt and strength. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftImageToImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftImageToImageNode", + display_name="Recraft Image to Image", + category="api node/image/Recraft", + description="Modify image based on prompt and strength.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Float.Input( + "strength", + default=0.5, + min=0.0, + max=1.0, + step=0.01, + tooltip="Defines the difference with the original image, should lie in [0, 1], " + "where 0 means almost identical, and 1 means miserable similarity.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "strength": ( - IO.FLOAT, - { - "default": 0.5, - "min": 0.0, - "max": 1.0, - "step": 0.01, - "tooltip": "Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity." - } - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, - image: torch.Tensor, - prompt: str, - n: int, - strength: float, - seed, - recraft_style: RecraftStyle = None, - negative_prompt: str = None, - recraft_controls: RecraftControls = None, - **kwargs, - ): + async def execute( + cls, + image: torch.Tensor, + prompt: str, + n: int, + strength: float, + seed, + recraft_style: RecraftStyle = None, + negative_prompt: str = None, + recraft_controls: RecraftControls = None, + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -595,92 +534,77 @@ class RecraftImageToImageNode: pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/imageToImage", request=request, - auth_kwargs=kwargs, ) with handle_recraft_image_output(): images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftImageInpaintingNode: - """ - Modify image based on prompt and mask. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftImageInpaintingNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftImageInpaintingNode", + display_name="Recraft Image Inpainting", + category="api node/image/Recraft", + description="Modify image based on prompt and mask.", + inputs=[ + IO.Image.Input("image"), + IO.Mask.Input("mask"), + IO.String.Input("prompt", multiline=True, default="", tooltip="Prompt for the image generation."), + IO.Int.Input( + "n", + default=1, + min=1, + max=6, + tooltip="The number of images to generate.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "mask": (IO.MASK, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, - image: torch.Tensor, - mask: torch.Tensor, - prompt: str, - n: int, - seed, - recraft_style: RecraftStyle = None, - negative_prompt: str = None, - **kwargs, - ): + async def execute( + cls, + image: torch.Tensor, + mask: torch.Tensor, + prompt: str, + n: int, + seed, + recraft_style: RecraftStyle = None, + negative_prompt: str = None, + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: @@ -707,106 +631,81 @@ class RecraftImageInpaintingNode: pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], - mask=mask[i:i+1], + mask=mask[i: i + 1], path="/proxy/recraft/images/inpaint", request=request, - auth_kwargs=kwargs, ) with handle_recraft_image_output(): images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftTextToVectorNode: - """ - Generates SVG synchronously based on prompt and resolution. - """ - - RETURN_TYPES = ("SVG",) # Changed - DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftTextToVectorNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftTextToVectorNode", + display_name="Recraft Text to Vector", + category="api node/image/Recraft", + description="Generates SVG synchronously based on prompt and resolution.", + inputs=[ + IO.String.Input("prompt", default="", tooltip="Prompt for the image generation.", multiline=True), + IO.Combo.Input("substyle", options=get_v3_substyles(RecraftStyleV3.vector_illustration)), + IO.Combo.Input( + "size", + options=[res.value for res in RecraftImageSize], + default=RecraftImageSize.res_1024x1024, + tooltip="The size of the generated image.", + ), + IO.Int.Input("n", default=1, min=1, max=6, tooltip="The number of images to generate."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + IO.Custom(RecraftIO.CONTROLS).Input( + "recraft_controls", + tooltip="Optional additional controls over the generation via the Recraft Controls node.", + optional=True, + ), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "substyle": (get_v3_substyles(RecraftStyleV3.vector_illustration),), - "size": ( - [res.value for res in RecraftImageSize], - { - "default": RecraftImageSize.res_1024x1024, - "tooltip": "The size of the generated image.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - "recraft_controls": ( - RecraftIO.CONTROLS, - { - "tooltip": "Optional additional controls over the generation via the Recraft Controls node." - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - async def api_call( - self, - prompt: str, - substyle: str, - size: str, - n: int, - seed, - negative_prompt: str = None, - recraft_controls: RecraftControls = None, - unique_id: Optional[str] = None, - **kwargs, - ): + async def execute( + cls, + prompt: str, + substyle: str, + size: str, + n: int, + seed, + negative_prompt: str = None, + recraft_controls: RecraftControls = None, + ) -> IO.NodeOutput: validate_string(prompt, strip_whitespace=False, max_length=1000) # create RecraftStyle so strings will be formatted properly (i.e. "None" will become None) recraft_style = RecraftStyle(RecraftStyleV3.vector_illustration, substyle=substyle) @@ -818,14 +717,11 @@ class RecraftTextToVectorNode: if not negative_prompt: negative_prompt = None - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/recraft/image_generation", - method=HttpMethod.POST, - request_model=RecraftImageGenerationRequest, - response_model=RecraftImageGenerationResponse, - ), - request=RecraftImageGenerationRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"), + response_model=RecraftImageGenerationResponse, + data=RecraftImageGenerationRequest( prompt=prompt, negative_prompt=negative_prompt, model=RecraftModel.recraftv3, @@ -835,139 +731,105 @@ class RecraftTextToVectorNode: substyle=recraft_style.substyle, controls=controls_api, ), - auth_kwargs=kwargs, + max_retries=1, ) - response: RecraftImageGenerationResponse = await operation.execute() svg_data = [] - urls = [] for data in response.data: - if unique_id and data.url: - urls.append(data.url) - # Print result on each iteration in case of error - PromptServer.instance.send_progress_text( - f"Result URL: {' '.join(urls)}", unique_id - ) - svg_data.append(await download_url_to_bytesio(data.url, timeout=1024)) + svg_data.append(await download_url_as_bytesio(data.url, timeout=1024)) - return (SVG(svg_data),) + return IO.NodeOutput(SVG(svg_data)) -class RecraftVectorizeImageNode: - """ - Generates SVG synchronously from an input image. - """ - - RETURN_TYPES = ("SVG",) # Changed - DESCRIPTION = cleandoc(__doc__ or "") if 'cleandoc' in globals() else __doc__ # Keep cleandoc if other nodes use it - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftVectorizeImageNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftVectorizeImageNode", + display_name="Recraft Vectorize Image", + category="api node/image/Recraft", + description="Generates SVG synchronously from an input image.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: svgs = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/vectorize", - auth_kwargs=kwargs, ) svgs.append(SVG(sub_bytes)) pbar.update(1) - return (SVG.combine_all(svgs), ) + return IO.NodeOutput(SVG.combine_all(svgs)) -class RecraftReplaceBackgroundNode: - """ - Replace background on image, based on provided prompt. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftReplaceBackgroundNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftReplaceBackgroundNode", + display_name="Recraft Replace Background", + category="api node/image/Recraft", + description="Replace background on image, based on provided prompt.", + inputs=[ + IO.Image.Input("image"), + IO.String.Input("prompt", tooltip="Prompt for the image generation.", default="", multiline=True), + IO.Int.Input("n", default=1, min=1, max=6, tooltip="The number of images to generate."), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + IO.Custom(RecraftIO.STYLEV3).Input("recraft_style", optional=True), + IO.String.Input( + "negative_prompt", + default="", + force_input=True, + tooltip="An optional text description of undesired elements on an image.", + optional=True, + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - "prompt": ( - IO.STRING, - { - "multiline": True, - "default": "", - "tooltip": "Prompt for the image generation.", - }, - ), - "n": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 6, - "tooltip": "The number of images to generate.", - }, - ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "control_after_generate": True, - "tooltip": "Seed to determine if node should re-run; actual results are nondeterministic regardless of seed.", - }, - ), - }, - "optional": { - "recraft_style": (RecraftIO.STYLEV3,), - "negative_prompt": ( - IO.STRING, - { - "default": "", - "forceInput": True, - "tooltip": "An optional text description of undesired elements on an image.", - }, - ), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, - image: torch.Tensor, - prompt: str, - n: int, - seed, - recraft_style: RecraftStyle = None, - negative_prompt: str = None, - **kwargs, - ): + async def execute( + cls, + image: torch.Tensor, + prompt: str, + n: int, + seed, + recraft_style: RecraftStyle = None, + negative_prompt: str = None, + ) -> IO.NodeOutput: default_style = RecraftStyle(RecraftStyleV3.realistic_image) if recraft_style is None: recraft_style = default_style @@ -990,165 +852,151 @@ class RecraftReplaceBackgroundNode: pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/replaceBackground", request=request, - auth_kwargs=kwargs, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor, ) + return IO.NodeOutput(torch.cat(images, dim=0)) -class RecraftRemoveBackgroundNode: - """ - Remove background from image, and return processed image and mask. - """ - - RETURN_TYPES = (IO.IMAGE, IO.MASK) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" +class RecraftRemoveBackgroundNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftRemoveBackgroundNode", + display_name="Recraft Remove Background", + category="api node/image/Recraft", + description="Remove background from image, and return processed image and mask.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + IO.Mask.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } - - async def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: images = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], path="/proxy/recraft/images/removeBackground", - auth_kwargs=kwargs, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) images_tensor = torch.cat(images, dim=0) # use alpha channel as masks, in B,H,W format - masks_tensor = images_tensor[:,:,:,-1:].squeeze(-1) - return (images_tensor, masks_tensor) + masks_tensor = images_tensor[:, :, :, -1:].squeeze(-1) + return IO.NodeOutput(images_tensor, masks_tensor) -class RecraftCrispUpscaleNode: - """ - Upscale image synchronously. - Enhances a given raster image using ‘crisp upscale’ tool, increasing image resolution, making the image sharper and cleaner. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" - +class RecraftCrispUpscaleNode(IO.ComfyNode): RECRAFT_PATH = "/proxy/recraft/images/crispUpscale" @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": (IO.IMAGE, ), - }, - "optional": { - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="RecraftCrispUpscaleNode", + display_name="Recraft Crisp Upscale Image", + category="api node/image/Recraft", + description="Upscale image synchronously.\n" + "Enhances a given raster image using ‘crisp upscale’ tool, " + "increasing image resolution, making the image sharper and cleaner.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) - async def api_call( - self, - image: torch.Tensor, - **kwargs, - ): + @classmethod + async def execute(cls, image: torch.Tensor) -> IO.NodeOutput: images = [] total = image.shape[0] pbar = ProgressBar(total) for i in range(total): sub_bytes = await handle_recraft_file_request( + cls, image=image[i], - path=self.RECRAFT_PATH, - auth_kwargs=kwargs, + path=cls.RECRAFT_PATH, ) images.append(torch.cat([bytesio_to_image_tensor(x) for x in sub_bytes], dim=0)) pbar.update(1) - images_tensor = torch.cat(images, dim=0) - return (images_tensor,) + return IO.NodeOutput(torch.cat(images, dim=0)) class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode): - """ - Upscale image synchronously. - Enhances a given raster image using ‘creative upscale’ tool, boosting resolution with a focus on refining small details and faces. - """ - - RETURN_TYPES = (IO.IMAGE,) - DESCRIPTION = cleandoc(__doc__ or "") # Handle potential None value - FUNCTION = "api_call" - API_NODE = True - CATEGORY = "api node/image/Recraft" - RECRAFT_PATH = "/proxy/recraft/images/creativeUpscale" + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="RecraftCreativeUpscaleNode", + display_name="Recraft Creative Upscale Image", + category="api node/image/Recraft", + description="Upscale image synchronously.\n" + "Enhances a given raster image using ‘creative upscale’ tool, " + "boosting resolution with a focus on refining small details and faces.", + inputs=[ + IO.Image.Input("image"), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) -# A dictionary that contains all nodes you want to export with their names -# NOTE: names should be globally unique -NODE_CLASS_MAPPINGS = { - "RecraftTextToImageNode": RecraftTextToImageNode, - "RecraftImageToImageNode": RecraftImageToImageNode, - "RecraftImageInpaintingNode": RecraftImageInpaintingNode, - "RecraftTextToVectorNode": RecraftTextToVectorNode, - "RecraftVectorizeImageNode": RecraftVectorizeImageNode, - "RecraftRemoveBackgroundNode": RecraftRemoveBackgroundNode, - "RecraftReplaceBackgroundNode": RecraftReplaceBackgroundNode, - "RecraftCrispUpscaleNode": RecraftCrispUpscaleNode, - "RecraftCreativeUpscaleNode": RecraftCreativeUpscaleNode, - "RecraftStyleV3RealisticImage": RecraftStyleV3RealisticImageNode, - "RecraftStyleV3DigitalIllustration": RecraftStyleV3DigitalIllustrationNode, - "RecraftStyleV3LogoRaster": RecraftStyleV3LogoRasterNode, - "RecraftStyleV3InfiniteStyleLibrary": RecraftStyleInfiniteStyleLibrary, - "RecraftColorRGB": RecraftColorRGBNode, - "RecraftControls": RecraftControlsNode, -} -# A dictionary that contains the friendly/humanly readable titles for the nodes -NODE_DISPLAY_NAME_MAPPINGS = { - "RecraftTextToImageNode": "Recraft Text to Image", - "RecraftImageToImageNode": "Recraft Image to Image", - "RecraftImageInpaintingNode": "Recraft Image Inpainting", - "RecraftTextToVectorNode": "Recraft Text to Vector", - "RecraftVectorizeImageNode": "Recraft Vectorize Image", - "RecraftRemoveBackgroundNode": "Recraft Remove Background", - "RecraftReplaceBackgroundNode": "Recraft Replace Background", - "RecraftCrispUpscaleNode": "Recraft Crisp Upscale Image", - "RecraftCreativeUpscaleNode": "Recraft Creative Upscale Image", - "RecraftStyleV3RealisticImage": "Recraft Style - Realistic Image", - "RecraftStyleV3DigitalIllustration": "Recraft Style - Digital Illustration", - "RecraftStyleV3LogoRaster": "Recraft Style - Logo Raster", - "RecraftStyleV3InfiniteStyleLibrary": "Recraft Style - Infinite Style Library", - "RecraftColorRGB": "Recraft Color RGB", - "RecraftControls": "Recraft Controls", -} +class RecraftExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + RecraftTextToImageNode, + RecraftImageToImageNode, + RecraftImageInpaintingNode, + RecraftTextToVectorNode, + RecraftVectorizeImageNode, + RecraftRemoveBackgroundNode, + RecraftReplaceBackgroundNode, + RecraftCrispUpscaleNode, + RecraftCreativeUpscaleNode, + RecraftStyleV3RealisticImageNode, + RecraftStyleV3DigitalIllustrationNode, + RecraftStyleV3LogoRasterNode, + RecraftStyleInfiniteStyleLibrary, + RecraftColorRGBNode, + RecraftControlsNode, + ] + + +async def comfy_entrypoint() -> RecraftExtension: + return RecraftExtension() diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 04e25a899..72296f488 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -5,13 +5,10 @@ Rodin API docs: https://developer.hyper3d.ai/ """ -from __future__ import annotations from inspect import cleandoc from comfy.comfy_types.node_typing import IO from comfy.cmd import folder_paths as comfy_paths -import aiohttp import os -import asyncio import logging import math from typing import Optional @@ -27,11 +24,11 @@ from comfy_api_nodes.apis.rodin_api import ( Rodin3DDownloadResponse, JobStatus, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( + sync_op, + poll_op, ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, + download_url_to_bytesio, ) from comfy_api.latest import ComfyExtension, IO @@ -121,35 +118,31 @@ def tensor_to_filelike(tensor, max_pixels: int = 2048 * 2048): async def create_generate_task( - images=None, + cls: type[IO.ComfyNode], images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", - TAPose=False, - auth_kwargs: Optional[dict[str, str]] = None, + ta_pose: bool = False, + ): if images is None: raise Exception("Rodin 3D generate requires at least 1 image.") if len(images) > 5: raise Exception("Rodin 3D generate requires up to 5 image.") - path = "/proxy/rodin/api/v2/rodin" - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=path, - method=HttpMethod.POST, - request_model=Rodin3DGenerateRequest, - response_model=Rodin3DGenerateResponse, - ), - request=Rodin3DGenerateRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/rodin", method="POST"), + response_model=Rodin3DGenerateResponse, + data=Rodin3DGenerateRequest( seed=seed, tier=tier, material=material, quality_override=quality_override, mesh_mode=mesh_mode, - TAPose=TAPose, + TAPose=ta_pose, ), files=[ ( @@ -159,11 +152,8 @@ async def create_generate_task( for image in images if image is not None ], content_type="multipart/form-data", - auth_kwargs=auth_kwargs, ) - response = await operation.execute() - if hasattr(response, "error"): error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" logging.error(error_message) @@ -188,74 +178,47 @@ def check_rodin_status(response: Rodin3DCheckStatusResponse) -> str: return "Generating" +def extract_progress(response: Rodin3DCheckStatusResponse) -> Optional[int]: + if not response.jobs: + return None + completed_count = sum(1 for job in response.jobs if job.status == JobStatus.Done) + return int((completed_count / len(response.jobs)) * 100) + + async def poll_for_task_status( - subscription_key, auth_kwargs: Optional[dict[str, str]] = None, -) -> Rodin3DCheckStatusResponse: - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path="/proxy/rodin/api/v2/status", - method=HttpMethod.POST, - request_model=Rodin3DCheckStatusRequest, - response_model=Rodin3DCheckStatusResponse, - ), - request=Rodin3DCheckStatusRequest(subscription_key=subscription_key), - completed_statuses=["DONE"], - failed_statuses=["FAILED"], - status_extractor=check_rodin_status, - poll_interval=3.0, - auth_kwargs=auth_kwargs, - ) + subscription_key: str, cls: type[IO.ComfyNode]) -> Rodin3DCheckStatusResponse: logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") - return await poll_operation.execute() - - -async def get_rodin_download_list(uuid, auth_kwargs: Optional[dict[str, str]] = None) -> Rodin3DDownloadResponse: - logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/rodin/api/v2/download", - method=HttpMethod.POST, - request_model=Rodin3DDownloadRequest, - response_model=Rodin3DDownloadResponse, - ), - request=Rodin3DDownloadRequest(task_uuid=uuid), - auth_kwargs=auth_kwargs, + return await poll_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/status", method="POST"), + response_model=Rodin3DCheckStatusResponse, + data=Rodin3DCheckStatusRequest(subscription_key=subscription_key), + status_extractor=check_rodin_status, + progress_extractor=extract_progress, ) - return await operation.execute() -async def download_files(url_list, task_uuid): - save_path = os.path.join(comfy_paths.get_output_directory(), f"Rodin3D_{task_uuid}") +async def get_rodin_download_list(uuid: str, cls: type[IO.ComfyNode]) -> Rodin3DDownloadResponse: + logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") + return await sync_op( + cls, + ApiEndpoint(path="/proxy/rodin/api/v2/download", method="POST"), + response_model=Rodin3DDownloadResponse, + data=Rodin3DDownloadRequest(task_uuid=uuid), + monitor_progress=False, + ) + + +async def download_files(url_list, task_uuid: str): + result_folder_name = f"Rodin3D_{task_uuid}" + save_path = os.path.join(comfy_paths.get_output_directory(), result_folder_name) os.makedirs(save_path, exist_ok=True) model_file_path = None - async with aiohttp.ClientSession() as session: - for i in url_list.list: - url = i.url - file_name = i.name - file_path = os.path.join(save_path, file_name) - if file_path.endswith(".glb"): - model_file_path = file_path - logging.info("[ Rodin3D API - download_files ] Downloading file: %s", file_path) - max_retries = 5 - for attempt in range(max_retries): - try: - async with session.get(url) as resp: - resp.raise_for_status() - with open(file_path, "wb") as f: - async for chunk in resp.content.iter_chunked(32 * 1024): - f.write(chunk) - break - except Exception as e: - logging.info("[ Rodin3D API - download_files ] Error downloading %s:%s", file_path, str(e)) - if attempt < max_retries - 1: - logging.info("Retrying...") - await asyncio.sleep(2) - else: - logging.info( - "[ Rodin3D API - download_files ] Failed to download %s after %s attempts.", - file_path, - max_retries, - ) + for i in url_list.list: + file_path = os.path.join(save_path, i.name) + if file_path.endswith(".glb"): + model_file_path = os.path.join(result_folder_name, i.name) + await download_url_to_bytesio(i.url, file_path) return model_file_path @@ -277,6 +240,7 @@ class Rodin3D_Regular(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -295,21 +259,17 @@ class Rodin3D_Regular(IO.ComfyNode): for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - auth_kwargs=auth, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -333,6 +293,7 @@ class Rodin3D_Detail(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -351,21 +312,17 @@ class Rodin3D_Detail(IO.ComfyNode): for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - auth_kwargs=auth, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -389,6 +346,7 @@ class Rodin3D_Smooth(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -401,27 +359,22 @@ class Rodin3D_Smooth(IO.ComfyNode): Material_Type, Polygon_count, ) -> IO.NodeOutput: - tier = "Smooth" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, - tier=tier, + tier="Smooth", mesh_mode=mesh_mode, - auth_kwargs=auth, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -452,6 +405,7 @@ class Rodin3D_Sketch(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -462,29 +416,21 @@ class Rodin3D_Sketch(IO.ComfyNode): Images, Seed, ) -> IO.NodeOutput: - tier = "Sketch" num_images = Images.shape[0] m_images = [] for i in range(num_images): m_images.append(Images[i]) - material_type = "PBR" - quality_override = 18000 - mesh_mode = "Quad" - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, - material=material_type, - quality_override=quality_override, - tier=tier, - mesh_mode=mesh_mode, - auth_kwargs=auth, + material="PBR", + quality_override=18000, + tier="Sketch", + mesh_mode="Quad", ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) @@ -523,6 +469,7 @@ class Rodin3D_Gen2(IO.ComfyNode): hidden=[ IO.Hidden.auth_token_comfy_org, IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, ], is_api_node=True, ) @@ -542,22 +489,18 @@ class Rodin3D_Gen2(IO.ComfyNode): for i in range(num_images): m_images.append(Images[i]) mesh_mode, quality_override = get_quality_mode(Polygon_count) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } task_uuid, subscription_key = await create_generate_task( + cls, images=m_images, seed=Seed, material=Material_Type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, - TAPose=TAPose, - auth_kwargs=auth, + ta_pose=TAPose, ) - await poll_for_task_status(subscription_key, auth_kwargs=auth) - download_list = await get_rodin_download_list(task_uuid, auth_kwargs=auth) + await poll_for_task_status(subscription_key, cls) + download_list = await get_rodin_download_list(task_uuid, cls) model = await download_files(download_list, task_uuid) return IO.NodeOutput(model) diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index eb03a897d..3c55039c9 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -11,17 +11,15 @@ User Guides: """ -from typing import Union, Optional, Any -from typing_extensions import override from enum import Enum -import torch +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl from comfy_api_nodes.apis import ( RunwayImageToVideoRequest, RunwayImageToVideoResponse, RunwayTaskStatusResponse as TaskStatusResponse, - RunwayTaskStatusEnum as TaskStatus, RunwayModelEnum as Model, RunwayDurationEnum as Duration, RunwayAspectRatioEnum as AspectRatio, @@ -33,23 +31,18 @@ from comfy_api_nodes.apis import ( ReferenceImage, RunwayTextToImageAspectRatioEnum, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( - upload_images_to_comfyapi, - download_url_to_video_output, +from comfy_api_nodes.util import ( image_tensor_pair_to_batch, validate_string, + validate_image_dimensions, + validate_image_aspect_ratio, + upload_images_to_comfyapi, + download_url_to_video_output, download_url_to_image_tensor, + ApiEndpoint, + sync_op, + poll_op, ) -from comfy_api.input_impl import VideoFromFile -from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.util.validation_utils import validate_image_dimensions, validate_image_aspect_ratio PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video" PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image" @@ -84,47 +77,22 @@ class RunwayGen3aAspectRatio(str, Enum): field_1280_768 = "1280:768" -def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: +def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None: """Returns the video URL from the task status response if it exists.""" if hasattr(response, "output") and len(response.output) > 0: return response.output[0] return None -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, TaskStatusResponse], - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> TaskStatusResponse: - """Polls the Runway API endpoint until the task reaches a terminal state, then returns the response.""" - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[ - TaskStatus.SUCCEEDED.value, - ], - failed_statuses=[ - TaskStatus.FAILED.value, - TaskStatus.CANCELLED.value, - ], - status_extractor=lambda response: response.status.value, - auth_kwargs=auth_kwargs, - result_url_extractor=get_video_url_from_task_status, - estimated_duration=estimated_duration, - node_id=node_id, - progress_extractor=extract_progress_from_task_status, - ).execute() - - def extract_progress_from_task_status( response: TaskStatusResponse, -) -> Union[float, None]: +) -> float | None: if hasattr(response, "progress") and response.progress is not None: return response.progress * 100 return None -def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]: +def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None: """Returns the image URL from the task status response if it exists.""" if hasattr(response, "output") and len(response.output) > 0: return response.output[0] @@ -132,42 +100,32 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, N async def get_response( - task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None, estimated_duration: Optional[int] = None + cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None ) -> TaskStatusResponse: """Poll the task status until it is finished then get the response.""" - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=f"{PATH_GET_TASK_STATUS}/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), + return await poll_op( + cls, + ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"), + response_model=TaskStatusResponse, + status_extractor=lambda r: r.status.value, estimated_duration=estimated_duration, - node_id=node_id, + progress_extractor=extract_progress_from_task_status, ) async def generate_video( + cls: type[IO.ComfyNode], request: RunwayImageToVideoRequest, - auth_kwargs: dict[str, str], - node_id: Optional[str] = None, - estimated_duration: Optional[int] = None, -) -> VideoFromFile: - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_IMAGE_TO_VIDEO, - method=HttpMethod.POST, - request_model=RunwayImageToVideoRequest, - response_model=RunwayImageToVideoResponse, - ), - request=request, - auth_kwargs=auth_kwargs, + estimated_duration: int | None = None, +) -> InputImpl.VideoFromFile: + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"), + response_model=RunwayImageToVideoResponse, + data=request, ) - initial_response = await initial_operation.execute() - - final_response = await get_response(initial_response.id, auth_kwargs, node_id, estimated_duration) + final_response = await get_response(cls, initial_response.id, estimated_duration) if not final_response.output: raise RunwayApiError("Runway task succeeded but no video data found in response.") @@ -184,9 +142,9 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): display_name="Runway Image to Video (Gen3a Turbo)", category="api node/video/Runway", description="Generate a video from a single starting frame using Gen3a Turbo model. " - "Before diving in, review these best practices to ensure that " - "your input selections will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.", + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/33927968552339-Creating-with-Act-One-on-Gen-3-Alpha-and-Turbo.", inputs=[ IO.String.Input( "prompt", @@ -232,29 +190,25 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): async def execute( cls, prompt: str, - start_frame: torch.Tensor, + start_frame: Input.Image, duration: str, ratio: str, seed: int, ) -> IO.NodeOutput: validate_string(prompt, min_length=1) validate_image_dimensions(start_frame, max_width=7999, max_height=7999) - validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } + validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) download_urls = await upload_images_to_comfyapi( + cls, start_frame, max_images=1, mime_type="image/png", - auth_kwargs=auth_kwargs, ) return IO.NodeOutput( await generate_video( + cls, RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -262,15 +216,9 @@ class RunwayImageToVideoNodeGen3a(IO.ComfyNode): duration=Duration(duration), ratio=AspectRatio(ratio), promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] + root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")] ), ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, ) ) @@ -284,9 +232,9 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): display_name="Runway Image to Video (Gen4 Turbo)", category="api node/video/Runway", description="Generate a video from a single starting frame using Gen4 Turbo model. " - "Before diving in, review these best practices to ensure that " - "your input selections will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.", + "Before diving in, review these best practices to ensure that " + "your input selections will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/37327109429011-Creating-with-Gen-4-Video.", inputs=[ IO.String.Input( "prompt", @@ -332,29 +280,25 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): async def execute( cls, prompt: str, - start_frame: torch.Tensor, + start_frame: Input.Image, duration: str, ratio: str, seed: int, ) -> IO.NodeOutput: validate_string(prompt, min_length=1) validate_image_dimensions(start_frame, max_width=7999, max_height=7999) - validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } + validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) download_urls = await upload_images_to_comfyapi( + cls, start_frame, max_images=1, mime_type="image/png", - auth_kwargs=auth_kwargs, ) return IO.NodeOutput( await generate_video( + cls, RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -362,15 +306,9 @@ class RunwayImageToVideoNodeGen4(IO.ComfyNode): duration=Duration(duration), ratio=AspectRatio(ratio), promptImage=RunwayPromptImageObject( - root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ) - ] + root=[RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first")] ), ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_FLF_SECONDS, ) ) @@ -385,12 +323,12 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): display_name="Runway First-Last-Frame to Video", category="api node/video/Runway", description="Upload first and last keyframes, draft a prompt, and generate a video. " - "More complex transitions, such as cases where the Last frame is completely different " - "from the First frame, may benefit from the longer 10s duration. " - "This would give the generation more time to smoothly transition between the two inputs. " - "Before diving in, review these best practices to ensure that your input selections " - "will set your generation up for success: " - "https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.", + "More complex transitions, such as cases where the Last frame is completely different " + "from the First frame, may benefit from the longer 10s duration. " + "This would give the generation more time to smoothly transition between the two inputs. " + "Before diving in, review these best practices to ensure that your input selections " + "will set your generation up for success: " + "https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3.", inputs=[ IO.String.Input( "prompt", @@ -440,8 +378,8 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): async def execute( cls, prompt: str, - start_frame: torch.Tensor, - end_frame: torch.Tensor, + start_frame: Input.Image, + end_frame: Input.Image, duration: str, ratio: str, seed: int, @@ -449,26 +387,22 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): validate_string(prompt, min_length=1) validate_image_dimensions(start_frame, max_width=7999, max_height=7999) validate_image_dimensions(end_frame, max_width=7999, max_height=7999) - validate_image_aspect_ratio(start_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - validate_image_aspect_ratio(end_frame, min_aspect_ratio=0.5, max_aspect_ratio=2.0) - - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } + validate_image_aspect_ratio(start_frame, (1, 2), (2, 1)) + validate_image_aspect_ratio(end_frame, (1, 2), (2, 1)) stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) download_urls = await upload_images_to_comfyapi( + cls, stacked_input_images, max_images=2, mime_type="image/png", - auth_kwargs=auth_kwargs, ) if len(download_urls) != 2: raise RunwayApiError("Failed to upload one or more images to comfy api.") return IO.NodeOutput( await generate_video( + cls, RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -477,17 +411,11 @@ class RunwayFirstLastFrameNode(IO.ComfyNode): ratio=AspectRatio(ratio), promptImage=RunwayPromptImageObject( root=[ - RunwayPromptImageDetailedObject( - uri=str(download_urls[0]), position="first" - ), - RunwayPromptImageDetailedObject( - uri=str(download_urls[1]), position="last" - ), + RunwayPromptImageDetailedObject(uri=str(download_urls[0]), position="first"), + RunwayPromptImageDetailedObject(uri=str(download_urls[1]), position="last"), ] ), ), - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_FLF_SECONDS, ) ) @@ -502,7 +430,7 @@ class RunwayTextToImageNode(IO.ComfyNode): display_name="Runway Text to Image", category="api node/image/Runway", description="Generate an image from a text prompt using Runway's Gen 4 model. " - "You can also include reference image to guide the generation.", + "You can also include reference image to guide the generation.", inputs=[ IO.String.Input( "prompt", @@ -536,53 +464,38 @@ class RunwayTextToImageNode(IO.ComfyNode): cls, prompt: str, ratio: str, - reference_image: Optional[torch.Tensor] = None, + reference_image: Input.Image | None = None, ) -> IO.NodeOutput: validate_string(prompt, min_length=1) - auth_kwargs = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - # Prepare reference images if provided reference_images = None if reference_image is not None: validate_image_dimensions(reference_image, max_width=7999, max_height=7999) - validate_image_aspect_ratio(reference_image, min_aspect_ratio=0.5, max_aspect_ratio=2.0) + validate_image_aspect_ratio(reference_image, (1, 2), (2, 1)) download_urls = await upload_images_to_comfyapi( + cls, reference_image, max_images=1, mime_type="image/png", - auth_kwargs=auth_kwargs, ) reference_images = [ReferenceImage(uri=str(download_urls[0]))] - request = RunwayTextToImageRequest( - promptText=prompt, - model=Model4.gen4_image, - ratio=ratio, - referenceImages=reference_images, - ) - - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=PATH_TEXT_TO_IMAGE, - method=HttpMethod.POST, - request_model=RunwayTextToImageRequest, - response_model=RunwayTextToImageResponse, + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path=PATH_TEXT_TO_IMAGE, method="POST"), + response_model=RunwayTextToImageResponse, + data=RunwayTextToImageRequest( + promptText=prompt, + model=Model4.gen4_image, + ratio=ratio, + referenceImages=reference_images, ), - request=request, - auth_kwargs=auth_kwargs, ) - initial_response = await initial_operation.execute() - - # Poll for completion final_response = await get_response( + cls, initial_response.id, - auth_kwargs=auth_kwargs, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_T2I_SECONDS, ) if not final_response.output: @@ -601,5 +514,6 @@ class RunwayExtension(ComfyExtension): RunwayTextToImageNode, ] + async def comfy_entrypoint() -> RunwayExtension: return RunwayExtension() diff --git a/comfy_api_nodes/nodes_sora.py b/comfy_api_nodes/nodes_sora.py index efc954869..92b225d40 100644 --- a/comfy_api_nodes/nodes_sora.py +++ b/comfy_api_nodes/nodes_sora.py @@ -1,23 +1,20 @@ from typing import Optional -from typing_extensions import override import torch from pydantic import BaseModel, Field -from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.util.validation_utils import get_number_of_images +from typing_extensions import override -from comfy_api_nodes.apinode_utils import ( +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import ( + ApiEndpoint, download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, tensor_to_bytesio, ) + class Sora2GenerationRequest(BaseModel): prompt: str = Field(...) model: str = Field(...) @@ -80,7 +77,7 @@ class OpenAIVideoSora2(IO.ComfyNode): control_after_generate=True, optional=True, tooltip="Seed to determine if node should re-run; " - "actual results are nondeterministic regardless of seed.", + "actual results are nondeterministic regardless of seed.", ), ], outputs=[ @@ -111,55 +108,34 @@ class OpenAIVideoSora2(IO.ComfyNode): if get_number_of_images(image) != 1: raise ValueError("Currently only one input image is supported.") files_input = {"input_reference": ("image.png", tensor_to_bytesio(image), "image/png")} - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - payload = Sora2GenerationRequest( - model=model, - prompt=prompt, - seconds=str(duration), - size=size, - ) - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/openai/v1/videos", - method=HttpMethod.POST, - request_model=Sora2GenerationRequest, - response_model=Sora2GenerationResponse + initial_response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/openai/v1/videos", method="POST"), + data=Sora2GenerationRequest( + model=model, + prompt=prompt, + seconds=str(duration), + size=size, ), - request=payload, files=files_input, - auth_kwargs=auth, + response_model=Sora2GenerationResponse, content_type="multipart/form-data", ) - initial_response = await initial_operation.execute() if initial_response.error: - raise Exception(initial_response.error.message) + raise Exception(initial_response.error["message"]) model_time_multiplier = 1 if model == "sora-2" else 2 - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/openai/v1/videos/{initial_response.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=Sora2GenerationResponse - ), - completed_statuses=["completed"], - failed_statuses=["failed"], + await poll_op( + cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/openai/v1/videos/{initial_response.id}"), + response_model=Sora2GenerationResponse, status_extractor=lambda x: x.status, - auth_kwargs=auth, poll_interval=8.0, max_poll_attempts=160, - node_id=cls.hidden.unique_id, - estimated_duration=45 * (duration / 4) * model_time_multiplier, + estimated_duration=int(45 * (duration / 4) * model_time_multiplier), ) - await poll_operation.execute() return IO.NodeOutput( - await download_url_to_video_output( - f"/proxy/openai/v1/videos/{initial_response.id}/content", - auth_kwargs=auth, - ) + await download_url_to_video_output(f"/proxy/openai/v1/videos/{initial_response.id}/content", cls=cls), ) diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 8af03cfd1..bb7ceed78 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -20,21 +20,17 @@ from comfy_api_nodes.apis.stability_api import ( StabilityAudioInpaintRequest, StabilityAudioResponse, ) -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( +from comfy_api_nodes.util import ( + validate_audio_duration, + validate_string, + audio_input_to_mp3, bytesio_to_image_tensor, tensor_to_bytesio, - validate_string, audio_bytes_to_audio_input, - audio_input_to_mp3, + sync_op, + poll_op, + ApiEndpoint, ) -from comfy_api_nodes.util.validation_utils import validate_audio_duration import torch import base64 @@ -161,19 +157,11 @@ class StabilityStableImageUltraNode(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/generate/ultra", - method=HttpMethod.POST, - request_model=StabilityStableUltraRequest, - response_model=StabilityStableUltraResponse, - ), - request=StabilityStableUltraRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"), + response_model=StabilityStableUltraResponse, + data=StabilityStableUltraRequest( prompt=prompt, negative_prompt=negative_prompt, aspect_ratio=aspect_ratio, @@ -183,9 +171,7 @@ class StabilityStableImageUltraNode(IO.ComfyNode): ), files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.") @@ -313,19 +299,11 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/generate/sd3", - method=HttpMethod.POST, - request_model=StabilityStable3_5Request, - response_model=StabilityStableUltraResponse, - ), - request=StabilityStable3_5Request( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"), + response_model=StabilityStableUltraResponse, + data=StabilityStable3_5Request( prompt=prompt, negative_prompt=negative_prompt, aspect_ratio=aspect_ratio, @@ -338,9 +316,7 @@ class StabilityStableImageSD_3_5Node(IO.ComfyNode): ), files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.") @@ -427,19 +403,11 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/upscale/conservative", - method=HttpMethod.POST, - request_model=StabilityUpscaleConservativeRequest, - response_model=StabilityStableUltraResponse, - ), - request=StabilityUpscaleConservativeRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"), + response_model=StabilityStableUltraResponse, + data=StabilityUpscaleConservativeRequest( prompt=prompt, negative_prompt=negative_prompt, creativity=round(creativity,2), @@ -447,9 +415,7 @@ class StabilityUpscaleConservativeNode(IO.ComfyNode): ), files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.") @@ -544,19 +510,11 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/upscale/creative", - method=HttpMethod.POST, - request_model=StabilityUpscaleCreativeRequest, - response_model=StabilityAsyncResponse, - ), - request=StabilityUpscaleCreativeRequest( + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"), + response_model=StabilityAsyncResponse, + data=StabilityUpscaleCreativeRequest( prompt=prompt, negative_prompt=negative_prompt, creativity=round(creativity,2), @@ -565,25 +523,15 @@ class StabilityUpscaleCreativeNode(IO.ComfyNode): ), files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() - operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/stability/v2beta/results/{response_api.id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=StabilityResultsGetResponse, - ), + response_poll = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"), + response_model=StabilityResultsGetResponse, poll_interval=3, - completed_statuses=[StabilityPollStatus.finished], - failed_statuses=[StabilityPollStatus.failed], status_extractor=lambda x: get_async_dummy_status(x), - auth_kwargs=auth, - node_id=cls.hidden.unique_id, ) - response_poll: StabilityResultsGetResponse = await operation.execute() if response_poll.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.") @@ -628,24 +576,13 @@ class StabilityUpscaleFastNode(IO.ComfyNode): "image": image_binary } - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/stable-image/upscale/fast", - method=HttpMethod.POST, - request_model=EmptyRequest, - response_model=StabilityStableUltraResponse, - ), - request=EmptyRequest(), + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"), + response_model=StabilityStableUltraResponse, files=files, content_type="multipart/form-data", - auth_kwargs=auth, ) - response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.") @@ -717,21 +654,13 @@ class StabilityTextToAudio(IO.ComfyNode): async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput: validate_string(prompt, max_length=10000) payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", - method=HttpMethod.POST, - request_model=StabilityTextToAudioRequest, - response_model=StabilityAudioResponse, - ), - request=payload, + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"), + response_model=StabilityAudioResponse, + data=payload, content_type="multipart/form-data", - auth_kwargs= { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - response_api = await operation.execute() if not response_api.audio: raise ValueError("No audio file was received in response.") return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) @@ -814,22 +743,14 @@ class StabilityAudioToAudio(IO.ComfyNode): payload = StabilityAudioToAudioRequest( prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength ) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", - method=HttpMethod.POST, - request_model=StabilityAudioToAudioRequest, - response_model=StabilityAudioResponse, - ), - request=payload, + response_api = await sync_op( + cls, + ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"), + response_model=StabilityAudioResponse, + data=payload, content_type="multipart/form-data", files={"audio": audio_input_to_mp3(audio)}, - auth_kwargs= { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - response_api = await operation.execute() if not response_api.audio: raise ValueError("No audio file was received in response.") return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) @@ -935,22 +856,14 @@ class StabilityAudioInpaint(IO.ComfyNode): mask_start=mask_start, mask_end=mask_end, ) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", - method=HttpMethod.POST, - request_model=StabilityAudioInpaintRequest, - response_model=StabilityAudioResponse, - ), - request=payload, + response_api = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"), + response_model=StabilityAudioResponse, + data=payload, content_type="multipart/form-data", files={"audio": audio_input_to_mp3(audio)}, - auth_kwargs={ - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, ) - response_api = await operation.execute() if not response_api.audio: raise ValueError("No audio file was received in response.") return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio))) diff --git a/comfy_api_nodes/nodes_topaz.py b/comfy_api_nodes/nodes_topaz.py new file mode 100644 index 000000000..f522756e5 --- /dev/null +++ b/comfy_api_nodes/nodes_topaz.py @@ -0,0 +1,418 @@ +import builtins +from io import BytesIO + +import aiohttp +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.apis import topaz_api +from comfy_api_nodes.util import ( + ApiEndpoint, + download_url_to_image_tensor, + download_url_to_video_output, + get_fs_object_size, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_container_format_is_mp4, +) + +UPSCALER_MODELS_MAP = { + "Starlight (Astra) Fast": "slf-1", + "Starlight (Astra) Creative": "slc-1", +} +UPSCALER_VALUES_MAP = { + "FullHD (1080p)": 1920, + "4K (2160p)": 3840, +} + + +class TopazImageEnhance(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TopazImageEnhance", + display_name="Topaz Image Enhance", + category="api node/image/Topaz", + description="Industry-standard upscaling and image enhancement.", + inputs=[ + IO.Combo.Input("model", options=["Reimagine"]), + IO.Image.Input("image"), + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Optional text prompt for creative upscaling guidance.", + optional=True, + ), + IO.Combo.Input( + "subject_detection", + options=["All", "Foreground", "Background"], + optional=True, + ), + IO.Boolean.Input( + "face_enhancement", + default=True, + optional=True, + tooltip="Enhance faces (if present) during processing.", + ), + IO.Float.Input( + "face_enhancement_creativity", + default=0.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + optional=True, + tooltip="Set the creativity level for face enhancement.", + ), + IO.Float.Input( + "face_enhancement_strength", + default=1.0, + min=0.0, + max=1.0, + step=0.01, + display_mode=IO.NumberDisplay.number, + optional=True, + tooltip="Controls how sharp enhanced faces are relative to the background.", + ), + IO.Boolean.Input( + "crop_to_fill", + default=False, + optional=True, + tooltip="By default, the image is letterboxed when the output aspect ratio differs. " + "Enable to crop the image to fill the output dimensions.", + ), + IO.Int.Input( + "output_width", + default=0, + min=0, + max=32000, + step=1, + display_mode=IO.NumberDisplay.number, + optional=True, + tooltip="Zero value means to calculate automatically (usually it will be original size or output_height if specified).", + ), + IO.Int.Input( + "output_height", + default=0, + min=0, + max=32000, + step=1, + display_mode=IO.NumberDisplay.number, + optional=True, + tooltip="Zero value means to output in the same height as original or output width.", + ), + IO.Int.Input( + "creativity", + default=3, + min=1, + max=9, + step=1, + display_mode=IO.NumberDisplay.slider, + optional=True, + ), + IO.Boolean.Input( + "face_preservation", + default=True, + optional=True, + tooltip="Preserve subjects' facial identity.", + ), + IO.Boolean.Input( + "color_preservation", + default=True, + optional=True, + tooltip="Preserve the original colors.", + ), + ], + outputs=[ + IO.Image.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + model: str, + image: torch.Tensor, + prompt: str = "", + subject_detection: str = "All", + face_enhancement: bool = True, + face_enhancement_creativity: float = 1.0, + face_enhancement_strength: float = 0.8, + crop_to_fill: bool = False, + output_width: int = 0, + output_height: int = 0, + creativity: int = 3, + face_preservation: bool = True, + color_preservation: bool = True, + ) -> IO.NodeOutput: + if get_number_of_images(image) != 1: + raise ValueError("Only one input image is supported.") + download_url = await upload_images_to_comfyapi(cls, image, max_images=1, mime_type="image/png") + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/topaz/image/v1/enhance-gen/async", method="POST"), + response_model=topaz_api.ImageAsyncTaskResponse, + data=topaz_api.ImageEnhanceRequest( + model=model, + prompt=prompt, + subject_detection=subject_detection, + face_enhancement=face_enhancement, + face_enhancement_creativity=face_enhancement_creativity, + face_enhancement_strength=face_enhancement_strength, + crop_to_fill=crop_to_fill, + output_width=output_width if output_width else None, + output_height=output_height if output_height else None, + creativity=creativity, + face_preservation=str(face_preservation).lower(), + color_preservation=str(color_preservation).lower(), + source_url=download_url[0], + output_format="png", + ), + content_type="multipart/form-data", + ) + + await poll_op( + cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/topaz/image/v1/status/{initial_response.process_id}"), + response_model=topaz_api.ImageStatusResponse, + status_extractor=lambda x: x.status, + progress_extractor=lambda x: getattr(x, "progress", 0), + price_extractor=lambda x: x.credits * 0.08, + poll_interval=8.0, + max_poll_attempts=160, + estimated_duration=60, + ) + + results = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/topaz/image/v1/download/{initial_response.process_id}"), + response_model=topaz_api.ImageDownloadResponse, + monitor_progress=False, + ) + return IO.NodeOutput(await download_url_to_image_tensor(results.download_url)) + + +class TopazVideoEnhance(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TopazVideoEnhance", + display_name="Topaz Video Enhance", + category="api node/video/Topaz", + description="Breathe new life into video with powerful upscaling and recovery technology.", + inputs=[ + IO.Video.Input("video"), + IO.Boolean.Input("upscaler_enabled", default=True), + IO.Combo.Input("upscaler_model", options=list(UPSCALER_MODELS_MAP.keys())), + IO.Combo.Input("upscaler_resolution", options=list(UPSCALER_VALUES_MAP.keys())), + IO.Combo.Input( + "upscaler_creativity", + options=["low", "middle", "high"], + default="low", + tooltip="Creativity level (applies only to Starlight (Astra) Creative).", + optional=True, + ), + IO.Boolean.Input("interpolation_enabled", default=False, optional=True), + IO.Combo.Input("interpolation_model", options=["apo-8"], default="apo-8", optional=True), + IO.Int.Input( + "interpolation_slowmo", + default=1, + min=1, + max=16, + display_mode=IO.NumberDisplay.number, + tooltip="Slow-motion factor applied to the input video. " + "For example, 2 makes the output twice as slow and doubles the duration.", + optional=True, + ), + IO.Int.Input( + "interpolation_frame_rate", + default=60, + min=15, + max=240, + display_mode=IO.NumberDisplay.number, + tooltip="Output frame rate.", + optional=True, + ), + IO.Boolean.Input( + "interpolation_duplicate", + default=False, + tooltip="Analyze the input for duplicate frames and remove them.", + optional=True, + ), + IO.Float.Input( + "interpolation_duplicate_threshold", + default=0.01, + min=0.001, + max=0.1, + step=0.001, + display_mode=IO.NumberDisplay.number, + tooltip="Detection sensitivity for duplicate frames.", + optional=True, + ), + IO.Combo.Input( + "dynamic_compression_level", + options=["Low", "Mid", "High"], + default="Low", + tooltip="CQP level.", + optional=True, + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + video: Input.Video, + upscaler_enabled: bool, + upscaler_model: str, + upscaler_resolution: str, + upscaler_creativity: str = "low", + interpolation_enabled: bool = False, + interpolation_model: str = "apo-8", + interpolation_slowmo: int = 1, + interpolation_frame_rate: int = 60, + interpolation_duplicate: bool = False, + interpolation_duplicate_threshold: float = 0.01, + dynamic_compression_level: str = "Low", + ) -> IO.NodeOutput: + if upscaler_enabled is False and interpolation_enabled is False: + raise ValueError("There is nothing to do: both upscaling and interpolation are disabled.") + validate_container_format_is_mp4(video) + src_width, src_height = video.get_dimensions() + src_frame_rate = int(video.get_frame_rate()) + duration_sec = video.get_duration() + src_video_stream = video.get_stream_source() + target_width = src_width + target_height = src_height + target_frame_rate = src_frame_rate + filters = [] + if upscaler_enabled: + target_width = UPSCALER_VALUES_MAP[upscaler_resolution] + target_height = UPSCALER_VALUES_MAP[upscaler_resolution] + filters.append( + topaz_api.VideoEnhancementFilter( + model=UPSCALER_MODELS_MAP[upscaler_model], + creativity=(upscaler_creativity if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), + isOptimizedMode=(True if UPSCALER_MODELS_MAP[upscaler_model] == "slc-1" else None), + ), + ) + if interpolation_enabled: + target_frame_rate = interpolation_frame_rate + filters.append( + topaz_api.VideoFrameInterpolationFilter( + model=interpolation_model, + slowmo=interpolation_slowmo, + fps=interpolation_frame_rate, + duplicate=interpolation_duplicate, + duplicate_threshold=interpolation_duplicate_threshold, + ), + ) + initial_res = await sync_op( + cls, + ApiEndpoint(path="/proxy/topaz/video/", method="POST"), + response_model=topaz_api.CreateVideoResponse, + data=topaz_api.CreateVideoRequest( + source=topaz_api.CreateCreateVideoRequestSource( + container="mp4", + size=get_fs_object_size(src_video_stream), + duration=int(duration_sec), + frameCount=video.get_frame_count(), + frameRate=src_frame_rate, + resolution=topaz_api.Resolution(width=src_width, height=src_height), + ), + filters=filters, + output=topaz_api.OutputInformationVideo( + resolution=topaz_api.Resolution(width=target_width, height=target_height), + frameRate=target_frame_rate, + audioCodec="AAC", + audioTransfer="Copy", + dynamicCompressionLevel=dynamic_compression_level, + ), + ), + wait_label="Creating task", + final_label_on_success="Task created", + ) + upload_res = await sync_op( + cls, + ApiEndpoint( + path=f"/proxy/topaz/video/{initial_res.requestId}/accept", + method="PATCH", + ), + response_model=topaz_api.VideoAcceptResponse, + wait_label="Preparing upload", + final_label_on_success="Upload started", + ) + if len(upload_res.urls) > 1: + raise NotImplementedError( + "Large files are not currently supported. Please open an issue in the ComfyUI repository." + ) + async with aiohttp.ClientSession(headers={"Content-Type": "video/mp4"}) as session: + if isinstance(src_video_stream, BytesIO): + src_video_stream.seek(0) + async with session.put(upload_res.urls[0], data=src_video_stream, raise_for_status=True) as res: + upload_etag = res.headers["Etag"] + else: + with builtins.open(src_video_stream, "rb") as video_file: + async with session.put(upload_res.urls[0], data=video_file, raise_for_status=True) as res: + upload_etag = res.headers["Etag"] + await sync_op( + cls, + ApiEndpoint( + path=f"/proxy/topaz/video/{initial_res.requestId}/complete-upload", + method="PATCH", + ), + response_model=topaz_api.VideoCompleteUploadResponse, + data=topaz_api.VideoCompleteUploadRequest( + uploadResults=[ + topaz_api.VideoCompleteUploadRequestPart( + partNum=1, + eTag=upload_etag, + ), + ], + ), + wait_label="Finalizing upload", + final_label_on_success="Upload completed", + ) + final_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/topaz/video/{initial_res.requestId}/status"), + response_model=topaz_api.VideoStatusResponse, + status_extractor=lambda x: x.status, + progress_extractor=lambda x: getattr(x, "progress", 0), + price_extractor=lambda x: (x.estimates.cost[0] * 0.08 if x.estimates and x.estimates.cost[0] else None), + poll_interval=10.0, + max_poll_attempts=320, + ) + return IO.NodeOutput(await download_url_to_video_output(final_response.download.url)) + + +class TopazExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TopazImageEnhance, + TopazVideoEnhance, + ] + + +async def comfy_entrypoint() -> TopazExtension: + return TopazExtension() diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index a5162946d..41aeebd2e 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -1,46 +1,39 @@ import os -from comfy.cmd.folder_paths import get_output_directory -from comfy_api_nodes.mapper_utils import model_field_to_node_input -from comfy.comfy_types.node_typing import IO -from comfy_api_nodes.apis import ( - TripoOrientation, - TripoModelVersion, -) +from typing import Optional + +import torch +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension from comfy_api_nodes.apis.tripo_api import ( - TripoTaskType, - TripoStyle, - TripoFileReference, + TripoAnimateRetargetRequest, + TripoAnimateRigRequest, + TripoConvertModelRequest, TripoFileEmptyReference, - TripoUrlReference, + TripoFileReference, + TripoImageToModelRequest, + TripoModelVersion, + TripoMultiviewToModelRequest, + TripoOrientation, + TripoRefineModelRequest, + TripoStyle, TripoTaskResponse, TripoTaskStatus, + TripoTaskType, TripoTextToModelRequest, - TripoImageToModelRequest, - TripoMultiviewToModelRequest, TripoTextureModelRequest, - TripoRefineModelRequest, - TripoAnimateRigRequest, - TripoAnimateRetargetRequest, - TripoConvertModelRequest, + TripoUrlReference, ) - -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, -) -from comfy_api_nodes.apinode_utils import ( + download_url_as_bytesio, + poll_op, + sync_op, upload_images_to_comfyapi, - download_url_to_bytesio, ) +from comfy.cmd.folder_paths import get_output_directory -async def upload_image_to_tripo(image, **kwargs): - urls = await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs) - return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg")) - def get_model_url_from_response(response: TripoTaskResponse) -> str: if response.data is not None: for key in ["pbr_model", "model", "base_model"]: @@ -50,20 +43,18 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str: async def poll_until_finished( - kwargs: dict[str, str], - response: TripoTaskResponse, -) -> tuple[str, str]: + node_cls: type[IO.ComfyNode], + response: TripoTaskResponse, + average_duration: Optional[int] = None, +) -> IO.NodeOutput: """Polls the Tripo API endpoint until the task reaches a terminal state, then returns the response.""" if response.code != 0: raise RuntimeError(f"Failed to generate mesh: {response.error}") task_id = response.data.task_id - response_poll = await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/tripo/v2/openapi/task/{task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TripoTaskResponse, - ), + response_poll = await poll_op( + node_cls, + poll_endpoint=ApiEndpoint(path=f"/proxy/tripo/v2/openapi/task/{task_id}"), + response_model=TripoTaskResponse, completed_statuses=[TripoTaskStatus.SUCCESS], failed_statuses=[ TripoTaskStatus.FAILED, @@ -73,72 +64,84 @@ async def poll_until_finished( TripoTaskStatus.EXPIRED, ], status_extractor=lambda x: x.data.status, - auth_kwargs=kwargs, - node_id=kwargs["unique_id"], - result_url_extractor=get_model_url_from_response, progress_extractor=lambda x: x.data.progress, - ).execute() + estimated_duration=average_duration, + ) if response_poll.data.status == TripoTaskStatus.SUCCESS: url = get_model_url_from_response(response_poll) - bytesio = await download_url_to_bytesio(url) + bytesio = await download_url_as_bytesio(url) # Save the downloaded model file model_file = f"tripo_model_{task_id}.glb" with open(os.path.join(get_output_directory(), model_file), "wb") as f: f.write(bytesio.getvalue()) - return model_file, task_id + return IO.NodeOutput(model_file, task_id) raise RuntimeError(f"Failed to generate mesh: {response_poll}") -class TripoTextToModelNode: +class TripoTextToModelNode(IO.ComfyNode): """ Generates 3D models synchronously based on a text prompt using Tripo's API. """ - AVERAGE_DURATION = 80 + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "prompt": ("STRING", {"multiline": True}), - }, - "optional": { - "negative_prompt": ("STRING", {"multiline": True}), - "model_version": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "model_version", enum_type=TripoModelVersion), - "style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"), - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "image_seed": ("INT", {"default": 42}), - "model_seed": ("INT", {"default": 42}), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "quad": ("BOOLEAN", {"default": False}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoTextToModelNode", + display_name="Tripo: Text to Model", + category="api node/3d/Tripo", + inputs=[ + IO.String.Input("prompt", multiline=True), + IO.String.Input("negative_prompt", multiline=True, optional=True), + IO.Combo.Input( + "model_version", options=TripoModelVersion, default=TripoModelVersion.v2_5_20250123, optional=True + ), + IO.Combo.Input("style", options=TripoStyle, default="None", optional=True), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("image_seed", default=42, optional=True), + IO.Int.Input("model_seed", default=42, optional=True), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), + IO.Boolean.Input("quad", default=False, optional=True), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - - async def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: Optional[str] = None, + model_version=None, + style: Optional[str] = None, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + image_seed: Optional[int] = None, + model_seed: Optional[int] = None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + face_limit: Optional[int] = None, + quad: Optional[bool] = None, + ) -> IO.NodeOutput: style_enum = None if style == "None" else style if not prompt: raise RuntimeError("Prompt is required") - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoTextToModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoTextToModelRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoTextToModelRequest( type=TripoTaskType.TEXT_TO_MODEL, prompt=prompt, negative_prompt=negative_prompt if negative_prompt else None, @@ -152,64 +155,89 @@ class TripoTextToModelNode: texture_quality=texture_quality, face_limit=face_limit, auto_size=True, - quad=quad + quad=quad, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoImageToModelNode: +class TripoImageToModelNode(IO.ComfyNode): """ Generates 3D models synchronously based on a single image using Tripo's API. """ - AVERAGE_DURATION = 80 + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - }, - "optional": { - "model_version": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "model_version", enum_type=TripoModelVersion), - "style": model_field_to_node_input(IO.COMBO, TripoTextToModelRequest, "style", enum_type=TripoStyle, default="None"), - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "model_seed": ("INT", {"default": 42}), - "orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "quad": ("BOOLEAN", {"default": False}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoImageToModelNode", + display_name="Tripo: Image to Model", + category="api node/3d/Tripo", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input( + "model_version", + options=TripoModelVersion, + tooltip="The model version to use for generation", + optional=True, + ), + IO.Combo.Input("style", options=TripoStyle, default="None", optional=True), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("model_seed", default=42, optional=True), + IO.Combo.Input( + "orientation", options=TripoOrientation, default=TripoOrientation.DEFAULT, optional=True + ), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Combo.Input( + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + ), + IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), + IO.Boolean.Input("quad", default=False, optional=True), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - - async def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + @classmethod + async def execute( + cls, + image: torch.Tensor, + model_version: Optional[str] = None, + style: Optional[str] = None, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + model_seed: Optional[int] = None, + orientation=None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + texture_alignment: Optional[str] = None, + face_limit: Optional[int] = None, + quad: Optional[bool] = None, + ) -> IO.NodeOutput: style_enum = None if style == "None" else style if image is None: raise RuntimeError("Image is required") - tripo_file = await upload_image_to_tripo(image, **kwargs) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoImageToModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoImageToModelRequest( + tripo_file = TripoFileReference( + root=TripoUrlReference( + url=(await upload_images_to_comfyapi(cls, image, max_images=1))[0], + type="jpeg", + ) + ) + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoImageToModelRequest( type=TripoTaskType.IMAGE_TO_MODEL, file=tripo_file, model_version=model_version, @@ -223,80 +251,105 @@ class TripoImageToModelNode: texture_quality=texture_quality, face_limit=face_limit, auto_size=True, - quad=quad + quad=quad, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoMultiviewToModelNode: +class TripoMultiviewToModelNode(IO.ComfyNode): """ Generates 3D models synchronously based on up to four images (front, left, back, right) using Tripo's API. """ - AVERAGE_DURATION = 80 + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - }, - "optional": { - "image_left": ("IMAGE",), - "image_back": ("IMAGE",), - "image_right": ("IMAGE",), - "model_version": model_field_to_node_input(IO.COMBO, TripoMultiviewToModelRequest, "model_version", enum_type=TripoModelVersion), - "orientation": model_field_to_node_input(IO.COMBO, TripoImageToModelRequest, "orientation", enum_type=TripoOrientation), - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "model_seed": ("INT", {"default": 42}), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "quad": ("BOOLEAN", {"default": False}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoMultiviewToModelNode", + display_name="Tripo: Multiview to Model", + category="api node/3d/Tripo", + inputs=[ + IO.Image.Input("image"), + IO.Image.Input("image_left", optional=True), + IO.Image.Input("image_back", optional=True), + IO.Image.Input("image_right", optional=True), + IO.Combo.Input( + "model_version", + options=TripoModelVersion, + optional=True, + tooltip="The model version to use for generation", + ), + IO.Combo.Input( + "orientation", + options=TripoOrientation, + default=TripoOrientation.DEFAULT, + optional=True, + ), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("model_seed", default=42, optional=True), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Combo.Input( + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + ), + IO.Int.Input("face_limit", default=-1, min=-1, max=500000, optional=True), + IO.Boolean.Input("quad", default=False, optional=True), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - - async def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs): + @classmethod + async def execute( + cls, + image: torch.Tensor, + image_left: Optional[torch.Tensor] = None, + image_back: Optional[torch.Tensor] = None, + image_right: Optional[torch.Tensor] = None, + model_version: Optional[str] = None, + orientation: Optional[str] = None, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + model_seed: Optional[int] = None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + texture_alignment: Optional[str] = None, + face_limit: Optional[int] = None, + quad: Optional[bool] = None, + ) -> IO.NodeOutput: if image is None: raise RuntimeError("front image for multiview is required") images = [] - image_dict = { - "image": image, - "image_left": image_left, - "image_back": image_back, - "image_right": image_right - } + image_dict = {"image": image, "image_left": image_left, "image_back": image_back, "image_right": image_right} if image_left is None and image_back is None and image_right is None: raise RuntimeError("At least one of left, back, or right image must be provided for multiview") for image_name in ["image", "image_left", "image_back", "image_right"]: image_ = image_dict[image_name] if image_ is not None: - tripo_file = await upload_image_to_tripo(image_, **kwargs) - images.append(tripo_file) + images.append( + TripoFileReference( + root=TripoUrlReference( + url=(await upload_images_to_comfyapi(cls, image_, max_images=1))[0], type="jpeg" + ) + ) + ) else: images.append(TripoFileEmptyReference()) - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoMultiviewToModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoMultiviewToModelRequest( + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoMultiviewToModelRequest( type=TripoTaskType.MULTIVIEW_TO_MODEL, files=images, model_version=model_version, @@ -310,272 +363,283 @@ class TripoMultiviewToModelNode: face_limit=face_limit, quad=quad, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoTextureNode: +class TripoTextureNode(IO.ComfyNode): + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model_task_id": ("MODEL_TASK_ID",), - }, - "optional": { - "texture": ("BOOLEAN", {"default": True}), - "pbr": ("BOOLEAN", {"default": True}), - "texture_seed": ("INT", {"default": 42}), - "texture_quality": (["standard", "detailed"], {"default": "standard"}), - "texture_alignment": (["original_image", "geometry"], {"default": "original_image"}), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoTextureNode", + display_name="Tripo: Texture model", + category="api node/3d/Tripo", + inputs=[ + IO.Custom("MODEL_TASK_ID").Input("model_task_id"), + IO.Boolean.Input("texture", default=True, optional=True), + IO.Boolean.Input("pbr", default=True, optional=True), + IO.Int.Input("texture_seed", default=42, optional=True), + IO.Combo.Input("texture_quality", default="standard", options=["standard", "detailed"], optional=True), + IO.Combo.Input( + "texture_alignment", default="original_image", options=["original_image", "geometry"], optional=True + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 80 - - async def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs): - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoTextureModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoTextureModelRequest( + @classmethod + async def execute( + cls, + model_task_id, + texture: Optional[bool] = None, + pbr: Optional[bool] = None, + texture_seed: Optional[int] = None, + texture_quality: Optional[str] = None, + texture_alignment: Optional[str] = None, + ) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoTextureModelRequest( original_model_task_id=model_task_id, texture=texture, pbr=pbr, texture_seed=texture_seed, texture_quality=texture_quality, - texture_alignment=texture_alignment + texture_alignment=texture_alignment, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=80) -class TripoRefineNode: +class TripoRefineNode(IO.ComfyNode): + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model_task_id": ("MODEL_TASK_ID", { - "tooltip": "Must be a v1.4 Tripo model" - }), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoRefineNode", + display_name="Tripo: Refine Draft model", + category="api node/3d/Tripo", + description="Refine a draft model created by v1.4 Tripo models only.", + inputs=[ + IO.Custom("MODEL_TASK_ID").Input("model_task_id", tooltip="Must be a v1.4 Tripo model"), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - DESCRIPTION = "Refine a draft model created by v1.4 Tripo models only." - - RETURN_TYPES = ("STRING", "MODEL_TASK_ID",) - RETURN_NAMES = ("model_file", "model task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 240 - - async def generate_mesh(self, model_task_id, **kwargs): - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoRefineModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoRefineModelRequest( - draft_model_task_id=model_task_id - ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) - - -class TripoRigNode: @classmethod - def INPUT_TYPES(s): - return { - "required": { - "original_model_task_id": ("MODEL_TASK_ID",), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } - - RETURN_TYPES = ("STRING", "RIG_TASK_ID") - RETURN_NAMES = ("model_file", "rig task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 180 - - async def generate_mesh(self, original_model_task_id, **kwargs): - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoAnimateRigRequest, - response_model=TripoTaskResponse, - ), - request=TripoAnimateRigRequest( - original_model_task_id=original_model_task_id, - out_format="glb", - spec="tripo" - ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + async def execute(cls, model_task_id) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoRefineModelRequest(draft_model_task_id=model_task_id), + ) + return await poll_until_finished(cls, response, average_duration=240) -class TripoRetargetNode: +class TripoRigNode(IO.ComfyNode): + @classmethod - def INPUT_TYPES(s): - return { - "required": { - "original_model_task_id": ("RIG_TASK_ID",), - "animation": ([ - "preset:idle", - "preset:walk", - "preset:climb", - "preset:jump", - "preset:slash", - "preset:shoot", - "preset:hurt", - "preset:fall", - "preset:turn", - ],), - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TripoRigNode", + display_name="Tripo: Rig model", + category="api node/3d/Tripo", + inputs=[IO.Custom("MODEL_TASK_ID").Input("original_model_task_id")], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("RIG_TASK_ID").Output(display_name="rig task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) - RETURN_TYPES = ("STRING", "RETARGET_TASK_ID") - RETURN_NAMES = ("model_file", "retarget task_id") - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 30 + @classmethod + async def execute(cls, original_model_task_id) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoAnimateRigRequest(original_model_task_id=original_model_task_id, out_format="glb", spec="tripo"), + ) + return await poll_until_finished(cls, response, average_duration=180) - async def generate_mesh(self, animation, original_model_task_id, **kwargs): - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoAnimateRetargetRequest, - response_model=TripoTaskResponse, - ), - request=TripoAnimateRetargetRequest( + +class TripoRetargetNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="TripoRetargetNode", + display_name="Tripo: Retarget rigged model", + category="api node/3d/Tripo", + inputs=[ + IO.Custom("RIG_TASK_ID").Input("original_model_task_id"), + IO.Combo.Input( + "animation", + options=[ + "preset:idle", + "preset:walk", + "preset:climb", + "preset:jump", + "preset:slash", + "preset:shoot", + "preset:hurt", + "preset:fall", + "preset:turn", + ], + ), + ], + outputs=[ + IO.String.Output(display_name="model_file"), + IO.Custom("RETARGET_TASK_ID").Output(display_name="retarget task_id"), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) + + @classmethod + async def execute(cls, original_model_task_id, animation: str) -> IO.NodeOutput: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoAnimateRetargetRequest( original_model_task_id=original_model_task_id, animation=animation, out_format="glb", - bake_animation=True + bake_animation=True, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=30) -class TripoConversionNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "original_model_task_id": ("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID",), - "format": (["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"],), - }, - "optional": { - "quad": ("BOOLEAN", {"default": False}), - "face_limit": ("INT", {"min": -1, "max": 500000, "default": -1}), - "texture_size": ("INT", {"min": 128, "max": 4096, "default": 4096}), - "texture_format": (["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], {"default": "JPEG"}) - }, - "hidden": { - "auth_token": "AUTH_TOKEN_COMFY_ORG", - "comfy_api_key": "API_KEY_COMFY_ORG", - "unique_id": "UNIQUE_ID", - }, - } +class TripoConversionNode(IO.ComfyNode): @classmethod - def VALIDATE_INPUTS(cls, input_types): + def define_schema(cls): + return IO.Schema( + node_id="TripoConversionNode", + display_name="Tripo: Convert model", + category="api node/3d/Tripo", + inputs=[ + IO.Custom("MODEL_TASK_ID,RIG_TASK_ID,RETARGET_TASK_ID").Input("original_model_task_id"), + IO.Combo.Input("format", options=["GLTF", "USDZ", "FBX", "OBJ", "STL", "3MF"]), + IO.Boolean.Input("quad", default=False, optional=True), + IO.Int.Input( + "face_limit", + default=-1, + min=-1, + max=500000, + optional=True, + ), + IO.Int.Input( + "texture_size", + default=4096, + min=128, + max=4096, + optional=True, + ), + IO.Combo.Input( + "texture_format", + options=["BMP", "DPX", "HDR", "JPEG", "OPEN_EXR", "PNG", "TARGA", "TIFF", "WEBP"], + default="JPEG", + optional=True, + ), + ], + outputs=[], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + is_output_node=True, + ) + + @classmethod + def validate_inputs(cls, input_types): # The min and max of input1 and input2 are still validated because # we didn't take `input1` or `input2` as arguments if input_types["original_model_task_id"] not in ("MODEL_TASK_ID", "RIG_TASK_ID", "RETARGET_TASK_ID"): return "original_model_task_id must be MODEL_TASK_ID, RIG_TASK_ID or RETARGET_TASK_ID type" return True - RETURN_TYPES = () - FUNCTION = "generate_mesh" - CATEGORY = "api node/3d/Tripo" - API_NODE = True - OUTPUT_NODE = True - AVERAGE_DURATION = 30 - - async def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs): + @classmethod + async def execute( + cls, + original_model_task_id, + format: str, + quad: bool, + face_limit: int, + texture_size: int, + texture_format: str, + ) -> IO.NodeOutput: if not original_model_task_id: raise RuntimeError("original_model_task_id is required") - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path="/proxy/tripo/v2/openapi/task", - method=HttpMethod.POST, - request_model=TripoConvertModelRequest, - response_model=TripoTaskResponse, - ), - request=TripoConvertModelRequest( + response = await sync_op( + cls, + endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/task", method="POST"), + response_model=TripoTaskResponse, + data=TripoConvertModelRequest( original_model_task_id=original_model_task_id, format=format, quad=quad if quad else None, face_limit=face_limit if face_limit != -1 else None, texture_size=texture_size if texture_size != 4096 else None, - texture_format=texture_format if texture_format != "JPEG" else None + texture_format=texture_format if texture_format != "JPEG" else None, ), - auth_kwargs=kwargs, - ).execute() - return await poll_until_finished(kwargs, response) + ) + return await poll_until_finished(cls, response, average_duration=30) -NODE_CLASS_MAPPINGS = { - "TripoTextToModelNode": TripoTextToModelNode, - "TripoImageToModelNode": TripoImageToModelNode, - "TripoMultiviewToModelNode": TripoMultiviewToModelNode, - "TripoTextureNode": TripoTextureNode, - "TripoRefineNode": TripoRefineNode, - "TripoRigNode": TripoRigNode, - "TripoRetargetNode": TripoRetargetNode, - "TripoConversionNode": TripoConversionNode, -} +class TripoExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + TripoTextToModelNode, + TripoImageToModelNode, + TripoMultiviewToModelNode, + TripoTextureNode, + TripoRefineNode, + TripoRigNode, + TripoRetargetNode, + TripoConversionNode, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "TripoTextToModelNode": "Tripo: Text to Model", - "TripoImageToModelNode": "Tripo: Image to Model", - "TripoMultiviewToModelNode": "Tripo: Multiview to Model", - "TripoTextureNode": "Tripo: Texture model", - "TripoRefineNode": "Tripo: Refine Draft model", - "TripoRigNode": "Tripo: Rig model", - "TripoRetargetNode": "Tripo: Retarget rigged model", - "TripoConversionNode": "Tripo: Convert model", -} + +async def comfy_entrypoint() -> TripoExtension: + return TripoExtension() diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index daeaa823e..e165b8380 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,28 +1,23 @@ -import logging import base64 -import aiohttp -import torch from io import BytesIO -from typing import Optional + from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api.input_impl.video_types import VideoFromFile -from comfy_api_nodes.apis import ( - VeoGenVidRequest, - VeoGenVidResponse, +from comfy_api.latest import IO, ComfyExtension, Input, InputImpl +from comfy_api_nodes.apis.veo_api import ( VeoGenVidPollRequest, VeoGenVidPollResponse, + VeoGenVidRequest, + VeoGenVidResponse, + VeoRequestInstance, + VeoRequestInstanceImage, + VeoRequestParameters, ) -from comfy_api_nodes.apis.client import ( +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, -) - -from comfy_api_nodes.apinode_utils import ( - downscale_image_tensor, + download_url_to_video_output, + poll_op, + sync_op, tensor_to_base64_string, ) @@ -35,28 +30,6 @@ MODELS_MAP = { "veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001", } -def convert_image_to_base64(image: torch.Tensor): - if image is None: - return None - - scaled_image = downscale_image_tensor(image, total_pixels=2048*2048) - return tensor_to_base64_string(scaled_image) - - -def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]: - if ( - poll_response.response - and hasattr(poll_response.response, "videos") - and poll_response.response.videos - and len(poll_response.response.videos) > 0 - ): - video = poll_response.response.videos[0] - else: - return None - if hasattr(video, "gcsUri") and video.gcsUri: - return str(video.gcsUri) - return None - class VeoVideoGenerationNode(IO.ComfyNode): """ @@ -169,18 +142,13 @@ class VeoVideoGenerationNode(IO.ComfyNode): # Prepare the instances for the request instances = [] - instance = { - "prompt": prompt - } + instance = {"prompt": prompt} # Add image if provided if image is not None: - image_base64 = convert_image_to_base64(image) + image_base64 = tensor_to_base64_string(image) if image_base64: - instance["image"] = { - "bytesBase64Encoded": image_base64, - "mimeType": "image/png" - } + instance["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"} instances.append(instance) @@ -198,119 +166,77 @@ class VeoVideoGenerationNode(IO.ComfyNode): if seed > 0: parameters["seed"] = seed # Only add generateAudio for Veo 3 models - if "veo-3.0" in model: + if model.find("veo-2.0") == -1: parameters["generateAudio"] = generate_audio - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - # Initial request to start video generation - initial_operation = SynchronousOperation( - endpoint=ApiEndpoint( - path=f"/proxy/veo/{model}/generate", - method=HttpMethod.POST, - request_model=VeoGenVidRequest, - response_model=VeoGenVidResponse - ), - request=VeoGenVidRequest( + initial_response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + response_model=VeoGenVidResponse, + data=VeoGenVidRequest( instances=instances, - parameters=parameters + parameters=parameters, ), - auth_kwargs=auth, ) - initial_response = await initial_operation.execute() - operation_name = initial_response.name - - logging.info("Veo generation started with operation name: %s", operation_name) - - # Define status extractor function def status_extractor(response): # Only return "completed" if the operation is done, regardless of success or failure # We'll check for errors after polling completes return "completed" if response.done else "pending" - # Define progress extractor function - def progress_extractor(response): - # Could be enhanced if the API provides progress information - return None - - # Define the polling operation - poll_operation = PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/veo/{model}/poll", - method=HttpMethod.POST, - request_model=VeoGenVidPollRequest, - response_model=VeoGenVidPollResponse - ), - completed_statuses=["completed"], - failed_statuses=[], # No failed statuses, we'll handle errors after polling + poll_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + response_model=VeoGenVidPollResponse, status_extractor=status_extractor, - progress_extractor=progress_extractor, - request=VeoGenVidPollRequest( - operationName=operation_name + data=VeoGenVidPollRequest( + operationName=initial_response.name, ), - auth_kwargs=auth, poll_interval=5.0, - result_url_extractor=get_video_url_from_response, - node_id=cls.hidden.unique_id, estimated_duration=AVERAGE_DURATION_VIDEO_GEN, ) - # Execute the polling operation - poll_response = await poll_operation.execute() - # Now check for errors in the final response # Check for error in poll response - if hasattr(poll_response, 'error') and poll_response.error: - error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})" - logging.error(error_message) - raise Exception(error_message) + if poll_response.error: + raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})") # Check for RAI filtered content - if (hasattr(poll_response.response, 'raiMediaFilteredCount') and - poll_response.response.raiMediaFilteredCount > 0): + if ( + hasattr(poll_response.response, "raiMediaFilteredCount") + and poll_response.response.raiMediaFilteredCount > 0 + ): # Extract reason message if available - if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and - poll_response.response.raiMediaFilteredReasons): + if ( + hasattr(poll_response.response, "raiMediaFilteredReasons") + and poll_response.response.raiMediaFilteredReasons + ): reason = poll_response.response.raiMediaFilteredReasons[0] error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)" else: error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)" - logging.error(error_message) raise Exception(error_message) # Extract video data - if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0: + if ( + poll_response.response + and hasattr(poll_response.response, "videos") + and poll_response.response.videos + and len(poll_response.response.videos) > 0 + ): video = poll_response.response.videos[0] # Check if video is provided as base64 or URL - if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded: - # Decode base64 string to bytes - video_data = base64.b64decode(video.bytesBase64Encoded) - elif hasattr(video, 'gcsUri') and video.gcsUri: - # Download from URL - async with aiohttp.ClientSession() as session: - async with session.get(video.gcsUri) as video_response: - video_data = await video_response.content.read() - else: - raise Exception("Video returned but no data or URL was provided") - else: - raise Exception("Video generation completed but no video was returned") + if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded: + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) - if not video_data: - raise Exception("No video data was returned") + if hasattr(video, "gcsUri") and video.gcsUri: + return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) - logging.info("Video generation completed successfully") - - # Convert video data to BytesIO object - video_io = BytesIO(video_data) - - # Return VideoFromFile object - return IO.NodeOutput(VideoFromFile(video_io)) + raise Exception("Video returned but no data or URL was provided") + raise Exception("Video generation completed but no video was returned") class Veo3VideoGenerationNode(VeoVideoGenerationNode): @@ -394,7 +320,10 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): IO.Combo.Input( "model", options=[ - "veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.0-generate-001", "veo-3.0-fast-generate-001" + "veo-3.1-generate", + "veo-3.1-fast-generate", + "veo-3.0-generate-001", + "veo-3.0-fast-generate-001", ], default="veo-3.0-generate-001", tooltip="Veo 3 model to use for video generation", @@ -419,13 +348,165 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode): ) +class Veo3FirstLastFrameNode(IO.ComfyNode): + + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="Veo3FirstLastFrameNode", + display_name="Google Veo 3 First-Last-Frame to Video", + category="api node/video/Veo", + description="Generate video using prompt and first and last frames.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the video", + ), + IO.String.Input( + "negative_prompt", + multiline=True, + default="", + tooltip="Negative text prompt to guide what to avoid in the video", + ), + IO.Combo.Input("resolution", options=["720p", "1080p"]), + IO.Combo.Input( + "aspect_ratio", + options=["16:9", "9:16"], + default="16:9", + tooltip="Aspect ratio of the output video", + ), + IO.Int.Input( + "duration", + default=8, + min=4, + max=8, + step=2, + display_mode=IO.NumberDisplay.slider, + tooltip="Duration of the output video in seconds", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFF, + step=1, + display_mode=IO.NumberDisplay.number, + control_after_generate=True, + tooltip="Seed for video generation", + ), + IO.Image.Input("first_frame", tooltip="Start frame"), + IO.Image.Input("last_frame", tooltip="End frame"), + IO.Combo.Input( + "model", + options=["veo-3.1-generate", "veo-3.1-fast-generate"], + default="veo-3.1-fast-generate", + ), + IO.Boolean.Input( + "generate_audio", + default=True, + tooltip="Generate audio for the video.", + ), + ], + outputs=[ + IO.Video.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + ) + + @classmethod + async def execute( + cls, + prompt: str, + negative_prompt: str, + resolution: str, + aspect_ratio: str, + duration: int, + seed: int, + first_frame: Input.Image, + last_frame: Input.Image, + model: str, + generate_audio: bool, + ): + model = MODELS_MAP[model] + initial_response = await sync_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"), + response_model=VeoGenVidResponse, + data=VeoGenVidRequest( + instances=[ + VeoRequestInstance( + prompt=prompt, + image=VeoRequestInstanceImage( + bytesBase64Encoded=tensor_to_base64_string(first_frame), mimeType="image/png" + ), + lastFrame=VeoRequestInstanceImage( + bytesBase64Encoded=tensor_to_base64_string(last_frame), mimeType="image/png" + ), + ), + ], + parameters=VeoRequestParameters( + aspectRatio=aspect_ratio, + personGeneration="ALLOW", + durationSeconds=duration, + enhancePrompt=True, # cannot be False for Veo3 + seed=seed, + generateAudio=generate_audio, + negativePrompt=negative_prompt, + resolution=resolution, + ), + ), + ) + poll_response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"), + response_model=VeoGenVidPollResponse, + status_extractor=lambda r: "completed" if r.done else "pending", + data=VeoGenVidPollRequest( + operationName=initial_response.name, + ), + poll_interval=5.0, + estimated_duration=AVERAGE_DURATION_VIDEO_GEN, + ) + + if poll_response.error: + raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})") + + response = poll_response.response + filtered_count = response.raiMediaFilteredCount + if filtered_count: + reasons = response.raiMediaFilteredReasons or [] + reason_part = f": {reasons[0]}" if reasons else "" + raise Exception( + f"Content blocked by Google's Responsible AI filters{reason_part} " + f"({filtered_count} video{'s' if filtered_count != 1 else ''} filtered)." + ) + + if response.videos: + video = response.videos[0] + if video.bytesBase64Encoded: + return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded)))) + if video.gcsUri: + return IO.NodeOutput(await download_url_to_video_output(video.gcsUri)) + raise Exception("Video returned but no data or URL was provided") + raise Exception("Video generation completed but no video was returned") + + class VeoExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[IO.ComfyNode]]: return [ VeoVideoGenerationNode, Veo3VideoGenerationNode, + Veo3FirstLastFrameNode, ] + async def comfy_entrypoint() -> VeoExtension: return VeoExtension() diff --git a/comfy_api_nodes/nodes_vidu.py b/comfy_api_nodes/nodes_vidu.py index 639be4b2b..7a679f0d9 100644 --- a/comfy_api_nodes/nodes_vidu.py +++ b/comfy_api_nodes/nodes_vidu.py @@ -1,27 +1,23 @@ import logging from enum import Enum -from typing import Any, Callable, Optional, Literal, TypeVar -from typing_extensions import override +from typing import Literal, Optional, TypeVar import torch from pydantic import BaseModel, Field +from typing_extensions import override -from comfy_api.latest import ComfyExtension, IO -from comfy_api_nodes.util.validation_utils import ( - validate_aspect_ratio_closeness, - validate_image_dimensions, - validate_image_aspect_ratio_range, - get_number_of_images, -) -from comfy_api_nodes.apis.client import ( +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.util import ( ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, + download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, + upload_images_to_comfyapi, + validate_image_aspect_ratio, + validate_image_dimensions, + validate_images_aspect_ratio_closeness, ) -from comfy_api_nodes.apinode_utils import download_url_to_video_output, upload_images_to_comfyapi - VIDU_TEXT_TO_VIDEO = "/proxy/vidu/text2video" VIDU_IMAGE_TO_VIDEO = "/proxy/vidu/img2video" @@ -31,8 +27,9 @@ VIDU_GET_GENERATION_STATUS = "/proxy/vidu/tasks/%s/creations" R = TypeVar("R") + class VideoModelName(str, Enum): - vidu_q1 = 'viduq1' + vidu_q1 = "viduq1" class AspectRatio(str, Enum): @@ -63,17 +60,9 @@ class TaskCreationRequest(BaseModel): images: Optional[list[str]] = Field(None, description="Base64 encoded string or image URL") -class TaskStatus(str, Enum): - created = "created" - queueing = "queueing" - processing = "processing" - success = "success" - failed = "failed" - - class TaskCreationResponse(BaseModel): task_id: str = Field(...) - state: TaskStatus = Field(...) + state: str = Field(...) created_at: str = Field(...) code: Optional[int] = Field(None, description="Error code") @@ -85,32 +74,11 @@ class TaskResult(BaseModel): class TaskStatusResponse(BaseModel): - state: TaskStatus = Field(...) + state: str = Field(...) err_code: Optional[str] = Field(None) creations: list[TaskResult] = Field(..., description="Generated results") -async def poll_until_finished( - auth_kwargs: dict[str, str], - api_endpoint: ApiEndpoint[Any, R], - result_url_extractor: Optional[Callable[[R], str]] = None, - estimated_duration: Optional[int] = None, - node_id: Optional[str] = None, -) -> R: - return await PollingOperation( - poll_endpoint=api_endpoint, - completed_statuses=[TaskStatus.success.value], - failed_statuses=[TaskStatus.failed.value], - status_extractor=lambda response: response.state.value, - auth_kwargs=auth_kwargs, - result_url_extractor=result_url_extractor, - estimated_duration=estimated_duration, - node_id=node_id, - poll_interval=16.0, - max_poll_attempts=256, - ).execute() - - def get_video_url_from_response(response) -> Optional[str]: if response.creations: return response.creations[0].url @@ -127,37 +95,27 @@ def get_video_from_response(response) -> TaskResult: async def execute_task( + cls: type[IO.ComfyNode], vidu_endpoint: str, - auth_kwargs: Optional[dict[str, str]], payload: TaskCreationRequest, estimated_duration: int, - node_id: str, ) -> R: - response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=vidu_endpoint, - method=HttpMethod.POST, - request_model=TaskCreationRequest, - response_model=TaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - if response.state == TaskStatus.failed: + response = await sync_op( + cls, + endpoint=ApiEndpoint(path=vidu_endpoint, method="POST"), + response_model=TaskCreationResponse, + data=payload, + ) + if response.state == "failed": error_msg = f"Vidu request failed. Code: {response.code}" logging.error(error_msg) raise RuntimeError(error_msg) - return await poll_until_finished( - auth_kwargs, - ApiEndpoint( - path=VIDU_GET_GENERATION_STATUS % response.task_id, - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=TaskStatusResponse, - ), - result_url_extractor=get_video_url_from_response, + return await poll_op( + cls, + ApiEndpoint(path=VIDU_GET_GENERATION_STATUS % response.task_id), + response_model=TaskStatusResponse, + status_extractor=lambda r: r.state, estimated_duration=estimated_duration, - node_id=node_id, ) @@ -258,11 +216,7 @@ class ViduTextToVideoNode(IO.ComfyNode): resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } - results = await execute_task(VIDU_TEXT_TO_VIDEO, auth, payload, 320, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_TEXT_TO_VIDEO, payload, 320) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -353,7 +307,7 @@ class ViduImageToVideoNode(IO.ComfyNode): ) -> IO.NodeOutput: if get_number_of_images(image) > 1: raise ValueError("Only one input image is allowed.") - validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) + validate_image_aspect_ratio(image, (1, 4), (4, 1)) payload = TaskCreationRequest( model_name=model, prompt=prompt, @@ -362,17 +316,13 @@ class ViduImageToVideoNode(IO.ComfyNode): resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } payload.images = await upload_images_to_comfyapi( + cls, image, max_images=1, mime_type="image/png", - auth_kwargs=auth, ) - results = await execute_task(VIDU_IMAGE_TO_VIDEO, auth, payload, 120, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_IMAGE_TO_VIDEO, payload, 120) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -473,7 +423,7 @@ class ViduReferenceVideoNode(IO.ComfyNode): if a > 7: raise ValueError("Too many images, maximum allowed is 7.") for image in images: - validate_image_aspect_ratio_range(image, (1, 4), (4, 1)) + validate_image_aspect_ratio(image, (1, 4), (4, 1)) validate_image_dimensions(image, min_width=128, min_height=128) payload = TaskCreationRequest( model_name=model, @@ -484,17 +434,13 @@ class ViduReferenceVideoNode(IO.ComfyNode): resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } payload.images = await upload_images_to_comfyapi( + cls, images, max_images=7, mime_type="image/png", - auth_kwargs=auth, ) - results = await execute_task(VIDU_REFERENCE_VIDEO, auth, payload, 120, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_REFERENCE_VIDEO, payload, 120) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -587,7 +533,7 @@ class ViduStartEndToVideoNode(IO.ComfyNode): resolution: str, movement_amplitude: str, ) -> IO.NodeOutput: - validate_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) + validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False) payload = TaskCreationRequest( model_name=model, prompt=prompt, @@ -596,15 +542,11 @@ class ViduStartEndToVideoNode(IO.ComfyNode): resolution=resolution, movement_amplitude=movement_amplitude, ) - auth = { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - } payload.images = [ - (await upload_images_to_comfyapi(frame, max_images=1, mime_type="image/png", auth_kwargs=auth))[0] + (await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0] for frame in (first_frame, end_frame) ] - results = await execute_task(VIDU_START_END_VIDEO, auth, payload, 96, cls.hidden.unique_id) + results = await execute_task(cls, VIDU_START_END_VIDEO, payload, 96) return IO.NodeOutput(await download_url_to_video_output(get_video_from_response(results).url)) @@ -618,5 +560,6 @@ class ViduExtension(ComfyExtension): ViduStartEndToVideoNode, ] + async def comfy_entrypoint() -> ViduExtension: return ViduExtension() diff --git a/comfy_api_nodes/nodes_wan.py b/comfy_api_nodes/nodes_wan.py index b089bd907..2aab3c2ff 100644 --- a/comfy_api_nodes/nodes_wan.py +++ b/comfy_api_nodes/nodes_wan.py @@ -1,28 +1,24 @@ import re -from typing import Optional, Type, Union -from typing_extensions import override +from typing import Optional import torch from pydantic import BaseModel, Field -from comfy_api.latest import ComfyExtension, Input, IO -from comfy_api_nodes.apis.client import ( - ApiEndpoint, - HttpMethod, - SynchronousOperation, - PollingOperation, - EmptyRequest, - R, - T, -) -from comfy_api_nodes.util.validation_utils import get_number_of_images, validate_audio_duration +from typing_extensions import override -from comfy_api_nodes.apinode_utils import ( +from comfy_api.latest import IO, ComfyExtension, Input +from comfy_api_nodes.util import ( + ApiEndpoint, + audio_to_base64_string, download_url_to_image_tensor, download_url_to_video_output, + get_number_of_images, + poll_op, + sync_op, tensor_to_base64_string, - audio_to_base64_string, + validate_audio_duration, ) + class Text2ImageInputField(BaseModel): prompt: str = Field(...) negative_prompt: Optional[str] = Field(None) @@ -146,53 +142,7 @@ class VideoTaskStatusResponse(BaseModel): request_id: str = Field(...) -RES_IN_PARENS = re.compile(r'\((\d+)\s*[x×]\s*(\d+)\)') - - -async def process_task( - auth_kwargs: dict[str, str], - url: str, - request_model: Type[T], - response_model: Type[R], - payload: Union[ - Text2ImageTaskCreationRequest, - Image2ImageTaskCreationRequest, - Text2VideoTaskCreationRequest, - Image2VideoTaskCreationRequest, - ], - node_id: str, - estimated_duration: int, - poll_interval: int, -) -> Type[R]: - initial_response = await SynchronousOperation( - endpoint=ApiEndpoint( - path=url, - method=HttpMethod.POST, - request_model=request_model, - response_model=TaskCreationResponse, - ), - request=payload, - auth_kwargs=auth_kwargs, - ).execute() - - if not initial_response.output: - raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") - - return await PollingOperation( - poll_endpoint=ApiEndpoint( - path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}", - method=HttpMethod.GET, - request_model=EmptyRequest, - response_model=response_model, - ), - completed_statuses=["SUCCEEDED"], - failed_statuses=["FAILED", "CANCELED", "UNKNOWN"], - status_extractor=lambda x: x.output.task_status, - estimated_duration=estimated_duration, - poll_interval=poll_interval, - node_id=node_id, - auth_kwargs=auth_kwargs, - ).execute() +RES_IN_PARENS = re.compile(r"\((\d+)\s*[x×]\s*(\d+)\)") class WanTextToImageApi(IO.ComfyNode): @@ -259,7 +209,7 @@ class WanTextToImageApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -286,26 +236,28 @@ class WanTextToImageApi(IO.ComfyNode): prompt_extend: bool = True, watermark: bool = True, ): - payload = Text2ImageTaskCreationRequest( - model=model, - input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt), - parameters=Txt2ImageParametersField( - size=f"{width}*{height}", - seed=seed, - prompt_extend=prompt_extend, - watermark=watermark, + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Text2ImageTaskCreationRequest( + model=model, + input=Text2ImageInputField(prompt=prompt, negative_prompt=negative_prompt), + parameters=Txt2ImageParametersField( + size=f"{width}*{height}", + seed=seed, + prompt_extend=prompt_extend, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/text2image/image-synthesis", - request_model=Text2ImageTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=ImageTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=9, poll_interval=3, ) @@ -320,7 +272,7 @@ class WanImageToImageApi(IO.ComfyNode): display_name="Wan Image to Image", category="api node/image/Wan", description="Generates an image from one or two input images and a text prompt. " - "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", + "The output image is currently fixed at 1.6 MP; its aspect ratio matches the input image(s).", inputs=[ IO.Combo.Input( "model", @@ -376,7 +328,7 @@ class WanImageToImageApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -408,28 +360,30 @@ class WanImageToImageApi(IO.ComfyNode): raise ValueError(f"Expected 1 or 2 input images, got {n_images}.") images = [] for i in image: - images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096*4096)) - payload = Image2ImageTaskCreationRequest( - model=model, - input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images), - parameters=Image2ImageParametersField( - # size=f"{width}*{height}", - seed=seed, - watermark=watermark, + images.append("data:image/png;base64," + tensor_to_base64_string(i, total_pixels=4096 * 4096)) + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Image2ImageTaskCreationRequest( + model=model, + input=Image2ImageInputField(prompt=prompt, negative_prompt=negative_prompt, images=images), + parameters=Image2ImageParametersField( + # size=f"{width}*{height}", + seed=seed, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/image2image/image-synthesis", - request_model=Image2ImageTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=ImageTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=42, - poll_interval=3, + poll_interval=4, ) return IO.NodeOutput(await download_url_to_image_tensor(str(response.output.results[0].url))) @@ -523,7 +477,7 @@ class WanTextToVideoApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -557,28 +511,31 @@ class WanTextToVideoApi(IO.ComfyNode): if audio is not None: validate_audio_duration(audio, 3.0, 29.0) audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") - payload = Text2VideoTaskCreationRequest( - model=model, - input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url), - parameters=Text2VideoParametersField( - size=f"{width}*{height}", - duration=duration, - seed=seed, - audio=generate_audio, - prompt_extend=prompt_extend, - watermark=watermark, + + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Text2VideoTaskCreationRequest( + model=model, + input=Text2VideoInputField(prompt=prompt, negative_prompt=negative_prompt, audio_url=audio_url), + parameters=Text2VideoParametersField( + size=f"{width}*{height}", + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", - request_model=Text2VideoTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=VideoTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=120 * int(duration / 5), poll_interval=6, ) @@ -667,7 +624,7 @@ class WanImageToVideoApi(IO.ComfyNode): IO.Boolean.Input( "watermark", default=True, - tooltip="Whether to add an \"AI generated\" watermark to the result.", + tooltip='Whether to add an "AI generated" watermark to the result.', optional=True, ), ], @@ -699,35 +656,37 @@ class WanImageToVideoApi(IO.ComfyNode): ): if get_number_of_images(image) != 1: raise ValueError("Exactly one input image is required.") - image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000*2000) + image_url = "data:image/png;base64," + tensor_to_base64_string(image, total_pixels=2000 * 2000) audio_url = None if audio is not None: validate_audio_duration(audio, 3.0, 29.0) audio_url = "data:audio/mp3;base64," + audio_to_base64_string(audio, "mp3", "libmp3lame") - payload = Image2VideoTaskCreationRequest( - model=model, - input=Image2VideoInputField( - prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url - ), - parameters=Image2VideoParametersField( - resolution=resolution, - duration=duration, - seed=seed, - audio=generate_audio, - prompt_extend=prompt_extend, - watermark=watermark, + initial_response = await sync_op( + cls, + ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"), + response_model=TaskCreationResponse, + data=Image2VideoTaskCreationRequest( + model=model, + input=Image2VideoInputField( + prompt=prompt, negative_prompt=negative_prompt, img_url=image_url, audio_url=audio_url + ), + parameters=Image2VideoParametersField( + resolution=resolution, + duration=duration, + seed=seed, + audio=generate_audio, + prompt_extend=prompt_extend, + watermark=watermark, + ), ), ) - response = await process_task( - { - "auth_token": cls.hidden.auth_token_comfy_org, - "comfy_api_key": cls.hidden.api_key_comfy_org, - }, - "/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", - request_model=Image2VideoTaskCreationRequest, + if not initial_response.output: + raise Exception(f"Unknown error occurred: {initial_response.code} - {initial_response.message}") + response = await poll_op( + cls, + ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"), response_model=VideoTaskStatusResponse, - payload=payload, - node_id=cls.hidden.unique_id, + status_extractor=lambda x: x.output.task_status, estimated_duration=120 * int(duration / 5), poll_interval=6, ) diff --git a/comfy_api_nodes/util/__init__.py b/comfy_api_nodes/util/__init__.py index e69de29bb..4cc22abfb 100644 --- a/comfy_api_nodes/util/__init__.py +++ b/comfy_api_nodes/util/__init__.py @@ -0,0 +1,101 @@ +from ._helpers import get_fs_object_size +from .client import ( + ApiEndpoint, + poll_op, + poll_op_raw, + sync_op, + sync_op_raw, +) +from .conversions import ( + audio_bytes_to_audio_input, + audio_input_to_mp3, + audio_to_base64_string, + bytesio_to_image_tensor, + downscale_image_tensor, + image_tensor_pair_to_batch, + pil_to_bytesio, + resize_mask_to_image, + tensor_to_base64_string, + tensor_to_bytesio, + tensor_to_pil, + text_filepath_to_base64_string, + text_filepath_to_data_uri, + trim_video, + video_to_base64_string, +) +from .download_helpers import ( + download_url_as_bytesio, + download_url_to_bytesio, + download_url_to_image_tensor, + download_url_to_video_output, +) +from .upload_helpers import ( + upload_audio_to_comfyapi, + upload_file_to_comfyapi, + upload_images_to_comfyapi, + upload_video_to_comfyapi, +) +from .validation_utils import ( + get_image_dimensions, + get_number_of_images, + validate_aspect_ratio_string, + validate_audio_duration, + validate_container_format_is_mp4, + validate_image_aspect_ratio, + validate_image_dimensions, + validate_images_aspect_ratio_closeness, + validate_string, + validate_video_dimensions, + validate_video_duration, + validate_video_frame_count, +) + +__all__ = [ + # API client + "ApiEndpoint", + "poll_op", + "poll_op_raw", + "sync_op", + "sync_op_raw", + # Upload helpers + "upload_audio_to_comfyapi", + "upload_file_to_comfyapi", + "upload_images_to_comfyapi", + "upload_video_to_comfyapi", + # Download helpers + "download_url_as_bytesio", + "download_url_to_bytesio", + "download_url_to_image_tensor", + "download_url_to_video_output", + # Conversions + "audio_bytes_to_audio_input", + "audio_input_to_mp3", + "audio_to_base64_string", + "bytesio_to_image_tensor", + "downscale_image_tensor", + "image_tensor_pair_to_batch", + "pil_to_bytesio", + "resize_mask_to_image", + "tensor_to_base64_string", + "tensor_to_bytesio", + "tensor_to_pil", + "text_filepath_to_base64_string", + "text_filepath_to_data_uri", + "trim_video", + "video_to_base64_string", + # Validation utilities + "get_image_dimensions", + "get_number_of_images", + "validate_aspect_ratio_string", + "validate_audio_duration", + "validate_container_format_is_mp4", + "validate_image_aspect_ratio", + "validate_image_dimensions", + "validate_images_aspect_ratio_closeness", + "validate_string", + "validate_video_dimensions", + "validate_video_duration", + "validate_video_frame_count", + # Misc functions + "get_fs_object_size", +] diff --git a/comfy_api_nodes/util/_helpers.py b/comfy_api_nodes/util/_helpers.py new file mode 100644 index 000000000..491e6b6a8 --- /dev/null +++ b/comfy_api_nodes/util/_helpers.py @@ -0,0 +1,71 @@ +import asyncio +import contextlib +import os +import time +from collections.abc import Callable +from io import BytesIO + +from comfy.cli_args import args +from comfy.model_management import processing_interrupted +from comfy_api.latest import IO + +from .common_exceptions import ProcessingInterrupted + + +def is_processing_interrupted() -> bool: + """Return True if user/runtime requested interruption.""" + return processing_interrupted() + + +def get_node_id(node_cls: type[IO.ComfyNode]) -> str: + return node_cls.hidden.unique_id + + +def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]: + if node_cls.hidden.auth_token_comfy_org: + return {"Authorization": f"Bearer {node_cls.hidden.auth_token_comfy_org}"} + if node_cls.hidden.api_key_comfy_org: + return {"X-API-KEY": node_cls.hidden.api_key_comfy_org} + return {} + + +def default_base_url() -> str: + return getattr(args, "comfy_api_base", "https://api.comfy.org") + + +async def sleep_with_interrupt( + seconds: float, + node_cls: type[IO.ComfyNode] | None, + label: str | None = None, + start_ts: float | None = None, + estimated_total: int | None = None, + *, + display_callback: Callable[[type[IO.ComfyNode], str, int, int | None], None] | None = None, +): + """ + Sleep in 1s slices while: + - Checking for interruption (raises ProcessingInterrupted). + - Optionally emitting time progress via display_callback (if provided). + """ + end = time.monotonic() + seconds + while True: + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + now = time.monotonic() + if start_ts is not None and label and display_callback: + with contextlib.suppress(Exception): + display_callback(node_cls, label, int(now - start_ts), estimated_total) + if now >= end: + break + await asyncio.sleep(min(1.0, end - now)) + + +def mimetype_to_extension(mime_type: str) -> str: + """Converts a MIME type to a file extension.""" + return mime_type.split("/")[-1].lower() + + +def get_fs_object_size(path_or_object: str | BytesIO) -> int: + if isinstance(path_or_object, str): + return os.path.getsize(path_or_object) + return len(path_or_object.getvalue()) diff --git a/comfy_api_nodes/util/client.py b/comfy_api_nodes/util/client.py new file mode 100644 index 000000000..bf37cba5f --- /dev/null +++ b/comfy_api_nodes/util/client.py @@ -0,0 +1,947 @@ +import asyncio +import contextlib +import json +import logging +import time +import uuid +from collections.abc import Callable, Iterable +from dataclasses import dataclass +from enum import Enum +from io import BytesIO +from typing import Any, Literal, TypeVar +from urllib.parse import urljoin, urlparse + +import aiohttp +from aiohttp.client_exceptions import ClientError, ContentTypeError +from pydantic import BaseModel + +from comfy import utils +from comfy_api.latest import IO +from server import PromptServer + +from . import request_logger +from ._helpers import ( + default_base_url, + get_auth_header, + get_node_id, + is_processing_interrupted, + sleep_with_interrupt, +) +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted + +M = TypeVar("M", bound=BaseModel) + + +class ApiEndpoint: + def __init__( + self, + path: str, + method: Literal["GET", "POST", "PUT", "DELETE", "PATCH"] = "GET", + *, + query_params: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + ): + self.path = path + self.method = method + self.query_params = query_params or {} + self.headers = headers or {} + + +@dataclass +class _RequestConfig: + node_cls: type[IO.ComfyNode] + endpoint: ApiEndpoint + timeout: float + content_type: str + data: dict[str, Any] | None + files: dict[str, Any] | list[tuple[str, Any]] | None + multipart_parser: Callable | None + max_retries: int + retry_delay: float + retry_backoff: float + wait_label: str = "Waiting" + monitor_progress: bool = True + estimated_total: int | None = None + final_label_on_success: str | None = "Completed" + progress_origin_ts: float | None = None + price_extractor: Callable[[dict[str, Any]], float | None] | None = None + + +@dataclass +class _PollUIState: + started: float + status_label: str = "Queued" + is_queued: bool = True + price: float | None = None + estimated_duration: int | None = None + base_processing_elapsed: float = 0.0 # sum of completed active intervals + active_since: float | None = None # start time of current active interval (None if queued) + + +_RETRY_STATUS = {408, 429, 500, 502, 503, 504} +COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"] +FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"] +QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"] + + +async def sync_op( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + *, + response_model: type[M], + price_extractor: Callable[[M | Any], float | None] | None = None, + data: BaseModel | None = None, + files: dict[str, Any] | list[tuple[str, Any]] | None = None, + content_type: str = "application/json", + timeout: float = 3600.0, + multipart_parser: Callable | None = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str = "Waiting for server", + estimated_duration: int | None = None, + final_label_on_success: str | None = "Completed", + progress_origin_ts: float | None = None, + monitor_progress: bool = True, +) -> M: + raw = await sync_op_raw( + cls, + endpoint, + price_extractor=_wrap_model_extractor(response_model, price_extractor), + data=data, + files=files, + content_type=content_type, + timeout=timeout, + multipart_parser=multipart_parser, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff=retry_backoff, + wait_label=wait_label, + estimated_duration=estimated_duration, + as_binary=False, + final_label_on_success=final_label_on_success, + progress_origin_ts=progress_origin_ts, + monitor_progress=monitor_progress, + ) + if not isinstance(raw, dict): + raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + return _validate_or_raise(response_model, raw) + + +async def poll_op( + cls: type[IO.ComfyNode], + poll_endpoint: ApiEndpoint, + *, + response_model: type[M], + status_extractor: Callable[[M | Any], str | int | None], + progress_extractor: Callable[[M | Any], int | None] | None = None, + price_extractor: Callable[[M | Any], float | None] | None = None, + completed_statuses: list[str | int] | None = None, + failed_statuses: list[str | int] | None = None, + queued_statuses: list[str | int] | None = None, + data: BaseModel | None = None, + poll_interval: float = 5.0, + max_poll_attempts: int = 120, + timeout_per_poll: float = 120.0, + max_retries_per_poll: int = 3, + retry_delay_per_poll: float = 1.0, + retry_backoff_per_poll: float = 2.0, + estimated_duration: int | None = None, + cancel_endpoint: ApiEndpoint | None = None, + cancel_timeout: float = 10.0, +) -> M: + raw = await poll_op_raw( + cls, + poll_endpoint=poll_endpoint, + status_extractor=_wrap_model_extractor(response_model, status_extractor), + progress_extractor=_wrap_model_extractor(response_model, progress_extractor), + price_extractor=_wrap_model_extractor(response_model, price_extractor), + completed_statuses=completed_statuses, + failed_statuses=failed_statuses, + queued_statuses=queued_statuses, + data=data, + poll_interval=poll_interval, + max_poll_attempts=max_poll_attempts, + timeout_per_poll=timeout_per_poll, + max_retries_per_poll=max_retries_per_poll, + retry_delay_per_poll=retry_delay_per_poll, + retry_backoff_per_poll=retry_backoff_per_poll, + estimated_duration=estimated_duration, + cancel_endpoint=cancel_endpoint, + cancel_timeout=cancel_timeout, + ) + if not isinstance(raw, dict): + raise Exception("Expected JSON response to validate into a Pydantic model, got non-JSON (binary or text).") + return _validate_or_raise(response_model, raw) + + +async def sync_op_raw( + cls: type[IO.ComfyNode], + endpoint: ApiEndpoint, + *, + price_extractor: Callable[[dict[str, Any]], float | None] | None = None, + data: dict[str, Any] | BaseModel | None = None, + files: dict[str, Any] | list[tuple[str, Any]] | None = None, + content_type: str = "application/json", + timeout: float = 3600.0, + multipart_parser: Callable | None = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str = "Waiting for server", + estimated_duration: int | None = None, + as_binary: bool = False, + final_label_on_success: str | None = "Completed", + progress_origin_ts: float | None = None, + monitor_progress: bool = True, +) -> dict[str, Any] | bytes: + """ + Make a single network request. + - If as_binary=False (default): returns JSON dict (or {'_raw': ''} if non-JSON). + - If as_binary=True: returns bytes. + """ + if isinstance(data, BaseModel): + data = data.model_dump(exclude_none=True) + for k, v in list(data.items()): + if isinstance(v, Enum): + data[k] = v.value + cfg = _RequestConfig( + node_cls=cls, + endpoint=endpoint, + timeout=timeout, + content_type=content_type, + data=data, + files=files, + multipart_parser=multipart_parser, + max_retries=max_retries, + retry_delay=retry_delay, + retry_backoff=retry_backoff, + wait_label=wait_label, + monitor_progress=monitor_progress, + estimated_total=estimated_duration, + final_label_on_success=final_label_on_success, + progress_origin_ts=progress_origin_ts, + price_extractor=price_extractor, + ) + return await _request_base(cfg, expect_binary=as_binary) + + +async def poll_op_raw( + cls: type[IO.ComfyNode], + poll_endpoint: ApiEndpoint, + *, + status_extractor: Callable[[dict[str, Any]], str | int | None], + progress_extractor: Callable[[dict[str, Any]], int | None] | None = None, + price_extractor: Callable[[dict[str, Any]], float | None] | None = None, + completed_statuses: list[str | int] | None = None, + failed_statuses: list[str | int] | None = None, + queued_statuses: list[str | int] | None = None, + data: dict[str, Any] | BaseModel | None = None, + poll_interval: float = 5.0, + max_poll_attempts: int = 120, + timeout_per_poll: float = 120.0, + max_retries_per_poll: int = 3, + retry_delay_per_poll: float = 1.0, + retry_backoff_per_poll: float = 2.0, + estimated_duration: int | None = None, + cancel_endpoint: ApiEndpoint | None = None, + cancel_timeout: float = 10.0, +) -> dict[str, Any]: + """ + Polls an endpoint until the task reaches a terminal state. Displays time while queued/processing, + checks interruption every second, and calls Cancel endpoint (if provided) on interruption. + + Uses default complete, failed and queued states assumption. + + Returns the final JSON response from the poll endpoint. + """ + completed_states = _normalize_statuses(COMPLETED_STATUSES if completed_statuses is None else completed_statuses) + failed_states = _normalize_statuses(FAILED_STATUSES if failed_statuses is None else failed_statuses) + queued_states = _normalize_statuses(QUEUED_STATUSES if queued_statuses is None else queued_statuses) + started = time.monotonic() + consumed_attempts = 0 # counts only non-queued polls + + progress_bar = utils.ProgressBar(100) if progress_extractor else None + last_progress: int | None = None + + state = _PollUIState(started=started, estimated_duration=estimated_duration) + stop_ticker = asyncio.Event() + + async def _ticker(): + """Emit a UI update every second while polling is in progress.""" + try: + while not stop_ticker.is_set(): + if is_processing_interrupted(): + break + now = time.monotonic() + proc_elapsed = state.base_processing_elapsed + ( + (now - state.active_since) if state.active_since is not None else 0.0 + ) + _display_time_progress( + cls, + status=state.status_label, + elapsed_seconds=int(now - state.started), + estimated_total=state.estimated_duration, + price=state.price, + is_queued=state.is_queued, + processing_elapsed_seconds=int(proc_elapsed), + ) + await asyncio.sleep(1.0) + except Exception as exc: + logging.debug("Polling ticker exited: %s", exc) + + ticker_task = asyncio.create_task(_ticker()) + try: + while consumed_attempts < max_poll_attempts: + try: + resp_json = await sync_op_raw( + cls, + poll_endpoint, + data=data, + timeout=timeout_per_poll, + max_retries=max_retries_per_poll, + retry_delay=retry_delay_per_poll, + retry_backoff=retry_backoff_per_poll, + wait_label="Checking", + estimated_duration=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + if not isinstance(resp_json, dict): + raise Exception("Polling endpoint returned non-JSON response.") + except ProcessingInterrupted: + if cancel_endpoint: + with contextlib.suppress(Exception): + await sync_op_raw( + cls, + cancel_endpoint, + timeout=cancel_timeout, + max_retries=0, + wait_label="Cancelling task", + estimated_duration=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + raise + + try: + status = _normalize_status_value(status_extractor(resp_json)) + except Exception as e: + logging.error("Status extraction failed: %s", e) + status = None + + if price_extractor: + new_price = price_extractor(resp_json) + if new_price is not None: + state.price = new_price + + if progress_extractor: + new_progress = progress_extractor(resp_json) + if new_progress is not None and last_progress != new_progress: + progress_bar.update_absolute(new_progress, total=100) + last_progress = new_progress + + now_ts = time.monotonic() + is_queued = status in queued_states + + if is_queued: + if state.active_since is not None: # If we just moved from active -> queued, close the active interval + state.base_processing_elapsed += now_ts - state.active_since + state.active_since = None + else: + if state.active_since is None: # If we just moved from queued -> active, open a new active interval + state.active_since = now_ts + + state.is_queued = is_queued + state.status_label = status or ("Queued" if is_queued else "Processing") + if status in completed_states: + if state.active_since is not None: + state.base_processing_elapsed += now_ts - state.active_since + state.active_since = None + stop_ticker.set() + with contextlib.suppress(Exception): + await ticker_task + + if progress_bar and last_progress != 100: + progress_bar.update_absolute(100, total=100) + + _display_time_progress( + cls, + status=status if status else "Completed", + elapsed_seconds=int(now_ts - started), + estimated_total=estimated_duration, + price=state.price, + is_queued=False, + processing_elapsed_seconds=int(state.base_processing_elapsed), + ) + return resp_json + + if status in failed_states: + msg = f"Task failed: {json.dumps(resp_json)}" + logging.error(msg) + raise Exception(msg) + + try: + await sleep_with_interrupt(poll_interval, cls, None, None, None) + except ProcessingInterrupted: + if cancel_endpoint: + with contextlib.suppress(Exception): + await sync_op_raw( + cls, + cancel_endpoint, + timeout=cancel_timeout, + max_retries=0, + wait_label="Cancelling task", + estimated_duration=None, + as_binary=False, + final_label_on_success=None, + monitor_progress=False, + ) + raise + if not is_queued: + consumed_attempts += 1 + + raise Exception( + f"Polling timed out after {max_poll_attempts} non-queued attempts " + f"(~{int(max_poll_attempts * poll_interval)}s of active polling)." + ) + except ProcessingInterrupted: + raise + except (LocalNetworkError, ApiServerError): + raise + except Exception as e: + raise Exception(f"Polling aborted due to error: {e}") from e + finally: + stop_ticker.set() + with contextlib.suppress(Exception): + await ticker_task + + +def _display_text( + node_cls: type[IO.ComfyNode], + text: str | None, + *, + status: str | int | None = None, + price: float | None = None, +) -> None: + display_lines: list[str] = [] + if status: + display_lines.append(f"Status: {status.capitalize() if isinstance(status, str) else status}") + if price is not None: + p = f"{float(price):,.4f}".rstrip("0").rstrip(".") + if p != "0": + display_lines.append(f"Price: ${p}") + if text is not None: + display_lines.append(text) + if display_lines: + PromptServer.instance.send_progress_text("\n".join(display_lines), get_node_id(node_cls)) + + +def _display_time_progress( + node_cls: type[IO.ComfyNode], + status: str | int | None, + elapsed_seconds: int, + estimated_total: int | None = None, + *, + price: float | None = None, + is_queued: bool | None = None, + processing_elapsed_seconds: int | None = None, +) -> None: + if estimated_total is not None and estimated_total > 0 and is_queued is False: + pe = processing_elapsed_seconds if processing_elapsed_seconds is not None else elapsed_seconds + remaining = max(0, int(estimated_total) - int(pe)) + time_line = f"Time elapsed: {int(elapsed_seconds)}s (~{remaining}s remaining)" + else: + time_line = f"Time elapsed: {int(elapsed_seconds)}s" + _display_text(node_cls, time_line, status=status, price=price) + + +async def _diagnose_connectivity() -> dict[str, bool]: + """Best-effort connectivity diagnostics to distinguish local vs. server issues.""" + results = { + "internet_accessible": False, + "api_accessible": False, + } + timeout = aiohttp.ClientTimeout(total=5.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + with contextlib.suppress(ClientError, OSError): + async with session.get("https://www.google.com") as resp: + results["internet_accessible"] = resp.status < 500 + if not results["internet_accessible"]: + return results + + parsed = urlparse(default_base_url()) + health_url = f"{parsed.scheme}://{parsed.netloc}/health" + with contextlib.suppress(ClientError, OSError): + async with session.get(health_url) as resp: + results["api_accessible"] = resp.status < 500 + return results + + +def _unpack_tuple(t: tuple) -> tuple[str, Any, str]: + """Normalize (filename, value, content_type).""" + if len(t) == 2: + return t[0], t[1], "application/octet-stream" + if len(t) == 3: + return t[0], t[1], t[2] + raise ValueError("files tuple must be (filename, file[, content_type])") + + +def _merge_params(endpoint_params: dict[str, Any], method: str, data: dict[str, Any] | None) -> dict[str, Any]: + params = dict(endpoint_params or {}) + if method.upper() == "GET" and data: + for k, v in data.items(): + if v is not None: + params[k] = v + return params + + +def _friendly_http_message(status: int, body: Any) -> str: + if status == 401: + return "Unauthorized: Please login first to use this node." + if status == 402: + return "Payment Required: Please add credits to your account to use this node." + if status == 409: + return "There is a problem with your account. Please contact support@comfy.org." + if status == 429: + return "Rate Limit Exceeded: Please try again later." + try: + if isinstance(body, dict): + err = body.get("error") + if isinstance(err, dict): + msg = err.get("message") + typ = err.get("type") + if msg and typ: + return f"API Error: {msg} (Type: {typ})" + if msg: + return f"API Error: {msg}" + return f"API Error: {json.dumps(body)}" + else: + txt = str(body) + if len(txt) <= 200: + return f"API Error (raw): {txt}" + return f"API Error (status {status})" + except Exception: + return f"HTTP {status}: Unknown error" + + +def _generate_operation_id(method: str, path: str, attempt: int) -> str: + slug = path.strip("/").replace("/", "_") or "op" + return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" + + +def _snapshot_request_body_for_logging( + content_type: str, + method: str, + data: dict[str, Any] | None, + files: dict[str, Any] | list[tuple[str, Any]] | None, +) -> dict[str, Any] | str | None: + if method.upper() == "GET": + return None + if content_type == "multipart/form-data": + form_fields = sorted([k for k, v in (data or {}).items() if v is not None]) + file_fields: list[dict[str, str]] = [] + if files: + file_iter = files if isinstance(files, list) else list(files.items()) + for field_name, file_obj in file_iter: + if file_obj is None: + continue + if isinstance(file_obj, tuple): + filename = file_obj[0] + else: + filename = getattr(file_obj, "name", field_name) + file_fields.append({"field": field_name, "filename": str(filename or "")}) + return {"_multipart": True, "form_fields": form_fields, "file_fields": file_fields} + if content_type == "application/x-www-form-urlencoded": + return data or {} + return data or {} + + +async def _request_base(cfg: _RequestConfig, expect_binary: bool): + """Core request with retries, per-second interruption monitoring, true cancellation, and friendly errors.""" + url = cfg.endpoint.path + parsed_url = urlparse(url) + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) + + method = cfg.endpoint.method + params = _merge_params(cfg.endpoint.query_params, method, cfg.data if method == "GET" else None) + + async def _monitor(stop_evt: asyncio.Event, start_ts: float): + """Every second: update elapsed time and signal interruption.""" + try: + while not stop_evt.is_set(): + if is_processing_interrupted(): + return + if cfg.monitor_progress: + _display_time_progress( + cfg.node_cls, cfg.wait_label, int(time.monotonic() - start_ts), cfg.estimated_total + ) + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return # normal shutdown + + start_time = cfg.progress_origin_ts if cfg.progress_origin_ts is not None else time.monotonic() + attempt = 0 + delay = cfg.retry_delay + operation_succeeded: bool = False + final_elapsed_seconds: int | None = None + extracted_price: float | None = None + while True: + attempt += 1 + stop_event = asyncio.Event() + monitor_task: asyncio.Task | None = None + sess: aiohttp.ClientSession | None = None + + operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt) + logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt) + + payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"} + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + payload_headers.update(get_auth_header(cfg.node_cls)) + if cfg.endpoint.headers: + payload_headers.update(cfg.endpoint.headers) + + payload_kw: dict[str, Any] = {"headers": payload_headers} + if method == "GET": + payload_headers.pop("Content-Type", None) + request_body_log = _snapshot_request_body_for_logging(cfg.content_type, method, cfg.data, cfg.files) + try: + if cfg.monitor_progress: + monitor_task = asyncio.create_task(_monitor(stop_event, start_time)) + + timeout = aiohttp.ClientTimeout(total=cfg.timeout) + sess = aiohttp.ClientSession(timeout=timeout) + + if cfg.content_type == "multipart/form-data" and method != "GET": + # aiohttp will set Content-Type boundary; remove any fixed Content-Type + payload_headers.pop("Content-Type", None) + if cfg.multipart_parser and cfg.data: + form = cfg.multipart_parser(cfg.data) + if not isinstance(form, aiohttp.FormData): + raise ValueError("multipart_parser must return aiohttp.FormData") + else: + form = aiohttp.FormData(default_to_multipart=True) + if cfg.data: + for k, v in cfg.data.items(): + if v is None: + continue + form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) + if cfg.files: + file_iter = cfg.files if isinstance(cfg.files, list) else cfg.files.items() + for field_name, file_obj in file_iter: + if file_obj is None: + continue + if isinstance(file_obj, tuple): + filename, file_value, content_type = _unpack_tuple(file_obj) + else: + filename = getattr(file_obj, "name", field_name) + file_value = file_obj + content_type = "application/octet-stream" + # Attempt to rewind BytesIO for retries + if isinstance(file_value, BytesIO): + with contextlib.suppress(Exception): + file_value.seek(0) + form.add_field(field_name, file_value, filename=filename, content_type=content_type) + payload_kw["data"] = form + elif cfg.content_type == "application/x-www-form-urlencoded" and method != "GET": + payload_headers["Content-Type"] = "application/x-www-form-urlencoded" + payload_kw["data"] = cfg.data or {} + elif method != "GET": + payload_headers["Content-Type"] = "application/json" + payload_kw["json"] = cfg.data or {} + + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + ) + except Exception as _log_e: + logging.debug("[DEBUG] request logging failed: %s", _log_e) + + req_coro = sess.request(method, url, params=params, **payload_kw) + req_task = asyncio.create_task(req_coro) + + # Race: request vs. monitor (interruption) + tasks = {req_task} + if monitor_task: + tasks.add(monitor_task) + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task and monitor_task in done: + # Interrupted – cancel the request and abort + if req_task in pending: + req_task.cancel() + raise ProcessingInterrupted("Task cancelled") + + # Otherwise, request finished + resp = await req_task + async with resp: + if resp.status >= 400: + try: + body = await resp.json() + except (ContentTypeError, json.JSONDecodeError): + body = await resp.text() + if resp.status in _RETRY_STATUS and attempt <= cfg.max_retries: + logging.warning( + "HTTP %s %s -> %s. Retrying in %.2fs (retry %d of %d).", + method, + url, + resp.status, + delay, + attempt, + cfg.max_retries, + ) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=_friendly_http_message(resp.status, body), + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + + await sleep_with_interrupt( + delay, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + delay *= cfg.retry_backoff + continue + msg = _friendly_http_message(resp.status, body) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + raise Exception(msg) + + if expect_binary: + buff = bytearray() + last_tick = time.monotonic() + async for chunk in resp.content.iter_chunked(64 * 1024): + buff.extend(chunk) + now = time.monotonic() + if now - last_tick >= 1.0: + last_tick = now + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + if cfg.monitor_progress: + _display_time_progress( + cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total + ) + bytes_payload = bytes(buff) + operation_succeeded = True + final_elapsed_seconds = int(time.monotonic() - start_time) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=bytes_payload, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + return bytes_payload + else: + try: + payload = await resp.json() + response_content_to_log: Any = payload + except (ContentTypeError, json.JSONDecodeError): + text = await resp.text() + try: + payload = json.loads(text) if text else {} + except json.JSONDecodeError: + payload = {"_raw": text} + response_content_to_log = payload if isinstance(payload, dict) else text + with contextlib.suppress(Exception): + extracted_price = cfg.price_extractor(payload) if cfg.price_extractor else None + operation_succeeded = True + final_elapsed_seconds = int(time.monotonic() - start_time) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=response_content_to_log, + ) + except Exception as _log_e: + logging.debug("[DEBUG] response logging failed: %s", _log_e) + return payload + + except ProcessingInterrupted: + logging.debug("Polling was interrupted by user") + raise + except (ClientError, OSError) as e: + if attempt <= cfg.max_retries: + logging.warning( + "Connection error calling %s %s. Retrying in %.2fs (%d/%d): %s", + method, + url, + delay, + attempt, + cfg.max_retries, + str(e), + ) + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + except Exception as _log_e: + logging.debug("[DEBUG] request error logging failed: %s", _log_e) + await sleep_with_interrupt( + delay, + cfg.node_cls, + cfg.wait_label if cfg.monitor_progress else None, + start_time if cfg.monitor_progress else None, + cfg.estimated_total, + display_callback=_display_time_progress if cfg.monitor_progress else None, + ) + delay *= cfg.retry_backoff + continue + diag = await _diagnose_connectivity() + if not diag["internet_accessible"]: + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"LocalNetworkError: {str(e)}", + ) + except Exception as _log_e: + logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise LocalNetworkError( + "Unable to connect to the API server due to local network issues. " + "Please check your internet connection and try again." + ) from e + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + request_headers=dict(payload_headers) if payload_headers else None, + request_params=dict(params) if params else None, + request_data=request_body_log, + error_message=f"ApiServerError: {str(e)}", + ) + except Exception as _log_e: + logging.debug("[DEBUG] final error logging failed: %s", _log_e) + raise ApiServerError( + f"The API server at {default_base_url()} is currently unreachable. " + f"The service may be experiencing issues." + ) from e + finally: + stop_event.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if sess: + with contextlib.suppress(Exception): + await sess.close() + if operation_succeeded and cfg.monitor_progress and cfg.final_label_on_success: + _display_time_progress( + cfg.node_cls, + status=cfg.final_label_on_success, + elapsed_seconds=( + final_elapsed_seconds + if final_elapsed_seconds is not None + else int(time.monotonic() - start_time) + ), + estimated_total=cfg.estimated_total, + price=extracted_price, + is_queued=False, + processing_elapsed_seconds=final_elapsed_seconds, + ) + + +def _validate_or_raise(response_model: type[M], payload: Any) -> M: + try: + return response_model.model_validate(payload) + except Exception as e: + logging.error( + "Response validation failed for %s: %s", + getattr(response_model, "__name__", response_model), + e, + ) + raise Exception( + f"Response validation failed for {getattr(response_model, '__name__', response_model)}: {e}" + ) from e + + +def _wrap_model_extractor( + response_model: type[M], + extractor: Callable[[M], Any] | None, +) -> Callable[[dict[str, Any]], Any] | None: + """Wrap a typed extractor so it can be used by the dict-based poller. + Validates the dict into `response_model` before invoking `extractor`. + Uses a small per-wrapper cache keyed by `id(dict)` to avoid re-validating + the same response for multiple extractors in a single poll attempt. + """ + if extractor is None: + return None + _cache: dict[int, M] = {} + + def _wrapped(d: dict[str, Any]) -> Any: + try: + key = id(d) + model = _cache.get(key) + if model is None: + model = response_model.model_validate(d) + _cache[key] = model + return extractor(model) + except Exception as e: + logging.error("Extractor failed (typed -> dict wrapper): %s", e) + raise + + return _wrapped + + +def _normalize_statuses(values: Iterable[str | int] | None) -> set[str | int]: + if not values: + return set() + out: set[str | int] = set() + for v in values: + nv = _normalize_status_value(v) + if nv is not None: + out.add(nv) + return out + + +def _normalize_status_value(val: str | int | None) -> str | int | None: + if isinstance(val, str): + return val.strip().lower() + return val diff --git a/comfy_api_nodes/util/common_exceptions.py b/comfy_api_nodes/util/common_exceptions.py new file mode 100644 index 000000000..0606a4407 --- /dev/null +++ b/comfy_api_nodes/util/common_exceptions.py @@ -0,0 +1,14 @@ +class NetworkError(Exception): + """Base exception for network-related errors with diagnostic information.""" + + +class LocalNetworkError(NetworkError): + """Exception raised when local network connectivity issues are detected.""" + + +class ApiServerError(NetworkError): + """Exception raised when the API server is unreachable but internet is working.""" + + +class ProcessingInterrupted(Exception): + """Operation was interrupted by user/runtime via processing_interrupted().""" diff --git a/comfy_api_nodes/util/conversions.py b/comfy_api_nodes/util/conversions.py new file mode 100644 index 000000000..c57457580 --- /dev/null +++ b/comfy_api_nodes/util/conversions.py @@ -0,0 +1,467 @@ +import base64 +import logging +import math +import mimetypes +import uuid +from io import BytesIO + +import av +import numpy as np +import torch +from PIL import Image + +from comfy.utils import common_upscale +from comfy_api.latest import Input, InputImpl, Types + +from ._helpers import mimetype_to_extension + + +def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: + """Converts image data from BytesIO to a torch.Tensor. + + Args: + image_bytesio: BytesIO object containing the image data. + mode: The PIL mode to convert the image to (e.g., "RGB", "RGBA"). + + Returns: + A torch.Tensor representing the image (1, H, W, C). + + Raises: + PIL.UnidentifiedImageError: If the image data cannot be identified. + ValueError: If the specified mode is invalid. + """ + image = Image.open(image_bytesio) + image = image.convert(mode) + image_array = np.array(image).astype(np.float32) / 255.0 + return torch.from_numpy(image_array).unsqueeze(0) + + +def image_tensor_pair_to_batch(image1: torch.Tensor, image2: torch.Tensor) -> torch.Tensor: + """ + Converts a pair of image tensors to a batch tensor. + If the images are not the same size, the smaller image is resized to + match the larger image. + """ + if image1.shape[1:] != image2.shape[1:]: + image2 = common_upscale( + image2.movedim(-1, 1), + image1.shape[2], + image1.shape[1], + "bilinear", + "center", + ).movedim(1, -1) + return torch.cat((image1, image2), dim=0) + + +def tensor_to_bytesio( + image: torch.Tensor, + name: str | None = None, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> BytesIO: + """Converts a torch.Tensor image to a named BytesIO object. + + Args: + image: Input torch.Tensor image. + name: Optional filename for the BytesIO object. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). + + Returns: + Named BytesIO object containing the image data, with pointer set to the start of buffer. + """ + if not mime_type: + mime_type = "image/png" + + pil_image = tensor_to_pil(image, total_pixels=total_pixels) + img_binary = pil_to_bytesio(pil_image, mime_type=mime_type) + img_binary.name = f"{name if name else uuid.uuid4()}.{mimetype_to_extension(mime_type)}" + return img_binary + + +def tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: + """Converts a single torch.Tensor image [H, W, C] to a PIL Image, optionally downscaling.""" + if len(image.shape) > 3: + image = image[0] + # TODO: remove alpha if not allowed and present + input_tensor = image.cpu() + input_tensor = downscale_image_tensor(input_tensor.unsqueeze(0), total_pixels=total_pixels).squeeze() + image_np = (input_tensor.numpy() * 255).astype(np.uint8) + img = Image.fromarray(image_np) + return img + + +def tensor_to_base64_string( + image_tensor: torch.Tensor, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> str: + """Convert [B, H, W, C] or [H, W, C] tensor to a base64 string. + + Args: + image_tensor: Input torch.Tensor image. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp', 'video/mp4'). + + Returns: + Base64 encoded string of the image. + """ + pil_image = tensor_to_pil(image_tensor, total_pixels=total_pixels) + img_byte_arr = pil_to_bytesio(pil_image, mime_type=mime_type) + img_bytes = img_byte_arr.getvalue() + # Encode bytes to base64 string + base64_encoded_string = base64.b64encode(img_bytes).decode("utf-8") + return base64_encoded_string + + +def pil_to_bytesio(img: Image.Image, mime_type: str = "image/png") -> BytesIO: + """Converts a PIL Image to a BytesIO object.""" + if not mime_type: + mime_type = "image/png" + + img_byte_arr = BytesIO() + # Derive PIL format from MIME type (e.g., 'image/png' -> 'PNG') + pil_format = mime_type.split("/")[-1].upper() + if pil_format == "JPG": + pil_format = "JPEG" + img.save(img_byte_arr, format=pil_format) + img_byte_arr.seek(0) + return img_byte_arr + + +def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: + """Downscale input image tensor to roughly the specified total pixels.""" + samples = image.movedim(-1, 1) + total = int(total_pixels) + scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2])) + if scale_by >= 1: + return image + width = round(samples.shape[3] * scale_by) + height = round(samples.shape[2] * scale_by) + + s = common_upscale(samples, width, height, "lanczos", "disabled") + s = s.movedim(1, -1) + return s + + +def tensor_to_data_uri( + image_tensor: torch.Tensor, + total_pixels: int = 2048 * 2048, + mime_type: str = "image/png", +) -> str: + """Converts a tensor image to a Data URI string. + + Args: + image_tensor: Input torch.Tensor image. + total_pixels: Maximum total pixels for potential downscaling. + mime_type: Target image MIME type (e.g., 'image/png', 'image/jpeg', 'image/webp'). + + Returns: + Data URI string (e.g., 'data:image/png;base64,...'). + """ + base64_string = tensor_to_base64_string(image_tensor, total_pixels, mime_type) + return f"data:{mime_type};base64,{base64_string}" + + +def audio_to_base64_string(audio: Input.Audio, container_format: str = "mp4", codec_name: str = "aac") -> str: + """Converts an audio input to a base64 string.""" + sample_rate: int = audio["sample_rate"] + waveform: torch.Tensor = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) + audio_bytes = audio_bytes_io.getvalue() + return base64.b64encode(audio_bytes).decode("utf-8") + + +def video_to_base64_string( + video: Input.Video, + container_format: Types.VideoContainer | None = None, + codec: Types.VideoCodec | None = None, +) -> str: + """ + Converts a video input to a base64 string. + + Args: + video: The video input to convert + container_format: Optional container format to use (defaults to video.container if available) + codec: Optional codec to use (defaults to video.codec if available) + """ + video_bytes_io = BytesIO() + video.save_to( + video_bytes_io, + format=container_format or getattr(video, "container", Types.VideoContainer.MP4), + codec=codec or getattr(video, "codec", Types.VideoCodec.H264), + ) + video_bytes_io.seek(0) + return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") + + +def audio_ndarray_to_bytesio( + audio_data_np: np.ndarray, + sample_rate: int, + container_format: str = "mp4", + codec_name: str = "aac", +) -> BytesIO: + """ + Encodes a numpy array of audio data into a BytesIO object. + """ + audio_bytes_io = BytesIO() + with av.open(audio_bytes_io, mode="w", format=container_format) as output_container: + audio_stream = output_container.add_stream(codec_name, rate=sample_rate) + frame = av.AudioFrame.from_ndarray( + audio_data_np, + format="fltp", + layout="stereo" if audio_data_np.shape[0] > 1 else "mono", + ) + frame.sample_rate = sample_rate + frame.pts = 0 + + for packet in audio_stream.encode(frame): + output_container.mux(packet) + + # Flush stream + for packet in audio_stream.encode(None): + output_container.mux(packet) + + audio_bytes_io.seek(0) + return audio_bytes_io + + +def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: + """ + Prepares audio waveform for av library by converting to a contiguous numpy array. + + Args: + waveform: a tensor of shape (1, channels, samples) derived from a Comfy `AUDIO` type. + + Returns: + Contiguous numpy array of the audio waveform. If the audio was batched, + the first item is taken. + """ + if waveform.ndim != 3 or waveform.shape[0] != 1: + raise ValueError("Expected waveform tensor shape (1, channels, samples)") + + # If batch is > 1, take first item + if waveform.shape[0] > 1: + waveform = waveform[0] + + # Prepare for av: remove batch dim, move to CPU, make contiguous, convert to numpy array + audio_data_np = waveform.squeeze(0).cpu().contiguous().numpy() + if audio_data_np.dtype != np.float32: + audio_data_np = audio_data_np.astype(np.float32) + + return audio_data_np + + +def audio_input_to_mp3(audio: Input.Audio) -> BytesIO: + waveform = audio["waveform"].cpu() + + output_buffer = BytesIO() + output_container = av.open(output_buffer, mode="w", format="mp3") + + out_stream = output_container.add_stream("libmp3lame", rate=audio["sample_rate"]) + out_stream.bit_rate = 320000 + + frame = av.AudioFrame.from_ndarray( + waveform.movedim(0, 1).reshape(1, -1).float().numpy(), + format="flt", + layout="mono" if waveform.shape[0] == 1 else "stereo", + ) + frame.sample_rate = audio["sample_rate"] + frame.pts = 0 + output_container.mux(out_stream.encode(frame)) + output_container.mux(out_stream.encode(None)) + output_container.close() + output_buffer.seek(0) + return output_buffer + + +def trim_video(video: Input.Video, duration_sec: float) -> Input.Video: + """ + Returns a new VideoInput object trimmed from the beginning to the specified duration, + using av to avoid loading entire video into memory. + + Args: + video: Input video to trim + duration_sec: Duration in seconds to keep from the beginning + + Returns: + VideoFromFile object that owns the output buffer + """ + output_buffer = BytesIO() + input_container = None + output_container = None + + try: + # Get the stream source - this avoids loading entire video into memory + # when the source is already a file path + input_source = video.get_stream_source() + + # Open containers + input_container = av.open(input_source, mode="r") + output_container = av.open(output_buffer, mode="w", format="mp4") + + # Set up output streams for re-encoding + video_stream = None + audio_stream = None + + for stream in input_container.streams: + logging.info("Found stream: type=%s, class=%s", stream.type, type(stream)) + if isinstance(stream, av.VideoStream): + # Create output video stream with same parameters + video_stream = output_container.add_stream("h264", rate=stream.average_rate) + video_stream.width = stream.width + video_stream.height = stream.height + video_stream.pix_fmt = "yuv420p" + logging.info("Added video stream: %sx%s @ %sfps", stream.width, stream.height, stream.average_rate) + elif isinstance(stream, av.AudioStream): + # Create output audio stream with same parameters + audio_stream = output_container.add_stream("aac", rate=stream.sample_rate) + audio_stream.sample_rate = stream.sample_rate + audio_stream.layout = stream.layout + logging.info("Added audio stream: %sHz, %s channels", stream.sample_rate, stream.channels) + + # Calculate target frame count that's divisible by 16 + fps = input_container.streams.video[0].average_rate + estimated_frames = int(duration_sec * fps) + target_frames = (estimated_frames // 16) * 16 # Round down to nearest multiple of 16 + + if target_frames == 0: + raise ValueError("Video too short: need at least 16 frames for Moonvalley") + + frame_count = 0 + audio_frame_count = 0 + + # Decode and re-encode video frames + if video_stream: + for frame in input_container.decode(video=0): + if frame_count >= target_frames: + break + + # Re-encode frame + for packet in video_stream.encode(frame): + output_container.mux(packet) + frame_count += 1 + + # Flush encoder + for packet in video_stream.encode(): + output_container.mux(packet) + + logging.info("Encoded %s video frames (target: %s)", frame_count, target_frames) + + # Decode and re-encode audio frames + if audio_stream: + input_container.seek(0) # Reset to beginning for audio + for frame in input_container.decode(audio=0): + if frame.time >= duration_sec: + break + + # Re-encode frame + for packet in audio_stream.encode(frame): + output_container.mux(packet) + audio_frame_count += 1 + + # Flush encoder + for packet in audio_stream.encode(): + output_container.mux(packet) + + logging.info("Encoded %s audio frames", audio_frame_count) + + # Close containers + output_container.close() + input_container.close() + + # Return as VideoFromFile using the buffer + output_buffer.seek(0) + return InputImpl.VideoFromFile(output_buffer) + + except Exception as e: + # Clean up on error + if input_container is not None: + input_container.close() + if output_container is not None: + output_container.close() + raise RuntimeError(f"Failed to trim video: {str(e)}") from e + + +def _f32_pcm(wav: torch.Tensor) -> torch.Tensor: + """Convert audio to float 32 bits PCM format. Copy-paste from nodes_audio.py file.""" + if wav.dtype.is_floating_point: + return wav + elif wav.dtype == torch.int16: + return wav.float() / (2**15) + elif wav.dtype == torch.int32: + return wav.float() / (2**31) + raise ValueError(f"Unsupported wav dtype: {wav.dtype}") + + +def audio_bytes_to_audio_input(audio_bytes: bytes) -> dict: + """ + Decode any common audio container from bytes using PyAV and return + a Comfy AUDIO dict: {"waveform": [1, C, T] float32, "sample_rate": int}. + """ + with av.open(BytesIO(audio_bytes)) as af: + if not af.streams.audio: + raise ValueError("No audio stream found in response.") + stream = af.streams.audio[0] + + in_sr = int(stream.codec_context.sample_rate) + out_sr = in_sr + + frames: list[torch.Tensor] = [] + n_channels = stream.channels or 1 + + for frame in af.decode(streams=stream.index): + arr = frame.to_ndarray() # shape can be [C, T] or [T, C] or [T] + buf = torch.from_numpy(arr) + if buf.ndim == 1: + buf = buf.unsqueeze(0) # [T] -> [1, T] + elif buf.shape[0] != n_channels and buf.shape[-1] == n_channels: + buf = buf.transpose(0, 1).contiguous() # [T, C] -> [C, T] + elif buf.shape[0] != n_channels: + buf = buf.reshape(-1, n_channels).t().contiguous() # fallback to [C, T] + frames.append(buf) + + if not frames: + raise ValueError("Decoded zero audio frames.") + + wav = torch.cat(frames, dim=1) # [C, T] + wav = _f32_pcm(wav) + return {"waveform": wav.unsqueeze(0).contiguous(), "sample_rate": out_sr} + + +def resize_mask_to_image( + mask: torch.Tensor, + image: torch.Tensor, + upscale_method="nearest-exact", + crop="disabled", + allow_gradient=True, + add_channel_dim=False, +): + """Resize mask to be the same dimensions as an image, while maintaining proper format for API calls.""" + _, height, width, _ = image.shape + mask = mask.unsqueeze(-1) + mask = mask.movedim(-1, 1) + mask = common_upscale(mask, width=width, height=height, upscale_method=upscale_method, crop=crop) + mask = mask.movedim(1, -1) + if not add_channel_dim: + mask = mask.squeeze(-1) + if not allow_gradient: + mask = (mask > 0.5).float() + return mask + + +def text_filepath_to_base64_string(filepath: str) -> str: + """Converts a text file to a base64 string.""" + with open(filepath, "rb") as f: + file_content = f.read() + return base64.b64encode(file_content).decode("utf-8") + + +def text_filepath_to_data_uri(filepath: str) -> str: + """Converts a text file to a data URI.""" + base64_string = text_filepath_to_base64_string(filepath) + mime_type, _ = mimetypes.guess_type(filepath) + if mime_type is None: + mime_type = "application/octet-stream" + return f"data:{mime_type};base64,{base64_string}" diff --git a/comfy_api_nodes/util/download_helpers.py b/comfy_api_nodes/util/download_helpers.py new file mode 100644 index 000000000..3e0d0352d --- /dev/null +++ b/comfy_api_nodes/util/download_helpers.py @@ -0,0 +1,262 @@ +import asyncio +import contextlib +import uuid +from io import BytesIO +from pathlib import Path +from typing import IO +from urllib.parse import urljoin, urlparse + +import aiohttp +import torch +from aiohttp.client_exceptions import ClientError, ContentTypeError + +from comfy_api.latest import IO as COMFY_IO +from comfy_api.latest import InputImpl + +from . import request_logger +from ._helpers import ( + default_base_url, + get_auth_header, + is_processing_interrupted, + sleep_with_interrupt, +) +from .client import _diagnose_connectivity +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted +from .conversions import bytesio_to_image_tensor + +_RETRY_STATUS = {408, 429, 500, 502, 503, 504} + + +async def download_url_to_bytesio( + url: str, + dest: BytesIO | IO[bytes] | str | Path | None, + *, + timeout: float | None = None, + max_retries: int = 5, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + cls: type[COMFY_IO.ComfyNode] = None, +) -> None: + """Stream-download a URL to `dest`. + + `dest` must be one of: + - a BytesIO (rewound to 0 after write), + - a file-like object opened in binary write mode (must implement .write()), + - a filesystem path (str | pathlib.Path), which will be opened with 'wb'. + + If `url` starts with `/proxy/`, `cls` must be provided so the URL can be expanded + to an absolute URL and authentication headers can be applied. + + Raises: + ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception (HTTP and other errors) + """ + if not isinstance(dest, (str, Path)) and not hasattr(dest, "write"): + raise ValueError("dest must be a path (str|Path) or a binary-writable object providing .write().") + + attempt = 0 + delay = retry_delay + headers: dict[str, str] = {} + + parsed_url = urlparse(url) + if not parsed_url.scheme and not parsed_url.netloc: # is URL relative? + if cls is None: + raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.") + url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/")) + headers = get_auth_header(cls) + + while True: + attempt += 1 + op_id = _generate_operation_id("GET", url, attempt) + timeout_cfg = aiohttp.ClientTimeout(total=timeout) + + is_path_sink = isinstance(dest, (str, Path)) + fhandle = None + session: aiohttp.ClientSession | None = None + stop_evt: asyncio.Event | None = None + monitor_task: asyncio.Task | None = None + req_task: asyncio.Task | None = None + + try: + with contextlib.suppress(Exception): + request_logger.log_request_response(operation_id=op_id, request_method="GET", request_url=url) + + session = aiohttp.ClientSession(timeout=timeout_cfg) + stop_evt = asyncio.Event() + + async def _monitor(): + try: + while not stop_evt.is_set(): + if is_processing_interrupted(): + return + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return + + monitor_task = asyncio.create_task(_monitor()) + + req_task = asyncio.create_task(session.get(url, headers=headers)) + done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task in done and req_task in pending: + req_task.cancel() + with contextlib.suppress(Exception): + await req_task + raise ProcessingInterrupted("Task cancelled") + + try: + resp = await req_task + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + + async with resp: + if resp.status >= 400: + with contextlib.suppress(Exception): + try: + body = await resp.json() + except (ContentTypeError, ValueError): + text = await resp.text() + body = text if len(text) <= 4096 else f"[text {len(text)} bytes]" + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=f"HTTP {resp.status}", + ) + + if resp.status in _RETRY_STATUS and attempt <= max_retries: + await sleep_with_interrupt(delay, cls, None, None, None) + delay *= retry_backoff + continue + raise Exception(f"Failed to download (HTTP {resp.status}).") + + if is_path_sink: + p = Path(str(dest)) + with contextlib.suppress(Exception): + p.parent.mkdir(parents=True, exist_ok=True) + fhandle = open(p, "wb") + sink = fhandle + else: + sink = dest # BytesIO or file-like + + written = 0 + while True: + try: + chunk = await asyncio.wait_for(resp.content.read(1024 * 1024), timeout=1.0) + except asyncio.TimeoutError: + chunk = b"" + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + + if is_processing_interrupted(): + raise ProcessingInterrupted("Task cancelled") + + if not chunk: + if resp.content.at_eof(): + break + continue + + sink.write(chunk) + written += len(chunk) + + if isinstance(dest, BytesIO): + with contextlib.suppress(Exception): + dest.seek(0) + + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=f"[streamed {written} bytes to dest]", + ) + return + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + except (ClientError, OSError) as e: + if attempt <= max_retries: + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=op_id, + request_method="GET", + request_url=url, + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + await sleep_with_interrupt(delay, cls, None, None, None) + delay *= retry_backoff + continue + + diag = await _diagnose_connectivity() + if not diag["internet_accessible"]: + raise LocalNetworkError( + "Unable to connect to the network. Please check your internet connection and try again." + ) from e + raise ApiServerError("The remote service appears unreachable at this time.") from e + finally: + if stop_evt is not None: + stop_evt.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if req_task and not req_task.done(): + req_task.cancel() + with contextlib.suppress(Exception): + await req_task + if session: + with contextlib.suppress(Exception): + await session.close() + if fhandle: + with contextlib.suppress(Exception): + fhandle.flush() + fhandle.close() + + +async def download_url_to_image_tensor( + url: str, + *, + timeout: float = None, + cls: type[COMFY_IO.ComfyNode] = None, +) -> torch.Tensor: + """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" + result = BytesIO() + await download_url_to_bytesio(url, result, timeout=timeout, cls=cls) + return bytesio_to_image_tensor(result) + + +async def download_url_to_video_output( + video_url: str, + *, + timeout: float = None, + max_retries: int = 5, + cls: type[COMFY_IO.ComfyNode] = None, +) -> InputImpl.VideoFromFile: + """Downloads a video from a URL and returns a `VIDEO` output.""" + result = BytesIO() + await download_url_to_bytesio(video_url, result, timeout=timeout, max_retries=max_retries, cls=cls) + return InputImpl.VideoFromFile(result) + + +async def download_url_as_bytesio( + url: str, + *, + timeout: float = None, + cls: type[COMFY_IO.ComfyNode] = None, +) -> BytesIO: + """Downloads content from a URL and returns a new BytesIO (rewound to 0).""" + result = BytesIO() + await download_url_to_bytesio(url, result, timeout=timeout, cls=cls) + return result + + +def _generate_operation_id(method: str, url: str, attempt: int) -> str: + try: + parsed = urlparse(url) + slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "download").strip("/").replace("/", "_") + except Exception: + slug = "download" + return f"{method}_{slug}_try{attempt}_{uuid.uuid4().hex[:8]}" diff --git a/comfy_api_nodes/apis/request_logger.py b/comfy_api_nodes/util/request_logger.py similarity index 99% rename from comfy_api_nodes/apis/request_logger.py rename to comfy_api_nodes/util/request_logger.py index 33d07040a..5337f4d0e 100644 --- a/comfy_api_nodes/apis/request_logger.py +++ b/comfy_api_nodes/util/request_logger.py @@ -1,13 +1,11 @@ -from __future__ import annotations - -import os import datetime +import hashlib import json import logging -from comfy.cmd import folder_paths +import os import re -import hashlib from typing import Any +from comfy.cmd import folder_paths # Get the logger instance logger = logging.getLogger(__name__) diff --git a/comfy_api_nodes/util/upload_helpers.py b/comfy_api_nodes/util/upload_helpers.py new file mode 100644 index 000000000..b8d33f4d1 --- /dev/null +++ b/comfy_api_nodes/util/upload_helpers.py @@ -0,0 +1,338 @@ +import asyncio +import contextlib +import logging +import time +import uuid +from io import BytesIO +from urllib.parse import urlparse + +import aiohttp +import torch +from pydantic import BaseModel, Field + +from comfy_api.latest import IO, Input, Types + +from . import request_logger +from ._helpers import is_processing_interrupted, sleep_with_interrupt +from .client import ( + ApiEndpoint, + _diagnose_connectivity, + _display_time_progress, + sync_op, +) +from .common_exceptions import ApiServerError, LocalNetworkError, ProcessingInterrupted +from .conversions import ( + audio_ndarray_to_bytesio, + audio_tensor_to_contiguous_ndarray, + tensor_to_bytesio, +) + + +class UploadRequest(BaseModel): + file_name: str = Field(..., description="Filename to upload") + content_type: str | None = Field( + None, + description="Mime type of the file. For example: image/png, image/jpeg, video/mp4, etc.", + ) + + +class UploadResponse(BaseModel): + download_url: str = Field(..., description="URL to GET uploaded file") + upload_url: str = Field(..., description="URL to PUT file to upload") + + +async def upload_images_to_comfyapi( + cls: type[IO.ComfyNode], + image: torch.Tensor, + *, + max_images: int = 8, + mime_type: str | None = None, + wait_label: str | None = "Uploading", + show_batch_index: bool = True, +) -> list[str]: + """ + Uploads images to ComfyUI API and returns download URLs. + To upload multiple images, stack them in the batch dimension first. + """ + # if batched, try to upload each file if max_images is greater than 0 + download_urls: list[str] = [] + is_batch = len(image.shape) > 3 + batch_len = image.shape[0] if is_batch else 1 + num_to_upload = min(batch_len, max_images) + batch_start_ts = time.monotonic() + + for idx in range(num_to_upload): + tensor = image[idx] if is_batch else image + img_io = tensor_to_bytesio(tensor, mime_type=mime_type) + + effective_label = wait_label + if wait_label and show_batch_index and num_to_upload > 1: + effective_label = f"{wait_label} ({idx + 1}/{num_to_upload})" + + url = await upload_file_to_comfyapi(cls, img_io, img_io.name, mime_type, effective_label, batch_start_ts) + download_urls.append(url) + return download_urls + + +async def upload_audio_to_comfyapi( + cls: type[IO.ComfyNode], + audio: Input.Audio, + *, + container_format: str = "mp4", + codec_name: str = "aac", + mime_type: str = "audio/mp4", + filename: str = "uploaded_audio.mp4", +) -> str: + """ + Uploads a single audio input to ComfyUI API and returns its download URL. + Encodes the raw waveform into the specified format before uploading. + """ + sample_rate: int = audio["sample_rate"] + waveform: torch.Tensor = audio["waveform"] + audio_data_np = audio_tensor_to_contiguous_ndarray(waveform) + audio_bytes_io = audio_ndarray_to_bytesio(audio_data_np, sample_rate, container_format, codec_name) + return await upload_file_to_comfyapi(cls, audio_bytes_io, filename, mime_type) + + +async def upload_video_to_comfyapi( + cls: type[IO.ComfyNode], + video: Input.Video, + *, + container: Types.VideoContainer = Types.VideoContainer.MP4, + codec: Types.VideoCodec = Types.VideoCodec.H264, + max_duration: int | None = None, + wait_label: str | None = "Uploading", +) -> str: + """ + Uploads a single video to ComfyUI API and returns its download URL. + Uses the specified container and codec for saving the video before upload. + """ + if max_duration is not None: + try: + actual_duration = video.get_duration() + if actual_duration > max_duration: + raise ValueError( + f"Video duration ({actual_duration:.2f}s) exceeds the maximum allowed ({max_duration}s)." + ) + except Exception as e: + logging.error("Error getting video duration: %s", str(e)) + raise ValueError(f"Could not verify video duration from source: {e}") from e + + upload_mime_type = f"video/{container.value.lower()}" + filename = f"uploaded_video.{container.value.lower()}" + + # Convert VideoInput to BytesIO using specified container/codec + video_bytes_io = BytesIO() + video.save_to(video_bytes_io, format=container, codec=codec) + video_bytes_io.seek(0) + + return await upload_file_to_comfyapi(cls, video_bytes_io, filename, upload_mime_type, wait_label) + + +async def upload_file_to_comfyapi( + cls: type[IO.ComfyNode], + file_bytes_io: BytesIO, + filename: str, + upload_mime_type: str | None, + wait_label: str | None = "Uploading", + progress_origin_ts: float | None = None, +) -> str: + """Uploads a single file to ComfyUI API and returns its download URL.""" + if upload_mime_type is None: + request_object = UploadRequest(file_name=filename) + else: + request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) + create_resp = await sync_op( + cls, + endpoint=ApiEndpoint(path="/customers/storage", method="POST"), + data=request_object, + response_model=UploadResponse, + final_label_on_success=None, + monitor_progress=False, + ) + await upload_file( + cls, + create_resp.upload_url, + file_bytes_io, + content_type=upload_mime_type, + wait_label=wait_label, + progress_origin_ts=progress_origin_ts, + ) + return create_resp.download_url + + +async def upload_file( + cls: type[IO.ComfyNode], + upload_url: str, + file: BytesIO | str, + *, + content_type: str | None = None, + max_retries: int = 3, + retry_delay: float = 1.0, + retry_backoff: float = 2.0, + wait_label: str | None = None, + progress_origin_ts: float | None = None, +) -> None: + """ + Upload a file to a signed URL (e.g., S3 pre-signed PUT) with retries, Comfy progress display, and interruption. + + Raises: + ProcessingInterrupted, LocalNetworkError, ApiServerError, Exception + """ + if isinstance(file, BytesIO): + with contextlib.suppress(Exception): + file.seek(0) + data = file.read() + elif isinstance(file, str): + with open(file, "rb") as f: + data = f.read() + else: + raise ValueError("file must be a BytesIO or a filesystem path string") + + headers: dict[str, str] = {} + skip_auto_headers: set[str] = set() + if content_type: + headers["Content-Type"] = content_type + else: + skip_auto_headers.add("Content-Type") # Don't let aiohttp add Content-Type, it can break the signed request + + attempt = 0 + delay = retry_delay + start_ts = progress_origin_ts if progress_origin_ts is not None else time.monotonic() + op_uuid = uuid.uuid4().hex[:8] + while True: + attempt += 1 + operation_id = _generate_operation_id("PUT", upload_url, attempt, op_uuid) + timeout = aiohttp.ClientTimeout(total=None) + stop_evt = asyncio.Event() + + async def _monitor(): + try: + while not stop_evt.is_set(): + if is_processing_interrupted(): + return + if wait_label: + _display_time_progress(cls, wait_label, int(time.monotonic() - start_ts), None) + await asyncio.sleep(1.0) + except asyncio.CancelledError: + return + + monitor_task = asyncio.create_task(_monitor()) + sess: aiohttp.ClientSession | None = None + try: + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_params=None, + request_data=f"[File data {len(data)} bytes]", + ) + except Exception as e: + logging.debug("[DEBUG] upload request logging failed: %s", e) + + sess = aiohttp.ClientSession(timeout=timeout) + req = sess.put(upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers) + req_task = asyncio.create_task(req) + + done, pending = await asyncio.wait({req_task, monitor_task}, return_when=asyncio.FIRST_COMPLETED) + + if monitor_task in done and req_task in pending: + req_task.cancel() + raise ProcessingInterrupted("Upload cancelled") + + try: + resp = await req_task + except asyncio.CancelledError: + raise ProcessingInterrupted("Upload cancelled") from None + + async with resp: + if resp.status >= 400: + with contextlib.suppress(Exception): + try: + body = await resp.json() + except Exception: + body = await resp.text() + msg = f"Upload failed with status {resp.status}" + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=body, + error_message=msg, + ) + if resp.status in {408, 429, 500, 502, 503, 504} and attempt <= max_retries: + await sleep_with_interrupt( + delay, + cls, + wait_label, + start_ts, + None, + display_callback=_display_time_progress if wait_label else None, + ) + delay *= retry_backoff + continue + raise Exception(f"Failed to upload (HTTP {resp.status}).") + try: + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content="File uploaded successfully.", + ) + except Exception as e: + logging.debug("[DEBUG] upload response logging failed: %s", e) + return + except asyncio.CancelledError: + raise ProcessingInterrupted("Task cancelled") from None + except (aiohttp.ClientError, OSError) as e: + if attempt <= max_retries: + with contextlib.suppress(Exception): + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + request_headers=headers or None, + request_data=f"[File data {len(data)} bytes]", + error_message=f"{type(e).__name__}: {str(e)} (will retry)", + ) + await sleep_with_interrupt( + delay, + cls, + wait_label, + start_ts, + None, + display_callback=_display_time_progress if wait_label else None, + ) + delay *= retry_backoff + continue + + diag = await _diagnose_connectivity() + if not diag["internet_accessible"]: + raise LocalNetworkError( + "Unable to connect to the network. Please check your internet connection and try again." + ) from e + raise ApiServerError("The API service appears unreachable at this time.") from e + finally: + stop_evt.set() + if monitor_task: + monitor_task.cancel() + with contextlib.suppress(Exception): + await monitor_task + if sess: + with contextlib.suppress(Exception): + await sess.close() + + +def _generate_operation_id(method: str, url: str, attempt: int, op_uuid: str) -> str: + try: + parsed = urlparse(url) + slug = (parsed.path.rsplit("/", 1)[-1] or parsed.netloc or "upload").strip("/").replace("/", "_") + except Exception: + slug = "upload" + return f"{method}_{slug}_{op_uuid}_try{attempt}" diff --git a/comfy_api_nodes/util/validation_utils.py b/comfy_api_nodes/util/validation_utils.py index ca913e9b3..f01edea96 100644 --- a/comfy_api_nodes/util/validation_utils.py +++ b/comfy_api_nodes/util/validation_utils.py @@ -1,7 +1,7 @@ import logging -from typing import Optional import torch + from comfy_api.latest import Input @@ -16,10 +16,10 @@ def get_image_dimensions(image: torch.Tensor) -> tuple[int, int]: def validate_image_dimensions( image: torch.Tensor, - min_width: Optional[int] = None, - max_width: Optional[int] = None, - min_height: Optional[int] = None, - max_height: Optional[int] = None, + min_width: int | None = None, + max_width: int | None = None, + min_height: int | None = None, + max_height: int | None = None, ): height, width = get_image_dimensions(image) @@ -28,84 +28,77 @@ def validate_image_dimensions( if max_width is not None and width > max_width: raise ValueError(f"Image width must be at most {max_width}px, got {width}px") if min_height is not None and height < min_height: - raise ValueError( - f"Image height must be at least {min_height}px, got {height}px" - ) + raise ValueError(f"Image height must be at least {min_height}px, got {height}px") if max_height is not None and height > max_height: raise ValueError(f"Image height must be at most {max_height}px, got {height}px") def validate_image_aspect_ratio( image: torch.Tensor, - min_aspect_ratio: Optional[float] = None, - max_aspect_ratio: Optional[float] = None, -): - width, height = get_image_dimensions(image) - aspect_ratio = width / height - - if min_aspect_ratio is not None and aspect_ratio < min_aspect_ratio: - raise ValueError( - f"Image aspect ratio must be at least {min_aspect_ratio}, got {aspect_ratio}" - ) - if max_aspect_ratio is not None and aspect_ratio > max_aspect_ratio: - raise ValueError( - f"Image aspect ratio must be at most {max_aspect_ratio}, got {aspect_ratio}" - ) - - -def validate_image_aspect_ratio_range( - image: torch.Tensor, - min_ratio: tuple[float, float], # e.g. (1, 4) - max_ratio: tuple[float, float], # e.g. (4, 1) + min_ratio: tuple[float, float] | None = None, # e.g. (1, 4) + max_ratio: tuple[float, float] | None = None, # e.g. (4, 1) *, - strict: bool = True, # True -> (min, max); False -> [min, max] + strict: bool = True, # True -> (min, max); False -> [min, max] ) -> float: - a1, b1 = min_ratio - a2, b2 = max_ratio - if a1 <= 0 or b1 <= 0 or a2 <= 0 or b2 <= 0: - raise ValueError("Ratios must be positive, like (1, 4) or (4, 1).") - lo, hi = (a1 / b1), (a2 / b2) - if lo > hi: - lo, hi = hi, lo - a1, b1, a2, b2 = a2, b2, a1, b1 # swap only for error text + """Validates that image aspect ratio is within min and max. If a bound is None, that side is not checked.""" w, h = get_image_dimensions(image) if w <= 0 or h <= 0: raise ValueError(f"Invalid image dimensions: {w}x{h}") ar = w / h - ok = (lo < ar < hi) if strict else (lo <= ar <= hi) - if not ok: - op = "<" if strict else "≤" - raise ValueError(f"Image aspect ratio {ar:.6g} is outside allowed range: {a1}:{b1} {op} ratio {op} {a2}:{b2}") + _assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict) return ar -def validate_aspect_ratio_closeness( - start_img, - end_img, - min_rel: float, - max_rel: float, +def validate_images_aspect_ratio_closeness( + first_image: torch.Tensor, + second_image: torch.Tensor, + min_rel: float, # e.g. 0.8 + max_rel: float, # e.g. 1.25 *, - strict: bool = False, # True => exclusive, False => inclusive -) -> None: - w1, h1 = get_image_dimensions(start_img) - w2, h2 = get_image_dimensions(end_img) + strict: bool = False, # True -> (min, max); False -> [min, max] +) -> float: + """ + Validates that the two images' aspect ratios are 'close'. + The closeness factor is C = max(ar1, ar2) / min(ar1, ar2) (C >= 1). + We require C <= limit, where limit = max(max_rel, 1.0 / min_rel). + + Returns the computed closeness factor C. + """ + w1, h1 = get_image_dimensions(first_image) + w2, h2 = get_image_dimensions(second_image) if min(w1, h1, w2, h2) <= 0: raise ValueError("Invalid image dimensions") ar1 = w1 / h1 ar2 = w2 / h2 - # Normalize so it is symmetric (no need to check both ar1/ar2 and ar2/ar1) closeness = max(ar1, ar2) / min(ar1, ar2) - limit = max(max_rel, 1.0 / min_rel) # for 0.8..1.25 this is 1.25 + limit = max(max_rel, 1.0 / min_rel) if (closeness >= limit) if strict else (closeness > limit): - raise ValueError(f"Aspect ratios must be close: start/end={ar1/ar2:.4f}, allowed range {min_rel}–{max_rel}.") + raise ValueError( + f"Aspect ratios must be close: ar1/ar2={ar1/ar2:.2g}, " + f"allowed range {min_rel}–{max_rel} (limit {limit:.2g})." + ) + return closeness + + +def validate_aspect_ratio_string( + aspect_ratio: str, + min_ratio: tuple[float, float] | None = None, # e.g. (1, 4) + max_ratio: tuple[float, float] | None = None, # e.g. (4, 1) + *, + strict: bool = False, # True -> (min, max); False -> [min, max] +) -> float: + """Parses 'X:Y' and validates it against optional bounds. Returns the numeric ratio.""" + ar = _parse_aspect_ratio_string(aspect_ratio) + _assert_ratio_bounds(ar, min_ratio=min_ratio, max_ratio=max_ratio, strict=strict) + return ar def validate_video_dimensions( video: Input.Video, - min_width: Optional[int] = None, - max_width: Optional[int] = None, - min_height: Optional[int] = None, - max_height: Optional[int] = None, + min_width: int | None = None, + max_width: int | None = None, + min_height: int | None = None, + max_height: int | None = None, ): try: width, height = video.get_dimensions() @@ -118,17 +111,15 @@ def validate_video_dimensions( if max_width is not None and width > max_width: raise ValueError(f"Video width must be at most {max_width}px, got {width}px") if min_height is not None and height < min_height: - raise ValueError( - f"Video height must be at least {min_height}px, got {height}px" - ) + raise ValueError(f"Video height must be at least {min_height}px, got {height}px") if max_height is not None and height > max_height: raise ValueError(f"Video height must be at most {max_height}px, got {height}px") def validate_video_duration( video: Input.Video, - min_duration: Optional[float] = None, - max_duration: Optional[float] = None, + min_duration: float | None = None, + max_duration: float | None = None, ): try: duration = video.get_duration() @@ -138,13 +129,26 @@ def validate_video_duration( epsilon = 0.0001 if min_duration is not None and min_duration - epsilon > duration: - raise ValueError( - f"Video duration must be at least {min_duration}s, got {duration}s" - ) + raise ValueError(f"Video duration must be at least {min_duration}s, got {duration}s") if max_duration is not None and duration > max_duration + epsilon: - raise ValueError( - f"Video duration must be at most {max_duration}s, got {duration}s" - ) + raise ValueError(f"Video duration must be at most {max_duration}s, got {duration}s") + + +def validate_video_frame_count( + video: Input.Video, + min_frame_count: int | None = None, + max_frame_count: int | None = None, +): + try: + frame_count = video.get_frame_count() + except Exception as e: + logging.error("Error getting frame count of video: %s", e) + return + + if min_frame_count is not None and min_frame_count > frame_count: + raise ValueError(f"Video frame count must be at least {min_frame_count}, got {frame_count}") + if max_frame_count is not None and frame_count > max_frame_count: + raise ValueError(f"Video frame count must be at most {max_frame_count}, got {frame_count}") def get_number_of_images(images): @@ -155,8 +159,8 @@ def get_number_of_images(images): def validate_audio_duration( audio: Input.Audio, - min_duration: Optional[float] = None, - max_duration: Optional[float] = None, + min_duration: float | None = None, + max_duration: float | None = None, ) -> None: sr = int(audio["sample_rate"]) dur = int(audio["waveform"].shape[-1]) / sr @@ -165,3 +169,77 @@ def validate_audio_duration( raise ValueError(f"Audio duration must be at least {min_duration}s, got {dur + eps:.2f}s") if max_duration is not None and dur - eps > max_duration: raise ValueError(f"Audio duration must be at most {max_duration}s, got {dur - eps:.2f}s") + + +def validate_string( + string: str, + strip_whitespace=True, + field_name="prompt", + min_length=None, + max_length=None, +): + if string is None: + raise Exception(f"Field '{field_name}' cannot be empty.") + if strip_whitespace: + string = string.strip() + if min_length and len(string) < min_length: + raise Exception( + f"Field '{field_name}' cannot be shorter than {min_length} characters; was {len(string)} characters long." + ) + if max_length and len(string) > max_length: + raise Exception( + f" Field '{field_name} cannot be longer than {max_length} characters; was {len(string)} characters long." + ) + + +def validate_container_format_is_mp4(video: Input.Video) -> None: + """Validates video container format is MP4.""" + container_format = video.get_container_format() + if container_format not in ["mp4", "mov,mp4,m4a,3gp,3g2,mj2"]: + raise ValueError(f"Only MP4 container format supported. Got: {container_format}") + + +def _ratio_from_tuple(r: tuple[float, float]) -> float: + a, b = r + if a <= 0 or b <= 0: + raise ValueError(f"Ratios must be positive, got {a}:{b}.") + return a / b + + +def _assert_ratio_bounds( + ar: float, + *, + min_ratio: tuple[float, float] | None = None, + max_ratio: tuple[float, float] | None = None, + strict: bool = True, +) -> None: + """Validate a numeric aspect ratio against optional min/max ratio bounds.""" + lo = _ratio_from_tuple(min_ratio) if min_ratio is not None else None + hi = _ratio_from_tuple(max_ratio) if max_ratio is not None else None + + if lo is not None and hi is not None and lo > hi: + lo, hi = hi, lo # normalize order if caller swapped them + + if lo is not None: + if (ar <= lo) if strict else (ar < lo): + op = "<" if strict else "≤" + raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {lo:.2g}.") + if hi is not None: + if (ar >= hi) if strict else (ar > hi): + op = "<" if strict else "≤" + raise ValueError(f"Aspect ratio `{ar:.2g}` must be {op} {hi:.2g}.") + + +def _parse_aspect_ratio_string(ar_str: str) -> float: + """Parse 'X:Y' with integer parts into a positive float ratio X/Y.""" + parts = ar_str.split(":") + if len(parts) != 2: + raise ValueError(f"Aspect ratio must be 'X:Y' (e.g., 16:9), got '{ar_str}'.") + try: + a = int(parts[0].strip()) + b = int(parts[1].strip()) + except ValueError as exc: + raise ValueError(f"Aspect ratio must contain integers separated by ':', got '{ar_str}'.") from exc + if a <= 0 or b <= 0: + raise ValueError(f"Aspect ratio parts must be positive integers, got {a}:{b}.") + return a / b diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 85481451e..02bc43bf7 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -1,4 +1,9 @@ +import bisect +import gc import itertools +import psutil +import time +import torch from typing import Sequence, Mapping, Dict from .graph import DynamicPrompt @@ -50,7 +55,7 @@ class Unhashable: def to_hashable(obj): # So that we don't infinitely recurse since frozenset and tuples # are Sequences. - if isinstance(obj, (int, float, str, bool, type(None))): + if isinstance(obj, (int, float, str, bool, bytes, type(None))): return obj elif isinstance(obj, Mapping): return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) @@ -193,6 +198,9 @@ class BasicCache: self._clean_cache() self._clean_subcaches() + def poll(self, **kwargs): + pass + def _set_immediate(self, node_id, value): assert self.initialized cache_key = self.cache_key_set.get_data_key(node_id) @@ -271,6 +279,29 @@ class HierarchicalCache(BasicCache): assert cache is not None return await cache._ensure_subcache(node_id, children_ids) +class NullCache: + + async def set_prompt(self, dynprompt, node_ids, is_changed_cache): + pass + + def all_node_ids(self): + return [] + + def clean_unused(self): + pass + + def poll(self, **kwargs): + pass + + def get(self, node_id): + return None + + def set(self, node_id, value): + pass + + async def ensure_subcache_for(self, node_id, children_ids): + return self + class LRUCache(BasicCache): def __init__(self, key_class, max_size=100): super().__init__(key_class) @@ -324,155 +355,75 @@ class LRUCache(BasicCache): return self -class DependencyAwareCache(BasicCache): - """ - A cache implementation that tracks dependencies between nodes and manages - their execution and caching accordingly. It extends the BasicCache class. - Nodes are removed from this cache once all of their descendants have been - executed. - """ +#Iterating the cache for usage analysis might be expensive, so if we trigger make sure +#to take a chunk out to give breathing space on high-node / low-ram-per-node flows. + +RAM_CACHE_HYSTERESIS = 1.1 + +#This is kinda in GB but not really. It needs to be non-zero for the below heuristic +#and as long as Multi GB models dwarf this it will approximate OOM scoring OK + +RAM_CACHE_DEFAULT_RAM_USAGE = 0.1 + +#Exponential bias towards evicting older workflows so garbage will be taken out +#in constantly changing setups. + +RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3 + +class RAMPressureCache(LRUCache): def __init__(self, key_class): - """ - Initialize the DependencyAwareCache. - - Args: - key_class: The class used for generating cache keys. - """ - super().__init__(key_class) - self.descendants = {} # Maps node_id -> set of descendant node_ids - self.ancestors = {} # Maps node_id -> set of ancestor node_ids - self.executed_nodes = set() # Tracks nodes that have been executed - - async def set_prompt(self, dynprompt, node_ids, is_changed_cache): - """ - Clear the entire cache and rebuild the dependency graph. - - Args: - dynprompt: The dynamic prompt object containing node information. - node_ids: List of node IDs to initialize the cache for. - is_changed_cache: Flag indicating if the cache has changed. - """ - # Clear all existing cache data - self.cache.clear() - self.subcaches.clear() - self.descendants.clear() - self.ancestors.clear() - self.executed_nodes.clear() - - # Call the parent method to initialize the cache with the new prompt - await super().set_prompt(dynprompt, node_ids, is_changed_cache) - - # Rebuild the dependency graph - self._build_dependency_graph(dynprompt, node_ids) - - def _build_dependency_graph(self, dynprompt, node_ids): - """ - Build the dependency graph for all nodes. - - Args: - dynprompt: The dynamic prompt object containing node information. - node_ids: List of node IDs to build the graph for. - """ - self.descendants.clear() - self.ancestors.clear() - for node_id in node_ids: - self.descendants[node_id] = set() - self.ancestors[node_id] = set() - - for node_id in node_ids: - inputs = dynprompt.get_node(node_id)["inputs"] - for input_data in inputs.values(): - if is_link(input_data): # Check if the input is a link to another node - ancestor_id = input_data[0] - self.descendants[ancestor_id].add(node_id) - self.ancestors[node_id].add(ancestor_id) - - def set(self, node_id, value): - """ - Mark a node as executed and store its value in the cache. - - Args: - node_id: The ID of the node to store. - value: The value to store for the node. - """ - self._set_immediate(node_id, value) - self.executed_nodes.add(node_id) - self._cleanup_ancestors(node_id) - - def get(self, node_id): - """ - Retrieve the cached value for a node. - - Args: - node_id: The ID of the node to retrieve. - - Returns: - The cached value for the node. - """ - return self._get_immediate(node_id) - - async def ensure_subcache_for(self, node_id, children_ids): - """ - Ensure a subcache exists for a node and update dependencies. - - Args: - node_id: The ID of the parent node. - children_ids: List of child node IDs to associate with the parent node. - - Returns: - The subcache object for the node. - """ - subcache = await super()._ensure_subcache(node_id, children_ids) - for child_id in children_ids: - self.descendants[node_id].add(child_id) - self.ancestors[child_id].add(node_id) - return subcache - - def _cleanup_ancestors(self, node_id): - """ - Check if ancestors of a node can be removed from the cache. - - Args: - node_id: The ID of the node whose ancestors are to be checked. - """ - for ancestor_id in self.ancestors.get(node_id, []): - if ancestor_id in self.executed_nodes: - # Remove ancestor if all its descendants have been executed - if all(descendant in self.executed_nodes for descendant in self.descendants[ancestor_id]): - self._remove_node(ancestor_id) - - def _remove_node(self, node_id): - """ - Remove a node from the cache. - - Args: - node_id: The ID of the node to remove. - """ - cache_key = self.cache_key_set.get_data_key(node_id) - if cache_key in self.cache: - del self.cache[cache_key] - subcache_key = self.cache_key_set.get_subcache_key(node_id) - if subcache_key in self.subcaches: - del self.subcaches[subcache_key] + super().__init__(key_class, 0) + self.timestamps = {} def clean_unused(self): - """ - Clean up unused nodes. This is a no-op for this cache implementation. - """ - pass + self._clean_subcaches() - def recursive_debug_dump(self): - """ - Dump the cache and dependency graph for debugging. + def set(self, node_id, value): + self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() + super().set(node_id, value) - Returns: - A list containing the cache state and dependency graph. - """ - result = super().recursive_debug_dump() - result.append({ - "descendants": self.descendants, - "ancestors": self.ancestors, - "executed_nodes": list(self.executed_nodes), - }) - return result + def get(self, node_id): + self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() + return super().get(node_id) + + def poll(self, ram_headroom): + def _ram_gb(): + return psutil.virtual_memory().available / (1024**3) + + if _ram_gb() > ram_headroom: + return + gc.collect() + if _ram_gb() > ram_headroom: + return + + clean_list = [] + + for key, (outputs, _), in self.cache.items(): + oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key]) + + ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE + def scan_list_for_ram_usage(outputs): + nonlocal ram_usage + if outputs is None: + return + for output in outputs: + if isinstance(output, list): + scan_list_for_ram_usage(output) + elif isinstance(output, torch.Tensor) and output.device.type == 'cpu': + #score Tensors at a 50% discount for RAM usage as they are likely to + #be high value intermediates + ram_usage += (output.numel() * output.element_size()) * 0.5 + elif hasattr(output, "get_ram_usage"): + ram_usage += output.get_ram_usage() + scan_list_for_ram_usage(outputs) + + oom_score *= ram_usage + #In the case where we have no information on the node ram usage at all, + #break OOM score ties on the last touch timestamp (pure LRU) + bisect.insort(clean_list, (oom_score, self.timestamps[key], key)) + + while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list: + _, _, key = clean_list.pop() + del self.cache[key] + gc.collect() diff --git a/comfy_execution/graph.py b/comfy_execution/graph.py index 088ed1482..81a630fcb 100644 --- a/comfy_execution/graph.py +++ b/comfy_execution/graph.py @@ -161,8 +161,9 @@ class TopologicalSort: continue _, _, input_info = self.get_input_info(unique_id, input_name) is_lazy = input_info is not None and "lazy" in input_info and input_info["lazy"] - if (include_lazy or not is_lazy) and not self.is_cached(from_node_id): - node_ids.append(from_node_id) + if (include_lazy or not is_lazy): + if not self.is_cached(from_node_id): + node_ids.append(from_node_id) links.append((from_node_id, from_socket, unique_id)) for link in links: @@ -206,10 +207,40 @@ class ExecutionList(TopologicalSort): super().__init__(dynprompt) self.output_cache = output_cache self.staged_node_id = None + self.execution_cache = {} + self.execution_cache_listeners = {} def is_cached(self, node_id): return self.output_cache.get(node_id) is not None + def cache_link(self, from_node_id, to_node_id): + if not to_node_id in self.execution_cache: + self.execution_cache[to_node_id] = {} + self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) + if not from_node_id in self.execution_cache_listeners: + self.execution_cache_listeners[from_node_id] = set() + self.execution_cache_listeners[from_node_id].add(to_node_id) + + def get_cache(self, from_node_id, to_node_id): + if not to_node_id in self.execution_cache: + return None + value = self.execution_cache[to_node_id].get(from_node_id) + if value is None: + return None + #Write back to the main cache on touch. + self.output_cache.set(from_node_id, value) + return value + + def cache_update(self, node_id, value): + if node_id in self.execution_cache_listeners: + for to_node_id in self.execution_cache_listeners[node_id]: + if to_node_id in self.execution_cache: + self.execution_cache[to_node_id][node_id] = value + + def add_strong_link(self, from_node_id, from_socket, to_node_id): + super().add_strong_link(from_node_id, from_socket, to_node_id) + self.cache_link(from_node_id, to_node_id) + async def stage_node_execution(self) -> tuple[Optional[str], Optional[DependencyExecutionErrorMessage], Optional[DependencyCycleError]]: assert self.staged_node_id is None if self.is_empty(): @@ -289,6 +320,8 @@ class ExecutionList(TopologicalSort): def complete_node_execution(self): node_id = self.staged_node_id self.pop_node(node_id) + self.execution_cache.pop(node_id, None) + self.execution_cache_listeners.pop(node_id, None) self.staged_node_id = None def get_nodes_in_cycle(self): diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py index cec105fc9..24c0b4ed7 100644 --- a/comfy_execution/validation.py +++ b/comfy_execution/validation.py @@ -1,4 +1,5 @@ from __future__ import annotations +from comfy_api.latest import IO def validate_node_input( @@ -23,6 +24,11 @@ def validate_node_input( if not received_type != input_type: return True + # If the received type or input_type is a MatchType, we can return True immediately; + # validation for this is handled by the frontend + if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type: + return True + # Not equal, and not strings if not isinstance(received_type, str) or not isinstance(input_type, str): return False diff --git a/comfy_extras/nodes/nodes_audio.py b/comfy_extras/nodes/nodes_audio.py index 5344bb53c..753d8231c 100644 --- a/comfy_extras/nodes/nodes_audio.py +++ b/comfy_extras/nodes/nodes_audio.py @@ -1,311 +1,227 @@ from __future__ import annotations -import hashlib -import io -import json -import os -import random - import av import torch - import comfy.model_management +from comfy.cmd import folder_paths +import os +import hashlib from comfy import node_helpers import logging -from comfy.cli_args import args -from comfy.cmd import folder_paths -from comfy.comfy_types import FileLocator +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, UI class TorchAudioNotFoundError(ModuleNotFoundError): pass -class EmptyLatentAudio: - def __init__(self): - self.device = comfy.model_management.intermediate_device() +class EmptyLatentAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentAudio", + display_name="Empty Latent Audio", + category="latent/audio", + inputs=[ + IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), + IO.Int.Input( + "batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch.", optional=True, + ), + ], + outputs=[IO.Latent.Output()], + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"seconds": ("FLOAT", {"default": 47.6, "min": 1.0, "max": 1000.0, "step": 0.1}), }, - # mark as optional to not break existing workflows - "optional": {"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - }} - - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" - - CATEGORY = "latent/audio" - - def generate(self, seconds: float, batch_size: int = 1): + def execute(cls, seconds, batch_size) -> IO.NodeOutput: length = round((seconds * 44100 / 2048) / 2) * 2 - latent = torch.zeros([batch_size, 64, length], device=self.device) - return ({"samples": latent, "type": "audio"},) + latent = torch.zeros([batch_size, 64, length], device=comfy.model_management.intermediate_device()) + return IO.NodeOutput({"samples": latent, "type": "audio"}) + + generate = execute # TODO: remove -class ConditioningStableAudio: +class ConditioningStableAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"positive": ("CONDITIONING",), - "negative": ("CONDITIONING",), - "seconds_start": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1}), - "seconds_total": ("FLOAT", {"default": 47.0, "min": 0.0, "max": 1000.0, "step": 0.1}), - }} + def define_schema(cls): + return IO.Schema( + node_id="ConditioningStableAudio", + category="conditioning", + inputs=[ + IO.Conditioning.Input("positive"), + IO.Conditioning.Input("negative"), + IO.Float.Input("seconds_start", default=0.0, min=0.0, max=1000.0, step=0.1), + IO.Float.Input("seconds_total", default=47.0, min=0.0, max=1000.0, step=0.1), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ], + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - - FUNCTION = "append" - - CATEGORY = "conditioning" - - def append(self, positive, negative, seconds_start, seconds_total): + @classmethod + def execute(cls, positive, negative, seconds_start, seconds_total) -> IO.NodeOutput: positive = node_helpers.conditioning_set_values(positive, {"seconds_start": seconds_start, "seconds_total": seconds_total}) negative = node_helpers.conditioning_set_values(negative, {"seconds_start": seconds_start, "seconds_total": seconds_total}) - return (positive, negative) + return IO.NodeOutput(positive, negative) + + append = execute # TODO: remove -class VAEEncodeAudio: +class VAEEncodeAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"audio": ("AUDIO",), "vae": ("VAE",)}} + def define_schema(cls): + return IO.Schema( + node_id="VAEEncodeAudio", + display_name="VAE Encode Audio", + category="latent/audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Vae.Input("vae"), + ], + outputs=[IO.Latent.Output()], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "encode" - - CATEGORY = "latent/audio" - - def encode(self, vae, audio): + @classmethod + def execute(cls, vae, audio) -> IO.NodeOutput: sample_rate = audio["sample_rate"] + try: + import torchaudio # pylint: disable=import-error + except ImportError: + raise TorchAudioNotFoundError() if 44100 != sample_rate: - try: - import torchaudio # pylint: disable=import-error - except ImportError as exc_info: - raise TorchAudioNotFoundError() waveform = torchaudio.functional.resample(audio["waveform"], sample_rate, 44100) else: waveform = audio["waveform"] t = vae.encode(waveform.movedim(1, -1)) - return ({"samples": t},) + return IO.NodeOutput({"samples": t}) + + encode = execute # TODO: remove -class VAEDecodeAudio: +class VAEDecodeAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"samples": ("LATENT",), "vae": ("VAE",)}} + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeAudio", + display_name="VAE Decode Audio", + category="latent/audio", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + ], + outputs=[IO.Audio.Output()], + ) - RETURN_TYPES = ("AUDIO",) - FUNCTION = "decode" - - CATEGORY = "latent/audio" - - def decode(self, vae, samples): - if samples is None: - return None, + @classmethod + def execute(cls, vae, samples) -> IO.NodeOutput: audio = vae.decode(samples["samples"]).movedim(-1, 1) std = torch.std(audio, dim=[1, 2], keepdim=True) * 5.0 std[std < 1.0] = 1.0 audio /= std - return ({"waveform": audio, "sample_rate": 44100},) + return IO.NodeOutput({"waveform": audio, "sample_rate": 44100}) + + decode = execute # TODO: remove -def save_audio(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None, quality="128k"): - try: - import torchaudio # pylint: disable=import-error - except ImportError as exc_info: - raise TorchAudioNotFoundError() - - filename_prefix += self.prefix_append - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, self.output_dir) - results: list[FileLocator] = [] - - # Prepare metadata dictionary - metadata = {} - if not args.disable_metadata: - if prompt is not None: - metadata["prompt"] = json.dumps(prompt) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) - - # Opus supported sample rates - OPUS_RATES = [8000, 12000, 16000, 24000, 48000] - - for (batch_number, waveform) in enumerate(audio["waveform"].cpu()): - filename_with_batch_num = filename.replace("%batch_num%", str(batch_number)) - file = f"{filename_with_batch_num}_{counter:05}_.{format}" - output_path = os.path.join(full_output_folder, file) - - # Use original sample rate initially - sample_rate = audio["sample_rate"] - - # Handle Opus sample rate requirements - if format == "opus": - if sample_rate > 48000: - sample_rate = 48000 - elif sample_rate not in OPUS_RATES: - # Find the next highest supported rate - for rate in sorted(OPUS_RATES): - if rate > sample_rate: - sample_rate = rate - break - if sample_rate not in OPUS_RATES: # Fallback if still not supported - sample_rate = 48000 - - # Resample if necessary - if sample_rate != audio["sample_rate"]: - waveform = torchaudio.functional.resample(waveform, audio["sample_rate"], sample_rate) - - # Create output with specified format - output_buffer = io.BytesIO() - output_container = av.open(output_buffer, mode='w', format=format) - - # Set metadata on the container - for key, value in metadata.items(): - output_container.metadata[key] = value - - layout = 'mono' if waveform.shape[0] == 1 else 'stereo' - # Set up the output stream with appropriate properties - if format == "opus": - out_stream = output_container.add_stream("libopus", rate=sample_rate, layout=layout) - if quality == "64k": - out_stream.bit_rate = 64000 - elif quality == "96k": - out_stream.bit_rate = 96000 - elif quality == "128k": - out_stream.bit_rate = 128000 - elif quality == "192k": - out_stream.bit_rate = 192000 - elif quality == "320k": - out_stream.bit_rate = 320000 - elif format == "mp3": - out_stream = output_container.add_stream("libmp3lame", rate=sample_rate, layout=layout) - if quality == "V0": - # TODO i would really love to support V3 and V5 but there doesn't seem to be a way to set the qscale level, the property below is a bool - out_stream.codec_context.qscale = 1 - elif quality == "128k": - out_stream.bit_rate = 128000 - elif quality == "320k": - out_stream.bit_rate = 320000 - else: # format == "flac": - out_stream = output_container.add_stream("flac", rate=sample_rate, layout=layout) - - frame = av.AudioFrame.from_ndarray(waveform.movedim(0, 1).reshape(1, -1).float().numpy(), format='flt', layout=layout) - frame.sample_rate = sample_rate - frame.pts = 0 - output_container.mux(out_stream.encode(frame)) - - # Flush encoder - output_container.mux(out_stream.encode(None)) - - # Close containers - output_container.close() - - # Write the output to file - output_buffer.seek(0) - with open(output_path, 'wb') as f: - f.write(output_buffer.getbuffer()) - - results.append({ - "filename": file, - "subfolder": subfolder, - "type": self.type - }) - counter += 1 - - return {"ui": {"audio": results}} - - -class SaveAudio: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudio", + display_name="Save Audio (FLAC)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"audio": ("AUDIO",), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="flac") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=format) + ) - RETURN_TYPES = () - FUNCTION = "save_flac" - - OUTPUT_NODE = True - - CATEGORY = "audio" - - def save_flac(self, audio, filename_prefix="ComfyUI", format="flac", prompt=None, extra_pnginfo=None): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo) + save_flac = execute # TODO: remove -class SaveAudioMP3: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudioMP3(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioMP3", + display_name="Save Audio (MP3)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"audio": ("AUDIO",), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - "quality": (["V0", "128k", "320k"], {"default": "V0"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="mp3", quality="128k") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui( + audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality + ) + ) - RETURN_TYPES = () - FUNCTION = "save_mp3" - - OUTPUT_NODE = True - - CATEGORY = "audio" - - def save_mp3(self, audio, filename_prefix="ComfyUI", format="mp3", prompt=None, extra_pnginfo=None, quality="128k"): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) + save_mp3 = execute # TODO: remove -class SaveAudioOpus: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() - self.type = "output" - self.prefix_append = "" +class SaveAudioOpus(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="SaveAudioOpus", + display_name="Save Audio (Opus)", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.String.Input("filename_prefix", default="audio/ComfyUI"), + IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": {"audio": ("AUDIO",), - "filename_prefix": ("STRING", {"default": "audio/ComfyUI"}), - "quality": (["64k", "96k", "128k", "192k", "320k"], {"default": "128k"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio, filename_prefix="ComfyUI", format="opus", quality="V3") -> IO.NodeOutput: + return IO.NodeOutput( + ui=UI.AudioSaveHelper.get_save_audio_ui( + audio, filename_prefix=filename_prefix, cls=cls, format=format, quality=quality + ) + ) - RETURN_TYPES = () - FUNCTION = "save_opus" - - OUTPUT_NODE = True - - CATEGORY = "audio" - - def save_opus(self, audio, filename_prefix="ComfyUI", format="opus", prompt=None, extra_pnginfo=None, quality="V3"): - return save_audio(self, audio, filename_prefix, format, prompt, extra_pnginfo, quality) + save_opus = execute # TODO: remove -class PreviewAudio(SaveAudio): - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() - self.type = "temp" - self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) +class PreviewAudio(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="PreviewAudio", + display_name="Preview Audio", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return {"required": - {"audio": ("AUDIO",), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } + def execute(cls, audio) -> IO.NodeOutput: + return IO.NodeOutput(ui=UI.PreviewAudio(audio, cls=cls)) + + save_flac = execute # TODO: remove def f32_pcm(wav: torch.Tensor) -> torch.Tensor: @@ -346,31 +262,30 @@ def load(filepath: str) -> tuple[torch.Tensor, int]: return wav, sr -class LoadAudio: +class LoadAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): + def define_schema(cls): input_dir = folder_paths.get_input_directory() files = folder_paths.filter_files_content_types(os.listdir(input_dir), ["audio", "video"]) - return {"required": {"audio": (sorted(files), {"audio_upload": True})}} - - CATEGORY = "audio" - - RETURN_TYPES = ("AUDIO",) - FUNCTION = "load" - - def load(self, audio): - try: - import torchaudio # pylint: disable=import-error - except ImportError as exc_info: - raise TorchAudioNotFoundError() + return IO.Schema( + node_id="LoadAudio", + display_name="Load Audio", + category="audio", + inputs=[ + IO.Combo.Input("audio", upload=IO.UploadType.audio, options=sorted(files)), + ], + outputs=[IO.Audio.Output()], + ) + @classmethod + def execute(cls, audio) -> IO.NodeOutput: audio_path = folder_paths.get_annotated_filepath(audio) waveform, sample_rate = load(audio_path) audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} - return (audio,) + return IO.NodeOutput(audio) @classmethod - def IS_CHANGED(s, audio): + def fingerprint_inputs(cls, audio): image_path = folder_paths.get_annotated_filepath(audio) m = hashlib.sha256() with open(image_path, 'rb') as f: @@ -378,51 +293,69 @@ class LoadAudio: return m.digest().hex() @classmethod - def VALIDATE_INPUTS(s, audio): + def validate_inputs(cls, audio): if not folder_paths.exists_annotated_filepath(audio): return "Invalid audio file: {}".format(audio) return True + load = execute # TODO: remove -class RecordAudio: + +class RecordAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"audio": ("AUDIO_RECORD", {})}} + def define_schema(cls): + return IO.Schema( + node_id="RecordAudio", + display_name="Record Audio", + category="audio", + inputs=[ + IO.Custom("AUDIO_RECORD").Input("audio"), + ], + outputs=[IO.Audio.Output()], + ) - CATEGORY = "audio" - - RETURN_TYPES = ("AUDIO",) - FUNCTION = "load" - - def load(self, audio): + @classmethod + def execute(cls, audio) -> IO.NodeOutput: audio_path = folder_paths.get_annotated_filepath(audio) - try: - import torchaudio # pylint: disable=import-error - except (ImportError, ModuleNotFoundError): - raise TorchAudioNotFoundError() waveform, sample_rate = load(audio_path) audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} - return (audio,) + return IO.NodeOutput(audio) + + load = execute # TODO: remove -class TrimAudioDuration: +class TrimAudioDuration(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "audio": ("AUDIO",), - "start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}), - "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="TrimAudioDuration", + display_name="Trim Audio Duration", + description="Trim audio tensor into chosen time range.", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Float.Input( + "start_index", + default=0.0, + min=-0xffffffffffffffff, + max=0xffffffffffffffff, + step=0.01, + tooltip="Start time in seconds, can be negative to count from the end (supports sub-seconds).", + ), + IO.Float.Input( + "duration", + default=60.0, + min=0.0, + step=0.01, + tooltip="Duration in seconds", + ), + ], + outputs=[IO.Audio.Output()], + ) - FUNCTION = "trim" - RETURN_TYPES = ("AUDIO",) - CATEGORY = "audio" - DESCRIPTION = "Trim audio tensor into chosen time range." - - def trim(self, audio, start_index, duration): + @classmethod + def execute(cls, audio, start_index, duration) -> IO.NodeOutput: waveform = audio["waveform"] sample_rate = audio["sample_rate"] audio_length = waveform.shape[-1] @@ -439,23 +372,30 @@ class TrimAudioDuration: if start_frame >= end_frame: raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.") - return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},) + return IO.NodeOutput({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate}) + + trim = execute # TODO: remove -class SplitAudioChannels: +class SplitAudioChannels(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "audio": ("AUDIO",), - }} + def define_schema(cls): + return IO.Schema( + node_id="SplitAudioChannels", + display_name="Split Audio Channels", + description="Separates the audio into left and right channels.", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + ], + outputs=[ + IO.Audio.Output(display_name="left"), + IO.Audio.Output(display_name="right"), + ], + ) - RETURN_TYPES = ("AUDIO", "AUDIO") - RETURN_NAMES = ("left", "right") - FUNCTION = "separate" - CATEGORY = "audio" - DESCRIPTION = "Separates the audio into left and right channels." - - def separate(self, audio): + @classmethod + def execute(cls, audio) -> IO.NodeOutput: waveform = audio["waveform"] sample_rate = audio["sample_rate"] @@ -465,15 +405,16 @@ class SplitAudioChannels: left_channel = waveform[..., 0:1, :] right_channel = waveform[..., 1:2, :] - return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate}) + return IO.NodeOutput({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate}) + + separate = execute # TODO: remove def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2): try: import torchaudio # pylint: disable=import-error - except ImportError as exc_info: + except ImportError: raise TorchAudioNotFoundError() - if sample_rate_1 != sample_rate_2: if sample_rate_1 > sample_rate_2: waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1) @@ -488,21 +429,29 @@ def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_ return waveform_1, waveform_2, output_sample_rate -class AudioConcat: +class AudioConcat(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "audio1": ("AUDIO",), - "audio2": ("AUDIO",), - "direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}), - }} + def define_schema(cls): + return IO.Schema( + node_id="AudioConcat", + display_name="Audio Concat", + description="Concatenates the audio1 to audio2 in the specified direction.", + category="audio", + inputs=[ + IO.Audio.Input("audio1"), + IO.Audio.Input("audio2"), + IO.Combo.Input( + "direction", + options=['after', 'before'], + default="after", + tooltip="Whether to append audio2 after or before audio1.", + ) + ], + outputs=[IO.Audio.Output()], + ) - RETURN_TYPES = ("AUDIO",) - FUNCTION = "concat" - CATEGORY = "audio" - DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction." - - def concat(self, audio1, audio2, direction): + @classmethod + def execute(cls, audio1, audio2, direction) -> IO.NodeOutput: waveform_1 = audio1["waveform"] waveform_2 = audio2["waveform"] sample_rate_1 = audio1["sample_rate"] @@ -517,32 +466,38 @@ class AudioConcat: waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2) - concatenated_audio: torch.Tensor = waveform_1 if direction == 'after': concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2) elif direction == 'before': concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2) - return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},) + return IO.NodeOutput({"waveform": concatenated_audio, "sample_rate": output_sample_rate}) + + concat = execute # TODO: remove -class AudioMerge: +class AudioMerge(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "audio1": ("AUDIO",), - "audio2": ("AUDIO",), - "merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="AudioMerge", + display_name="Audio Merge", + description="Combine two audio tracks by overlaying their waveforms.", + category="audio", + inputs=[ + IO.Audio.Input("audio1"), + IO.Audio.Input("audio2"), + IO.Combo.Input( + "merge_method", + options=["add", "mean", "subtract", "multiply"], + tooltip="The method used to combine the audio waveforms.", + ) + ], + outputs=[IO.Audio.Output()], + ) - FUNCTION = "merge" - RETURN_TYPES = ("AUDIO",) - CATEGORY = "audio" - DESCRIPTION = "Combine two audio tracks by overlaying their waveforms." - - def merge(self, audio1, audio2, merge_method): + @classmethod + def execute(cls, audio1, audio2, merge_method) -> IO.NodeOutput: waveform_1 = audio1["waveform"] waveform_2 = audio2["waveform"] sample_rate_1 = audio1["sample_rate"] @@ -563,7 +518,6 @@ class AudioMerge: pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device) waveform_2 = torch.cat((waveform_2, pad_tensor), dim=-1) - waveform = None if merge_method == "add": waveform = waveform_1 + waveform_2 elif merge_method == "subtract": @@ -577,85 +531,111 @@ class AudioMerge: if max_val > 1.0: waveform = waveform / max_val - return ({"waveform": waveform, "sample_rate": output_sample_rate},) + return IO.NodeOutput({"waveform": waveform, "sample_rate": output_sample_rate}) + + merge = execute # TODO: remove -class AudioAdjustVolume: +class AudioAdjustVolume(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "audio": ("AUDIO",), - "volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}), - }} + def define_schema(cls): + return IO.Schema( + node_id="AudioAdjustVolume", + display_name="Audio Adjust Volume", + category="audio", + inputs=[ + IO.Audio.Input("audio"), + IO.Int.Input( + "volume", + default=1, + min=-100, + max=100, + tooltip="Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc", + ) + ], + outputs=[IO.Audio.Output()], + ) - RETURN_TYPES = ("AUDIO",) - FUNCTION = "adjust_volume" - CATEGORY = "audio" - - def adjust_volume(self, audio, volume): + @classmethod + def execute(cls, audio, volume) -> IO.NodeOutput: if volume == 0: - return (audio,) + return IO.NodeOutput(audio) waveform = audio["waveform"] sample_rate = audio["sample_rate"] gain = 10 ** (volume / 20) waveform = waveform * gain - return ({"waveform": waveform, "sample_rate": sample_rate},) + return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate}) + + adjust_volume = execute # TODO: remove -class EmptyAudio: +class EmptyAudio(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { - "duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}), - "sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}), - "channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}), - }} + def define_schema(cls): + return IO.Schema( + node_id="EmptyAudio", + display_name="Empty Audio", + category="audio", + inputs=[ + IO.Float.Input( + "duration", + default=60.0, + min=0.0, + max=0xffffffffffffffff, + step=0.01, + tooltip="Duration of the empty audio clip in seconds", + ), + IO.Int.Input( + "sample_rate", + default=44100, + tooltip="Sample rate of the empty audio clip.", + min=1, + max=192000, + ), + IO.Int.Input( + "channels", + default=2, + min=1, + max=2, + tooltip="Number of audio channels (1 for mono, 2 for stereo).", + ), + ], + outputs=[IO.Audio.Output()], + ) - RETURN_TYPES = ("AUDIO",) - FUNCTION = "create_empty_audio" - CATEGORY = "audio" - - def create_empty_audio(self, duration, sample_rate, channels): + @classmethod + def execute(cls, duration, sample_rate, channels) -> IO.NodeOutput: num_samples = int(round(duration * sample_rate)) waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32) - return ({"waveform": waveform, "sample_rate": sample_rate},) + return IO.NodeOutput({"waveform": waveform, "sample_rate": sample_rate}) + + create_empty_audio = execute # TODO: remove -NODE_CLASS_MAPPINGS = { - "EmptyLatentAudio": EmptyLatentAudio, - "VAEEncodeAudio": VAEEncodeAudio, - "VAEDecodeAudio": VAEDecodeAudio, - "SaveAudio": SaveAudio, - "SaveAudioMP3": SaveAudioMP3, - "SaveAudioOpus": SaveAudioOpus, - "LoadAudio": LoadAudio, - "PreviewAudio": PreviewAudio, - "ConditioningStableAudio": ConditioningStableAudio, - "RecordAudio": RecordAudio, - "TrimAudioDuration": TrimAudioDuration, - "SplitAudioChannels": SplitAudioChannels, - "AudioConcat": AudioConcat, - "AudioMerge": AudioMerge, - "AudioAdjustVolume": AudioAdjustVolume, - "EmptyAudio": EmptyAudio, -} +class AudioExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + EmptyLatentAudio, + VAEEncodeAudio, + VAEDecodeAudio, + SaveAudio, + SaveAudioMP3, + SaveAudioOpus, + LoadAudio, + PreviewAudio, + ConditioningStableAudio, + RecordAudio, + TrimAudioDuration, + SplitAudioChannels, + AudioConcat, + AudioMerge, + AudioAdjustVolume, + EmptyAudio, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "EmptyLatentAudio": "Empty Latent Audio", - "VAEEncodeAudio": "VAE Encode Audio", - "VAEDecodeAudio": "VAE Decode Audio", - "PreviewAudio": "Preview Audio", - "LoadAudio": "Load Audio", - "SaveAudio": "Save Audio (FLAC)", - "SaveAudioMP3": "Save Audio (MP3)", - "SaveAudioOpus": "Save Audio (Opus)", - "RecordAudio": "Record Audio", - "TrimAudioDuration": "Trim Audio Duration", - "SplitAudioChannels": "Split Audio Channels", - "AudioConcat": "Audio Concat", - "AudioMerge": "Audio Merge", - "AudioAdjustVolume": "Audio Adjust Volume", - "EmptyAudio": "Empty Audio", -} + +async def comfy_entrypoint() -> AudioExtension: + return AudioExtension() diff --git a/comfy_extras/nodes/nodes_context_windows.py b/comfy_extras/nodes/nodes_context_windows.py index 4cd8370aa..f3681a970 100644 --- a/comfy_extras/nodes/nodes_context_windows.py +++ b/comfy_extras/nodes/nodes_context_windows.py @@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode): io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), + io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), ], outputs=[ io.Model.Output(tooltip="The model with context windows applied during sampling."), @@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode): ) @classmethod - def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model: + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: model = model.clone() model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), @@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode): context_overlap=context_overlap, context_stride=context_stride, closed_loop=closed_loop, - dim=dim) + dim=dim, + freenoise=freenoise, + cond_retain_index_list=cond_retain_index_list, + split_conds_to_windows=split_conds_to_windows + ) # make memory usage calculation only take into account the context window latents comfy.context_windows.create_prepare_sampling_wrapper(model) + if freenoise: # no other use for this wrapper at this time + comfy.context_windows.create_sampler_sample_wrapper(model) return io.NodeOutput(model) class WanContextWindowsManualNode(ContextWindowsManualNode): @@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode): io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), ] return schema @classmethod - def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model: + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 - return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2) + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows) class ContextWindowsExtension(ComfyExtension): diff --git a/comfy_extras/nodes/nodes_custom_sampler.py b/comfy_extras/nodes/nodes_custom_sampler.py index f025a95a7..30a7854d5 100644 --- a/comfy_extras/nodes/nodes_custom_sampler.py +++ b/comfy_extras/nodes/nodes_custom_sampler.py @@ -1,279 +1,316 @@ import math - -import torch - -import comfy.sampler_names -from comfy import model_management -from comfy import node_helpers -from comfy import sample -from comfy import samplers -from comfy.cmd import latent_preview -from comfy.comfy_types import ComfyNodeABC, InputTypeDict, IO +import comfy.samplers +import comfy.sample from comfy.execution_context import current_execution_context -from comfy.k_diffusion import sampling as k_diffusion_sampling, sa_solver -from comfy.nodes.package_typing import Seed64 -from comfy.samplers import KSAMPLER +from comfy.k_diffusion import sampling as k_diffusion_sampling +from comfy.k_diffusion import sa_solver +import comfy.sampler_names +from comfy.cmd import latent_preview +import torch +import comfy.utils +from comfy import node_helpers +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io -class BasicScheduler: +class BasicScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "scheduler": (comfy.sampler_names.SCHEDULER_NAMES,), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="BasicScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Combo.Input("scheduler", options=comfy.sampler_names.SCHEDULER_NAMES), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, scheduler, steps, denoise): + @classmethod + def execute(cls, model, scheduler, steps, denoise) -> io.NodeOutput: total_steps = steps if denoise < 1.0: if denoise <= 0.0: - return (torch.FloatTensor([]),) + return io.NodeOutput(torch.FloatTensor([])) total_steps = int(steps/denoise) - sigmas = samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() + sigmas = comfy.samplers.calculate_sigmas(model.get_model_object("model_sampling"), scheduler, total_steps).cpu() sigmas = sigmas[-(steps + 1):] - return (sigmas, ) + return io.NodeOutput(sigmas) + + get_sigmas = execute -class KarrasScheduler: +class KarrasScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "rho": ("FLOAT", {"default": 7.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="KarrasScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=7.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, rho): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) - return (sigmas, ) + return io.NodeOutput(sigmas) -class ExponentialScheduler: + get_sigmas = execute + +class ExponentialScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="ExponentialScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min): + @classmethod + def execute(cls, steps, sigma_max, sigma_min) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_exponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max) - return (sigmas, ) + return io.NodeOutput(sigmas) -class PolyexponentialScheduler: + get_sigmas = execute + +class PolyexponentialScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "rho": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="PolyexponentialScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("rho", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, rho): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, rho) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_polyexponential(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho) - return (sigmas, ) + return io.NodeOutput(sigmas) -class LaplaceScheduler: + get_sigmas = execute + +class LaplaceScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "sigma_max": ("FLOAT", {"default": 14.614642, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "sigma_min": ("FLOAT", {"default": 0.0291675, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "mu": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step":0.1, "round": False}), - "beta": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 10.0, "step":0.1, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="LaplaceScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("sigma_max", default=14.614642, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("sigma_min", default=0.0291675, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("mu", default=0.0, min=-10.0, max=10.0, step=0.1, round=False), + io.Float.Input("beta", default=0.5, min=0.0, max=10.0, step=0.1, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, sigma_max, sigma_min, mu, beta): + @classmethod + def execute(cls, steps, sigma_max, sigma_min, mu, beta) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_laplace(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, mu=mu, beta=beta) - return (sigmas, ) + return io.NodeOutput(sigmas) + + get_sigmas = execute -class SDTurboScheduler: +class SDTurboScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "steps": ("INT", {"default": 1, "min": 1, "max": 10}), - "denoise": ("FLOAT", {"default": 1.0, "min": 0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="SDTurboScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=1, min=1, max=10), + io.Float.Input("denoise", default=1.0, min=0, max=1.0, step=0.01), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, steps, denoise): + @classmethod + def execute(cls, model, steps, denoise) -> io.NodeOutput: start_step = 10 - int(10 * denoise) timesteps = torch.flip(torch.arange(1, 11) * 100 - 1, (0,))[start_step:start_step + steps] sigmas = model.get_model_object("model_sampling").sigma(timesteps) sigmas = torch.cat([sigmas, sigmas.new_zeros([1])]) - return (sigmas, ) + return io.NodeOutput(sigmas) -class BetaSamplingScheduler: + get_sigmas = execute + +class BetaSamplingScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "alpha": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), - "beta": ("FLOAT", {"default": 0.6, "min": 0.0, "max": 50.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="BetaSamplingScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Model.Input("model"), + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("alpha", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + io.Float.Input("beta", default=0.6, min=0.0, max=50.0, step=0.01, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, model, steps, alpha, beta): + @classmethod + def execute(cls, model, steps, alpha, beta) -> io.NodeOutput: sigmas = comfy.samplers.beta_scheduler(model.get_model_object("model_sampling"), steps, alpha=alpha, beta=beta) - return (sigmas, ) + return io.NodeOutput(sigmas) -class VPScheduler: + get_sigmas = execute + +class VPScheduler(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"steps": ("INT", {"default": 20, "min": 1, "max": 10000}), - "beta_d": ("FLOAT", {"default": 19.9, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), #TODO: fix default values - "beta_min": ("FLOAT", {"default": 0.1, "min": 0.0, "max": 5000.0, "step":0.01, "round": False}), - "eps_s": ("FLOAT", {"default": 0.001, "min": 0.0, "max": 1.0, "step":0.0001, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/schedulers" + def define_schema(cls): + return io.Schema( + node_id="VPScheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=10000), + io.Float.Input("beta_d", default=19.9, min=0.0, max=5000.0, step=0.01, round=False), #TODO: fix default values + io.Float.Input("beta_min", default=0.1, min=0.0, max=5000.0, step=0.01, round=False), + io.Float.Input("eps_s", default=0.001, min=0.0, max=1.0, step=0.0001, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, steps, beta_d, beta_min, eps_s): + @classmethod + def execute(cls, steps, beta_d, beta_min, eps_s) -> io.NodeOutput: sigmas = k_diffusion_sampling.get_sigmas_vp(n=steps, beta_d=beta_d, beta_min=beta_min, eps_s=eps_s) - return (sigmas, ) + return io.NodeOutput(sigmas) -class SplitSigmas: + get_sigmas = execute + +class SplitSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "step": ("INT", {"default": 0, "min": 0, "max": 10000}), - } - } - RETURN_TYPES = ("SIGMAS","SIGMAS") - RETURN_NAMES = ("high_sigmas", "low_sigmas") - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SplitSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("step", default=0, min=0, max=10000), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas, step): + @classmethod + def execute(cls, sigmas, step) -> io.NodeOutput: sigmas1 = sigmas[:step + 1] sigmas2 = sigmas[step:] - return (sigmas1, sigmas2) + return io.NodeOutput(sigmas1, sigmas2) -class SplitSigmasDenoise: + get_sigmas = execute + +class SplitSigmasDenoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } - RETURN_TYPES = ("SIGMAS","SIGMAS") - RETURN_NAMES = ("high_sigmas", "low_sigmas") - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SplitSigmasDenoise", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("denoise", default=1.0, min=0.0, max=1.0, step=0.01), + ], + outputs=[ + io.Sigmas.Output(display_name="high_sigmas"), + io.Sigmas.Output(display_name="low_sigmas"), + ] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas, denoise): + @classmethod + def execute(cls, sigmas, denoise) -> io.NodeOutput: steps = max(sigmas.shape[-1] - 1, 0) total_steps = round(steps * denoise) sigmas1 = sigmas[:-(total_steps)] sigmas2 = sigmas[-(total_steps + 1):] - return (sigmas1, sigmas2) + return io.NodeOutput(sigmas1, sigmas2) -class FlipSigmas: + get_sigmas = execute + +class FlipSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="FlipSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[io.Sigmas.Input("sigmas")], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "get_sigmas" - - def get_sigmas(self, sigmas): + @classmethod + def execute(cls, sigmas) -> io.NodeOutput: if len(sigmas) == 0: - return (sigmas,) + return io.NodeOutput(sigmas) sigmas = sigmas.flip(0) if sigmas[0] == 0: sigmas[0] = 0.0001 - return (sigmas,) + return io.NodeOutput(sigmas) -class SetFirstSigma: + get_sigmas = execute + +class SetFirstSigma(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "sigma": ("FLOAT", {"default": 136.0, "min": 0.0, "max": 20000.0, "step": 0.001, "round": False}), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="SetFirstSigma", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Float.Input("sigma", default=136.0, min=0.0, max=20000.0, step=0.001, round=False), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "set_first_sigma" - - def set_first_sigma(self, sigmas, sigma): + @classmethod + def execute(cls, sigmas, sigma) -> io.NodeOutput: sigmas = sigmas.clone() sigmas[0] = sigma - return (sigmas, ) + return io.NodeOutput(sigmas) -class ExtendIntermediateSigmas: + set_first_sigma = execute + +class ExtendIntermediateSigmas(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sigmas": ("SIGMAS", ), - "steps": ("INT", {"default": 2, "min": 1, "max": 100}), - "start_at_sigma": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 20000.0, "step": 0.01, "round": False}), - "end_at_sigma": ("FLOAT", {"default": 12.0, "min": 0.0, "max": 20000.0, "step": 0.01, "round": False}), - "spacing": (['linear', 'cosine', 'sine'],), - } - } - RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling/sigmas" + def define_schema(cls): + return io.Schema( + node_id="ExtendIntermediateSigmas", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Sigmas.Input("sigmas"), + io.Int.Input("steps", default=2, min=1, max=100), + io.Float.Input("start_at_sigma", default=-1.0, min=-1.0, max=20000.0, step=0.01, round=False), + io.Float.Input("end_at_sigma", default=12.0, min=0.0, max=20000.0, step=0.01, round=False), + io.Combo.Input("spacing", options=['linear', 'cosine', 'sine']), + ], + outputs=[io.Sigmas.Output()] + ) - FUNCTION = "extend" - - def extend(self, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str): + @classmethod + def execute(cls, sigmas: torch.Tensor, steps: int, start_at_sigma: float, end_at_sigma: float, spacing: str) -> io.NodeOutput: if start_at_sigma < 0: start_at_sigma = float("inf") @@ -304,27 +341,27 @@ class ExtendIntermediateSigmas: extended_sigmas = torch.FloatTensor(extended_sigmas) - return (extended_sigmas,) + return io.NodeOutput(extended_sigmas) + + extend = execute -class SamplingPercentToSigma: +class SamplingPercentToSigma(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "model": (IO.MODEL, {}), - "sampling_percent": (IO.FLOAT, {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.0001}), - "return_actual_sigma": (IO.BOOLEAN, {"default": False, "tooltip": "Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."}), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplingPercentToSigma", + category="sampling/custom_sampling/sigmas", + inputs=[ + io.Model.Input("model"), + io.Float.Input("sampling_percent", default=0.0, min=0.0, max=1.0, step=0.0001), + io.Boolean.Input("return_actual_sigma", default=False, tooltip="Return the actual sigma value instead of the value used for interval checks.\nThis only affects results at 0.0 and 1.0."), + ], + outputs=[io.Float.Output(display_name="sigma_value")] + ) - RETURN_TYPES = (IO.FLOAT,) - RETURN_NAMES = ("sigma_value",) - CATEGORY = "sampling/custom_sampling/sigmas" - - FUNCTION = "get_sigma" - - def get_sigma(self, model, sampling_percent, return_actual_sigma): + @classmethod + def execute(cls, model, sampling_percent, return_actual_sigma) -> io.NodeOutput: model_sampling = model.get_model_object("model_sampling") sigma_val = model_sampling.percent_to_sigma(sampling_percent) if return_actual_sigma: @@ -332,212 +369,234 @@ class SamplingPercentToSigma: sigma_val = model_sampling.sigma_max.item() elif sampling_percent == 1.0: sigma_val = model_sampling.sigma_min.item() - return (sigma_val,) + return io.NodeOutput(sigma_val) + + get_sigma = execute -class KSamplerSelect: +class KSamplerSelect(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"sampler_name": (comfy.sampler_names.SAMPLER_NAMES,), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="KSamplerSelect", + category="sampling/custom_sampling/samplers", + inputs=[io.Combo.Input("sampler_name", options=comfy.samplers.SAMPLER_NAMES)], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, sampler_name): - sampler = samplers.sampler_object(sampler_name) - return (sampler, ) - -class SamplerDPMPP_3M_SDE: @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def execute(cls, sampler_name) -> io.NodeOutput: + sampler = comfy.samplers.sampler_object(sampler_name) + return io.NodeOutput(sampler) - FUNCTION = "get_sampler" + get_sampler = execute - def get_sampler(self, eta, s_noise, noise_device): +class SamplerDPMPP_3M_SDE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_3M_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) + + @classmethod + def execute(cls, eta, s_noise, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_3m_sde" else: sampler_name = "dpmpp_3m_sde_gpu" sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMPP_2M_SDE: + get_sampler = execute + +class SamplerDPMPP_2M_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"solver_type": (['midpoint', 'heun'], ), - "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2M_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=['midpoint', 'heun']), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, solver_type, eta, s_noise, noise_device): + @classmethod + def execute(cls, solver_type, eta, s_noise, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_2m_sde" else: sampler_name = "dpmpp_2m_sde_gpu" - sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) - return (sampler, ) + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerDPMPP_SDE: +class SamplerDPMPP_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "r": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "noise_device": (['gpu', 'cpu'], ), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("r", default=0.5, min=0.0, max=100.0, step=0.01, round=False), + io.Combo.Input("noise_device", options=['gpu', 'cpu']), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise, r, noise_device): + @classmethod + def execute(cls, eta, s_noise, r, noise_device) -> io.NodeOutput: if noise_device == 'cpu': sampler_name = "dpmpp_sde" else: sampler_name = "dpmpp_sde_gpu" - sampler = samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) - return (sampler, ) + sampler = comfy.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) + return io.NodeOutput(sampler) -class SamplerDPMPP_2S_Ancestral: + get_sampler = execute + +class SamplerDPMPP_2S_Ancestral(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMPP_2S_Ancestral", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("dpmpp_2s_ancestral", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerEulerAncestral: + get_sampler = execute + +class SamplerEulerAncestral(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestral", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("euler_ancestral", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerEulerAncestralCFGPP: + get_sampler = execute + +class SamplerEulerAncestralCFGPP(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "eta": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step":0.01, "round": False}), - }} - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerEulerAncestralCFGPP", + display_name="SamplerEulerAncestralCFG++", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Float.Input("eta", default=1.0, min=0.0, max=1.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, eta, s_noise): + @classmethod + def execute(cls, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler( "euler_ancestral_cfg_pp", {"eta": eta, "s_noise": s_noise}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerLMS: + get_sampler = execute + +class SamplerLMS(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"order": ("INT", {"default": 4, "min": 1, "max": 100}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerLMS", + category="sampling/custom_sampling/samplers", + inputs=[io.Int.Input("order", default=4, min=1, max=100)], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, order): + @classmethod + def execute(cls, order) -> io.NodeOutput: sampler = comfy.samplers.ksampler("lms", {"order": order}) - return (sampler, ) + return io.NodeOutput(sampler) -class SamplerDPMAdaptative: + get_sampler = execute + +class SamplerDPMAdaptative(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"order": ("INT", {"default": 3, "min": 2, "max": 3}), - "rtol": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "atol": ("FLOAT", {"default": 0.0078, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "h_init": ("FLOAT", {"default": 0.05, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "pcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "icoeff": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "dcoeff": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "accept_safety": ("FLOAT", {"default": 0.81, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "eta": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - "s_noise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 100.0, "step":0.01, "round": False}), - } - } - RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling/samplers" + def define_schema(cls): + return io.Schema( + node_id="SamplerDPMAdaptative", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Int.Input("order", default=3, min=2, max=3), + io.Float.Input("rtol", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("atol", default=0.0078, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("h_init", default=0.05, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("pcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("icoeff", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("dcoeff", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("accept_safety", default=0.81, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("eta", default=0.0, min=0.0, max=100.0, step=0.01, round=False), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - FUNCTION = "get_sampler" - - def get_sampler(self, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise): + @classmethod + def execute(cls, order, rtol, atol, h_init, pcoeff, icoeff, dcoeff, accept_safety, eta, s_noise) -> io.NodeOutput: sampler = comfy.samplers.ksampler("dpm_adaptive", {"order": order, "rtol": rtol, "atol": atol, "h_init": h_init, "pcoeff": pcoeff, - "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, - "s_noise":s_noise }) - return (sampler, ) + "icoeff": icoeff, "dcoeff": dcoeff, "accept_safety": accept_safety, "eta": eta, + "s_noise":s_noise }) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerER_SDE(ComfyNodeABC): +class SamplerER_SDE(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "solver_type": (IO.COMBO, {"options": ["ER-SDE", "Reverse-time SDE", "ODE"]}), - "max_stage": (IO.INT, {"default": 3, "min": 1, "max": 3}), - "eta": ( - IO.FLOAT, - {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False, "tooltip": "Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."}, - ), - "s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False}), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerER_SDE", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Combo.Input("solver_type", options=["ER-SDE", "Reverse-time SDE", "ODE"]), + io.Int.Input("max_stage", default=3, min=1, max=3), + io.Float.Input("eta", default=1.0, min=0.0, max=100.0, step=0.01, round=False, tooltip="Stochastic strength of reverse-time SDE.\nWhen eta=0, it reduces to deterministic ODE. This setting doesn't apply to ER-SDE solver type."), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + ], + outputs=[io.Sampler.Output()] + ) - RETURN_TYPES = (IO.SAMPLER,) - CATEGORY = "sampling/custom_sampling/samplers" - - FUNCTION = "get_sampler" - - def get_sampler(self, solver_type, max_stage, eta, s_noise): + @classmethod + def execute(cls, solver_type, max_stage, eta, s_noise) -> io.NodeOutput: if solver_type == "ODE" or (solver_type == "Reverse-time SDE" and eta == 0): eta = 0 s_noise = 0 @@ -553,32 +612,33 @@ class SamplerER_SDE(ComfyNodeABC): sampler_name = "er_sde" sampler = comfy.samplers.ksampler(sampler_name, {"s_noise": s_noise, "noise_scaler": noise_scaler, "max_stage": max_stage}) - return (sampler,) + return io.NodeOutput(sampler) + + get_sampler = execute -class SamplerSASolver(ComfyNodeABC): +class SamplerSASolver(io.ComfyNode): @classmethod - def INPUT_TYPES(cls) -> InputTypeDict: - return { - "required": { - "model": (IO.MODEL, {}), - "eta": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01, "round": False},), - "sde_start_percent": (IO.FLOAT, {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.001},), - "sde_end_percent": (IO.FLOAT, {"default": 0.8, "min": 0.0, "max": 1.0, "step": 0.001},), - "s_noise": (IO.FLOAT, {"default": 1.0, "min": 0.0, "max": 100.0, "step": 0.01, "round": False},), - "predictor_order": (IO.INT, {"default": 3, "min": 1, "max": 6}), - "corrector_order": (IO.INT, {"default": 4, "min": 0, "max": 6}), - "use_pece": (IO.BOOLEAN, {}), - "simple_order_2": (IO.BOOLEAN, {}), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerSASolver", + category="sampling/custom_sampling/samplers", + inputs=[ + io.Model.Input("model"), + io.Float.Input("eta", default=1.0, min=0.0, max=10.0, step=0.01, round=False), + io.Float.Input("sde_start_percent", default=0.2, min=0.0, max=1.0, step=0.001), + io.Float.Input("sde_end_percent", default=0.8, min=0.0, max=1.0, step=0.001), + io.Float.Input("s_noise", default=1.0, min=0.0, max=100.0, step=0.01, round=False), + io.Int.Input("predictor_order", default=3, min=1, max=6), + io.Int.Input("corrector_order", default=4, min=0, max=6), + io.Boolean.Input("use_pece"), + io.Boolean.Input("simple_order_2"), + ], + outputs=[io.Sampler.Output()] + ) - RETURN_TYPES = (IO.SAMPLER,) - CATEGORY = "sampling/custom_sampling/samplers" - - FUNCTION = "get_sampler" - - def get_sampler(self, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2): + @classmethod + def execute(cls, model, eta, sde_start_percent, sde_end_percent, s_noise, predictor_order, corrector_order, use_pece, simple_order_2) -> io.NodeOutput: model_sampling = model.get_model_object("model_sampling") start_sigma = model_sampling.percent_to_sigma(sde_start_percent) end_sigma = model_sampling.percent_to_sigma(sde_end_percent) @@ -596,7 +656,9 @@ class SamplerSASolver(ComfyNodeABC): "simple_order_2": simple_order_2, }, ) - return (sampler,) + return io.NodeOutput(sampler) + + get_sampler = execute class Noise_EmptyNoise: @@ -617,30 +679,31 @@ class Noise_RandomNoise: batch_inds = input_latent["batch_index"] if "batch_index" in input_latent else None return comfy.sample.prepare_noise(latent_image, self.seed, batch_inds) -class SamplerCustom: +class SamplerCustom(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "add_noise": ("BOOLEAN", {"default": True}), - "noise_seed": Seed64, - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "sampler": ("SAMPLER", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="SamplerCustom", + category="sampling/custom_sampling", + inputs=[ + io.Model.Input("model"), + io.Boolean.Input("add_noise", default=True), + io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) - RETURN_TYPES = ("LATENT","LATENT") - RETURN_NAMES = ("output", "denoised_output") - - FUNCTION = "sample" - - CATEGORY = "sampling/custom_sampling" - - def sample(self, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image): + @classmethod + def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler, sigmas, latent_image) -> io.NodeOutput: latent = latent_image latent_image = latent["samples"] latent = latent.copy() @@ -660,7 +723,7 @@ class SamplerCustom: callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) disable_pbar = not current_execution_context().server.receive_all_progress_notifications - samples = sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) + samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) out = latent.copy() out["samples"] = samples @@ -669,52 +732,58 @@ class SamplerCustom: out_denoised["samples"] = model.model.process_latent_out(x0_output["x0"].cpu()) else: out_denoised = out - return (out, out_denoised) + return io.NodeOutput(out, out_denoised) + + sample = execute class Guider_Basic(comfy.samplers.CFGGuider): def set_conds(self, positive): self.inner_set_conds({"positive": positive}) -class BasicGuider: +class BasicGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "conditioning": ("CONDITIONING", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="BasicGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("conditioning"), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, conditioning): + @classmethod + def execute(cls, model, conditioning) -> io.NodeOutput: guider = Guider_Basic(model) guider.set_conds(conditioning) - return (guider,) + return io.NodeOutput(guider) -class CFGGuider: + get_guider = execute + +class CFGGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "positive": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - } - } + def define_schema(cls): + return io.Schema( + node_id="CFGGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, positive, negative, cfg): + @classmethod + def execute(cls, model, positive, negative, cfg) -> io.NodeOutput: guider = comfy.samplers.CFGGuider(model) guider.set_conds(positive, negative) guider.set_cfg(cfg) - return (guider,) + return io.NodeOutput(guider) + + get_guider = execute class Guider_DualCFG(comfy.samplers.CFGGuider): def set_cfg(self, cfg1, cfg2, nested=False): @@ -745,79 +814,88 @@ class Guider_DualCFG(comfy.samplers.CFGGuider): out = comfy.samplers.calc_cond_batch(self.inner_model, [negative_cond, middle_cond, positive_cond], x, timestep, model_options) return comfy.samplers.cfg_function(self.inner_model, out[1], out[0], self.cfg2, x, timestep, model_options=model_options, cond=middle_cond, uncond=negative_cond) + (out[2] - out[1]) * self.cfg1 -class DualCFGGuider: +class DualCFGGuider(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "cond1": ("CONDITIONING", ), - "cond2": ("CONDITIONING", ), - "negative": ("CONDITIONING", ), - "cfg_conds": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - "cfg_cond2_negative": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), - }, "optional": { - "style": (["regular", "nested"], {"default": "regular"}), - } - } + def define_schema(cls): + return io.Schema( + node_id="DualCFGGuider", + category="sampling/custom_sampling/guiders", + inputs=[ + io.Model.Input("model"), + io.Conditioning.Input("cond1"), + io.Conditioning.Input("cond2"), + io.Conditioning.Input("negative"), + io.Float.Input("cfg_conds", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Float.Input("cfg_cond2_negative", default=8.0, min=0.0, max=100.0, step=0.1, round=0.01), + io.Combo.Input("style", options=["regular", "nested"]), + ], + outputs=[io.Guider.Output()] + ) - RETURN_TYPES = ("GUIDER",) - - FUNCTION = "get_guider" - CATEGORY = "sampling/custom_sampling/guiders" - - def get_guider(self, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style="regular"): + @classmethod + def execute(cls, model, cond1, cond2, negative, cfg_conds, cfg_cond2_negative, style) -> io.NodeOutput: guider = Guider_DualCFG(model) guider.set_conds(cond1, cond2, negative) guider.set_cfg(cfg_conds, cfg_cond2_negative, nested=(style == "nested")) - return (guider,) + return io.NodeOutput(guider) -class DisableNoise: + get_guider = execute + +class DisableNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required":{ - } - } + def define_schema(cls): + return io.Schema( + node_id="DisableNoise", + category="sampling/custom_sampling/noise", + inputs=[], + outputs=[io.Noise.Output()] + ) - RETURN_TYPES = ("NOISE",) - FUNCTION = "get_noise" - CATEGORY = "sampling/custom_sampling/noise" - - def get_noise(self): - return (Noise_EmptyNoise(),) - - -class RandomNoise(DisableNoise): @classmethod - def INPUT_TYPES(s): - return {"required":{ - "noise_seed": Seed64, - } - } + def execute(cls) -> io.NodeOutput: + return io.NodeOutput(Noise_EmptyNoise()) - def get_noise(self, noise_seed): - return (Noise_RandomNoise(noise_seed),) + get_noise = execute -class SamplerCustomAdvanced: +class RandomNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"noise": ("NOISE", ), - "guider": ("GUIDER", ), - "sampler": ("SAMPLER", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="RandomNoise", + category="sampling/custom_sampling/noise", + inputs=[io.Int.Input("noise_seed", default=0, min=0, max=0xffffffffffffffff, control_after_generate=True)], + outputs=[io.Noise.Output()] + ) - RETURN_TYPES = ("LATENT","LATENT") - RETURN_NAMES = ("output", "denoised_output") + @classmethod + def execute(cls, noise_seed) -> io.NodeOutput: + return io.NodeOutput(Noise_RandomNoise(noise_seed)) - FUNCTION = "sample" + get_noise = execute - CATEGORY = "sampling/custom_sampling" - def sample(self, noise, guider: comfy.samplers.CFGGuider, sampler: KSAMPLER, sigmas, latent_image): +class SamplerCustomAdvanced(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SamplerCustomAdvanced", + category="sampling/custom_sampling", + inputs=[ + io.Noise.Input("noise"), + io.Guider.Input("guider"), + io.Sampler.Input("sampler"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(display_name="output"), + io.Latent.Output(display_name="denoised_output"), + ] + ) + + @classmethod + def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput: latent = latent_image latent_image = latent["samples"] latent = latent.copy() @@ -842,28 +920,32 @@ class SamplerCustomAdvanced: out_denoised["samples"] = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu()) else: out_denoised = out - return (out, out_denoised) + return io.NodeOutput(out, out_denoised) -class AddNoise: + sample = execute + +class AddNoise(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": - {"model": ("MODEL",), - "noise": ("NOISE", ), - "sigmas": ("SIGMAS", ), - "latent_image": ("LATENT", ), - } - } + def define_schema(cls): + return io.Schema( + node_id="AddNoise", + category="_for_testing/custom_sampling/noise", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.Noise.Input("noise"), + io.Sigmas.Input("sigmas"), + io.Latent.Input("latent_image"), + ], + outputs=[ + io.Latent.Output(), + ] + ) - RETURN_TYPES = ("LATENT",) - - FUNCTION = "add_noise" - - CATEGORY = "_for_testing/custom_sampling/noise" - - def add_noise(self, model, noise, sigmas, latent_image): + @classmethod + def execute(cls, model, noise, sigmas, latent_image) -> io.NodeOutput: if len(sigmas) == 0: - return latent_image + return io.NodeOutput(latent_image) latent = latent_image latent_image = latent["samples"] @@ -887,46 +969,50 @@ class AddNoise: out = latent.copy() out["samples"] = noisy - return (out,) + return io.NodeOutput(out) + + add_noise = execute -NODE_CLASS_MAPPINGS = { - "SamplerCustom": SamplerCustom, - "BasicScheduler": BasicScheduler, - "KarrasScheduler": KarrasScheduler, - "ExponentialScheduler": ExponentialScheduler, - "PolyexponentialScheduler": PolyexponentialScheduler, - "LaplaceScheduler": LaplaceScheduler, - "VPScheduler": VPScheduler, - "BetaSamplingScheduler": BetaSamplingScheduler, - "SDTurboScheduler": SDTurboScheduler, - "KSamplerSelect": KSamplerSelect, - "SamplerEulerAncestral": SamplerEulerAncestral, - "SamplerEulerAncestralCFGPP": SamplerEulerAncestralCFGPP, - "SamplerLMS": SamplerLMS, - "SamplerDPMPP_3M_SDE": SamplerDPMPP_3M_SDE, - "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, - "SamplerDPMPP_SDE": SamplerDPMPP_SDE, - "SamplerDPMPP_2S_Ancestral": SamplerDPMPP_2S_Ancestral, - "SamplerDPMAdaptative": SamplerDPMAdaptative, - "SamplerER_SDE": SamplerER_SDE, - "SamplerSASolver": SamplerSASolver, - "SplitSigmas": SplitSigmas, - "SplitSigmasDenoise": SplitSigmasDenoise, - "FlipSigmas": FlipSigmas, - "SetFirstSigma": SetFirstSigma, - "ExtendIntermediateSigmas": ExtendIntermediateSigmas, - "SamplingPercentToSigma": SamplingPercentToSigma, +class CustomSamplersExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + SamplerCustom, + BasicScheduler, + KarrasScheduler, + ExponentialScheduler, + PolyexponentialScheduler, + LaplaceScheduler, + VPScheduler, + BetaSamplingScheduler, + SDTurboScheduler, + KSamplerSelect, + SamplerEulerAncestral, + SamplerEulerAncestralCFGPP, + SamplerLMS, + SamplerDPMPP_3M_SDE, + SamplerDPMPP_2M_SDE, + SamplerDPMPP_SDE, + SamplerDPMPP_2S_Ancestral, + SamplerDPMAdaptative, + SamplerER_SDE, + SamplerSASolver, + SplitSigmas, + SplitSigmasDenoise, + FlipSigmas, + SetFirstSigma, + ExtendIntermediateSigmas, + SamplingPercentToSigma, + CFGGuider, + DualCFGGuider, + BasicGuider, + RandomNoise, + DisableNoise, + AddNoise, + SamplerCustomAdvanced, + ] - "CFGGuider": CFGGuider, - "DualCFGGuider": DualCFGGuider, - "BasicGuider": BasicGuider, - "RandomNoise": RandomNoise, - "DisableNoise": DisableNoise, - "AddNoise": AddNoise, - "SamplerCustomAdvanced": SamplerCustomAdvanced, -} -NODE_DISPLAY_NAME_MAPPINGS = { - "SamplerEulerAncestralCFGPP": "SamplerEulerAncestralCFG++", -} +async def comfy_entrypoint() -> CustomSamplersExtension: + return CustomSamplersExtension() diff --git a/comfy_extras/nodes/nodes_easycache.py b/comfy_extras/nodes/nodes_easycache.py index 4633b479d..2d0236b04 100644 --- a/comfy_extras/nodes/nodes_easycache.py +++ b/comfy_extras/nodes/nodes_easycache.py @@ -12,13 +12,13 @@ logger = logging.getLogger(__name__) def easycache_forward_wrapper(executor, *args, **kwargs): # get values from args - x: torch.Tensor = args[0] transformer_options: dict[str] = args[-1] if not isinstance(transformer_options, dict): transformer_options = kwargs.get("transformer_options") if not transformer_options: transformer_options = args[-2] easycache: EasyCacheHolder = transformer_options["easycache"] + x: torch.Tensor = args[0][:, :easycache.output_channels] sigmas = transformer_options["sigmas"] uuids = transformer_options["uuids"] if sigmas is not None and easycache.is_past_end_timestep(sigmas): @@ -83,13 +83,13 @@ def easycache_forward_wrapper(executor, *args, **kwargs): def lazycache_predict_noise_wrapper(executor, *args, **kwargs): # get values from args - x: torch.Tensor = args[0] timestep: float = args[1] model_options: dict[str] = args[2] easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"] if easycache.is_past_end_timestep(timestep): return executor(*args, **kwargs) # prepare next x_prev + x: torch.Tensor = args[0][:, :easycache.output_channels] next_x_prev = x input_change = None do_easycache = easycache.should_do_easycache(timestep) @@ -174,7 +174,7 @@ def easycache_sample_wrapper(executor, *args, **kwargs): class EasyCacheHolder: - def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False): + def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None): self.name = "EasyCache" self.reuse_threshold = reuse_threshold self.start_percent = start_percent @@ -203,6 +203,7 @@ class EasyCacheHolder: self.allow_mismatch = True self.cut_from_start = True self.state_metadata = None + self.output_channels = output_channels def is_past_end_timestep(self, timestep: float) -> bool: return not (timestep[0] > self.end_t).item() @@ -265,7 +266,7 @@ class EasyCacheHolder: else: slicing.append(slice(None)) batch_slice = batch_slice + slicing - x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device) + x[tuple(batch_slice)] += self.uuid_cache_diffs[uuid].to(x.device) return x def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]): @@ -284,7 +285,7 @@ class EasyCacheHolder: else: slicing.append(slice(None)) skip_dim = False - x = x[slicing] + x = x[tuple(slicing)] diff = output - x batch_offset = diff.shape[0] // len(uuids) for i, uuid in enumerate(uuids): @@ -324,7 +325,7 @@ class EasyCacheHolder: return self def clone(self): - return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose) + return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels) class EasyCacheNode(io.ComfyNode): @@ -351,7 +352,7 @@ class EasyCacheNode(io.ComfyNode): @classmethod def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: model = model.clone() - model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose) + model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper) @@ -359,7 +360,7 @@ class EasyCacheNode(io.ComfyNode): class LazyCacheHolder: - def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False): + def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False, output_channels: int=None): self.name = "LazyCache" self.reuse_threshold = reuse_threshold self.start_percent = start_percent @@ -383,6 +384,7 @@ class LazyCacheHolder: self.approx_output_change_rates = [] self.total_steps_skipped = 0 self.state_metadata = None + self.output_channels = output_channels def has_cache_diff(self) -> bool: return self.cache_diff is not None @@ -457,7 +459,7 @@ class LazyCacheHolder: return self def clone(self): - return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose) + return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose, output_channels=self.output_channels) class LazyCacheNode(io.ComfyNode): @classmethod @@ -483,7 +485,7 @@ class LazyCacheNode(io.ComfyNode): @classmethod def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: model = model.clone() - model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose) + model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose, output_channels=model.model.latent_format.latent_channels) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper) model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper) return io.NodeOutput(model) diff --git a/comfy_extras/nodes/nodes_flux.py b/comfy_extras/nodes/nodes_flux.py index 78a491286..70555d371 100644 --- a/comfy_extras/nodes/nodes_flux.py +++ b/comfy_extras/nodes/nodes_flux.py @@ -4,7 +4,12 @@ from typing_extensions import override from comfy_api.latest import ComfyExtension, io from ..constants.resolutions import KONTEXT_RESOLUTIONS +from comfy.nodes.common import MAX_RESOLUTION +import comfy.model_management +import torch +import math +import nodes class CLIPTextEncodeFlux(io.ComfyNode): @classmethod @@ -32,6 +37,27 @@ class CLIPTextEncodeFlux(io.ComfyNode): encode = execute # TODO: remove +class EmptyFlux2LatentImage(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="EmptyFlux2LatentImage", + display_name="Empty Flux 2 Latent", + category="latent", + inputs=[ + io.Int.Input("width", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=1024, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("batch_size", default=1, min=1, max=4096), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, width, height, batch_size=1) -> io.NodeOutput: + latent = torch.zeros([batch_size, 128, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) class FluxGuidance(io.ComfyNode): @classmethod @@ -136,6 +162,58 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode): append = execute # TODO: remove +def generalized_time_snr_shift(t, mu: float, sigma: float): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float: + a1, b1 = 8.73809524e-05, 1.89833333 + a2, b2 = 0.00016927, 0.45666666 + + if image_seq_len > 4300: + mu = a2 * image_seq_len + b2 + return float(mu) + + m_200 = a2 * image_seq_len + b2 + m_10 = a1 * image_seq_len + b1 + + a = (m_200 - m_10) / 190.0 + b = m_200 - 200.0 * a + mu = a * num_steps + b + + return float(mu) + + +def get_schedule(num_steps: int, image_seq_len: int) -> list[float]: + mu = compute_empirical_mu(image_seq_len, num_steps) + timesteps = torch.linspace(1, 0, num_steps + 1) + timesteps = generalized_time_snr_shift(timesteps, mu, 1.0) + return timesteps + + +class Flux2Scheduler(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Flux2Scheduler", + category="sampling/custom_sampling/schedulers", + inputs=[ + io.Int.Input("steps", default=20, min=1, max=4096), + io.Int.Input("width", default=1024, min=16, max=MAX_RESOLUTION, step=1), + io.Int.Input("height", default=1024, min=16, max=MAX_RESOLUTION, step=1), + ], + outputs=[ + io.Sigmas.Output(), + ], + ) + + @classmethod + def execute(cls, steps, width, height) -> io.NodeOutput: + seq_len = (width * height / (16 * 16)) + sigmas = get_schedule(steps, round(seq_len)) + return io.NodeOutput(sigmas) + + class FluxExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: @@ -145,6 +223,8 @@ class FluxExtension(ComfyExtension): FluxDisableGuidance, FluxKontextImageScale, FluxKontextMultiReferenceLatentMethod, + EmptyFlux2LatentImage, + Flux2Scheduler, ] diff --git a/comfy_extras/nodes/nodes_freelunch.py b/comfy_extras/nodes/nodes_freelunch.py index de86c2c20..71a27e7d7 100644 --- a/comfy_extras/nodes/nodes_freelunch.py +++ b/comfy_extras/nodes/nodes_freelunch.py @@ -26,6 +26,8 @@ SOFTWARE. import torch import logging +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO def Fourier_filter(x, threshold, scale): # FFT @@ -46,21 +48,26 @@ def Fourier_filter(x, threshold, scale): return x_filtered.to(x.dtype) -class FreeU: +class FreeU(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "b1": ("FLOAT", {"default": 1.1, "min": 0.0, "max": 10.0, "step": 0.01}), - "b2": ("FLOAT", {"default": 1.2, "min": 0.0, "max": 10.0, "step": 0.01}), - "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), - "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return IO.Schema( + node_id="FreeU", + category="model_patches/unet", + inputs=[ + IO.Model.Input("model"), + IO.Float.Input("b1", default=1.1, min=0.0, max=10.0, step=0.01), + IO.Float.Input("b2", default=1.2, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + IO.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, b1, b2, s1, s2): + @classmethod + def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput: model_channels = model.model.model_config.unet_config["model_channels"] scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} on_cpu_devices = {} @@ -83,23 +90,31 @@ class FreeU: m = model.clone() m.set_model_output_block_patch(output_block_patch) - return (m, ) + return IO.NodeOutput(m) -class FreeU_V2: + patch = execute # TODO: remove + + +class FreeU_V2(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "b1": ("FLOAT", {"default": 1.3, "min": 0.0, "max": 10.0, "step": 0.01}), - "b2": ("FLOAT", {"default": 1.4, "min": 0.0, "max": 10.0, "step": 0.01}), - "s1": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 10.0, "step": 0.01}), - "s2": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "patch" + def define_schema(cls): + return IO.Schema( + node_id="FreeU_V2", + category="model_patches/unet", + inputs=[ + IO.Model.Input("model"), + IO.Float.Input("b1", default=1.3, min=0.0, max=10.0, step=0.01), + IO.Float.Input("b2", default=1.4, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s1", default=0.9, min=0.0, max=10.0, step=0.01), + IO.Float.Input("s2", default=0.2, min=0.0, max=10.0, step=0.01), + ], + outputs=[ + IO.Model.Output(), + ], + ) - CATEGORY = "model_patches/unet" - - def patch(self, model, b1, b2, s1, s2): + @classmethod + def execute(cls, model, b1, b2, s1, s2) -> IO.NodeOutput: model_channels = model.model.model_config.unet_config["model_channels"] scale_dict = {model_channels * 4: (b1, s1), model_channels * 2: (b2, s2)} on_cpu_devices = {} @@ -129,9 +144,19 @@ class FreeU_V2: m = model.clone() m.set_model_output_block_patch(output_block_patch) - return (m, ) + return IO.NodeOutput(m) -NODE_CLASS_MAPPINGS = { - "FreeU": FreeU, - "FreeU_V2": FreeU_V2, -} + patch = execute # TODO: remove + + +class FreelunchExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + FreeU, + FreeU_V2, + ] + + +async def comfy_entrypoint() -> FreelunchExtension: + return FreelunchExtension() diff --git a/comfy_extras/nodes/nodes_hunyuan.py b/comfy_extras/nodes/nodes_hunyuan.py index b8c801fa9..fb5bafde6 100644 --- a/comfy_extras/nodes/nodes_hunyuan.py +++ b/comfy_extras/nodes/nodes_hunyuan.py @@ -8,6 +8,8 @@ from comfy.nodes.common import MAX_RESOLUTION from comfy.nodes import base_nodes as nodes from comfy import node_helpers +from comfy.ldm.hunyuan_video.upsampler import HunyuanVideo15SRModel +from comfy.cmd import folder_paths class CLIPTextEncodeHunyuanDiT(io.ComfyNode): @classmethod @@ -40,6 +42,7 @@ class EmptyHunyuanLatentVideo(io.ComfyNode): def define_schema(cls): return io.Schema( node_id="EmptyHunyuanLatentVideo", + display_name="Empty HunyuanVideo 1.0 Latent", category="latent/video", inputs=[ io.Int.Input("width", default=848, min=16, max=MAX_RESOLUTION, step=16), @@ -60,6 +63,198 @@ class EmptyHunyuanLatentVideo(io.ComfyNode): generate = execute # TODO: remove +class EmptyHunyuanVideo15Latent(EmptyHunyuanLatentVideo): + @classmethod + def define_schema(cls): + schema = super().define_schema() + schema.node_id = "EmptyHunyuanVideo15Latent" + schema.display_name = "Empty HunyuanVideo 1.5 Latent" + return schema + + @classmethod + def execute(cls, width, height, length, batch_size=1) -> io.NodeOutput: + # Using scale factor of 16 instead of 8 + latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + return io.NodeOutput({"samples": latent}) + + +class HunyuanVideo15ImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15ImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=848, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=33, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 32, ((length - 1) // 4) + 1, height // 16, width // 16], device=comfy.model_management.intermediate_device()) + + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + + encoded = vae.encode(start_image[:, :, :, :3]) + concat_latent_image = torch.zeros((latent.shape[0], 32, latent.shape[2], latent.shape[3], latent.shape[4]), device=comfy.model_management.intermediate_device()) + concat_latent_image[:, :, :encoded.shape[2], :, :] = encoded + + mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) + + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent) + + +class HunyuanVideo15SuperResolution(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15SuperResolution", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae", optional=True), + io.Image.Input("start_image", optional=True), + io.ClipVisionOutput.Input("clip_vision_output", optional=True), + io.Latent.Input("latent"), + io.Float.Input("noise_augmentation", default=0.70, min=0.0, max=1.0, step=0.01), + + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, positive, negative, latent, noise_augmentation, vae=None, start_image=None, clip_vision_output=None) -> io.NodeOutput: + in_latent = latent["samples"] + in_channels = in_latent.shape[1] + cond_latent = torch.zeros([in_latent.shape[0], in_channels * 2 + 2, in_latent.shape[-3], in_latent.shape[-2], in_latent.shape[-1]], device=comfy.model_management.intermediate_device()) + cond_latent[:, in_channels + 1 : 2 * in_channels + 1] = in_latent + cond_latent[:, 2 * in_channels + 1] = 1 + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image.movedim(-1, 1), in_latent.shape[-1] * 16, in_latent.shape[-2] * 16, "bilinear", "center").movedim(1, -1) + encoded = vae.encode(start_image[:, :, :, :3]) + cond_latent[:, :in_channels, :encoded.shape[2], :, :] = encoded + cond_latent[:, in_channels + 1, 0] = 1 + + positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation}) + negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": cond_latent, "noise_augmentation": noise_augmentation}) + if clip_vision_output is not None: + positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output}) + negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output}) + + return io.NodeOutput(positive, negative, latent) + + +class LatentUpscaleModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LatentUpscaleModelLoader", + display_name="Load Latent Upscale Model", + category="loaders", + inputs=[ + io.Combo.Input("model_name", options=folder_paths.get_filename_list("latent_upscale_models")), + ], + outputs=[ + io.LatentUpscaleModel.Output(), + ], + ) + + @classmethod + def execute(cls, model_name) -> io.NodeOutput: + model_path = folder_paths.get_full_path_or_raise("latent_upscale_models", model_name) + sd = comfy.utils.load_torch_file(model_path, safe_load=True) + + if "blocks.0.block.0.conv.weight" in sd: + config = { + "in_channels": sd["in_conv.conv.weight"].shape[1], + "out_channels": sd["out_conv.conv.weight"].shape[0], + "hidden_channels": sd["in_conv.conv.weight"].shape[0], + "num_blocks": len([k for k in sd.keys() if k.startswith("blocks.") and k.endswith(".block.0.conv.weight")]), + "global_residual": False, + } + model_type = "720p" + elif "up.0.block.0.conv1.conv.weight" in sd: + sd = {key.replace("nin_shortcut", "nin_shortcut.conv", 1): value for key, value in sd.items()} + config = { + "z_channels": sd["conv_in.conv.weight"].shape[1], + "out_channels": sd["conv_out.conv.weight"].shape[0], + "block_out_channels": tuple(sd[f"up.{i}.block.0.conv1.conv.weight"].shape[0] for i in range(len([k for k in sd.keys() if k.startswith("up.") and k.endswith(".block.0.conv1.conv.weight")]))), + } + model_type = "1080p" + + model = HunyuanVideo15SRModel(model_type, config) + model.load_sd(sd) + + return io.NodeOutput(model) + + +class HunyuanVideo15LatentUpscaleWithModel(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="HunyuanVideo15LatentUpscaleWithModel", + display_name="Hunyuan Video 15 Latent Upscale With Model", + category="latent", + inputs=[ + io.LatentUpscaleModel.Input("model"), + io.Latent.Input("samples"), + io.Combo.Input("upscale_method", options=["nearest-exact", "bilinear", "area", "bicubic", "bislerp"], default="bilinear"), + io.Int.Input("width", default=1280, min=0, max=16384, step=8), + io.Int.Input("height", default=720, min=0, max=16384, step=8), + io.Combo.Input("crop", options=["disabled", "center"]), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, model, samples, upscale_method, width, height, crop) -> io.NodeOutput: + if width == 0 and height == 0: + return io.NodeOutput(samples) + else: + if width == 0: + height = max(64, height) + width = max(64, round(samples["samples"].shape[-1] * height / samples["samples"].shape[-2])) + elif height == 0: + width = max(64, width) + height = max(64, round(samples["samples"].shape[-2] * width / samples["samples"].shape[-1])) + else: + width = max(64, width) + height = max(64, height) + s = comfy.utils.common_upscale(samples["samples"], width // 16, height // 16, upscale_method, crop) + s = model.resample_latent(s) + return io.NodeOutput({"samples": s.cpu().float()}) + + PROMPT_TEMPLATE_ENCODE_VIDEO_I2V = ( "<|start_header_id|>system<|end_header_id|>\n\n\nDescribe the video by detailing the following aspects according to the reference image: " "1. The main content and theme of the video." @@ -216,6 +411,11 @@ class HunyuanExtension(ComfyExtension): CLIPTextEncodeHunyuanDiT, TextEncodeHunyuanVideo_ImageToVideo, EmptyHunyuanLatentVideo, + EmptyHunyuanVideo15Latent, + HunyuanVideo15ImageToVideo, + HunyuanVideo15SuperResolution, + HunyuanVideo15LatentUpscaleWithModel, + LatentUpscaleModelLoader, HunyuanImageToVideo, EmptyHunyuanImageLatent, HunyuanRefinerLatent, diff --git a/comfy_extras/nodes/nodes_hunyuan3d.py b/comfy_extras/nodes/nodes_hunyuan3d.py index dc30179de..12607996a 100644 --- a/comfy_extras/nodes/nodes_hunyuan3d.py +++ b/comfy_extras/nodes/nodes_hunyuan3d.py @@ -7,63 +7,79 @@ from comfy.ldm.modules.diffusionmodules.mmdit import get_1d_sincos_pos_embed_fro from comfy.cmd import folder_paths import comfy.model_management from comfy.cli_args import args +from typing_extensions import override +from comfy_api.latest import ComfyExtension, IO, Types +from comfy_api.latest._util import MESH, VOXEL # only for backward compatibility if someone import it from this file (will be removed later) # noqa -class EmptyLatentHunyuan3Dv2: + +class EmptyLatentHunyuan3Dv2(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "resolution": ("INT", {"default": 3072, "min": 1, "max": 8192}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 4096, "tooltip": "The number of latent images in the batch."}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="EmptyLatentHunyuan3Dv2", + category="latent/3d", + inputs=[ + IO.Int.Input("resolution", default=3072, min=1, max=8192), + IO.Int.Input("batch_size", default=1, min=1, max=4096, tooltip="The number of latent images in the batch."), + ], + outputs=[ + IO.Latent.Output(), + ] + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "generate" - - CATEGORY = "latent/3d" - - def generate(self, resolution, batch_size): + @classmethod + def execute(cls, resolution, batch_size) -> IO.NodeOutput: latent = torch.zeros([batch_size, 64, resolution], device=comfy.model_management.intermediate_device()) - return ({"samples": latent, "type": "hunyuan3dv2"}, ) + return IO.NodeOutput({"samples": latent, "type": "hunyuan3dv2"}) -class Hunyuan3Dv2Conditioning: + generate = execute # TODO: remove + + +class Hunyuan3Dv2Conditioning(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"clip_vision_output": ("CLIP_VISION_OUTPUT",), - }} + def define_schema(cls): + return IO.Schema( + node_id="Hunyuan3Dv2Conditioning", + category="conditioning/video_models", + inputs=[ + IO.ClipVisionOutput.Input("clip_vision_output"), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ] + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, clip_vision_output): + @classmethod + def execute(cls, clip_vision_output) -> IO.NodeOutput: embeds = clip_vision_output.last_hidden_state positive = [[embeds, {}]] negative = [[torch.zeros_like(embeds), {}]] - return (positive, negative) + return IO.NodeOutput(positive, negative) + + encode = execute # TODO: remove -class Hunyuan3Dv2ConditioningMultiView: +class Hunyuan3Dv2ConditioningMultiView(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {}, - "optional": {"front": ("CLIP_VISION_OUTPUT",), - "left": ("CLIP_VISION_OUTPUT",), - "back": ("CLIP_VISION_OUTPUT",), - "right": ("CLIP_VISION_OUTPUT",), }} + def define_schema(cls): + return IO.Schema( + node_id="Hunyuan3Dv2ConditioningMultiView", + category="conditioning/video_models", + inputs=[ + IO.ClipVisionOutput.Input("front", optional=True), + IO.ClipVisionOutput.Input("left", optional=True), + IO.ClipVisionOutput.Input("back", optional=True), + IO.ClipVisionOutput.Input("right", optional=True), + ], + outputs=[ + IO.Conditioning.Output(display_name="positive"), + IO.Conditioning.Output(display_name="negative"), + ] + ) - RETURN_TYPES = ("CONDITIONING", "CONDITIONING") - RETURN_NAMES = ("positive", "negative") - - FUNCTION = "encode" - - CATEGORY = "conditioning/video_models" - - def encode(self, front=None, left=None, back=None, right=None): + @classmethod + def execute(cls, front=None, left=None, back=None, right=None) -> IO.NodeOutput: all_embeds = [front, left, back, right] out = [] pos_embeds = None @@ -76,31 +92,37 @@ class Hunyuan3Dv2ConditioningMultiView: embeds = torch.cat(out, dim=1) positive = [[embeds, {}]] negative = [[torch.zeros_like(embeds), {}]] - return (positive, negative) + return IO.NodeOutput(positive, negative) + + encode = execute # TODO: remove -class VOXEL: - def __init__(self, data): - self.data = data - -class VAEDecodeHunyuan3D: +class VAEDecodeHunyuan3D(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"samples": ("LATENT", ), - "vae": ("VAE", ), - "num_chunks": ("INT", {"default": 8000, "min": 1000, "max": 500000}), - "octree_resolution": ("INT", {"default": 256, "min": 16, "max": 512}), - }} - RETURN_TYPES = ("VOXEL",) - FUNCTION = "decode" + def define_schema(cls): + return IO.Schema( + node_id="VAEDecodeHunyuan3D", + category="latent/3d", + inputs=[ + IO.Latent.Input("samples"), + IO.Vae.Input("vae"), + IO.Int.Input("num_chunks", default=8000, min=1000, max=500000), + IO.Int.Input("octree_resolution", default=256, min=16, max=512), + ], + outputs=[ + IO.Voxel.Output(), + ] + ) - CATEGORY = "latent/3d" - - def decode(self, vae, samples, num_chunks, octree_resolution): + @classmethod + def execute(cls, vae, samples, num_chunks, octree_resolution) -> IO.NodeOutput: if samples is None: - return None, - voxels = VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) - return (voxels, ) + return IO.NodeOutput(None) + voxels = Types.VOXEL(vae.decode(samples["samples"], vae_options={"num_chunks": num_chunks, "octree_resolution": octree_resolution})) + return IO.NodeOutput(voxels) + + decode = execute # TODO: remove + def voxel_to_mesh(voxels, threshold=0.5, device=None): if device is None: @@ -164,10 +186,10 @@ def voxel_to_mesh(voxels, threshold=0.5, device=None): padded_indices = neighbor_indices + 1 is_exposed = padded[ - padded_indices[:, 0], - padded_indices[:, 1], - padded_indices[:, 2] - ] == 0 + padded_indices[:, 0], + padded_indices[:, 1], + padded_indices[:, 2] + ] == 0 if not is_exposed.any(): continue @@ -211,6 +233,7 @@ def voxel_to_mesh(voxels, threshold=0.5, device=None): vertices = torch.fliplr(vertices) return vertices, faces + def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): if device is None: device = torch.device("cpu") @@ -288,7 +311,7 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): vertices = [] vertex_lookup = {} - vert_progress_mod = round(len(cell_vertices)/50) + vert_progress_mod = round(len(cell_vertices) / 50) for i, points in cell_vertices.items(): if not i % vert_progress_mod: @@ -332,15 +355,15 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): cross_products = [ torch.linalg.cross(pos_dirs[i].float(), pos_dirs[j].float()) - for i in range(3) for j in range(i+1, 3) + for i in range(3) for j in range(i + 1, 3) ] faces = [] all_keys = set(vertex_lookup.keys()) - face_progress_mod = round(len(active_cells)/38*3) + face_progress_mod = round(len(active_cells) / 38 * 3) - for pair_idx, (i, j) in enumerate([(0,1), (0,2), (1,2)]): + for pair_idx, (i, j) in enumerate([(0, 1), (0, 2), (1, 2)]): dir_i = pos_dirs[i] dir_j = pos_dirs[j] cross_product = cross_products[pair_idx] @@ -398,24 +421,24 @@ def voxel_to_mesh_surfnet(voxels, threshold=0.5, device=None): return final_vertices, faces -class MESH: - def __init__(self, vertices, faces): - self.vertices = vertices - self.faces = faces - -class VoxelToMeshBasic: +class VoxelToMeshBasic(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"voxel": ("VOXEL", ), - "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("MESH",) - FUNCTION = "decode" + def define_schema(cls): + return IO.Schema( + node_id="VoxelToMeshBasic", + category="3d", + inputs=[ + IO.Voxel.Input("voxel"), + IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01), + ], + outputs=[ + IO.Mesh.Output(), + ] + ) - CATEGORY = "3d" - - def decode(self, voxel, threshold): + @classmethod + def execute(cls, voxel, threshold) -> IO.NodeOutput: vertices = [] faces = [] for x in voxel.data: @@ -423,21 +446,29 @@ class VoxelToMeshBasic: vertices.append(v) faces.append(f) - return (MESH(torch.stack(vertices), torch.stack(faces)), ) + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) -class VoxelToMesh: + decode = execute # TODO: remove + + +class VoxelToMesh(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"voxel": ("VOXEL", ), - "algorithm": (["surface net", "basic"], ), - "threshold": ("FLOAT", {"default": 0.6, "min": -1.0, "max": 1.0, "step": 0.01}), - }} - RETURN_TYPES = ("MESH",) - FUNCTION = "decode" + def define_schema(cls): + return IO.Schema( + node_id="VoxelToMesh", + category="3d", + inputs=[ + IO.Voxel.Input("voxel"), + IO.Combo.Input("algorithm", options=["surface net", "basic"]), + IO.Float.Input("threshold", default=0.6, min=-1.0, max=1.0, step=0.01), + ], + outputs=[ + IO.Mesh.Output(), + ] + ) - CATEGORY = "3d" - - def decode(self, voxel, algorithm, threshold): + @classmethod + def execute(cls, voxel, algorithm, threshold) -> IO.NodeOutput: vertices = [] faces = [] @@ -453,7 +484,9 @@ class VoxelToMesh: vertices.append(v) faces.append(f) - return (MESH(torch.stack(vertices), torch.stack(faces)), ) + return IO.NodeOutput(Types.MESH(torch.stack(vertices), torch.stack(faces))) + + decode = execute # TODO: remove def save_glb(vertices, faces, filepath, metadata=None): @@ -585,31 +618,32 @@ def save_glb(vertices, faces, filepath, metadata=None): return filepath -class SaveGLB: +class SaveGLB(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": {"mesh": ("MESH", ), - "filename_prefix": ("STRING", {"default": "mesh/ComfyUI"}), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, } + def define_schema(cls): + return IO.Schema( + node_id="SaveGLB", + category="3d", + is_output_node=True, + inputs=[ + IO.Mesh.Input("mesh"), + IO.String.Input("filename_prefix", default="mesh/ComfyUI"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo] + ) - RETURN_TYPES = () - FUNCTION = "save" - - OUTPUT_NODE = True - - CATEGORY = "3d" - - def save(self, mesh, filename_prefix, prompt=None, extra_pnginfo=None): + @classmethod + def execute(cls, mesh, filename_prefix) -> IO.NodeOutput: full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(filename_prefix, folder_paths.get_output_directory()) results = [] metadata = {} if not args.disable_metadata: - if prompt is not None: - metadata["prompt"] = json.dumps(prompt) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata[x] = json.dumps(extra_pnginfo[x]) + if cls.hidden.prompt is not None: + metadata["prompt"] = json.dumps(cls.hidden.prompt) + if cls.hidden.extra_pnginfo is not None: + for x in cls.hidden.extra_pnginfo: + metadata[x] = json.dumps(cls.hidden.extra_pnginfo[x]) for i in range(mesh.vertices.shape[0]): f = f"{filename}_{counter:05}_.glb" @@ -620,15 +654,22 @@ class SaveGLB: "type": "output" }) counter += 1 - return {"ui": {"3d": results}} + return IO.NodeOutput(ui={"3d": results}) -NODE_CLASS_MAPPINGS = { - "EmptyLatentHunyuan3Dv2": EmptyLatentHunyuan3Dv2, - "Hunyuan3Dv2Conditioning": Hunyuan3Dv2Conditioning, - "Hunyuan3Dv2ConditioningMultiView": Hunyuan3Dv2ConditioningMultiView, - "VAEDecodeHunyuan3D": VAEDecodeHunyuan3D, - "VoxelToMeshBasic": VoxelToMeshBasic, - "VoxelToMesh": VoxelToMesh, - "SaveGLB": SaveGLB, -} +class Hunyuan3dExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + EmptyLatentHunyuan3Dv2, + Hunyuan3Dv2Conditioning, + Hunyuan3Dv2ConditioningMultiView, + VAEDecodeHunyuan3D, + VoxelToMeshBasic, + VoxelToMesh, + SaveGLB, + ] + + +async def comfy_entrypoint() -> Hunyuan3dExtension: + return Hunyuan3dExtension() diff --git a/comfy_extras/nodes/nodes_hypernetwork.py b/comfy_extras/nodes/nodes_hypernetwork.py index 8853d5b15..032da206f 100644 --- a/comfy_extras/nodes/nodes_hypernetwork.py +++ b/comfy_extras/nodes/nodes_hypernetwork.py @@ -2,6 +2,9 @@ from comfy import utils from comfy.cmd import folder_paths import torch import logging +from comfy_api.latest import IO, ComfyExtension +from typing_extensions import override + def load_hypernetwork_patch(path, strength): sd = utils.load_torch_file(path, safe_load=True) @@ -78,6 +81,7 @@ def load_hypernetwork_patch(path, strength): def __init__(self, hypernet, strength): self.hypernet = hypernet self.strength = strength + def __call__(self, q, k, v, extra_options): dim = k.shape[-1] if dim in self.hypernet: @@ -94,27 +98,43 @@ def load_hypernetwork_patch(path, strength): return hypernetwork_patch(out, strength) -class HypernetworkLoader: + +class HypernetworkLoader(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return {"required": { "model": ("MODEL",), - "hypernetwork_name": (folder_paths.get_filename_list("hypernetworks"),), - "strength": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01}), - }} - RETURN_TYPES = ("MODEL",) - FUNCTION = "load_hypernetwork" + def define_schema(cls): + return IO.Schema( + node_id="HypernetworkLoader", + category="loaders", + inputs=[ + IO.Model.Input("model"), + IO.Combo.Input("hypernetwork_name", options=folder_paths.get_filename_list("hypernetworks")), + IO.Float.Input("strength", default=1.0, min=-10.0, max=10.0, step=0.01), + ], + outputs=[ + IO.Model.Output(), + ], + ) - CATEGORY = "loaders" - - def load_hypernetwork(self, model, hypernetwork_name, strength): + @classmethod + def execute(cls, model, hypernetwork_name, strength) -> IO.NodeOutput: hypernetwork_path = folder_paths.get_full_path_or_raise("hypernetworks", hypernetwork_name) model_hypernetwork = model.clone() patch = load_hypernetwork_patch(hypernetwork_path, strength) if patch is not None: model_hypernetwork.set_model_attn1_patch(patch) model_hypernetwork.set_model_attn2_patch(patch) - return (model_hypernetwork,) + return IO.NodeOutput(model_hypernetwork) -NODE_CLASS_MAPPINGS = { - "HypernetworkLoader": HypernetworkLoader -} + load_hypernetwork = execute # TODO: remove + + +class HyperNetworkExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + HypernetworkLoader, + ] + + +async def comfy_entrypoint() -> HyperNetworkExtension: + return HyperNetworkExtension() diff --git a/comfy_extras/nodes/nodes_latent.py b/comfy_extras/nodes/nodes_latent.py index 773bdcd7f..db6bffb85 100644 --- a/comfy_extras/nodes/nodes_latent.py +++ b/comfy_extras/nodes/nodes_latent.py @@ -7,7 +7,7 @@ from comfy.nodes.package_typing import Seed, Seed64 from .nodes_post_processing import gaussian_kernel from typing_extensions import override from comfy_api.latest import ComfyExtension, io - +import logging def reshape_latent_to(target_shape, latent, repeat_batch=True): if latent.shape[1:] != target_shape[1:]: @@ -452,6 +452,42 @@ class LatentOperationSharpen(io.ComfyNode): return io.NodeOutput(sharpen) +class ReplaceVideoLatentFrames(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ReplaceVideoLatentFrames", + category="latent/batch", + inputs=[ + io.Latent.Input("destination", tooltip="The destination latent where frames will be replaced."), + io.Latent.Input("source", optional=True, tooltip="The source latent providing frames to insert into the destination latent. If not provided, the destination latent is returned unchanged."), + io.Int.Input("index", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1, tooltip="The starting latent frame index in the destination latent where the source latent frames will be placed. Negative values count from the end."), + ], + outputs=[ + io.Latent.Output(), + ], + ) + + @classmethod + def execute(cls, destination, index, source=None) -> io.NodeOutput: + if source is None: + return io.NodeOutput(destination) + dest_frames = destination["samples"].shape[2] + source_frames = source["samples"].shape[2] + if index < 0: + index = dest_frames + index + if index > dest_frames: + logging.warning(f"ReplaceVideoLatentFrames: Index {index} is out of bounds for destination latent frames {dest_frames}.") + return io.NodeOutput(destination) + if index + source_frames > dest_frames: + logging.warning(f"ReplaceVideoLatentFrames: Source latent frames {source_frames} do not fit within destination latent frames {dest_frames} at the specified index {index}.") + return io.NodeOutput(destination) + s = source.copy() + s_source = source["samples"] + s_destination = destination["samples"].clone() + s_destination[:, :, index:index + s_source.shape[2]] = s_source + s["samples"] = s_destination + return io.NodeOutput(s) class LatentExtension(ComfyExtension): @override @@ -470,6 +506,7 @@ class LatentExtension(ComfyExtension): LatentApplyOperationCFG, LatentOperationTonemapReinhard, LatentOperationSharpen, + ReplaceVideoLatentFrames ] diff --git a/comfy_extras/nodes/nodes_load_3d.py b/comfy_extras/nodes/nodes_load_3d.py index 5f85d8dc2..fffb502c0 100644 --- a/comfy_extras/nodes/nodes_load_3d.py +++ b/comfy_extras/nodes/nodes_load_3d.py @@ -2,8 +2,8 @@ import os from comfy.cmd import folder_paths from comfy.nodes import base_nodes as nodes -from comfy.comfy_types import IO -from comfy_api.input_impl import VideoFromFile +from typing_extensions import override +from comfy_api.latest import IO, ComfyExtension, InputImpl, UI from pathlib import Path @@ -12,9 +12,9 @@ def normalize_path(path): return path.replace('\\', '/') -class Load3D(): +class Load3D(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): + def define_schema(cls): input_dir = os.path.join(folder_paths.get_input_directory(), "3d") os.makedirs(input_dir, exist_ok=True) @@ -27,160 +27,85 @@ class Load3D(): for file_path in input_path.rglob("*") if file_path.suffix.lower() in {'.gltf', '.glb', '.obj', '.fbx', '.stl'} ] + return IO.Schema( + node_id="Load3D", + display_name="Load 3D & Animation", + category="3d", + is_experimental=True, + inputs=[ + IO.Combo.Input("model_file", options=sorted(files), upload=IO.UploadType.model), + IO.Load3D.Input("image"), + IO.Int.Input("width", default=1024, min=1, max=4096, step=1), + IO.Int.Input("height", default=1024, min=1, max=4096, step=1), + ], + outputs=[ + IO.Image.Output(display_name="image"), + IO.Mask.Output(display_name="mask"), + IO.String.Output(display_name="mesh_path"), + IO.Image.Output(display_name="normal"), + IO.Load3DCamera.Output(display_name="camera_info"), + IO.Video.Output(display_name="recording_video"), + ], + ) - return {"required": { - "model_file": (sorted(files), {"file_upload": True}), - "image": ("LOAD_3D", {}), - "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - }} - - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) - RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "lineart", "camera_info", "recording_video") - - FUNCTION = "process" - EXPERIMENTAL = True - - CATEGORY = "3d" - - def process(self, model_file, image, **kwargs): + @classmethod + def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput: image_path = folder_paths.get_annotated_filepath(image['image']) mask_path = folder_paths.get_annotated_filepath(image['mask']) normal_path = folder_paths.get_annotated_filepath(image['normal']) - lineart_path = folder_paths.get_annotated_filepath(image['lineart']) load_image_node = nodes.LoadImage() output_image, ignore_mask = load_image_node.load_image(image=image_path) ignore_image, output_mask = load_image_node.load_image(image=mask_path) normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) - lineart_image, ignore_mask3 = load_image_node.load_image(image=lineart_path) video = None if image['recording'] != "": recording_video_path = folder_paths.get_annotated_filepath(image['recording']) - video = VideoFromFile(recording_video_path) + video = InputImpl.VideoFromFile(recording_video_path) - return output_image, output_mask, model_file, normal_image, lineart_image, image['camera_info'], video + return IO.NodeOutput(output_image, output_mask, model_file, normal_image, image['camera_info'], video) + + process = execute # TODO: remove -class Load3DAnimation(): +class Preview3D(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - input_dir = os.path.join(folder_paths.get_input_directory(), "3d") + def define_schema(cls): + return IO.Schema( + node_id="Preview3D", + display_name="Preview 3D & Animation", + category="3d", + is_experimental=True, + is_output_node=True, + inputs=[ + IO.String.Input("model_file", default="", multiline=False), + IO.Load3DCamera.Input("camera_info", optional=True), + IO.Image.Input("bg_image", optional=True), + ], + outputs=[], + ) - os.makedirs(input_dir, exist_ok=True) + @classmethod + def execute(cls, model_file, **kwargs) -> IO.NodeOutput: + camera_info = kwargs.get("camera_info", None) + bg_image = kwargs.get("bg_image", None) + return IO.NodeOutput(ui=UI.PreviewUI3D(model_file, camera_info, bg_image=bg_image)) - input_path = Path(input_dir) - base_path = Path(folder_paths.get_input_directory()) + process = execute # TODO: remove - files = [ - normalize_path(str(file_path.relative_to(base_path))) - for file_path in input_path.rglob("*") - if file_path.suffix.lower() in {'.gltf', '.glb', '.fbx'} + + +class Load3DExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + Load3D, + Preview3D, ] - return {"required": { - "model_file": (sorted(files), {"file_upload": True}), - "image": ("LOAD_3D_ANIMATION", {}), - "width": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - "height": ("INT", {"default": 1024, "min": 1, "max": 4096, "step": 1}), - }} - RETURN_TYPES = ("IMAGE", "MASK", "STRING", "IMAGE", "LOAD3D_CAMERA", IO.VIDEO) - RETURN_NAMES = ("image", "mask", "mesh_path", "normal", "camera_info", "recording_video") - - FUNCTION = "process" - EXPERIMENTAL = True - - CATEGORY = "3d" - - def process(self, model_file, image, **kwargs): - image_path = folder_paths.get_annotated_filepath(image['image']) - mask_path = folder_paths.get_annotated_filepath(image['mask']) - normal_path = folder_paths.get_annotated_filepath(image['normal']) - - load_image_node = nodes.LoadImage() - output_image, ignore_mask = load_image_node.load_image(image=image_path) - ignore_image, output_mask = load_image_node.load_image(image=mask_path) - normal_image, ignore_mask2 = load_image_node.load_image(image=normal_path) - - video = None - - if image['recording'] != "": - recording_video_path = folder_paths.get_annotated_filepath(image['recording']) - - video = VideoFromFile(recording_video_path) - - return output_image, output_mask, model_file, normal_image, image['camera_info'], video - - -class Preview3D(): - @classmethod - def INPUT_TYPES(s): - return {"required": { - "model_file": ("STRING", {"default": "", "multiline": False}), - }, - "optional": { - "camera_info": ("LOAD3D_CAMERA", {}) - }} - - OUTPUT_NODE = True - RETURN_TYPES = () - - CATEGORY = "3d" - - FUNCTION = "process" - EXPERIMENTAL = True - - def process(self, model_file, **kwargs): - camera_info = kwargs.get("camera_info", None) - - return { - "ui": { - "result": [model_file, camera_info] - } - } - -class Preview3DAnimation(): - @classmethod - def INPUT_TYPES(s): - return {"required": { - "model_file": ("STRING", {"default": "", "multiline": False}), - }, - "optional": { - "camera_info": ("LOAD3D_CAMERA", {}) - }} - - OUTPUT_NODE = True - RETURN_TYPES = () - - CATEGORY = "3d" - - FUNCTION = "process" - EXPERIMENTAL = True - - def process(self, model_file, **kwargs): - camera_info = kwargs.get("camera_info", None) - - return { - "ui": { - "result": [model_file, camera_info] - } - } - - -NODE_CLASS_MAPPINGS = { - "Load3D": Load3D, - "Load3DAnimation": Load3DAnimation, - "Preview3D": Preview3D, - "Preview3DAnimation": Preview3DAnimation -} - -NODE_DISPLAY_NAME_MAPPINGS = { - "Load3D": "Load 3D", - "Load3DAnimation": "Load 3D - Animation", - "Preview3D": "Preview 3D", - "Preview3DAnimation": "Preview 3D - Animation" -} +async def comfy_entrypoint() -> Load3DExtension: + return Load3DExtension() diff --git a/comfy_extras/nodes/nodes_mask.py b/comfy_extras/nodes/nodes_mask.py index 98d7bd120..787538a00 100644 --- a/comfy_extras/nodes/nodes_mask.py +++ b/comfy_extras/nodes/nodes_mask.py @@ -1,14 +1,12 @@ import numpy as np -import random import scipy.ndimage import torch +from typing_extensions import override from comfy import node_helpers from comfy import utils -from comfy.cmd import folder_paths -from comfy.component_model.tensor_types import MaskBatch, RGBImageBatch -from comfy.nodes.base_nodes import SaveImage -from comfy.nodes.common import MAX_RESOLUTION +from comfy.nodes import base_nodes as nodes +from comfy_api.latest import ComfyExtension, IO, UI def composite(destination, source, x, y, mask=None, multiplier=8, resize_source=False): @@ -49,212 +47,213 @@ def composite(destination, source, x, y, mask=None, multiplier=8, resize_source= return destination -class LatentCompositeMasked: +class LatentCompositeMasked(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "destination": ("LATENT",), - "source": ("LATENT",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), - "resize_source": ("BOOLEAN", {"default": False}), - }, - "optional": { - "mask": ("MASK",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="LatentCompositeMasked", + category="latent", + inputs=[ + IO.Latent.Input("destination"), + IO.Latent.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=8), + IO.Boolean.Input("resize_source", default=False), + IO.Mask.Input("mask", optional=True), + ], + outputs=[IO.Latent.Output()], + ) - RETURN_TYPES = ("LATENT",) - FUNCTION = "composite" - - CATEGORY = "latent" - - def composite(self, destination, source, x, y, resize_source, mask=None): + @classmethod + def execute(cls, destination, source, x, y, resize_source, mask=None) -> IO.NodeOutput: output = destination.copy() destination = destination["samples"].clone() source = source["samples"] output["samples"] = composite(destination, source, x, y, mask, 8, resize_source) - return (output,) + return IO.NodeOutput(output) + + composite = execute # TODO: remove -class ImageCompositeMasked: +class ImageCompositeMasked(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "destination": ("IMAGE",), - "source": ("IMAGE",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "resize_source": ("BOOLEAN", {"default": False}), - }, - "optional": { - "mask": ("MASK",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ImageCompositeMasked", + category="image", + inputs=[ + IO.Image.Input("destination"), + IO.Image.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Boolean.Input("resize_source", default=False), + IO.Mask.Input("mask", optional=True), + ], + outputs=[IO.Image.Output()], + ) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "composite" - - CATEGORY = "image" - - def composite(self, destination, source, x, y, resize_source, mask=None): + @classmethod + def execute(cls, destination, source, x, y, resize_source, mask=None) -> IO.NodeOutput: destination, source = node_helpers.image_alpha_fix(destination, source) destination = destination.clone().movedim(-1, 1) output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) - return (output,) + return IO.NodeOutput(output) + + composite = execute # TODO: remove -class MaskToImage: +class MaskToImage(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="MaskToImage", + display_name="Convert Mask to Image", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + ], + outputs=[IO.Image.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "mask_to_image" - - def mask_to_image(self, mask: MaskBatch) -> tuple[RGBImageBatch]: + @classmethod + def execute(cls, mask) -> IO.NodeOutput: result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) - return (result,) + return IO.NodeOutput(result) + + mask_to_image = execute # TODO: remove -class ImageToMask: +class ImageToMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "channel": (["red", "green", "blue", "alpha"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ImageToMask", + display_name="Convert Image to Mask", + category="mask", + inputs=[ + IO.Image.Input("image"), + IO.Combo.Input("channel", options=["red", "green", "blue", "alpha"]), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, image, channel): + @classmethod + def execute(cls, image, channel) -> IO.NodeOutput: channels = ["red", "green", "blue", "alpha"] mask = image[:, :, :, channels.index(channel)] - return (mask,) + return IO.NodeOutput(mask) + + image_to_mask = execute # TODO: remove -class ImageColorToMask: +class ImageColorToMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "image": ("IMAGE",), - "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ImageColorToMask", + category="mask", + inputs=[ + IO.Image.Input("image"), + IO.Int.Input("color", default=0, min=0, max=0xFFFFFF, step=1, display_mode=IO.NumberDisplay.number), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, image, color): + @classmethod + def execute(cls, image, color) -> IO.NodeOutput: temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int) temp = torch.bitwise_left_shift(temp[:, :, :, 0], 16) + torch.bitwise_left_shift(temp[:, :, :, 1], 8) + temp[:, :, :, 2] mask = torch.where(temp == color, 1.0, 0).float() - return (mask,) + return IO.NodeOutput(mask) + + image_to_mask = execute # TODO: remove -class SolidMask: +class SolidMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="SolidMask", + category="mask", + inputs=[ + IO.Float.Input("value", default=1.0, min=0.0, max=1.0, step=0.01), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "solid" - - def solid(self, value, width, height): + @classmethod + def execute(cls, value, width, height) -> IO.NodeOutput: out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu") - return (out,) + return IO.NodeOutput(out) + + solid = execute # TODO: remove -class InvertMask: +class InvertMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - } - } + def define_schema(cls): + return IO.Schema( + node_id="InvertMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "invert" - - def invert(self, mask): + @classmethod + def execute(cls, mask) -> IO.NodeOutput: out = 1.0 - mask - return (out,) + return IO.NodeOutput(out) + + invert = execute # TODO: remove -class CropMask: +class CropMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="CropMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("width", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("height", default=512, min=1, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "crop" - - def crop(self, mask, x, y, width, height): + @classmethod + def execute(cls, mask, x, y, width, height) -> IO.NodeOutput: mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) out = mask[:, y:y + height, x:x + width] - return (out,) + return IO.NodeOutput(out) + + crop = execute # TODO: remove -class MaskComposite: +class MaskComposite(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "destination": ("MASK",), - "source": ("MASK",), - "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), - } - } + def define_schema(cls): + return IO.Schema( + node_id="MaskComposite", + category="mask", + inputs=[ + IO.Mask.Input("destination"), + IO.Mask.Input("source"), + IO.Int.Input("x", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("y", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Combo.Input("operation", options=["multiply", "add", "subtract", "and", "or", "xor"]), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "combine" - - def combine(self, destination, source, x, y, operation): + @classmethod + def execute(cls, destination, source, x, y, operation) -> IO.NodeOutput: output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() source = source.reshape((-1, source.shape[-2], source.shape[-1])) @@ -280,29 +279,29 @@ class MaskComposite: output = torch.clamp(output, 0.0, 1.0) - return (output,) + return IO.NodeOutput(output) + + combine = execute # TODO: remove -class FeatherMask: +class FeatherMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="FeatherMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("left", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("top", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("right", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + IO.Int.Input("bottom", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "feather" - - def feather(self, mask, left, top, right, bottom): + @classmethod + def execute(cls, mask, left, top, right, bottom) -> IO.NodeOutput: output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() left = min(left, output.shape[-1]) @@ -326,27 +325,28 @@ class FeatherMask: feather_rate = (y + 1) / bottom output[:, -y, :] *= feather_rate - return (output,) + return IO.NodeOutput(output) + + feather = execute # TODO: remove -class GrowMask: +class GrowMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(cls): - return { - "required": { - "mask": ("MASK",), - "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), - "tapered_corners": ("BOOLEAN", {"default": True}), - }, - } + def define_schema(cls): + return IO.Schema( + node_id="GrowMask", + display_name="Grow Mask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Int.Input("expand", default=0, min=-nodes.MAX_RESOLUTION, max=nodes.MAX_RESOLUTION, step=1), + IO.Boolean.Input("tapered_corners", default=True), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - - FUNCTION = "expand_mask" - - def expand_mask(self, mask, expand, tapered_corners): + @classmethod + def execute(cls, mask, expand, tapered_corners) -> IO.NodeOutput: c = 0 if tapered_corners else 1 kernel = np.array([[c, 1, c], [1, 1, 1], @@ -362,71 +362,74 @@ class GrowMask: output = scipy.ndimage.grey_dilation(output, footprint=kernel) output = torch.from_numpy(output) out.append(output) - return (torch.stack(out, dim=0),) + return IO.NodeOutput(torch.stack(out, dim=0)) + + expand_mask = execute # TODO: remove -class ThresholdMask: +class ThresholdMask(IO.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "mask": ("MASK",), - "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), - } - } + def define_schema(cls): + return IO.Schema( + node_id="ThresholdMask", + category="mask", + inputs=[ + IO.Mask.Input("mask"), + IO.Float.Input("value", default=0.5, min=0.0, max=1.0, step=0.01), + ], + outputs=[IO.Mask.Output()], + ) - CATEGORY = "mask" - - RETURN_TYPES = ("MASK",) - FUNCTION = "image_to_mask" - - def image_to_mask(self, mask, value): + @classmethod + def execute(cls, mask, value) -> IO.NodeOutput: mask = (mask > value).float() - return (mask,) + return IO.NodeOutput(mask) + + image_to_mask = execute # TODO: remove # Mask Preview - original implement from # https://github.com/cubiq/ComfyUI_essentials/blob/9d9f4bedfc9f0321c19faf71855e228c93bd0dc9/mask.py#L81 # upstream requested in https://github.com/Kosinkadink/rfcs/blob/main/rfcs/0000-corenodes.md#preview-nodes -class MaskPreview(SaveImage): - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() - self.type = "temp" - self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) - self.compress_level = 4 +class MaskPreview(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="MaskPreview", + display_name="Preview Mask", + category="mask", + description="Saves the input images to your ComfyUI output directory.", + inputs=[ + IO.Mask.Input("mask"), + ], + hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo], + is_output_node=True, + ) @classmethod - def INPUT_TYPES(s): - return { - "required": {"mask": ("MASK",), }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - FUNCTION = "execute" - CATEGORY = "mask" - - def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): - preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) - return self.save_images(preview, filename_prefix, prompt, extra_pnginfo) + def execute(cls, mask, filename_prefix="ComfyUI") -> IO.NodeOutput: + return IO.NodeOutput(ui=UI.PreviewMask(mask)) -NODE_CLASS_MAPPINGS = { - "LatentCompositeMasked": LatentCompositeMasked, - "ImageCompositeMasked": ImageCompositeMasked, - "MaskToImage": MaskToImage, - "ImageToMask": ImageToMask, - "ImageColorToMask": ImageColorToMask, - "SolidMask": SolidMask, - "InvertMask": InvertMask, - "CropMask": CropMask, - "MaskComposite": MaskComposite, - "FeatherMask": FeatherMask, - "GrowMask": GrowMask, - "ThresholdMask": ThresholdMask, - "MaskPreview": MaskPreview -} +class MaskExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + LatentCompositeMasked, + ImageCompositeMasked, + MaskToImage, + ImageToMask, + ImageColorToMask, + SolidMask, + InvertMask, + CropMask, + MaskComposite, + FeatherMask, + GrowMask, + ThresholdMask, + MaskPreview, + ] -NODE_DISPLAY_NAME_MAPPINGS = { - "ImageToMask": "Convert Image to Mask", - "MaskToImage": "Convert Mask to Image", -} + +async def comfy_entrypoint() -> MaskExtension: + return MaskExtension() diff --git a/comfy_extras/nodes/nodes_model_downscale.py b/comfy_extras/nodes/nodes_model_downscale.py index cce02c762..1421d1cc0 100644 --- a/comfy_extras/nodes/nodes_model_downscale.py +++ b/comfy_extras/nodes/nodes_model_downscale.py @@ -55,12 +55,6 @@ class PatchModelAddDownscale(io.ComfyNode): return io.NodeOutput(m) -NODE_DISPLAY_NAME_MAPPINGS = { - # Sampling - "PatchModelAddDownscale": "", -} - - class ModelDownscaleExtension(ComfyExtension): @override async def get_node_list(self) -> list[type[io.ComfyNode]]: diff --git a/comfy_extras/nodes/nodes_model_patch.py b/comfy_extras/nodes/nodes_model_patch.py index 8723b8981..8e86d0ff4 100644 --- a/comfy_extras/nodes/nodes_model_patch.py +++ b/comfy_extras/nodes/nodes_model_patch.py @@ -6,6 +6,7 @@ import comfy.ops import comfy.model_management import comfy.ldm.common_dit import comfy.latent_formats +import comfy.ldm.lumina.controlnet from comfy.model_patcher import ModelPatcher @@ -191,6 +192,36 @@ class SigLIPMultiFeatProjModel(torch.nn.Module): return embedding +def z_image_convert(sd): + replace_keys = {".attention.to_out.0.bias": ".attention.out.bias", + ".attention.norm_k.weight": ".attention.k_norm.weight", + ".attention.norm_q.weight": ".attention.q_norm.weight", + ".attention.to_out.0.weight": ".attention.out.weight" + } + + out_sd = {} + for k in sorted(sd.keys()): + w = sd[k] + + k_out = k + if k_out.endswith(".attention.to_k.weight"): + cc = [w] + continue + if k_out.endswith(".attention.to_q.weight"): + cc = [w] + cc + continue + if k_out.endswith(".attention.to_v.weight"): + cc = cc + [w] + w = torch.cat(cc, dim=0) + k_out = k_out.replace(".attention.to_v.weight", ".attention.qkv.weight") + + for r, rr in replace_keys.items(): + k_out = k_out.replace(r, rr) + out_sd[k_out] = w + + return out_sd + + class ModelPatchLoader: @classmethod def INPUT_TYPES(s): @@ -215,6 +246,9 @@ class ModelPatchLoader: elif 'feature_embedder.mid_layer_norm.bias' in sd: sd = comfy.utils.state_dict_prefix_replace(sd, {"feature_embedder.": ""}, filter_keys=True) model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) + elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet + sd = z_image_convert(sd) + model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast) model.load_state_dict(sd) model = ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device()) @@ -268,6 +302,70 @@ class DiffSynthCnetPatch: return [self.model_patch] +class ZImageControlPatch: + def __init__(self, model_patch, vae, image, strength): + self.model_patch = model_patch + self.vae = vae + self.image = image + self.strength = strength + self.encoded_image = self.encode_latent_cond(image) + self.encoded_image_size = (image.shape[1], image.shape[2]) + self.temp_data = None + + def encode_latent_cond(self, image): + latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image)) + return latent_image + + def __call__(self, kwargs): + x = kwargs.get("x") + img = kwargs.get("img") + txt = kwargs.get("txt") + pe = kwargs.get("pe") + vec = kwargs.get("vec") + block_index = kwargs.get("block_index") + spacial_compression = self.vae.spacial_compression_encode() + if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression): + image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center") + loaded_models = comfy.model_management.loaded_models(only_currently_used=True) + self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1)) + self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1]) + comfy.model_management.load_models_gpu(loaded_models) + + cnet_index = (block_index // 5) + cnet_index_float = (block_index / 5) + + kwargs.pop("img") # we do ops in place + kwargs.pop("txt") + + cnet_blocks = self.model_patch.model.n_control_layers + if cnet_index_float > (cnet_blocks - 1): + self.temp_data = None + return kwargs + + if self.temp_data is None or self.temp_data[0] > cnet_index: + self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec))) + + while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks: + next_layer = self.temp_data[0] + 1 + self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec)) + + if cnet_index_float == self.temp_data[0]: + img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength) + if cnet_blocks == self.temp_data[0] + 1: + self.temp_data = None + + return kwargs + + def to(self, device_or_dtype): + if isinstance(device_or_dtype, torch.device): + self.encoded_image = self.encoded_image.to(device_or_dtype) + self.temp_data = None + return self + + def models(self): + return [self.model_patch] + + class QwenImageDiffsynthControlnet: @classmethod def INPUT_TYPES(s): @@ -295,7 +393,10 @@ class QwenImageDiffsynthControlnet: mask = mask.unsqueeze(2) mask = 1.0 - mask - model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) + if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control): + model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength)) + else: + model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask)) return (model_patched,) diff --git a/comfy_extras/nodes/nodes_preview_any.py b/comfy_extras/nodes/nodes_preview_any.py index e749fa6ae..139b07c93 100644 --- a/comfy_extras/nodes/nodes_preview_any.py +++ b/comfy_extras/nodes/nodes_preview_any.py @@ -39,5 +39,5 @@ NODE_CLASS_MAPPINGS = { } NODE_DISPLAY_NAME_MAPPINGS = { - "PreviewAny": "Preview Any", + "PreviewAny": "Preview as Text", } diff --git a/comfy_extras/nodes/nodes_train.py b/comfy_extras/nodes/nodes_train.py index e4d7a4ff5..609a1ea74 100644 --- a/comfy_extras/nodes/nodes_train.py +++ b/comfy_extras/nodes/nodes_train.py @@ -1,28 +1,22 @@ -import datetime -import json import logging import os +import tqdm import numpy as np import safetensors -import torch -from PIL import Image, ImageDraw, ImageFont -from PIL.PngImagePlugin import PngInfo import torch.utils.checkpoint -import tqdm +from PIL import Image, ImageDraw, ImageFont +import comfy.model_management import comfy.samplers import comfy.sd import comfy.utils -import comfy.model_management +import comfy_extras.nodes.nodes_custom_sampler from comfy.cmd import folder_paths -from comfy import node_helpers -from comfy.cli_args import args -from comfy.comfy_types.node_typing import IO -from comfy.execution_context import current_execution_context from comfy.weight_adapter import adapters, adapter_maps -from . import nodes_custom_sampler -from .nodes_custom_sampler import Noise_RandomNoise +from comfy_api.latest import ui +from .nodes_custom_sampler import * +from comfy.utils import ProgressBar def make_batch_extra_option_dict(d, indicies, full_size=None): @@ -58,7 +52,18 @@ def process_cond_list(d, prefix=""): class TrainSampler(comfy.samplers.Sampler): - def __init__(self, loss_fn, optimizer, loss_callback=None, batch_size=1, grad_acc=1, total_steps=1, seed=0, training_dtype=torch.bfloat16): + def __init__( + self, + loss_fn, + optimizer, + loss_callback=None, + batch_size=1, + grad_acc=1, + total_steps=1, + seed=0, + training_dtype=torch.bfloat16, + real_dataset=None, + ): self.loss_fn = loss_fn self.optimizer = optimizer self.loss_callback = loss_callback @@ -67,54 +72,139 @@ class TrainSampler(comfy.samplers.Sampler): self.grad_acc = grad_acc self.seed = seed self.training_dtype = training_dtype + self.real_dataset: list[torch.Tensor] | None = real_dataset - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + def fwd_bwd( + self, + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ): + xt = model_wrap.inner_model.model_sampling.noise_scaling( + batch_sigmas, batch_noise, batch_latent, False + ) + x0 = model_wrap.inner_model.model_sampling.noise_scaling( + torch.zeros_like(batch_sigmas), + torch.zeros_like(batch_noise), + batch_latent, + False, + ) + + model_wrap.conds["positive"] = [cond[i] for i in indicies] + batch_extra_args = make_batch_extra_option_dict( + extra_args, indicies, full_size=dataset_size + ) + + with torch.autocast(xt.device.type, dtype=self.training_dtype): + x0_pred = model_wrap( + xt.requires_grad_(True), + batch_sigmas.requires_grad_(True), + **batch_extra_args, + ) + loss = self.loss_fn(x0_pred, x0) + if bwd: + bwd_loss = loss / self.grad_acc + bwd_loss.backward() + return loss + + def sample( + self, + model_wrap, + sigmas, + extra_args, + callback, + noise, + latent_image=None, + denoise_mask=None, + disable_pbar=False, + ): model_wrap.conds = process_cond_list(model_wrap.conds) cond = model_wrap.conds["positive"] dataset_size = sigmas.size(0) torch.cuda.empty_cache() - for i in (pbar:=tqdm.trange(self.total_steps, desc="Training LoRA", smoothing=0.01, disable=not current_execution_context().server.receive_all_progress_notifications)): - noisegen = Noise_RandomNoise(self.seed + i * 1000) - indicies = torch.randperm(dataset_size)[:self.batch_size].tolist() - - batch_latent = torch.stack([latent_image[i] for i in indicies]) - batch_noise = noisegen.generate_noise({"samples": batch_latent}).to(batch_latent.device) - batch_sigmas = [ - model_wrap.inner_model.model_sampling.percent_to_sigma( - torch.rand((1,)).item() - ) for _ in range(min(self.batch_size, dataset_size)) - ] - batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - - xt = model_wrap.inner_model.model_sampling.noise_scaling( - batch_sigmas, - batch_noise, - batch_latent, - False + ui_pbar = ProgressBar(self.total_steps) + for i in ( + pbar := tqdm.trange( + self.total_steps, + desc="Training LoRA", + smoothing=0.01, + disable=not current_execution_context().server.receive_all_progress_notifications ) - x0 = model_wrap.inner_model.model_sampling.noise_scaling( - torch.zeros_like(batch_sigmas), - torch.zeros_like(batch_noise), - batch_latent, - False + ): + noisegen = Noise_RandomNoise( + self.seed + i * 1000 ) + indicies = torch.randperm(dataset_size)[: self.batch_size].tolist() - model_wrap.conds["positive"] = [ - cond[i] for i in indicies - ] - batch_extra_args = make_batch_extra_option_dict(extra_args, indicies, full_size=dataset_size) + if self.real_dataset is None: + batch_latent = torch.stack([latent_image[i] for i in indicies]) + batch_noise = noisegen.generate_noise({"samples": batch_latent}).to( + batch_latent.device + ) + batch_sigmas = [ + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + for _ in range(min(self.batch_size, dataset_size)) + ] + batch_sigmas = torch.tensor(batch_sigmas).to(batch_latent.device) - with torch.autocast(xt.device.type, dtype=self.training_dtype): - x0_pred = model_wrap(xt, batch_sigmas, **batch_extra_args) - loss = self.loss_fn(x0_pred, x0) - loss.backward() - if self.loss_callback: - self.loss_callback(loss.item()) - pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + batch_latent, + cond, + indicies, + extra_args, + dataset_size, + bwd=True, + ) + if self.loss_callback: + self.loss_callback(loss.item()) + pbar.set_postfix({"loss": f"{loss.item():.4f}"}) + else: + # todo: should this be "0" or scalar_tensor? + total_loss = torch.tensor(0.0) + for index in indicies: + single_latent = self.real_dataset[index].to(latent_image) + batch_noise = noisegen.generate_noise( + {"samples": single_latent} + ).to(single_latent.device) + batch_sigmas = ( + model_wrap.inner_model.model_sampling.percent_to_sigma( + torch.rand((1,)).item() + ) + ) + batch_sigmas = torch.tensor([batch_sigmas]).to(single_latent.device) + loss = self.fwd_bwd( + model_wrap, + batch_sigmas, + batch_noise, + single_latent, + cond, + [index], + extra_args, + dataset_size, + bwd=False, + ) + total_loss += loss + total_loss = total_loss / self.grad_acc / len(indicies) + total_loss.backward() + if self.loss_callback: + self.loss_callback(total_loss.item()) + pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) - if (i+1) % self.grad_acc == 0: + if (i + 1) % self.grad_acc == 0: self.optimizer.step() self.optimizer.zero_grad() + ui_pbar.update(1) torch.cuda.empty_cache() return torch.zeros_like(latent_image) @@ -136,233 +226,6 @@ class BiasDiff(torch.nn.Module): return self.passive_memory_usage() -def load_and_process_images(image_files, input_dir, resize_method="None", w=None, h=None): - """Utility function to load and process a list of images. - - Args: - image_files: List of image filenames - input_dir: Base directory containing the images - resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") - - Returns: - torch.Tensor: Batch of processed images - """ - if not image_files: - raise ValueError("No valid images found in input") - - output_images = [] - - for file in image_files: - image_path = os.path.join(input_dir, file) - img = node_helpers.pillow(Image.open, image_path) - - if img.mode == "I": - img = img.point(lambda i: i * (1 / 255)) - img = img.convert("RGB") - - if w is None and h is None: - w, h = img.size[0], img.size[1] - - # Resize image to first image - if img.size[0] != w or img.size[1] != h: - if resize_method == "Stretch": - img = img.resize((w, h), Image.Resampling.LANCZOS) - elif resize_method == "Crop": - img = img.crop((0, 0, w, h)) - elif resize_method == "Pad": - img = img.resize((w, h), Image.Resampling.LANCZOS) - elif resize_method == "None": - raise ValueError( - "Your input image size does not match the first image in the dataset. Either select a valid resize method or use the same size for all images." - ) - - img_array = np.array(img).astype(np.float32) / 255.0 - img_tensor = torch.from_numpy(img_array)[None,] - output_images.append(img_tensor) - - return torch.cat(output_images, dim=0) - - -class LoadImageSetNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "images": ( - [ - f - for f in os.listdir(folder_paths.get_input_directory()) - if f.endswith((".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff")) - ], - {"image_upload": True, "allow_batch": True}, - ) - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - }, - } - - INPUT_IS_LIST = True - RETURN_TYPES = ("IMAGE",) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images from a directory for training." - - @classmethod - def VALIDATE_INPUTS(s, images, resize_method): - filenames = images[0] if isinstance(images[0], list) else images - - for image in filenames: - if not folder_paths.exists_annotated_filepath(image): - return "Invalid image file: {}".format(image) - return True - - def load_images(self, input_files, resize_method): - input_dir = folder_paths.get_input_directory() - valid_extensions = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif", ".jpe", ".apng", ".tif", ".tiff"] - image_files = [ - f - for f in input_files - if any(f.lower().endswith(ext) for ext in valid_extensions) - ] - output_tensor = load_and_process_images(image_files, input_dir, resize_method) - return (output_tensor,) - - -class LoadImageSetFromFolderNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}) - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - }, - } - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images from a directory for training." - - def load_images(self, folder, resize_method): - sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) - valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] - image_files = [ - f - for f in os.listdir(sub_input_dir) - if any(f.lower().endswith(ext) for ext in valid_extensions) - ] - output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method) - return (output_tensor,) - - -class LoadImageTextSetFromFolderNode: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "folder": (folder_paths.get_input_subfolders(), {"tooltip": "The folder to load images from."}), - "clip": (IO.CLIP, {"tooltip": "The CLIP model used for encoding the text."}), - }, - "optional": { - "resize_method": ( - ["None", "Stretch", "Crop", "Pad"], - {"default": "None"}, - ), - "width": ( - IO.INT, - { - "default": -1, - "min": -1, - "max": 10000, - "step": 1, - "tooltip": "The width to resize the images to. -1 means use the original width.", - }, - ), - "height": ( - IO.INT, - { - "default": -1, - "min": -1, - "max": 10000, - "step": 1, - "tooltip": "The height to resize the images to. -1 means use the original height.", - }, - ) - }, - } - - RETURN_TYPES = ("IMAGE", IO.CONDITIONING,) - FUNCTION = "load_images" - CATEGORY = "loaders" - EXPERIMENTAL = True - DESCRIPTION = "Loads a batch of images and caption from a directory for training." - - def load_images(self, folder, clip, resize_method, width=None, height=None): - if clip is None: - raise RuntimeError("ERROR: clip input is invalid: None\n\nIf the clip is from a checkpoint loader node your checkpoint does not contain a valid clip or text encoder model.") - - logging.info(f"Loading images from folder: {folder}") - - sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) - valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] - - image_files = [] - for item in os.listdir(sub_input_dir): - path = os.path.join(sub_input_dir, item) - if any(item.lower().endswith(ext) for ext in valid_extensions): - image_files.append(path) - elif os.path.isdir(path): - # Support kohya-ss/sd-scripts folder structure - repeat = 1 - if item.split("_")[0].isdigit(): - repeat = int(item.split("_")[0]) - image_files.extend([ - os.path.join(path, f) for f in os.listdir(path) if any(f.lower().endswith(ext) for ext in valid_extensions) - ] * repeat) - - caption_file_path = [ - f.replace(os.path.splitext(f)[1], ".txt") - for f in image_files - ] - captions = [] - for caption_file in caption_file_path: - caption_path = os.path.join(sub_input_dir, caption_file) - if os.path.exists(caption_path): - with open(caption_path, "r", encoding="utf-8") as f: - caption = f.read().strip() - captions.append(caption) - else: - captions.append("") - - width = width if width != -1 else None - height = height if height != -1 else None - output_tensor = load_and_process_images(image_files, sub_input_dir, resize_method, width, height) - - logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") - - logging.info(f"Encoding captions from {sub_input_dir}.") - conditions = [] - empty_cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) - for text in captions: - if text == "": - conditions.append(empty_cond) - tokens = clip.tokenize(text) - conditions.extend(clip.encode_from_tokens_scheduled(tokens)) - logging.info(f"Encoded {len(conditions)} captions from {sub_input_dir}.") - return (output_tensor, conditions) - - def draw_loss_graph(loss_map, steps): width, height = 500, 300 img = Image.new("RGB", (width, height), "white") @@ -381,10 +244,14 @@ def draw_loss_graph(loss_map, steps): return img -def find_all_highest_child_module_with_forward(model: torch.nn.Module, result = None, name = None): +def find_all_highest_child_module_with_forward( + model: torch.nn.Module, result=None, name=None +): if result is None: result = [] - elif hasattr(model, "forward") and not isinstance(model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict)): + elif hasattr(model, "forward") and not isinstance( + model, (torch.nn.ModuleList, torch.nn.Sequential, torch.nn.ModuleDict) + ): result.append(model) logging.debug(f"Found module with forward: {name} ({model.__class__.__name__})") return result @@ -398,12 +265,13 @@ def patch(m): if not hasattr(m, "forward"): return org_forward = m.forward + def fwd(args, kwargs): return org_forward(*args, **kwargs) + def checkpointing_fwd(*args, **kwargs): - return torch.utils.checkpoint.checkpoint( - fwd, args, kwargs, use_reentrant=False - ) + return torch.utils.checkpoint.checkpoint(fwd, args, kwargs, use_reentrant=False) + m.org_forward = org_forward m.forward = checkpointing_fwd @@ -414,154 +282,211 @@ def unpatch(m): del m.org_forward -class TrainLoraNode: +class TrainLoraNode(io.ComfyNode): @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": (IO.MODEL, {"tooltip": "The model to train the LoRA on."}), - "latents": ( - "LATENT", - { - "tooltip": "The Latents to use for training, serve as dataset/input of the model." - }, + def define_schema(cls): + return io.Schema( + node_id="TrainLoraNode", + display_name="Train LoRA", + category="training", + is_experimental=True, + is_input_list=True, # All inputs become lists + inputs=[ + io.Model.Input("model", tooltip="The model to train the LoRA on."), + io.Latent.Input( + "latents", + tooltip="The Latents to use for training, serve as dataset/input of the model.", ), - "positive": ( - IO.CONDITIONING, - {"tooltip": "The positive conditioning to use for training."}, + io.Conditioning.Input( + "positive", tooltip="The positive conditioning to use for training." ), - "batch_size": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 10000, - "step": 1, - "tooltip": "The batch size to use for training.", - }, + io.Int.Input( + "batch_size", + default=1, + min=1, + max=10000, + tooltip="The batch size to use for training.", ), - "grad_accumulation_steps": ( - IO.INT, - { - "default": 1, - "min": 1, - "max": 1024, - "step": 1, - "tooltip": "The number of gradient accumulation steps to use for training.", - } + io.Int.Input( + "grad_accumulation_steps", + default=1, + min=1, + max=1024, + tooltip="The number of gradient accumulation steps to use for training.", ), - "steps": ( - IO.INT, - { - "default": 16, - "min": 1, - "max": 100000, - "tooltip": "The number of steps to train the LoRA for.", - }, + io.Int.Input( + "steps", + default=16, + min=1, + max=100000, + tooltip="The number of steps to train the LoRA for.", ), - "learning_rate": ( - IO.FLOAT, - { - "default": 0.0005, - "min": 0.0000001, - "max": 1.0, - "step": 0.000001, - "tooltip": "The learning rate to use for training.", - }, + io.Float.Input( + "learning_rate", + default=0.0005, + min=0.0000001, + max=1.0, + step=0.0000001, + tooltip="The learning rate to use for training.", ), - "rank": ( - IO.INT, - { - "default": 8, - "min": 1, - "max": 128, - "tooltip": "The rank of the LoRA layers.", - }, + io.Int.Input( + "rank", + default=8, + min=1, + max=128, + tooltip="The rank of the LoRA layers.", ), - "optimizer": ( - ["AdamW", "Adam", "SGD", "RMSprop"], - { - "default": "AdamW", - "tooltip": "The optimizer to use for training.", - }, + io.Combo.Input( + "optimizer", + options=["AdamW", "Adam", "SGD", "RMSprop"], + default="AdamW", + tooltip="The optimizer to use for training.", ), - "loss_function": ( - ["MSE", "L1", "Huber", "SmoothL1"], - { - "default": "MSE", - "tooltip": "The loss function to use for training.", - }, + io.Combo.Input( + "loss_function", + options=["MSE", "L1", "Huber", "SmoothL1"], + default="MSE", + tooltip="The loss function to use for training.", ), - "seed": ( - IO.INT, - { - "default": 0, - "min": 0, - "max": 0xFFFFFFFFFFFFFFFF, - "tooltip": "The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", - }, + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="The seed to use for training (used in generator for LoRA weight initialization and noise sampling)", ), - "training_dtype": ( - ["bf16", "fp32"], - {"default": "bf16", "tooltip": "The dtype to use for training."}, + io.Combo.Input( + "training_dtype", + options=["bf16", "fp32"], + default="bf16", + tooltip="The dtype to use for training.", ), - "lora_dtype": ( - ["bf16", "fp32"], - {"default": "bf16", "tooltip": "The dtype to use for lora."}, + io.Combo.Input( + "lora_dtype", + options=["bf16", "fp32"], + default="bf16", + tooltip="The dtype to use for lora.", ), - "algorithm": ( - list(adapter_maps.keys()), - {"default": list(adapter_maps.keys())[0], "tooltip": "The algorithm to use for training."}, + io.Combo.Input( + "algorithm", + options=list(adapter_maps.keys()), + default=list(adapter_maps.keys())[0], + tooltip="The algorithm to use for training.", ), - "gradient_checkpointing": ( - IO.BOOLEAN, - { - "default": True, - "tooltip": "Use gradient checkpointing for training.", - } + io.Boolean.Input( + "gradient_checkpointing", + default=True, + tooltip="Use gradient checkpointing for training.", ), - "existing_lora": ( - folder_paths.get_filename_list("loras") + ["[None]"], - { - "default": "[None]", - "tooltip": "The existing LoRA to append to. Set to None for new LoRA.", - }, + io.Combo.Input( + "existing_lora", + options=folder_paths.get_filename_list("loras") + ["[None]"], + default="[None]", + tooltip="The existing LoRA to append to. Set to None for new LoRA.", ), - }, - } + ], + outputs=[ + io.Model.Output( + display_name="model", tooltip="Model with LoRA applied" + ), + io.Custom("LORA_MODEL").Output( + display_name="lora", tooltip="LoRA weights" + ), + io.Custom("LOSS_MAP").Output( + display_name="loss_map", tooltip="Loss history" + ), + io.Int.Output(display_name="steps", tooltip="Total training steps"), + ], + ) - RETURN_TYPES = (IO.MODEL, IO.LORA_MODEL, IO.LOSS_MAP, IO.INT) - RETURN_NAMES = ("model_with_lora", "lora", "loss", "steps") - FUNCTION = "train" - CATEGORY = "training" - EXPERIMENTAL = True - - def train( - self, - model, - latents, - positive, - batch_size, - steps, - grad_accumulation_steps, - learning_rate, - rank, - optimizer, - loss_function, - seed, - training_dtype, - lora_dtype, - algorithm, - gradient_checkpointing, - existing_lora, + @classmethod + def execute( + cls, + model, + latents, + positive, + batch_size, + steps, + grad_accumulation_steps, + learning_rate, + rank, + optimizer, + loss_function, + seed, + training_dtype, + lora_dtype, + algorithm, + gradient_checkpointing, + existing_lora, ): + # Extract scalars from lists (due to is_input_list=True) + model = model[0] + batch_size = batch_size[0] + steps = steps[0] + grad_accumulation_steps = grad_accumulation_steps[0] + learning_rate = learning_rate[0] + rank = rank[0] + optimizer = optimizer[0] + loss_function = loss_function[0] + seed = seed[0] + training_dtype = training_dtype[0] + lora_dtype = lora_dtype[0] + algorithm = algorithm[0] + gradient_checkpointing = gradient_checkpointing[0] + existing_lora = existing_lora[0] + + # Handle latents - either single dict or list of dicts + if len(latents) == 1: + latents = latents[0]["samples"] # Single latent dict + else: + latent_list = [] + for latent in latents: + latent = latent["samples"] + bs = latent.shape[0] + if bs != 1: + for sub_latent in latent: + latent_list.append(sub_latent[None]) + else: + latent_list.append(latent) + latents = latent_list + + # Handle conditioning - either single list or list of lists + if len(positive) == 1: + positive = positive[0] # Single conditioning list + else: + # Multiple conditioning lists - flatten + flat_positive = [] + for cond in positive: + if isinstance(cond, list): + flat_positive.extend(cond) + else: + flat_positive.append(cond) + positive = flat_positive + mp = model.clone() dtype = node_helpers.string_to_torch_dtype(training_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) mp.set_model_compute_dtype(dtype) - latents = latents["samples"].to(dtype) - num_images = latents.shape[0] + # latents here can be list of different size latent or one large batch + if isinstance(latents, list): + all_shapes = set() + latents = [t.to(dtype) for t in latents] + for latent in latents: + all_shapes.add(latent.shape) + logging.info(f"Latent shapes: {all_shapes}") + if len(all_shapes) > 1: + multi_res = True + else: + multi_res = False + latents = torch.cat(latents, dim=0) + num_images = len(latents) + elif isinstance(latents, torch.Tensor): + latents = latents.to(dtype) + num_images = latents.shape[0] + else: + logging.error(f"Invalid latents type: {type(latents)}") + logging.info(f"Total Images: {num_images}, Total Captions: {len(positive)}") if len(positive) == 1 and num_images > 1: positive = positive * num_images @@ -593,9 +518,7 @@ class TrainLoraNode: shape = m.weight.shape if len(shape) >= 2: alpha = float(existing_weights.get(f"{key}.alpha", 1.0)) - dora_scale = existing_weights.get( - f"{key}.dora_scale", None - ) + dora_scale = existing_weights.get(f"{key}.dora_scale", None) for adapter_cls in adapters: existing_adapter = adapter_cls.load( n, existing_weights, alpha, dora_scale @@ -607,7 +530,9 @@ class TrainLoraNode: adapter_cls = adapter_maps[algorithm] if existing_adapter is not None: - train_adapter = existing_adapter.to_train().to(lora_dtype) + train_adapter = existing_adapter.to_train().to( + lora_dtype + ) else: # Use LoRA with alpha=1.0 by default train_adapter = adapter_cls.create_train( @@ -631,7 +556,9 @@ class TrainLoraNode: if hasattr(m, "bias") and m.bias is not None: key = "{}.bias".format(n) bias = torch.nn.Parameter( - torch.zeros(m.bias.shape, dtype=lora_dtype, requires_grad=True) + torch.zeros( + m.bias.shape, dtype=lora_dtype, requires_grad=True + ) ) bias_module = BiasDiff(bias) lora_sd["{}.diff_b".format(n)] = bias @@ -661,26 +588,33 @@ class TrainLoraNode: # setup models if gradient_checkpointing: - for m in find_all_highest_child_module_with_forward(mp.model.diffusion_model): + for m in find_all_highest_child_module_with_forward( + mp.model.diffusion_model + ): patch(m) mp.model.requires_grad_(False) - comfy.model_management.load_models_gpu([mp], memory_required=1e20, force_full_load=True) + comfy.model_management.load_models_gpu( + [mp], memory_required=1e20, force_full_load=True + ) # Setup sampler and guider like in test script loss_map = {"loss": []} + def loss_callback(loss): loss_map["loss"].append(loss) + train_sampler = TrainSampler( criterion, optimizer, loss_callback=loss_callback, batch_size=batch_size, grad_acc=grad_accumulation_steps, - total_steps=steps*grad_accumulation_steps, + total_steps=steps * grad_accumulation_steps, seed=seed, - training_dtype=dtype + training_dtype=dtype, + real_dataset=latents if multi_res else None, ) - guider = nodes_custom_sampler.Guider_Basic(mp) + guider = comfy_extras.Guider_Basic(mp) guider.set_conds(positive) # Set conditioning from input # Training loop @@ -688,12 +622,15 @@ class TrainLoraNode: # Generate dummy sigmas and noise sigmas = torch.tensor(range(num_images)) noise = Noise_RandomNoise(seed) + if multi_res: + # use first latent as dummy latent if multi_res + latents = latents[0].repeat((num_images,) + ((1,) * (latents[0].ndim - 1))) guider.sample( noise.generate_noise({"samples": latents}), latents, train_sampler, sigmas, - seed=noise.seed + seed=noise.seed, ) finally: for m in mp.model.modules(): @@ -706,111 +643,118 @@ class TrainLoraNode: for param in lora_sd: lora_sd[param] = lora_sd[param].to(lora_dtype) - return (mp, lora_sd, loss_map, steps + existing_steps) + return io.NodeOutput(mp, lora_sd, loss_map, steps + existing_steps) -class LoraModelLoader: - def __init__(self): - self.loaded_lora = None +class LoraModelLoader(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoraModelLoader", + display_name="Load LoRA Model", + category="loaders", + is_experimental=True, + inputs=[ + io.Model.Input( + "model", tooltip="The diffusion model the LoRA will be applied to." + ), + io.Custom("LORA_MODEL").Input( + "lora", tooltip="The LoRA model to apply to the diffusion model." + ), + io.Float.Input( + "strength_model", + default=1.0, + min=-100.0, + max=100.0, + tooltip="How strongly to modify the diffusion model. This value can be negative.", + ), + ], + outputs=[ + io.Model.Output( + display_name="model", tooltip="The modified diffusion model." + ), + ], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "model": ("MODEL", {"tooltip": "The diffusion model the LoRA will be applied to."}), - "lora": (IO.LORA_MODEL, {"tooltip": "The LoRA model to apply to the diffusion model."}), - "strength_model": ("FLOAT", {"default": 1.0, "min": -100.0, "max": 100.0, "step": 0.01, "tooltip": "How strongly to modify the diffusion model. This value can be negative."}), - } - } - - RETURN_TYPES = ("MODEL",) - OUTPUT_TOOLTIPS = ("The modified diffusion model.",) - FUNCTION = "load_lora_model" - - CATEGORY = "loaders" - DESCRIPTION = "Load Trained LoRA weights from Train LoRA node." - EXPERIMENTAL = True - - def load_lora_model(self, model, lora, strength_model): + def execute(cls, model, lora, strength_model): if strength_model == 0: - return (model, ) + return io.NodeOutput(model) - model_lora, _ = comfy.sd.load_lora_for_models(model, None, lora, strength_model, 0, None) - return (model_lora, ) + model_lora, _ = comfy.sd.load_lora_for_models( + model, None, lora, strength_model, 0 + ) + return io.NodeOutput(model_lora) -class SaveLoRA: - def __init__(self): - self.output_dir = folder_paths.get_output_directory() +class SaveLoRA(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveLoRA", + display_name="Save LoRA Weights", + category="loaders", + is_experimental=True, + is_output_node=True, + inputs=[ + io.Custom("LORA_MODEL").Input( + "lora", + tooltip="The LoRA model to save. Do not use the model with LoRA layers.", + ), + io.String.Input( + "prefix", + default="loras/ComfyUI_trained_lora", + tooltip="The prefix to use for the saved LoRA file.", + ), + io.Int.Input( + "steps", + optional=True, + tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + ), + ], + outputs=[], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "lora": ( - IO.LORA_MODEL, - { - "tooltip": "The LoRA model to save. Do not use the model with LoRA layers." - }, - ), - "prefix": ( - "STRING", - { - "default": "loras/ComfyUI_trained_lora", - "tooltip": "The prefix to use for the saved LoRA file.", - }, - ), - }, - "optional": { - "steps": ( - IO.INT, - { - "forceInput": True, - "tooltip": "Optional: The number of steps to LoRA has been trained for, used to name the saved file.", - }, - ), - }, - } - - RETURN_TYPES = () - FUNCTION = "save" - CATEGORY = "loaders" - EXPERIMENTAL = True - OUTPUT_NODE = True - - def save(self, lora, prefix, steps=None): - full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path(prefix, self.output_dir) + def execute(cls, lora, prefix, steps=None): + output_dir = folder_paths.get_output_directory() + full_output_folder, filename, counter, subfolder, filename_prefix = ( + folder_paths.get_save_image_path(prefix, output_dir) + ) if steps is None: output_checkpoint = f"{filename}_{counter:05}_.safetensors" else: output_checkpoint = f"{filename}_{steps}_steps_{counter:05}_.safetensors" output_checkpoint = os.path.join(full_output_folder, output_checkpoint) safetensors.torch.save_file(lora, output_checkpoint) - return {} + return io.NodeOutput() -class LossGraphNode: - def __init__(self): - self.output_dir = folder_paths.get_temp_directory() +class LossGraphNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LossGraphNode", + display_name="Plot Loss Graph", + category="training", + is_experimental=True, + is_output_node=True, + inputs=[ + io.Custom("LOSS_MAP").Input( + "loss", tooltip="Loss map from training node." + ), + io.String.Input( + "filename_prefix", + default="loss_graph", + tooltip="Prefix for the saved loss graph image.", + ), + ], + outputs=[], + hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], + ) @classmethod - def INPUT_TYPES(s): - return { - "required": { - "loss": (IO.LOSS_MAP, {"default": {}}), - "filename_prefix": (IO.STRING, {"default": "loss_graph"}), - }, - "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, - } - - RETURN_TYPES = () - FUNCTION = "plot_loss" - OUTPUT_NODE = True - CATEGORY = "training" - EXPERIMENTAL = True - DESCRIPTION = "Plots the loss graph and saves it to the output directory." - - def plot_loss(self, loss, filename_prefix, prompt=None, extra_pnginfo=None): + def execute(cls, loss, filename_prefix, prompt=None, extra_pnginfo=None): loss_values = loss["loss"] width, height = 800, 480 margin = 40 @@ -853,47 +797,27 @@ class LossGraphNode: (margin - 30, height - 10), f"{min_loss:.2f}", font=font, fill="black" ) - metadata = None - if not args.disable_metadata: - metadata = PngInfo() - if prompt is not None: - metadata.add_text("prompt", json.dumps(prompt)) - if extra_pnginfo is not None: - for x in extra_pnginfo: - metadata.add_text(x, json.dumps(extra_pnginfo[x])) + # Convert PIL image to tensor for PreviewImage + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] # [1, H, W, 3] - date = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") - img.save( - os.path.join(self.output_dir, f"{filename_prefix}_{date}.png"), - pnginfo=metadata, - ) - return { - "ui": { - "images": [ - { - "filename": f"{filename_prefix}_{date}.png", - "subfolder": "", - "type": "temp", - } - ] - } - } + # Return preview UI + return io.NodeOutput(ui=ui.PreviewImage(img_tensor, cls=cls)) -NODE_CLASS_MAPPINGS = { - "TrainLoraNode": TrainLoraNode, - "SaveLoRANode": SaveLoRA, - "LoraModelLoader": LoraModelLoader, - "LoadImageSetFromFolderNode": LoadImageSetFromFolderNode, - "LoadImageTextSetFromFolderNode": LoadImageTextSetFromFolderNode, - "LossGraphNode": LossGraphNode, -} +# ========== Extension Setup ========== -NODE_DISPLAY_NAME_MAPPINGS = { - "TrainLoraNode": "Train LoRA", - "SaveLoRANode": "Save LoRA Weights", - "LoraModelLoader": "Load LoRA Model", - "LoadImageSetFromFolderNode": "Load Image Dataset from Folder", - "LoadImageTextSetFromFolderNode": "Load Image and Text Dataset from Folder", - "LossGraphNode": "Plot Loss Graph", -} + +class TrainingExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + TrainLoraNode, + LoraModelLoader, + SaveLoRA, + LossGraphNode, + ] + + +async def comfy_entrypoint() -> TrainingExtension: + return TrainingExtension() diff --git a/comfy_extras/nodes/nodes_video.py b/comfy_extras/nodes/nodes_video.py index 64fc27d74..aac41cc1e 100644 --- a/comfy_extras/nodes/nodes_video.py +++ b/comfy_extras/nodes/nodes_video.py @@ -11,10 +11,8 @@ from typing_extensions import override from comfy.cli_args import args from comfy.cmd import folder_paths -from comfy_api.input import AudioInput, ImageInput, VideoInput -from comfy_api.input_impl import VideoFromComponents, VideoFromFile from comfy_api.latest import ComfyExtension, io, ui -from comfy_api.util import VideoCodec, VideoComponents, VideoContainer +from comfy_api.latest import Input, InputImpl, Types class SaveWEBM(io.ComfyNode): @@ -31,7 +29,6 @@ class SaveWEBM(io.ComfyNode): io.Float.Input("fps", default=24.0, min=0.01, max=1000.0, step=0.01), io.Float.Input("crf", default=32.0, min=0, max=63.0, step=1, tooltip="Higher crf means lower quality with a smaller file size, lower crf means higher quality higher filesize."), ], - outputs=[], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, ) @@ -83,16 +80,15 @@ class SaveVideo(io.ComfyNode): inputs=[ io.Video.Input("video", tooltip="The video to save."), io.String.Input("filename_prefix", default="video/ComfyUI", tooltip="The prefix for the file to save. This may include formatting information such as %date:yyyy-MM-dd% or %Empty Latent Image.width% to include values from nodes."), - io.Combo.Input("format", options=VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), - io.Combo.Input("codec", options=VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), + io.Combo.Input("format", options=Types.VideoContainer.as_input(), default="auto", tooltip="The format to save the video as."), + io.Combo.Input("codec", options=Types.VideoCodec.as_input(), default="auto", tooltip="The codec to use for the video."), ], - outputs=[], hidden=[io.Hidden.prompt, io.Hidden.extra_pnginfo], is_output_node=True, ) @classmethod - def execute(cls, video: VideoInput, filename_prefix, format, codec) -> io.NodeOutput: + def execute(cls, video: Input.Video, filename_prefix, format: str, codec) -> io.NodeOutput: width, height = video.get_dimensions() full_output_folder, filename, counter, subfolder, filename_prefix = folder_paths.get_save_image_path( filename_prefix, @@ -109,10 +105,10 @@ class SaveVideo(io.ComfyNode): metadata["prompt"] = cls.hidden.prompt if len(metadata) > 0: saved_metadata = metadata - file = f"{filename}_{counter:05}_.{VideoContainer.get_extension(format)}" + file = f"{filename}_{counter:05}_.{Types.VideoContainer.get_extension(format)}" video.save_to( os.path.join(full_output_folder, file), - format=format, + format=Types.VideoContainer(format), codec=codec, metadata=saved_metadata ) @@ -139,9 +135,9 @@ class CreateVideo(io.ComfyNode): ) @classmethod - def execute(cls, images: ImageInput, fps: float, audio: Optional[AudioInput] = None) -> io.NodeOutput: + def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput: return io.NodeOutput( - VideoFromComponents(VideoComponents( + InputImpl.VideoFromComponents(Types.VideoComponents( images=images, audio=audio, frame_rate=Fraction(fps))) @@ -167,9 +163,8 @@ class GetVideoComponents(io.ComfyNode): ) @classmethod - def execute(cls, video: VideoInput) -> io.NodeOutput: + def execute(cls, video: Input.Video) -> io.NodeOutput: components = video.get_components() - return io.NodeOutput(components.images, components.audio, float(components.frame_rate)) @@ -194,7 +189,7 @@ class LoadVideo(io.ComfyNode): @classmethod def execute(cls, file) -> io.NodeOutput: video_path = folder_paths.get_annotated_filepath(file) - return io.NodeOutput(VideoFromFile(video_path)) + return io.NodeOutput(InputImpl.VideoFromFile(video_path)) @classmethod def fingerprint_inputs(s, file): diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py new file mode 100644 index 000000000..4789d7d53 --- /dev/null +++ b/comfy_extras/nodes_dataset.py @@ -0,0 +1,1432 @@ +import logging +import os +import json + +import numpy as np +import torch +from PIL import Image +from typing_extensions import override + +import folder_paths +import node_helpers +from comfy_api.latest import ComfyExtension, io + + +def load_and_process_images(image_files, input_dir): + """Utility function to load and process a list of images. + + Args: + image_files: List of image filenames + input_dir: Base directory containing the images + resize_method: How to handle images of different sizes ("None", "Stretch", "Crop", "Pad") + + Returns: + torch.Tensor: Batch of processed images + """ + if not image_files: + raise ValueError("No valid images found in input") + + output_images = [] + + for file in image_files: + image_path = os.path.join(input_dir, file) + img = node_helpers.pillow(Image.open, image_path) + + if img.mode == "I": + img = img.point(lambda i: i * (1 / 255)) + img = img.convert("RGB") + img_array = np.array(img).astype(np.float32) / 255.0 + img_tensor = torch.from_numpy(img_array)[None,] + output_images.append(img_tensor) + + return output_images + + +class LoadImageDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageDataSetFromFolder", + display_name="Load Image Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ) + ], + ) + + @classmethod + def execute(cls, folder): + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + image_files = [ + f + for f in os.listdir(sub_input_dir) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + output_tensor = load_and_process_images(image_files, sub_input_dir) + return io.NodeOutput(output_tensor) + + +class LoadImageTextDataSetFromFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadImageTextDataSetFromFolder", + display_name="Load Image and Text Dataset from Folder", + category="dataset", + is_experimental=True, + inputs=[ + io.Combo.Input( + "folder", + options=folder_paths.get_input_subfolders(), + tooltip="The folder to load images from.", + ) + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="List of loaded images", + ), + io.String.Output( + display_name="texts", + is_output_list=True, + tooltip="List of text captions", + ), + ], + ) + + @classmethod + def execute(cls, folder): + logging.info(f"Loading images from folder: {folder}") + + sub_input_dir = os.path.join(folder_paths.get_input_directory(), folder) + valid_extensions = [".png", ".jpg", ".jpeg", ".webp"] + + image_files = [] + for item in os.listdir(sub_input_dir): + path = os.path.join(sub_input_dir, item) + if any(item.lower().endswith(ext) for ext in valid_extensions): + image_files.append(path) + elif os.path.isdir(path): + # Support kohya-ss/sd-scripts folder structure + repeat = 1 + if item.split("_")[0].isdigit(): + repeat = int(item.split("_")[0]) + image_files.extend( + [ + os.path.join(path, f) + for f in os.listdir(path) + if any(f.lower().endswith(ext) for ext in valid_extensions) + ] + * repeat + ) + + caption_file_path = [ + f.replace(os.path.splitext(f)[1], ".txt") for f in image_files + ] + captions = [] + for caption_file in caption_file_path: + caption_path = os.path.join(sub_input_dir, caption_file) + if os.path.exists(caption_path): + with open(caption_path, "r", encoding="utf-8") as f: + caption = f.read().strip() + captions.append(caption) + else: + captions.append("") + + output_tensor = load_and_process_images(image_files, sub_input_dir) + + logging.info(f"Loaded {len(output_tensor)} images from {sub_input_dir}.") + return io.NodeOutput(output_tensor, captions) + + +def save_images_to_folder(image_list, output_dir, prefix="image"): + """Utility function to save a list of image tensors to disk. + + Args: + image_list: List of image tensors (each [1, H, W, C] or [H, W, C] or [C, H, W]) + output_dir: Directory to save images to + prefix: Filename prefix + + Returns: + List of saved filenames + """ + os.makedirs(output_dir, exist_ok=True) + saved_files = [] + + for idx, img_tensor in enumerate(image_list): + # Handle different tensor shapes + if isinstance(img_tensor, torch.Tensor): + # Remove batch dimension if present [1, H, W, C] -> [H, W, C] + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + + # If tensor is [C, H, W], permute to [H, W, C] + if img_tensor.dim() == 3 and img_tensor.shape[0] in [1, 3, 4]: + if ( + img_tensor.shape[0] <= 4 + and img_tensor.shape[1] > 4 + and img_tensor.shape[2] > 4 + ): + img_tensor = img_tensor.permute(1, 2, 0) + + # Convert to numpy and scale to 0-255 + img_array = img_tensor.cpu().numpy() + img_array = np.clip(img_array * 255.0, 0, 255).astype(np.uint8) + + # Convert to PIL Image + img = Image.fromarray(img_array) + else: + raise ValueError(f"Expected torch.Tensor, got {type(img_tensor)}") + + # Save image + filename = f"{prefix}_{idx:05d}.png" + filepath = os.path.join(output_dir, filename) + img.save(filepath) + saved_files.append(filename) + + return saved_files + + +class SaveImageDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageDataSetToFolder", + display_name="Save Image Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive images as list + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + logging.info(f"Saved {len(saved_files)} images to {output_dir}.") + return io.NodeOutput() + + +class SaveImageTextDataSetToFolderNode(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveImageTextDataSetToFolder", + display_name="Save Image and Text Dataset to Folder", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive both images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to save."), + io.String.Input("texts", tooltip="List of text captions to save."), + io.String.Input( + "folder_name", + default="dataset", + tooltip="Name of the folder to save images to (inside output directory).", + ), + io.String.Input( + "filename_prefix", + default="image", + tooltip="Prefix for saved image filenames.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, images, texts, folder_name, filename_prefix): + # Extract scalar values + folder_name = folder_name[0] + filename_prefix = filename_prefix[0] + + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + saved_files = save_images_to_folder(images, output_dir, filename_prefix) + + # Save captions + for idx, (filename, caption) in enumerate(zip(saved_files, texts)): + caption_filename = filename.replace(".png", ".txt") + caption_path = os.path.join(output_dir, caption_filename) + with open(caption_path, "w", encoding="utf-8") as f: + f.write(caption) + + logging.info(f"Saved {len(saved_files)} images and captions to {output_dir}.") + return io.NodeOutput() + + +# ========== Helper Functions for Transform Nodes ========== + + +def tensor_to_pil(img_tensor): + """Convert tensor to PIL Image.""" + if img_tensor.dim() == 4 and img_tensor.shape[0] == 1: + img_tensor = img_tensor.squeeze(0) + img_array = (img_tensor.cpu().numpy() * 255).clip(0, 255).astype(np.uint8) + return Image.fromarray(img_array) + + +def pil_to_tensor(img): + """Convert PIL Image to tensor.""" + img_array = np.array(img).astype(np.float32) / 255.0 + return torch.from_numpy(img_array)[None,] + + +# ========== Base Classes for Transform Nodes ========== + + +class ImageProcessingNode(io.ComfyNode): + """Base class for image processing nodes that operate on images. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "images" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, image, **kwargs) -> tensor (for single-item processing) + _group_process(cls, images, **kwargs) -> list[tensor] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = ImageProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + # Auto-detect is_output_list if not explicitly set + # Single processing: False (backend collects results into list) + # Group processing: True by default (can be False for single-output nodes) + output_is_list = ( + cls.is_output_list if cls.is_output_list is not None else is_group + ) + + inputs = [ + io.Image.Input( + "images", + tooltip=( + "List of images to process." if is_group else "Image to process." + ), + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/image", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=output_is_list, + tooltip="Processed images", + ) + ], + ) + + @classmethod + def execute(cls, images, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: images is list, call _group_process + result = cls._group_process(images, **params) + else: + # Individual processing: images is single item, call _process + result = cls._process(images, **params) + + return io.NodeOutput(result) + + @classmethod + def _process(cls, image, **kwargs): + """Override this method for single-item processing. + + Args: + image: tensor - Single image tensor + **kwargs: Additional parameters (already extracted from lists) + + Returns: + tensor - Processed image + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, images, **kwargs): + """Override this method for group processing. + + Args: + images: list[tensor] - List of image tensors + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[tensor] - Processed images + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +class TextProcessingNode(io.ComfyNode): + """Base class for text processing nodes that operate on texts. + + Child classes should set: + node_id: Unique node identifier (required) + display_name: Display name (optional, defaults to node_id) + description: Node description (optional) + extra_inputs: List of additional io.Input objects beyond "texts" (optional) + is_group_process: None (auto-detect), True (group), or False (individual) (optional) + is_output_list: True (list output) or False (single output) (optional, default True) + + Child classes must implement ONE of: + _process(cls, text, **kwargs) -> str (for single-item processing) + _group_process(cls, texts, **kwargs) -> list[str] (for group processing) + """ + + node_id = None + display_name = None + description = None + extra_inputs = [] + is_group_process = None # None = auto-detect, True/False = explicit + is_output_list = None # None = auto-detect based on processing mode + + @classmethod + def _detect_processing_mode(cls): + """Detect whether this node uses group or individual processing. + + Returns: + bool: True if group processing, False if individual processing + """ + # Explicit setting takes precedence + if cls.is_group_process is not None: + return cls.is_group_process + + # Check which method is overridden by looking at the defining class in MRO + base_class = TextProcessingNode + + # Find which class in MRO defines _process + process_definer = None + for klass in cls.__mro__: + if "_process" in klass.__dict__: + process_definer = klass + break + + # Find which class in MRO defines _group_process + group_definer = None + for klass in cls.__mro__: + if "_group_process" in klass.__dict__: + group_definer = klass + break + + # Check what was overridden (not defined in base class) + has_process = process_definer is not None and process_definer is not base_class + has_group = group_definer is not None and group_definer is not base_class + + if has_process and has_group: + raise ValueError( + f"{cls.__name__}: Cannot override both _process and _group_process. " + "Override only one, or set is_group_process explicitly." + ) + if not has_process and not has_group: + raise ValueError( + f"{cls.__name__}: Must override either _process or _group_process" + ) + + return has_group + + @classmethod + def define_schema(cls): + if cls.node_id is None: + raise NotImplementedError(f"{cls.__name__} must set node_id class variable") + + is_group = cls._detect_processing_mode() + + inputs = [ + io.String.Input( + "texts", + tooltip="List of texts to process." if is_group else "Text to process.", + ) + ] + inputs.extend(cls.extra_inputs) + + return io.Schema( + node_id=cls.node_id, + display_name=cls.display_name or cls.node_id, + category="dataset/text", + is_experimental=True, + is_input_list=is_group, # True for group, False for individual + inputs=inputs, + outputs=[ + io.String.Output( + display_name="texts", + is_output_list=cls.is_output_list, + tooltip="Processed texts", + ) + ], + ) + + @classmethod + def execute(cls, texts, **kwargs): + """Execute the node. Routes to _process or _group_process based on mode.""" + is_group = cls._detect_processing_mode() + + # Extract scalar values from lists for parameters + params = {} + for k, v in kwargs.items(): + if isinstance(v, list) and len(v) == 1: + params[k] = v[0] + else: + params[k] = v + + if is_group: + # Group processing: texts is list, call _group_process + result = cls._group_process(texts, **params) + else: + # Individual processing: texts is single item, call _process + result = cls._process(texts, **params) + + # Wrap result based on is_output_list + if cls.is_output_list: + # Result should already be a list (or will be for individual) + return io.NodeOutput(result if is_group else [result]) + else: + # Single output - wrap in list for NodeOutput + return io.NodeOutput([result]) + + @classmethod + def _process(cls, text, **kwargs): + """Override this method for single-item processing. + + Args: + text: str - Single text string + **kwargs: Additional parameters (already extracted from lists) + + Returns: + str - Processed text + """ + raise NotImplementedError(f"{cls.__name__} must implement _process method") + + @classmethod + def _group_process(cls, texts, **kwargs): + """Override this method for group processing. + + Args: + texts: list[str] - List of text strings + **kwargs: Additional parameters (already extracted from lists) + + Returns: + list[str] - Processed texts + """ + raise NotImplementedError( + f"{cls.__name__} must implement _group_process method" + ) + + +# ========== Image Transform Nodes ========== + + +class ResizeImagesByShorterEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByShorterEdge" + display_name = "Resize Images by Shorter Edge" + description = "Resize images so that the shorter edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "shorter_edge", + default=512, + min=1, + max=8192, + tooltip="Target length for the shorter edge.", + ), + ] + + @classmethod + def _process(cls, image, shorter_edge): + img = tensor_to_pil(image) + w, h = img.size + if w < h: + new_w = shorter_edge + new_h = int(h * (shorter_edge / w)) + else: + new_h = shorter_edge + new_w = int(w * (shorter_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class ResizeImagesByLongerEdgeNode(ImageProcessingNode): + node_id = "ResizeImagesByLongerEdge" + display_name = "Resize Images by Longer Edge" + description = "Resize images so that the longer edge matches the specified length while preserving aspect ratio." + extra_inputs = [ + io.Int.Input( + "longer_edge", + default=1024, + min=1, + max=8192, + tooltip="Target length for the longer edge.", + ), + ] + + @classmethod + def _process(cls, image, longer_edge): + img = tensor_to_pil(image) + w, h = img.size + if w > h: + new_w = longer_edge + new_h = int(h * (longer_edge / w)) + else: + new_h = longer_edge + new_w = int(w * (longer_edge / h)) + img = img.resize((new_w, new_h), Image.Resampling.LANCZOS) + return pil_to_tensor(img) + + +class CenterCropImagesNode(ImageProcessingNode): + node_id = "CenterCropImages" + display_name = "Center Crop Images" + description = "Center crop all images to the specified dimensions." + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + ] + + @classmethod + def _process(cls, image, width, height): + img = tensor_to_pil(image) + left = max(0, (img.width - width) // 2) + top = max(0, (img.height - height) // 2) + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class RandomCropImagesNode(ImageProcessingNode): + node_id = "RandomCropImages" + display_name = "Random Crop Images" + description = ( + "Randomly crop all images to the specified dimensions (for data augmentation)." + ) + extra_inputs = [ + io.Int.Input("width", default=512, min=1, max=8192, tooltip="Crop width."), + io.Int.Input("height", default=512, min=1, max=8192, tooltip="Crop height."), + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _process(cls, image, width, height, seed): + np.random.seed(seed % (2**32 - 1)) + img = tensor_to_pil(image) + max_left = max(0, img.width - width) + max_top = max(0, img.height - height) + left = np.random.randint(0, max_left + 1) if max_left > 0 else 0 + top = np.random.randint(0, max_top + 1) if max_top > 0 else 0 + right = min(img.width, left + width) + bottom = min(img.height, top + height) + img = img.crop((left, top, right, bottom)) + return pil_to_tensor(img) + + +class NormalizeImagesNode(ImageProcessingNode): + node_id = "NormalizeImages" + display_name = "Normalize Images" + description = "Normalize images using mean and standard deviation." + extra_inputs = [ + io.Float.Input( + "mean", + default=0.5, + min=0.0, + max=1.0, + tooltip="Mean value for normalization.", + ), + io.Float.Input( + "std", + default=0.5, + min=0.001, + max=1.0, + tooltip="Standard deviation for normalization.", + ), + ] + + @classmethod + def _process(cls, image, mean, std): + return (image - mean) / std + + +class AdjustBrightnessNode(ImageProcessingNode): + node_id = "AdjustBrightness" + display_name = "Adjust Brightness" + description = "Adjust brightness of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Brightness factor. 1.0 = no change, <1.0 = darker, >1.0 = brighter.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return (image * factor).clamp(0.0, 1.0) + + +class AdjustContrastNode(ImageProcessingNode): + node_id = "AdjustContrast" + display_name = "Adjust Contrast" + description = "Adjust contrast of all images." + extra_inputs = [ + io.Float.Input( + "factor", + default=1.0, + min=0.0, + max=2.0, + tooltip="Contrast factor. 1.0 = no change, <1.0 = less contrast, >1.0 = more contrast.", + ), + ] + + @classmethod + def _process(cls, image, factor): + return ((image - 0.5) * factor + 0.5).clamp(0.0, 1.0) + + +class ShuffleDatasetNode(ImageProcessingNode): + node_id = "ShuffleDataset" + display_name = "Shuffle Image Dataset" + description = "Randomly shuffle the order of images in the dataset." + is_group_process = True # Requires full list to shuffle + extra_inputs = [ + io.Int.Input( + "seed", default=0, min=0, max=0xFFFFFFFFFFFFFFFF, tooltip="Random seed." + ), + ] + + @classmethod + def _group_process(cls, images, seed): + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + return [images[i] for i in indices] + + +class ShuffleImageTextDatasetNode(io.ComfyNode): + """Special node that shuffles both images and texts together.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ShuffleImageTextDataset", + display_name="Shuffle Image-Text Dataset", + category="dataset/image", + is_experimental=True, + is_input_list=True, + inputs=[ + io.Image.Input("images", tooltip="List of images to shuffle."), + io.String.Input("texts", tooltip="List of texts to shuffle."), + io.Int.Input( + "seed", + default=0, + min=0, + max=0xFFFFFFFFFFFFFFFF, + tooltip="Random seed.", + ), + ], + outputs=[ + io.Image.Output( + display_name="images", + is_output_list=True, + tooltip="Shuffled images", + ), + io.String.Output( + display_name="texts", is_output_list=True, tooltip="Shuffled texts" + ), + ], + ) + + @classmethod + def execute(cls, images, texts, seed): + seed = seed[0] # Extract scalar + np.random.seed(seed % (2**32 - 1)) + indices = np.random.permutation(len(images)) + shuffled_images = [images[i] for i in indices] + shuffled_texts = [texts[i] for i in indices] + return io.NodeOutput(shuffled_images, shuffled_texts) + + +# ========== Text Transform Nodes ========== + + +class TextToLowercaseNode(TextProcessingNode): + node_id = "TextToLowercase" + display_name = "Text to Lowercase" + description = "Convert all texts to lowercase." + + @classmethod + def _process(cls, text): + return text.lower() + + +class TextToUppercaseNode(TextProcessingNode): + node_id = "TextToUppercase" + display_name = "Text to Uppercase" + description = "Convert all texts to uppercase." + + @classmethod + def _process(cls, text): + return text.upper() + + +class TruncateTextNode(TextProcessingNode): + node_id = "TruncateText" + display_name = "Truncate Text" + description = "Truncate all texts to a maximum length." + extra_inputs = [ + io.Int.Input( + "max_length", default=77, min=1, max=10000, tooltip="Maximum text length." + ), + ] + + @classmethod + def _process(cls, text, max_length): + return text[:max_length] + + +class AddTextPrefixNode(TextProcessingNode): + node_id = "AddTextPrefix" + display_name = "Add Text Prefix" + description = "Add a prefix to all texts." + extra_inputs = [ + io.String.Input("prefix", default="", tooltip="Prefix to add."), + ] + + @classmethod + def _process(cls, text, prefix): + return prefix + text + + +class AddTextSuffixNode(TextProcessingNode): + node_id = "AddTextSuffix" + display_name = "Add Text Suffix" + description = "Add a suffix to all texts." + extra_inputs = [ + io.String.Input("suffix", default="", tooltip="Suffix to add."), + ] + + @classmethod + def _process(cls, text, suffix): + return text + suffix + + +class ReplaceTextNode(TextProcessingNode): + node_id = "ReplaceText" + display_name = "Replace Text" + description = "Replace text in all texts." + extra_inputs = [ + io.String.Input("find", default="", tooltip="Text to find."), + io.String.Input("replace", default="", tooltip="Text to replace with."), + ] + + @classmethod + def _process(cls, text, find, replace): + return text.replace(find, replace) + + +class StripWhitespaceNode(TextProcessingNode): + node_id = "StripWhitespace" + display_name = "Strip Whitespace" + description = "Strip leading and trailing whitespace from all texts." + + @classmethod + def _process(cls, text): + return text.strip() + + +# ========== Group Processing Example Nodes ========== + + +class ImageDeduplicationNode(ImageProcessingNode): + """Remove duplicate or very similar images from the dataset using perceptual hashing.""" + + node_id = "ImageDeduplication" + display_name = "Image Deduplication" + description = "Remove duplicate or very similar images from the dataset." + is_group_process = True # Requires full list to compare images + extra_inputs = [ + io.Float.Input( + "similarity_threshold", + default=0.95, + min=0.0, + max=1.0, + tooltip="Similarity threshold (0-1). Higher means more similar. Images above this threshold are considered duplicates.", + ), + ] + + @classmethod + def _group_process(cls, images, similarity_threshold): + """Remove duplicate images using perceptual hashing.""" + if len(images) == 0: + return [] + + # Compute simple perceptual hash for each image + def compute_hash(img_tensor): + """Compute a simple perceptual hash by resizing to 8x8 and comparing to average.""" + img = tensor_to_pil(img_tensor) + # Resize to 8x8 + img_small = img.resize((8, 8), Image.Resampling.LANCZOS).convert("L") + # Get pixels + pixels = list(img_small.getdata()) + # Compute average + avg = sum(pixels) / len(pixels) + # Create hash (1 if above average, 0 otherwise) + hash_bits = "".join("1" if p > avg else "0" for p in pixels) + return hash_bits + + def hamming_distance(hash1, hash2): + """Compute Hamming distance between two hash strings.""" + return sum(c1 != c2 for c1, c2 in zip(hash1, hash2)) + + # Compute hashes for all images + hashes = [compute_hash(img) for img in images] + + # Find duplicates + keep_indices = [] + for i in range(len(images)): + is_duplicate = False + for j in keep_indices: + # Compare hashes + distance = hamming_distance(hashes[i], hashes[j]) + similarity = 1.0 - (distance / 64.0) # 64 bits total + if similarity >= similarity_threshold: + is_duplicate = True + logging.info( + f"Image {i} is similar to image {j} (similarity: {similarity:.3f}), skipping" + ) + break + + if not is_duplicate: + keep_indices.append(i) + + # Return only unique images + unique_images = [images[i] for i in keep_indices] + logging.info( + f"Deduplication: kept {len(unique_images)} out of {len(images)} images" + ) + return unique_images + + +class ImageGridNode(ImageProcessingNode): + """Combine multiple images into a single grid/collage.""" + + node_id = "ImageGrid" + display_name = "Image Grid" + description = "Arrange multiple images into a grid layout." + is_group_process = True # Requires full list to create grid + is_output_list = False # Outputs single grid image + extra_inputs = [ + io.Int.Input( + "columns", + default=4, + min=1, + max=20, + tooltip="Number of columns in the grid.", + ), + io.Int.Input( + "cell_width", + default=256, + min=32, + max=2048, + tooltip="Width of each cell in the grid.", + ), + io.Int.Input( + "cell_height", + default=256, + min=32, + max=2048, + tooltip="Height of each cell in the grid.", + ), + io.Int.Input( + "padding", default=4, min=0, max=50, tooltip="Padding between images." + ), + ] + + @classmethod + def _group_process(cls, images, columns, cell_width, cell_height, padding): + """Arrange images into a grid.""" + if len(images) == 0: + raise ValueError("Cannot create grid from empty image list") + + # Calculate grid dimensions + num_images = len(images) + rows = (num_images + columns - 1) // columns # Ceiling division + + # Calculate total grid size + grid_width = columns * cell_width + (columns - 1) * padding + grid_height = rows * cell_height + (rows - 1) * padding + + # Create blank grid + grid = Image.new("RGB", (grid_width, grid_height), (0, 0, 0)) + + # Place images + for idx, img_tensor in enumerate(images): + row = idx // columns + col = idx % columns + + # Convert to PIL and resize to cell size + img = tensor_to_pil(img_tensor) + img = img.resize((cell_width, cell_height), Image.Resampling.LANCZOS) + + # Calculate position + x = col * (cell_width + padding) + y = row * (cell_height + padding) + + # Paste into grid + grid.paste(img, (x, y)) + + logging.info( + f"Created {columns}x{rows} grid with {num_images} images ({grid_width}x{grid_height})" + ) + return pil_to_tensor(grid) + + +class MergeImageListsNode(ImageProcessingNode): + """Merge multiple image lists into a single list.""" + + node_id = "MergeImageLists" + display_name = "Merge Image Lists" + description = "Concatenate multiple image lists into one." + is_group_process = True # Receives images as list + + @classmethod + def _group_process(cls, images): + """Simply return the images list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged image list contains {len(images)} images") + return images + + +class MergeTextListsNode(TextProcessingNode): + """Merge multiple text lists into a single list.""" + + node_id = "MergeTextLists" + display_name = "Merge Text Lists" + description = "Concatenate multiple text lists into one." + is_group_process = True # Receives texts as list + + @classmethod + def _group_process(cls, texts): + """Simply return the texts list (already merged by input handling).""" + # When multiple list inputs are connected, they're concatenated + # For now, this is a simple pass-through + logging.info(f"Merged text list contains {len(texts)} texts") + return texts + + +# ========== Training Dataset Nodes ========== + + +class MakeTrainingDataset(io.ComfyNode): + """Encode images with VAE and texts with CLIP to create a training dataset.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="MakeTrainingDataset", + display_name="Make Training Dataset", + category="dataset", + is_experimental=True, + is_input_list=True, # images and texts as lists + inputs=[ + io.Image.Input("images", tooltip="List of images to encode."), + io.Vae.Input( + "vae", tooltip="VAE model for encoding images to latents." + ), + io.Clip.Input( + "clip", tooltip="CLIP model for encoding text to conditioning." + ), + io.String.Input( + "texts", + optional=True, + tooltip="List of text captions. Can be length n (matching images), 1 (repeated for all), or omitted (uses empty string).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, images, vae, clip, texts=None): + # Extract scalars (vae and clip are single values wrapped in lists) + vae = vae[0] + clip = clip[0] + + # Handle text list + num_images = len(images) + + if texts is None or len(texts) == 0: + # Treat as [""] for unconditional training + texts = [""] + + if len(texts) == 1 and num_images > 1: + # Repeat single text for all images + texts = texts * num_images + elif len(texts) != num_images: + raise ValueError( + f"Number of texts ({len(texts)}) does not match number of images ({num_images}). " + f"Text list should have length {num_images}, 1, or 0." + ) + + # Encode images with VAE + logging.info(f"Encoding {num_images} images with VAE...") + latents_list = [] # list[{"samples": tensor}] + for img_tensor in images: + # img_tensor is [1, H, W, 3] + latent_tensor = vae.encode(img_tensor[:, :, :, :3]) + latents_list.append({"samples": latent_tensor}) + + # Encode texts with CLIP + logging.info(f"Encoding {len(texts)} texts with CLIP...") + conditioning_list = [] # list[list[cond]] + for text in texts: + if text == "": + cond = clip.encode_from_tokens_scheduled(clip.tokenize("")) + else: + tokens = clip.tokenize(text) + cond = clip.encode_from_tokens_scheduled(tokens) + conditioning_list.append(cond) + + logging.info( + f"Created dataset with {len(latents_list)} latents and {len(conditioning_list)} conditioning." + ) + return io.NodeOutput(latents_list, conditioning_list) + + +class SaveTrainingDataset(io.ComfyNode): + """Save encoded training dataset (latents + conditioning) to disk.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="SaveTrainingDataset", + display_name="Save Training Dataset", + category="dataset", + is_experimental=True, + is_output_node=True, + is_input_list=True, # Receive lists + inputs=[ + io.Latent.Input( + "latents", + tooltip="List of latent dicts from MakeTrainingDataset.", + ), + io.Conditioning.Input( + "conditioning", + tooltip="List of conditioning lists from MakeTrainingDataset.", + ), + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder to save dataset (inside output directory).", + ), + io.Int.Input( + "shard_size", + default=1000, + min=1, + max=100000, + tooltip="Number of samples per shard file.", + ), + ], + outputs=[], + ) + + @classmethod + def execute(cls, latents, conditioning, folder_name, shard_size): + # Extract scalars + folder_name = folder_name[0] + shard_size = shard_size[0] + + # latents: list[{"samples": tensor}] + # conditioning: list[list[cond]] + + # Validate lengths match + if len(latents) != len(conditioning): + raise ValueError( + f"Number of latents ({len(latents)}) does not match number of conditions ({len(conditioning)}). " + f"Something went wrong in dataset preparation." + ) + + # Create output directory + output_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + os.makedirs(output_dir, exist_ok=True) + + # Prepare data pairs + num_samples = len(latents) + num_shards = (num_samples + shard_size - 1) // shard_size # Ceiling division + + logging.info( + f"Saving {num_samples} samples to {num_shards} shards in {output_dir}..." + ) + + # Save data in shards + for shard_idx in range(num_shards): + start_idx = shard_idx * shard_size + end_idx = min(start_idx + shard_size, num_samples) + + # Get shard data (list of latent dicts and conditioning lists) + shard_data = { + "latents": latents[start_idx:end_idx], + "conditioning": conditioning[start_idx:end_idx], + } + + # Save shard + shard_filename = f"shard_{shard_idx:04d}.pkl" + shard_path = os.path.join(output_dir, shard_filename) + + with open(shard_path, "wb") as f: + torch.save(shard_data, f) + + logging.info( + f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)" + ) + + # Save metadata + metadata = { + "num_samples": num_samples, + "num_shards": num_shards, + "shard_size": shard_size, + } + metadata_path = os.path.join(output_dir, "metadata.json") + with open(metadata_path, "w") as f: + json.dump(metadata, f, indent=2) + + logging.info(f"Successfully saved {num_samples} samples to {output_dir}.") + return io.NodeOutput() + + +class LoadTrainingDataset(io.ComfyNode): + """Load encoded training dataset from disk.""" + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="LoadTrainingDataset", + display_name="Load Training Dataset", + category="dataset", + is_experimental=True, + inputs=[ + io.String.Input( + "folder_name", + default="training_dataset", + tooltip="Name of folder containing the saved dataset (inside output directory).", + ), + ], + outputs=[ + io.Latent.Output( + display_name="latents", + is_output_list=True, + tooltip="List of latent dicts", + ), + io.Conditioning.Output( + display_name="conditioning", + is_output_list=True, + tooltip="List of conditioning lists", + ), + ], + ) + + @classmethod + def execute(cls, folder_name): + # Get dataset directory + dataset_dir = os.path.join(folder_paths.get_output_directory(), folder_name) + + if not os.path.exists(dataset_dir): + raise ValueError(f"Dataset directory not found: {dataset_dir}") + + # Find all shard files + shard_files = sorted( + [ + f + for f in os.listdir(dataset_dir) + if f.startswith("shard_") and f.endswith(".pkl") + ] + ) + + if not shard_files: + raise ValueError(f"No shard files found in {dataset_dir}") + + logging.info(f"Loading {len(shard_files)} shards from {dataset_dir}...") + + # Load all shards + all_latents = [] # list[{"samples": tensor}] + all_conditioning = [] # list[list[cond]] + + for shard_file in shard_files: + shard_path = os.path.join(dataset_dir, shard_file) + + with open(shard_path, "rb") as f: + shard_data = torch.load(f, weights_only=True) + + all_latents.extend(shard_data["latents"]) + all_conditioning.extend(shard_data["conditioning"]) + + logging.info(f"Loaded {shard_file}: {len(shard_data['latents'])} samples") + + logging.info( + f"Successfully loaded {len(all_latents)} samples from {dataset_dir}." + ) + return io.NodeOutput(all_latents, all_conditioning) + + +# ========== Extension Setup ========== + + +class DatasetExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + # Data loading/saving nodes + LoadImageDataSetFromFolderNode, + LoadImageTextDataSetFromFolderNode, + SaveImageDataSetToFolderNode, + SaveImageTextDataSetToFolderNode, + # Image transform nodes + ResizeImagesByShorterEdgeNode, + ResizeImagesByLongerEdgeNode, + CenterCropImagesNode, + RandomCropImagesNode, + NormalizeImagesNode, + AdjustBrightnessNode, + AdjustContrastNode, + ShuffleDatasetNode, + ShuffleImageTextDatasetNode, + # Text transform nodes + TextToLowercaseNode, + TextToUppercaseNode, + TruncateTextNode, + AddTextPrefixNode, + AddTextSuffixNode, + ReplaceTextNode, + StripWhitespaceNode, + # Group processing examples + ImageDeduplicationNode, + ImageGridNode, + MergeImageListsNode, + MergeTextListsNode, + # Training dataset nodes + MakeTrainingDataset, + SaveTrainingDataset, + LoadTrainingDataset, + ] + + +async def comfy_entrypoint() -> DatasetExtension: + return DatasetExtension() diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py new file mode 100644 index 000000000..9cb234be1 --- /dev/null +++ b/comfy_extras/nodes_kandinsky5.py @@ -0,0 +1,136 @@ +import nodes +import node_helpers +import torch +import comfy.model_management +import comfy.utils + +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io + + +class Kandinsky5ImageToVideo(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="Kandinsky5ImageToVideo", + category="conditioning/video_models", + inputs=[ + io.Conditioning.Input("positive"), + io.Conditioning.Input("negative"), + io.Vae.Input("vae"), + io.Int.Input("width", default=768, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("height", default=512, min=16, max=nodes.MAX_RESOLUTION, step=16), + io.Int.Input("length", default=121, min=1, max=nodes.MAX_RESOLUTION, step=4), + io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Image.Input("start_image", optional=True), + ], + outputs=[ + io.Conditioning.Output(display_name="positive"), + io.Conditioning.Output(display_name="negative"), + io.Latent.Output(display_name="latent", tooltip="Empty video latent"), + io.Latent.Output(display_name="cond_latent", tooltip="Clean encoded start images, used to replace the noisy start of the model output latents"), + ], + ) + + @classmethod + def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: + latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) + cond_latent_out = {} + if start_image is not None: + start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) + encoded = vae.encode(start_image[:, :, :, :3]) + cond_latent_out["samples"] = encoded + + mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=start_image.device, dtype=start_image.dtype) + mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 + + positive = node_helpers.conditioning_set_values(positive, {"time_dim_replace": encoded, "concat_mask": mask}) + negative = node_helpers.conditioning_set_values(negative, {"time_dim_replace": encoded, "concat_mask": mask}) + + out_latent = {} + out_latent["samples"] = latent + return io.NodeOutput(positive, negative, out_latent, cond_latent_out) + + +def adaptive_mean_std_normalization(source, reference, clump_mean_low=0.3, clump_mean_high=0.35, clump_std_low=0.35, clump_std_high=0.5): + source_mean = source.mean(dim=(1, 3, 4), keepdim=True) # mean over C, H, W + source_std = source.std(dim=(1, 3, 4), keepdim=True) # std over C, H, W + + reference_mean = torch.clamp(reference.mean(), source_mean - clump_mean_low, source_mean + clump_mean_high) + reference_std = torch.clamp(reference.std(), source_std - clump_std_low, source_std + clump_std_high) + + # normalization + normalized = (source - source_mean) / (source_std + 1e-8) + normalized = normalized * reference_std + reference_mean + + return normalized + + +class NormalizeVideoLatentStart(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="NormalizeVideoLatentStart", + category="conditioning/video_models", + description="Normalizes the initial frames of a video latent to match the mean and standard deviation of subsequent reference frames. Helps reduce differences between the starting frames and the rest of the video.", + inputs=[ + io.Latent.Input("latent"), + io.Int.Input("start_frame_count", default=4, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames to normalize, counted from the start"), + io.Int.Input("reference_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=1, tooltip="Number of latent frames after the start frames to use as reference"), + ], + outputs=[ + io.Latent.Output(display_name="latent"), + ], + ) + + @classmethod + def execute(cls, latent, start_frame_count, reference_frame_count) -> io.NodeOutput: + if latent["samples"].shape[2] <= 1: + return io.NodeOutput(latent) + s = latent.copy() + samples = latent["samples"].clone() + + first_frames = samples[:, :, :start_frame_count] + reference_frames_data = samples[:, :, start_frame_count:start_frame_count+min(reference_frame_count, samples.shape[2]-1)] + normalized_first_frames = adaptive_mean_std_normalization(first_frames, reference_frames_data) + + samples[:, :, :start_frame_count] = normalized_first_frames + s["samples"] = samples + return io.NodeOutput(s) + + +class CLIPTextEncodeKandinsky5(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="CLIPTextEncodeKandinsky5", + category="advanced/conditioning/kandinsky5", + inputs=[ + io.Clip.Input("clip"), + io.String.Input("clip_l", multiline=True, dynamic_prompts=True), + io.String.Input("qwen25_7b", multiline=True, dynamic_prompts=True), + ], + outputs=[ + io.Conditioning.Output(), + ], + ) + + @classmethod + def execute(cls, clip, clip_l, qwen25_7b) -> io.NodeOutput: + tokens = clip.tokenize(clip_l) + tokens["qwen25_7b"] = clip.tokenize(qwen25_7b)["qwen25_7b"] + + return io.NodeOutput(clip.encode_from_tokens_scheduled(tokens)) + + +class Kandinsky5Extension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + Kandinsky5ImageToVideo, + NormalizeVideoLatentStart, + CLIPTextEncodeKandinsky5, + ] + +async def comfy_entrypoint() -> Kandinsky5Extension: + return Kandinsky5Extension() diff --git a/comfy_extras/nodes_logic.py b/comfy_extras/nodes_logic.py new file mode 100644 index 000000000..95a6ba788 --- /dev/null +++ b/comfy_extras/nodes_logic.py @@ -0,0 +1,155 @@ +from typing import TypedDict +from typing_extensions import override +from comfy_api.latest import ComfyExtension, io +from comfy_api.latest import _io + + + +class SwitchNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = io.MatchType.Template("switch") + return io.Schema( + node_id="ComfySwitchNode", + display_name="Switch", + category="logic", + is_experimental=True, + inputs=[ + io.Boolean.Input("switch"), + io.MatchType.Input("on_false", template=template, lazy=True, optional=True), + io.MatchType.Input("on_true", template=template, lazy=True, optional=True), + ], + outputs=[ + io.MatchType.Output(template=template, display_name="output"), + ], + ) + + @classmethod + def check_lazy_status(cls, switch, on_false=..., on_true=...): + # We use ... instead of None, as None is passed for connected-but-unevaluated inputs. + # This trick allows us to ignore the value of the switch and still be able to run execute(). + + # One of the inputs may be missing, in which case we need to evaluate the other input + if on_false is ...: + return ["on_true"] + if on_true is ...: + return ["on_false"] + # Normal lazy switch operation + if switch and on_true is None: + return ["on_true"] + if not switch and on_false is None: + return ["on_false"] + + @classmethod + def validate_inputs(cls, switch, on_false=..., on_true=...): + # This check happens before check_lazy_status(), so we can eliminate the case where + # both inputs are missing. + if on_false is ... and on_true is ...: + return "At least one of on_false or on_true must be connected to Switch node" + return True + + @classmethod + def execute(cls, switch, on_true=..., on_false=...) -> io.NodeOutput: + if on_true is ...: + return io.NodeOutput(on_false) + if on_false is ...: + return io.NodeOutput(on_true) + return io.NodeOutput(on_true if switch else on_false) + + +class DCTestNode(io.ComfyNode): + class DCValues(TypedDict): + combo: str + string: str + integer: int + image: io.Image.Type + subcombo: dict[str] + + @classmethod + def define_schema(cls): + return io.Schema( + node_id="DCTestNode", + display_name="DCTest", + category="logic", + is_output_node=True, + inputs=[_io.DynamicCombo.Input("combo", options=[ + _io.DynamicCombo.Option("option1", [io.String.Input("string")]), + _io.DynamicCombo.Option("option2", [io.Int.Input("integer")]), + _io.DynamicCombo.Option("option3", [io.Image.Input("image")]), + _io.DynamicCombo.Option("option4", [ + _io.DynamicCombo.Input("subcombo", options=[ + _io.DynamicCombo.Option("opt1", [io.Float.Input("float_x"), io.Float.Input("float_y")]), + _io.DynamicCombo.Option("opt2", [io.Mask.Input("mask1", optional=True)]), + ]) + ])] + )], + outputs=[io.AnyType.Output()], + ) + + @classmethod + def execute(cls, combo: DCValues) -> io.NodeOutput: + combo_val = combo["combo"] + if combo_val == "option1": + return io.NodeOutput(combo["string"]) + elif combo_val == "option2": + return io.NodeOutput(combo["integer"]) + elif combo_val == "option3": + return io.NodeOutput(combo["image"]) + elif combo_val == "option4": + return io.NodeOutput(f"{combo['subcombo']}") + else: + raise ValueError(f"Invalid combo: {combo_val}") + + +class AutogrowNamesTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = _io.Autogrow.TemplateNames(input=io.Float.Input("float"), names=["a", "b", "c"]) + return io.Schema( + node_id="AutogrowNamesTestNode", + display_name="AutogrowNamesTest", + category="logic", + inputs=[ + _io.Autogrow.Input("autogrow", template=template) + ], + outputs=[io.String.Output()], + ) + + @classmethod + def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput: + vals = list(autogrow.values()) + combined = ",".join([str(x) for x in vals]) + return io.NodeOutput(combined) + +class AutogrowPrefixTestNode(io.ComfyNode): + @classmethod + def define_schema(cls): + template = _io.Autogrow.TemplatePrefix(input=io.Float.Input("float"), prefix="float", min=1, max=10) + return io.Schema( + node_id="AutogrowPrefixTestNode", + display_name="AutogrowPrefixTest", + category="logic", + inputs=[ + _io.Autogrow.Input("autogrow", template=template) + ], + outputs=[io.String.Output()], + ) + + @classmethod + def execute(cls, autogrow: _io.Autogrow.Type) -> io.NodeOutput: + vals = list(autogrow.values()) + combined = ",".join([str(x) for x in vals]) + return io.NodeOutput(combined) + +class LogicExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + # SwitchNode, + # DCTestNode, + # AutogrowNamesTestNode, + # AutogrowPrefixTestNode, + ] + +async def comfy_entrypoint() -> LogicExtension: + return LogicExtension() diff --git a/comfy_extras/nodes_nop.py b/comfy_extras/nodes_nop.py new file mode 100644 index 000000000..953061bcb --- /dev/null +++ b/comfy_extras/nodes_nop.py @@ -0,0 +1,39 @@ +from comfy_api.latest import ComfyExtension, io +from typing_extensions import override +# If you write a node that is so useless that it breaks ComfyUI it will be featured in this exclusive list + +# "native" block swap nodes are placebo at best and break the ComfyUI memory management system. +# They are also considered harmful because instead of users reporting issues with the built in +# memory management they install these stupid nodes and complain even harder. Now it completely +# breaks with some of the new ComfyUI memory optimizations so I have made the decision to NOP it +# out of all workflows. +class wanBlockSwap(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="wanBlockSwap", + category="", + description="NOP", + inputs=[ + io.Model.Input("model"), + ], + outputs=[ + io.Model.Output(), + ], + is_deprecated=True, + ) + + @classmethod + def execute(cls, model) -> io.NodeOutput: + return io.NodeOutput(model) + + +class NopExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + wanBlockSwap + ] + +async def comfy_entrypoint() -> NopExtension: + return NopExtension() diff --git a/comfy_extras/nodes_rope.py b/comfy_extras/nodes_rope.py new file mode 100644 index 000000000..d1feb031e --- /dev/null +++ b/comfy_extras/nodes_rope.py @@ -0,0 +1,47 @@ +from comfy_api.latest import ComfyExtension, io +from typing_extensions import override + + +class ScaleROPE(io.ComfyNode): + @classmethod + def define_schema(cls): + return io.Schema( + node_id="ScaleROPE", + category="advanced/model_patches", + description="Scale and shift the ROPE of the model.", + is_experimental=True, + inputs=[ + io.Model.Input("model"), + io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1), + io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1), + + io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1), + io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1), + + io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1), + io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1), + + + ], + outputs=[ + io.Model.Output(), + ], + ) + + @classmethod + def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput: + m = model.clone() + m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) + return io.NodeOutput(m) + + +class RopeExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[io.ComfyNode]]: + return [ + ScaleROPE + ] + + +async def comfy_entrypoint() -> RopeExtension: + return RopeExtension() diff --git a/manager_requirements.txt b/manager_requirements.txt new file mode 100644 index 000000000..b95cefb74 --- /dev/null +++ b/manager_requirements.txt @@ -0,0 +1 @@ +comfyui_manager==4.0.3b4 diff --git a/models/latent_upscale_models/put_latent_upscale_models_here b/models/latent_upscale_models/put_latent_upscale_models_here new file mode 100644 index 000000000..e69de29bb diff --git a/pyproject.toml b/pyproject.toml index 9642a4d5c..020956ba5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "comfyui" -version = "0.3.66" +version = "0.3.76" description = "An installable version of ComfyUI" readme = "README.md" authors = [ @@ -18,16 +18,16 @@ classifiers = [ ] dependencies = [ - "comfyui-frontend-package>=1.28.7", - "comfyui-workflow-templates>=0.1.95,<0.3.0", - "comfyui-embedded-docs>=0.3.0", + "comfyui-frontend-package>=1.33.10", + "comfyui-workflow-templates>=0.7.51", + "comfyui-embedded-docs>=0.3.1", "torch", "torchvision", "torchdiffeq>=0.2.3", "torchsde>=0.2.6", "einops>=0.6.0", "open-clip-torch>=2.24.0", - "transformers>=4.46.0,!=4.53.0,!=4.53.1,!=4.53.2,!=4.57.0", + "transformers>=4.57.1", "tokenizers>=0.13.3", "sentencepiece", "peft>=0.10.0", @@ -201,6 +201,8 @@ comfyui-manager = [ "uv", "chardet", "pip", + # todo: bold move + "comfyui_manager==4.0.3b4", ] [project.scripts] @@ -302,3 +304,50 @@ allow-direct-references = true [tool.hatch.build.targets.wheel] packages = ["comfy/", "comfy_extras/", "comfy_api/", "comfy_api_nodes/", "comfy_config/", "comfy_execution/", "comfy_compatibility/"] + +[tool.pylint] +master.py-version = "3.10" +master.extension-pkg-allow-list = [ + "pydantic", +] +reports.output-format = "colorized" +similarities.ignore-imports = "yes" +messages_control.disable = [ + "missing-module-docstring", + "missing-class-docstring", + "missing-function-docstring", + "line-too-long", + "too-few-public-methods", + "too-many-public-methods", + "too-many-instance-attributes", + "too-many-positional-arguments", + "broad-exception-raised", + "too-many-lines", + "invalid-name", + "unused-argument", + "broad-exception-caught", + "consider-using-with", + "fixme", + "too-many-statements", + "too-many-branches", + "too-many-locals", + "too-many-arguments", + "too-many-return-statements", + "too-many-nested-blocks", + "duplicate-code", + "abstract-method", + "superfluous-parens", + "arguments-differ", + "redefined-builtin", + "unnecessary-lambda", + "dangerous-default-value", + "invalid-overridden-method", + # next warnings should be fixed in future + "bad-classmethod-argument", # Class method should have 'cls' as first argument + "wrong-import-order", # Standard imports should be placed before third party imports + "ungrouped-imports", + "unnecessary-pass", + "unnecessary-lambda-assignment", + "no-else-return", + "unused-variable", +] diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py new file mode 100644 index 000000000..3a54941e6 --- /dev/null +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -0,0 +1,233 @@ +import unittest +import torch +import sys +import os +import json + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +def has_gpu(): + return torch.cuda.is_available() + +from comfy.cli_args import args +if not has_gpu(): + args.cpu = True + +from comfy import ops +from comfy.quant_ops import QuantizedTensor +import comfy.utils + + +class SimpleModel(torch.nn.Module): + def __init__(self, operations=ops.disable_weight_init): + super().__init__() + self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) + self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) + self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) + + def forward(self, x): + x = self.layer1(x) + x = torch.nn.functional.relu(x) + x = self.layer2(x) + x = torch.nn.functional.relu(x) + x = self.layer3(x) + return x + + +class TestMixedPrecisionOps(unittest.TestCase): + + def test_all_layers_standard(self): + """Test that model with no quantization works normally""" + # Create model + model = SimpleModel(operations=ops.mixed_precision_ops({})) + + # Initialize weights manually + model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) + model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) + model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16)) + model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) + model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) + model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) + + # Initialize weight_function and bias_function + for layer in [model.layer1, model.layer2, model.layer3]: + layer.weight_function = [] + layer.bias_function = [] + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + self.assertEqual(output.dtype, torch.bfloat16) + + def test_mixed_precision_load(self): + """Test loading a mixed precision model from state dict""" + # Configure mixed precision: layer1 is FP8, layer2 and layer3 are standard + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + }, + "layer3": { + "format": "float8_e4m3fn", + "params": {} + } + } + + # Create state dict with mixed precision + fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) + + state_dict = { + # Layer 1: FP8 E4M3FN + "layer1.weight": fp8_weight1, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + + # Layer 2: Standard BF16 + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + + # Layer 3: FP8 E4M3FN + "layer3.weight": fp8_weight3, + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), + } + + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + # Create model and load state dict (strict=False because custom loading pops keys) + model = SimpleModel(operations=ops.mixed_precision_ops({})) + model.load_state_dict(state_dict, strict=False) + + # Verify weights are wrapped in QuantizedTensor + self.assertIsInstance(model.layer1.weight, QuantizedTensor) + self.assertEqual(model.layer1.weight._layout_type, "TensorCoreFP8Layout") + + # Layer 2 should NOT be quantized + self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) + + # Layer 3 should be quantized + self.assertIsInstance(model.layer3.weight, QuantizedTensor) + self.assertEqual(model.layer3.weight._layout_type, "TensorCoreFP8Layout") + + # Verify scales were loaded + self.assertEqual(model.layer1.weight._layout_params['scale'].item(), 2.0) + self.assertEqual(model.layer3.weight._layout_params['scale'].item(), 1.5) + + # Forward pass + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + with torch.inference_mode(): + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_state_dict_quantized_preserved(self): + """Test that quantized weights are preserved in state_dict()""" + # Configure mixed precision + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict1 = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + model = SimpleModel(operations=ops.mixed_precision_ops({})) + model.load_state_dict(state_dict1, strict=False) + + # Save state dict + state_dict2 = model.state_dict() + + # Verify layer1.weight is a QuantizedTensor with scale preserved + self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor) + self.assertEqual(state_dict2["layer1.weight"]._layout_params['scale'].item(), 3.0) + self.assertEqual(state_dict2["layer1.weight"]._layout_type, "TensorCoreFP8Layout") + + # Verify non-quantized layers are standard tensors + self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) + self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) + + def test_weight_function_compatibility(self): + """Test that weight_function (LoRA) works with quantized layers""" + # Configure FP8 quantization + layer_quant_config = { + "layer1": { + "format": "float8_e4m3fn", + "params": {} + } + } + + # Create and load model + fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) + state_dict = { + "layer1.weight": fp8_weight, + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + model = SimpleModel(operations=ops.mixed_precision_ops({})) + model.load_state_dict(state_dict, strict=False) + + # Add a weight function (simulating LoRA) + # This should trigger dequantization during forward pass + def apply_lora(weight): + lora_delta = torch.randn_like(weight) * 0.01 + return weight + lora_delta + + model.layer1.weight_function.append(apply_lora) + + # Forward pass should work with LoRA (triggers weight_function path) + input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) + output = model(input_tensor) + + self.assertEqual(output.shape, (5, 40)) + + def test_error_handling_unknown_format(self): + """Test that unknown formats raise error""" + # Configure with unknown format + layer_quant_config = { + "layer1": { + "format": "unknown_format_xyz", + "params": {} + } + } + + # Create state dict + state_dict = { + "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), + "layer1.bias": torch.randn(20, dtype=torch.bfloat16), + "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), + "layer2.bias": torch.randn(30, dtype=torch.bfloat16), + "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), + "layer3.bias": torch.randn(40, dtype=torch.bfloat16), + } + + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS + model = SimpleModel(operations=ops.mixed_precision_ops({})) + with self.assertRaises(KeyError): + model.load_state_dict(state_dict, strict=False) + +if __name__ == "__main__": + unittest.main() + diff --git a/tests-unit/comfy_quant/test_quant_registry.py b/tests-unit/comfy_quant/test_quant_registry.py new file mode 100644 index 000000000..9cb54ede8 --- /dev/null +++ b/tests-unit/comfy_quant/test_quant_registry.py @@ -0,0 +1,190 @@ +import unittest +import torch +import sys +import os + +# Add comfy to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) + +def has_gpu(): + return torch.cuda.is_available() + +from comfy.cli_args import args +if not has_gpu(): + args.cpu = True + +from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout + + +class TestQuantizedTensor(unittest.TestCase): + """Test the QuantizedTensor subclass with FP8 layout""" + + def test_creation(self): + """Test creating a QuantizedTensor with TensorCoreFP8Layout""" + fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(2.0) + layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} + + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.shape, (256, 128)) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt._layout_params['scale'], scale) + self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) + self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") + + def test_dequantize(self): + """Test explicit dequantization""" + + fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(3.0) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + dequantized = qt.dequantize() + + self.assertEqual(dequantized.dtype, torch.float32) + self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) + + def test_from_float(self): + """Test creating QuantizedTensor from float tensor""" + float_tensor = torch.randn(64, 32, dtype=torch.float32) + scale = torch.tensor(1.5) + + qt = QuantizedTensor.from_float( + float_tensor, + "TensorCoreFP8Layout", + scale=scale, + dtype=torch.float8_e4m3fn + ) + + self.assertIsInstance(qt, QuantizedTensor) + self.assertEqual(qt.dtype, torch.float8_e4m3fn) + self.assertEqual(qt.shape, (64, 32)) + + # Verify dequantization gives approximately original values + dequantized = qt.dequantize() + mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() + self.assertLess(mean_rel_error, 0.1) + + +class TestGenericUtilities(unittest.TestCase): + """Test generic utility operations""" + + def test_detach(self): + """Test detach operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + + # Detach should return a new QuantizedTensor + qt_detached = qt.detach() + + self.assertIsInstance(qt_detached, QuantizedTensor) + self.assertEqual(qt_detached.shape, qt.shape) + self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout") + + def test_clone(self): + """Test clone operation on quantized tensor""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + + # Clone should return a new QuantizedTensor + qt_cloned = qt.clone() + + self.assertIsInstance(qt_cloned, QuantizedTensor) + self.assertEqual(qt_cloned.shape, qt.shape) + self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout") + + # Verify it's a deep copy + self.assertIsNot(qt_cloned._qdata, qt._qdata) + + @unittest.skipUnless(has_gpu(), "GPU not available") + def test_to_device(self): + """Test device transfer""" + fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) + scale = torch.tensor(1.5) + layout_params = {'scale': scale, 'orig_dtype': torch.float32} + qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) + + # Moving to same device should work (CPU to CPU) + qt_cpu = qt.to('cpu') + + self.assertIsInstance(qt_cpu, QuantizedTensor) + self.assertEqual(qt_cpu.device.type, 'cpu') + self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') + + +class TestTensorCoreFP8Layout(unittest.TestCase): + """Test the TensorCoreFP8Layout implementation""" + + def test_quantize(self): + """Test quantization method""" + float_tensor = torch.randn(32, 64, dtype=torch.float32) + scale = torch.tensor(1.5) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + self.assertEqual(qdata.dtype, torch.float8_e4m3fn) + self.assertEqual(qdata.shape, float_tensor.shape) + self.assertIn('scale', layout_params) + self.assertIn('orig_dtype', layout_params) + self.assertEqual(layout_params['orig_dtype'], torch.float32) + + def test_dequantize(self): + """Test dequantization method""" + float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 + scale = torch.tensor(1.0) + + qdata, layout_params = TensorCoreFP8Layout.quantize( + float_tensor, + scale=scale, + dtype=torch.float8_e4m3fn + ) + + dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) + + # Should approximately match original + self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) + + +class TestFallbackMechanism(unittest.TestCase): + """Test fallback for unsupported operations""" + + def test_unsupported_op_dequantizes(self): + """Test that unsupported operations fall back to dequantization""" + # Set seed for reproducibility + torch.manual_seed(42) + + # Create quantized tensor + a_fp32 = torch.randn(10, 20, dtype=torch.float32) + scale = torch.tensor(1.0) + a_q = QuantizedTensor.from_float( + a_fp32, + "TensorCoreFP8Layout", + scale=scale, + dtype=torch.float8_e4m3fn + ) + + # Call an operation that doesn't have a registered handler + # For example, torch.abs + result = torch.abs(a_q) + + # Should work via fallback (dequantize → abs → return) + self.assertNotIsInstance(result, QuantizedTensor) + expected = torch.abs(a_fp32) + # FP8 introduces quantization error, so use loose tolerance + mean_error = (result - expected).abs().mean() + self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index 356997950..d4ebb54e5 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -122,19 +122,17 @@ class ComfyClient: @pytest.mark.execution class TestExecution: # Initialize server and client - @fixture(scope="class", params=[ - # ??? no cache lru, should cache results, etc. etc. - # (lru_size, should_cache_results) - (0, True), - (100, True), + @fixture(scope="class", autouse=True, params=[ + { "extra_args" : [], "should_cache_results" : True }, + { "extra_args" : ["--cache-lru", 0], "should_cache_results" : True }, + { "extra_args" : ["--cache-lru", 100], "should_cache_results" : True }, + { "extra_args" : ["--cache-none"], "should_cache_results" : False }, ]) async def client(self, request): from ..inference.testing_pack import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS - # ??? todo: we have to deal with this - lru_size, should_cache_results = request.param configuration = default_configuration() - configuration.cache_lru = lru_size + configuration.update(request.param["extra_args"]) progress_handler = _ProgressHandler() with context_add_custom_nodes(ExportedNodes(NODE_CLASS_MAPPINGS=NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS=NODE_DISPLAY_NAME_MAPPINGS)): @@ -162,7 +160,7 @@ class TestExecution: assert result.did_run(mask) assert result.did_run(lazy_mix) - async def test_full_cache(self, client: ComfyClient, builder: GraphBuilder): + async def test_full_cache(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -174,9 +172,12 @@ class TestExecution: await client.run(g) result2 = await client.run(g) for node_id, node in g.nodes.items(): - assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached" + else: + assert result2.did_run(node), f"Node {node_id} was cached, but should have been run" - async def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder): + async def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1) input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1) @@ -188,8 +189,12 @@ class TestExecution: await client.run(g) mask.inputs['value'] = 0.4 result2 = await client.run(g) - assert not result2.did_run(input1), "Input1 should have been cached" - assert not result2.did_run(input2), "Input2 should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(input1), "Input1 should have been cached" + assert not result2.did_run(input2), "Input2 should have been cached" + else: + assert result2.did_run(input1), "Input1 should have been rerun" + assert result2.did_run(input2), "Input2 should have been rerun" async def test_error(self, client: ComfyClient, builder: GraphBuilder): g = builder @@ -313,7 +318,7 @@ class TestExecution: assert 'prompt_id' in e.args[0], f"Did not get back a proper error message: {e}" assert e.args[0]['node_id'] == generator.id, "Error should have been on the generator node" - async def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder): + async def test_custom_is_changed(self, client: ComfyClient, builder: GraphBuilder, server): g = builder # Creating the nodes in this specific order previously caused a bug save = g.node("SaveImage") @@ -329,7 +334,10 @@ class TestExecution: result3 = await client.run(g) result4 = await client.run(g) assert result1.did_run(is_changed), "is_changed should have been run" - assert not result2.did_run(is_changed), "is_changed should have been cached" + if server["should_cache_results"]: + assert not result2.did_run(is_changed), "is_changed should have been cached" + else: + assert result2.did_run(is_changed), "is_changed should have been re-run" assert result3.did_run(is_changed), "is_changed should have been re-run" assert result4.did_run(is_changed), "is_changed should not have been cached" @@ -435,7 +443,7 @@ class TestExecution: assert len(images2) == 1, "Should have 1 image" # This tests that only constant outputs are used in the call to `IS_CHANGED` - async def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder): + async def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) @@ -451,7 +459,11 @@ class TestExecution: images = result.get_images(output) assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" - assert not result.did_run(test_node), "The execution should have been cached" + if server["should_cache_results"]: + assert not result.did_run(test_node), "The execution should have been cached" + else: + assert result.did_run(test_node), "The execution should have been re-run" + async def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): # Warmup execution to ensure server is fully initialized diff --git a/tests/execution/test_public_api.py b/tests/execution/test_public_api.py new file mode 100644 index 000000000..52bc2fcd8 --- /dev/null +++ b/tests/execution/test_public_api.py @@ -0,0 +1,153 @@ +""" +Tests for public ComfyAPI and ComfyAPISync functions. + +These tests verify that the public API methods work correctly in both sync and async contexts, +ensuring that the sync wrapper generation (via get_type_hints() in async_to_sync.py) correctly +handles string annotations from 'from __future__ import annotations'. +""" + +import pytest +import time +import subprocess +import torch +from pytest import fixture +from comfy_execution.graph_utils import GraphBuilder +from tests.execution.test_execution import ComfyClient + + +@pytest.mark.execution +class TestPublicAPI: + """Test suite for public ComfyAPI and ComfyAPISync methods.""" + + @fixture(scope="class", autouse=True) + def _server(self, args_pytest): + """Start ComfyUI server for testing.""" + pargs = [ + 'python', 'main.py', + '--output-directory', args_pytest["output_dir"], + '--listen', args_pytest["listen"], + '--port', str(args_pytest["port"]), + '--extra-model-paths-config', 'tests/execution/extra_model_paths.yaml', + '--cpu', + ] + p = subprocess.Popen(pargs) + yield + p.kill() + torch.cuda.empty_cache() + + @fixture(scope="class", autouse=True) + def shared_client(self, args_pytest, _server): + """Create shared client with connection retry.""" + client = ComfyClient() + n_tries = 5 + for i in range(n_tries): + time.sleep(4) + try: + client.connect(listen=args_pytest["listen"], port=args_pytest["port"]) + break + except ConnectionRefusedError: + if i == n_tries - 1: + raise + yield client + del client + torch.cuda.empty_cache() + + @fixture + def client(self, shared_client, request): + """Set test name for each test.""" + shared_client.set_test_name(f"public_api[{request.node.name}]") + yield shared_client + + @fixture + def builder(self, request): + """Create GraphBuilder for each test.""" + yield GraphBuilder(prefix=request.node.name) + + def test_sync_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder): + """Test that TestSyncProgressUpdate executes without errors. + + This test validates that api_sync.execution.set_progress() works correctly, + which is the primary code path fixed by adding get_type_hints() to async_to_sync.py. + """ + g = builder + image = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1) + + # Use TestSyncProgressUpdate with short sleep + progress_node = g.node("TestSyncProgressUpdate", + value=image.out(0), + sleep_seconds=0.5) + output = g.node("SaveImage", images=progress_node.out(0)) + + # Execute workflow + result = client.run(g) + + # Verify execution + assert result.did_run(progress_node), "Progress node should have executed" + assert result.did_run(output), "Output node should have executed" + + # Verify output + images = result.get_images(output) + assert len(images) == 1, "Should have produced 1 image" + + def test_async_progress_update_executes(self, client: ComfyClient, builder: GraphBuilder): + """Test that TestAsyncProgressUpdate executes without errors. + + This test validates that await api.execution.set_progress() works correctly + in async contexts. + """ + g = builder + image = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1) + + # Use TestAsyncProgressUpdate with short sleep + progress_node = g.node("TestAsyncProgressUpdate", + value=image.out(0), + sleep_seconds=0.5) + output = g.node("SaveImage", images=progress_node.out(0)) + + # Execute workflow + result = client.run(g) + + # Verify execution + assert result.did_run(progress_node), "Async progress node should have executed" + assert result.did_run(output), "Output node should have executed" + + # Verify output + images = result.get_images(output) + assert len(images) == 1, "Should have produced 1 image" + + def test_sync_and_async_progress_together(self, client: ComfyClient, builder: GraphBuilder): + """Test both sync and async progress updates in same workflow. + + This test ensures that both ComfyAPISync and ComfyAPI can coexist and work + correctly in the same workflow execution. + """ + g = builder + image1 = g.node("StubImage", content="BLACK", height=256, width=256, batch_size=1) + image2 = g.node("StubImage", content="WHITE", height=256, width=256, batch_size=1) + + # Use both types of progress nodes + sync_progress = g.node("TestSyncProgressUpdate", + value=image1.out(0), + sleep_seconds=0.3) + async_progress = g.node("TestAsyncProgressUpdate", + value=image2.out(0), + sleep_seconds=0.3) + + # Create outputs + output1 = g.node("SaveImage", images=sync_progress.out(0)) + output2 = g.node("SaveImage", images=async_progress.out(0)) + + # Execute workflow + result = client.run(g) + + # Both should execute successfully + assert result.did_run(sync_progress), "Sync progress node should have executed" + assert result.did_run(async_progress), "Async progress node should have executed" + assert result.did_run(output1), "First output node should have executed" + assert result.did_run(output2), "Second output node should have executed" + + # Verify outputs + images1 = result.get_images(output1) + images2 = result.get_images(output2) + assert len(images1) == 1, "Should have produced 1 image from sync node" + assert len(images2) == 1, "Should have produced 1 image from async node" diff --git a/tests/unit/app_test/user_manager_system_user_test.py b/tests/unit/app_test/user_manager_system_user_test.py new file mode 100644 index 000000000..63b1ac5e5 --- /dev/null +++ b/tests/unit/app_test/user_manager_system_user_test.py @@ -0,0 +1,193 @@ +"""Tests for System User Protection in user_manager.py + +Tests cover: +- get_request_user_id(): 1st defense layer - blocks System Users from HTTP headers +- get_request_user_filepath(): 2nd defense layer - structural blocking via get_public_user_directory() +- add_user(): 3rd defense layer - prevents creation of System User names +- Defense layers integration tests +""" + +import pytest +from unittest.mock import MagicMock, patch +import tempfile + +import folder_paths +from app.user_manager import UserManager + + +@pytest.fixture +def mock_user_directory(): + """Create a temporary user directory.""" + with tempfile.TemporaryDirectory() as temp_dir: + original_dir = folder_paths.get_user_directory() + folder_paths.set_user_directory(temp_dir) + yield temp_dir + folder_paths.set_user_directory(original_dir) + + +@pytest.fixture +def user_manager(mock_user_directory): + """Create a UserManager instance for testing.""" + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + manager = UserManager() + # Add a default user for testing + manager.users = {"default": "default", "test_user_123": "Test User"} + yield manager + + +@pytest.fixture +def mock_request(): + """Create a mock request object.""" + request = MagicMock() + request.headers = {} + return request + + +class TestGetRequestUserId: + """Tests for get_request_user_id() - 1st defense layer. + + Verifies: + - System Users (__ prefix) in HTTP header are rejected with KeyError + - Public Users pass through successfully + """ + + def test_system_user_raises_error(self, user_manager, mock_request): + """Test System User in header raises KeyError.""" + mock_request.headers = {"comfy-user": "__system"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + def test_system_user_cache_raises_error(self, user_manager, mock_request): + """Test System User cache raises KeyError.""" + mock_request.headers = {"comfy-user": "__cache"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + def test_normal_user_works(self, user_manager, mock_request): + """Test normal user access works.""" + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + user_id = user_manager.get_request_user_id(mock_request) + assert user_id == "default" + + def test_unknown_user_raises_error(self, user_manager, mock_request): + """Test unknown user raises KeyError.""" + mock_request.headers = {"comfy-user": "unknown_user"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError, match="Unknown user"): + user_manager.get_request_user_id(mock_request) + + +class TestGetRequestUserFilepath: + """Tests for get_request_user_filepath() - 2nd defense layer. + + Verifies: + - Returns None when get_public_user_directory() returns None (System User) + - Acts as backup defense if 1st layer is bypassed + """ + + def test_system_user_returns_none(self, user_manager, mock_request, mock_user_directory): + """Test System User returns None (structural blocking).""" + # First, we need to mock get_request_user_id to return System User + # But actually, get_request_user_id will raise KeyError first + # So we test via get_public_user_directory returning None + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + # Patch get_public_user_directory to return None for testing + with patch.object(folder_paths, 'get_public_user_directory', return_value=None): + result = user_manager.get_request_user_filepath(mock_request, "test.txt") + assert result is None + + def test_normal_user_gets_path(self, user_manager, mock_request, mock_user_directory): + """Test normal user gets valid filepath.""" + mock_request.headers = {"comfy-user": "default"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + path = user_manager.get_request_user_filepath(mock_request, "test.txt") + assert path is not None + assert "default" in path + assert path.endswith("test.txt") + + +class TestAddUser: + """Tests for add_user() - 3rd defense layer (creation-time blocking). + + Verifies: + - System User name (__ prefix) creation is rejected with ValueError + - Sanitized usernames that become System User are also rejected + """ + + def test_system_user_prefix_name_raises(self, user_manager): + """Test System User prefix in name raises ValueError.""" + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__system") + + def test_system_user_prefix_cache_raises(self, user_manager): + """Test System User cache prefix raises ValueError.""" + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__cache") + + def test_sanitized_system_user_prefix_raises(self, user_manager): + """Test sanitized name becoming System User prefix raises ValueError (bypass prevention).""" + # "__test" directly starts with System User prefix + with pytest.raises(ValueError, match="System User prefix not allowed"): + user_manager.add_user("__test") + + def test_normal_user_creation(self, user_manager, mock_user_directory): + """Test normal user creation works.""" + user_id = user_manager.add_user("Normal User") + assert user_id is not None + assert not user_id.startswith("__") + assert "Normal-User" in user_id or "Normal_User" in user_id + + def test_empty_name_raises(self, user_manager): + """Test empty name raises ValueError.""" + with pytest.raises(ValueError, match="username not provided"): + user_manager.add_user("") + + def test_whitespace_only_raises(self, user_manager): + """Test whitespace-only name raises ValueError.""" + with pytest.raises(ValueError, match="username not provided"): + user_manager.add_user(" ") + + +class TestDefenseLayers: + """Integration tests for all three defense layers. + + Verifies: + - Each defense layer blocks System Users independently + - System User bypass is impossible through any layer + """ + + def test_layer1_get_request_user_id(self, user_manager, mock_request): + """Test 1st defense layer blocks System Users.""" + mock_request.headers = {"comfy-user": "__system"} + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + with pytest.raises(KeyError): + user_manager.get_request_user_id(mock_request) + + def test_layer2_get_public_user_directory(self): + """Test 2nd defense layer blocks System Users.""" + result = folder_paths.get_public_user_directory("__system") + assert result is None + + def test_layer3_add_user(self, user_manager): + """Test 3rd defense layer blocks System User creation.""" + with pytest.raises(ValueError): + user_manager.add_user("__system") diff --git a/tests/unit/folder_paths_test/system_user_test.py b/tests/unit/folder_paths_test/system_user_test.py new file mode 100644 index 000000000..cd46459f1 --- /dev/null +++ b/tests/unit/folder_paths_test/system_user_test.py @@ -0,0 +1,206 @@ +"""Tests for System User Protection in folder_paths.py + +Tests cover: +- get_system_user_directory(): Internal API for custom nodes to access System User directories +- get_public_user_directory(): HTTP endpoint access with System User blocking +- Backward compatibility: Existing APIs unchanged +- Security: Path traversal and injection prevention +""" + +import pytest +import os +import tempfile + +from folder_paths import ( + get_system_user_directory, + get_public_user_directory, + get_user_directory, + set_user_directory, +) + + +@pytest.fixture(scope="module") +def mock_user_directory(): + """Create a temporary user directory for testing.""" + with tempfile.TemporaryDirectory() as temp_dir: + original_dir = get_user_directory() + set_user_directory(temp_dir) + yield temp_dir + set_user_directory(original_dir) + + +class TestGetSystemUserDirectory: + """Tests for get_system_user_directory() - internal API for System User directories. + + Verifies: + - Custom nodes can access System User directories via internal API + - Input validation prevents path traversal attacks + """ + + def test_default_name(self, mock_user_directory): + """Test default 'system' name.""" + path = get_system_user_directory() + assert path.endswith("__system") + assert mock_user_directory in path + + def test_custom_name(self, mock_user_directory): + """Test custom system user name.""" + path = get_system_user_directory("cache") + assert path.endswith("__cache") + assert "__cache" in path + + def test_name_with_underscore(self, mock_user_directory): + """Test name with underscore in middle.""" + path = get_system_user_directory("my_cache") + assert "__my_cache" in path + + def test_empty_name_raises(self): + """Test empty name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + get_system_user_directory("") + + def test_none_name_raises(self): + """Test None name raises ValueError.""" + with pytest.raises(ValueError, match="cannot be empty"): + get_system_user_directory(None) + + def test_name_starting_with_underscore_raises(self): + """Test name starting with underscore raises ValueError.""" + with pytest.raises(ValueError, match="should not start with underscore"): + get_system_user_directory("_system") + + def test_path_traversal_raises(self): + """Test path traversal attempt raises ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("../escape") + + def test_path_traversal_middle_raises(self): + """Test path traversal in middle raises ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("system/../other") + + def test_special_chars_raise(self): + """Test special characters raise ValueError (security).""" + with pytest.raises(ValueError, match="Invalid system user name"): + get_system_user_directory("system!") + + def test_returns_absolute_path(self, mock_user_directory): + """Test returned path is absolute.""" + path = get_system_user_directory("test") + assert os.path.isabs(path) + + +class TestGetPublicUserDirectory: + """Tests for get_public_user_directory() - HTTP endpoint access with System User blocking. + + Verifies: + - System Users (__ prefix) return None, blocking HTTP access + - Public Users get valid paths + - New endpoints using this function are automatically protected + """ + + def test_normal_user(self, mock_user_directory): + """Test normal user returns valid path.""" + path = get_public_user_directory("default") + assert path is not None + assert "default" in path + assert mock_user_directory in path + + def test_system_user_returns_none(self): + """Test System User (__ prefix) returns None - blocks HTTP access.""" + assert get_public_user_directory("__system") is None + + def test_system_user_cache_returns_none(self): + """Test System User cache returns None.""" + assert get_public_user_directory("__cache") is None + + def test_empty_user_returns_none(self): + """Test empty user returns None.""" + assert get_public_user_directory("") is None + + def test_none_user_returns_none(self): + """Test None user returns None.""" + assert get_public_user_directory(None) is None + + def test_header_injection_returns_none(self): + """Test header injection attempt returns None (security).""" + assert get_public_user_directory("__system\r\nX-Injected: true") is None + + def test_null_byte_injection_returns_none(self): + """Test null byte injection handling (security).""" + # Note: startswith check happens before any path operations + result = get_public_user_directory("user\x00__system") + # This should return a path since it doesn't start with __ + # The actual security comes from the path not being __* + assert result is not None or result is None # Depends on validation + + def test_path_traversal_attempt(self, mock_user_directory): + """Test path traversal attempt handling.""" + # This function doesn't validate paths, only reserved prefix + # Path traversal should be handled by the caller + path = get_public_user_directory("../../../etc/passwd") + # Returns path but doesn't start with __, so not None + # Actual path validation happens in user_manager + assert path is not None or "__" not in "../../../etc/passwd" + + def test_returns_absolute_path(self, mock_user_directory): + """Test returned path is absolute.""" + path = get_public_user_directory("testuser") + assert path is not None + assert os.path.isabs(path) + + +class TestBackwardCompatibility: + """Tests for backward compatibility with existing APIs. + + Verifies: + - get_user_directory() API unchanged + - Existing user data remains accessible + """ + + def test_get_user_directory_unchanged(self, mock_user_directory): + """Test get_user_directory() still works as before.""" + user_dir = get_user_directory() + assert user_dir is not None + assert os.path.isabs(user_dir) + assert user_dir == mock_user_directory + + def test_existing_user_accessible(self, mock_user_directory): + """Test existing users can access their directories.""" + path = get_public_user_directory("default") + assert path is not None + assert "default" in path + + +class TestEdgeCases: + """Tests for edge cases in System User detection. + + Verifies: + - Only __ prefix is blocked (not _, not middle __) + - Bypass attempts are prevented + """ + + def test_prefix_only(self): + """Test prefix-only string is blocked.""" + assert get_public_user_directory("__") is None + + def test_single_underscore_allowed(self): + """Test single underscore prefix is allowed (not System User).""" + path = get_public_user_directory("_system") + assert path is not None + assert "_system" in path + + def test_triple_underscore_blocked(self): + """Test triple underscore is blocked (starts with __).""" + assert get_public_user_directory("___system") is None + + def test_underscore_in_middle_allowed(self): + """Test underscore in middle is allowed.""" + path = get_public_user_directory("my__system") + assert path is not None + assert "my__system" in path + + def test_leading_space_allowed(self): + """Test leading space + prefix is allowed (doesn't start with __).""" + path = get_public_user_directory(" __system") + assert path is not None diff --git a/tests/unit/folder_paths_test/test_folder_paths_types.py b/tests/unit/folder_paths_test/test_folder_paths_types.py new file mode 100644 index 000000000..1dd710f3f --- /dev/null +++ b/tests/unit/folder_paths_test/test_folder_paths_types.py @@ -0,0 +1,23 @@ + +import pytest +from comfy.cmd import folder_paths + +def test_folder_paths_interface_sanity(): + """ + Basic sanity check to ensure functions added to folder_paths.pyi exist in folder_paths.py at runtime. + """ + # Check for functions recently added/modified + assert hasattr(folder_paths, "get_system_user_directory"), "get_system_user_directory missing from runtime" + assert hasattr(folder_paths, "get_public_user_directory"), "get_public_user_directory missing from runtime" + assert hasattr(folder_paths, "get_input_directory"), "get_input_directory missing from runtime" + + # Check variables + assert hasattr(folder_paths, "extension_mimetypes_cache"), "extension_mimetypes_cache missing from runtime" + + # Minimal signature check (can call them with defaults if possible, but some might require setup) + # get_input_directory has a default now + # We might not be able to call it if it depends on execution context not being set up, + # but we can check if it is callable. + assert callable(folder_paths.get_input_directory) + assert callable(folder_paths.get_system_user_directory) + assert callable(folder_paths.get_public_user_directory) diff --git a/tests/unit/prompt_server_test/system_user_endpoint_test.py b/tests/unit/prompt_server_test/system_user_endpoint_test.py new file mode 100644 index 000000000..22ac00af9 --- /dev/null +++ b/tests/unit/prompt_server_test/system_user_endpoint_test.py @@ -0,0 +1,375 @@ +"""E2E Tests for System User Protection HTTP Endpoints + +Tests cover: +- HTTP endpoint blocking: System Users cannot access /userdata (GET, POST, DELETE, move) +- User creation blocking: System User names cannot be created via POST /users +- Backward compatibility: Public Users work as before +- Custom node scenario: Internal API works while HTTP is blocked +- Structural security: get_public_user_directory() provides automatic protection +""" + +import pytest +import os +from aiohttp import web +from app.user_manager import UserManager +from unittest.mock import patch +import folder_paths + + +@pytest.fixture +def mock_user_directory(tmp_path): + """Create a temporary user directory.""" + original_dir = folder_paths.get_user_directory() + folder_paths.set_user_directory(str(tmp_path)) + yield tmp_path + folder_paths.set_user_directory(original_dir) + + +@pytest.fixture +def user_manager_multi_user(mock_user_directory): + """Create UserManager in multi-user mode.""" + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + um = UserManager() + # Add test users + um.users = {"default": "default", "test_user_123": "Test User"} + yield um + + +@pytest.fixture +def app_multi_user(user_manager_multi_user): + """Create app with multi-user mode enabled.""" + app = web.Application() + routes = web.RouteTableDef() + user_manager_multi_user.add_routes(routes) + app.add_routes(routes) + return app + + +class TestSystemUserEndpointBlocking: + """E2E tests for System User blocking on all HTTP endpoints. + + Verifies: + - GET /userdata blocked for System Users + - POST /userdata blocked for System Users + - DELETE /userdata blocked for System Users + - POST /userdata/.../move/... blocked for System Users + """ + + @pytest.mark.asyncio + async def test_userdata_get_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + GET /userdata with System User header should be blocked. + """ + # Create test directory for System User (simulating internal creation) + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + (system_user_dir / "secret.txt").write_text("sensitive data") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + # Attempt to access System User's data via HTTP + resp = await client.get( + "/userdata?dir=.", + headers={"comfy-user": "__system"} + ) + + # Should be blocked (403 Forbidden or similar error) + assert resp.status in [400, 403, 500], \ + f"System User access should be blocked, got {resp.status}" + + @pytest.mark.asyncio + async def test_userdata_post_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + POST /userdata with System User header should be blocked. + """ + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/test.txt", + headers={"comfy-user": "__system"}, + data=b"malicious content" + ) + + assert resp.status in [400, 403, 500], \ + f"System User write should be blocked, got {resp.status}" + + # Verify no file was created + assert not (mock_user_directory / "__system" / "test.txt").exists() + + @pytest.mark.asyncio + async def test_userdata_delete_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + DELETE /userdata with System User header should be blocked. + """ + # Create a file in System User directory + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + secret_file = system_user_dir / "secret.txt" + secret_file.write_text("do not delete") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.delete( + "/userdata/secret.txt", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User delete should be blocked, got {resp.status}" + + # Verify file still exists + assert secret_file.exists() + + @pytest.mark.asyncio + async def test_v2_userdata_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + GET /v2/userdata with System User header should be blocked. + """ + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/v2/userdata", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User v2 access should be blocked, got {resp.status}" + + @pytest.mark.asyncio + async def test_move_userdata_blocks_system_user( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + POST /userdata/{file}/move/{dest} with System User header should be blocked. + """ + system_user_dir = mock_user_directory / "__system" + system_user_dir.mkdir() + (system_user_dir / "source.txt").write_text("sensitive data") + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/source.txt/move/dest.txt", + headers={"comfy-user": "__system"} + ) + + assert resp.status in [400, 403, 500], \ + f"System User move should be blocked, got {resp.status}" + + # Verify source file still exists (move was blocked) + assert (system_user_dir / "source.txt").exists() + + +class TestSystemUserCreationBlocking: + """E2E tests for blocking System User name creation via POST /users. + + Verifies: + - POST /users returns 400 for System User name (not 500) + """ + + @pytest.mark.asyncio + async def test_post_users_blocks_system_user_name( + self, aiohttp_client, app_multi_user + ): + """POST /users with System User name should return 400 Bad Request.""" + client = await aiohttp_client(app_multi_user) + + resp = await client.post( + "/users", + json={"username": "__system"} + ) + + assert resp.status == 400, \ + f"System User creation should return 400, got {resp.status}" + + @pytest.mark.asyncio + async def test_post_users_blocks_system_user_prefix_variations( + self, aiohttp_client, app_multi_user + ): + """POST /users with any System User prefix variation should return 400 Bad Request.""" + client = await aiohttp_client(app_multi_user) + + system_user_names = ["__system", "__cache", "__config", "__anything"] + + for name in system_user_names: + resp = await client.post("/users", json={"username": name}) + assert resp.status == 400, \ + f"System User name '{name}' should return 400, got {resp.status}" + + +class TestPublicUserStillWorks: + """E2E tests for backward compatibility - Public Users should work as before. + + Verifies: + - Public Users can access their data via HTTP + - Public Users can create files via HTTP + """ + + @pytest.mark.asyncio + async def test_public_user_can_access_userdata( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + Public Users should still be able to access their data. + """ + # Create test directory for Public User + user_dir = mock_user_directory / "default" + user_dir.mkdir() + test_dir = user_dir / "workflows" + test_dir.mkdir() + (test_dir / "test.json").write_text('{"test": true}') + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/userdata?dir=workflows", + headers={"comfy-user": "default"} + ) + + assert resp.status == 200 + data = await resp.json() + assert "test.json" in data + + @pytest.mark.asyncio + async def test_public_user_can_create_files( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + Public Users should still be able to create files. + """ + # Create user directory + user_dir = mock_user_directory / "default" + user_dir.mkdir() + + client = await aiohttp_client(app_multi_user) + + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.post( + "/userdata/newfile.txt", + headers={"comfy-user": "default"}, + data=b"user content" + ) + + assert resp.status == 200 + assert (user_dir / "newfile.txt").exists() + + +class TestCustomNodeScenario: + """Tests for custom node use case: internal API access vs HTTP blocking. + + Verifies: + - Internal API (get_system_user_directory) works for custom nodes + - HTTP endpoint cannot access data created via internal API + """ + + def test_internal_api_can_access_system_user(self, mock_user_directory): + """ + Internal API (get_system_user_directory) should work for custom nodes. + """ + # Custom node uses internal API + system_path = folder_paths.get_system_user_directory("mynode_config") + + assert system_path is not None + assert "__mynode_config" in system_path + + # Can create and write to System User directory + os.makedirs(system_path, exist_ok=True) + config_file = os.path.join(system_path, "settings.json") + with open(config_file, "w") as f: + f.write('{"api_key": "secret"}') + + assert os.path.exists(config_file) + + @pytest.mark.asyncio + async def test_http_cannot_access_internal_data( + self, aiohttp_client, app_multi_user, mock_user_directory + ): + """ + HTTP endpoint cannot access data created via internal API. + """ + # Custom node creates data via internal API + system_path = folder_paths.get_system_user_directory("mynode_config") + os.makedirs(system_path, exist_ok=True) + with open(os.path.join(system_path, "secret.json"), "w") as f: + f.write('{"api_key": "secret"}') + + client = await aiohttp_client(app_multi_user) + + # Attacker tries to access via HTTP + with patch('app.user_manager.args') as mock_args: + mock_args.multi_user = True + resp = await client.get( + "/userdata/secret.json", + headers={"comfy-user": "__mynode_config"} + ) + + # Should be blocked + assert resp.status in [400, 403, 500] + + +class TestStructuralSecurity: + """Tests for structural security pattern. + + Verifies: + - get_public_user_directory() automatically blocks System Users + - New endpoints using this function are automatically protected + """ + + def test_get_public_user_directory_blocks_system_user(self): + """ + Any code using get_public_user_directory() is automatically protected. + """ + # This is the structural security - any new endpoint using this function + # will automatically block System Users + assert folder_paths.get_public_user_directory("__system") is None + assert folder_paths.get_public_user_directory("__cache") is None + assert folder_paths.get_public_user_directory("__anything") is None + + # Public Users work + assert folder_paths.get_public_user_directory("default") is not None + assert folder_paths.get_public_user_directory("user123") is not None + + def test_structural_security_pattern(self, mock_user_directory): + """ + Demonstrate the structural security pattern for new endpoints. + + Any new endpoint should follow this pattern: + 1. Get user from request + 2. Use get_public_user_directory() - automatically blocks System Users + 3. If None, return error + """ + def new_endpoint_handler(user_id: str) -> str | None: + """Example of how new endpoints should be implemented.""" + user_path = folder_paths.get_public_user_directory(user_id) + if user_path is None: + return None # Blocked + return user_path + + # System Users are automatically blocked + assert new_endpoint_handler("__system") is None + assert new_endpoint_handler("__secret") is None + + # Public Users work + assert new_endpoint_handler("default") is not None diff --git a/tests/unit/test_cli_args_types_sync.py b/tests/unit/test_cli_args_types_sync.py new file mode 100644 index 000000000..8dee0e065 --- /dev/null +++ b/tests/unit/test_cli_args_types_sync.py @@ -0,0 +1,50 @@ + +import pytest +import sys +from unittest.mock import patch +from comfy import cli_args +from comfy import cli_args_types + +def test_cli_args_types_completeness(): + """ + Verify that cli_args_types.Configuration matches the actual arguments defined in cli_args. + """ + # Get actual configuration defaults + # Parse with empty args to get defaults + parser = cli_args._create_parser() + with patch.object(parser, 'parse_known_args_with_config_files', return_value=(parser.parse_known_args([])[0], [], [])): + actual_config = cli_args._parse_args(parser, args_parsing=True) + + # Get type definition + type_config = cli_args_types.Configuration() + + actual_keys = set(actual_config.keys()) + type_keys = set(type_config.keys()) + + # Check for missing keys in type definition + missing_in_types = actual_keys - type_keys + assert not missing_in_types, f"Keys present in actual config but missing in types: {missing_in_types}" + + # Check for extra keys in type definition (warning level usually, but here strict) + # We allow exact match or superset if types has deprecated stuff? + # But for now let's assume close parity. + # extra_in_types = type_keys - actual_keys + # if extra_in_types: + # print(f"WARNING: Keys in types but not in actual: {extra_in_types}") + # # Not asserting here as sometimes types carry legacy or helper fields + + # Check specific types if needed. + # Verify new fields exist + assert hasattr(type_config, "disable_auto_launch") + assert hasattr(type_config, "cache_ram") + assert hasattr(type_config, "enable_manager") + assert hasattr(type_config, "disable_manager_ui") + assert hasattr(type_config, "enable_manager_legacy_ui") + assert hasattr(type_config, "disable_async_offload") + assert hasattr(type_config, "disable_pinned_memory") + + # Verify type mismatches we fixed + # async_offload should be Optional[int] in annotation, defaulting to None in value + # But runtime value from default_configuration might be None (or 2 if set) + # Configuration init defaults it to None. + assert type_config.async_offload is None